Merge branch 'dev' into dev_model_providers

# Conflicts:
#	model-providers/model_providers/__init__.py
#	model-providers/model_providers/__main__.py
#	model-providers/model_providers/core/provider_manager.py
#	model-providers/pyproject.toml
This commit is contained in:
glide-the 2024-03-29 12:09:01 +08:00
commit 3ed9162392
195 changed files with 9980 additions and 5932 deletions

View File

@ -92,7 +92,8 @@ class QwenChatAgentOutputParserCustom(StructuredChatOutputParser):
s = s[-1]
return AgentFinish({"output": s}, log=text)
else:
raise OutputParserException(f"Could not parse LLM output: {text}")
return AgentFinish({"output": text}, log=text)
# raise OutputParserException(f"Could not parse LLM output: {text}")
@property
def _type(self) -> str:

View File

@ -1,7 +1,7 @@
import base64
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
def save_base64_audio(base64_audio, file_path):
@ -16,8 +16,8 @@ def aqa_run(model, tokenizer, query):
return response
@regist_tool
def aqa_processor(query: str = Field(description="The question of the image in English")):
@regist_tool(title="音频问答")
def aqa_processor(query: str = Field(description="The question of the audio in English")):
'''use this tool to get answer for audio question'''
from chatchat.server.agent.container import container
@ -28,6 +28,8 @@ def aqa_processor(query: str = Field(description="The question of the image in E
"audio": file_path,
"text": query,
}
return aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model)
ret = aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model)
else:
return "No Audio, Please Try Again"
ret = "No Audio, Please Try Again"
return BaseToolOutput(ret)

View File

@ -1,12 +1,12 @@
# LangChain 的 ArxivQueryRun 工具
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
@regist_tool
@regist_tool(title="ARXIV论文")
def arxiv(query: str = Field(description="The search query title")):
'''A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.'''
from langchain.tools.arxiv.tool import ArxivQueryRun
tool = ArxivQueryRun()
return tool.run(tool_input=query)
return BaseToolOutput(tool.run(tool_input=query))

View File

@ -1,8 +1,8 @@
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
@regist_tool
@regist_tool(title="数学计算器")
def calculate(text: str = Field(description="a math expression")) -> float:
'''
Useful to answer questions about simple calculations.
@ -11,6 +11,8 @@ def calculate(text: str = Field(description="a math expression")) -> float:
import numexpr
try:
return str(numexpr.evaluate(text))
ret = str(numexpr.evaluate(text))
except Exception as e:
return f"wrong: {e}"
ret = f"wrong: {e}"
return BaseToolOutput(ret)

View File

@ -9,7 +9,7 @@ from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
def bing_search(text, config):
@ -95,8 +95,8 @@ def search_engine(query: str,
return context
@regist_tool
@regist_tool(title="互联网搜索")
def search_internet(query: str = Field(description="query for Internet search")):
'''Use this tool to use bing search engine to search the internet and get information.'''
tool_config = get_tool_config("search_internet")
return search_engine(query=query, config=tool_config)
return BaseToolOutput(search_engine(query=query, config=tool_config))

View File

@ -1,8 +1,9 @@
from urllib.parse import urlencode
from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
from chatchat.server.knowledge_base.kb_doc_api import search_docs
from .tools_registry import regist_tool, BaseToolOutput
from chatchat.server.knowledge_base.kb_api import list_kbs
from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId
from chatchat.configs import KB_INFO
@ -11,35 +12,44 @@ KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
class KBToolOutput(BaseToolOutput):
def __str__(self) -> str:
context = ""
docs = self.data
source_documents = []
for inum, doc in enumerate(docs):
doc = DocumentWithVSId.parse_obj(doc)
filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": self.extras.get("database"), "file_name": filename})
url = f"download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
if len(source_documents) == 0:
context = "没有找到相关文档,请更换关键词重试"
else:
for doc in source_documents:
context += doc + "\n"
return context
def search_knowledgebase(query: str, database: str, config: dict):
docs = search_docs(
query=query,
knowledge_base_name=database,
top_k=config["top_k"],
score_threshold=config["score_threshold"])
context = ""
source_documents = []
for inum, doc in enumerate(docs):
filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": database, "file_name": filename})
url = f"download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
if len(source_documents) == 0:
context = "没有找到相关文档,请更换关键词重试"
else:
for doc in source_documents:
context += doc + "\n"
return context
return docs
@regist_tool(description=template_knowledge)
@regist_tool(description=template_knowledge, title="本地知识库")
def search_local_knowledgebase(
database: str = Field(description="Database for Knowledge Search"),
database: str = Field(description="Database for Knowledge Search", choices=list_kbs().data),
query: str = Field(description="Query for Knowledge Search"),
):
''''''
tool_config = get_tool_config("search_local_knowledgebase")
return search_knowledgebase(query=query, database=database, config=tool_config)
ret = search_knowledgebase(query=query, database=database, config=tool_config)
return BaseToolOutput(ret, database=database)

View File

@ -1,10 +1,10 @@
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
@regist_tool
@regist_tool(title="油管视频")
def search_youtube(query: str = Field(description="Query for Videos search")):
'''use this tools_factory to search youtube videos'''
from langchain_community.tools import YouTubeSearchTool
tool = YouTubeSearchTool()
return tool.run(tool_input=query)
return BaseToolOutput(tool.run(tool_input=query))

View File

@ -2,11 +2,11 @@
from langchain.tools.shell import ShellTool
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
@regist_tool
@regist_tool(title="系统命令")
def shell(query: str = Field(description="The command to execute")):
'''Use Shell to execute system shell commands'''
tool = ShellTool()
return tool.run(tool_input=query)
return BaseToolOutput(tool.run(tool_input=query))

View File

@ -7,7 +7,7 @@ import uuid
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
import openai
from chatchat.configs.basic_config import MEDIA_PATH
@ -26,7 +26,7 @@ def get_image_model_config() -> dict:
return config
@regist_tool(return_direct=True)
@regist_tool(title="文生图", return_direct=True)
def text2images(
prompt: str,
n: int = Field(1, description="需生成图片的数量"),
@ -56,7 +56,7 @@ def text2images(
with open(os.path.join(MEDIA_PATH, filename), "wb") as fp:
fp.write(base64.b64decode(x.b64_json))
images.append(filename)
return json.dumps({"message_type": MsgType.IMAGE, "images": images})
return BaseToolOutput({"message_type": MsgType.IMAGE, "images": images}, format="json")
if __name__ == "__main__":

View File

@ -1,15 +1,22 @@
import json
import re
from typing import Any, Union, Dict, Tuple, Callable, Optional, Type
from langchain.agents import tool
from langchain_core.tools import BaseTool
from chatchat.server.pydantic_v1 import BaseModel
from chatchat.server.pydantic_v1 import BaseModel, Extra
__all__ = ["regist_tool", "BaseToolOutput"]
_TOOLS_REGISTRY = {}
# patch BaseTool to support extra fields e.g. a title
BaseTool.Config.extra = Extra.allow
################################### TODO: workaround to langchain #15855
# patch BaseTool to support tool parameters defined using pydantic Field
@ -60,6 +67,7 @@ BaseTool._to_args_and_kwargs = _new_to_args_and_kwargs
def regist_tool(
*args: Any,
title: str = "",
description: str = "",
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
@ -70,9 +78,10 @@ def regist_tool(
add tool to regstiry automatically
'''
def _parse_tool(t: BaseTool):
nonlocal description
nonlocal description, title
_TOOLS_REGISTRY[t.name] = t
# change default description
if not description:
if t.func is not None:
@ -80,6 +89,10 @@ def regist_tool(
elif t.coroutine is not None:
description = t.coroutine.__doc__
t.description = " ".join(re.split(r"\n+\s*", description))
# set a default title for human
if not title:
title = "".join([x.capitalize() for x in t.name.split("_")])
t.title = title
def wrapper(def_func: Callable) -> BaseTool:
partial_ = tool(*args,
@ -101,3 +114,30 @@ def regist_tool(
)
_parse_tool(t)
return t
class BaseToolOutput:
'''
LLM 要求 Tool 的输出为 str Tool 用在别处时希望它正常返回结构化数据
只需要将 Tool 返回值用该类封装能同时满足两者的需要
基类简单的将返回值字符串化或指定 format="json" 将其转为 json
用户也可以继承该类定义自己的转换方法
'''
def __init__(
self,
data: Any,
format: str="",
data_alias: str="",
**extras: Any,
) -> None:
self.data = data
self.format = format
self.extras = extras
if data_alias:
setattr(self, data_alias, property(lambda obj: obj.data))
def __str__(self) -> str:
if self.format == "json":
return json.dumps(self.data, ensure_ascii=False, indent=2)
else:
return str(self.data)

View File

@ -6,7 +6,7 @@ from io import BytesIO
from PIL import Image, ImageDraw
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
import re
from chatchat.server.agent.container import container
@ -99,7 +99,7 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
return response
@regist_tool
@regist_tool(title="图片对话")
def vqa_processor(query: str = Field(description="The question of the image in English")):
'''use this tool to get answer for image question'''
@ -119,6 +119,8 @@ def vqa_processor(query: str = Field(description="The question of the image in E
# end_marker = "Grounded Operation:"
# ans = extract_between_markers(ans, start_marker, end_marker)
return ans
ret = ans
else:
return "No Image, Please Try Again"
ret = "No Image, Please Try Again"
return BaseToolOutput(ret)

View File

@ -3,11 +3,11 @@
"""
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
import requests
@regist_tool
@regist_tool(title="天气查询")
def weather_check(city: str = Field(description="City name,include city and county,like '厦门'")):
'''Use this tool to check the weather at a specific city'''
@ -21,7 +21,7 @@ def weather_check(city: str = Field(description="City name,include city and coun
"temperature": data["results"][0]["now"]["temperature"],
"description": data["results"][0]["now"]["text"],
}
return weather
return BaseToolOutput(weather)
else:
raise Exception(
f"Failed to retrieve weather: {response.status_code}")

View File

@ -2,7 +2,7 @@
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
from .tools_registry import regist_tool, BaseToolOutput
@regist_tool
@ -13,4 +13,4 @@ def wolfram(query: str = Field(description="The formula to be calculated")):
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=get_tool_config("wolfram").get("appid"))
ans = wolfram.run(query)
return ans
return BaseToolOutput(ans)

View File

@ -1,10 +1,10 @@
from __future__ import annotations
import re
import json
import time
from typing import Dict, List, Literal, Optional, Union
from fastapi import UploadFile
from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl, root_validator
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionToolChoiceOptionParam,
@ -12,8 +12,10 @@ from openai.types.chat import (
completion_create_params,
)
from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE, LLM_MODEL_CONFIG
from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl
from chatchat.server.utils import MsgType
class OpenAIBaseInput(BaseModel):
user: Optional[str] = None
@ -21,7 +23,7 @@ class OpenAIBaseInput(BaseModel):
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict] = None
extra_query: Optional[Dict] = None
extra_body: Optional[Dict] = None
extra_json: Optional[Dict] = Field(None, alias="extra_body")
timeout: Optional[float] = None
class Config:
@ -44,8 +46,8 @@ class OpenAIChatInput(OpenAIBaseInput):
stop: Union[Optional[str], List[str]] = None
stream: Optional[bool] = None
temperature: Optional[float] = TEMPERATURE
tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None
tools: List[ChatCompletionToolParam] = None
tool_choice: Optional[Union[ChatCompletionToolChoiceOptionParam, str]] = None
tools: List[Union[ChatCompletionToolParam, str]] = None
top_logprobs: Optional[int] = None
top_p: Optional[float] = None
@ -98,3 +100,62 @@ class OpenAIAudioSpeechInput(OpenAIBaseInput):
voice: str
response_format: Optional[Literal["mp3", "opus", "aac", "flac", "pcm", "wav"]] = None
speed: Optional[float] = None
class OpenAIBaseOutput(BaseModel):
id: Optional[str] = None
content: Optional[str] = None
model: Optional[str] = None
object: Literal["chat.completion", "chat.completion.chunk"] = "chat.completion.chunk"
role: Literal["assistant"] = "assistant"
finish_reason: Optional[str] = None
created: int = Field(default_factory=lambda : int(time.time()))
tool_calls: List[Dict] = []
status: Optional[int] = None # AgentStatus
message_type: int = MsgType.TEXT
message_id: Optional[str] = None # id in database table
is_ref: bool = False # wheather show in seperated expander
class Config:
extra = "allow"
def model_dump(self) -> dict:
result = {
"id": self.id,
"object": self.object,
"model": self.model,
"created": self.created,
"status": self.status,
"message_type": self.message_type,
"message_id": self.message_id,
"is_ref": self.is_ref,
**(self.model_extra or {}),
}
if self.object == "chat.completion.chunk":
result["choices"] = [{
"delta": {
"content": self.content,
"tool_calls": self.tool_calls,
},
"role": self.role,
}]
elif self.object == "chat.completion":
result["choices"] = [{
"message": {
"role": self.role,
"content": self.content,
"finish_reason": self.finish_reason,
"tool_calls": self.tool_calls,
}
}]
return result
def model_dump_json(self):
return json.dumps(self.model_dump(), ensure_ascii=False)
class OpenAIChatOutput(OpenAIBaseOutput):
...

View File

@ -1,12 +1,17 @@
from __future__ import annotations
from typing import List
from typing import List, Dict
from fastapi import APIRouter, Request
from langchain.prompts.prompt import PromptTemplate
from chatchat.server.api_server.api_schemas import OpenAIChatInput, MsgType, AgentStatus
from chatchat.server.chat.chat import chat
from chatchat.server.chat.feedback import chat_feedback
from chatchat.server.chat.file_chat import file_chat
from chatchat.server.db.repository import add_message_to_db
from chatchat.server.utils import get_OpenAIClient, get_tool, get_tool_config, get_prompt_template
from .openai_routes import openai_request
chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"])
@ -21,4 +26,115 @@ chat_router.post("/feedback",
chat_router.post("/file_chat",
summary="文件对话"
)(file_chat)
)(file_chat)
@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口")
async def chat_completions(
request: Request,
body: OpenAIChatInput,
) -> Dict:
'''
请求参数与 openai.chat.completions.create 一致可以通过 extra_body 传入额外参数
tools tool_choice 可以直接传工具名称会根据项目里包含的 tools 进行转换
通过不同的参数组合调用不同的 chat 功能
- tool_choice
- extra_body 中包含 tool_input: 直接调用 tool_choice(tool_input)
- extra_body 中不包含 tool_input: 通过 agent 调用 tool_choice
- tools: agent 对话
- 其它LLM 对话
以后还要考虑其它的组合如文件对话
返回与 openai 兼容的 Dict
'''
client = get_OpenAIClient(model_name=body.model, is_async=True)
extra = {**body.model_extra} or {}
for key in list(extra):
delattr(body, key)
# check tools & tool_choice in request body
if isinstance(body.tool_choice, str):
if t := get_tool(body.tool_choice):
body.tool_choice = {"function": {"name": t.name}, "type": "function"}
if isinstance(body.tools, list):
for i in range(len(body.tools)):
if isinstance(body.tools[i], str):
if t := get_tool(body.tools[i]):
body.tools[i] = {
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.args,
}
}
conversation_id = extra.get("conversation_id")
# chat based on result from one choiced tool
if body.tool_choice:
tool = get_tool(body.tool_choice["function"]["name"])
if not body.tools:
body.tools = [{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.args,
}
}]
if tool_input := extra.get("tool_input"):
message_id = add_message_to_db(
chat_type="tool_call",
query=body.messages[-1]["content"],
conversation_id=conversation_id
) if conversation_id else None
tool_result = await tool.ainvoke(tool_input)
prompt_template = PromptTemplate.from_template(get_prompt_template("llm_model", "rag"))
body.messages[-1]["content"] = prompt_template.format(context=tool_result, question=body.messages[-1]["content"])
del body.tools
del body.tool_choice
extra_json = {
"message_id": message_id,
"status": None,
}
header = [{**extra_json,
"content": f"知识库参考资料:\n\n{tool_result}\n\n",
"tool_output":tool_result.data,
"is_ref": True,
}]
return await openai_request(client.chat.completions.create, body, extra_json=extra_json, header=header)
# agent chat with tool calls
if body.tools:
message_id = add_message_to_db(
chat_type="agent_chat",
query=body.messages[-1]["content"],
conversation_id=conversation_id
) if conversation_id else None
chat_model_config = {} # TODO: 前端支持配置模型
tool_names = [x["function"]["name"] for x in body.tools]
tool_config = {name: get_tool_config(name) for name in tool_names}
result = await chat(query=body.messages[-1]["content"],
metadata=extra.get("metadata", {}),
conversation_id=extra.get("conversation_id", ""),
message_id=message_id,
history_len=-1,
history=body.messages[:-1],
stream=body.stream,
chat_model_config=extra.get("chat_model_config", chat_model_config),
tool_config=extra.get("tool_config", tool_config),
)
return result
else: # LLM chat directly
message_id = add_message_to_db(
chat_type="llm_chat",
query=body.messages[-1]["content"],
conversation_id=conversation_id
) if conversation_id else None
extra_json = {
"message_id": message_id,
"status": None,
}
return await openai_request(client.chat.completions.create, body, extra_json=extra_json)

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from typing import Dict, Tuple, AsyncGenerator
from typing import Dict, Tuple, AsyncGenerator, Iterable
from fastapi import APIRouter, Request
from openai import AsyncClient
@ -19,7 +19,7 @@ openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"]
@asynccontextmanager
async def acquire_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
'''
对重名模型进行调度依次选择空闲的模型 -> 当前访问数最少的模型
'''
@ -49,19 +49,47 @@ async def acquire_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
semaphore.release()
async def openai_request(method, body):
async def openai_request(method, body, extra_json: Dict={}, header: Iterable=[], tail: Iterable=[]):
'''
helper function to make openai request
helper function to make openai request with extra fields
'''
async def generator():
async for chunk in await method(**params):
yield {"data": chunk.json()}
for x in header:
if isinstance(x, str):
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
elif isinstance(x, dict):
x = OpenAIChatOutput.model_validate(x)
else:
raise RuntimeError(f"unsupported value: {header}")
for k, v in extra_json.items():
setattr(x, k, v)
yield x.model_dump_json()
async for chunk in await method(**params):
for k, v in extra_json.items():
setattr(chunk, k, v)
yield chunk.model_dump_json()
for x in tail:
if isinstance(x, str):
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
elif isinstance(x, dict):
x = OpenAIChatOutput.model_validate(x)
else:
raise RuntimeError(f"unsupported value: {tail}")
for k, v in extra_json.items():
setattr(x, k, v)
yield x.model_dump_json()
params = body.model_dump(exclude_unset=True)
params = body.dict(exclude_unset=True)
if hasattr(body, "stream") and body.stream:
return EventSourceResponse(generator())
else:
return (await method(**params)).dict()
result = await method(**params)
for k, v in extra_json.items():
setattr(result, k, v)
return result.model_dump()
@openai_router.get("/models")
@ -74,12 +102,12 @@ async def list_models() -> List:
client = get_OpenAIClient(name, is_async=True)
models = await client.models.list()
if config.get("platform_type") == "xinference":
models = models.dict(exclude={"data":..., "object":...})
models = models.model_dump(exclude={"data":..., "object":...})
for x in models:
models[x]["platform_name"] = name
return [{**v, "id": k} for k, v in models.items()]
elif config.get("platform_type") == "oneapi":
return [{**x.dict(), "platform_name": name} for x in models.data]
return [{**x.model_dump(), "platform_name": name} for x in models.data]
except Exception:
logger.error(f"failed request to platform: {name}", exc_info=True)
return {}
@ -97,12 +125,8 @@ async def create_chat_completions(
request: Request,
body: OpenAIChatInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
result = await openai_request(client.chat.completions.create, body)
# result["related_docs"] = ["doc1"]
# result["choices"][0]["message"]["related_docs"] = ["doc1"]
# print(result)
# breakpoint()
return result
@ -111,7 +135,7 @@ async def create_completions(
request: Request,
body: OpenAIChatInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
return await openai_request(client.completions.create, body)
@ -130,7 +154,7 @@ async def create_image_generations(
request: Request,
body: OpenAIImageGenerationsInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
return await openai_request(client.images.generate, body)
@ -139,7 +163,7 @@ async def create_image_variations(
request: Request,
body: OpenAIImageVariationsInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
return await openai_request(client.images.create_variation, body)
@ -148,7 +172,7 @@ async def create_image_edit(
request: Request,
body: OpenAIImageEditsInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
return await openai_request(client.images.edit, body)
@ -157,7 +181,7 @@ async def create_audio_translations(
request: Request,
body: OpenAIAudioTranslationsInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
return await openai_request(client.audio.translations.create, body)
@ -166,7 +190,7 @@ async def create_audio_transcriptions(
request: Request,
body: OpenAIAudioTranscriptionsInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
return await openai_request(client.audio.transcriptions.create, body)
@ -175,7 +199,7 @@ async def create_audio_speech(
request: Request,
body: OpenAIAudioSpeechInput,
):
async with acquire_model_client(body.model) as client:
async with get_model_client(body.model) as client:
return await openai_request(client.audio.speech.create, body)

View File

@ -5,7 +5,7 @@ from typing import List
from fastapi import APIRouter, Request, Body
from chatchat.configs import logger
from chatchat.server.utils import BaseResponse, get_tool
from chatchat.server.utils import BaseResponse, get_tool, get_tool_config
tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
@ -14,20 +14,24 @@ tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
@tool_router.get("/", response_model=BaseResponse)
async def list_tools():
tools = get_tool()
data = {t.name: {"name": t.name, "description": t.description, "args": t.args} for t in tools}
data = {t.name: {
"name": t.name,
"title": t.title,
"description": t.description,
"args": t.args,
"config": get_tool_config(t.name),
} for t in tools.values()}
return {"data": data}
@tool_router.post("/call", response_model=BaseResponse)
async def call_tool(
name: str = Body(examples=["calculate"]),
kwargs: dict = Body({}, examples=[{"a":1,"b":2,"operator":"+"}]),
tool_input: dict = Body({}, examples=[{"text": "3+5/2"}]),
):
tools = get_tool()
if tool := tools.get(name):
if tool := get_tool(name):
try:
result = await tool.ainvoke(kwargs)
result = await tool.ainvoke(tool_input)
return {"data": result}
except Exception:
msg = f"failed to call tool '{name}'"

View File

@ -102,7 +102,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
pass
data = {
"run_id": str(run_id),
"status": AgentStatus.tool_start,
"tool": serialized["name"],
"tool_input": input_str,
}
self.queue.put_nowait(dumps(data))
async def on_tool_end(
@ -116,6 +122,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
) -> None:
"""Run when tool ends running."""
data = {
"run_id": str(run_id),
"status": AgentStatus.tool_end,
"tool_output": output,
}
@ -133,8 +140,10 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
) -> None:
"""Run when tool errors."""
data = {
"run_id": str(run_id),
"status": AgentStatus.tool_end,
"text": str(error),
"tool_output": str(error),
"is_error": True,
}
# self.done.clear()
self.queue.put_nowait(dumps(data))

View File

@ -1,6 +1,8 @@
import asyncio
import json
import time
from typing import AsyncIterable, List
import uuid
from fastapi import Body
from sse_starlette.sse import EventSourceResponse
@ -10,19 +12,19 @@ from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains import LLMChain
from langchain.prompts.chat import ChatPromptTemplate
from langchain.prompts import PromptTemplate
from chatchat.configs.model_config import LLM_MODEL_CONFIG
from chatchat.server.agent.agent_factory.agents_registry import agents_registry
from chatchat.server.agent.container import container
from chatchat.server.api_server.api_schemas import OpenAIChatOutput
from chatchat.server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType, get_tool
from chatchat.server.chat.utils import History
from chatchat.server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from chatchat.server.db.repository import add_message_to_db
from chatchat.server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
def create_models_from_config(configs, callbacks, stream):
if configs is None:
configs = {}
configs = configs or LLM_MODEL_CONFIG
models = {}
prompts = {}
for model_type, model_configs in configs.items():
@ -99,6 +101,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
conversation_id: str = Body("", description="对话框ID"),
message_id: str = Body(None, description="数据库消息ID"),
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
history: List[History] = Body(
[],
@ -115,13 +118,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]),
tool_config: dict = Body({}, description="工具配置", examples=[]),
):
async def chat_iterator() -> AsyncIterable[str]:
message_id = add_message_to_db(
chat_type="llm_chat",
query=query,
conversation_id=conversation_id
) if conversation_id else None
'''Agent 对话'''
async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]:
callback = AgentExecutorAsyncIteratorCallbackHandler()
callbacks = [callback]
models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config,
@ -145,9 +144,32 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
}
), callback.done))
last_tool = {}
async for chunk in callback.aiter():
data = json.loads(chunk)
if data["status"] == AgentStatus.tool_end:
data["tool_calls"] = []
data["message_type"] = MsgType.TEXT
if data["status"] == AgentStatus.tool_start:
last_tool = {
"index": 0,
"id": data["run_id"],
"type": "function",
"function": {
"name": data["tool"],
"arguments": data["tool_input"],
},
"tool_output": None,
"is_error": False,
}
data["tool_calls"].append(last_tool)
if data["status"] in [AgentStatus.tool_end]:
last_tool.update(
tool_output=data["tool_output"],
is_error=data.get("is_error", False)
)
data["tool_calls"] = [last_tool]
last_tool = {}
try:
tool_output = json.loads(data["tool_output"])
if message_type := tool_output.get("message_type"):
@ -161,10 +183,43 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
data["message_type"] = message_type
except:
...
data.setdefault("message_type", MsgType.TEXT)
data["message_id"] = message_id
yield json.dumps(data, ensure_ascii=False)
ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion.chunk",
content=data.get("text", ""),
role="assistant",
tool_calls=data["tool_calls"],
model=models["llm_model"].model_name,
status = data["status"],
message_type = data["message_type"],
message_id=message_id,
)
yield ret.model_dump_json()
await task
return EventSourceResponse(chat_iterator())
if stream:
return EventSourceResponse(chat_iterator())
else:
ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion",
content="",
role="assistant",
finish_reason="stop",
tool_calls=[],
status = AgentStatus.agent_finish,
message_type = MsgType.TEXT,
message_id=message_id,
)
async for chunk in chat_iterator():
data = json.loads(chunk)
if text := data["choices"][0]["delta"]["content"]:
ret.content += text
if data["status"] == AgentStatus.tool_end:
ret.tool_calls += data["choices"][0]["delta"]["tool_calls"]
ret.model = data["model"]
ret.created = data["created"]
return ret.model_dump()

View File

@ -96,11 +96,7 @@ class ESKBService(KBService):
return os.path.join(ESKBService.get_kb_path(knowledge_base_name), "vector_store")
def do_create_kb(self):
if os.path.exists(self.doc_path):
if not os.path.exists(os.path.join(self.kb_path, "vector_store")):
os.makedirs(os.path.join(self.kb_path, "vector_store"))
else:
logger.warning("directory `vector_store` already exists.")
...
def vs_type(self) -> str:
return SupportedVSType.ES

View File

@ -65,7 +65,6 @@ class FaissKBService(KBService):
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs:
embeddings = vs.embeddings.embed_query(query)
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
return docs

View File

@ -1,24 +1,25 @@
import base64
from chatchat.server.utils import get_tool_config
import streamlit as st
from streamlit_antd_components.utils import ParseItems
from chatchat.webui_pages.dialogue.utils import process_files
# from chatchat.webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \
# get_select_model_endpoint
from chatchat.webui_pages.utils import *
from streamlit_chatbox import *
from streamlit_modal import Modal
from datetime import datetime
import uuid
import os
import re
import time
from typing import List, Dict
import streamlit as st
from streamlit_antd_components.utils import ParseItems
import openai
from streamlit_chatbox import *
from streamlit_modal import Modal
from datetime import datetime
from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS)
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
from chatchat.server.utils import MsgType, get_config_models
import uuid
from typing import List, Dict
from chatchat.server.utils import get_tool_config
from chatchat.webui_pages.utils import *
from chatchat.webui_pages.dialogue.utils import process_files
img_dir = (Path(__file__).absolute().parent.parent.parent)
@ -121,69 +122,66 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.write("\n\n".join(cmds))
with st.sidebar:
conv_names = list(st.session_state["conversation_ids"].keys())
index = 0
tab1, tab2 = st.tabs(["对话设置", "模型设置"])
if st.session_state.get("cur_conv_name") in conv_names:
index = conv_names.index(st.session_state.get("cur_conv_name"))
conversation_name = st.selectbox("当前会话", conv_names, index=index)
chat_box.use_chat_name(conversation_name)
conversation_id = st.session_state["conversation_ids"][conversation_name]
with tab1:
use_agent = st.checkbox("启用Agent", True, help="请确保选择的模型具备Agent能力")
# 选择工具
tools = api.list_tools()
if use_agent:
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"])
else:
selected_tool = st.selectbox("选择工具", list(tools), format_func=lambda x: tools[x]["title"])
selected_tools = [selected_tool]
selected_tool_configs = {name: tool["config"] for name, tool in tools.items() if name in selected_tools}
platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS]
platform = st.selectbox("选择模型平台", platforms)
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform))
llm_model = st.selectbox("选择LLM模型", llm_models)
# 当不启用Agent时手动生成工具参数
# TODO: 需要更精细的控制控件
tool_input = {}
if not use_agent and len(selected_tools) == 1:
with st.expander("工具参数", True):
for k, v in tools[selected_tools[0]]["args"].items():
if choices := v.get("choices", v.get("enum")):
tool_input[k] = st.selectbox(v["title"], choices)
else:
if v["type"] == "integer":
tool_input[k] = st.slider(v["title"], value=v.get("default"))
elif v["type"] == "number":
tool_input[k] = st.slider(v["title"], value=v.get("default"), step=0.1)
else:
tool_input[k] = st.text_input(v["title"], v.get("default"))
# 传入后端的内容
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
tool_use = True
for key in LLM_MODEL_CONFIG:
if key == 'llm_model':
continue
if key == 'action_model':
first_key = next(iter(LLM_MODEL_CONFIG[key]))
if first_key not in SUPPORT_AGENT_MODELS:
st.warning("不支持Agent的模型无法执行任何工具调用")
tool_use = False
continue
if LLM_MODEL_CONFIG[key]:
first_key = next(iter(LLM_MODEL_CONFIG[key]))
chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
# 选择工具
selected_tool_configs = {}
if tool_use:
from chatchat.configs import model_config as model_config_py
import importlib
importlib.reload(model_config_py)
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
with tab2:
# 会话
conv_names = list(st.session_state["conversation_ids"].keys())
index = 0
if st.session_state.get("cur_conv_name") in conv_names:
index = conv_names.index(st.session_state.get("cur_conv_name"))
conversation_name = st.selectbox("当前会话", conv_names, index=index)
chat_box.use_chat_name(conversation_name)
conversation_id = st.session_state["conversation_ids"][conversation_name]
tools = get_tool_config()
with st.expander("工具栏"):
for tool in tools:
is_selected = st.checkbox(tool, value=tools[tool]["use"], key=tool)
if is_selected:
selected_tool_configs[tool] = tools[tool]
# 模型
platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS]
platform = st.selectbox("选择模型平台", platforms)
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform))
llm_model = st.selectbox("选择LLM模型", llm_models)
if llm_model is not None:
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
# 传入后端的内容
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
for key in LLM_MODEL_CONFIG:
if LLM_MODEL_CONFIG[key]:
first_key = next(iter(LLM_MODEL_CONFIG[key]))
chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
# print(len(files_upload["audios"])) if files_upload else None
if llm_model is not None:
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
# if dialogue_mode == "文件对话":
# with st.expander("文件对话配置", True):
# files = st.file_uploader("上传知识文件:",
# [i for ls in LOADER_DICT.values() for i in ls],
# accept_multiple_files=True,
# )
# kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
# score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
# if st.button("开始上传", disabled=len(files) == 0):
# st.session_state["file_chat_id"] = upload_temp_docs(files, api)
# Display chat messages from history on app rerun
chat_box.output_messages()
chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter。输入/help查看自定义命令 "
@ -228,50 +226,85 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
chat_box.ai_say("正在思考...")
text = ""
text_action = ""
element_index = 0
started = False
for d in api.chat_chat(query=prompt,
metadata=files_upload,
history=history,
chat_model_config=chat_model_config,
conversation_id=conversation_id,
tool_config=selected_tool_configs,
):
message_id = d.get("message_id", "")
client = openai.Client(base_url=f"{api_address()}/chat", api_key="NONE")
messages = history + [{"role": "user", "content": prompt}]
tools = list(selected_tool_configs)
if len(selected_tools) == 1:
tool_choice = selected_tools[0]
else:
tool_choice = None
# 如果 tool_input 中有空的字段,设为用户输入
for k in tool_input:
if tool_input[k] in [None, ""]:
tool_input[k] = prompt
extra_body = dict(
metadata=files_upload,
chat_model_config=chat_model_config,
conversation_id=conversation_id,
tool_input = tool_input,
)
for d in client.chat.completions.create(
messages=messages,
model=llm_model,
stream=True,
tools=tools,
tool_choice=tool_choice,
extra_body=extra_body,
):
print("\n\n", d.status, "\n", d, "\n\n")
message_id = d.message_id
metadata = {
"message_id": message_id,
}
print(d)
if d["status"] == AgentStatus.error:
st.error(d["text"])
elif d["status"] == AgentStatus.agent_action:
formatted_data = {
"Function": d["tool_name"],
"function_input": d["tool_input"]
}
element_index += 1
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
chat_box.insert_msg(
Markdown(title="Function call", in_expander=True, expanded=True, state="running"))
text = """\n```{}\n```\n""".format(formatted_json)
chat_box.update_msg(Markdown(text), element_index=element_index)
elif d["status"] == AgentStatus.tool_end:
text += """\n```\nObservation:\n{}\n```\n""".format(d["tool_output"])
chat_box.update_msg(Markdown(text), element_index=element_index, expanded=False, state="complete")
elif d["status"] == AgentStatus.llm_new_token:
text += d["text"]
chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata)
elif d["status"] == AgentStatus.llm_end:
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
elif d["status"] == AgentStatus.agent_finish:
if d["message_type"] == MsgType.IMAGE:
for url in json.loads(d["text"]).get("images", []):
url = f"{api.base_url}/media/{url}"
chat_box.insert_msg(Image(url))
chat_box.update_msg(element_index=element_index, expanded=False, state="complete")
if d.status == AgentStatus.error:
st.error(d.choices[0].delta.content)
elif d.status == AgentStatus.llm_start:
if not started:
started = True
else:
chat_box.insert_msg(Markdown(d["text"], expanded=True))
chat_box.insert_msg("正在解读工具输出结果...")
text = d.choices[0].delta.content or ""
elif d.status == AgentStatus.llm_new_token:
text += d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata)
elif d.status == AgentStatus.llm_end:
text += d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=False, metadata=metadata)
# tool 的输出与 llm 输出重复了
# elif d.status == AgentStatus.tool_start:
# formatted_data = {
# "Function": d.choices[0].delta.tool_calls[0].function.name,
# "function_input": d.choices[0].delta.tool_calls[0].function.arguments,
# }
# formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
# text = """\n```{}\n```\n""".format(formatted_json)
# chat_box.insert_msg( # TODO: insert text directly not shown
# Markdown(text, title="Function call", in_expander=True, expanded=True, state="running"))
# elif d.status == AgentStatus.tool_end:
# tool_output = d.choices[0].delta.tool_calls[0].tool_output
# if d.message_type == MsgType.IMAGE:
# for url in json.loads(tool_output).get("images", []):
# url = f"{api.base_url}/media/{url}"
# chat_box.insert_msg(Image(url))
# chat_box.update_msg(expanded=False, state="complete")
# else:
# text += """\n```\nObservation:\n{}\n```\n""".format(tool_output)
# chat_box.update_msg(text, streaming=False, expanded=False, state="complete")
elif d.status == AgentStatus.agent_finish:
text = d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"))
elif d.status == None: # not agent chat
if getattr(d, "is_ref", False):
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete", title="参考资料"))
chat_box.insert_msg("")
else:
text += d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata)
chat_box.update_msg(text, streaming=False, metadata=metadata)
if os.path.exists("tmp/image.jpg"):
with open("tmp/image.jpg", "rb") as image_file:
@ -313,8 +346,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.rerun()
now = datetime.now()
with st.sidebar:
with tab1:
cols = st.columns(2)
export_btn = cols[0]
if cols[1].button(
@ -333,3 +365,5 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
mime="text/markdown",
use_container_width=True,
)
# st.write(chat_box.history)

View File

@ -50,6 +50,14 @@ class ApiRequest:
timeout=self.timeout)
return self._client
def _check_url(self, url: str) -> str:
'''
新版 httpx 强制要求 url / 结尾否则会返回 307
'''
if not url.endswith("/"):
url = url + "/"
return url
def get(
self,
url: str,
@ -58,6 +66,7 @@ class ApiRequest:
stream: bool = False,
**kwargs: Any,
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._check_url(url)
while retry > 0:
try:
if stream:
@ -79,6 +88,7 @@ class ApiRequest:
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._check_url(url)
while retry > 0:
try:
# print(kwargs)
@ -101,6 +111,7 @@ class ApiRequest:
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._check_url(url)
while retry > 0:
try:
if stream:
@ -638,6 +649,27 @@ class ApiRequest:
resp = self.post("/chat/feedback", json=data)
return self._get_response_value(resp)
def list_tools(self) -> Dict:
'''
列出所有工具
'''
resp = self.get("/tools")
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data", {}))
def call_tool(
self,
name: str,
tool_input: Dict = {},
):
'''
调用工具
'''
data = {
"name": name,
"tool_input": tool_input,
}
resp = self.post("/tools/call", json=data)
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
class AsyncApiRequest(ApiRequest):
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):

View File

@ -79,6 +79,12 @@ optional = true
ruff = "^0.1.5"
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.dev]
optional = true
@ -222,4 +228,14 @@ markers = [
"scheduled: mark tests to run in scheduled testing",
"compile: mark placeholder test used to compile integration tests without running them"
]
asyncio_mode = "auto"
asyncio_mode = "auto"
[tool.codespell]
skip = '.git,*.pdf,*.svg,*.pdf,*.yaml,*.ipynb,poetry.lock,*.min.js,*.css,package-lock.json,example_data,_dist,examples,*.trig,*.json,*.md,*.html,*.txt,*.csv'
# Ignore latin etc
ignore-regex = '.*(Stati Uniti|Tense=Pres).*'
# whats is a typo but used frequently in queries so kept as is
# aapply - async apply
# unsecure - typo but part of API, decided to not bother for now
ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin'

85
model-providers/Makefile Normal file
View File

@ -0,0 +1,85 @@
.PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
######################
# TESTING AND COVERAGE
######################
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
# Run unit tests and generate a coverage report.
coverage:
poetry run pytest --cov \
--cov-config=.coveragerc \
--cov-report xml \
--cov-report term-missing:skip-covered \
$(TEST_FILE)
test tests:
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
extended_tests:
poetry run pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests
test_watch:
poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket tests/unit_tests
test_watch_extended:
poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --only-extended tests/unit_tests
integration_tests:
poetry run pytest tests/integration_tests
scheduled_tests:
poetry run pytest -m scheduled tests/integration_tests
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/langchain --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=model_providers
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
######################
# HELP
######################
help:
@echo '-- LINTING --'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'spell_check - run codespell on the project'
@echo 'spell_fix - run codespell on the project and fix the errors'
@echo '-- TESTS --'
@echo 'coverage - run unit tests and generate coverage report'
@echo 'test - run unit tests'
@echo 'tests - run unit tests (alias for "make test")'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

View File

@ -1,43 +1,50 @@
import os
from typing import cast, Generator
from typing import Generator, cast
from model_providers import provider_manager
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResultChunk,
LLMResultChunkDelta,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import ModelType
if __name__ == '__main__':
# 基于配置管理器创建的模型实例
# Invoke model
model_instance = provider_manager.get_model_instance(provider='openai', model_type=ModelType.LLM, model='gpt-4')
model_instance = provider_manager.get_model_instance(
provider="openai", model_type=ModelType.LLM, model="gpt-4"
)
response = model_instance.invoke_llm(
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
'plugin_web_search': True,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
"plugin_web_search": True,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
total_message = ''
total_message = ""
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
total_message += chunk.delta.message.content
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
assert (
len(chunk.delta.message.content) > 0
if not chunk.delta.finish_reason
else True
)
print(total_message)
assert '参考资料' in total_message
assert "参考资料" in total_message

View File

@ -1,60 +1,58 @@
import asyncio
import os
from typing import Optional, Any, Dict
from fastapi import (APIRouter,
FastAPI,
HTTPException,
Response,
Request,
status
)
import logging
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
import json
import pprint
import tiktoken
from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest, \
ChatCompletionResponse, ModelList, EmbeddingsResponse, ChatCompletionStreamResponse, FunctionAvailable
from uvicorn import Config, Server
from fastapi.middleware.cors import CORSMiddleware
import logging
import multiprocessing as mp
import os
import pprint
import threading
from typing import Any, Dict, Optional
import tiktoken
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette import EventSourceResponse
from uvicorn import Config, Server
from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
EmbeddingsRequest,
EmbeddingsResponse,
FunctionAvailable,
ModelList,
)
from model_providers.core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.utils.generic import dictify, jsonify
from model_providers.core.model_runtime.model_providers import model_provider_factory
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
from model_providers.core.utils.generic import dictify, jsonify
logger = logging.getLogger(__name__)
async def create_stream_chat_completion(model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest):
async def create_stream_chat_completion(
model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest
):
try:
response = model_type_instance.invoke(
model=chat_request.model,
credentials={
'openai_api_key': "sk-",
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
},
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
model_parameters={
**chat_request.to_model_parameters_dict()
"openai_api_key": "sk-",
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={**chat_request.to_model_parameters_dict()},
stop=chat_request.stop,
stream=chat_request.stream,
user="abc-123"
user="abc-123",
)
return response
@ -81,7 +79,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
host = cfg.get("host", "127.0.0.1")
port = cfg.get("port", 20000)
logger.info(f"Starting openai Bootstrap Server Lifecycle at endpoint: http://{host}:{port}")
logger.info(
f"Starting openai Bootstrap Server Lifecycle at endpoint: http://{host}:{port}"
)
return cls(host=host, port=port)
def serve(self, logging_conf: Optional[dict] = None):
@ -140,8 +140,12 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
async def list_models(self, request: Request):
pass
async def create_embeddings(self, request: Request, embeddings_request: EmbeddingsRequest):
logger.info(f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}")
async def create_embeddings(
self, request: Request, embeddings_request: EmbeddingsRequest
):
logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
)
if os.environ["API_KEY"] is None:
authorization = request.headers.get("Authorization")
authorization = authorization.split("Bearer ")[-1]
@ -171,42 +175,41 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
)
return EmbeddingsResponse(**dictify(response))
async def create_chat_completion(self, request: Request, chat_request: ChatCompletionRequest):
logger.info(f"Received chat completion request: {pprint.pformat(chat_request.dict())}")
async def create_chat_completion(
self, request: Request, chat_request: ChatCompletionRequest
):
logger.info(
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
)
if os.environ["API_KEY"] is None:
authorization = request.headers.get("Authorization")
authorization = authorization.split("Bearer ")[-1]
else:
authorization = os.environ["API_KEY"]
model_provider_factory.get_providers(provider_name='openai')
provider_instance = model_provider_factory.get_provider_instance('openai')
model_provider_factory.get_providers(provider_name="openai")
provider_instance = model_provider_factory.get_provider_instance("openai")
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
if chat_request.stream:
generator = create_stream_chat_completion(model_type_instance, chat_request)
return EventSourceResponse(generator, media_type="text/event-stream")
else:
response = model_type_instance.invoke(
model='gpt-4',
model="gpt-4",
credentials={
'openai_api_key': "sk-",
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
"openai_api_key": "sk-",
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
'plugin_web_search': True,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
"plugin_web_search": True,
},
stop=['you'],
stop=["you"],
stream=False,
user="abc-123"
user="abc-123",
)
chat_response = ChatCompletionResponse(**dictify(response))
@ -215,15 +218,19 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
def run(
cfg: Dict, logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
cfg: Dict,
logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
):
logging.config.dictConfig(logging_conf) # type: ignore
try:
import signal
# 跳过键盘中断使用xoscar的信号处理
signal.signal(signal.SIGINT, lambda *_: None)
api = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg=cfg.get("run_openai_api", {}))
api = RESTFulOpenAIBootstrapBaseWeb.from_config(
cfg=cfg.get("run_openai_api", {})
)
api.set_app_event(started_event=started_event)
api.serve(logging_conf=logging_conf)

View File

@ -1,6 +1,6 @@
from model_providers.core.bootstrap.base import Bootstrap, OpenAIBootstrapBaseWeb
from model_providers.core.bootstrap.bootstrap_register import bootstrap_register
__all__ = [
"bootstrap_register",
"Bootstrap",

View File

@ -1,11 +1,13 @@
from abc import abstractmethod
from collections import deque
from fastapi import Request
class Bootstrap:
"""最大的任务队列"""
_MAX_ONGOING_TASKS: int = 1
"""任务队列"""
@ -37,7 +39,6 @@ class Bootstrap:
class OpenAIBootstrapBaseWeb(Bootstrap):
def __init__(self):
super().__init__()
@ -46,9 +47,13 @@ class OpenAIBootstrapBaseWeb(Bootstrap):
pass
@abstractmethod
async def create_embeddings(self, request: Request, embeddings_request: EmbeddingsRequest):
async def create_embeddings(
self, request: Request, embeddings_request: EmbeddingsRequest
):
pass
@abstractmethod
async def create_chat_completion(self, request: Request, chat_request: ChatCompletionRequest):
async def create_chat_completion(
self, request: Request, chat_request: ChatCompletionRequest
):
pass

View File

@ -5,6 +5,7 @@ class BootstrapRegister:
"""
注册管理器
"""
mapping = {
"bootstrap": {},
}
@ -48,4 +49,3 @@ class BootstrapRegister:
bootstrap_register = BootstrapRegister()

View File

@ -1,6 +1,7 @@
import time
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, root_validator
from typing_extensions import Literal
@ -86,13 +87,15 @@ class ChatCompletionRequest(BaseModel):
top_k: Optional[float] = None
n: int = 1
max_tokens: Optional[int] = None
stop: Optional[list[str]] = None,
stop: Optional[list[str]] = (None,)
stream: Optional[bool] = False
def to_model_parameters_dict(self, *args, **kwargs):
# 调用父类的to_dict方法并排除tools字段
helper.dump_model
return super().dict(exclude={'tools','messages','functions','function_call'}, *args, **kwargs)
return super().dict(
exclude={"tools", "messages", "functions", "function_call"}, *args, **kwargs
)
class ChatCompletionResponseChoice(BaseModel):

View File

@ -2,7 +2,7 @@ from enum import Enum
class PlanningStrategy(Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
ROUTER = "router"
REACT_ROUTER = "react_router"
REACT = "react"
FUNCTION_CALL = "function_call"

View File

@ -5,7 +5,9 @@ from pydantic import BaseModel
from model_providers.core.entities.provider_configuration import ProviderModelBundle
from model_providers.core.file.file_obj import FileObj
from model_providers.core.model_runtime.entities.message_entities import PromptMessageRole
from model_providers.core.model_runtime.entities.message_entities import (
PromptMessageRole,
)
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity
@ -13,6 +15,7 @@ class ModelConfigEntity(BaseModel):
"""
Model Config Entity.
"""
provider: str
model: str
model_schema: AIModelEntity
@ -27,6 +30,7 @@ class AdvancedChatMessageEntity(BaseModel):
"""
Advanced Chat Message Entity.
"""
text: str
role: PromptMessageRole
@ -35,6 +39,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
"""
Advanced Chat Prompt Template Entity.
"""
messages: list[AdvancedChatMessageEntity]
@ -47,6 +52,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
"""
Role Prefix Entity.
"""
user: str
assistant: str
@ -64,11 +70,12 @@ class PromptTemplateEntity(BaseModel):
Prompt Type.
'simple', 'advanced'
"""
SIMPLE = 'simple'
ADVANCED = 'advanced'
SIMPLE = "simple"
ADVANCED = "advanced"
@classmethod
def value_of(cls, value: str) -> 'PromptType':
def value_of(cls, value: str) -> "PromptType":
"""
Get value of given mode.
@ -78,18 +85,21 @@ class PromptTemplateEntity(BaseModel):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid prompt type value {value}')
raise ValueError(f"invalid prompt type value {value}")
prompt_type: PromptType
simple_prompt_template: Optional[str] = None
advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
advanced_completion_prompt_template: Optional[
AdvancedCompletionPromptTemplateEntity
] = None
class ExternalDataVariableEntity(BaseModel):
"""
External Data Variable Entity.
"""
variable: str
type: str
config: dict[str, Any] = {}
@ -105,11 +115,12 @@ class DatasetRetrieveConfigEntity(BaseModel):
Dataset Retrieve Strategy.
'single' or 'multiple'
"""
SINGLE = 'single'
MULTIPLE = 'multiple'
SINGLE = "single"
MULTIPLE = "multiple"
@classmethod
def value_of(cls, value: str) -> 'RetrieveStrategy':
def value_of(cls, value: str) -> "RetrieveStrategy":
"""
Get value of given mode.
@ -119,7 +130,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid retrieve strategy value {value}')
raise ValueError(f"invalid retrieve strategy value {value}")
query_variable: Optional[str] = None # Only when app mode is completion
@ -134,6 +145,7 @@ class DatasetEntity(BaseModel):
"""
Dataset Config Entity.
"""
dataset_ids: list[str]
retrieve_config: DatasetRetrieveConfigEntity
@ -142,6 +154,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
"""
Sensitive Word Avoidance Entity.
"""
type: str
config: dict[str, Any] = {}
@ -150,6 +163,7 @@ class TextToSpeechEntity(BaseModel):
"""
Sensitive Word Avoidance Entity.
"""
enabled: bool
voice: Optional[str] = None
language: Optional[str] = None
@ -159,6 +173,7 @@ class FileUploadEntity(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
@ -166,6 +181,7 @@ class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: Literal["builtin", "api"]
provider_id: str
tool_name: str
@ -176,6 +192,7 @@ class AgentPromptEntity(BaseModel):
"""
Agent Prompt Entity.
"""
first_prompt: str
next_iteration: str
@ -189,6 +206,7 @@ class AgentScratchpadUnit(BaseModel):
"""
Action Entity.
"""
action_name: str
action_input: Union[dict, str]
@ -208,8 +226,9 @@ class AgentEntity(BaseModel):
"""
Agent Strategy.
"""
CHAIN_OF_THOUGHT = 'chain-of-thought'
FUNCTION_CALLING = 'function-calling'
CHAIN_OF_THOUGHT = "chain-of-thought"
FUNCTION_CALLING = "function-calling"
provider: str
model: str
@ -223,6 +242,7 @@ class AppOrchestrationConfigEntity(BaseModel):
"""
App Orchestration Config Entity.
"""
model_config: ModelConfigEntity
prompt_template: PromptTemplateEntity
external_data_variables: list[ExternalDataVariableEntity] = []
@ -244,13 +264,14 @@ class InvokeFrom(Enum):
"""
Invoke From.
"""
SERVICE_API = 'service-api'
WEB_APP = 'web-app'
EXPLORE = 'explore'
DEBUGGER = 'debugger'
SERVICE_API = "service-api"
WEB_APP = "web-app"
EXPLORE = "explore"
DEBUGGER = "debugger"
@classmethod
def value_of(cls, value: str) -> 'InvokeFrom':
def value_of(cls, value: str) -> "InvokeFrom":
"""
Get value of given mode.
@ -260,7 +281,7 @@ class InvokeFrom(Enum):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid invoke from value {value}')
raise ValueError(f"invalid invoke from value {value}")
def to_source(self) -> str:
"""
@ -269,21 +290,22 @@ class InvokeFrom(Enum):
:return: source
"""
if self == InvokeFrom.WEB_APP:
return 'web_app'
return "web_app"
elif self == InvokeFrom.DEBUGGER:
return 'dev'
return "dev"
elif self == InvokeFrom.EXPLORE:
return 'explore_app'
return "explore_app"
elif self == InvokeFrom.SERVICE_API:
return 'api'
return "api"
return 'dev'
return "dev"
class ApplicationGenerateEntity(BaseModel):
"""
Application Generate Entity.
"""
task_id: str
tenant_id: str

View File

@ -1,7 +1,13 @@
import enum
from typing import Any, cast
from langchain.schema import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage
from langchain.schema import (
AIMessage,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from pydantic import BaseModel
from model_providers.core.model_runtime.entities.message_entities import (
@ -16,7 +22,7 @@ from model_providers.core.model_runtime.entities.message_entities import (
class PromptMessageFileType(enum.Enum):
IMAGE = 'image'
IMAGE = "image"
@staticmethod
def value_of(value):
@ -33,8 +39,8 @@ class PromptMessageFile(BaseModel):
class ImagePromptMessageFile(PromptMessageFile):
class DETAIL(enum.Enum):
LOW = 'low'
HIGH = 'high'
LOW = "low"
HIGH = "high"
type: PromptMessageFileType = PromptMessageFileType.IMAGE
detail: DETAIL = DETAIL.LOW
@ -55,32 +61,39 @@ def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMe
for file in message.files:
if file.type == PromptMessageFileType.IMAGE:
file = cast(ImagePromptMessageFile, file)
file_prompt_message_contents.append(ImagePromptMessageContent(
data=file.data,
detail=ImagePromptMessageContent.DETAIL.HIGH
if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW
))
file_prompt_message_contents.append(
ImagePromptMessageContent(
data=file.data,
detail=ImagePromptMessageContent.DETAIL.HIGH
if file.detail.value == "high"
else ImagePromptMessageContent.DETAIL.LOW,
)
)
prompt_message_contents = [TextPromptMessageContent(data=message.content)]
prompt_message_contents = [
TextPromptMessageContent(data=message.content)
]
prompt_message_contents.extend(file_prompt_message_contents)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
prompt_messages.append(
UserPromptMessage(content=prompt_message_contents)
)
else:
prompt_messages.append(UserPromptMessage(content=message.content))
elif isinstance(message, AIMessage):
message_kwargs = {
'content': message.content
}
message_kwargs = {"content": message.content}
if 'function_call' in message.additional_kwargs:
message_kwargs['tool_calls'] = [
if "function_call" in message.additional_kwargs:
message_kwargs["tool_calls"] = [
AssistantPromptMessage.ToolCall(
id=message.additional_kwargs['function_call']['id'],
type='function',
id=message.additional_kwargs["function_call"]["id"],
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=message.additional_kwargs['function_call']['name'],
arguments=message.additional_kwargs['function_call']['arguments']
)
name=message.additional_kwargs["function_call"]["name"],
arguments=message.additional_kwargs["function_call"][
"arguments"
],
),
)
]
@ -88,12 +101,16 @@ def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMe
elif isinstance(message, SystemMessage):
prompt_messages.append(SystemPromptMessage(content=message.content))
elif isinstance(message, FunctionMessage):
prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name))
prompt_messages.append(
ToolPromptMessage(content=message.content, tool_call_id=message.name)
)
return prompt_messages
def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]:
def prompt_messages_to_lc_messages(
prompt_messages: list[PromptMessage],
) -> list[BaseMessage]:
messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
@ -105,24 +122,24 @@ def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list
if isinstance(content, TextPromptMessageContent):
message_contents.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
message_contents.append({
'type': 'image',
'data': content.data,
'detail': content.detail.value
})
message_contents.append(
{
"type": "image",
"data": content.data,
"detail": content.detail.value,
}
)
messages.append(HumanMessage(content=message_contents))
elif isinstance(prompt_message, AssistantPromptMessage):
message_kwargs = {
'content': prompt_message.content
}
message_kwargs = {"content": prompt_message.content}
if prompt_message.tool_calls:
message_kwargs['additional_kwargs'] = {
'function_call': {
'id': prompt_message.tool_calls[0].id,
'name': prompt_message.tool_calls[0].function.name,
'arguments': prompt_message.tool_calls[0].function.arguments
message_kwargs["additional_kwargs"] = {
"function_call": {
"id": prompt_message.tool_calls[0].id,
"name": prompt_message.tool_calls[0].function.name,
"arguments": prompt_message.tool_calls[0].function.arguments,
}
}
@ -130,6 +147,10 @@ def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list
elif isinstance(prompt_message, SystemPromptMessage):
messages.append(SystemMessage(content=prompt_message.content))
elif isinstance(prompt_message, ToolPromptMessage):
messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content))
messages.append(
FunctionMessage(
name=prompt_message.tool_call_id, content=prompt_message.content
)
)
return messages

View File

@ -4,7 +4,10 @@ from typing import Optional
from pydantic import BaseModel
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import ModelType, ProviderModel
from model_providers.core.model_runtime.entities.model_entities import (
ModelType,
ProviderModel,
)
from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity
@ -12,6 +15,7 @@ class ModelStatus(Enum):
"""
Enum class for model status.
"""
ACTIVE = "active"
NO_CONFIGURE = "no-configure"
QUOTA_EXCEEDED = "quota-exceeded"
@ -22,6 +26,7 @@ class SimpleModelProviderEntity(BaseModel):
"""
Simple provider.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
@ -39,7 +44,7 @@ class SimpleModelProviderEntity(BaseModel):
label=provider_entity.label,
icon_small=provider_entity.icon_small,
icon_large=provider_entity.icon_large,
supported_model_types=provider_entity.supported_model_types
supported_model_types=provider_entity.supported_model_types,
)
@ -47,6 +52,7 @@ class ModelWithProviderEntity(ProviderModel):
"""
Model with provider entity.
"""
provider: SimpleModelProviderEntity
status: ModelStatus
@ -55,6 +61,7 @@ class DefaultModelProviderEntity(BaseModel):
"""
Default model provider entity.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
@ -66,6 +73,7 @@ class DefaultModelEntity(BaseModel):
"""
Default model entity.
"""
model: str
model_type: ModelType
provider: DefaultModelProviderEntity

View File

@ -7,9 +7,16 @@ from typing import Optional
from pydantic import BaseModel
from model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from model_providers.core.entities.model_entities import (
ModelStatus,
ModelWithProviderEntity,
SimpleModelProviderEntity,
)
from model_providers.core.entities.provider_entities import CustomConfiguration
from model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType
from model_providers.core.model_runtime.entities.model_entities import (
FetchFrom,
ModelType,
)
from model_providers.core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
CredentialFormSchema,
@ -18,7 +25,9 @@ from model_providers.core.model_runtime.entities.provider_entities import (
)
from model_providers.core.model_runtime.model_providers import model_provider_factory
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
ModelProvider,
)
logger = logging.getLogger(__name__)
@ -27,13 +36,16 @@ class ProviderConfiguration(BaseModel):
"""
Model class for provider configuration.
"""
provider: ProviderEntity
custom_configuration: CustomConfiguration
def __init__(self, **data):
super().__init__(**data)
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
def get_current_credentials(
self, model_type: ModelType, model: str
) -> Optional[dict]:
"""
Get current credentials.
@ -43,7 +55,10 @@ class ProviderConfiguration(BaseModel):
"""
if self.custom_configuration.models:
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
if (
model_configuration.model_type == model_type
and model_configuration.model == model
):
return model_configuration.credentials
if self.custom_configuration.provider:
@ -69,8 +84,9 @@ class ProviderConfiguration(BaseModel):
copy_credentials = credentials.copy()
return copy_credentials
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
-> Optional[dict]:
def get_custom_model_credentials(
self, model_type: ModelType, model: str, obfuscated: bool = False
) -> Optional[dict]:
"""
Get custom model credentials.
@ -83,7 +99,10 @@ class ProviderConfiguration(BaseModel):
return None
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
if (
model_configuration.model_type == model_type
and model_configuration.model == model
):
credentials = model_configuration.credentials
if not obfuscated:
return credentials
@ -113,9 +132,9 @@ class ProviderConfiguration(BaseModel):
# Get model instance of LLM
return provider_instance.get_model_instance(model_type)
def get_provider_model(self, model_type: ModelType,
model: str,
only_active: bool = False) -> Optional[ModelWithProviderEntity]:
def get_provider_model(
self, model_type: ModelType, model: str, only_active: bool = False
) -> Optional[ModelWithProviderEntity]:
"""
Get provider model.
:param model_type: model type
@ -131,8 +150,9 @@ class ProviderConfiguration(BaseModel):
return None
def get_provider_models(self, model_type: Optional[ModelType] = None,
only_active: bool = False) -> list[ModelWithProviderEntity]:
def get_provider_models(
self, model_type: Optional[ModelType] = None, only_active: bool = False
) -> list[ModelWithProviderEntity]:
"""
Get provider models.
:param model_type: model type
@ -148,18 +168,19 @@ class ProviderConfiguration(BaseModel):
model_types = provider_instance.get_provider_schema().supported_model_types
provider_models = self._get_custom_provider_models(
model_types=model_types,
provider_instance=provider_instance
model_types=model_types, provider_instance=provider_instance
)
if only_active:
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
provider_models = [
m for m in provider_models if m.status == ModelStatus.ACTIVE
]
# resort provider_models
return sorted(provider_models, key=lambda x: x.model_type.value)
def _get_custom_provider_models(self,
model_types: list[ModelType],
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
def _get_custom_provider_models(
self, model_types: list[ModelType], provider_instance: ModelProvider
) -> list[ModelWithProviderEntity]:
"""
Get custom provider models.
@ -189,7 +210,9 @@ class ProviderConfiguration(BaseModel):
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
status=ModelStatus.ACTIVE
if credentials
else ModelStatus.NO_CONFIGURE,
)
)
@ -199,15 +222,13 @@ class ProviderConfiguration(BaseModel):
continue
try:
custom_model_schema = (
provider_instance.get_model_instance(model_configuration.model_type)
.get_customizable_model_schema_from_credentials(
model_configuration.model,
model_configuration.credentials
)
custom_model_schema = provider_instance.get_model_instance(
model_configuration.model_type
).get_customizable_model_schema_from_credentials(
model_configuration.model, model_configuration.credentials
)
except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}')
logger.warning(f"get custom model schema failed, {ex}")
continue
if not custom_model_schema:
@ -223,7 +244,7 @@ class ProviderConfiguration(BaseModel):
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
status=ModelStatus.ACTIVE,
)
)
@ -234,16 +255,18 @@ class ProviderConfigurations(BaseModel):
"""
Model class for provider configuration dict.
"""
configurations: dict[str, ProviderConfiguration] = {}
def __init__(self):
super().__init__()
def get_models(self,
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
only_active: bool = False) \
-> list[ModelWithProviderEntity]:
def get_models(
self,
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
only_active: bool = False,
) -> list[ModelWithProviderEntity]:
"""
Get available models.
@ -278,7 +301,9 @@ class ProviderConfigurations(BaseModel):
if provider and provider_configuration.provider.provider != provider:
continue
all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
all_models.extend(
provider_configuration.get_provider_models(model_type, only_active)
)
return all_models
@ -310,6 +335,7 @@ class ProviderModelBundle(BaseModel):
"""
Provider model bundle.
"""
configuration: ProviderConfiguration
provider_instance: ModelProvider
model_type_instance: AIModel

View File

@ -12,11 +12,11 @@ class RestrictModel(BaseModel):
model_type: ModelType
class CustomProviderConfiguration(BaseModel):
"""
Model class for provider custom configuration.
"""
credentials: dict
@ -24,6 +24,7 @@ class CustomModelConfiguration(BaseModel):
"""
Model class for provider custom model configuration.
"""
model: str
model_type: ModelType
credentials: dict
@ -33,5 +34,6 @@ class CustomConfiguration(BaseModel):
"""
Model class for provider custom configuration.
"""
provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = []

View File

@ -3,13 +3,17 @@ from typing import Any
from pydantic import BaseModel
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
)
class QueueEvent(Enum):
"""
QueueEvent enum
"""
MESSAGE = "message"
AGENT_MESSAGE = "agent_message"
MESSAGE_REPLACE = "message-replace"
@ -27,6 +31,7 @@ class AppQueueEvent(BaseModel):
"""
QueueEvent entity
"""
event: QueueEvent
@ -34,21 +39,25 @@ class QueueMessageEvent(AppQueueEvent):
"""
QueueMessageEvent entity
"""
event = QueueEvent.MESSAGE
chunk: LLMResultChunk
class QueueAgentMessageEvent(AppQueueEvent):
"""
QueueMessageEvent entity
"""
event = QueueEvent.AGENT_MESSAGE
chunk: LLMResultChunk
class QueueMessageReplaceEvent(AppQueueEvent):
"""
QueueMessageReplaceEvent entity
"""
event = QueueEvent.MESSAGE_REPLACE
text: str
@ -57,6 +66,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
"""
QueueRetrieverResourcesEvent entity
"""
event = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict]
@ -65,6 +75,7 @@ class AnnotationReplyEvent(AppQueueEvent):
"""
AnnotationReplyEvent entity
"""
event = QueueEvent.ANNOTATION_REPLY
message_annotation_id: str
@ -73,28 +84,34 @@ class QueueMessageEndEvent(AppQueueEvent):
"""
QueueMessageEndEvent entity
"""
event = QueueEvent.MESSAGE_END
llm_result: LLMResult
class QueueAgentThoughtEvent(AppQueueEvent):
"""
QueueAgentThoughtEvent entity
"""
event = QueueEvent.AGENT_THOUGHT
agent_thought_id: str
class QueueMessageFileEvent(AppQueueEvent):
"""
QueueAgentThoughtEvent entity
"""
event = QueueEvent.MESSAGE_FILE
message_file_id: str
class QueueErrorEvent(AppQueueEvent):
"""
QueueErrorEvent entity
"""
event = QueueEvent.ERROR
error: Any
@ -103,6 +120,7 @@ class QueuePingEvent(AppQueueEvent):
"""
QueuePingEvent entity
"""
event = QueueEvent.PING
@ -110,10 +128,12 @@ class QueueStopEvent(AppQueueEvent):
"""
QueueStopEvent entity
"""
class StopBy(Enum):
"""
Stop by enum
"""
USER_MANUAL = "user-manual"
ANNOTATION_REPLY = "annotation-reply"
OUTPUT_MODERATION = "output-moderation"
@ -126,6 +146,7 @@ class QueueMessage(BaseModel):
"""
QueueMessage entity
"""
task_id: str
message_id: str
conversation_id: str

View File

@ -2,23 +2,40 @@ from collections.abc import Generator
from typing import IO, Optional, Union, cast
from model_providers.core.entities.provider_configuration import ProviderModelBundle
from model_providers.errors.error import ProviderTokenNotInitError
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from model_providers.core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.entities.rerank_entities import RerankResult
from model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.model_runtime.model_providers.__base.moderation_model import ModerationModel
from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel
from model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from model_providers.core.model_runtime.entities.text_embedding_entities import (
TextEmbeddingResult,
)
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
from model_providers.core.model_runtime.model_providers.__base.moderation_model import (
ModerationModel,
)
from model_providers.core.model_runtime.model_providers.__base.rerank_model import (
RerankModel,
)
from model_providers.core.model_runtime.model_providers.__base.speech2text_model import (
Speech2TextModel,
)
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
TextEmbeddingModel,
)
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
from model_providers.core.provider_manager import ProviderManager
from model_providers.errors.error import ProviderTokenNotInitError
def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict:
def _fetch_credentials_from_bundle(
provider_model_bundle: ProviderModelBundle, model: str
) -> dict:
"""
Fetch credentials from provider model bundle
:param provider_model_bundle: provider model bundle
@ -26,12 +43,13 @@ def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, m
:return:
"""
credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=provider_model_bundle.model_type_instance.model_type,
model=model
model_type=provider_model_bundle.model_type_instance.model_type, model=model
)
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
raise ProviderTokenNotInitError(
f"Model {model} credentials is not initialized."
)
return credentials
@ -48,10 +66,16 @@ class ModelInstance:
self.credentials = _fetch_credentials_from_bundle(provider_model_bundle, model)
self.model_type_instance = self._provider_model_bundle.model_type_instance
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
-> Union[LLMResult, Generator]:
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -77,11 +101,12 @@ class ModelInstance:
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
callbacks=callbacks,
)
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
def invoke_text_embedding(
self, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke large language model
@ -94,16 +119,17 @@ class ModelInstance:
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
texts=texts,
user=user
model=self.model, credentials=self.credentials, texts=texts, user=user
)
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
def invoke_rerank(
self,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
@ -125,11 +151,10 @@ class ModelInstance:
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user
user=user,
)
def invoke_moderation(self, text: str, user: Optional[str] = None) \
-> bool:
def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool:
"""
Invoke moderation model
@ -142,14 +167,10 @@ class ModelInstance:
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
text=text,
user=user
model=self.model, credentials=self.credentials, text=text, user=user
)
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
-> str:
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str:
"""
Invoke large language model
@ -162,14 +183,17 @@ class ModelInstance:
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
file=file,
user=user
model=self.model, credentials=self.credentials, file=file, user=user
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
-> str:
def invoke_tts(
self,
content_text: str,
tenant_id: str,
voice: str,
streaming: bool,
user: Optional[str] = None,
) -> str:
"""
Invoke large language tts model
@ -191,7 +215,7 @@ class ModelInstance:
user=user,
tenant_id=tenant_id,
voice=voice,
streaming=streaming
streaming=streaming,
)
def get_tts_voices(self, language: str) -> list:
@ -206,21 +230,24 @@ class ModelInstance:
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices(
model=self.model,
credentials=self.credentials,
language=language
model=self.model, credentials=self.credentials, language=language
)
class ModelManager:
def __init__(self,
provider_name_to_provider_records_dict: dict,
provider_name_to_provider_model_records_dict: dict) -> None:
def __init__(
self,
provider_name_to_provider_records_dict: dict,
provider_name_to_provider_model_records_dict: dict,
) -> None:
self._provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict)
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
)
def get_model_instance(self, provider: str, model_type: ModelType, model: str) -> ModelInstance:
def get_model_instance(
self, provider: str, model_type: ModelType, model: str
) -> ModelInstance:
"""
Get model instance
:param provider: provider name
@ -231,8 +258,7 @@ class ModelManager:
if not provider:
return self.get_default_model_instance(model_type)
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
provider=provider,
model_type=model_type
provider=provider, model_type=model_type
)
return ModelInstance(provider_model_bundle, model)
@ -253,5 +279,5 @@ class ModelManager:
return self.get_model_instance(
provider=default_model_entity.provider.provider,
model_type=model_type,
model=default_model_entity.model
model=default_model_entity.model,
)

View File

@ -1,8 +1,14 @@
from abc import ABC
from typing import Optional
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
)
from model_providers.core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
_TEXT_COLOR_MAPPING = {
@ -19,12 +25,21 @@ class Callback(ABC):
Base class for callbacks.
Only for LLM.
"""
raise_error: bool = False
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
def on_before_invoke(
self,
llm_instance: AIModel,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
"""
Before invoke callback
@ -40,10 +55,19 @@ class Callback(ABC):
"""
raise NotImplementedError()
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None):
def on_new_chunk(
self,
llm_instance: AIModel,
chunk: LLMResultChunk,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
"""
On new chunk callback
@ -60,10 +84,19 @@ class Callback(ABC):
"""
raise NotImplementedError()
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
def on_after_invoke(
self,
llm_instance: AIModel,
result: LLMResult,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
"""
After invoke callback
@ -80,10 +113,19 @@ class Callback(ABC):
"""
raise NotImplementedError()
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
def on_invoke_error(
self,
llm_instance: AIModel,
ex: Exception,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
"""
Invoke error callback
@ -100,9 +142,7 @@ class Callback(ABC):
"""
raise NotImplementedError()
def print_text(
self, text: str, color: Optional[str] = None, end: str = ""
) -> None:
def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None:
"""Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text
print(text_to_print, end=end)

View File

@ -4,17 +4,32 @@ import sys
from typing import Optional
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
)
from model_providers.core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
class LoggingCallback(Callback):
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
def on_before_invoke(
self,
llm_instance: AIModel,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
"""
Before invoke callback
@ -28,40 +43,49 @@ class LoggingCallback(Callback):
:param stream: is stream response
:param user: unique user id
"""
self.print_text("\n[on_llm_before_invoke]\n", color='blue')
self.print_text(f"Model: {model}\n", color='blue')
self.print_text("Parameters:\n", color='blue')
self.print_text("\n[on_llm_before_invoke]\n", color="blue")
self.print_text(f"Model: {model}\n", color="blue")
self.print_text("Parameters:\n", color="blue")
for key, value in model_parameters.items():
self.print_text(f"\t{key}: {value}\n", color='blue')
self.print_text(f"\t{key}: {value}\n", color="blue")
if stop:
self.print_text(f"\tstop: {stop}\n", color='blue')
self.print_text(f"\tstop: {stop}\n", color="blue")
if tools:
self.print_text("\tTools:\n", color='blue')
self.print_text("\tTools:\n", color="blue")
for tool in tools:
self.print_text(f"\t\t{tool.name}\n", color='blue')
self.print_text(f"\t\t{tool.name}\n", color="blue")
self.print_text(f"Stream: {stream}\n", color='blue')
self.print_text(f"Stream: {stream}\n", color="blue")
if user:
self.print_text(f"User: {user}\n", color='blue')
self.print_text(f"User: {user}\n", color="blue")
self.print_text("Prompt messages:\n", color='blue')
self.print_text("Prompt messages:\n", color="blue")
for prompt_message in prompt_messages:
if prompt_message.name:
self.print_text(f"\tname: {prompt_message.name}\n", color='blue')
self.print_text(f"\tname: {prompt_message.name}\n", color="blue")
self.print_text(f"\trole: {prompt_message.role.value}\n", color='blue')
self.print_text(f"\tcontent: {prompt_message.content}\n", color='blue')
self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue")
self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue")
if stream:
self.print_text("\n[on_llm_new_chunk]")
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None):
def on_new_chunk(
self,
llm_instance: AIModel,
chunk: LLMResultChunk,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
"""
On new chunk callback
@ -79,10 +103,19 @@ class LoggingCallback(Callback):
sys.stdout.write(chunk.delta.message.content)
sys.stdout.flush()
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
def on_after_invoke(
self,
llm_instance: AIModel,
result: LLMResult,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
"""
After invoke callback
@ -97,24 +130,37 @@ class LoggingCallback(Callback):
:param stream: is stream response
:param user: unique user id
"""
self.print_text("\n[on_llm_after_invoke]\n", color='yellow')
self.print_text(f"Content: {result.message.content}\n", color='yellow')
self.print_text("\n[on_llm_after_invoke]\n", color="yellow")
self.print_text(f"Content: {result.message.content}\n", color="yellow")
if result.message.tool_calls:
self.print_text("Tool calls:\n", color='yellow')
self.print_text("Tool calls:\n", color="yellow")
for tool_call in result.message.tool_calls:
self.print_text(f"\t{tool_call.id}\n", color='yellow')
self.print_text(f"\t{tool_call.function.name}\n", color='yellow')
self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color='yellow')
self.print_text(f"\t{tool_call.id}\n", color="yellow")
self.print_text(f"\t{tool_call.function.name}\n", color="yellow")
self.print_text(
f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow"
)
self.print_text(f"Model: {result.model}\n", color='yellow')
self.print_text(f"Usage: {result.usage}\n", color='yellow')
self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color='yellow')
self.print_text(f"Model: {result.model}\n", color="yellow")
self.print_text(f"Usage: {result.usage}\n", color="yellow")
self.print_text(
f"System Fingerprint: {result.system_fingerprint}\n", color="yellow"
)
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
def on_invoke_error(
self,
llm_instance: AIModel,
ex: Exception,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
"""
Invoke error callback
@ -129,5 +175,5 @@ class LoggingCallback(Callback):
:param stream: is stream response
:param user: unique user id
"""
self.print_text("\n[on_llm_invoke_error]\n", color='red')
self.print_text("\n[on_llm_invoke_error]\n", color="red")
logger.exception(ex)

View File

@ -7,6 +7,7 @@ class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
zh_Hans: Optional[str] = None
en_US: str

View File

@ -1,98 +1,99 @@
from model_providers.core.model_runtime.entities.model_entities import DefaultParameterName
from model_providers.core.model_runtime.entities.model_entities import (
DefaultParameterName,
)
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
DefaultParameterName.TEMPERATURE: {
'label': {
'en_US': 'Temperature',
'zh_Hans': '温度',
"label": {
"en_US": "Temperature",
"zh_Hans": "温度",
},
'type': 'float',
'help': {
'en_US': 'Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.',
'zh_Hans': '温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。',
"type": "float",
"help": {
"en_US": "Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.",
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。",
},
'required': False,
'default': 0.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
"required": False,
"default": 0.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.TOP_P: {
'label': {
'en_US': 'Top P',
'zh_Hans': 'Top P',
"label": {
"en_US": "Top P",
"zh_Hans": "Top P",
},
'type': 'float',
'help': {
'en_US': 'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.',
'zh_Hans': '通过核心采样控制多样性0.5表示考虑了一半的所有可能性加权选项。',
"type": "float",
"help": {
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.",
"zh_Hans": "通过核心采样控制多样性0.5表示考虑了一半的所有可能性加权选项。",
},
'required': False,
'default': 1.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
"required": False,
"default": 1.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.PRESENCE_PENALTY: {
'label': {
'en_US': 'Presence Penalty',
'zh_Hans': '存在惩罚',
"label": {
"en_US": "Presence Penalty",
"zh_Hans": "存在惩罚",
},
'type': 'float',
'help': {
'en_US': 'Applies a penalty to the log-probability of tokens already in the text.',
'zh_Hans': '对文本中已有的标记的对数概率施加惩罚。',
"type": "float",
"help": {
"en_US": "Applies a penalty to the log-probability of tokens already in the text.",
"zh_Hans": "对文本中已有的标记的对数概率施加惩罚。",
},
'required': False,
'default': 0.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
"required": False,
"default": 0.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.FREQUENCY_PENALTY: {
'label': {
'en_US': 'Frequency Penalty',
'zh_Hans': '频率惩罚',
"label": {
"en_US": "Frequency Penalty",
"zh_Hans": "频率惩罚",
},
'type': 'float',
'help': {
'en_US': 'Applies a penalty to the log-probability of tokens that appear in the text.',
'zh_Hans': '对文本中出现的标记的对数概率施加惩罚。',
"type": "float",
"help": {
"en_US": "Applies a penalty to the log-probability of tokens that appear in the text.",
"zh_Hans": "对文本中出现的标记的对数概率施加惩罚。",
},
'required': False,
'default': 0.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
"required": False,
"default": 0.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.MAX_TOKENS: {
'label': {
'en_US': 'Max Tokens',
'zh_Hans': '最大标记',
"label": {
"en_US": "Max Tokens",
"zh_Hans": "最大标记",
},
'type': 'int',
'help': {
'en_US': 'The maximum number of tokens to generate. Requests can use up to 2048 tokens shared between prompt and completion.',
'zh_Hans': '要生成的标记的最大数量。请求可以使用最多2048个标记这些标记在提示和完成之间共享。',
"type": "int",
"help": {
"en_US": "The maximum number of tokens to generate. Requests can use up to 2048 tokens shared between prompt and completion.",
"zh_Hans": "要生成的标记的最大数量。请求可以使用最多2048个标记这些标记在提示和完成之间共享。",
},
'required': False,
'default': 64,
'min': 1,
'max': 2048,
'precision': 0,
"required": False,
"default": 64,
"min": 1,
"max": 2048,
"precision": 0,
},
DefaultParameterName.RESPONSE_FORMAT: {
'label': {
'en_US': 'Response Format',
'zh_Hans': '回复格式',
"label": {
"en_US": "Response Format",
"zh_Hans": "回复格式",
},
'type': 'string',
'help': {
'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.',
'zh_Hans': '设置一个返回格式确保llm的输出尽可能是有效的代码块如JSON、XML等',
"type": "string",
"help": {
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.",
"zh_Hans": "设置一个返回格式确保llm的输出尽可能是有效的代码块如JSON、XML等",
},
'required': False,
'options': ['JSON', 'XML'],
}
"required": False,
"options": ["JSON", "XML"],
},
}

View File

@ -4,19 +4,26 @@ from typing import Optional
from pydantic import BaseModel
from model_providers.core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
from model_providers.core.model_runtime.entities.model_entities import ModelUsage, PriceInfo
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import (
ModelUsage,
PriceInfo,
)
class LLMMode(Enum):
"""
Enum class for large language model mode.
"""
COMPLETION = "completion"
CHAT = "chat"
@classmethod
def value_of(cls, value: str) -> 'LLMMode':
def value_of(cls, value: str) -> "LLMMode":
"""
Get value of given mode.
@ -26,13 +33,14 @@ class LLMMode(Enum):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
raise ValueError(f"invalid mode value {value}")
class LLMUsage(ModelUsage):
"""
Model class for llm usage.
"""
prompt_tokens: int
prompt_unit_price: Decimal
prompt_price_unit: Decimal
@ -50,17 +58,17 @@ class LLMUsage(ModelUsage):
def empty_usage(cls):
return cls(
prompt_tokens=0,
prompt_unit_price=Decimal('0.0'),
prompt_price_unit=Decimal('0.0'),
prompt_price=Decimal('0.0'),
prompt_unit_price=Decimal("0.0"),
prompt_price_unit=Decimal("0.0"),
prompt_price=Decimal("0.0"),
completion_tokens=0,
completion_unit_price=Decimal('0.0'),
completion_price_unit=Decimal('0.0'),
completion_price=Decimal('0.0'),
completion_unit_price=Decimal("0.0"),
completion_price_unit=Decimal("0.0"),
completion_price=Decimal("0.0"),
total_tokens=0,
total_price=Decimal('0.0'),
currency='USD',
latency=0.0
total_price=Decimal("0.0"),
currency="USD",
latency=0.0,
)
@ -68,6 +76,7 @@ class LLMResult(BaseModel):
"""
Model class for llm result.
"""
model: str
prompt_messages: list[PromptMessage]
message: AssistantPromptMessage
@ -79,6 +88,7 @@ class LLMResultChunkDelta(BaseModel):
"""
Model class for llm result chunk delta.
"""
index: int
message: AssistantPromptMessage
usage: Optional[LLMUsage] = None
@ -89,6 +99,7 @@ class LLMResultChunk(BaseModel):
"""
Model class for llm result chunk.
"""
model: str
prompt_messages: list[PromptMessage]
system_fingerprint: Optional[str] = None
@ -99,4 +110,5 @@ class NumTokensResult(PriceInfo):
"""
Model class for number of tokens result.
"""
tokens: int

View File

@ -9,13 +9,14 @@ class PromptMessageRole(Enum):
"""
Enum class for prompt message.
"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
@classmethod
def value_of(cls, value: str) -> 'PromptMessageRole':
def value_of(cls, value: str) -> "PromptMessageRole":
"""
Get value of given mode.
@ -25,13 +26,14 @@ class PromptMessageRole(Enum):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid prompt message type value {value}')
raise ValueError(f"invalid prompt message type value {value}")
class PromptMessageTool(BaseModel):
"""
Model class for prompt message tool.
"""
name: str
description: str
parameters: dict
@ -41,7 +43,8 @@ class PromptMessageFunction(BaseModel):
"""
Model class for prompt message function.
"""
type: str = 'function'
type: str = "function"
function: PromptMessageTool
@ -49,14 +52,16 @@ class PromptMessageContentType(Enum):
"""
Enum class for prompt message content type.
"""
TEXT = 'text'
IMAGE = 'image'
TEXT = "text"
IMAGE = "image"
class PromptMessageContent(BaseModel):
"""
Model class for prompt message content.
"""
type: PromptMessageContentType
data: str
@ -65,6 +70,7 @@ class TextPromptMessageContent(PromptMessageContent):
"""
Model class for text prompt message content.
"""
type: PromptMessageContentType = PromptMessageContentType.TEXT
@ -72,9 +78,10 @@ class ImagePromptMessageContent(PromptMessageContent):
"""
Model class for image prompt message content.
"""
class DETAIL(Enum):
LOW = 'low'
HIGH = 'high'
LOW = "low"
HIGH = "high"
type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW
@ -84,6 +91,7 @@ class PromptMessage(ABC, BaseModel):
"""
Model class for prompt message.
"""
role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None
name: Optional[str] = None
@ -93,6 +101,7 @@ class UserPromptMessage(PromptMessage):
"""
Model class for user prompt message.
"""
role: PromptMessageRole = PromptMessageRole.USER
@ -100,14 +109,17 @@ class AssistantPromptMessage(PromptMessage):
"""
Model class for assistant prompt message.
"""
class ToolCall(BaseModel):
"""
Model class for assistant prompt message tool call.
"""
class ToolCallFunction(BaseModel):
"""
Model class for assistant prompt message tool call function.
"""
name: str
arguments: str
@ -123,6 +135,7 @@ class SystemPromptMessage(PromptMessage):
"""
Model class for system prompt message.
"""
role: PromptMessageRole = PromptMessageRole.SYSTEM
@ -130,5 +143,6 @@ class ToolPromptMessage(PromptMessage):
"""
Model class for tool prompt message.
"""
role: PromptMessageRole = PromptMessageRole.TOOL
tool_call_id: str

View File

@ -11,6 +11,7 @@ class ModelType(Enum):
"""
Enum class for model type.
"""
LLM = "llm"
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
@ -26,22 +27,28 @@ class ModelType(Enum):
:return: model type
"""
if origin_model_type == 'text-generation' or origin_model_type == cls.LLM.value:
if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value:
return cls.LLM
elif origin_model_type == 'embeddings' or origin_model_type == cls.TEXT_EMBEDDING.value:
elif (
origin_model_type == "embeddings"
or origin_model_type == cls.TEXT_EMBEDDING.value
):
return cls.TEXT_EMBEDDING
elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value:
elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value:
return cls.RERANK
elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
elif (
origin_model_type == "speech2text"
or origin_model_type == cls.SPEECH2TEXT.value
):
return cls.SPEECH2TEXT
elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
elif origin_model_type == "tts" or origin_model_type == cls.TTS.value:
return cls.TTS
elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value:
elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value:
return cls.TEXT2IMG
elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION
else:
raise ValueError(f'invalid origin model type {origin_model_type}')
raise ValueError(f"invalid origin model type {origin_model_type}")
def to_origin_model_type(self) -> str:
"""
@ -50,26 +57,28 @@ class ModelType(Enum):
:return: origin model type
"""
if self == self.LLM:
return 'text-generation'
return "text-generation"
elif self == self.TEXT_EMBEDDING:
return 'embeddings'
return "embeddings"
elif self == self.RERANK:
return 'reranking'
return "reranking"
elif self == self.SPEECH2TEXT:
return 'speech2text'
return "speech2text"
elif self == self.TTS:
return 'tts'
return "tts"
elif self == self.MODERATION:
return 'moderation'
return "moderation"
elif self == self.TEXT2IMG:
return 'text2img'
return "text2img"
else:
raise ValueError(f'invalid model type {self}')
raise ValueError(f"invalid model type {self}")
class FetchFrom(Enum):
"""
Enum class for fetch from.
"""
PREDEFINED_MODEL = "predefined-model"
CUSTOMIZABLE_MODEL = "customizable-model"
@ -78,6 +87,7 @@ class ModelFeature(Enum):
"""
Enum class for llm feature.
"""
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
@ -89,6 +99,7 @@ class DefaultParameterName(Enum):
"""
Enum class for parameter template variable.
"""
TEMPERATURE = "temperature"
TOP_P = "top_p"
PRESENCE_PENALTY = "presence_penalty"
@ -97,7 +108,7 @@ class DefaultParameterName(Enum):
RESPONSE_FORMAT = "response_format"
@classmethod
def value_of(cls, value: Any) -> 'DefaultParameterName':
def value_of(cls, value: Any) -> "DefaultParameterName":
"""
Get parameter name from value.
@ -107,13 +118,14 @@ class DefaultParameterName(Enum):
for name in cls:
if name.value == value:
return name
raise ValueError(f'invalid parameter name {value}')
raise ValueError(f"invalid parameter name {value}")
class ParameterType(Enum):
"""
Enum class for parameter type.
"""
FLOAT = "float"
INT = "int"
STRING = "string"
@ -124,6 +136,7 @@ class ModelPropertyKey(Enum):
"""
Enum class for model property key.
"""
MODE = "mode"
CONTEXT_SIZE = "context_size"
MAX_CHUNKS = "max_chunks"
@ -141,6 +154,7 @@ class ProviderModel(BaseModel):
"""
Model class for provider model.
"""
model: str
label: I18nObject
model_type: ModelType
@ -157,6 +171,7 @@ class ParameterRule(BaseModel):
"""
Model class for parameter rule.
"""
name: str
use_template: Optional[str] = None
label: I18nObject
@ -174,6 +189,7 @@ class PriceConfig(BaseModel):
"""
Model class for pricing info.
"""
input: Decimal
output: Optional[Decimal] = None
unit: Decimal
@ -184,6 +200,7 @@ class AIModelEntity(ProviderModel):
"""
Model class for AI model.
"""
parameter_rules: list[ParameterRule] = []
pricing: Optional[PriceConfig] = None
@ -196,6 +213,7 @@ class PriceType(Enum):
"""
Enum class for price type.
"""
INPUT = "input"
OUTPUT = "output"
@ -204,6 +222,7 @@ class PriceInfo(BaseModel):
"""
Model class for price info.
"""
unit_price: Decimal
unit: Decimal
total_amount: Decimal

View File

@ -4,13 +4,18 @@ from typing import Optional
from pydantic import BaseModel
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelType,
ProviderModel,
)
class ConfigurateMethod(Enum):
"""
Enum class for configurate method of provider model.
"""
PREDEFINED_MODEL = "predefined-model"
CUSTOMIZABLE_MODEL = "customizable-model"
@ -19,6 +24,7 @@ class FormType(Enum):
"""
Enum class for form type.
"""
TEXT_INPUT = "text-input"
SECRET_INPUT = "secret-input"
SELECT = "select"
@ -30,6 +36,7 @@ class FormShowOnObject(BaseModel):
"""
Model class for form show on.
"""
variable: str
value: str
@ -38,6 +45,7 @@ class FormOption(BaseModel):
"""
Model class for form option.
"""
label: I18nObject
value: str
show_on: list[FormShowOnObject] = []
@ -45,15 +53,14 @@ class FormOption(BaseModel):
def __init__(self, **data):
super().__init__(**data)
if not self.label:
self.label = I18nObject(
en_US=self.value
)
self.label = I18nObject(en_US=self.value)
class CredentialFormSchema(BaseModel):
"""
Model class for credential form schema.
"""
variable: str
label: I18nObject
type: FormType
@ -69,6 +76,7 @@ class ProviderCredentialSchema(BaseModel):
"""
Model class for provider credential schema.
"""
credential_form_schemas: list[CredentialFormSchema]
@ -81,6 +89,7 @@ class ModelCredentialSchema(BaseModel):
"""
Model class for model credential schema.
"""
model: FieldModelSchema
credential_form_schemas: list[CredentialFormSchema]
@ -89,6 +98,7 @@ class SimpleProviderEntity(BaseModel):
"""
Simple model class for provider.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
@ -101,6 +111,7 @@ class ProviderHelpEntity(BaseModel):
"""
Model class for provider help.
"""
title: I18nObject
url: I18nObject
@ -109,6 +120,7 @@ class ProviderEntity(BaseModel):
"""
Model class for provider.
"""
provider: str
label: I18nObject
description: Optional[I18nObject] = None
@ -137,7 +149,7 @@ class ProviderEntity(BaseModel):
icon_small=self.icon_small,
icon_large=self.icon_large,
supported_model_types=self.supported_model_types,
models=self.models
models=self.models,
)
@ -145,5 +157,6 @@ class ProviderConfig(BaseModel):
"""
Model class for provider config.
"""
provider: str
credentials: dict

View File

@ -5,6 +5,7 @@ class RerankDocument(BaseModel):
"""
Model class for rerank document.
"""
index: int
text: str
score: float
@ -14,5 +15,6 @@ class RerankResult(BaseModel):
"""
Model class for rerank result.
"""
model: str
docs: list[RerankDocument]

View File

@ -9,6 +9,7 @@ class EmbeddingUsage(ModelUsage):
"""
Model class for embedding usage.
"""
tokens: int
total_tokens: int
unit_price: Decimal
@ -22,7 +23,7 @@ class TextEmbeddingResult(BaseModel):
"""
Model class for text embedding result.
"""
model: str
embeddings: list[list[float]]
usage: EmbeddingUsage

View File

@ -3,6 +3,7 @@ from typing import Optional
class InvokeError(Exception):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None:
@ -14,24 +15,29 @@ class InvokeError(Exception):
class InvokeConnectionError(InvokeError):
"""Raised when the Invoke returns connection error."""
description = "Connection Error"
class InvokeServerUnavailableError(InvokeError):
"""Raised when the Invoke returns server unavailable error."""
description = "Server Unavailable Error"
class InvokeRateLimitError(InvokeError):
"""Raised when the Invoke returns rate limit error."""
description = "Rate Limit Error"
class InvokeAuthorizationError(InvokeError):
"""Raised when the Invoke returns authorization error."""
description = "Incorrect model credentials provided, please check and try again. "
class InvokeBadRequestError(InvokeError):
"""Raised when the Invoke returns bad request."""
description = "Bad Request Error"

View File

@ -2,4 +2,5 @@ class CredentialsValidateFailedError(Exception):
"""
Credentials validate failed error
"""
pass

View File

@ -16,15 +16,24 @@ from model_providers.core.model_runtime.entities.model_entities import (
PriceInfo,
PriceType,
)
from model_providers.core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from model_providers.core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from model_providers.core.utils.position_helper import get_position_map, sort_by_position_map
from model_providers.core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeError,
)
from model_providers.core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import (
GPT2Tokenizer,
)
from model_providers.core.utils.position_helper import (
get_position_map,
sort_by_position_map,
)
class AIModel(ABC):
"""
Base class for all models.
"""
model_type: ModelType
model_schemas: list[AIModelEntity] = None
started_at: float = 0
@ -60,18 +69,24 @@ class AIModel(ABC):
:param error: model invoke error
:return: unified error
"""
provider_name = self.__class__.__module__.split('.')[-3]
provider_name = self.__class__.__module__.split(".")[-3]
for invoke_error, model_errors in self._invoke_error_mapping.items():
if isinstance(error, tuple(model_errors)):
if invoke_error == InvokeAuthorizationError:
return invoke_error(description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. ")
return invoke_error(
description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. "
)
return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}")
return invoke_error(
description=f"[{provider_name}] {invoke_error.description}, {str(error)}"
)
return InvokeError(description=f"[{provider_name}] Error: {str(error)}")
def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo:
def get_price(
self, model: str, credentials: dict, price_type: PriceType, tokens: int
) -> PriceInfo:
"""
Get price for given model and tokens
@ -99,15 +114,17 @@ class AIModel(ABC):
if unit_price is None:
return PriceInfo(
unit_price=decimal.Decimal('0.0'),
unit=decimal.Decimal('0.0'),
total_amount=decimal.Decimal('0.0'),
unit_price=decimal.Decimal("0.0"),
unit=decimal.Decimal("0.0"),
total_amount=decimal.Decimal("0.0"),
currency="USD",
)
# calculate total amount
total_amount = tokens * unit_price * price_config.unit
total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
total_amount = total_amount.quantize(
decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP
)
return PriceInfo(
unit_price=unit_price,
@ -128,24 +145,28 @@ class AIModel(ABC):
model_schemas = []
# get module name
model_type = self.__class__.__module__.split('.')[-1]
model_type = self.__class__.__module__.split(".")[-1]
# get provider name
provider_name = self.__class__.__module__.split('.')[-3]
provider_name = self.__class__.__module__.split(".")[-3]
# get the path of current classes
current_path = os.path.abspath(__file__)
# get parent path of the current path
provider_model_type_path = os.path.join(os.path.dirname(os.path.dirname(current_path)), provider_name, model_type)
provider_model_type_path = os.path.join(
os.path.dirname(os.path.dirname(current_path)), provider_name, model_type
)
# get all yaml files path under provider_model_type_path that do not start with __
model_schema_yaml_paths = [
os.path.join(provider_model_type_path, model_schema_yaml)
for model_schema_yaml in os.listdir(provider_model_type_path)
if not model_schema_yaml.startswith('__')
and not model_schema_yaml.startswith('_')
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
and model_schema_yaml.endswith('.yaml')
if not model_schema_yaml.startswith("__")
and not model_schema_yaml.startswith("_")
and os.path.isfile(
os.path.join(provider_model_type_path, model_schema_yaml)
)
and model_schema_yaml.endswith(".yaml")
]
# get _position.yaml file path
@ -154,59 +175,73 @@ class AIModel(ABC):
# traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths:
# read yaml data from yaml file
with open(model_schema_yaml_path, encoding='utf-8') as f:
with open(model_schema_yaml_path, encoding="utf-8") as f:
yaml_data = yaml.safe_load(f)
new_parameter_rules = []
for parameter_rule in yaml_data.get('parameter_rules', []):
if 'use_template' in parameter_rule:
for parameter_rule in yaml_data.get("parameter_rules", []):
if "use_template" in parameter_rule:
try:
default_parameter_name = DefaultParameterName.value_of(parameter_rule['use_template'])
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
default_parameter_name = DefaultParameterName.value_of(
parameter_rule["use_template"]
)
default_parameter_rule = (
self._get_default_parameter_rule_variable_map(
default_parameter_name
)
)
copy_default_parameter_rule = default_parameter_rule.copy()
copy_default_parameter_rule.update(parameter_rule)
parameter_rule = copy_default_parameter_rule
except ValueError:
pass
if 'label' not in parameter_rule:
parameter_rule['label'] = {
'zh_Hans': parameter_rule['name'],
'en_US': parameter_rule['name']
if "label" not in parameter_rule:
parameter_rule["label"] = {
"zh_Hans": parameter_rule["name"],
"en_US": parameter_rule["name"],
}
new_parameter_rules.append(parameter_rule)
yaml_data['parameter_rules'] = new_parameter_rules
yaml_data["parameter_rules"] = new_parameter_rules
if 'label' not in yaml_data:
yaml_data['label'] = {
'zh_Hans': yaml_data['model'],
'en_US': yaml_data['model']
if "label" not in yaml_data:
yaml_data["label"] = {
"zh_Hans": yaml_data["model"],
"en_US": yaml_data["model"],
}
yaml_data['fetch_from'] = FetchFrom.PREDEFINED_MODEL.value
yaml_data["fetch_from"] = FetchFrom.PREDEFINED_MODEL.value
try:
# yaml_data to entity
model_schema = AIModelEntity(**yaml_data)
except Exception as e:
model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml")
raise Exception(f'Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:'
f' {str(e)}')
model_schema_yaml_file_name = os.path.basename(
model_schema_yaml_path
).rstrip(".yaml")
raise Exception(
f"Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:"
f" {str(e)}"
)
# cache model schema
model_schemas.append(model_schema)
# resort model schemas by position
model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model)
model_schemas = sort_by_position_map(
position_map, model_schemas, lambda x: x.model
)
# cache model schemas
self.model_schemas = model_schemas
return model_schemas
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
def get_model_schema(
self, model: str, credentials: Optional[dict] = None
) -> Optional[AIModelEntity]:
"""
Get model schema by model name and credentials
@ -222,13 +257,17 @@ class AIModel(ABC):
return model_map[model]
if credentials:
model_schema = self.get_customizable_model_schema_from_credentials(model, credentials)
model_schema = self.get_customizable_model_schema_from_credentials(
model, credentials
)
if model_schema:
return model_schema
return None
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def get_customizable_model_schema_from_credentials(
self, model: str, credentials: dict
) -> Optional[AIModelEntity]:
"""
Get customizable model schema from credentials
@ -238,7 +277,9 @@ class AIModel(ABC):
"""
return self._get_customizable_model_schema(model, credentials)
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def _get_customizable_model_schema(
self, model: str, credentials: dict
) -> Optional[AIModelEntity]:
"""
Get customizable model schema and fill in the template
"""
@ -252,26 +293,51 @@ class AIModel(ABC):
for parameter_rule in schema.parameter_rules:
if parameter_rule.use_template:
try:
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
if not parameter_rule.max and 'max' in default_parameter_rule:
parameter_rule.max = default_parameter_rule['max']
if not parameter_rule.min and 'min' in default_parameter_rule:
parameter_rule.min = default_parameter_rule['min']
if not parameter_rule.default and 'default' in default_parameter_rule:
parameter_rule.default = default_parameter_rule['default']
if not parameter_rule.precision and 'precision' in default_parameter_rule:
parameter_rule.precision = default_parameter_rule['precision']
if not parameter_rule.required and 'required' in default_parameter_rule:
parameter_rule.required = default_parameter_rule['required']
if not parameter_rule.help and 'help' in default_parameter_rule:
parameter_rule.help = I18nObject(
en_US=default_parameter_rule['help']['en_US'],
default_parameter_name = DefaultParameterName.value_of(
parameter_rule.use_template
)
default_parameter_rule = (
self._get_default_parameter_rule_variable_map(
default_parameter_name
)
if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
)
if not parameter_rule.max and "max" in default_parameter_rule:
parameter_rule.max = default_parameter_rule["max"]
if not parameter_rule.min and "min" in default_parameter_rule:
parameter_rule.min = default_parameter_rule["min"]
if (
not parameter_rule.default
and "default" in default_parameter_rule
):
parameter_rule.default = default_parameter_rule["default"]
if (
not parameter_rule.precision
and "precision" in default_parameter_rule
):
parameter_rule.precision = default_parameter_rule["precision"]
if (
not parameter_rule.required
and "required" in default_parameter_rule
):
parameter_rule.required = default_parameter_rule["required"]
if not parameter_rule.help and "help" in default_parameter_rule:
parameter_rule.help = I18nObject(
en_US=default_parameter_rule["help"]["en_US"],
)
if not parameter_rule.help.en_US and (
"help" in default_parameter_rule
and "en_US" in default_parameter_rule["help"]
):
parameter_rule.help.en_US = default_parameter_rule["help"][
"en_US"
]
if not parameter_rule.help.zh_Hans and (
"help" in default_parameter_rule
and "zh_Hans" in default_parameter_rule["help"]
):
parameter_rule.help.zh_Hans = default_parameter_rule[
"help"
].get("zh_Hans", default_parameter_rule["help"]["en_US"])
except ValueError:
pass
@ -281,7 +347,9 @@ class AIModel(ABC):
return schema
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def get_customizable_model_schema(
self, model: str, credentials: dict
) -> Optional[AIModelEntity]:
"""
Get customizable model schema
@ -291,7 +359,9 @@ class AIModel(ABC):
"""
return None
def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName) -> dict:
def _get_default_parameter_rule_variable_map(
self, name: DefaultParameterName
) -> dict:
"""
Get default parameter rule for given name
@ -301,7 +371,7 @@ class AIModel(ABC):
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
if not default_parameter_rule:
raise Exception(f'Invalid model parameter rule name {name}')
raise Exception(f"Invalid model parameter rule name {name}")
return default_parameter_rule

View File

@ -7,8 +7,16 @@ from collections.abc import Generator
from typing import Optional, Union
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.callbacks.logging_callback import LoggingCallback
from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from model_providers.core.model_runtime.callbacks.logging_callback import (
LoggingCallback,
)
from model_providers.core.model_runtime.entities.llm_entities import (
LLMMode,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@ -32,13 +40,21 @@ class LargeLanguageModel(AIModel):
"""
Model class for large language model.
"""
model_type: ModelType = ModelType.LLM
def invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
-> Union[LLMResult, Generator]:
def invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -57,7 +73,9 @@ class LargeLanguageModel(AIModel):
if model_parameters is None:
model_parameters = {}
model_parameters = self._validate_and_filter_model_parameters(model, model_parameters, credentials)
model_parameters = self._validate_and_filter_model_parameters(
model, model_parameters, credentials
)
self.started_at = time.perf_counter()
@ -76,7 +94,7 @@ class LargeLanguageModel(AIModel):
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
callbacks=callbacks,
)
try:
@ -90,10 +108,19 @@ class LargeLanguageModel(AIModel):
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
callbacks=callbacks,
)
else:
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
result = self._invoke(
model,
credentials,
prompt_messages,
model_parameters,
tools,
stop,
stream,
user,
)
except Exception as e:
self._trigger_invoke_error_callbacks(
model=model,
@ -105,7 +132,7 @@ class LargeLanguageModel(AIModel):
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
callbacks=callbacks,
)
raise self._transform_invoke_error(e)
@ -121,7 +148,7 @@ class LargeLanguageModel(AIModel):
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
callbacks=callbacks,
)
else:
self._trigger_after_invoke_callbacks(
@ -134,15 +161,23 @@ class LargeLanguageModel(AIModel):
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
callbacks=callbacks,
)
return result
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
def _code_block_mode_wrapper(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper, ensure the response is a code block with output markdown quote
@ -177,36 +212,44 @@ if you are not sure about the structure.
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
model_parameters.pop("response_format")
stop = stop or []
stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block)
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
if len(prompt_messages) > 0 and isinstance(
prompt_messages[0], SystemPromptMessage
):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", prompt_messages[0].content)
content=block_prompts.replace(
"{{instructions}}", prompt_messages[0].content
)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
))
prompt_messages.insert(
0,
SystemPromptMessage(
content=block_prompts.replace(
"{{instructions}}",
f"Please output a valid {code_block} object.",
)
),
)
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
if len(prompt_messages) > 0 and isinstance(
prompt_messages[-1], UserPromptMessage
):
# add ```JSON\n to the last message
prompt_messages[-1].content += f"\n```{code_block}\n"
else:
# append a user message
prompt_messages.append(UserPromptMessage(
content=f"```{code_block}\n"
))
prompt_messages.append(UserPromptMessage(content=f"```{code_block}\n"))
response = self._invoke(
model=model,
@ -216,33 +259,40 @@ if you are not sure about the structure.
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
if isinstance(response, Generator):
first_chunk = next(response)
def new_generator():
yield first_chunk
yield from response
if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"):
if (
first_chunk.delta.message.content
and first_chunk.delta.message.content.startswith("`")
):
return self._code_block_mode_stream_processor_with_backtick(
model=model,
prompt_messages=prompt_messages,
input_generator=new_generator()
input_generator=new_generator(),
)
else:
return self._code_block_mode_stream_processor(
model=model,
prompt_messages=prompt_messages,
input_generator=new_generator()
input_generator=new_generator(),
)
return response
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
input_generator: Generator[LLMResultChunk, None, None]
) -> Generator[LLMResultChunk, None, None]:
def _code_block_mode_stream_processor(
self,
model: str,
prompt_messages: list[PromptMessage],
input_generator: Generator[LLMResultChunk, None, None],
) -> Generator[LLMResultChunk, None, None]:
"""
Code block mode stream processor, ensure the response is a code block with output markdown quote
@ -291,15 +341,17 @@ if you are not sure about the structure.
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=new_piece,
tool_calls=[]
content=new_piece, tool_calls=[]
),
)
),
)
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
input_generator: Generator[LLMResultChunk, None, None]) \
-> Generator[LLMResultChunk, None, None]:
def _code_block_mode_stream_processor_with_backtick(
self,
model: str,
prompt_messages: list,
input_generator: Generator[LLMResultChunk, None, None],
) -> Generator[LLMResultChunk, None, None]:
"""
Code block mode stream processor, ensure the response is a code block with output markdown quote.
This version skips the language identifier that follows the opening triple backticks.
@ -366,26 +418,31 @@ if you are not sure about the structure.
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=new_piece,
tool_calls=[]
content=new_piece, tool_calls=[]
),
)
),
)
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator:
def _invoke_result_generator(
self,
model: str,
result: Generator,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
) -> Generator:
"""
Invoke result generator
:param result: result generator
:return: result generator
"""
prompt_message = AssistantPromptMessage(
content=""
)
prompt_message = AssistantPromptMessage(content="")
usage = None
system_fingerprint = None
real_model = model
@ -404,7 +461,7 @@ if you are not sure about the structure.
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
callbacks=callbacks,
)
prompt_message.content += chunk.delta.message.content
@ -424,7 +481,7 @@ if you are not sure about the structure.
prompt_messages=prompt_messages,
message=prompt_message,
usage=usage if usage else LLMUsage.empty_usage(),
system_fingerprint=system_fingerprint
system_fingerprint=system_fingerprint,
),
credentials=credentials,
prompt_messages=prompt_messages,
@ -433,15 +490,21 @@ if you are not sure about the structure.
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
callbacks=callbacks,
)
@abstractmethod
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -456,10 +519,15 @@ if you are not sure about the structure.
:return: full response or stream response chunk generator result
"""
raise NotImplementedError
@abstractmethod
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -489,7 +557,9 @@ if you are not sure about the structure.
for word in result.message.content:
assistant_prompt_message = AssistantPromptMessage(
content=word,
tool_calls=tool_calls if index == (len(result.message.content) - 1) else []
tool_calls=tool_calls
if index == (len(result.message.content) - 1)
else [],
)
yield LLMResultChunk(
@ -499,7 +569,7 @@ if you are not sure about the structure.
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
)
),
)
index += 1
@ -531,11 +601,15 @@ if you are not sure about the structure.
mode = LLMMode.CHAT
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE])
mode = LLMMode.value_of(
model_schema.model_properties[ModelPropertyKey.MODE]
)
return mode
def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage:
def _calc_response_usage(
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
) -> LLMUsage:
"""
Calculate response usage
@ -558,7 +632,7 @@ if you are not sure about the structure.
model=model,
credentials=credentials,
price_type=PriceType.OUTPUT,
tokens=completion_tokens
tokens=completion_tokens,
)
# transform usage
@ -572,18 +646,26 @@ if you are not sure about the structure.
completion_price_unit=completion_price_info.unit,
completion_price=completion_price_info.total_amount,
total_tokens=prompt_tokens + completion_tokens,
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
total_price=prompt_price_info.total_amount
+ completion_price_info.total_amount,
currency=prompt_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage
def _trigger_before_invoke_callbacks(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
def _trigger_before_invoke_callbacks(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
) -> None:
"""
Trigger before invoke callbacks
@ -609,19 +691,29 @@ if you are not sure about the structure.
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
except Exception as e:
if callback.raise_error:
raise e
else:
logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}")
logger.warning(
f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}"
)
def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
def _trigger_new_chunk_callbacks(
self,
chunk: LLMResultChunk,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
) -> None:
"""
Trigger new chunk callbacks
@ -648,19 +740,29 @@ if you are not sure about the structure.
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
except Exception as e:
if callback.raise_error:
raise e
else:
logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}")
logger.warning(
f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}"
)
def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
def _trigger_after_invoke_callbacks(
self,
model: str,
result: LLMResult,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
) -> None:
"""
Trigger after invoke callbacks
@ -688,19 +790,29 @@ if you are not sure about the structure.
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
except Exception as e:
if callback.raise_error:
raise e
else:
logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}")
logger.warning(
f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}"
)
def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
def _trigger_invoke_error_callbacks(
self,
model: str,
ex: Exception,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
) -> None:
"""
Trigger invoke error callbacks
@ -728,15 +840,19 @@ if you are not sure about the structure.
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
except Exception as e:
if callback.raise_error:
raise e
else:
logger.warning(f"Callback {callback.__class__.__name__} on_invoke_error failed with error {e}")
logger.warning(
f"Callback {callback.__class__.__name__} on_invoke_error failed with error {e}"
)
def _validate_and_filter_model_parameters(self, model: str, model_parameters: dict, credentials: dict) -> dict:
def _validate_and_filter_model_parameters(
self, model: str, model_parameters: dict, credentials: dict
) -> dict:
"""
Validate model parameters
@ -753,16 +869,23 @@ if you are not sure about the structure.
parameter_name = parameter_rule.name
parameter_value = model_parameters.get(parameter_name)
if parameter_value is None:
if parameter_rule.use_template and parameter_rule.use_template in model_parameters:
if (
parameter_rule.use_template
and parameter_rule.use_template in model_parameters
):
# if parameter value is None, use template value variable name instead
parameter_value = model_parameters[parameter_rule.use_template]
else:
if parameter_rule.required:
if parameter_rule.default is not None:
filtered_model_parameters[parameter_name] = parameter_rule.default
filtered_model_parameters[
parameter_name
] = parameter_rule.default
continue
else:
raise ValueError(f"Model Parameter {parameter_name} is required.")
raise ValueError(
f"Model Parameter {parameter_name} is required."
)
else:
continue
@ -772,47 +895,81 @@ if you are not sure about the structure.
raise ValueError(f"Model Parameter {parameter_name} should be int.")
# validate parameter value range
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
if (
parameter_rule.min is not None
and parameter_value < parameter_rule.min
):
raise ValueError(
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.")
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}."
)
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
if (
parameter_rule.max is not None
and parameter_value > parameter_rule.max
):
raise ValueError(
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
)
elif parameter_rule.type == ParameterType.FLOAT:
if not isinstance(parameter_value, float | int):
raise ValueError(f"Model Parameter {parameter_name} should be float.")
raise ValueError(
f"Model Parameter {parameter_name} should be float."
)
# validate parameter value precision
if parameter_rule.precision is not None:
if parameter_rule.precision == 0:
if parameter_value != int(parameter_value):
raise ValueError(f"Model Parameter {parameter_name} should be int.")
else:
if parameter_value != round(parameter_value, parameter_rule.precision):
raise ValueError(
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places.")
f"Model Parameter {parameter_name} should be int."
)
else:
if parameter_value != round(
parameter_value, parameter_rule.precision
):
raise ValueError(
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places."
)
# validate parameter value range
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
if (
parameter_rule.min is not None
and parameter_value < parameter_rule.min
):
raise ValueError(
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.")
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}."
)
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
if (
parameter_rule.max is not None
and parameter_value > parameter_rule.max
):
raise ValueError(
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
)
elif parameter_rule.type == ParameterType.BOOLEAN:
if not isinstance(parameter_value, bool):
raise ValueError(f"Model Parameter {parameter_name} should be bool.")
raise ValueError(
f"Model Parameter {parameter_name} should be bool."
)
elif parameter_rule.type == ParameterType.STRING:
if not isinstance(parameter_value, str):
raise ValueError(f"Model Parameter {parameter_name} should be string.")
raise ValueError(
f"Model Parameter {parameter_name} should be string."
)
# validate options
if parameter_rule.options and parameter_value not in parameter_rule.options:
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
if (
parameter_rule.options
and parameter_value not in parameter_rule.options
):
raise ValueError(
f"Model Parameter {parameter_name} should be one of {parameter_rule.options}."
)
else:
raise ValueError(f"Model Parameter {parameter_name} type {parameter_rule.type} is not supported.")
raise ValueError(
f"Model Parameter {parameter_name} type {parameter_rule.type} is not supported."
)
filtered_model_parameters[parameter_name] = parameter_value

View File

@ -4,7 +4,10 @@ from abc import ABC, abstractmethod
import yaml
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelType,
)
from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
@ -36,24 +39,26 @@ class ModelProvider(ABC):
return self.provider_schema
# get dirname of the current path
provider_name = self.__class__.__module__.split('.')[-1]
provider_name = self.__class__.__module__.split(".")[-1]
# get the path of the model_provider classes
base_path = os.path.abspath(__file__)
current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
current_path = os.path.join(
os.path.dirname(os.path.dirname(base_path)), provider_name
)
# read provider schema from yaml file
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
yaml_path = os.path.join(current_path, f"{provider_name}.yaml")
yaml_data = {}
if os.path.exists(yaml_path):
with open(yaml_path, encoding='utf-8') as f:
with open(yaml_path, encoding="utf-8") as f:
yaml_data = yaml.safe_load(f)
try:
# yaml_data to entity
provider_schema = ProviderEntity(**yaml_data)
except Exception as e:
raise Exception(f'Invalid provider schema for {provider_name}: {str(e)}')
raise Exception(f"Invalid provider schema for {provider_name}: {str(e)}")
# cache schema
self.provider_schema = provider_schema
@ -88,37 +93,52 @@ class ModelProvider(ABC):
:return:
"""
# get dirname of the current path
provider_name = self.__class__.__module__.split('.')[-1]
provider_name = self.__class__.__module__.split(".")[-1]
if f"{provider_name}.{model_type.value}" in self.model_instance_map:
return self.model_instance_map[f"{provider_name}.{model_type.value}"]
# get the path of the model type classes
base_path = os.path.abspath(__file__)
model_type_name = model_type.value.replace('-', '_')
model_type_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name)
model_type_py_path = os.path.join(model_type_path, f'{model_type_name}.py')
model_type_name = model_type.value.replace("-", "_")
model_type_path = os.path.join(
os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name
)
model_type_py_path = os.path.join(model_type_path, f"{model_type_name}.py")
if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path):
raise Exception(f'Invalid model type {model_type} for provider {provider_name}')
raise Exception(
f"Invalid model type {model_type} for provider {provider_name}"
)
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
spec = importlib.util.spec_from_file_location(f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path)
parent_module = ".".join(self.__class__.__module__.split(".")[:-1])
spec = importlib.util.spec_from_file_location(
f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path
)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
model_class = None
for name, obj in vars(mod).items():
if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__
and obj != AIModel and obj.__module__ == mod.__name__):
if (
isinstance(obj, type)
and issubclass(obj, AIModel)
and not obj.__abstractmethods__
and obj != AIModel
and obj.__module__ == mod.__name__
):
model_class = obj
break
if not model_class:
raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')
raise Exception(
f"Missing AIModel Class for model type {model_type} in {model_type_py_path}"
)
model_instance_map = model_class()
self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map
self.model_instance_map[
f"{provider_name}.{model_type.value}"
] = model_instance_map
return model_instance_map

View File

@ -10,11 +10,12 @@ class ModerationModel(AIModel):
"""
Model class for moderation model.
"""
model_type: ModelType = ModelType.MODERATION
def invoke(self, model: str, credentials: dict,
text: str, user: Optional[str] = None) \
-> bool:
def invoke(
self, model: str, credentials: dict, text: str, user: Optional[str] = None
) -> bool:
"""
Invoke moderation model
@ -32,9 +33,9 @@ class ModerationModel(AIModel):
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict,
text: str, user: Optional[str] = None) \
-> bool:
def _invoke(
self, model: str, credentials: dict, text: str, user: Optional[str] = None
) -> bool:
"""
Invoke large language model
@ -45,4 +46,3 @@ class ModerationModel(AIModel):
:return: false if text is safe, true otherwise
"""
raise NotImplementedError

View File

@ -11,12 +11,19 @@ class RerankModel(AIModel):
"""
Base Model class for rerank model.
"""
model_type: ModelType = ModelType.RERANK
def invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
def invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
@ -32,15 +39,23 @@ class RerankModel(AIModel):
self.started_at = time.perf_counter()
try:
return self._invoke(model, credentials, query, docs, score_threshold, top_n, user)
return self._invoke(
model, credentials, query, docs, score_threshold, top_n, user
)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model

View File

@ -10,11 +10,12 @@ class Speech2TextModel(AIModel):
"""
Model class for speech2text model.
"""
model_type: ModelType = ModelType.SPEECH2TEXT
def invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
def invoke(
self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None
) -> str:
"""
Invoke large language model
@ -30,9 +31,9 @@ class Speech2TextModel(AIModel):
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
def _invoke(
self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None
) -> str:
"""
Invoke large language model
@ -54,4 +55,4 @@ class Speech2TextModel(AIModel):
current_dir = os.path.dirname(os.path.abspath(__file__))
# Construct the path to the audio file
return os.path.join(current_dir, 'audio.mp3')
return os.path.join(current_dir, "audio.mp3")

View File

@ -9,11 +9,17 @@ class Text2ImageModel(AIModel):
"""
Model class for text2img model.
"""
model_type: ModelType = ModelType.TEXT2IMG
def invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
def invoke(
self,
model: str,
credentials: dict,
prompt: str,
model_parameters: dict,
user: Optional[str] = None,
) -> list[IO[bytes]]:
"""
Invoke Text2Image model
@ -31,9 +37,14 @@ class Text2ImageModel(AIModel):
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
def _invoke(
self,
model: str,
credentials: dict,
prompt: str,
model_parameters: dict,
user: Optional[str] = None,
) -> list[IO[bytes]]:
"""
Invoke Text2Image model

View File

@ -2,8 +2,13 @@ import time
from abc import abstractmethod
from typing import Optional
from model_providers.core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from model_providers.core.model_runtime.entities.model_entities import (
ModelPropertyKey,
ModelType,
)
from model_providers.core.model_runtime.entities.text_embedding_entities import (
TextEmbeddingResult,
)
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
@ -11,11 +16,16 @@ class TextEmbeddingModel(AIModel):
"""
Model class for text embedding model.
"""
model_type: ModelType = ModelType.TEXT_EMBEDDING
def invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
def invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
Invoke large language model
@ -33,9 +43,13 @@ class TextEmbeddingModel(AIModel):
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
Invoke large language model
@ -69,7 +83,10 @@ class TextEmbeddingModel(AIModel):
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties:
if (
model_schema
and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties
):
return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
return 1000
@ -84,7 +101,10 @@ class TextEmbeddingModel(AIModel):
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
if (
model_schema
and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
):
return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
return 1

View File

@ -7,27 +7,30 @@ from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
_tokenizer = None
_lock = Lock()
class GPT2Tokenizer:
@staticmethod
def _get_num_tokens_by_gpt2(text: str) -> int:
"""
use gpt2 tokenizer to get num tokens
use gpt2 tokenizer to get num tokens
"""
_tokenizer = GPT2Tokenizer.get_encoder()
tokens = _tokenizer.encode(text, verbose=False)
return len(tokens)
@staticmethod
def get_num_tokens(text: str) -> int:
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
@staticmethod
def get_encoder() -> Any:
global _tokenizer, _lock
with _lock:
if _tokenizer is None:
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), 'gpt2')
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(
gpt2_tokenizer_path
)
return _tokenizer
return _tokenizer

View File

@ -4,7 +4,10 @@ import uuid
from abc import abstractmethod
from typing import Optional
from model_providers.core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from model_providers.core.model_runtime.entities.model_entities import (
ModelPropertyKey,
ModelType,
)
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
@ -13,10 +16,19 @@ class TTSModel(AIModel):
"""
Model class for ttstext model.
"""
model_type: ModelType = ModelType.TTS
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
user: Optional[str] = None):
def invoke(
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
streaming: bool,
user: Optional[str] = None,
):
"""
Invoke large language model
@ -31,14 +43,29 @@ class TTSModel(AIModel):
"""
try:
self._is_ffmpeg_installed()
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming,
content_text=content_text, voice=voice, tenant_id=tenant_id)
return self._invoke(
model=model,
credentials=credentials,
user=user,
streaming=streaming,
content_text=content_text,
voice=voice,
tenant_id=tenant_id,
)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
user: Optional[str] = None):
def _invoke(
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
streaming: bool,
user: Optional[str] = None,
):
"""
Invoke large language model
@ -53,7 +80,9 @@ class TTSModel(AIModel):
"""
raise NotImplementedError
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
def get_tts_model_voices(
self, model: str, credentials: dict, language: Optional[str] = None
) -> list:
"""
Get voice for given tts model voices
@ -67,9 +96,13 @@ class TTSModel(AIModel):
if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties:
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
if language:
return [{'name': d['name'], 'value': d['mode']} for d in voices if language and language in d.get('language')]
return [
{"name": d["name"], "value": d["mode"]}
for d in voices
if language and language in d.get("language")
]
else:
return [{'name': d['name'], 'value': d['mode']} for d in voices]
return [{"name": d["name"], "value": d["mode"]} for d in voices]
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
"""
@ -81,7 +114,10 @@ class TTSModel(AIModel):
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.DEFAULT_VOICE in model_schema.model_properties:
if (
model_schema
and ModelPropertyKey.DEFAULT_VOICE in model_schema.model_properties
):
return model_schema.model_properties[ModelPropertyKey.DEFAULT_VOICE]
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
@ -94,7 +130,10 @@ class TTSModel(AIModel):
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties:
if (
model_schema
and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties
):
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
@ -104,7 +143,10 @@ class TTSModel(AIModel):
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties:
if (
model_schema
and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties
):
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
@ -114,13 +156,16 @@ class TTSModel(AIModel):
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties:
if (
model_schema
and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties
):
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
@staticmethod
def _split_text_into_sentences(text: str, limit: int, delimiters=None):
if delimiters is None:
delimiters = set('。!?;\n')
delimiters = set("。!?;\n")
buf = []
word_count = 0
@ -128,7 +173,7 @@ class TTSModel(AIModel):
buf.append(char)
if char in delimiters:
if word_count >= limit:
yield ''.join(buf)
yield "".join(buf)
buf = []
word_count = 0
else:
@ -137,7 +182,7 @@ class TTSModel(AIModel):
word_count += 1
if buf:
yield ''.join(buf)
yield "".join(buf)
@staticmethod
def _is_ffmpeg_installed():
@ -146,13 +191,17 @@ class TTSModel(AIModel):
if "ffmpeg version" in output.decode("utf-8"):
return True
else:
raise InvokeBadRequestError("ffmpeg is not installed, "
"details: https://docs.dify.ai/getting-started/install-self-hosted"
"/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech")
raise InvokeBadRequestError(
"ffmpeg is not installed, "
"details: https://docs.dify.ai/getting-started/install-self-hosted"
"/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech"
)
except Exception:
raise InvokeBadRequestError("ffmpeg is not installed, "
"details: https://docs.dify.ai/getting-started/install-self-hosted"
"/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech")
raise InvokeBadRequestError(
"ffmpeg is not installed, "
"details: https://docs.dify.ai/getting-started/install-self-hosted"
"/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech"
)
# Todo: To improve the streaming function
@staticmethod
@ -160,6 +209,6 @@ class TTSModel(AIModel):
hash_object = hashlib.sha256(file_content.encode())
hex_digest = hash_object.hexdigest()
namespace_uuid = uuid.UUID('a5da6ef9-b303-596f-8e88-bf8fa40f4b31')
namespace_uuid = uuid.UUID("a5da6ef9-b303-596f-8e88-bf8fa40f4b31")
unique_uuid = uuid.uuid5(namespace_uuid, hex_digest)
return str(unique_uuid)

View File

@ -1,3 +1,5 @@
from model_providers.core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from model_providers.core.model_runtime.model_providers.model_provider_factory import (
ModelProviderFactory,
)
model_provider_factory = ModelProviderFactory()

View File

@ -1,8 +1,12 @@
import logging
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
ModelProvider,
)
logger = logging.getLogger(__name__)
@ -21,11 +25,12 @@ class AnthropicProvider(ModelProvider):
# Use `claude-instant-1` model for validate,
model_instance.validate_credentials(
model='claude-instant-1.2',
credentials=credentials
model="claude-instant-1.2", credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(
f"{self.get_provider_schema().provider} credentials validate failed"
)
raise ex

View File

@ -18,7 +18,11 @@ from anthropic.types import (
from httpx import Timeout
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
@ -37,8 +41,12 @@ from model_providers.core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
@ -51,11 +59,17 @@ if you are not sure about the structure.
class AnthropicLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -70,11 +84,20 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
# invoke model
return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
return self._chat_generate(
model, credentials, prompt_messages, model_parameters, stop, stream, user
)
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _chat_generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke llm chat model
@ -91,23 +114,27 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
credentials_kwargs = self._to_credential_kwargs(credentials)
# transform model parameters from completion api of anthropic to chat api
if 'max_tokens_to_sample' in model_parameters:
model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample')
if "max_tokens_to_sample" in model_parameters:
model_parameters["max_tokens"] = model_parameters.pop(
"max_tokens_to_sample"
)
# init model client
client = Anthropic(**credentials_kwargs)
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop
extra_model_kwargs["stop_sequences"] = stop
if user:
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
extra_model_kwargs["metadata"] = completion_create_params.Metadata(
user_id=user
)
system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages)
if system:
extra_model_kwargs['system'] = system
extra_model_kwargs["system"] = system
# chat model
response = client.messages.create(
@ -115,22 +142,37 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
messages=prompt_message_dicts,
stream=stream,
**model_parameters,
**extra_model_kwargs
**extra_model_kwargs,
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_chat_generate_stream_response(
model, credentials, response, prompt_messages
)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
return self._handle_chat_generate_response(
model, credentials, response, prompt_messages
)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
def _code_block_mode_wrapper(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if 'response_format' in model_parameters and model_parameters['response_format']:
if (
"response_format" in model_parameters
and model_parameters["response_format"]
):
stop = stop or []
# chat model
self._transform_chat_json_prompts(
@ -142,17 +184,33 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
stop=stop,
stream=stream,
user=user,
response_format=model_parameters['response_format']
response_format=model_parameters["response_format"],
)
model_parameters.pop('response_format')
model_parameters.pop("response_format")
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
return self._invoke(
model,
credentials,
prompt_messages,
model_parameters,
tools,
stop,
stream,
user,
)
def _transform_chat_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
def _transform_chat_json_prompts(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
response_format: str = "JSON",
) -> None:
"""
Transform json prompts
"""
@ -162,25 +220,40 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
stop.append("\n```")
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
if len(prompt_messages) > 0 and isinstance(
prompt_messages[0], SystemPromptMessage
):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace(
"{{instructions}}", prompt_messages[0].content
).replace("{{block}}", response_format)
)
prompt_messages.append(
AssistantPromptMessage(content=f"\n```{response_format}")
)
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
prompt_messages.insert(
0,
SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace(
"{{instructions}}",
f"Please output a valid {response_format} object.",
).replace("{{block}}", response_format)
),
)
prompt_messages.append(
AssistantPromptMessage(content=f"\n```{response_format}")
)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -214,13 +287,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
"temperature": 0,
"max_tokens": 20,
},
stream=False
stream=False,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message,
prompt_messages: list[PromptMessage]) -> LLMResult:
def _handle_chat_generate_response(
self,
model: str,
credentials: dict,
response: Message,
prompt_messages: list[PromptMessage],
) -> LLMResult:
"""
Handle llm chat response
@ -243,24 +321,32 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
completion_tokens = self.get_num_tokens(
model, credentials, [assistant_prompt_message]
)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
# transform response
response = LLMResult(
model=response.model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage
usage=usage,
)
return response
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
response: Stream[MessageStreamEvent],
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_chat_generate_stream_response(
self,
model: str,
credentials: dict,
response: Stream[MessageStreamEvent],
prompt_messages: list[PromptMessage],
) -> Generator:
"""
Handle llm chat stream response
@ -269,7 +355,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages
:return: llm response chunk generator
"""
full_assistant_content = ''
full_assistant_content = ""
return_model = None
input_tokens = 0
output_tokens = 0
@ -284,28 +370,26 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
finish_reason = chunk.delta.stop_reason
elif isinstance(chunk, MessageStopEvent):
# transform usage
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
usage = self._calc_response_usage(
model, credentials, input_tokens, output_tokens
)
yield LLMResultChunk(
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index + 1,
message=AssistantPromptMessage(
content=''
),
message=AssistantPromptMessage(content=""),
finish_reason=finish_reason,
usage=usage
)
usage=usage,
),
)
elif isinstance(chunk, ContentBlockDeltaEvent):
chunk_text = chunk.delta.text if chunk.delta.text else ''
chunk_text = chunk.delta.text if chunk.delta.text else ""
full_assistant_content += chunk_text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text
)
assistant_prompt_message = AssistantPromptMessage(content=chunk_text)
index = chunk.index
@ -315,7 +399,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta(
index=chunk.index,
message=assistant_prompt_message,
)
),
)
def _to_credential_kwargs(self, credentials: dict) -> dict:
@ -326,18 +410,22 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return:
"""
credentials_kwargs = {
"api_key": credentials['anthropic_api_key'],
"api_key": credentials["anthropic_api_key"],
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"max_retries": 1,
}
if 'anthropic_api_url' in credentials and credentials['anthropic_api_url']:
credentials['anthropic_api_url'] = credentials['anthropic_api_url'].rstrip('/')
credentials_kwargs['base_url'] = credentials['anthropic_api_url']
if "anthropic_api_url" in credentials and credentials["anthropic_api_url"]:
credentials["anthropic_api_url"] = credentials["anthropic_api_url"].rstrip(
"/"
)
credentials_kwargs["base_url"] = credentials["anthropic_api_url"]
return credentials_kwargs
def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
def _convert_prompt_messages(
self, prompt_messages: list[PromptMessage]
) -> tuple[str, list[dict]]:
"""
Convert prompt messages to dict list and system
"""
@ -348,7 +436,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
if isinstance(message, SystemPromptMessage):
system += message.content + ("\n" if not system else "")
else:
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
prompt_message_dicts.append(
self._convert_prompt_message_to_dict(message)
)
return system, prompt_message_dicts
@ -364,38 +454,57 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
message_content = cast(
TextPromptMessageContent, message_content
)
sub_message_dict = {
"type": "text",
"text": message_content.data
"text": message_content.data,
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
message_content = cast(
ImagePromptMessageContent, message_content
)
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
base64_data = base64.b64encode(image_content).decode('utf-8')
image_content = requests.get(
message_content.data
).content
mime_type, _ = mimetypes.guess_type(
message_content.data
)
base64_data = base64.b64encode(image_content).decode(
"utf-8"
)
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
raise ValueError(
f"Failed to fetch image data from url {message_content.data}, {ex}"
)
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
raise ValueError(f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp")
if mime_type not in [
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
]:
raise ValueError(
f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp"
)
sub_message_dict = {
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_data
}
"data": base64_data,
},
}
sub_messages.append(sub_message_dict)
@ -450,7 +559,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return message_text
def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) -> str:
def _convert_messages_to_prompt_anthropic(
self, messages: list[PromptMessage]
) -> str:
"""
Format a list of messages into a full prompt for the Anthropic model
@ -458,15 +569,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
if not messages:
return ''
return ""
messages = messages.copy() # don't mutate the original list
if not isinstance(messages[-1], AssistantPromptMessage):
messages.append(AssistantPromptMessage(content=""))
text = "".join(
self._convert_one_message_to_text(message)
for message in messages
self._convert_one_message_to_text(message) for message in messages
)
# trim off the trailing ' ' that might come from the "Assistant: "
@ -485,22 +595,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return {
InvokeConnectionError: [
anthropic.APIConnectionError,
anthropic.APITimeoutError
],
InvokeServerUnavailableError: [
anthropic.InternalServerError
],
InvokeRateLimitError: [
anthropic.RateLimitError
anthropic.APITimeoutError,
],
InvokeServerUnavailableError: [anthropic.InternalServerError],
InvokeRateLimitError: [anthropic.RateLimitError],
InvokeAuthorizationError: [
anthropic.AuthenticationError,
anthropic.PermissionDeniedError
anthropic.PermissionDeniedError,
],
InvokeBadRequestError: [
anthropic.BadRequestError,
anthropic.NotFoundError,
anthropic.UnprocessableEntityError,
anthropic.APIError
]
anthropic.APIError,
],
}

View File

@ -9,16 +9,18 @@ from model_providers.core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.model_providers.azure_openai._constant import AZURE_OPENAI_API_VERSION
from model_providers.core.model_runtime.model_providers.azure_openai._constant import (
AZURE_OPENAI_API_VERSION,
)
class _CommonAzureOpenAI:
@staticmethod
def _to_credential_kwargs(credentials: dict) -> dict:
api_version = credentials.get('openai_api_version', AZURE_OPENAI_API_VERSION)
api_version = credentials.get("openai_api_version", AZURE_OPENAI_API_VERSION)
credentials_kwargs = {
"api_key": credentials['openai_api_key'],
"azure_endpoint": credentials['openai_api_base'],
"api_key": credentials["openai_api_key"],
"azure_endpoint": credentials["openai_api_base"],
"api_version": api_version,
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"max_retries": 1,
@ -29,24 +31,17 @@ class _CommonAzureOpenAI:
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
openai.APIConnectionError,
openai.APITimeoutError
],
InvokeServerUnavailableError: [
openai.InternalServerError
],
InvokeRateLimitError: [
openai.RateLimitError
],
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
InvokeServerUnavailableError: [openai.InternalServerError],
InvokeRateLimitError: [openai.RateLimitError],
InvokeAuthorizationError: [
openai.AuthenticationError,
openai.PermissionDeniedError
openai.PermissionDeniedError,
],
InvokeBadRequestError: [
openai.BadRequestError,
openai.NotFoundError,
openai.UnprocessableEntityError,
openai.APIError
]
openai.APIError,
],
}

View File

@ -14,11 +14,12 @@ from model_providers.core.model_runtime.entities.model_entities import (
PriceConfig,
)
AZURE_OPENAI_API_VERSION = '2024-02-15-preview'
AZURE_OPENAI_API_VERSION = "2024-02-15-preview"
def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule:
rule = ParameterRule(
name='max_tokens',
name="max_tokens",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.MAX_TOKENS],
)
rule.default = default
@ -34,11 +35,11 @@ class AzureBaseModel(BaseModel):
LLM_BASE_MODELS = [
AzureBaseModel(
base_model_name='gpt-35-turbo',
base_model_name="gpt-35-turbo",
entity=AIModelEntity(
model='fake-deployment-name',
model="fake-deployment-name",
label=I18nObject(
en_US='fake-deployment-name-label',
en_US="fake-deployment-name-label",
),
model_type=ModelType.LLM,
features=[
@ -53,37 +54,37 @@ LLM_BASE_MODELS = [
},
parameter_rules=[
ParameterRule(
name='temperature',
name="temperature",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name='top_p',
name="top_p",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name='presence_penalty',
name="presence_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name='frequency_penalty',
name="frequency_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=4096)
_get_max_tokens(default=512, min_val=1, max_val=4096),
],
pricing=PriceConfig(
input=0.001,
output=0.002,
unit=0.001,
currency='USD',
)
)
currency="USD",
),
),
),
AzureBaseModel(
base_model_name='gpt-35-turbo-16k',
base_model_name="gpt-35-turbo-16k",
entity=AIModelEntity(
model='fake-deployment-name',
model="fake-deployment-name",
label=I18nObject(
en_US='fake-deployment-name-label',
en_US="fake-deployment-name-label",
),
model_type=ModelType.LLM,
features=[
@ -98,37 +99,37 @@ LLM_BASE_MODELS = [
},
parameter_rules=[
ParameterRule(
name='temperature',
name="temperature",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name='top_p',
name="top_p",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name='presence_penalty',
name="presence_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name='frequency_penalty',
name="frequency_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=16385)
_get_max_tokens(default=512, min_val=1, max_val=16385),
],
pricing=PriceConfig(
input=0.003,
output=0.004,
unit=0.001,
currency='USD',
)
)
currency="USD",
),
),
),
AzureBaseModel(
base_model_name='gpt-4',
base_model_name="gpt-4",
entity=AIModelEntity(
model='fake-deployment-name',
model="fake-deployment-name",
label=I18nObject(
en_US='fake-deployment-name-label',
en_US="fake-deployment-name-label",
),
model_type=ModelType.LLM,
features=[
@ -143,32 +144,29 @@ LLM_BASE_MODELS = [
},
parameter_rules=[
ParameterRule(
name='temperature',
name="temperature",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name='top_p',
name="top_p",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name='presence_penalty',
name="presence_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name='frequency_penalty',
name="frequency_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=8192),
ParameterRule(
name='seed',
label=I18nObject(
zh_Hans='种子',
en_US='Seed'
),
type='int',
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
required=False,
precision=2,
@ -176,34 +174,31 @@ LLM_BASE_MODELS = [
max=1,
),
ParameterRule(
name='response_format',
label=I18nObject(
zh_Hans='回复格式',
en_US='response_format'
),
type='string',
name="response_format",
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
type="string",
help=I18nObject(
zh_Hans='指定模型必须输出的格式',
en_US='specifying the format that the model must output'
zh_Hans="指定模型必须输出的格式",
en_US="specifying the format that the model must output",
),
required=False,
options=['text', 'json_object']
options=["text", "json_object"],
),
],
pricing=PriceConfig(
input=0.03,
output=0.06,
unit=0.001,
currency='USD',
)
)
currency="USD",
),
),
),
AzureBaseModel(
base_model_name='gpt-4-32k',
base_model_name="gpt-4-32k",
entity=AIModelEntity(
model='fake-deployment-name',
model="fake-deployment-name",
label=I18nObject(
en_US='fake-deployment-name-label',
en_US="fake-deployment-name-label",
),
model_type=ModelType.LLM,
features=[
@ -218,32 +213,29 @@ LLM_BASE_MODELS = [
},
parameter_rules=[
ParameterRule(
name='temperature',
name="temperature",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name='top_p',
name="top_p",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name='presence_penalty',
name="presence_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name='frequency_penalty',
name="frequency_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=32768),
ParameterRule(
name='seed',
label=I18nObject(
zh_Hans='种子',
en_US='Seed'
),
type='int',
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
required=False,
precision=2,
@ -251,34 +243,31 @@ LLM_BASE_MODELS = [
max=1,
),
ParameterRule(
name='response_format',
label=I18nObject(
zh_Hans='回复格式',
en_US='response_format'
),
type='string',
name="response_format",
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
type="string",
help=I18nObject(
zh_Hans='指定模型必须输出的格式',
en_US='specifying the format that the model must output'
zh_Hans="指定模型必须输出的格式",
en_US="specifying the format that the model must output",
),
required=False,
options=['text', 'json_object']
options=["text", "json_object"],
),
],
pricing=PriceConfig(
input=0.06,
output=0.12,
unit=0.001,
currency='USD',
)
)
currency="USD",
),
),
),
AzureBaseModel(
base_model_name='gpt-4-1106-preview',
base_model_name="gpt-4-1106-preview",
entity=AIModelEntity(
model='fake-deployment-name',
model="fake-deployment-name",
label=I18nObject(
en_US='fake-deployment-name-label',
en_US="fake-deployment-name-label",
),
model_type=ModelType.LLM,
features=[
@ -293,32 +282,29 @@ LLM_BASE_MODELS = [
},
parameter_rules=[
ParameterRule(
name='temperature',
name="temperature",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name='top_p',
name="top_p",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name='presence_penalty',
name="presence_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name='frequency_penalty',
name="frequency_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=4096),
ParameterRule(
name='seed',
label=I18nObject(
zh_Hans='种子',
en_US='Seed'
),
type='int',
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
required=False,
precision=2,
@ -326,39 +312,34 @@ LLM_BASE_MODELS = [
max=1,
),
ParameterRule(
name='response_format',
label=I18nObject(
zh_Hans='回复格式',
en_US='response_format'
),
type='string',
name="response_format",
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
type="string",
help=I18nObject(
zh_Hans='指定模型必须输出的格式',
en_US='specifying the format that the model must output'
zh_Hans="指定模型必须输出的格式",
en_US="specifying the format that the model must output",
),
required=False,
options=['text', 'json_object']
options=["text", "json_object"],
),
],
pricing=PriceConfig(
input=0.01,
output=0.03,
unit=0.001,
currency='USD',
)
)
currency="USD",
),
),
),
AzureBaseModel(
base_model_name='gpt-4-vision-preview',
base_model_name="gpt-4-vision-preview",
entity=AIModelEntity(
model='fake-deployment-name',
model="fake-deployment-name",
label=I18nObject(
en_US='fake-deployment-name-label',
en_US="fake-deployment-name-label",
),
model_type=ModelType.LLM,
features=[
ModelFeature.VISION
],
features=[ModelFeature.VISION],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.MODE: LLMMode.CHAT.value,
@ -366,32 +347,29 @@ LLM_BASE_MODELS = [
},
parameter_rules=[
ParameterRule(
name='temperature',
name="temperature",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name='top_p',
name="top_p",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name='presence_penalty',
name="presence_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name='frequency_penalty',
name="frequency_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=4096),
ParameterRule(
name='seed',
label=I18nObject(
zh_Hans='种子',
en_US='Seed'
),
type='int',
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=I18nObject(
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
),
required=False,
precision=2,
@ -399,34 +377,31 @@ LLM_BASE_MODELS = [
max=1,
),
ParameterRule(
name='response_format',
label=I18nObject(
zh_Hans='回复格式',
en_US='response_format'
),
type='string',
name="response_format",
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
type="string",
help=I18nObject(
zh_Hans='指定模型必须输出的格式',
en_US='specifying the format that the model must output'
zh_Hans="指定模型必须输出的格式",
en_US="specifying the format that the model must output",
),
required=False,
options=['text', 'json_object']
options=["text", "json_object"],
),
],
pricing=PriceConfig(
input=0.01,
output=0.03,
unit=0.001,
currency='USD',
)
)
currency="USD",
),
),
),
AzureBaseModel(
base_model_name='gpt-35-turbo-instruct',
base_model_name="gpt-35-turbo-instruct",
entity=AIModelEntity(
model='fake-deployment-name',
model="fake-deployment-name",
label=I18nObject(
en_US='fake-deployment-name-label',
en_US="fake-deployment-name-label",
),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
@ -436,19 +411,19 @@ LLM_BASE_MODELS = [
},
parameter_rules=[
ParameterRule(
name='temperature',
name="temperature",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name='top_p',
name="top_p",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name='presence_penalty',
name="presence_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name='frequency_penalty',
name="frequency_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=4096),
@ -457,16 +432,16 @@ LLM_BASE_MODELS = [
input=0.0015,
output=0.002,
unit=0.001,
currency='USD',
)
)
currency="USD",
),
),
),
AzureBaseModel(
base_model_name='text-davinci-003',
base_model_name="text-davinci-003",
entity=AIModelEntity(
model='fake-deployment-name',
model="fake-deployment-name",
label=I18nObject(
en_US='fake-deployment-name-label',
en_US="fake-deployment-name-label",
),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
@ -476,19 +451,19 @@ LLM_BASE_MODELS = [
},
parameter_rules=[
ParameterRule(
name='temperature',
name="temperature",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name='top_p',
name="top_p",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name='presence_penalty',
name="presence_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name='frequency_penalty',
name="frequency_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=4096),
@ -497,20 +472,18 @@ LLM_BASE_MODELS = [
input=0.02,
output=0.02,
unit=0.001,
currency='USD',
)
)
)
currency="USD",
),
),
),
]
EMBEDDING_BASE_MODELS = [
AzureBaseModel(
base_model_name='text-embedding-ada-002',
base_model_name="text-embedding-ada-002",
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
model="fake-deployment-name",
label=I18nObject(en_US="fake-deployment-name-label"),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
@ -520,17 +493,15 @@ EMBEDDING_BASE_MODELS = [
pricing=PriceConfig(
input=0.0001,
unit=0.001,
currency='USD',
)
)
currency="USD",
),
),
),
AzureBaseModel(
base_model_name='text-embedding-3-small',
base_model_name="text-embedding-3-small",
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
model="fake-deployment-name",
label=I18nObject(en_US="fake-deployment-name-label"),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
@ -540,17 +511,15 @@ EMBEDDING_BASE_MODELS = [
pricing=PriceConfig(
input=0.00002,
unit=0.001,
currency='USD',
)
)
currency="USD",
),
),
),
AzureBaseModel(
base_model_name='text-embedding-3-large',
base_model_name="text-embedding-3-large",
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
model="fake-deployment-name",
label=I18nObject(en_US="fake-deployment-name-label"),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
@ -560,135 +529,237 @@ EMBEDDING_BASE_MODELS = [
pricing=PriceConfig(
input=0.00013,
unit=0.001,
currency='USD',
)
)
)
currency="USD",
),
),
),
]
SPEECH2TEXT_BASE_MODELS = [
AzureBaseModel(
base_model_name='whisper-1',
base_model_name="whisper-1",
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
model="fake-deployment-name",
label=I18nObject(en_US="fake-deployment-name-label"),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.SPEECH2TEXT,
model_properties={
ModelPropertyKey.FILE_UPLOAD_LIMIT: 25,
ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm'
}
)
ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: "flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm",
},
),
)
]
TTS_BASE_MODELS = [
AzureBaseModel(
base_model_name='tts-1',
base_model_name="tts-1",
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
model="fake-deployment-name",
label=I18nObject(en_US="fake-deployment-name-label"),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TTS,
model_properties={
ModelPropertyKey.DEFAULT_VOICE: 'alloy',
ModelPropertyKey.DEFAULT_VOICE: "alloy",
ModelPropertyKey.VOICES: [
{
'mode': 'alloy',
'name': 'Alloy',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
"mode": "alloy",
"name": "Alloy",
"language": [
"zh-Hans",
"en-US",
"de-DE",
"fr-FR",
"es-ES",
"it-IT",
"th-TH",
"id-ID",
],
},
{
'mode': 'echo',
'name': 'Echo',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
"mode": "echo",
"name": "Echo",
"language": [
"zh-Hans",
"en-US",
"de-DE",
"fr-FR",
"es-ES",
"it-IT",
"th-TH",
"id-ID",
],
},
{
'mode': 'fable',
'name': 'Fable',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
"mode": "fable",
"name": "Fable",
"language": [
"zh-Hans",
"en-US",
"de-DE",
"fr-FR",
"es-ES",
"it-IT",
"th-TH",
"id-ID",
],
},
{
'mode': 'onyx',
'name': 'Onyx',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
"mode": "onyx",
"name": "Onyx",
"language": [
"zh-Hans",
"en-US",
"de-DE",
"fr-FR",
"es-ES",
"it-IT",
"th-TH",
"id-ID",
],
},
{
'mode': 'nova',
'name': 'Nova',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
"mode": "nova",
"name": "Nova",
"language": [
"zh-Hans",
"en-US",
"de-DE",
"fr-FR",
"es-ES",
"it-IT",
"th-TH",
"id-ID",
],
},
{
'mode': 'shimmer',
'name': 'Shimmer',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
"mode": "shimmer",
"name": "Shimmer",
"language": [
"zh-Hans",
"en-US",
"de-DE",
"fr-FR",
"es-ES",
"it-IT",
"th-TH",
"id-ID",
],
},
],
ModelPropertyKey.WORD_LIMIT: 120,
ModelPropertyKey.AUDIO_TYPE: 'mp3',
ModelPropertyKey.MAX_WORKERS: 5
ModelPropertyKey.AUDIO_TYPE: "mp3",
ModelPropertyKey.MAX_WORKERS: 5,
},
pricing=PriceConfig(
input=0.015,
unit=0.001,
currency='USD',
)
)
currency="USD",
),
),
),
AzureBaseModel(
base_model_name='tts-1-hd',
base_model_name="tts-1-hd",
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
model="fake-deployment-name",
label=I18nObject(en_US="fake-deployment-name-label"),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TTS,
model_properties={
ModelPropertyKey.DEFAULT_VOICE: 'alloy',
ModelPropertyKey.DEFAULT_VOICE: "alloy",
ModelPropertyKey.VOICES: [
{
'mode': 'alloy',
'name': 'Alloy',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
"mode": "alloy",
"name": "Alloy",
"language": [
"zh-Hans",
"en-US",
"de-DE",
"fr-FR",
"es-ES",
"it-IT",
"th-TH",
"id-ID",
],
},
{
'mode': 'echo',
'name': 'Echo',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
"mode": "echo",
"name": "Echo",
"language": [
"zh-Hans",
"en-US",
"de-DE",
"fr-FR",
"es-ES",
"it-IT",
"th-TH",
"id-ID",
],
},
{
'mode': 'fable',
'name': 'Fable',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
"mode": "fable",
"name": "Fable",
"language": [
"zh-Hans",
"en-US",
"de-DE",
"fr-FR",
"es-ES",
"it-IT",
"th-TH",
"id-ID",
],
},
{
'mode': 'onyx',
'name': 'Onyx',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
"mode": "onyx",
"name": "Onyx",
"language": [
"zh-Hans",
"en-US",
"de-DE",
"fr-FR",
"es-ES",
"it-IT",
"th-TH",
"id-ID",
],
},
{
'mode': 'nova',
'name': 'Nova',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
"mode": "nova",
"name": "Nova",
"language": [
"zh-Hans",
"en-US",
"de-DE",
"fr-FR",
"es-ES",
"it-IT",
"th-TH",
"id-ID",
],
},
{
'mode': 'shimmer',
'name': 'Shimmer',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
"mode": "shimmer",
"name": "Shimmer",
"language": [
"zh-Hans",
"en-US",
"de-DE",
"fr-FR",
"es-ES",
"it-IT",
"th-TH",
"id-ID",
],
},
],
ModelPropertyKey.WORD_LIMIT: 120,
ModelPropertyKey.AUDIO_TYPE: 'mp3',
ModelPropertyKey.MAX_WORKERS: 5
ModelPropertyKey.AUDIO_TYPE: "mp3",
ModelPropertyKey.MAX_WORKERS: 5,
},
pricing=PriceConfig(
input=0.03,
unit=0.001,
currency='USD',
)
)
)
currency="USD",
),
),
),
]

View File

@ -1,11 +1,12 @@
import logging
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
ModelProvider,
)
logger = logging.getLogger(__name__)
class AzureOpenAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -6,11 +6,23 @@ from typing import Optional, Union, cast
import tiktoken
from openai import AzureOpenAI, Stream
from openai.types import Completion
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageToolCall,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaFunctionCall,
ChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_message import FunctionCall
from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from model_providers.core.model_runtime.entities.llm_entities import (
LLMMode,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
@ -22,26 +34,47 @@ from model_providers.core.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from model_providers.core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelPropertyKey,
)
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
from model_providers.core.model_runtime.model_providers.azure_openai._common import (
_CommonAzureOpenAI,
)
from model_providers.core.model_runtime.model_providers.azure_openai._constant import (
LLM_BASE_MODELS,
AzureBaseModel,
)
logger = logging.getLogger(__name__)
class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
ai_model_entity = self._get_ai_model_entity(
credentials.get("base_model_name"), model
)
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
if (
ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
== LLMMode.CHAT.value
):
# chat model
return self._chat_generate(
model=model,
@ -51,7 +84,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
else:
# text completion model
@ -62,14 +95,19 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
model_parameters=model_parameters,
stop=stop,
stream=stream,
user=user
user=user,
)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
ModelPropertyKey.MODE)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
model_mode = self._get_ai_model_entity(
credentials.get("base_model_name"), model
).entity.model_properties.get(ModelPropertyKey.MODE)
if model_mode == LLMMode.CHAT.value:
# chat model
@ -79,27 +117,36 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return self._num_tokens_from_string(credentials, prompt_messages[0].content)
def validate_credentials(self, model: str, credentials: dict) -> None:
if 'openai_api_base' not in credentials:
raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required')
if "openai_api_base" not in credentials:
raise CredentialsValidateFailedError(
"Azure OpenAI API Base Endpoint is required"
)
if 'openai_api_key' not in credentials:
raise CredentialsValidateFailedError('Azure OpenAI API key is required')
if "openai_api_key" not in credentials:
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
if 'base_model_name' not in credentials:
raise CredentialsValidateFailedError('Base Model Name is required')
if "base_model_name" not in credentials:
raise CredentialsValidateFailedError("Base Model Name is required")
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
ai_model_entity = self._get_ai_model_entity(
credentials.get("base_model_name"), model
)
if not ai_model_entity:
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
raise CredentialsValidateFailedError(
f'Base Model Name {credentials["base_model_name"]} is invalid'
)
try:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
if (
ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
== LLMMode.CHAT.value
):
# chat model
client.chat.completions.create(
messages=[{"role": "user", "content": 'ping'}],
messages=[{"role": "user", "content": "ping"}],
model=model,
temperature=0,
max_tokens=20,
@ -108,7 +155,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
else:
# text completion model
client.completions.create(
prompt='ping',
prompt="ping",
model=model,
temperature=0,
max_tokens=20,
@ -117,23 +164,33 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
def get_customizable_model_schema(
self, model: str, credentials: dict
) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(
credentials.get("base_model_name"), model
)
return ai_model_entity.entity if ai_model_entity else None
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop'] = stop
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs['user'] = user
extra_model_kwargs["user"] = user
# text completion model
response = client.completions.create(
@ -141,22 +198,29 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
model=model,
stream=stream,
**model_parameters,
**extra_model_kwargs
**extra_model_kwargs,
)
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_stream_response(
model, credentials, response, prompt_messages
)
return self._handle_generate_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(
model, credentials, response, prompt_messages
)
def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
prompt_messages: list[PromptMessage]) -> LLMResult:
def _handle_generate_response(
self,
model: str,
credentials: dict,
response: Completion,
prompt_messages: list[PromptMessage],
) -> LLMResult:
assistant_text = response.choices[0].text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
)
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
# calculate num tokens
if response.usage:
@ -165,11 +229,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
completion_tokens = response.usage.completion_tokens
else:
# calculate num tokens
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
prompt_tokens = self._num_tokens_from_string(
credentials, prompt_messages[0].content
)
completion_tokens = self._num_tokens_from_string(
credentials, assistant_text
)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
# transform response
result = LLMResult(
@ -182,23 +252,26 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return result
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
prompt_messages: list[PromptMessage]) -> Generator:
full_text = ''
def _handle_generate_stream_response(
self,
model: str,
credentials: dict,
response: Stream[Completion],
prompt_messages: list[PromptMessage],
) -> Generator:
full_text = ""
for chunk in response:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
if delta.finish_reason is None and (delta.text is None or delta.text == ''):
if delta.finish_reason is None and (delta.text is None or delta.text == ""):
continue
# transform assistant message to prompt message
text = delta.text if delta.text else ''
assistant_prompt_message = AssistantPromptMessage(
content=text
)
text = delta.text if delta.text else ""
assistant_prompt_message = AssistantPromptMessage(content=text)
full_text += text
@ -210,11 +283,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
completion_tokens = chunk.usage.completion_tokens
else:
# calculate num tokens
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
completion_tokens = self._num_tokens_from_string(credentials, full_text)
prompt_tokens = self._num_tokens_from_string(
credentials, prompt_messages[0].content
)
completion_tokens = self._num_tokens_from_string(
credentials, full_text
)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
yield LLMResultChunk(
model=chunk.model,
@ -224,8 +303,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage
)
usage=usage,
),
)
else:
yield LLMResultChunk(
@ -235,14 +314,20 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
delta=LLMResultChunkDelta(
index=delta.index,
message=assistant_prompt_message,
)
),
)
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _chat_generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
response_format = model_parameters.get("response_format")
@ -258,17 +343,20 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
if tools:
# extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
extra_model_kwargs['functions'] = [{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
} for tool in tools]
extra_model_kwargs["functions"] = [
{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
}
for tool in tools
]
if stop:
extra_model_kwargs['stop'] = stop
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs['user'] = user
extra_model_kwargs["user"] = user
# chat model
response = client.chat.completions.create(
@ -280,27 +368,36 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
return self._handle_chat_generate_stream_response(
model, credentials, response, prompt_messages, tools
)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> LLMResult:
return self._handle_chat_generate_response(
model, credentials, response, prompt_messages, tools
)
def _handle_chat_generate_response(
self,
model: str,
credentials: dict,
response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> LLMResult:
assistant_message = response.choices[0].message
# assistant_message_tool_calls = assistant_message.tool_calls
assistant_message_function_call = assistant_message.function_call
# extract tool calls from response
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
function_call = self._extract_response_function_call(assistant_message_function_call)
function_call = self._extract_response_function_call(
assistant_message_function_call
)
tool_calls = [function_call] if function_call else []
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_message.content,
tool_calls=tool_calls
content=assistant_message.content, tool_calls=tool_calls
)
# calculate num tokens
@ -310,11 +407,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
completion_tokens = response.usage.completion_tokens
else:
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
completion_tokens = self._num_tokens_from_messages(credentials, [assistant_prompt_message])
prompt_tokens = self._num_tokens_from_messages(
credentials, prompt_messages, tools
)
completion_tokens = self._num_tokens_from_messages(
credentials, [assistant_prompt_message]
)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
# transform response
response = LLMResult(
@ -327,24 +430,31 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return response
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> Generator:
def _handle_chat_generate_stream_response(
self,
model: str,
credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> Generator:
index = 0
full_assistant_content = ''
full_assistant_content = ""
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
real_model = model
system_fingerprint = None
completion = ''
completion = ""
for chunk in response:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \
delta.delta.function_call is None:
if (
delta.finish_reason is None
and (delta.delta.content is None or delta.delta.content == "")
and delta.delta.function_call is None
):
continue
# assistant_message_tool_calls = delta.delta.tool_calls
@ -355,36 +465,44 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
# handle process of stream function call
if assistant_message_function_call:
# message has not ended ever
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
delta_assistant_message_function_call_storage.arguments += (
assistant_message_function_call.arguments
)
continue
else:
# message has ended
assistant_message_function_call = delta_assistant_message_function_call_storage
assistant_message_function_call = (
delta_assistant_message_function_call_storage
)
delta_assistant_message_function_call_storage = None
else:
if assistant_message_function_call:
# start of stream function call
delta_assistant_message_function_call_storage = assistant_message_function_call
delta_assistant_message_function_call_storage = (
assistant_message_function_call
)
if delta_assistant_message_function_call_storage.arguments is None:
delta_assistant_message_function_call_storage.arguments = ''
delta_assistant_message_function_call_storage.arguments = ""
continue
# extract tool calls from response
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
function_call = self._extract_response_function_call(assistant_message_function_call)
function_call = self._extract_response_function_call(
assistant_message_function_call
)
tool_calls = [function_call] if function_call else []
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else '',
tool_calls=tool_calls
content=delta.delta.content if delta.delta.content else "",
tool_calls=tool_calls,
)
full_assistant_content += delta.delta.content if delta.delta.content else ''
full_assistant_content += delta.delta.content if delta.delta.content else ""
real_model = chunk.model
system_fingerprint = chunk.system_fingerprint
completion += delta.delta.content if delta.delta.content else ''
completion += delta.delta.content if delta.delta.content else ""
yield LLMResultChunk(
model=real_model,
@ -393,21 +511,25 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
)
),
)
index += 0
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
full_assistant_prompt_message = AssistantPromptMessage(
content=completion
prompt_tokens = self._num_tokens_from_messages(
credentials, prompt_messages, tools
)
full_assistant_prompt_message = AssistantPromptMessage(content=completion)
completion_tokens = self._num_tokens_from_messages(
credentials, [full_assistant_prompt_message]
)
completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
yield LLMResultChunk(
model=real_model,
@ -415,55 +537,52 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=''),
finish_reason='stop',
usage=usage
)
message=AssistantPromptMessage(content=""),
finish_reason="stop",
usage=usage,
),
)
@staticmethod
def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
-> list[AssistantPromptMessage.ToolCall]:
def _extract_response_tool_calls(
response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall],
) -> list[AssistantPromptMessage.ToolCall]:
tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name,
arguments=response_tool_call.function.arguments
arguments=response_tool_call.function.arguments,
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id,
type=response_tool_call.type,
function=function
function=function,
)
tool_calls.append(tool_call)
return tool_calls
@staticmethod
def _extract_response_function_call(response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
-> AssistantPromptMessage.ToolCall:
def _extract_response_function_call(
response_function_call: FunctionCall | ChoiceDeltaFunctionCall,
) -> AssistantPromptMessage.ToolCall:
tool_call = None
if response_function_call:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function_call.name,
arguments=response_function_call.arguments
arguments=response_function_call.arguments,
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_function_call.name,
type="function",
function=function
id=response_function_call.name, type="function", function=function
)
return tool_call
@staticmethod
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
@ -472,20 +591,24 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
message_content = cast(
TextPromptMessageContent, message_content
)
sub_message_dict = {
"type": "text",
"text": message_content.data
"text": message_content.data,
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
message_content = cast(
ImagePromptMessageContent, message_content
)
sub_message_dict = {
"type": "image_url",
"image_url": {
"url": message_content.data,
"detail": message_content.detail.value
}
"detail": message_content.detail.value,
},
}
sub_messages.append(sub_message_dict)
@ -514,7 +637,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
message_dict = {
"role": "function",
"content": message.content,
"name": message.tool_call_id
"name": message.tool_call_id,
}
else:
raise ValueError(f"Got unknown type {message}")
@ -524,10 +647,14 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return message_dict
def _num_tokens_from_string(self, credentials: dict, text: str,
tools: Optional[list[PromptMessageTool]] = None) -> int:
def _num_tokens_from_string(
self,
credentials: dict,
text: str,
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
try:
encoding = tiktoken.encoding_for_model(credentials['base_model_name'])
encoding = tiktoken.encoding_for_model(credentials["base_model_name"])
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
@ -538,13 +665,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return num_tokens
def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def _num_tokens_from_messages(
self,
credentials: dict,
messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
model = credentials['base_model_name']
model = credentials["base_model_name"]
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
@ -578,10 +709,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
# which need to download the image and then get the resolution for calculation,
# and will increase the request delay
if isinstance(value, list):
text = ''
text = ""
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
@ -611,41 +742,42 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return num_tokens
@staticmethod
def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int:
def _num_tokens_for_tools(
encoding: tiktoken.Encoding, tools: list[PromptMessageTool]
) -> int:
num_tokens = 0
for tool in tools:
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode('function'))
num_tokens += len(encoding.encode("type"))
num_tokens += len(encoding.encode("function"))
# calculate num tokens for function object
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode("name"))
num_tokens += len(encoding.encode(tool.name))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode("description"))
num_tokens += len(encoding.encode(tool.description))
parameters = tool.parameters
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode("parameters"))
if "title" in parameters:
num_tokens += len(encoding.encode("title"))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode("type"))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
if "properties" in parameters:
num_tokens += len(encoding.encode("properties"))
for key, value in parameters.get("properties").items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
if field_key == "enum":
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
if "required" in parameters:
num_tokens += len(encoding.encode("required"))
for required_field in parameters["required"]:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))

View File

@ -4,10 +4,19 @@ from typing import IO, Optional
from openai import AzureOpenAI
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from model_providers.core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from model_providers.core.model_runtime.model_providers.azure_openai._constant import SPEECH2TEXT_BASE_MODELS, AzureBaseModel
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.speech2text_model import (
Speech2TextModel,
)
from model_providers.core.model_runtime.model_providers.azure_openai._common import (
_CommonAzureOpenAI,
)
from model_providers.core.model_runtime.model_providers.azure_openai._constant import (
SPEECH2TEXT_BASE_MODELS,
AzureBaseModel,
)
class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
@ -15,9 +24,9 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
Model class for OpenAI Speech to text model.
"""
def _invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
def _invoke(
self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None
) -> str:
"""
Invoke speech2text model
@ -40,12 +49,14 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
try:
audio_file_path = self._get_demo_file_path()
with open(audio_file_path, 'rb') as audio_file:
with open(audio_file_path, "rb") as audio_file:
self._speech2text_invoke(model, credentials, audio_file)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
def _speech2text_invoke(
self, model: str, credentials: dict, file: IO[bytes]
) -> str:
"""
Invoke speech2text model
@ -64,11 +75,14 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
return response.text
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
def get_customizable_model_schema(
self, model: str, credentials: dict
) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(
credentials["base_model_name"], model
)
return ai_model_entity.entity
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:

View File

@ -7,28 +7,46 @@ import numpy as np
import tiktoken
from openai import AzureOpenAI
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, PriceType
from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from model_providers.core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from model_providers.core.model_runtime.model_providers.azure_openai._constant import EMBEDDING_BASE_MODELS, AzureBaseModel
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
PriceType,
)
from model_providers.core.model_runtime.entities.text_embedding_entities import (
EmbeddingUsage,
TextEmbeddingResult,
)
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
TextEmbeddingModel,
)
from model_providers.core.model_runtime.model_providers.azure_openai._common import (
_CommonAzureOpenAI,
)
from model_providers.core.model_runtime.model_providers.azure_openai._constant import (
EMBEDDING_BASE_MODELS,
AzureBaseModel,
)
class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
base_model_name = credentials['base_model_name']
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
base_model_name = credentials["base_model_name"]
credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs)
extra_model_kwargs = {}
if user:
extra_model_kwargs['user'] = user
extra_model_kwargs["user"] = user
extra_model_kwargs['encoding_format'] = 'base64'
extra_model_kwargs["encoding_format"] = "base64"
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
@ -44,11 +62,9 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
enc = tiktoken.get_encoding("cl100k_base")
for i, text in enumerate(texts):
token = enc.encode(
text
)
token = enc.encode(text)
for j in range(0, len(token), context_size):
tokens += [token[j: j + context_size]]
tokens += [token[j : j + context_size]]
indices += [i]
batched_embeddings = []
@ -58,8 +74,8 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
client=client,
texts=tokens[i: i + max_chunks],
extra_model_kwargs=extra_model_kwargs
texts=tokens[i : i + max_chunks],
extra_model_kwargs=extra_model_kwargs,
)
used_tokens += embedding_used_tokens
@ -78,7 +94,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
model=model,
client=client,
texts="",
extra_model_kwargs=extra_model_kwargs
extra_model_kwargs=extra_model_kwargs,
)
used_tokens += embedding_used_tokens
@ -89,15 +105,11 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=used_tokens
model=model, credentials=credentials, tokens=used_tokens
)
return TextEmbeddingResult(
embeddings=embeddings,
usage=usage,
model=base_model_name
embeddings=embeddings, usage=usage, model=base_model_name
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
@ -105,7 +117,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
return 0
try:
enc = tiktoken.encoding_for_model(credentials['base_model_name'])
enc = tiktoken.encoding_for_model(credentials["base_model_name"])
except KeyError:
enc = tiktoken.get_encoding("cl100k_base")
@ -118,57 +130,78 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
return total_num_tokens
def validate_credentials(self, model: str, credentials: dict) -> None:
if 'openai_api_base' not in credentials:
raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required')
if "openai_api_base" not in credentials:
raise CredentialsValidateFailedError(
"Azure OpenAI API Base Endpoint is required"
)
if 'openai_api_key' not in credentials:
raise CredentialsValidateFailedError('Azure OpenAI API key is required')
if "openai_api_key" not in credentials:
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
if 'base_model_name' not in credentials:
raise CredentialsValidateFailedError('Base Model Name is required')
if "base_model_name" not in credentials:
raise CredentialsValidateFailedError("Base Model Name is required")
if not self._get_ai_model_entity(credentials['base_model_name'], model):
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
if not self._get_ai_model_entity(credentials["base_model_name"], model):
raise CredentialsValidateFailedError(
f'Base Model Name {credentials["base_model_name"]} is invalid'
)
try:
credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs)
self._embedding_invoke(
model=model,
client=client,
texts=['ping'],
extra_model_kwargs={}
model=model, client=client, texts=["ping"], extra_model_kwargs={}
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
def get_customizable_model_schema(
self, model: str, credentials: dict
) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(
credentials["base_model_name"], model
)
return ai_model_entity.entity
@staticmethod
def _embedding_invoke(model: str, client: AzureOpenAI, texts: Union[list[str], str],
extra_model_kwargs: dict) -> tuple[list[list[float]], int]:
def _embedding_invoke(
model: str,
client: AzureOpenAI,
texts: Union[list[str], str],
extra_model_kwargs: dict,
) -> tuple[list[list[float]], int]:
response = client.embeddings.create(
input=texts,
model=model,
**extra_model_kwargs,
)
if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64':
if (
"encoding_format" in extra_model_kwargs
and extra_model_kwargs["encoding_format"] == "base64"
):
# decode base64 embedding
return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data],
response.usage.total_tokens)
return (
[
list(
np.frombuffer(base64.b64decode(data.embedding), dtype="float32")
)
for data in response.data
],
response.usage.total_tokens,
)
return [data.embedding for data in response.data], response.usage.total_tokens
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
def _calc_response_usage(
self, model: str, credentials: dict, tokens: int
) -> EmbeddingUsage:
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
tokens=tokens,
)
# transform usage
@ -179,7 +212,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage

View File

@ -3,16 +3,24 @@ import copy
from functools import reduce
from io import BytesIO
from typing import Optional
from fastapi.responses import StreamingResponse
from openai import AzureOpenAI
from pydub import AudioSegment
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
from model_providers.core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from model_providers.core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel
from model_providers.core.model_runtime.model_providers.azure_openai._common import (
_CommonAzureOpenAI,
)
from model_providers.core.model_runtime.model_providers.azure_openai._constant import (
TTS_BASE_MODELS,
AzureBaseModel,
)
from model_providers.extensions.ext_storage import storage
@ -21,8 +29,16 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
Model class for OpenAI Speech to text model.
"""
def _invoke(self, model: str, tenant_id: str, credentials: dict,
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
def _invoke(
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
streaming: bool,
user: Optional[str] = None,
) -> any:
"""
_invoke text2speech model
@ -36,20 +52,34 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
:return: text translated to audio file
"""
audio_type = self._get_model_audio_type(model, credentials)
if not voice or voice not in [d['value'] for d in
self.get_tts_model_voices(model=model, credentials=credentials)]:
if not voice or voice not in [
d["value"]
for d in self.get_tts_model_voices(model=model, credentials=credentials)
]:
voice = self._get_model_default_voice(model, credentials)
if streaming:
return StreamingResponse(self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
tenant_id=tenant_id,
voice=voice), media_type='text/event-stream')
return StreamingResponse(
self._tts_invoke_streaming(
model=model,
credentials=credentials,
content_text=content_text,
tenant_id=tenant_id,
voice=voice,
),
media_type="text/event-stream",
)
else:
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
return self._tts_invoke(
model=model,
credentials=credentials,
content_text=content_text,
voice=voice,
)
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
def validate_credentials(
self, model: str, credentials: dict, user: Optional[str] = None
) -> None:
"""
validate credentials text2speech model
@ -62,13 +92,15 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
self._tts_invoke(
model=model,
credentials=credentials,
content_text='Hello Dify!',
content_text="Hello Dify!",
voice=self._get_model_default_voice(model, credentials),
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> StreamingResponse:
def _tts_invoke(
self, model: str, credentials: dict, content_text: str, voice: str
) -> StreamingResponse:
"""
_tts_invoke text2speech model
@ -82,13 +114,25 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
word_limit = self._get_model_word_limit(model, credentials)
max_workers = self._get_model_workers_limit(model, credentials)
try:
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
sentences = list(
self._split_text_into_sentences(text=content_text, limit=word_limit)
)
audio_bytes_list = list()
# Create a thread pool and map the function to the list of sentences
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(self._process_sentence, sentence=sentence, model=model, voice=voice,
credentials=credentials) for sentence in sentences]
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
futures = [
executor.submit(
self._process_sentence,
sentence=sentence,
model=model,
voice=voice,
credentials=credentials,
)
for sentence in sentences
]
for future in futures:
try:
if future.result():
@ -97,8 +141,11 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
raise InvokeBadRequestError(str(ex))
if len(audio_bytes_list) > 0:
audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in
audio_bytes_list if audio_bytes]
audio_segments = [
AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type)
for audio_bytes in audio_bytes_list
if audio_bytes
]
combined_segment = reduce(lambda x, y: x + y, audio_segments)
buffer: BytesIO = BytesIO()
combined_segment.export(buffer, format=audio_type)
@ -108,8 +155,14 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
raise InvokeBadRequestError(str(ex))
# Todo: To improve the streaming function
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
voice: str) -> any:
def _tts_invoke_streaming(
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
) -> any:
"""
_tts_invoke_streaming text2speech model
@ -122,24 +175,29 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
if not voice or voice not in self.get_tts_model_voices(
model=model, credentials=credentials
):
voice = self._get_model_default_voice(model, credentials)
word_limit = self._get_model_word_limit(model, credentials)
audio_type = self._get_model_audio_type(model, credentials)
tts_file_id = self._get_file_name(content_text)
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
file_path = f"generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}"
try:
client = AzureOpenAI(**credentials_kwargs)
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
sentences = list(
self._split_text_into_sentences(text=content_text, limit=word_limit)
)
for sentence in sentences:
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
response = client.audio.speech.create(
model=model, voice=voice, input=sentence.strip()
)
# response.stream_to_file(file_path)
storage.save(file_path, response.read())
except Exception as ex:
raise InvokeBadRequestError(str(ex))
def _process_sentence(self, sentence: str, model: str,
voice, credentials: dict):
def _process_sentence(self, sentence: str, model: str, voice, credentials: dict):
"""
_tts_invoke openai text2speech model api
@ -152,12 +210,18 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs)
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
response = client.audio.speech.create(
model=model, voice=voice, input=sentence.strip()
)
if isinstance(response.read(), bytes):
return response.read()
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
def get_customizable_model_schema(
self, model: str, credentials: dict
) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(
credentials["base_model_name"], model
)
return ai_model_entity.entity
@staticmethod

View File

@ -1,11 +1,16 @@
import logging
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
ModelProvider,
)
logger = logging.getLogger(__name__)
class BaichuanProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
@ -20,11 +25,12 @@ class BaichuanProvider(ModelProvider):
# Use `baichuan2-turbo` model for validate,
model_instance.validate_credentials(
model='baichuan2-turbo',
credentials=credentials
model="baichuan2-turbo", credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(
f"{self.get_provider_schema().provider} credentials validate failed"
)
raise ex

View File

@ -4,17 +4,20 @@ import re
class BaichuanTokenizer:
@classmethod
def count_chinese_characters(cls, text: str) -> int:
return len(re.findall(r'[\u4e00-\u9fa5]', text))
return len(re.findall(r"[\u4e00-\u9fa5]", text))
@classmethod
def count_english_vocabularies(cls, text: str) -> int:
# remove all non-alphanumeric characters but keep spaces and other symbols like !, ., etc.
text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
text = re.sub(r"[^a-zA-Z0-9\s]", "", text)
# count the number of words not characters
return len(text.split())
@classmethod
def _get_num_tokens(cls, text: str) -> int:
# tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return)
# https://platform.baichuan-ai.com/docs/text-Embedding
return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3)
return int(
cls.count_chinese_characters(text)
+ cls.count_english_vocabularies(text) * 1.3
)

View File

@ -18,153 +18,188 @@ from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tu
class BaichuanMessage:
class Role(Enum):
USER = 'user'
ASSISTANT = 'assistant'
USER = "user"
ASSISTANT = "assistant"
# Baichuan does not have system message
_SYSTEM = 'system'
_SYSTEM = "system"
role: str = Role.USER.value
content: str
usage: dict[str, int] = None
stop_reason: str = ''
stop_reason: str = ""
def to_dict(self) -> dict[str, Any]:
return {
'role': self.role,
'content': self.content,
"role": self.role,
"content": self.content,
}
def __init__(self, content: str, role: str = 'user') -> None:
def __init__(self, content: str, role: str = "user") -> None:
self.content = content
self.role = role
class BaichuanModel:
api_key: str
secret_key: str
def __init__(self, api_key: str, secret_key: str = '') -> None:
def __init__(self, api_key: str, secret_key: str = "") -> None:
self.api_key = api_key
self.secret_key = secret_key
def _model_mapping(self, model: str) -> str:
return {
'baichuan2-turbo': 'Baichuan2-Turbo',
'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k',
'baichuan2-53b': 'Baichuan2-53B',
"baichuan2-turbo": "Baichuan2-Turbo",
"baichuan2-turbo-192k": "Baichuan2-Turbo-192k",
"baichuan2-53b": "Baichuan2-53B",
}[model]
def _handle_chat_generate_response(self, response) -> BaichuanMessage:
resp = response.json()
choices = resp.get('choices', [])
message = BaichuanMessage(content='', role='assistant')
for choice in choices:
message.content += choice['message']['content']
message.role = choice['message']['role']
if choice['finish_reason']:
message.stop_reason = choice['finish_reason']
resp = response.json()
choices = resp.get("choices", [])
message = BaichuanMessage(content="", role="assistant")
for choice in choices:
message.content += choice["message"]["content"]
message.role = choice["message"]["role"]
if choice["finish_reason"]:
message.stop_reason = choice["finish_reason"]
if "usage" in resp:
message.usage = {
"prompt_tokens": resp["usage"]["prompt_tokens"],
"completion_tokens": resp["usage"]["completion_tokens"],
"total_tokens": resp["usage"]["total_tokens"],
}
return message
if 'usage' in resp:
message.usage = {
'prompt_tokens': resp['usage']['prompt_tokens'],
'completion_tokens': resp['usage']['completion_tokens'],
'total_tokens': resp['usage']['total_tokens'],
}
return message
def _handle_chat_stream_generate_response(self, response) -> Generator:
for line in response.iter_lines():
if not line:
continue
line = line.decode('utf-8')
line = line.decode("utf-8")
# remove the first `data: ` prefix
if line.startswith('data:'):
if line.startswith("data:"):
line = line[5:].strip()
try:
data = loads(line)
except Exception as e:
if line.strip() == '[DONE]':
if line.strip() == "[DONE]":
return
choices = data.get('choices', [])
choices = data.get("choices", [])
# save stop reason temporarily
stop_reason = ''
stop_reason = ""
for choice in choices:
if 'finish_reason' in choice and choice['finish_reason']:
stop_reason = choice['finish_reason']
if "finish_reason" in choice and choice["finish_reason"]:
stop_reason = choice["finish_reason"]
if len(choice['delta']['content']) == 0:
if len(choice["delta"]["content"]) == 0:
continue
yield BaichuanMessage(**choice['delta'])
yield BaichuanMessage(**choice["delta"])
# if there is usage, the response is the last one, yield it and return
if 'usage' in data:
message = BaichuanMessage(content='', role='assistant')
if "usage" in data:
message = BaichuanMessage(content="", role="assistant")
message.usage = {
'prompt_tokens': data['usage']['prompt_tokens'],
'completion_tokens': data['usage']['completion_tokens'],
'total_tokens': data['usage']['total_tokens'],
"prompt_tokens": data["usage"]["prompt_tokens"],
"completion_tokens": data["usage"]["completion_tokens"],
"total_tokens": data["usage"]["total_tokens"],
}
message.stop_reason = stop_reason
yield message
def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage],
parameters: dict[str, Any]) \
-> dict[str, Any]:
if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
def _build_parameters(
self,
model: str,
stream: bool,
messages: list[BaichuanMessage],
parameters: dict[str, Any],
) -> dict[str, Any]:
if (
model == "baichuan2-turbo"
or model == "baichuan2-turbo-192k"
or model == "baichuan2-53b"
):
prompt_messages = []
for message in messages:
if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value:
if (
message.role == BaichuanMessage.Role.USER.value
or message.role == BaichuanMessage.Role._SYSTEM.value
):
# check if the latest message is a user message
if len(prompt_messages) > 0 and prompt_messages[-1]['role'] == BaichuanMessage.Role.USER.value:
prompt_messages[-1]['content'] += message.content
if (
len(prompt_messages) > 0
and prompt_messages[-1]["role"]
== BaichuanMessage.Role.USER.value
):
prompt_messages[-1]["content"] += message.content
else:
prompt_messages.append({
'content': message.content,
'role': BaichuanMessage.Role.USER.value,
})
prompt_messages.append(
{
"content": message.content,
"role": BaichuanMessage.Role.USER.value,
}
)
elif message.role == BaichuanMessage.Role.ASSISTANT.value:
prompt_messages.append({
'content': message.content,
'role': message.role,
})
prompt_messages.append(
{
"content": message.content,
"role": message.role,
}
)
# [baichuan] frequency_penalty must be between 1 and 2
if 'frequency_penalty' in parameters:
if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2:
parameters['frequency_penalty'] = 1
if "frequency_penalty" in parameters:
if (
parameters["frequency_penalty"] < 1
or parameters["frequency_penalty"] > 2
):
parameters["frequency_penalty"] = 1
# turbo api accepts flat parameters
return {
'model': self._model_mapping(model),
'stream': stream,
'messages': prompt_messages,
"model": self._model_mapping(model),
"stream": stream,
"messages": prompt_messages,
**parameters,
}
else:
raise BadRequestError(f"Unknown model: {model}")
def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
if (
model == "baichuan2-turbo"
or model == "baichuan2-turbo-192k"
or model == "baichuan2-53b"
):
# there is no secret key for turbo api
return {
'Content-Type': 'application/json',
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ',
'Authorization': 'Bearer ' + self.api_key,
"Content-Type": "application/json",
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ",
"Authorization": "Bearer " + self.api_key,
}
else:
raise BadRequestError(f"Unknown model: {model}")
def _calculate_md5(self, input_string):
return md5(input_string.encode('utf-8')).hexdigest()
def generate(self, model: str, stream: bool, messages: list[BaichuanMessage],
parameters: dict[str, Any], timeout: int) \
-> Union[Generator, BaichuanMessage]:
if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
api_base = 'https://api.baichuan-ai.com/v1/chat/completions'
def _calculate_md5(self, input_string):
return md5(input_string.encode("utf-8")).hexdigest()
def generate(
self,
model: str,
stream: bool,
messages: list[BaichuanMessage],
parameters: dict[str, Any],
timeout: int,
) -> Union[Generator, BaichuanMessage]:
if (
model == "baichuan2-turbo"
or model == "baichuan2-turbo-192k"
or model == "baichuan2-53b"
):
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
else:
raise BadRequestError(f"Unknown model: {model}")
try:
data = self._build_parameters(model, stream, messages, parameters)
headers = self._build_headers(model, data)
@ -177,35 +212,37 @@ class BaichuanModel:
headers=headers,
data=dumps(data),
timeout=timeout,
stream=stream
stream=stream,
)
except Exception as e:
raise InternalServerError(f"Failed to invoke model: {e}")
if response.status_code != 200:
try:
resp = response.json()
# try to parse error message
err = resp['error']['code']
msg = resp['error']['message']
err = resp["error"]["code"]
msg = resp["error"]["message"]
except Exception as e:
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
raise InternalServerError(
f"Failed to convert response to json: {e} with text: {response.text}"
)
if err == 'invalid_api_key':
if err == "invalid_api_key":
raise InvalidAPIKeyError(msg)
elif err == 'insufficient_quota':
elif err == "insufficient_quota":
raise InsufficientAccountBalance(msg)
elif err == 'invalid_authentication':
elif err == "invalid_authentication":
raise InvalidAuthenticationError(msg)
elif 'rate' in err:
elif "rate" in err:
raise RateLimitReachedError(msg)
elif 'internal' in err:
elif "internal" in err:
raise InternalServerError(msg)
elif err == 'api_key_empty':
elif err == "api_key_empty":
raise InvalidAPIKeyError(msg)
else:
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
if stream:
return self._handle_chat_stream_generate_response(response)
else:

View File

@ -1,17 +1,22 @@
class InvalidAuthenticationError(Exception):
pass
class InvalidAPIKeyError(Exception):
pass
class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalance(Exception):
pass
class InternalServerError(Exception):
pass
class BadRequestError(Exception):
pass
pass

View File

@ -1,7 +1,11 @@
from collections.abc import Generator
from typing import cast
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@ -17,10 +21,19 @@ from model_providers.core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import (
BaichuanTokenizer,
)
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import (
BaichuanMessage,
BaichuanModel,
)
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError,
InsufficientAccountBalance,
@ -32,20 +45,43 @@ from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tu
class BaichuanLarguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
return self._generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
return self._num_tokens_from_messages(prompt_messages)
def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int:
def _num_tokens_from_messages(
self,
messages: list[PromptMessage],
) -> int:
"""Calculate num tokens for baichuan model"""
def tokens(text: str):
return BaichuanTokenizer._get_num_tokens(text)
@ -57,10 +93,10 @@ class BaichuanLarguageModel(LargeLanguageModel):
num_tokens += tokens_per_message
for key, value in message.items():
if isinstance(value, list):
text = ''
text = ""
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
@ -87,89 +123,123 @@ class BaichuanLarguageModel(LargeLanguageModel):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return message_dict
def validate_credentials(self, model: str, credentials: dict) -> None:
# ping
instance = BaichuanModel(
api_key=credentials['api_key'],
secret_key=credentials.get('secret_key', '')
api_key=credentials["api_key"], secret_key=credentials.get("secret_key", "")
)
try:
instance.generate(model=model, stream=False, messages=[
BaichuanMessage(content='ping', role='user')
], parameters={
'max_tokens': 1,
}, timeout=60)
instance.generate(
model=model,
stream=False,
messages=[BaichuanMessage(content="ping", role="user")],
parameters={
"max_tokens": 1,
},
timeout=60,
)
except Exception as e:
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
if tools is not None and len(tools) > 0:
raise InvokeBadRequestError("Baichuan model doesn't support tools")
instance = BaichuanModel(
api_key=credentials['api_key'],
secret_key=credentials.get('secret_key', '')
api_key=credentials["api_key"], secret_key=credentials.get("secret_key", "")
)
# convert prompt messages to baichuan messages
messages = [
BaichuanMessage(
content=message.content if isinstance(message.content, str) else ''.join([
content.data for content in message.content
]),
role=message.role.value
) for message in prompt_messages
content=message.content
if isinstance(message.content, str)
else "".join([content.data for content in message.content]),
role=message.role.value,
)
for message in prompt_messages
]
# invoke model
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, timeout=60)
response = instance.generate(
model=model,
stream=stream,
messages=messages,
parameters=model_parameters,
timeout=60,
)
if stream:
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
return self._handle_chat_generate_response(model, prompt_messages, credentials, response)
return self._handle_chat_generate_stream_response(
model, prompt_messages, credentials, response
)
def _handle_chat_generate_response(self, model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: BaichuanMessage) -> LLMResult:
return self._handle_chat_generate_response(
model, prompt_messages, credentials, response
)
def _handle_chat_generate_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: BaichuanMessage,
) -> LLMResult:
# convert baichuan message to llm result
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens'])
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=response.usage["prompt_tokens"],
completion_tokens=response.usage["completion_tokens"],
)
return LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=response.content,
tool_calls=[]
),
message=AssistantPromptMessage(content=response.content, tool_calls=[]),
usage=usage,
)
def _handle_chat_generate_stream_response(self, model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Generator[BaichuanMessage, None, None]) -> Generator:
def _handle_chat_generate_stream_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Generator[BaichuanMessage, None, None],
) -> Generator:
for message in response:
if message.usage:
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens'])
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=message.usage["prompt_tokens"],
completion_tokens=message.usage["completion_tokens"],
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=message.content,
tool_calls=[]
content=message.content, tool_calls=[]
),
usage=usage,
finish_reason=message.stop_reason if message.stop_reason else None,
finish_reason=message.stop_reason
if message.stop_reason
else None,
),
)
else:
@ -179,10 +249,11 @@ class BaichuanLarguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=message.content,
tool_calls=[]
content=message.content, tool_calls=[]
),
finish_reason=message.stop_reason if message.stop_reason else None,
finish_reason=message.stop_reason
if message.stop_reason
else None,
),
)
@ -197,21 +268,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeConnectionError: [],
InvokeServerUnavailableError: [InternalServerError],
InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
InvokeBadRequestError: [BadRequestError, KeyError],
}

View File

@ -5,7 +5,10 @@ from typing import Optional
from requests import post
from model_providers.core.model_runtime.entities.model_entities import PriceType
from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from model_providers.core.model_runtime.entities.text_embedding_entities import (
EmbeddingUsage,
TextEmbeddingResult,
)
from model_providers.core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@ -14,9 +17,15 @@ from model_providers.core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
TextEmbeddingModel,
)
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import (
BaichuanTokenizer,
)
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError,
InsufficientAccountBalance,
@ -31,11 +40,16 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for BaiChuan text embedding model.
"""
api_base: str = 'http://api.baichuan-ai.com/v1/embeddings'
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
api_base: str = "http://api.baichuan-ai.com/v1/embeddings"
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
@ -45,27 +59,24 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
:param user: unique user id
:return: embeddings result
"""
api_key = credentials['api_key']
if model != 'baichuan-text-embedding':
raise ValueError('Invalid model name')
api_key = credentials["api_key"]
if model != "baichuan-text-embedding":
raise ValueError("Invalid model name")
if not api_key:
raise CredentialsValidateFailedError('api_key is required')
raise CredentialsValidateFailedError("api_key is required")
# split into chunks of batch size 16
chunks = []
for i in range(0, len(texts), 16):
chunks.append(texts[i:i + 16])
chunks.append(texts[i : i + 16])
embeddings = []
token_usage = 0
for chunk in chunks:
# embeding chunk
# embedding chunk
chunk_embeddings, chunk_usage = self.embedding(
model=model,
api_key=api_key,
texts=chunk,
user=user
model=model, api_key=api_key, texts=chunk, user=user
)
embeddings.extend(chunk_embeddings)
@ -75,16 +86,15 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
model=model,
embeddings=embeddings,
usage=self._calc_response_usage(
model=model,
credentials=credentials,
tokens=token_usage
)
model=model, credentials=credentials, tokens=token_usage
),
)
return result
def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = None) \
-> tuple[list[list[float]], int]:
def embedding(
self, model: str, api_key, texts: list[str], user: Optional[str] = None
) -> tuple[list[list[float]], int]:
"""
Embed given texts
@ -96,55 +106,53 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
"""
url = self.api_base
headers = {
'Authorization': 'Bearer ' + api_key,
'Content-Type': 'application/json'
"Authorization": "Bearer " + api_key,
"Content-Type": "application/json",
}
data = {
'model': 'Baichuan-Text-Embedding',
'input': texts
}
data = {"model": "Baichuan-Text-Embedding", "input": texts}
try:
response = post(url, headers=headers, data=dumps(data))
except Exception as e:
raise InvokeConnectionError(str(e))
if response.status_code != 200:
try:
resp = response.json()
# try to parse error message
err = resp['error']['code']
msg = resp['error']['message']
err = resp["error"]["code"]
msg = resp["error"]["message"]
except Exception as e:
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
raise InternalServerError(
f"Failed to convert response to json: {e} with text: {response.text}"
)
if err == 'invalid_api_key':
if err == "invalid_api_key":
raise InvalidAPIKeyError(msg)
elif err == 'insufficient_quota':
elif err == "insufficient_quota":
raise InsufficientAccountBalance(msg)
elif err == 'invalid_authentication':
elif err == "invalid_authentication":
raise InvalidAuthenticationError(msg)
elif err and 'rate' in err:
elif err and "rate" in err:
raise RateLimitReachedError(msg)
elif err and 'internal' in err:
elif err and "internal" in err:
raise InternalServerError(msg)
elif err == 'api_key_empty':
elif err == "api_key_empty":
raise InvalidAPIKeyError(msg)
else:
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
try:
resp = response.json()
embeddings = resp['data']
usage = resp['usage']
embeddings = resp["data"]
usage = resp["usage"]
except Exception as e:
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
return [
data['embedding'] for data in embeddings
], usage['total_tokens']
raise InternalServerError(
f"Failed to convert response to json: {e} with text: {response.text}"
)
return [data["embedding"] for data in embeddings], usage["total_tokens"]
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
@ -170,33 +178,27 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
:return:
"""
try:
self._invoke(model=model, credentials=credentials, texts=['ping'])
self._invoke(model=model, credentials=credentials, texts=["ping"])
except InvalidAPIKeyError:
raise CredentialsValidateFailedError('Invalid api key')
raise CredentialsValidateFailedError("Invalid api key")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeConnectionError: [],
InvokeServerUnavailableError: [InternalServerError],
InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
InvokeBadRequestError: [BadRequestError, KeyError],
}
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
def _calc_response_usage(
self, model: str, credentials: dict, tokens: int
) -> EmbeddingUsage:
"""
Calculate response usage
@ -210,7 +212,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
tokens=tokens,
)
# transform usage
@ -221,7 +223,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage

View File

@ -1,11 +1,16 @@
import logging
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
ModelProvider,
)
logger = logging.getLogger(__name__)
class BedrockProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
@ -20,11 +25,12 @@ class BedrockProvider(ModelProvider):
# Use `gemini-pro` model for validate,
model_instance.validate_credentials(
model='amazon.titan-text-lite-v1',
credentials=credentials
model="amazon.titan-text-lite-v1", credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(
f"{self.get_provider_schema().provider} credentials validate failed"
)
raise ex

View File

@ -13,7 +13,11 @@ from botocore.exceptions import (
UnknownServiceError,
)
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@ -29,18 +33,28 @@ from model_providers.core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
logger = logging.getLogger(__name__)
class BedrockLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
class BedrockLargeLanguageModel(LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -55,10 +69,17 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
return self._generate(
model, credentials, prompt_messages, model_parameters, stop, stream, user
)
def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str,
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
messages: list[PromptMessage] | str,
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -68,7 +89,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param tools: tools for tool calling
:return:md = genai.GenerativeModel(model)
"""
prefix = model.split('.')[0]
prefix = model.split(".")[0]
if isinstance(messages, str):
prompt = messages
@ -76,8 +97,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
prompt = self._convert_messages_to_prompt(messages, prefix)
return self._get_num_tokens_by_gpt2(prompt)
def _convert_messages_to_prompt(self, model_prefix: str, messages: list[PromptMessage]) -> str:
def _convert_messages_to_prompt(
self, model_prefix: str, messages: list[PromptMessage]
) -> str:
"""
Format a list of messages into a full prompt for the Google model
@ -85,7 +108,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
messages = messages.copy() # don't mutate the original list
text = "".join(
self._convert_one_message_to_text(message, model_prefix)
for message in messages
@ -101,32 +124,38 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param credentials: model credentials
:return:
"""
try:
ping_message = UserPromptMessage(content="ping")
self._generate(model=model,
credentials=credentials,
prompt_messages=[ping_message],
model_parameters={},
stream=False)
self._generate(
model=model,
credentials=credentials,
prompt_messages=[ping_message],
model_parameters={},
stream=False,
)
except ClientError as ex:
error_code = ex.response['Error']['Code']
error_code = ex.response["Error"]["Code"]
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg)))
raise CredentialsValidateFailedError(
str(self._map_client_to_invoke_error(error_code, full_error_msg))
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str) -> str:
def _convert_one_message_to_text(
self, message: PromptMessage, model_prefix: str
) -> str:
"""
Convert a single message to a string.
:param message: PromptMessage to convert.
:return: String representation of the message.
"""
if model_prefix == "anthropic":
human_prompt_prefix = "\n\nHuman:"
human_prompt_postfix = ""
@ -141,7 +170,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
human_prompt_prefix = "\n\nUser:"
human_prompt_postfix = ""
ai_prompt = "\n\nBot:"
else:
human_prompt_prefix = ""
human_prompt_postfix = ""
@ -160,7 +189,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return message_text
def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str) -> str:
def _convert_messages_to_prompt(
self, messages: list[PromptMessage], model_prefix: str
) -> str:
"""
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
@ -168,7 +199,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
if not messages:
return ''
return ""
messages = messages.copy() # don't mutate the original list
if not isinstance(messages[-1], AssistantPromptMessage):
@ -182,23 +213,36 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
# trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip()
def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
def _create_payload(
self,
model_prefix: str,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
):
"""
Create payload for bedrock api call depending on model provider
"""
payload = dict()
if model_prefix == "amazon":
payload["textGenerationConfig"] = { **model_parameters }
payload["textGenerationConfig"]["stopSequences"] = ["User:"] + (stop if stop else [])
payload["inputText"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
payload["textGenerationConfig"] = {**model_parameters}
payload["textGenerationConfig"]["stopSequences"] = ["User:"] + (
stop if stop else []
)
payload["inputText"] = self._convert_messages_to_prompt(
prompt_messages, model_prefix
)
elif model_prefix == "ai21":
payload["temperature"] = model_parameters.get("temperature")
payload["topP"] = model_parameters.get("topP")
payload["maxTokens"] = model_parameters.get("maxTokens")
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
payload["prompt"] = self._convert_messages_to_prompt(
prompt_messages, model_prefix
)
# jurassic models only support a single stop sequence
if stop:
@ -212,28 +256,38 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
payload["countPenalty"] = {model_parameters.get("countPenalty")}
elif model_prefix == "anthropic":
payload = { **model_parameters }
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
payload = {**model_parameters}
payload["prompt"] = self._convert_messages_to_prompt(
prompt_messages, model_prefix
)
payload["stop_sequences"] = ["\n\nHuman:"] + (stop if stop else [])
elif model_prefix == "cohere":
payload = { **model_parameters }
payload = {**model_parameters}
payload["prompt"] = prompt_messages[0].content
payload["stream"] = stream
elif model_prefix == "meta":
payload = { **model_parameters }
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
payload = {**model_parameters}
payload["prompt"] = self._convert_messages_to_prompt(
prompt_messages, model_prefix
)
else:
raise ValueError(f"Got unknown model prefix {model_prefix}")
return payload
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -246,19 +300,19 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
client_config = Config(
region_name=credentials["aws_region"]
)
client_config = Config(region_name=credentials["aws_region"])
runtime_client = boto3.client(
service_name='bedrock-runtime',
service_name="bedrock-runtime",
config=client_config,
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"]
aws_secret_access_key=credentials["aws_secret_access_key"],
)
model_prefix = model.split('.')[0]
payload = self._create_payload(model_prefix, prompt_messages, model_parameters, stop, stream)
model_prefix = model.split(".")[0]
payload = self._create_payload(
model_prefix, prompt_messages, model_parameters, stop, stream
)
# need workaround for ai21 models which doesn't support streaming
if stream and model_prefix != "ai21":
@ -267,18 +321,18 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
invoke = runtime_client.invoke_model
try:
body_jsonstr=json.dumps(payload)
body_jsonstr = json.dumps(payload)
response = invoke(
modelId=model,
contentType="application/json",
accept= "*/*",
body=body_jsonstr
accept="*/*",
body=body_jsonstr,
)
except ClientError as ex:
error_code = ex.response['Error']['Code']
error_code = ex.response["Error"]["Code"]
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
raise self._map_client_to_invoke_error(error_code, full_error_msg)
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
raise InvokeConnectionError(str(ex))
@ -287,15 +341,23 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
except Exception as ex:
raise InvokeError(str(ex))
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_stream_response(
model, credentials, response, prompt_messages
)
return self._handle_generate_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(
model, credentials, response, prompt_messages
)
def _handle_generate_response(self, model: str, credentials: dict, response: dict,
prompt_messages: list[PromptMessage]) -> LLMResult:
def _handle_generate_response(
self,
model: str,
credentials: dict,
response: dict,
prompt_messages: list[PromptMessage],
) -> LLMResult:
"""
Handle llm response
@ -305,7 +367,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages
:return: llm response
"""
response_body = json.loads(response.get('body').read().decode('utf-8'))
response_body = json.loads(response.get("body").read().decode("utf-8"))
finish_reason = response_body.get("error")
@ -313,43 +375,51 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
raise InvokeError(finish_reason)
# get output text and calculate num tokens based on model / provider
model_prefix = model.split('.')[0]
model_prefix = model.split(".")[0]
if model_prefix == "amazon":
output = response_body.get("results")[0].get("outputText").strip('\n')
output = response_body.get("results")[0].get("outputText").strip("\n")
prompt_tokens = response_body.get("inputTextTokenCount")
completion_tokens = response_body.get("results")[0].get("tokenCount")
elif model_prefix == "ai21":
output = response_body.get('completions')[0].get('data').get('text')
output = response_body.get("completions")[0].get("data").get("text")
prompt_tokens = len(response_body.get("prompt").get("tokens"))
completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens'))
completion_tokens = len(
response_body.get("completions")[0].get("data").get("tokens")
)
elif model_prefix == "anthropic":
output = response_body.get("completion")
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, output if output else '')
completion_tokens = self.get_num_tokens(
model, credentials, output if output else ""
)
elif model_prefix == "cohere":
output = response_body.get("generations")[0].get("text")
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, output if output else '')
completion_tokens = self.get_num_tokens(
model, credentials, output if output else ""
)
elif model_prefix == "meta":
output = response_body.get("generation").strip('\n')
output = response_body.get("generation").strip("\n")
prompt_tokens = response_body.get("prompt_token_count")
completion_tokens = response_body.get("generation_token_count")
else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
raise ValueError(
f"Got unknown model prefix {model_prefix} when handling block response"
)
# construct assistant message from output
assistant_prompt_message = AssistantPromptMessage(
content=output
)
assistant_prompt_message = AssistantPromptMessage(content=output)
# calculate usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
# construct response
result = LLMResult(
@ -361,8 +431,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return result
def _handle_generate_stream_response(self, model: str, credentials: dict, response: dict,
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_generate_stream_response(
self,
model: str,
credentials: dict,
response: dict,
prompt_messages: list[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@ -372,48 +447,52 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
model_prefix = model.split('.')[0]
model_prefix = model.split(".")[0]
if model_prefix == "ai21":
response_body = json.loads(response.get('body').read().decode('utf-8'))
response_body = json.loads(response.get("body").read().decode("utf-8"))
content = response_body.get('completions')[0].get('data').get('text')
finish_reason = response_body.get('completions')[0].get('finish_reason')
content = response_body.get("completions")[0].get("data").get("text")
finish_reason = response_body.get("completions")[0].get("finish_reason")
prompt_tokens = len(response_body.get("prompt").get("tokens"))
completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens'))
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
completion_tokens = len(
response_body.get("completions")[0].get("data").get("tokens")
)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=content),
finish_reason=finish_reason,
usage=usage
)
)
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=content),
finish_reason=finish_reason,
usage=usage,
),
)
return
stream = response.get('body')
stream = response.get("body")
if not stream:
raise InvokeError('No response body')
raise InvokeError("No response body")
index = -1
for event in stream:
chunk = event.get('chunk')
chunk = event.get("chunk")
if not chunk:
exception_name = next(iter(event))
full_ex_msg = f"{exception_name}: {event[exception_name]['message']}"
raise self._map_client_to_invoke_error(exception_name, full_ex_msg)
payload = json.loads(chunk.get('bytes').decode())
payload = json.loads(chunk.get("bytes").decode())
model_prefix = model.split('.')[0]
model_prefix = model.split(".")[0]
if model_prefix == "amazon":
content_delta = payload.get("outputText").strip('\n')
content_delta = payload.get("outputText").strip("\n")
finish_reason = payload.get("completion_reason")
elif model_prefix == "anthropic":
content_delta = payload.get("completion")
finish_reason = payload.get("stop_reason")
@ -421,38 +500,45 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
elif model_prefix == "cohere":
content_delta = payload.get("text")
finish_reason = payload.get("finish_reason")
elif model_prefix == "meta":
content_delta = payload.get("generation").strip('\n')
content_delta = payload.get("generation").strip("\n")
finish_reason = payload.get("stop_reason")
else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response")
raise ValueError(
f"Got unknown model prefix {model_prefix} when handling stream response"
)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content = content_delta if content_delta else '',
content=content_delta if content_delta else "",
)
index += 1
if not finish_reason:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
)
index=index, message=assistant_prompt_message
),
)
else:
# get num tokens from metrics in last chunk
prompt_tokens = payload["amazon-bedrock-invocationMetrics"]["inputTokenCount"]
completion_tokens = payload["amazon-bedrock-invocationMetrics"]["outputTokenCount"]
prompt_tokens = payload["amazon-bedrock-invocationMetrics"][
"inputTokenCount"
]
completion_tokens = payload["amazon-bedrock-invocationMetrics"][
"outputTokenCount"
]
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
@ -460,10 +546,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
index=index,
message=assistant_prompt_message,
finish_reason=finish_reason,
usage=usage
)
usage=usage,
),
)
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
@ -479,10 +565,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: []
InvokeBadRequestError: [],
}
def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
def _map_client_to_invoke_error(
self, error_code: str, error_msg: str
) -> type[InvokeError]:
"""
Map client error to invoke error
@ -497,7 +585,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return InvokeBadRequestError(error_msg)
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
return InvokeRateLimitError(error_msg)
elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]:
elif error_code in [
"ModelTimeoutException",
"ModelErrorException",
"InternalServerException",
"ModelNotReadyException",
]:
return InvokeServerUnavailableError(error_msg)
elif error_code == "ModelStreamErrorException":
return InvokeConnectionError(error_msg)

View File

@ -1,8 +1,12 @@
import logging
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
ModelProvider,
)
logger = logging.getLogger(__name__)
@ -21,11 +25,12 @@ class ChatGLMProvider(ModelProvider):
# Use `chatglm3-6b` model for validate,
model_instance.validate_credentials(
model='chatglm3-6b',
credentials=credentials
model="chatglm3-6b", credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(
f"{self.get_provider_schema().provider} credentials validate failed"
)
raise ex

View File

@ -20,7 +20,11 @@ from openai import (
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message import FunctionCall
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@ -37,18 +41,29 @@ from model_providers.core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
from model_providers.core.model_runtime.utils import helper
logger = logging.getLogger(__name__)
class ChatGLMLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
"""
Invoke large language model
@ -71,11 +86,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -96,11 +116,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
:return:
"""
try:
self._invoke(model=model, credentials=credentials, prompt_messages=[
UserPromptMessage(content="ping"),
], model_parameters={
"max_tokens": 16,
})
self._invoke(
model=model,
credentials=credentials,
prompt_messages=[
UserPromptMessage(content="ping"),
],
model_parameters={
"max_tokens": 16,
},
)
except Exception as e:
raise CredentialsValidateFailedError(str(e))
@ -124,24 +149,24 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
ConflictError,
NotFoundError,
UnprocessableEntityError,
PermissionDeniedError
PermissionDeniedError,
],
InvokeRateLimitError: [
RateLimitError
],
InvokeAuthorizationError: [
AuthenticationError
],
InvokeBadRequestError: [
ValueError
]
InvokeRateLimitError: [RateLimitError],
InvokeAuthorizationError: [AuthenticationError],
InvokeBadRequestError: [ValueError],
}
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
"""
Invoke large language model
@ -155,7 +180,9 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
self._check_chatglm_parameters(model=model, model_parameters=model_parameters, tools=tools)
self._check_chatglm_parameters(
model=model, model_parameters=model_parameters, tools=tools
)
kwargs = self._to_client_kwargs(credentials)
# init model client
@ -163,13 +190,13 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop'] = stop
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs['user'] = user
extra_model_kwargs["user"] = user
if tools and len(tools) > 0:
extra_model_kwargs['functions'] = [
extra_model_kwargs["functions"] = [
helper.dump_model(tool) for tool in tools
]
@ -178,21 +205,29 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
model=model,
stream=stream,
**model_parameters,
**extra_model_kwargs
**extra_model_kwargs,
)
if stream:
return self._handle_chat_generate_stream_response(
model=model, credentials=credentials, response=result, tools=tools,
prompt_messages=prompt_messages
model=model,
credentials=credentials,
response=result,
tools=tools,
prompt_messages=prompt_messages,
)
return self._handle_chat_generate_response(
model=model, credentials=credentials, response=result, tools=tools,
prompt_messages=prompt_messages
model=model,
credentials=credentials,
response=result,
tools=tools,
prompt_messages=prompt_messages,
)
def _check_chatglm_parameters(self, model: str, model_parameters: dict, tools: list[PromptMessageTool]) -> None:
def _check_chatglm_parameters(
self, model: str, model_parameters: dict, tools: list[PromptMessageTool]
) -> None:
if model.find("chatglm2") != -1 and tools is not None and len(tools) > 0:
raise InvokeBadRequestError("ChatGLM2 does not support function calling")
@ -212,7 +247,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments
"arguments": message.tool_calls[0].function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
@ -223,12 +258,12 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return message_dict
def _extract_response_tool_calls(self,
response_function_calls: list[FunctionCall]) \
-> list[AssistantPromptMessage.ToolCall]:
def _extract_response_tool_calls(
self, response_function_calls: list[FunctionCall]
) -> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
@ -239,19 +274,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
if response_function_calls:
for response_tool_call in response_function_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.name,
arguments=response_tool_call.arguments
name=response_tool_call.name, arguments=response_tool_call.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=0,
type='function',
function=function
id=0, type="function", function=function
)
tool_calls.append(tool_call)
return tool_calls
def _to_client_kwargs(self, credentials: dict) -> dict:
"""
Convert invoke kwargs to client kwargs
@ -265,17 +297,20 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
client_kwargs = {
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"api_key": "1",
"base_url": join(credentials['api_base'], 'v1')
"base_url": join(credentials["api_base"], "v1"),
}
return client_kwargs
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) \
-> Generator:
full_response = ''
def _handle_chat_generate_stream_response(
self,
model: str,
credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> Generator:
full_response = ""
for chunk in response:
if len(chunk.choices) == 0:
@ -283,35 +318,46 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
delta = chunk.choices[0]
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
if delta.finish_reason is None and (
delta.delta.content is None or delta.delta.content == ""
):
continue
# check if there is a tool call in the response
function_calls = None
if delta.delta.function_call:
function_calls = [delta.delta.function_call]
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else [])
assistant_message_tool_calls = self._extract_response_tool_calls(
function_calls if function_calls else []
)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else '',
tool_calls=assistant_message_tool_calls
content=delta.delta.content if delta.delta.content else "",
tool_calls=assistant_message_tool_calls,
)
if delta.finish_reason is not None:
# temp_assistant_prompt_message is used to calculate usage
temp_assistant_prompt_message = AssistantPromptMessage(
content=full_response,
tool_calls=assistant_message_tool_calls
content=full_response, tool_calls=assistant_message_tool_calls
)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
prompt_tokens = self._num_tokens_from_messages(
messages=prompt_messages, tools=tools
)
completion_tokens = self._num_tokens_from_messages(
messages=[temp_assistant_prompt_message], tools=[]
)
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
@ -320,7 +366,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage
usage=usage,
),
)
else:
@ -335,11 +381,15 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
)
full_response += delta.delta.content
def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) \
-> LLMResult:
def _handle_chat_generate_response(
self,
model: str,
credentials: dict,
response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> LLMResult:
"""
Handle llm chat response
@ -356,18 +406,28 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
# convert function call to tool call
function_calls = assistant_message.function_call
tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else [])
tool_calls = self._extract_response_tool_calls(
[function_calls] if function_calls else []
)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_message.content,
tool_calls=tool_calls
content=assistant_message.content, tool_calls=tool_calls
)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
prompt_tokens = self._num_tokens_from_messages(
messages=prompt_messages, tools=tools
)
completion_tokens = self._num_tokens_from_messages(
messages=[assistant_prompt_message], tools=tools
)
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
response = LLMResult(
model=model,
@ -378,8 +438,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
)
return response
def _num_tokens_from_string(self, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int:
def _num_tokens_from_string(
self, text: str, tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""
Calculate num tokens for text completion model with tiktoken package.
@ -395,17 +457,21 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
return num_tokens
def _num_tokens_from_messages(self, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def _num_tokens_from_messages(
self,
messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer.
it's too complex to calculate num tokens for chatglm2 and chatglm3 with ChatGLM tokenizer,
As a temporary solution we use GPT2 tokenizer instead.
"""
def tokens(text: str):
return self._get_num_tokens_by_gpt2(text)
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
@ -414,10 +480,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
num_tokens += tokens_per_message
for key, value in message.items():
if isinstance(value, list):
text = ''
text = ""
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
if key == "function_call":
@ -452,36 +518,37 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
:param tools: tools for tool calling
:return: number of tokens
"""
def tokens(text: str):
return self._get_num_tokens_by_gpt2(text)
num_tokens = 0
for tool in tools:
# calculate num tokens for function object
num_tokens += tokens('name')
num_tokens += tokens("name")
num_tokens += tokens(tool.name)
num_tokens += tokens('description')
num_tokens += tokens("description")
num_tokens += tokens(tool.description)
parameters = tool.parameters
num_tokens += tokens('parameters')
num_tokens += tokens('type')
num_tokens += tokens("parameters")
num_tokens += tokens("type")
num_tokens += tokens(parameters.get("type"))
if 'properties' in parameters:
num_tokens += tokens('properties')
for key, value in parameters.get('properties').items():
if "properties" in parameters:
num_tokens += tokens("properties")
for key, value in parameters.get("properties").items():
num_tokens += tokens(key)
for field_key, field_value in value.items():
num_tokens += tokens(field_key)
if field_key == 'enum':
if field_key == "enum":
for enum_field in field_value:
num_tokens += 3
num_tokens += tokens(enum_field)
else:
num_tokens += tokens(field_key)
num_tokens += tokens(str(field_value))
if 'required' in parameters:
num_tokens += tokens('required')
for required_field in parameters['required']:
if "required" in parameters:
num_tokens += tokens("required")
for required_field in parameters["required"]:
num_tokens += 3
num_tokens += tokens(required_field)

View File

@ -1,8 +1,12 @@
import logging
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
ModelProvider,
)
logger = logging.getLogger(__name__)
@ -21,11 +25,12 @@ class CohereProvider(ModelProvider):
# Use `rerank-english-v2.0` model for validate,
model_instance.validate_credentials(
model='rerank-english-v2.0',
credentials=credentials
model="rerank-english-v2.0", credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(
f"{self.get_provider_schema().provider} credentials validate failed"
)
raise ex

View File

@ -7,7 +7,12 @@ from cohere.responses import Chat, Generations
from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration
from cohere.responses.generation import StreamingGenerations, StreamingText
from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from model_providers.core.model_runtime.entities.llm_entities import (
LLMMode,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@ -17,7 +22,12 @@ from model_providers.core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
I18nObject,
ModelType,
)
from model_providers.core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@ -26,8 +36,12 @@ from model_providers.core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
logger = logging.getLogger(__name__)
@ -37,11 +51,17 @@ class CohereLargeLanguageModel(LargeLanguageModel):
Model class for Cohere large language model.
"""
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -66,7 +86,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
model_parameters=model_parameters,
stop=stop,
stream=stream,
user=user
user=user,
)
else:
return self._generate(
@ -76,11 +96,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
model_parameters=model_parameters,
stop=stop,
stream=stream,
user=user
user=user,
)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -95,9 +120,13 @@ class CohereLargeLanguageModel(LargeLanguageModel):
try:
if model_mode == LLMMode.CHAT:
return self._num_tokens_from_messages(model, credentials, prompt_messages)
return self._num_tokens_from_messages(
model, credentials, prompt_messages
)
else:
return self._num_tokens_from_string(model, credentials, prompt_messages[0].content)
return self._num_tokens_from_string(
model, credentials, prompt_messages[0].content
)
except Exception as e:
raise self._transform_invoke_error(e)
@ -117,30 +146,37 @@ class CohereLargeLanguageModel(LargeLanguageModel):
self._chat_generate(
model=model,
credentials=credentials,
prompt_messages=[UserPromptMessage(content='ping')],
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={
'max_tokens': 20,
'temperature': 0,
"max_tokens": 20,
"temperature": 0,
},
stream=False
stream=False,
)
else:
self._generate(
model=model,
credentials=credentials,
prompt_messages=[UserPromptMessage(content='ping')],
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={
'max_tokens': 20,
'temperature': 0,
"max_tokens": 20,
"temperature": 0,
},
stream=False
stream=False,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke llm model
@ -154,10 +190,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get("api_key"))
if stop:
model_parameters['end_sequences'] = stop
model_parameters["end_sequences"] = stop
response = client.generate(
prompt=prompt_messages[0].content,
@ -167,13 +203,21 @@ class CohereLargeLanguageModel(LargeLanguageModel):
)
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_stream_response(
model, credentials, response, prompt_messages
)
return self._handle_generate_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(
model, credentials, response, prompt_messages
)
def _handle_generate_response(self, model: str, credentials: dict, response: Generations,
prompt_messages: list[PromptMessage]) \
-> LLMResult:
def _handle_generate_response(
self,
model: str,
credentials: dict,
response: Generations,
prompt_messages: list[PromptMessage],
) -> LLMResult:
"""
Handle llm response
@ -186,29 +230,34 @@ class CohereLargeLanguageModel(LargeLanguageModel):
assistant_text = response.generations[0].text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
)
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
# calculate num tokens
prompt_tokens = response.meta['billed_units']['input_tokens']
completion_tokens = response.meta['billed_units']['output_tokens']
prompt_tokens = response.meta["billed_units"]["input_tokens"]
completion_tokens = response.meta["billed_units"]["output_tokens"]
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
# transform response
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage
usage=usage,
)
return response
def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations,
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_generate_stream_response(
self,
model: str,
credentials: dict,
response: StreamingGenerations,
prompt_messages: list[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@ -218,7 +267,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: llm response chunk generator
"""
index = 1
full_assistant_content = ''
full_assistant_content = ""
for chunk in response:
if isinstance(chunk, StreamingText):
chunk = cast(StreamingText, chunk)
@ -228,9 +277,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=text
)
assistant_prompt_message = AssistantPromptMessage(content=text)
full_assistant_content += text
@ -240,33 +287,42 @@ class CohereLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
)
),
)
index += 1
elif chunk is None:
# calculate num tokens
prompt_tokens = response.meta['billed_units']['input_tokens']
completion_tokens = response.meta['billed_units']['output_tokens']
prompt_tokens = response.meta["billed_units"]["input_tokens"]
completion_tokens = response.meta["billed_units"]["output_tokens"]
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=''),
message=AssistantPromptMessage(content=""),
finish_reason=response.finish_reason,
usage=usage
)
usage=usage,
),
)
break
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _chat_generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke llm chat model
@ -280,17 +336,23 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get("api_key"))
if user:
model_parameters['user_name'] = user
model_parameters["user_name"] = user
message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
(
message,
chat_histories,
) = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
# chat model
real_model = model
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
real_model = model.removesuffix('-chat')
if (
self.get_model_schema(model, credentials).fetch_from
== FetchFrom.PREDEFINED_MODEL
):
real_model = model.removesuffix("-chat")
response = client.chat(
message=message,
@ -302,13 +364,22 @@ class CohereLargeLanguageModel(LargeLanguageModel):
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop)
return self._handle_chat_generate_stream_response(
model, credentials, response, prompt_messages, stop
)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop)
return self._handle_chat_generate_response(
model, credentials, response, prompt_messages, stop
)
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat,
prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \
-> LLMResult:
def _handle_chat_generate_response(
self,
model: str,
credentials: dict,
response: Chat,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> LLMResult:
"""
Handle llm chat response
@ -322,23 +393,25 @@ class CohereLargeLanguageModel(LargeLanguageModel):
assistant_text = response.text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
)
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
completion_tokens = self._num_tokens_from_messages(model, credentials, [assistant_prompt_message])
prompt_tokens = self._num_tokens_from_messages(
model, credentials, prompt_messages
)
completion_tokens = self._num_tokens_from_messages(
model, credentials, [assistant_prompt_message]
)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
if stop:
# enforce stop tokens
assistant_text = self.enforce_stop_tokens(assistant_text, stop)
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
)
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
# transform response
response = LLMResult(
@ -346,14 +419,19 @@ class CohereLargeLanguageModel(LargeLanguageModel):
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
system_fingerprint=response.preamble
system_fingerprint=response.preamble,
)
return response
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None) -> Generator:
def _handle_chat_generate_stream_response(
self,
model: str,
credentials: dict,
response: StreamingChat,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator:
"""
Handle llm chat stream response
@ -364,18 +442,26 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: llm response chunk generator
"""
def final_response(full_text: str, index: int, finish_reason: Optional[str] = None,
preamble: Optional[str] = None) -> LLMResultChunk:
def final_response(
full_text: str,
index: int,
finish_reason: Optional[str] = None,
preamble: Optional[str] = None,
) -> LLMResultChunk:
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
full_assistant_prompt_message = AssistantPromptMessage(
content=full_text
prompt_tokens = self._num_tokens_from_messages(
model, credentials, prompt_messages
)
full_assistant_prompt_message = AssistantPromptMessage(content=full_text)
completion_tokens = self._num_tokens_from_messages(
model, credentials, [full_assistant_prompt_message]
)
completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
return LLMResultChunk(
model=model,
@ -383,14 +469,14 @@ class CohereLargeLanguageModel(LargeLanguageModel):
system_fingerprint=preamble,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=''),
message=AssistantPromptMessage(content=""),
finish_reason=finish_reason,
usage=usage
)
usage=usage,
),
)
index = 1
full_assistant_content = ''
full_assistant_content = ""
for chunk in response:
if isinstance(chunk, StreamTextGeneration):
chunk = cast(StreamTextGeneration, chunk)
@ -400,14 +486,12 @@ class CohereLargeLanguageModel(LargeLanguageModel):
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=text
)
assistant_prompt_message = AssistantPromptMessage(content=text)
# stop
# notice: This logic can only cover few stop scenarios
if stop and text in stop:
yield final_response(full_assistant_content, index, 'stop')
yield final_response(full_assistant_content, index, "stop")
break
full_assistant_content += text
@ -418,17 +502,23 @@ class CohereLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
)
),
)
index += 1
elif isinstance(chunk, StreamEnd):
chunk = cast(StreamEnd, chunk)
yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble)
yield final_response(
full_assistant_content,
index,
chunk.finish_reason,
response.preamble,
)
index += 1
def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \
-> tuple[str, list[dict]]:
def _convert_prompt_messages_to_message_and_chat_histories(
self, prompt_messages: list[PromptMessage]
) -> tuple[str, list[dict]]:
"""
Convert prompt messages to message and chat histories
:param prompt_messages: prompt messages
@ -441,9 +531,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
# get latest message from chat histories and pop it
if len(chat_histories) > 0:
latest_message = chat_histories.pop()
message = latest_message['message']
message = latest_message["message"]
else:
raise ValueError('Prompt messages is empty')
raise ValueError("Prompt messages is empty")
return message, chat_histories
@ -456,10 +546,12 @@ class CohereLargeLanguageModel(LargeLanguageModel):
if isinstance(message.content, str):
message_dict = {"role": "USER", "message": message.content}
else:
sub_message_text = ''
sub_message_text = ""
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
message_content = cast(
TextPromptMessageContent, message_content
)
sub_message_text += message_content.data
message_dict = {"role": "USER", "message": sub_message_text}
@ -487,47 +579,53 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: number of tokens
"""
# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get("api_key"))
response = client.tokenize(
text=text,
model=model
)
response = client.tokenize(text=text, model=model)
return response.length
def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int:
def _num_tokens_from_messages(
self, model: str, credentials: dict, messages: list[PromptMessage]
) -> int:
"""Calculate num tokens Cohere model."""
messages = [self._convert_prompt_message_to_dict(m) for m in messages]
message_strs = [f"{message['role']}: {message['message']}" for message in messages]
message_strs = [
f"{message['role']}: {message['message']}" for message in messages
]
message_str = "\n".join(message_strs)
real_model = model
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
real_model = model.removesuffix('-chat')
if (
self.get_model_schema(model, credentials).fetch_from
== FetchFrom.PREDEFINED_MODEL
):
real_model = model.removesuffix("-chat")
return self._num_tokens_from_string(real_model, credentials, message_str)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
def get_customizable_model_schema(
self, model: str, credentials: dict
) -> AIModelEntity:
"""
Cohere supports fine-tuning of their models. This method returns the schema of the base model
but renamed to the fine-tuned model name.
Cohere supports fine-tuning of their models. This method returns the schema of the base model
but renamed to the fine-tuned model name.
:param model: model name
:param credentials: credentials
:param model: model name
:param credentials: credentials
:return: model schema
:return: model schema
"""
# get model schema
models = self.predefined_models()
model_map = {model.model: model for model in models}
mode = credentials.get('mode')
mode = credentials.get("mode")
if mode == 'chat':
base_model_schema = model_map['command-light-chat']
if mode == "chat":
base_model_schema = model_map["command-light-chat"]
else:
base_model_schema = model_map['command-light']
base_model_schema = model_map["command-light"]
base_model_schema = cast(AIModelEntity, base_model_schema)
@ -537,18 +635,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
entity = AIModelEntity(
model=model,
label=I18nObject(
zh_Hans=model,
en_US=model
),
label=I18nObject(zh_Hans=model, en_US=model),
model_type=ModelType.LLM,
features=[feature for feature in base_model_schema_features],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
key: property for key, property in base_model_schema_model_properties.items()
key: property
for key, property in base_model_schema_model_properties.items()
},
parameter_rules=[rule for rule in base_model_schema_parameters_rules],
pricing=base_model_schema.pricing
pricing=base_model_schema.pricing,
)
return entity
@ -564,14 +660,12 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
cohere.CohereConnectionError
],
InvokeConnectionError: [cohere.CohereConnectionError],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: [
cohere.CohereAPIError,
cohere.CohereError,
]
],
}

View File

@ -2,7 +2,10 @@ from typing import Optional
import cohere
from model_providers.core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from model_providers.core.model_runtime.entities.rerank_entities import (
RerankDocument,
RerankResult,
)
from model_providers.core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@ -11,8 +14,12 @@ from model_providers.core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.rerank_model import (
RerankModel,
)
class CohereRerankModel(RerankModel):
@ -20,10 +27,16 @@ class CohereRerankModel(RerankModel):
Model class for Cohere rerank model.
"""
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
@ -37,26 +50,18 @@ class CohereRerankModel(RerankModel):
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(
model=model,
docs=docs
)
return RerankResult(model=model, docs=docs)
# initialize client
client = cohere.Client(credentials.get('api_key'))
results = client.rerank(
query=query,
documents=docs,
model=model,
top_n=top_n
)
client = cohere.Client(credentials.get("api_key"))
results = client.rerank(query=query, documents=docs, model=model, top_n=top_n)
rerank_documents = []
for idx, result in enumerate(results):
# format document
rerank_document = RerankDocument(
index=result.index,
text=result.document['text'],
text=result.document["text"],
score=result.relevance_score,
)
@ -67,10 +72,7 @@ class CohereRerankModel(RerankModel):
else:
rerank_documents.append(rerank_document)
return RerankResult(
model=model,
docs=rerank_documents
)
return RerankResult(model=model, docs=rerank_documents)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
@ -91,7 +93,7 @@ class CohereRerankModel(RerankModel):
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -116,5 +118,5 @@ class CohereRerankModel(RerankModel):
InvokeBadRequestError: [
cohere.CohereAPIError,
cohere.CohereError,
]
],
}

View File

@ -6,7 +6,10 @@ import numpy as np
from cohere.responses import Tokens
from model_providers.core.model_runtime.entities.model_entities import PriceType
from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from model_providers.core.model_runtime.entities.text_embedding_entities import (
EmbeddingUsage,
TextEmbeddingResult,
)
from model_providers.core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@ -15,8 +18,12 @@ from model_providers.core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
TextEmbeddingModel,
)
class CohereTextEmbeddingModel(TextEmbeddingModel):
@ -24,9 +31,13 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
Model class for Cohere text embedding model.
"""
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
@ -47,13 +58,11 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
for i, text in enumerate(texts):
tokenize_response = self._tokenize(
model=model,
credentials=credentials,
text=text
model=model, credentials=credentials, text=text
)
for j in range(0, tokenize_response.length, context_size):
tokens += [tokenize_response.token_strings[j: j + context_size]]
tokens += [tokenize_response.token_strings[j : j + context_size]]
indices += [i]
batched_embeddings = []
@ -64,7 +73,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
credentials=credentials,
texts=["".join(token) for token in tokens[i: i + max_chunks]]
texts=["".join(token) for token in tokens[i : i + max_chunks]],
)
used_tokens += embedding_used_tokens
@ -80,9 +89,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
_result = results[i]
if len(_result) == 0:
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
credentials=credentials,
texts=[" "]
model=model, credentials=credentials, texts=[" "]
)
used_tokens += embedding_used_tokens
@ -93,16 +100,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=used_tokens
model=model, credentials=credentials, tokens=used_tokens
)
return TextEmbeddingResult(
embeddings=embeddings,
usage=usage,
model=model
)
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
@ -116,13 +117,11 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
if len(texts) == 0:
return 0
full_text = ' '.join(texts)
full_text = " ".join(texts)
try:
response = self._tokenize(
model=model,
credentials=credentials,
text=full_text
model=model, credentials=credentials, text=full_text
)
except Exception as e:
raise self._transform_invoke_error(e)
@ -141,12 +140,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
return Tokens([], [], {})
# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get("api_key"))
response = client.tokenize(
text=text,
model=model
)
response = client.tokenize(text=text, model=model)
return response
@ -160,15 +156,13 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
"""
try:
# call embedding model
self._embedding_invoke(
model=model,
credentials=credentials,
texts=['ping']
)
self._embedding_invoke(model=model, credentials=credentials, texts=["ping"])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]:
def _embedding_invoke(
self, model: str, credentials: dict, texts: list[str]
) -> tuple[list[list[float]], int]:
"""
Invoke embedding model
@ -178,18 +172,20 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
:return: embeddings and used tokens
"""
# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get("api_key"))
# call embedding model
response = client.embed(
texts=texts,
model=model,
input_type='search_document' if len(texts) > 1 else 'search_query'
input_type="search_document" if len(texts) > 1 else "search_query",
)
return response.embeddings, response.meta['billed_units']['input_tokens']
return response.embeddings, response.meta["billed_units"]["input_tokens"]
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
def _calc_response_usage(
self, model: str, credentials: dict, tokens: int
) -> EmbeddingUsage:
"""
Calculate response usage
@ -203,7 +199,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
tokens=tokens,
)
# transform usage
@ -214,7 +210,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage
@ -230,14 +226,12 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
cohere.CohereConnectionError
],
InvokeConnectionError: [cohere.CohereConnectionError],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: [
cohere.CohereAPIError,
cohere.CohereError,
]
],
}

View File

@ -1,8 +1,12 @@
import logging
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
ModelProvider,
)
logger = logging.getLogger(__name__)
@ -21,11 +25,12 @@ class GoogleProvider(ModelProvider):
# Use `gemini-pro` model for validate,
model_instance.validate_credentials(
model='gemini-pro',
credentials=credentials
model="gemini-pro", credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(
f"{self.get_provider_schema().provider} credentials validate failed"
)
raise ex

View File

@ -5,10 +5,19 @@ from typing import Optional, Union
import google.api_core.exceptions as exceptions
import google.generativeai as genai
import google.generativeai.client as client
from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
from google.generativeai.types import (
ContentType,
GenerateContentResponse,
HarmBlockThreshold,
HarmCategory,
)
from google.generativeai.types.content_types import to_part
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@ -26,8 +35,12 @@ from model_providers.core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
logger = logging.getLogger(__name__)
@ -42,12 +55,17 @@ if you are not sure about the structure.
class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -62,10 +80,17 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
return self._generate(
model, credentials, prompt_messages, model_parameters, stop, stream, user
)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -89,8 +114,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
messages = messages.copy() # don't mutate the original list
text = "".join(
self._convert_one_message_to_text(message)
for message in messages
self._convert_one_message_to_text(message) for message in messages
)
return text.rstrip()
@ -106,16 +130,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
try:
ping_message = PromptMessage(content="ping", role="system")
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
self._generate(
model, credentials, [ping_message], {"max_tokens_to_sample": 5}
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -129,14 +160,14 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
config_kwargs = model_parameters.copy()
config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
config_kwargs["max_output_tokens"] = config_kwargs.pop(
"max_tokens_to_sample", None
)
if stop:
config_kwargs["stop_sequences"] = stop
google_model = genai.GenerativeModel(
model_name=model
)
google_model = genai.GenerativeModel(model_name=model)
history = []
@ -146,14 +177,13 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
content = self._format_message_to_glm_content(last_msg)
history.append(content)
else:
for msg in prompt_messages: # makes message roles strictly alternating
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)
# Create a new ClientManager with tenant's API key
new_client_manager = client._ClientManager()
new_client_manager.configure(api_key=credentials["google_api_key"])
@ -161,7 +191,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
google_model._client = new_custom_client
safety_settings={
safety_settings = {
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
@ -170,20 +200,27 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
response = google_model.generate_content(
contents=history,
generation_config=genai.types.GenerationConfig(
**config_kwargs
),
generation_config=genai.types.GenerationConfig(**config_kwargs),
stream=stream,
safety_settings=safety_settings
safety_settings=safety_settings,
)
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_stream_response(
model, credentials, response, prompt_messages
)
return self._handle_generate_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(
model, credentials, response, prompt_messages
)
def _handle_generate_response(self, model: str, credentials: dict, response: GenerateContentResponse,
prompt_messages: list[PromptMessage]) -> LLMResult:
def _handle_generate_response(
self,
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
) -> LLMResult:
"""
Handle llm response
@ -194,16 +231,18 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: llm response
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=response.text
)
assistant_prompt_message = AssistantPromptMessage(content=response.text)
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
completion_tokens = self.get_num_tokens(
model, credentials, [assistant_prompt_message]
)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
# transform response
result = LLMResult(
@ -215,8 +254,13 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return result
def _handle_generate_stream_response(self, model: str, credentials: dict, response: GenerateContentResponse,
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_generate_stream_response(
self,
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@ -232,28 +276,29 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
index += 1
assistant_prompt_message = AssistantPromptMessage(
content=content if content else '',
content=content if content else "",
)
if not response._done:
# transform assistant message to prompt message
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
)
index=index, message=assistant_prompt_message
),
)
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
completion_tokens = self.get_num_tokens(
model, credentials, [assistant_prompt_message]
)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
yield LLMResultChunk(
model=model,
@ -262,8 +307,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
index=index,
message=assistant_prompt_message,
finish_reason=chunk.candidates[0].finish_reason,
usage=usage
)
usage=usage,
),
)
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
@ -302,21 +347,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
"""
parts = []
if (isinstance(message.content, str)):
if isinstance(message.content, str):
parts.append(to_part(message.content))
else:
for c in message.content:
if c.type == PromptMessageContentType.TEXT:
parts.append(to_part(c.data))
else:
metadata, data = c.data.split(',', 1)
mime_type = metadata.split(';', 1)[0].split(':')[1]
blob = {"inline_data":{"mime_type":mime_type,"data":data}}
metadata, data = c.data.split(",", 1)
mime_type = metadata.split(";", 1)[0].split(":")[1]
blob = {"inline_data": {"mime_type": mime_type, "data": data}}
parts.append(blob)
glm_content = {
"role": "user" if message.role in (PromptMessageRole.USER, PromptMessageRole.SYSTEM) else "model",
"parts": parts
"role": "user"
if message.role in (PromptMessageRole.USER, PromptMessageRole.SYSTEM)
else "model",
"parts": parts,
}
return glm_content
@ -332,25 +379,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: Invoke emd = genai.GenerativeModel(model)rror mapping
"""
return {
InvokeConnectionError: [
exceptions.RetryError
],
InvokeConnectionError: [exceptions.RetryError],
InvokeServerUnavailableError: [
exceptions.ServiceUnavailable,
exceptions.InternalServerError,
exceptions.BadGateway,
exceptions.GatewayTimeout,
exceptions.DeadlineExceeded
exceptions.DeadlineExceeded,
],
InvokeRateLimitError: [
exceptions.ResourceExhausted,
exceptions.TooManyRequests
exceptions.TooManyRequests,
],
InvokeAuthorizationError: [
exceptions.Unauthenticated,
exceptions.PermissionDenied,
exceptions.Unauthenticated,
exceptions.Forbidden
exceptions.Forbidden,
],
InvokeBadRequestError: [
exceptions.BadRequest,
@ -366,5 +411,5 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
exceptions.PreconditionFailed,
exceptions.RequestRangeNotSatisfiable,
exceptions.Cancelled,
]
],
}

View File

@ -1,13 +1,17 @@
import logging
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
ModelProvider,
)
logger = logging.getLogger(__name__)
class GroqProvider(ModelProvider):
class GroqProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
@ -19,11 +23,12 @@ class GroqProvider(ModelProvider):
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='llama2-70b-4096',
credentials=credentials
model="llama2-70b-4096", credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(
f"{self.get_provider_schema().provider} credentials validate failed"
)
raise ex

View File

@ -2,18 +2,31 @@ from collections.abc import Generator
from typing import Optional, Union
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
from model_providers.core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import (
OAIAPICompatLargeLanguageModel,
)
class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
return super()._invoke(
model, credentials, prompt_messages, model_parameters, tools, stop, stream
)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
@ -21,6 +34,5 @@ class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['endpoint_url'] = 'https://api.groq.com/openai/v1'
credentials["mode"] = "chat"
credentials["endpoint_url"] = "https://api.groq.com/openai/v1"

View File

@ -1,15 +1,12 @@
from huggingface_hub.utils import BadRequestError, HfHubHTTPError
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError
from model_providers.core.model_runtime.errors.invoke import (
InvokeBadRequestError,
InvokeError,
)
class _CommonHuggingfaceHub:
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeBadRequestError: [
HfHubHTTPError,
BadRequestError
]
}
return {InvokeBadRequestError: [HfHubHTTPError, BadRequestError]}

View File

@ -1,11 +1,12 @@
import logging
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
ModelProvider,
)
logger = logging.getLogger(__name__)
class HuggingfaceHubProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -7,7 +7,12 @@ from huggingface_hub.utils import BadRequestError
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from model_providers.core.model_runtime.entities.llm_entities import (
LLMMode,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@ -23,22 +28,35 @@ from model_providers.core.model_runtime.entities.model_entities import (
ModelType,
ParameterRule,
)
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
from model_providers.core.model_runtime.model_providers.huggingface_hub._common import (
_CommonHuggingfaceHub,
)
class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
model = credentials["huggingfacehub_endpoint_url"]
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
model = credentials['huggingfacehub_endpoint_url']
if 'baichuan' in model.lower():
if "baichuan" in model.lower():
stream = False
response = client.text_generation(
@ -47,71 +65,97 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
stream=stream,
model=model,
stop_sequences=stop,
**model_parameters)
**model_parameters,
)
if stream:
return self._handle_generate_stream_response(model, credentials, prompt_messages, response)
return self._handle_generate_stream_response(
model, credentials, prompt_messages, response
)
return self._handle_generate_response(model, credentials, prompt_messages, response)
return self._handle_generate_response(
model, credentials, prompt_messages, response
)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
prompt = self._convert_messages_to_prompt(prompt_messages)
return self._get_num_tokens_by_gpt2(prompt)
def validate_credentials(self, model: str, credentials: dict) -> None:
try:
if 'huggingfacehub_api_type' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.')
if "huggingfacehub_api_type" not in credentials:
raise CredentialsValidateFailedError(
"Huggingface Hub Endpoint Type must be provided."
)
if credentials['huggingfacehub_api_type'] not in ('inference_endpoints', 'hosted_inference_api'):
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.')
if credentials["huggingfacehub_api_type"] not in (
"inference_endpoints",
"hosted_inference_api",
):
raise CredentialsValidateFailedError(
"Huggingface Hub Endpoint Type is invalid."
)
if 'huggingfacehub_api_token' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub Access Token must be provided.')
if "huggingfacehub_api_token" not in credentials:
raise CredentialsValidateFailedError(
"Huggingface Hub Access Token must be provided."
)
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
if 'huggingfacehub_endpoint_url' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.')
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
if "huggingfacehub_endpoint_url" not in credentials:
raise CredentialsValidateFailedError(
"Huggingface Hub Endpoint URL must be provided."
)
if 'task_type' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.')
elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api':
credentials['task_type'] = self._get_hosted_model_task_type(credentials['huggingfacehub_api_token'],
model)
if "task_type" not in credentials:
raise CredentialsValidateFailedError(
"Huggingface Hub Task Type must be provided."
)
elif credentials["huggingfacehub_api_type"] == "hosted_inference_api":
credentials["task_type"] = self._get_hosted_model_task_type(
credentials["huggingfacehub_api_token"], model
)
if credentials['task_type'] not in ("text2text-generation", "text-generation"):
raise CredentialsValidateFailedError('Huggingface Hub Task Type must be one of text2text-generation, '
'text-generation.')
if credentials["task_type"] not in (
"text2text-generation",
"text-generation",
):
raise CredentialsValidateFailedError(
"Huggingface Hub Task Type must be one of text2text-generation, "
"text-generation."
)
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
model = credentials['huggingfacehub_endpoint_url']
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
model = credentials["huggingfacehub_endpoint_url"]
try:
client.text_generation(
prompt='Who are you?',
stream=True,
model=model)
client.text_generation(prompt="Who are you?", stream=True, model=model)
except BadRequestError as e:
raise CredentialsValidateFailedError('Only available for models running on with the `text-generation-inference`. '
'To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.')
raise CredentialsValidateFailedError(
"Only available for models running on with the `text-generation-inference`. "
"To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference."
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def get_customizable_model_schema(
self, model: str, credentials: dict
) -> Optional[AIModelEntity]:
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties={
ModelPropertyKey.MODE: LLMMode.COMPLETION.value
},
parameter_rules=self._get_customizable_model_parameter_rules()
model_properties={ModelPropertyKey.MODE: LLMMode.COMPLETION.value},
parameter_rules=self._get_customizable_model_parameter_rules(),
)
return entity
@ -119,26 +163,27 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
@staticmethod
def _get_customizable_model_parameter_rules() -> list[ParameterRule]:
temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get(
DefaultParameterName.TEMPERATURE).copy()
temperature_rule_dict['name'] = 'temperature'
DefaultParameterName.TEMPERATURE
).copy()
temperature_rule_dict["name"] = "temperature"
temperature_rule = ParameterRule(**temperature_rule_dict)
temperature_rule.default = 0.5
top_p_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TOP_P).copy()
top_p_rule_dict['name'] = 'top_p'
top_p_rule_dict["name"] = "top_p"
top_p_rule = ParameterRule(**top_p_rule_dict)
top_p_rule.default = 0.5
top_k_rule = ParameterRule(
name='top_k',
name="top_k",
label={
'en_US': 'Top K',
'zh_Hans': 'Top K',
"en_US": "Top K",
"zh_Hans": "Top K",
},
type='int',
type="int",
help={
'en_US': 'The number of highest probability vocabulary tokens to keep for top-k-filtering.',
'zh_Hans': '保留的最高概率词汇标记的数量。',
"en_US": "The number of highest probability vocabulary tokens to keep for top-k-filtering.",
"zh_Hans": "保留的最高概率词汇标记的数量。",
},
required=False,
default=2,
@ -148,15 +193,15 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
)
max_new_tokens = ParameterRule(
name='max_new_tokens',
name="max_new_tokens",
label={
'en_US': 'Max New Tokens',
'zh_Hans': '最大新标记',
"en_US": "Max New Tokens",
"zh_Hans": "最大新标记",
},
type='int',
type="int",
help={
'en_US': 'Maximum number of generated tokens.',
'zh_Hans': '生成的标记的最大数量。',
"en_US": "Maximum number of generated tokens.",
"zh_Hans": "生成的标记的最大数量。",
},
required=False,
default=20,
@ -166,42 +211,51 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
)
seed = ParameterRule(
name='seed',
name="seed",
label={
'en_US': 'Random sampling seed',
'zh_Hans': '随机采样种子',
"en_US": "Random sampling seed",
"zh_Hans": "随机采样种子",
},
type='int',
type="int",
help={
'en_US': 'Random sampling seed.',
'zh_Hans': '随机采样种子。',
"en_US": "Random sampling seed.",
"zh_Hans": "随机采样种子。",
},
required=False,
precision=0,
)
repetition_penalty = ParameterRule(
name='repetition_penalty',
name="repetition_penalty",
label={
'en_US': 'Repetition Penalty',
'zh_Hans': '重复惩罚',
"en_US": "Repetition Penalty",
"zh_Hans": "重复惩罚",
},
type='float',
type="float",
help={
'en_US': 'The parameter for repetition penalty. 1.0 means no penalty.',
'zh_Hans': '重复惩罚的参数。1.0 表示没有惩罚。',
"en_US": "The parameter for repetition penalty. 1.0 means no penalty.",
"zh_Hans": "重复惩罚的参数。1.0 表示没有惩罚。",
},
required=False,
precision=1,
)
return [temperature_rule, top_k_rule, top_p_rule, max_new_tokens, seed, repetition_penalty]
return [
temperature_rule,
top_k_rule,
top_p_rule,
max_new_tokens,
seed,
repetition_penalty,
]
def _handle_generate_stream_response(self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
response: Generator) -> Generator:
def _handle_generate_stream_response(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
response: Generator,
) -> Generator:
index = -1
for chunk in response:
# skip special tokens
@ -210,15 +264,17 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
index += 1
assistant_prompt_message = AssistantPromptMessage(
content=chunk.token.text
)
assistant_prompt_message = AssistantPromptMessage(content=chunk.token.text)
if chunk.details:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
completion_tokens = self.get_num_tokens(
model, credentials, [assistant_prompt_message]
)
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
yield LLMResultChunk(
model=model,
@ -240,20 +296,28 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
),
)
def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult:
def _handle_generate_response(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
response: any,
) -> LLMResult:
if isinstance(response, str):
content = response
else:
content = response.generated_text
assistant_prompt_message = AssistantPromptMessage(
content=content
)
assistant_prompt_message = AssistantPromptMessage(content=content)
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
completion_tokens = self.get_num_tokens(
model, credentials, [assistant_prompt_message]
)
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
result = LLMResult(
model=model,
@ -270,15 +334,22 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
try:
if not model_info:
raise ValueError(f'Model {model_name} not found.')
raise ValueError(f"Model {model_name} not found.")
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
if (
"inference" in model_info.cardData
and not model_info.cardData["inference"]
):
raise ValueError(
f"Inference API has been turned off for this model {model_name}."
)
valid_tasks = ("text2text-generation", "text-generation")
if model_info.pipeline_tag not in valid_tasks:
raise ValueError(f"Model {model_name} is not a valid task, "
f"must be one of {valid_tasks}.")
raise ValueError(
f"Model {model_name} is not a valid task, "
f"must be one of {valid_tasks}."
)
except Exception as e:
raise CredentialsValidateFailedError(f"{str(e)}")
@ -288,8 +359,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
messages = messages.copy() # don't mutate the original list
text = "".join(
self._convert_one_message_to_text(message)
for message in messages
self._convert_one_message_to_text(message) for message in messages
)
return text.rstrip()

View File

@ -7,35 +7,51 @@ import requests
from huggingface_hub import HfApi, InferenceClient
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from model_providers.core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelType,
PriceType,
)
from model_providers.core.model_runtime.entities.text_embedding_entities import (
EmbeddingUsage,
TextEmbeddingResult,
)
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
TextEmbeddingModel,
)
from model_providers.core.model_runtime.model_providers.huggingface_hub._common import (
_CommonHuggingfaceHub,
)
HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/'
HUGGINGFACE_ENDPOINT_API = "https://api.endpoints.huggingface.cloud/v2/endpoint/"
class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel):
def _invoke(self, model: str, credentials: dict, texts: list[str],
user: Optional[str] = None) -> TextEmbeddingResult:
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
execute_model = model
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
execute_model = credentials['huggingfacehub_endpoint_url']
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
execute_model = credentials["huggingfacehub_endpoint_url"]
output = client.post(
json={
"inputs": texts,
"options": {
"wait_for_model": False,
"use_cache": False
}
"options": {"wait_for_model": False, "use_cache": False},
},
model=execute_model)
model=execute_model,
)
embeddings = json.loads(output.decode())
@ -43,9 +59,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
usage = self._calc_response_usage(model, credentials, tokens)
return TextEmbeddingResult(
embeddings=self._mean_pooling(embeddings),
usage=usage,
model=model
embeddings=self._mean_pooling(embeddings), usage=usage, model=model
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
@ -56,52 +70,64 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
def validate_credentials(self, model: str, credentials: dict) -> None:
try:
if 'huggingfacehub_api_type' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.')
if "huggingfacehub_api_type" not in credentials:
raise CredentialsValidateFailedError(
"Huggingface Hub Endpoint Type must be provided."
)
if 'huggingfacehub_api_token' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub API Token must be provided.')
if "huggingfacehub_api_token" not in credentials:
raise CredentialsValidateFailedError(
"Huggingface Hub API Token must be provided."
)
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
if 'huggingface_namespace' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub User Name / Organization Name must be provided.')
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
if "huggingface_namespace" not in credentials:
raise CredentialsValidateFailedError(
"Huggingface Hub User Name / Organization Name must be provided."
)
if 'huggingfacehub_endpoint_url' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.')
if "huggingfacehub_endpoint_url" not in credentials:
raise CredentialsValidateFailedError(
"Huggingface Hub Endpoint URL must be provided."
)
if 'task_type' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.')
if "task_type" not in credentials:
raise CredentialsValidateFailedError(
"Huggingface Hub Task Type must be provided."
)
if credentials['task_type'] != 'feature-extraction':
raise CredentialsValidateFailedError('Huggingface Hub Task Type is invalid.')
if credentials["task_type"] != "feature-extraction":
raise CredentialsValidateFailedError(
"Huggingface Hub Task Type is invalid."
)
self._check_endpoint_url_model_repository_name(credentials, model)
model = credentials['huggingfacehub_endpoint_url']
model = credentials["huggingfacehub_endpoint_url"]
elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api':
self._check_hosted_model_task_type(credentials['huggingfacehub_api_token'],
model)
elif credentials["huggingfacehub_api_type"] == "hosted_inference_api":
self._check_hosted_model_task_type(
credentials["huggingfacehub_api_token"], model
)
else:
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.')
raise CredentialsValidateFailedError(
"Huggingface Hub Endpoint Type is invalid."
)
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
client.feature_extraction(text='hello world', model=model)
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
client.feature_extraction(text="hello world", model=model)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def get_customizable_model_schema(
self, model: str, credentials: dict
) -> Optional[AIModelEntity]:
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
'context_size': 10000,
'max_chunks': 1
}
model_properties={"context_size": 10000, "max_chunks": 1},
)
return entity
@ -118,34 +144,47 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
return embeddings
# For example two: List[List[List[float]]], need to mean_pooling.
sentence_embeddings = [np.mean(embedding[0], axis=0).tolist() for embedding in embeddings]
sentence_embeddings = [
np.mean(embedding[0], axis=0).tolist() for embedding in embeddings
]
return sentence_embeddings
@staticmethod
def _check_hosted_model_task_type(huggingfacehub_api_token: str, model_name: str) -> None:
def _check_hosted_model_task_type(
huggingfacehub_api_token: str, model_name: str
) -> None:
hf_api = HfApi(token=huggingfacehub_api_token)
model_info = hf_api.model_info(repo_id=model_name)
try:
if not model_info:
raise ValueError(f'Model {model_name} not found.')
raise ValueError(f"Model {model_name} not found.")
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
if (
"inference" in model_info.cardData
and not model_info.cardData["inference"]
):
raise ValueError(
f"Inference API has been turned off for this model {model_name}."
)
valid_tasks = "feature-extraction"
if model_info.pipeline_tag not in valid_tasks:
raise ValueError(f"Model {model_name} is not a valid task, "
f"must be one of {valid_tasks}.")
raise ValueError(
f"Model {model_name} is not a valid task, "
f"must be one of {valid_tasks}."
)
except Exception as e:
raise CredentialsValidateFailedError(f"{str(e)}")
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
def _calc_response_usage(
self, model: str, credentials: dict, tokens: int
) -> EmbeddingUsage:
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
tokens=tokens,
)
# transform usage
@ -156,7 +195,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage
@ -166,25 +205,29 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
try:
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
headers = {
'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}',
'Content-Type': 'application/json'
"Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}',
"Content-Type": "application/json",
}
response = requests.get(url=url, headers=headers)
if response.status_code != 200:
raise ValueError('User Name or Organization Name is invalid.')
raise ValueError("User Name or Organization Name is invalid.")
model_repository_name = ''
model_repository_name = ""
for item in response.json().get("items", []):
if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']:
if (
item.get("status", {}).get("url")
== credentials["huggingfacehub_endpoint_url"]
):
model_repository_name = item.get("model", {}).get("repository")
break
if model_repository_name != model_name:
raise ValueError(
f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.')
f"Model Name {model_name} is invalid. Please check it on the inference endpoints console."
)
except Exception as e:
raise ValueError(str(e))

View File

@ -1,14 +1,17 @@
import logging
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
ModelProvider,
)
logger = logging.getLogger(__name__)
class JinaProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
@ -22,11 +25,12 @@ class JinaProvider(ModelProvider):
# Use `jina-embeddings-v2-base-en` model for validate,
# no matter what model you pass in, text completion model or chat model
model_instance.validate_credentials(
model='jina-embeddings-v2-base-en',
credentials=credentials
model="jina-embeddings-v2-base-en", credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(
f"{self.get_provider_schema().provider} credentials validate failed"
)
raise ex

View File

@ -2,7 +2,10 @@ from typing import Optional
import httpx
from model_providers.core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from model_providers.core.model_runtime.entities.rerank_entities import (
RerankDocument,
RerankResult,
)
from model_providers.core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@ -11,8 +14,12 @@ from model_providers.core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.rerank_model import (
RerankModel,
)
class JinaRerankModel(RerankModel):
@ -20,9 +27,16 @@ class JinaRerankModel(RerankModel):
Model class for Jina rerank model.
"""
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) -> RerankResult:
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
@ -45,26 +59,29 @@ class JinaRerankModel(RerankModel):
"model": model,
"query": query,
"documents": docs,
"top_n": top_n
"top_n": top_n,
},
headers={"Authorization": f"Bearer {credentials.get('api_key')}"}
headers={"Authorization": f"Bearer {credentials.get('api_key')}"},
)
response.raise_for_status()
response.raise_for_status()
results = response.json()
rerank_documents = []
for result in results['results']:
for result in results["results"]:
rerank_document = RerankDocument(
index=result['index'],
text=result['document']['text'],
score=result['relevance_score'],
index=result["index"],
text=result["document"]["text"],
score=result["relevance_score"],
)
if score_threshold is None or result['relevance_score'] >= score_threshold:
if (
score_threshold is None
or result["relevance_score"] >= score_threshold
):
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
@ -75,7 +92,6 @@ class JinaRerankModel(RerankModel):
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
@ -86,7 +102,7 @@ class JinaRerankModel(RerankModel):
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -99,7 +115,7 @@ class JinaRerankModel(RerankModel):
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError]
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError],
}

View File

@ -14,19 +14,19 @@ class JinaTokenizer:
with cls._lock:
if cls._tokenizer is None:
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
gpt2_tokenizer_path = join(dirname(base_path), "tokenizer")
cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
return cls._tokenizer
@classmethod
def _get_num_tokens_by_jina_base(cls, text: str) -> int:
"""
use jina tokenizer to get num tokens
use jina tokenizer to get num tokens
"""
tokenizer = cls._get_tokenizer()
tokens = tokenizer.encode(text)
return len(tokens)
@classmethod
def get_num_tokens(cls, text: str) -> int:
return cls._get_num_tokens_by_jina_base(text)
return cls._get_num_tokens_by_jina_base(text)

View File

@ -5,7 +5,10 @@ from typing import Optional
from requests import post
from model_providers.core.model_runtime.entities.model_entities import PriceType
from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from model_providers.core.model_runtime.entities.text_embedding_entities import (
EmbeddingUsage,
TextEmbeddingResult,
)
from model_providers.core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@ -14,21 +17,37 @@ from model_providers.core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from model_providers.core.model_runtime.model_providers.jina.text_embedding.jina_tokenizer import JinaTokenizer
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
TextEmbeddingModel,
)
from model_providers.core.model_runtime.model_providers.jina.text_embedding.jina_tokenizer import (
JinaTokenizer,
)
class JinaTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Jina text embedding model.
"""
api_base: str = 'https://api.jina.ai/v1/embeddings'
models: list[str] = ['jina-embeddings-v2-base-en', 'jina-embeddings-v2-small-en', 'jina-embeddings-v2-base-zh', 'jina-embeddings-v2-base-de']
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
api_base: str = "https://api.jina.ai/v1/embeddings"
models: list[str] = [
"jina-embeddings-v2-base-en",
"jina-embeddings-v2-small-en",
"jina-embeddings-v2-base-zh",
"jina-embeddings-v2-base-de",
]
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
@ -38,31 +57,28 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
:param user: unique user id
:return: embeddings result
"""
api_key = credentials['api_key']
api_key = credentials["api_key"]
if model not in self.models:
raise InvokeBadRequestError('Invalid model name')
raise InvokeBadRequestError("Invalid model name")
if not api_key:
raise CredentialsValidateFailedError('api_key is required')
raise CredentialsValidateFailedError("api_key is required")
url = self.api_base
headers = {
'Authorization': 'Bearer ' + api_key,
'Content-Type': 'application/json'
"Authorization": "Bearer " + api_key,
"Content-Type": "application/json",
}
data = {
'model': model,
'input': texts
}
data = {"model": model, "input": texts}
try:
response = post(url, headers=headers, data=dumps(data))
except Exception as e:
raise InvokeConnectionError(str(e))
if response.status_code != 200:
try:
resp = response.json()
msg = resp['detail']
msg = resp["detail"]
if response.status_code == 401:
raise InvokeAuthorizationError(msg)
elif response.status_code == 429:
@ -72,23 +88,27 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
else:
raise InvokeError(msg)
except JSONDecodeError as e:
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
raise InvokeServerUnavailableError(
f"Failed to convert response to json: {e} with text: {response.text}"
)
try:
resp = response.json()
embeddings = resp['data']
usage = resp['usage']
embeddings = resp["data"]
usage = resp["usage"]
except Exception as e:
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
raise InvokeServerUnavailableError(
f"Failed to convert response to json: {e} with text: {response.text}"
)
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens'])
usage = self._calc_response_usage(
model=model, credentials=credentials, tokens=usage["total_tokens"]
)
result = TextEmbeddingResult(
model=model,
embeddings=[[
float(data) for data in x['embedding']
] for x in embeddings],
usage=usage
embeddings=[[float(data) for data in x["embedding"]] for x in embeddings],
usage=usage,
)
return result
@ -117,31 +137,23 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
:return:
"""
try:
self._invoke(model=model, credentials=credentials, texts=['ping'])
self._invoke(model=model, credentials=credentials, texts=["ping"])
except InvokeAuthorizationError:
raise CredentialsValidateFailedError('Invalid api key')
raise CredentialsValidateFailedError("Invalid api key")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
KeyError
]
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [KeyError],
}
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
def _calc_response_usage(
self, model: str, credentials: dict, tokens: int
) -> EmbeddingUsage:
"""
Calculate response usage
@ -155,7 +167,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
tokens=tokens,
)
# transform usage
@ -166,7 +178,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage

View File

@ -21,7 +21,12 @@ from openai.types.completion import Completion
from yarl import URL
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from model_providers.core.model_runtime.entities.llm_entities import (
LLMMode,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@ -45,34 +50,60 @@ from model_providers.core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
from model_providers.core.model_runtime.utils import helper
class LocalAILarguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
return self._generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
# tools is not supported yet
return self._num_tokens_from_messages(prompt_messages, tools=tools)
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int:
def _num_tokens_from_messages(
self, messages: list[PromptMessage], tools: list[PromptMessageTool]
) -> int:
"""
Calculate num tokens for baichuan model
LocalAI does not supports
Calculate num tokens for baichuan model
LocalAI does not supports
"""
def tokens(text: str):
"""
We cloud not determine which tokenizer to use, cause the model is customized.
So we use gpt2 tokenizer to calculate the num tokens for convenience.
We cloud not determine which tokenizer to use, cause the model is customized.
So we use gpt2 tokenizer to calculate the num tokens for convenience.
"""
return self._get_num_tokens_by_gpt2(text)
@ -85,10 +116,10 @@ class LocalAILarguageModel(LargeLanguageModel):
num_tokens += tokens_per_message
for key, value in message.items():
if isinstance(value, list):
text = ''
text = ""
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
@ -124,7 +155,7 @@ class LocalAILarguageModel(LargeLanguageModel):
num_tokens += self._num_tokens_for_tools(tools)
return num_tokens
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
"""
Calculate num tokens for tool calling
@ -133,36 +164,37 @@ class LocalAILarguageModel(LargeLanguageModel):
:param tools: tools for tool calling
:return: number of tokens
"""
def tokens(text: str):
return self._get_num_tokens_by_gpt2(text)
num_tokens = 0
for tool in tools:
# calculate num tokens for function object
num_tokens += tokens('name')
num_tokens += tokens("name")
num_tokens += tokens(tool.name)
num_tokens += tokens('description')
num_tokens += tokens("description")
num_tokens += tokens(tool.description)
parameters = tool.parameters
num_tokens += tokens('parameters')
num_tokens += tokens('type')
num_tokens += tokens("parameters")
num_tokens += tokens("type")
num_tokens += tokens(parameters.get("type"))
if 'properties' in parameters:
num_tokens += tokens('properties')
for key, value in parameters.get('properties').items():
if "properties" in parameters:
num_tokens += tokens("properties")
for key, value in parameters.get("properties").items():
num_tokens += tokens(key)
for field_key, field_value in value.items():
num_tokens += tokens(field_key)
if field_key == 'enum':
if field_key == "enum":
for enum_field in field_value:
num_tokens += 3
num_tokens += tokens(enum_field)
else:
num_tokens += tokens(field_key)
num_tokens += tokens(str(field_value))
if 'required' in parameters:
num_tokens += tokens('required')
for required_field in parameters['required']:
if "required" in parameters:
num_tokens += tokens("required")
for required_field in parameters["required"]:
num_tokens += 3
num_tokens += tokens(required_field)
@ -177,141 +209,166 @@ class LocalAILarguageModel(LargeLanguageModel):
:return:
"""
try:
self._invoke(model=model, credentials=credentials, prompt_messages=[
UserPromptMessage(content='ping')
], model_parameters={
'max_tokens': 10,
}, stop=[], stream=False)
self._invoke(
model=model,
credentials=credentials,
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={
"max_tokens": 10,
},
stop=[],
stream=False,
)
except Exception as ex:
raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')
raise CredentialsValidateFailedError(f"Invalid credentials {str(ex)}")
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
def get_customizable_model_schema(
self, model: str, credentials: dict
) -> AIModelEntity | None:
completion_model = None
if credentials['completion_type'] == 'chat_completion':
if credentials["completion_type"] == "chat_completion":
completion_model = LLMMode.CHAT.value
elif credentials['completion_type'] == 'completion':
elif credentials["completion_type"] == "completion":
completion_model = LLMMode.COMPLETION.value
else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
raise ValueError(
f"Unknown completion type {credentials['completion_type']}"
)
rules = [
ParameterRule(
name='temperature',
name="temperature",
type=ParameterType.FLOAT,
use_template='temperature',
label=I18nObject(
zh_Hans='温度',
en_US='Temperature'
)
use_template="temperature",
label=I18nObject(zh_Hans="温度", en_US="Temperature"),
),
ParameterRule(
name='top_p',
name="top_p",
type=ParameterType.FLOAT,
use_template='top_p',
label=I18nObject(
zh_Hans='Top P',
en_US='Top P'
)
use_template="top_p",
label=I18nObject(zh_Hans="Top P", en_US="Top P"),
),
ParameterRule(
name='max_tokens',
name="max_tokens",
type=ParameterType.INT,
use_template='max_tokens',
use_template="max_tokens",
min=1,
max=2048,
default=512,
label=I18nObject(
zh_Hans='最大生成长度',
en_US='Max Tokens'
)
)
label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"),
),
]
model_properties = {
ModelPropertyKey.MODE: completion_model,
} if completion_model else {}
model_properties = (
{
ModelPropertyKey.MODE: completion_model,
}
if completion_model
else {}
)
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048'))
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
credentials.get("context_size", "2048")
)
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties=model_properties,
parameter_rules=rules
parameter_rules=rules,
)
return entity
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
kwargs = self._to_client_kwargs(credentials)
# init model client
client = OpenAI(**kwargs)
model_name = model
completion_type = credentials['completion_type']
completion_type = credentials["completion_type"]
extra_model_kwargs = {
"timeout": 60,
}
if stop:
extra_model_kwargs['stop'] = stop
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs['user'] = user
extra_model_kwargs["user"] = user
if tools and len(tools) > 0:
extra_model_kwargs['functions'] = [
extra_model_kwargs["functions"] = [
helper.dump_model(tool) for tool in tools
]
if completion_type == 'chat_completion':
if completion_type == "chat_completion":
result = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
messages=[
self._convert_prompt_message_to_dict(m) for m in prompt_messages
],
model=model_name,
stream=stream,
**model_parameters,
**extra_model_kwargs,
)
elif completion_type == 'completion':
elif completion_type == "completion":
result = client.completions.create(
prompt=self._convert_prompt_message_to_completion_prompts(prompt_messages),
prompt=self._convert_prompt_message_to_completion_prompts(
prompt_messages
),
model=model,
stream=stream,
**model_parameters,
**extra_model_kwargs
**extra_model_kwargs,
)
else:
raise ValueError(f"Unknown completion type {completion_type}")
if stream:
if completion_type == 'completion':
if completion_type == "completion":
return self._handle_completion_generate_stream_response(
model=model, credentials=credentials, response=result, tools=tools,
prompt_messages=prompt_messages
model=model,
credentials=credentials,
response=result,
tools=tools,
prompt_messages=prompt_messages,
)
return self._handle_chat_generate_stream_response(
model=model, credentials=credentials, response=result, tools=tools,
prompt_messages=prompt_messages
model=model,
credentials=credentials,
response=result,
tools=tools,
prompt_messages=prompt_messages,
)
if completion_type == 'completion':
if completion_type == "completion":
return self._handle_completion_generate_response(
model=model, credentials=credentials, response=result,
prompt_messages=prompt_messages
model=model,
credentials=credentials,
response=result,
prompt_messages=prompt_messages,
)
return self._handle_chat_generate_response(
model=model, credentials=credentials, response=result, tools=tools,
prompt_messages=prompt_messages
model=model,
credentials=credentials,
response=result,
tools=tools,
prompt_messages=prompt_messages,
)
def _to_client_kwargs(self, credentials: dict) -> dict:
"""
Convert invoke kwargs to client kwargs
@ -319,13 +376,13 @@ class LocalAILarguageModel(LargeLanguageModel):
:param credentials: credentials dict
:return: client kwargs
"""
if not credentials['server_url'].endswith('/'):
credentials['server_url'] += '/'
if not credentials["server_url"].endswith("/"):
credentials["server_url"] += "/"
client_kwargs = {
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"api_key": "1",
"base_url": str(URL(credentials['server_url']) / 'v1'),
"base_url": str(URL(credentials["server_url"]) / "v1"),
}
return client_kwargs
@ -346,41 +403,45 @@ class LocalAILarguageModel(LargeLanguageModel):
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments
"arguments": message.tool_calls[0].function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return message_dict
def _convert_prompt_message_to_completion_prompts(self, messages: list[PromptMessage]) -> str:
def _convert_prompt_message_to_completion_prompts(
self, messages: list[PromptMessage]
) -> str:
"""
Convert PromptMessage to completion prompts
"""
prompts = ''
prompts = ""
for message in messages:
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
prompts += f'{message.content}\n'
prompts += f"{message.content}\n"
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
prompts += f'{message.content}\n'
prompts += f"{message.content}\n"
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
prompts += f'{message.content}\n'
prompts += f"{message.content}\n"
else:
raise ValueError(f"Unknown message type {type(message)}")
return prompts
def _handle_completion_generate_response(self, model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Completion,
) -> LLMResult:
def _handle_completion_generate_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Completion,
) -> LLMResult:
"""
Handle llm chat response
@ -393,21 +454,27 @@ class LocalAILarguageModel(LargeLanguageModel):
"""
if len(response.choices) == 0:
raise InvokeServerUnavailableError("Empty response")
assistant_message = response.choices[0].text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_message,
tool_calls=[]
content=assistant_message, tool_calls=[]
)
prompt_tokens = self._get_num_tokens_by_gpt2(
self._convert_prompt_message_to_completion_prompts(prompt_messages)
)
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=[])
completion_tokens = self._num_tokens_from_messages(
messages=[assistant_prompt_message], tools=[]
)
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
response = LLMResult(
model=model,
@ -419,11 +486,14 @@ class LocalAILarguageModel(LargeLanguageModel):
return response
def _handle_chat_generate_response(self, model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: ChatCompletion,
tools: list[PromptMessageTool]) -> LLMResult:
def _handle_chat_generate_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: ChatCompletion,
tools: list[PromptMessageTool],
) -> LLMResult:
"""
Handle llm chat response
@ -436,23 +506,33 @@ class LocalAILarguageModel(LargeLanguageModel):
"""
if len(response.choices) == 0:
raise InvokeServerUnavailableError("Empty response")
assistant_message = response.choices[0].message
# convert function call to tool call
function_calls = assistant_message.function_call
tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else [])
tool_calls = self._extract_response_tool_calls(
[function_calls] if function_calls else []
)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_message.content,
tool_calls=tool_calls
content=assistant_message.content, tool_calls=tool_calls
)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
prompt_tokens = self._num_tokens_from_messages(
messages=prompt_messages, tools=tools
)
completion_tokens = self._num_tokens_from_messages(
messages=[assistant_prompt_message], tools=tools
)
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
response = LLMResult(
model=model,
@ -464,12 +544,15 @@ class LocalAILarguageModel(LargeLanguageModel):
return response
def _handle_completion_generate_stream_response(self, model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Stream[Completion],
tools: list[PromptMessageTool]) -> Generator:
full_response = ''
def _handle_completion_generate_stream_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Stream[Completion],
tools: list[PromptMessageTool],
) -> Generator:
full_response = ""
for chunk in response:
if len(chunk.choices) == 0:
@ -479,26 +562,30 @@ class LocalAILarguageModel(LargeLanguageModel):
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.text if delta.text else '',
tool_calls=[]
content=delta.text if delta.text else "", tool_calls=[]
)
if delta.finish_reason is not None:
# temp_assistant_prompt_message is used to calculate usage
temp_assistant_prompt_message = AssistantPromptMessage(
content=full_response,
tool_calls=[]
content=full_response, tool_calls=[]
)
prompt_tokens = self._get_num_tokens_by_gpt2(
self._convert_prompt_message_to_completion_prompts(prompt_messages)
)
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
completion_tokens = self._num_tokens_from_messages(
messages=[temp_assistant_prompt_message], tools=[]
)
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
@ -507,7 +594,7 @@ class LocalAILarguageModel(LargeLanguageModel):
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage
usage=usage,
),
)
else:
@ -523,12 +610,15 @@ class LocalAILarguageModel(LargeLanguageModel):
full_response += delta.text
def _handle_chat_generate_stream_response(self, model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Stream[ChatCompletionChunk],
tools: list[PromptMessageTool]) -> Generator:
full_response = ''
def _handle_chat_generate_stream_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Stream[ChatCompletionChunk],
tools: list[PromptMessageTool],
) -> Generator:
full_response = ""
for chunk in response:
if len(chunk.choices) == 0:
@ -536,35 +626,46 @@ class LocalAILarguageModel(LargeLanguageModel):
delta = chunk.choices[0]
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
if delta.finish_reason is None and (
delta.delta.content is None or delta.delta.content == ""
):
continue
# check if there is a tool call in the response
function_calls = None
if delta.delta.function_call:
function_calls = [delta.delta.function_call]
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else [])
assistant_message_tool_calls = self._extract_response_tool_calls(
function_calls if function_calls else []
)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else '',
tool_calls=assistant_message_tool_calls
content=delta.delta.content if delta.delta.content else "",
tool_calls=assistant_message_tool_calls,
)
if delta.finish_reason is not None:
# temp_assistant_prompt_message is used to calculate usage
temp_assistant_prompt_message = AssistantPromptMessage(
content=full_response,
tool_calls=assistant_message_tool_calls
content=full_response, tool_calls=assistant_message_tool_calls
)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
prompt_tokens = self._num_tokens_from_messages(
messages=prompt_messages, tools=tools
)
completion_tokens = self._num_tokens_from_messages(
messages=[temp_assistant_prompt_message], tools=[]
)
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
@ -573,7 +674,7 @@ class LocalAILarguageModel(LargeLanguageModel):
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage
usage=usage,
),
)
else:
@ -589,9 +690,9 @@ class LocalAILarguageModel(LargeLanguageModel):
full_response += delta.delta.content
def _extract_response_tool_calls(self,
response_function_calls: list[FunctionCall]) \
-> list[AssistantPromptMessage.ToolCall]:
def _extract_response_tool_calls(
self, response_function_calls: list[FunctionCall]
) -> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
@ -602,18 +703,15 @@ class LocalAILarguageModel(LargeLanguageModel):
if response_function_calls:
for response_tool_call in response_function_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.name,
arguments=response_tool_call.arguments
name=response_tool_call.name, arguments=response_tool_call.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=0,
type='function',
function=function
id=0, type="function", function=function
)
tool_calls.append(tool_call)
return tool_calls
return tool_calls
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
@ -635,15 +733,9 @@ class LocalAILarguageModel(LargeLanguageModel):
ConflictError,
NotFoundError,
UnprocessableEntityError,
PermissionDeniedError
PermissionDeniedError,
],
InvokeRateLimitError: [
RateLimitError
],
InvokeAuthorizationError: [
AuthenticationError
],
InvokeBadRequestError: [
ValueError
]
InvokeRateLimitError: [RateLimitError],
InvokeAuthorizationError: [AuthenticationError],
InvokeBadRequestError: [ValueError],
}

View File

@ -1,11 +1,12 @@
import logging
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
ModelProvider,
)
logger = logging.getLogger(__name__)
class LocalAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -6,8 +6,17 @@ from requests import post
from yarl import URL
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelPropertyKey,
ModelType,
PriceType,
)
from model_providers.core.model_runtime.entities.text_embedding_entities import (
EmbeddingUsage,
TextEmbeddingResult,
)
from model_providers.core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@ -16,17 +25,26 @@ from model_providers.core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
TextEmbeddingModel,
)
class LocalAITextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Jina text embedding model.
"""
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
@ -37,39 +55,38 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
:return: embeddings result
"""
if len(texts) != 1:
raise InvokeBadRequestError('Only one text is supported')
raise InvokeBadRequestError("Only one text is supported")
server_url = credentials['server_url']
server_url = credentials["server_url"]
model_name = model
if not server_url:
raise CredentialsValidateFailedError('server_url is required')
raise CredentialsValidateFailedError("server_url is required")
if not model_name:
raise CredentialsValidateFailedError('model_name is required')
url = server_url
headers = {
'Authorization': 'Bearer 123',
'Content-Type': 'application/json'
}
raise CredentialsValidateFailedError("model_name is required")
data = {
'model': model_name,
'input': texts[0]
}
url = server_url
headers = {"Authorization": "Bearer 123", "Content-Type": "application/json"}
data = {"model": model_name, "input": texts[0]}
try:
response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10)
response = post(
str(URL(url) / "embeddings"),
headers=headers,
data=dumps(data),
timeout=10,
)
except Exception as e:
raise InvokeConnectionError(str(e))
if response.status_code != 200:
try:
resp = response.json()
code = resp['error']['code']
msg = resp['error']['message']
code = resp["error"]["code"]
msg = resp["error"]["message"]
if code == 500:
raise InvokeServerUnavailableError(msg)
if response.status_code == 401:
raise InvokeAuthorizationError(msg)
elif response.status_code == 429:
@ -79,23 +96,27 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
else:
raise InvokeError(msg)
except JSONDecodeError as e:
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
raise InvokeServerUnavailableError(
f"Failed to convert response to json: {e} with text: {response.text}"
)
try:
resp = response.json()
embeddings = resp['data']
usage = resp['usage']
embeddings = resp["data"]
usage = resp["usage"]
except Exception as e:
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
raise InvokeServerUnavailableError(
f"Failed to convert response to json: {e} with text: {response.text}"
)
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens'])
usage = self._calc_response_usage(
model=model, credentials=credentials, tokens=usage["total_tokens"]
)
result = TextEmbeddingResult(
model=model,
embeddings=[[
float(data) for data in x['embedding']
] for x in embeddings],
usage=usage
embeddings=[[float(data) for data in x["embedding"]] for x in embeddings],
usage=usage,
)
return result
@ -114,8 +135,10 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
# use GPT2Tokenizer to get num tokens
num_tokens += self._get_num_tokens_by_gpt2(text)
return num_tokens
def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
def _get_customizable_model_schema(
self, model: str, credentials: dict
) -> AIModelEntity | None:
"""
Get customizable model schema
@ -130,10 +153,12 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
features=[],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')),
ModelPropertyKey.CONTEXT_SIZE: int(
credentials.get("context_size", "512")
),
ModelPropertyKey.MAX_CHUNKS: 1,
},
parameter_rules=[]
parameter_rules=[],
)
def validate_credentials(self, model: str, credentials: dict) -> None:
@ -145,33 +170,25 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
:return:
"""
try:
self._invoke(model=model, credentials=credentials, texts=['ping'])
self._invoke(model=model, credentials=credentials, texts=["ping"])
except InvokeAuthorizationError:
raise CredentialsValidateFailedError('Invalid credentials')
raise CredentialsValidateFailedError("Invalid credentials")
except InvokeConnectionError as e:
raise CredentialsValidateFailedError(f'Invalid credentials: {e}')
raise CredentialsValidateFailedError(f"Invalid credentials: {e}")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
KeyError
]
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [KeyError],
}
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
def _calc_response_usage(
self, model: str, credentials: dict, tokens: int
) -> EmbeddingUsage:
"""
Calculate response usage
@ -185,7 +202,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
tokens=tokens,
)
# transform usage
@ -196,7 +213,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage

Some files were not shown because too many files have changed in this diff Show More