diff --git a/model-providers/model_providers/core/model_runtime/model_providers/ollama/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/ollama/_common.py
new file mode 100644
index 00000000..c910ac97
--- /dev/null
+++ b/model-providers/model_providers/core/model_runtime/model_providers/ollama/_common.py
@@ -0,0 +1,60 @@
+from typing import Dict, List, Type
+
+import openai
+from httpx import Timeout
+
+from model_providers.core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+
+
+class _CommonOllama:
+ def _to_credential_kwargs(self, credentials: dict) -> dict:
+ """
+ Transform credentials to kwargs for model instance
+
+ :param credentials:
+ :return:
+ """
+ credentials_kwargs = {
+ "openai_api_key": "Empty",
+ "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
+ "max_retries": 1,
+ }
+
+ if "openai_api_base" in credentials and credentials["openai_api_base"]:
+ credentials["openai_api_base"] = credentials["openai_api_base"].rstrip("/")
+ credentials_kwargs["base_url"] = credentials["openai_api_base"] + "/v1"
+
+ return credentials_kwargs
+
+ @property
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the error type thrown to the caller
+ The value is the error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke error mapping
+ """
+ return {
+ InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
+ InvokeServerUnavailableError: [openai.InternalServerError],
+ InvokeRateLimitError: [openai.RateLimitError],
+ InvokeAuthorizationError: [
+ openai.AuthenticationError,
+ openai.PermissionDeniedError,
+ ],
+ InvokeBadRequestError: [
+ openai.BadRequestError,
+ openai.NotFoundError,
+ openai.UnprocessableEntityError,
+ openai.APIError,
+ ],
+ }
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py
index 3ce6f73f..0848f410 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py
@@ -1,13 +1,22 @@
-import json
import logging
-import re
from collections.abc import Generator
-from decimal import Decimal
-from typing import Dict, List, Optional, Type, Union, cast
-from urllib.parse import urljoin
+from typing import List, Optional, Union, cast
-import requests
+import tiktoken
+from openai import OpenAI, Stream
+from openai.types import Completion
+from openai.types.chat import (
+ ChatCompletion,
+ ChatCompletionChunk,
+ ChatCompletionMessageToolCall,
+)
+from openai.types.chat.chat_completion_chunk import (
+ ChoiceDeltaFunctionCall,
+ ChoiceDeltaToolCall,
+)
+from openai.types.chat.chat_completion_message import FunctionCall
+from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.entities.llm_entities import (
LLMMode,
LLMResult,
@@ -22,41 +31,39 @@ from model_providers.core.model_runtime.entities.message_entities import (
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
+ ToolPromptMessage,
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
- DefaultParameterName,
FetchFrom,
I18nObject,
- ModelFeature,
- ModelPropertyKey,
ModelType,
- ParameterRule,
- ParameterType,
PriceConfig,
)
-from model_providers.core.model_runtime.errors.invoke import (
- InvokeAuthorizationError,
- InvokeBadRequestError,
- InvokeConnectionError,
- InvokeError,
- InvokeRateLimitError,
- InvokeServerUnavailableError,
-)
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
+from model_providers.core.model_runtime.model_providers.ollama._common import _CommonOllama
logger = logging.getLogger(__name__)
+OPENAI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
+The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
+if you are not sure about the structure.
-class OllamaLargeLanguageModel(LargeLanguageModel):
+
+{{instructions}}
+
+"""
+
+
+class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
"""
- Model class for Ollama large language model.
+ Model class for OpenAI large language model.
"""
def _invoke(
@@ -83,16 +90,197 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
- return self._generate(
+ # handle fine tune remote models
+ base_model = model
+ if model.startswith("ft:"):
+ base_model = model.split(":")[1]
+
+ # get model mode
+ model_mode = self.get_model_mode(base_model, credentials)
+
+ if model_mode == LLMMode.CHAT:
+ # chat model
+ return self._chat_generate(
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ )
+ else:
+ # text completion model
+ return self._generate(
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ stop=stop,
+ stream=stream,
+ user=user,
+ )
+
+ def _code_block_mode_wrapper(
+ self,
+ model: str,
+ credentials: dict,
+ prompt_messages: List[PromptMessage],
+ model_parameters: dict,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
+ stream: bool = True,
+ user: Optional[str] = None,
+ callbacks: List[Callback] = None,
+ ) -> Union[LLMResult, Generator]:
+ """
+ Code block mode wrapper for invoking large language model
+ """
+ # handle fine tune remote models
+ base_model = model
+ if model.startswith("ft:"):
+ base_model = model.split(":")[1]
+
+ # get model mode
+ model_mode = self.get_model_mode(base_model, credentials)
+
+ # transform response format
+ if "response_format" in model_parameters and model_parameters[
+ "response_format"
+ ] in ["JSON", "XML"]:
+ stop = stop or []
+ if model_mode == LLMMode.CHAT:
+ # chat model
+ self._transform_chat_json_prompts(
+ model=base_model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ response_format=model_parameters["response_format"],
+ )
+ else:
+ self._transform_completion_json_prompts(
+ model=base_model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ response_format=model_parameters["response_format"],
+ )
+ model_parameters.pop("response_format")
+
+ return self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
+ tools=tools,
stop=stop,
stream=stream,
user=user,
)
+ def _transform_chat_json_prompts(
+ self,
+ model: str,
+ credentials: dict,
+ prompt_messages: List[PromptMessage],
+ model_parameters: dict,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
+ stream: bool = True,
+ user: Union[str, None] = None,
+ response_format: str = "JSON",
+ ) -> None:
+ """
+ Transform json prompts
+ """
+ if "```\n" not in stop:
+ stop.append("```\n")
+ if "\n```" not in stop:
+ stop.append("\n```")
+
+ # check if there is a system message
+ if len(prompt_messages) > 0 and isinstance(
+ prompt_messages[0], SystemPromptMessage
+ ):
+ # override the system message
+ prompt_messages[0] = SystemPromptMessage(
+ content=OPENAI_BLOCK_MODE_PROMPT.replace(
+ "{{instructions}}", prompt_messages[0].content
+ ).replace("{{block}}", response_format)
+ )
+ prompt_messages.append(
+ AssistantPromptMessage(content=f"\n```{response_format}\n")
+ )
+ else:
+ # insert the system message
+ prompt_messages.insert(
+ 0,
+ SystemPromptMessage(
+ content=OPENAI_BLOCK_MODE_PROMPT.replace(
+ "{{instructions}}",
+ f"Please output a valid {response_format} object.",
+ ).replace("{{block}}", response_format)
+ ),
+ )
+ prompt_messages.append(
+ AssistantPromptMessage(content=f"\n```{response_format}")
+ )
+
+ def _transform_completion_json_prompts(
+ self,
+ model: str,
+ credentials: dict,
+ prompt_messages: List[PromptMessage],
+ model_parameters: dict,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
+ stream: bool = True,
+ user: Union[str, None] = None,
+ response_format: str = "JSON",
+ ) -> None:
+ """
+ Transform json prompts
+ """
+ if "```\n" not in stop:
+ stop.append("```\n")
+ if "\n```" not in stop:
+ stop.append("\n```")
+
+ # override the last user message
+ user_message = None
+ for i in range(len(prompt_messages) - 1, -1, -1):
+ if isinstance(prompt_messages[i], UserPromptMessage):
+ user_message = prompt_messages[i]
+ break
+
+ if user_message:
+ if prompt_messages[i].content[-11:] == "Assistant: ":
+ # now we are in the chat app, remove the last assistant message
+ prompt_messages[i].content = prompt_messages[i].content[:-11]
+ prompt_messages[i] = UserPromptMessage(
+ content=OPENAI_BLOCK_MODE_PROMPT.replace(
+ "{{instructions}}", user_message.content
+ ).replace("{{block}}", response_format)
+ )
+ prompt_messages[i].content += f"Assistant:\n```{response_format}\n"
+ else:
+ prompt_messages[i] = UserPromptMessage(
+ content=OPENAI_BLOCK_MODE_PROMPT.replace(
+ "{{instructions}}", user_message.content
+ ).replace("{{block}}", response_format)
+ )
+ prompt_messages[i].content += f"\n```{response_format}\n"
+
def get_num_tokens(
self,
model: str,
@@ -109,26 +297,21 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:param tools: tools for tool calling
:return:
"""
+ # handle fine tune remote models
+ if model.startswith("ft:"):
+ base_model = model.split(":")[1]
+ else:
+ base_model = model
+
# get model mode
- model_mode = self.get_model_mode(model, credentials)
+ model_mode = self.get_model_mode(model)
if model_mode == LLMMode.CHAT:
# chat model
- return self._num_tokens_from_messages(prompt_messages)
+ return self._num_tokens_from_messages(base_model, prompt_messages, tools)
else:
- first_prompt_message = prompt_messages[0]
- if isinstance(first_prompt_message.content, str):
- text = first_prompt_message.content
- else:
- text = ""
- for message_content in first_prompt_message.content:
- if message_content.type == PromptMessageContentType.TEXT:
- message_content = cast(
- TextPromptMessageContent, message_content
- )
- text = message_content.data
- break
- return self._get_num_tokens_by_gpt2(text)
+ # text completion model, do not support tool calling
+ return self._num_tokens_from_string(base_model, prompt_messages[0].content)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
@@ -139,22 +322,102 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:return:
"""
try:
- self._generate(
- model=model,
- credentials=credentials,
- prompt_messages=[UserPromptMessage(content="ping")],
- model_parameters={"num_predict": 5},
- stream=False,
- )
- except InvokeError as ex:
- raise CredentialsValidateFailedError(
- f"An error occurred during credentials validation: {ex.description}"
- )
+ # transform credentials to kwargs for model instance
+ credentials_kwargs = self._to_credential_kwargs(credentials)
+ client = OpenAI(**credentials_kwargs)
+
+ # handle fine tune remote models
+ base_model = model
+ # fine-tuned model name likes ft:gpt-3.5-turbo-0613:personal::xxxxx
+ if model.startswith("ft:"):
+ base_model = model.split(":")[1]
+
+ # check if model exists
+ remote_models = self.remote_models(credentials)
+ remote_model_map = {model.model: model for model in remote_models}
+ if model not in remote_model_map:
+ raise CredentialsValidateFailedError(
+ f"Fine-tuned model {model} not found"
+ )
+
+ # get model mode
+ model_mode = self.get_model_mode(base_model, credentials)
+
+ if model_mode == LLMMode.CHAT:
+ # chat model
+ client.chat.completions.create(
+ messages=[{"role": "user", "content": "ping"}],
+ model=model,
+ temperature=0,
+ max_tokens=20,
+ stream=False,
+ )
+ else:
+ # text completion model
+ client.completions.create(
+ prompt="ping",
+ model=model,
+ temperature=0,
+ max_tokens=20,
+ stream=False,
+ )
except Exception as ex:
- raise CredentialsValidateFailedError(
- f"An error occurred during credentials validation: {str(ex)}"
+ raise CredentialsValidateFailedError(str(ex))
+
+ def remote_models(self, credentials: dict) -> List[AIModelEntity]:
+ """
+ Return remote models if credentials are provided.
+
+ :param credentials: provider credentials
+ :return:
+ """
+ # get predefined models
+ predefined_models = self.predefined_models()
+ predefined_models_map = {model.model: model for model in predefined_models}
+
+ # transform credentials to kwargs for model instance
+ credentials_kwargs = self._to_credential_kwargs(credentials)
+ client = OpenAI(**credentials_kwargs)
+
+ # get all remote models
+ remote_models = client.models.list()
+
+ fine_tune_models = [
+ model for model in remote_models if model.id.startswith("ft:")
+ ]
+
+ ai_model_entities = []
+ for model in fine_tune_models:
+ base_model = model.id.split(":")[1]
+
+ base_model_schema = None
+ for (
+ predefined_model_name,
+ predefined_model,
+ ) in predefined_models_map.items():
+ if predefined_model_name in base_model:
+ base_model_schema = predefined_model
+
+ if not base_model_schema:
+ continue
+
+ ai_model_entity = AIModelEntity(
+ model=model.id,
+ label=I18nObject(zh_Hans=model.id, en_US=model.id),
+ model_type=ModelType.LLM,
+ features=base_model_schema.features,
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties=base_model_schema.model_properties,
+ parameter_rules=base_model_schema.parameter_rules,
+ pricing=PriceConfig(
+ input=0.003, output=0.006, unit=0.001, currency="USD"
+ ),
)
+ ai_model_entities.append(ai_model_entity)
+
+ return ai_model_entities
+
def _generate(
self,
model: str,
@@ -177,87 +440,43 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
- headers = {"Content-Type": "application/json"}
+ # transform credentials to kwargs for model instance
+ credentials_kwargs = self._to_credential_kwargs(credentials)
- endpoint_url = credentials["base_url"]
- if not endpoint_url.endswith("/"):
- endpoint_url += "/"
+ # init model client
+ client = OpenAI(**credentials_kwargs)
- # prepare the payload for a simple ping to the model
- data = {"model": model, "stream": stream}
-
- if "format" in model_parameters:
- data["format"] = model_parameters["format"]
- del model_parameters["format"]
-
- data["options"] = model_parameters or {}
+ extra_model_kwargs = {}
if stop:
- data["stop"] = "\n".join(stop)
+ extra_model_kwargs["stop"] = stop
- completion_type = LLMMode.value_of(credentials["mode"])
+ if user:
+ extra_model_kwargs["user"] = user
- if completion_type is LLMMode.CHAT:
- endpoint_url = urljoin(endpoint_url, "api/chat")
- data["messages"] = [
- self._convert_prompt_message_to_dict(m) for m in prompt_messages
- ]
- else:
- endpoint_url = urljoin(endpoint_url, "api/generate")
- first_prompt_message = prompt_messages[0]
- if isinstance(first_prompt_message, UserPromptMessage):
- first_prompt_message = cast(UserPromptMessage, first_prompt_message)
- if isinstance(first_prompt_message.content, str):
- data["prompt"] = first_prompt_message.content
- else:
- text = ""
- images = []
- for message_content in first_prompt_message.content:
- if message_content.type == PromptMessageContentType.TEXT:
- message_content = cast(
- TextPromptMessageContent, message_content
- )
- text = message_content.data
- elif message_content.type == PromptMessageContentType.IMAGE:
- message_content = cast(
- ImagePromptMessageContent, message_content
- )
- image_data = re.sub(
- r"^data:image\/[a-zA-Z]+;base64,",
- "",
- message_content.data,
- )
- images.append(image_data)
-
- data["prompt"] = text
- data["images"] = images
-
- # send a post request to validate the credentials
- response = requests.post(
- endpoint_url, headers=headers, json=data, timeout=(10, 60), stream=stream
+ # text completion model
+ response = client.completions.create(
+ prompt=prompt_messages[0].content,
+ model=model,
+ stream=stream,
+ **model_parameters,
+ **extra_model_kwargs,
)
- response.encoding = "utf-8"
- if response.status_code != 200:
- raise InvokeError(
- f"API request failed with status code {response.status_code}: {response.text}"
- )
-
if stream:
return self._handle_generate_stream_response(
- model, credentials, completion_type, response, prompt_messages
+ model, credentials, response, prompt_messages
)
return self._handle_generate_response(
- model, credentials, completion_type, response, prompt_messages
+ model, credentials, response, prompt_messages
)
def _handle_generate_response(
self,
model: str,
credentials: dict,
- completion_type: LLMMode,
- response: requests.Response,
+ response: Completion,
prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
@@ -265,29 +484,26 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:param model: model name
:param credentials: model credentials
- :param completion_type: completion type
:param response: response
:param prompt_messages: prompt messages
:return: llm result
"""
- response_json = response.json()
+ assistant_text = response.choices[0].text
- if completion_type is LLMMode.CHAT:
- message = response_json.get("message", {})
- response_content = message.get("content", "")
- else:
- response_content = response_json["response"]
+ # transform assistant message to prompt message
+ assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
- assistant_message = AssistantPromptMessage(content=response_content)
-
- if "prompt_eval_count" in response_json and "eval_count" in response_json:
+ # calculate num tokens
+ if response.usage:
# transform usage
- prompt_tokens = response_json["prompt_eval_count"]
- completion_tokens = response_json["eval_count"]
+ prompt_tokens = response.usage.prompt_tokens
+ completion_tokens = response.usage.completion_tokens
else:
# calculate num tokens
- prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
- completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content)
+ prompt_tokens = self._num_tokens_from_string(
+ model, prompt_messages[0].content
+ )
+ completion_tokens = self._num_tokens_from_string(model, assistant_text)
# transform usage
usage = self._calc_response_usage(
@@ -296,10 +512,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
# transform response
result = LLMResult(
- model=response_json["model"],
+ model=response.model,
prompt_messages=prompt_messages,
- message=assistant_message,
+ message=assistant_prompt_message,
usage=usage,
+ system_fingerprint=response.system_fingerprint,
)
return result
@@ -308,8 +525,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- completion_type: LLMMode,
- response: requests.Response,
+ response: Stream[Completion],
prompt_messages: List[PromptMessage],
) -> Generator:
"""
@@ -317,85 +533,38 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:param model: model name
:param credentials: model credentials
- :param completion_type: completion type
:param response: response
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
full_text = ""
- chunk_index = 0
-
- def create_final_llm_result_chunk(
- index: int, message: AssistantPromptMessage, finish_reason: str
- ) -> LLMResultChunk:
- # calculate num tokens
- prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
- completion_tokens = self._get_num_tokens_by_gpt2(full_text)
-
- # transform usage
- usage = self._calc_response_usage(
- model, credentials, prompt_tokens, completion_tokens
- )
-
- return LLMResultChunk(
- model=model,
- prompt_messages=prompt_messages,
- delta=LLMResultChunkDelta(
- index=index,
- message=message,
- finish_reason=finish_reason,
- usage=usage,
- ),
- )
-
- for chunk in response.iter_lines(decode_unicode=True, delimiter="\n"):
- if not chunk:
+ for chunk in response:
+ if len(chunk.choices) == 0:
continue
- try:
- chunk_json = json.loads(chunk)
- # stream ended
- except json.JSONDecodeError as e:
- yield create_final_llm_result_chunk(
- index=chunk_index,
- message=AssistantPromptMessage(content=""),
- finish_reason="Non-JSON encountered.",
- )
+ delta = chunk.choices[0]
- chunk_index += 1
- break
-
- if completion_type is LLMMode.CHAT:
- if not chunk_json:
- continue
-
- if "message" not in chunk_json:
- text = ""
- else:
- text = chunk_json.get("message").get("content", "")
- else:
- if not chunk_json:
- continue
-
- # transform assistant message to prompt message
- text = chunk_json["response"]
+ if delta.finish_reason is None and (delta.text is None or delta.text == ""):
+ continue
+ # transform assistant message to prompt message
+ text = delta.text if delta.text else ""
assistant_prompt_message = AssistantPromptMessage(content=text)
full_text += text
- if chunk_json["done"]:
+ if delta.finish_reason is not None:
# calculate num tokens
- if "prompt_eval_count" in chunk_json and "eval_count" in chunk_json:
+ if chunk.usage:
# transform usage
- prompt_tokens = chunk_json["prompt_eval_count"]
- completion_tokens = chunk_json["eval_count"]
+ prompt_tokens = chunk.usage.prompt_tokens
+ completion_tokens = chunk.usage.completion_tokens
else:
# calculate num tokens
- prompt_tokens = self._get_num_tokens_by_gpt2(
- prompt_messages[0].content
+ prompt_tokens = self._num_tokens_from_string(
+ model, prompt_messages[0].content
)
- completion_tokens = self._get_num_tokens_by_gpt2(full_text)
+ completion_tokens = self._num_tokens_from_string(model, full_text)
# transform usage
usage = self._calc_response_usage(
@@ -403,77 +572,546 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
)
yield LLMResultChunk(
- model=chunk_json["model"],
+ model=chunk.model,
prompt_messages=prompt_messages,
+ system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
- index=chunk_index,
+ index=delta.index,
message=assistant_prompt_message,
- finish_reason="stop",
+ finish_reason=delta.finish_reason,
usage=usage,
),
)
else:
yield LLMResultChunk(
- model=chunk_json["model"],
+ model=chunk.model,
prompt_messages=prompt_messages,
+ system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
- index=chunk_index,
+ index=delta.index,
message=assistant_prompt_message,
),
)
- chunk_index += 1
+ def _chat_generate(
+ self,
+ model: str,
+ credentials: dict,
+ prompt_messages: List[PromptMessage],
+ model_parameters: dict,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
+ stream: bool = True,
+ user: Optional[str] = None,
+ ) -> Union[LLMResult, Generator]:
+ """
+ Invoke llm chat model
+
+ :param model: model name
+ :param credentials: credentials
+ :param prompt_messages: prompt messages
+ :param model_parameters: model parameters
+ :param tools: tools for tool calling
+ :param stop: stop words
+ :param stream: is stream response
+ :param user: unique user id
+ :return: full response or stream response chunk generator result
+ """
+ # transform credentials to kwargs for model instance
+ credentials_kwargs = self._to_credential_kwargs(credentials)
+
+ # init model client
+ client = OpenAI(**credentials_kwargs)
+
+ response_format = model_parameters.get("response_format")
+ if response_format:
+ if response_format == "json_object":
+ response_format = {"type": "json_object"}
+ else:
+ response_format = {"type": "text"}
+
+ model_parameters["response_format"] = response_format
+
+ extra_model_kwargs = {}
+
+ if tools:
+ # extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
+ extra_model_kwargs["functions"] = [
+ {
+ "name": tool.name,
+ "description": tool.description,
+ "parameters": tool.parameters,
+ }
+ for tool in tools
+ ]
+
+ if stop:
+ extra_model_kwargs["stop"] = stop
+
+ if user:
+ extra_model_kwargs["user"] = user
+
+ # chat model
+ response = client.chat.completions.create(
+ messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
+ model=model,
+ stream=stream,
+ **model_parameters,
+ **extra_model_kwargs,
+ )
+
+ if stream:
+ return self._handle_chat_generate_stream_response(
+ model, credentials, response, prompt_messages, tools
+ )
+
+ return self._handle_chat_generate_response(
+ model, credentials, response, prompt_messages, tools
+ )
+
+ def _handle_chat_generate_response(
+ self,
+ model: str,
+ credentials: dict,
+ response: ChatCompletion,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
+ ) -> LLMResult:
+ """
+ Handle llm chat response
+
+ :param model: model name
+ :param credentials: credentials
+ :param response: response
+ :param prompt_messages: prompt messages
+ :param tools: tools for tool calling
+ :return: llm response
+ """
+ assistant_message = response.choices[0].message
+ # assistant_message_tool_calls = assistant_message.tool_calls
+ assistant_message_function_call = assistant_message.function_call
+
+ # extract tool calls from response
+ # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
+ function_call = self._extract_response_function_call(
+ assistant_message_function_call
+ )
+ tool_calls = [function_call] if function_call else []
+
+ # transform assistant message to prompt message
+ assistant_prompt_message = AssistantPromptMessage(
+ content=assistant_message.content, tool_calls=tool_calls
+ )
+
+ # calculate num tokens
+ if response.usage:
+ # transform usage
+ prompt_tokens = response.usage.prompt_tokens
+ completion_tokens = response.usage.completion_tokens
+ else:
+ # calculate num tokens
+ prompt_tokens = self._num_tokens_from_messages(
+ model, prompt_messages, tools
+ )
+ completion_tokens = self._num_tokens_from_messages(
+ model, [assistant_prompt_message]
+ )
+
+ # transform usage
+ usage = self._calc_response_usage(
+ model, credentials, prompt_tokens, completion_tokens
+ )
+
+ # transform response
+ response = LLMResult(
+ model=response.model,
+ prompt_messages=prompt_messages,
+ message=assistant_prompt_message,
+ usage=usage,
+ system_fingerprint=response.system_fingerprint,
+ )
+
+ return response
+
+ def _handle_chat_generate_stream_response(
+ self,
+ model: str,
+ credentials: dict,
+ response: Stream[ChatCompletionChunk],
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
+ ) -> Generator:
+ """
+ Handle llm chat stream response
+
+ :param model: model name
+ :param response: response
+ :param prompt_messages: prompt messages
+ :param tools: tools for tool calling
+ :return: llm response chunk generator
+ """
+ full_assistant_content = ""
+ delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
+ for chunk in response:
+ if len(chunk.choices) == 0:
+ continue
+
+ delta = chunk.choices[0]
+ has_finish_reason = delta.finish_reason is not None
+
+ if (
+ not has_finish_reason
+ and (delta.delta.content is None or delta.delta.content == "")
+ and delta.delta.function_call is None
+ ):
+ continue
+
+ # assistant_message_tool_calls = delta.delta.tool_calls
+ assistant_message_function_call = delta.delta.function_call
+
+ # extract tool calls from response
+ if delta_assistant_message_function_call_storage is not None:
+ # handle process of stream function call
+ if assistant_message_function_call:
+ # message has not ended ever
+ delta_assistant_message_function_call_storage.arguments += (
+ assistant_message_function_call.arguments
+ )
+ continue
+ else:
+ # message has ended
+ assistant_message_function_call = (
+ delta_assistant_message_function_call_storage
+ )
+ delta_assistant_message_function_call_storage = None
+ else:
+ if assistant_message_function_call:
+ # start of stream function call
+ delta_assistant_message_function_call_storage = (
+ assistant_message_function_call
+ )
+ if not has_finish_reason:
+ continue
+
+ # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
+ function_call = self._extract_response_function_call(
+ assistant_message_function_call
+ )
+ tool_calls = [function_call] if function_call else []
+
+ # transform assistant message to prompt message
+ assistant_prompt_message = AssistantPromptMessage(
+ content=delta.delta.content if delta.delta.content else "",
+ tool_calls=tool_calls,
+ )
+
+ full_assistant_content += delta.delta.content if delta.delta.content else ""
+
+ if has_finish_reason:
+ # calculate num tokens
+ prompt_tokens = self._num_tokens_from_messages(
+ model, prompt_messages, tools
+ )
+
+ full_assistant_prompt_message = AssistantPromptMessage(
+ content=full_assistant_content, tool_calls=tool_calls
+ )
+ completion_tokens = self._num_tokens_from_messages(
+ model, [full_assistant_prompt_message]
+ )
+
+ # transform usage
+ usage = self._calc_response_usage(
+ model, credentials, prompt_tokens, completion_tokens
+ )
+
+ yield LLMResultChunk(
+ model=chunk.model,
+ prompt_messages=prompt_messages,
+ system_fingerprint=chunk.system_fingerprint,
+ delta=LLMResultChunkDelta(
+ index=delta.index,
+ message=assistant_prompt_message,
+ finish_reason=delta.finish_reason,
+ usage=usage,
+ ),
+ )
+ else:
+ yield LLMResultChunk(
+ model=chunk.model,
+ prompt_messages=prompt_messages,
+ system_fingerprint=chunk.system_fingerprint,
+ delta=LLMResultChunkDelta(
+ index=delta.index,
+ message=assistant_prompt_message,
+ ),
+ )
+
+ def _extract_response_tool_calls(
+ self,
+ response_tool_calls: List[
+ Union[ChatCompletionMessageToolCall, ChoiceDeltaToolCall]
+ ],
+ ) -> List[AssistantPromptMessage.ToolCall]:
+ """
+ Extract tool calls from response
+
+ :param response_tool_calls: response tool calls
+ :return: list of tool calls
+ """
+ tool_calls = []
+ if response_tool_calls:
+ for response_tool_call in response_tool_calls:
+ function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+ name=response_tool_call.function.name,
+ arguments=response_tool_call.function.arguments,
+ )
+
+ tool_call = AssistantPromptMessage.ToolCall(
+ id=response_tool_call.id,
+ type=response_tool_call.type,
+ function=function,
+ )
+ tool_calls.append(tool_call)
+
+ return tool_calls
+
+ def _extract_response_function_call(
+ self, response_function_call: Union[FunctionCall, ChoiceDeltaFunctionCall]
+ ) -> AssistantPromptMessage.ToolCall:
+ """
+ Extract function call from response
+
+ :param response_function_call: response function call
+ :return: tool call
+ """
+ tool_call = None
+ if response_function_call:
+ function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+ name=response_function_call.name,
+ arguments=response_function_call.arguments,
+ )
+
+ tool_call = AssistantPromptMessage.ToolCall(
+ id=response_function_call.name, type="function", function=function
+ )
+
+ return tool_call
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
- Convert PromptMessage to dict for Ollama API
+ Convert PromptMessage to dict for OpenAI API
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
- text = ""
- images = []
+ sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(
TextPromptMessageContent, message_content
)
- text = message_content.data
+ sub_message_dict = {
+ "type": "text",
+ "text": message_content.data,
+ }
+ sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(
ImagePromptMessageContent, message_content
)
- image_data = re.sub(
- r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data
- )
- images.append(image_data)
+ sub_message_dict = {
+ "type": "image_url",
+ "image_url": {
+ "url": message_content.data,
+ "detail": message_content.detail.value,
+ },
+ }
+ sub_messages.append(sub_message_dict)
- message_dict = {"role": "user", "content": text, "images": images}
+ message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
+ if message.tool_calls:
+ # message_dict["tool_calls"] = [tool_call.dict() for tool_call in
+ # message.tool_calls]
+ function_call = message.tool_calls[0]
+ message_dict["function_call"] = {
+ "name": function_call.function.name,
+ "arguments": function_call.function.arguments,
+ }
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
+ elif isinstance(message, ToolPromptMessage):
+ message = cast(ToolPromptMessage, message)
+ # message_dict = {
+ # "role": "tool",
+ # "content": message.content,
+ # "tool_call_id": message.tool_call_id
+ # }
+ message_dict = {
+ "role": "function",
+ "content": message.content,
+ "name": message.tool_call_id,
+ }
else:
raise ValueError(f"Got unknown type {message}")
+ if message.name:
+ message_dict["name"] = message.name
+
return message_dict
- def _num_tokens_from_messages(self, messages: List[PromptMessage]) -> int:
+ def _num_tokens_from_string(
+ self, model: str, text: str, tools: Optional[List[PromptMessageTool]] = None
+ ) -> int:
"""
- Calculate num tokens.
+ Calculate num tokens for text completion model with tiktoken package.
- :param messages: messages
+ :param model: model name
+ :param text: prompt text
+ :param tools: tools for tool calling
+ :return: number of tokens
"""
+ try:
+ encoding = tiktoken.encoding_for_model(model)
+ except KeyError:
+ encoding = tiktoken.get_encoding("cl100k_base")
+
+ num_tokens = len(encoding.encode(text))
+
+ if tools:
+ num_tokens += self._num_tokens_for_tools(encoding, tools)
+
+ return num_tokens
+
+ def _num_tokens_from_messages(
+ self,
+ model: str,
+ messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
+ ) -> int:
+ """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
+
+ Official documentation: https://github.com/openai/openai-cookbook/blob/
+ main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
+ if model.startswith("ft:"):
+ model = model.split(":")[1]
+
+ try:
+ encoding = tiktoken.encoding_for_model(model)
+ except KeyError:
+ logger.warning("Warning: model not found. Using cl100k_base encoding.")
+ model = "cl100k_base"
+ encoding = tiktoken.get_encoding(model)
+
+ if model.startswith("gpt-3.5-turbo-0301"):
+ # every message follows {role/name}\n{content}\n
+ tokens_per_message = 4
+ # if there's a name, the role is omitted
+ tokens_per_name = -1
+ elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
+ tokens_per_message = 3
+ tokens_per_name = 1
+ else:
+ raise NotImplementedError(
+ f"get_num_tokens_from_messages() is not presently implemented "
+ f"for model {model}."
+ "See https://github.com/openai/openai-python/blob/main/chatml.md for "
+ "information on how messages are converted to tokens."
+ )
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
+ num_tokens += tokens_per_message
for key, value in message.items():
- num_tokens += self._get_num_tokens_by_gpt2(str(key))
- num_tokens += self._get_num_tokens_by_gpt2(str(value))
+ # Cast str(value) in case the message value is not a string
+ # This occurs with function messages
+ # TODO: The current token calculation method for the image type is not implemented,
+ # which need to download the image and then get the resolution for calculation,
+ # and will increase the request delay
+ if isinstance(value, list):
+ text = ""
+ for item in value:
+ if isinstance(item, dict) and item["type"] == "text":
+ text += item["text"]
+
+ value = text
+
+ if key == "tool_calls":
+ for tool_call in value:
+ for t_key, t_value in tool_call.items():
+ num_tokens += len(encoding.encode(t_key))
+ if t_key == "function":
+ for f_key, f_value in t_value.items():
+ num_tokens += len(encoding.encode(f_key))
+ num_tokens += len(encoding.encode(f_value))
+ else:
+ num_tokens += len(encoding.encode(t_key))
+ num_tokens += len(encoding.encode(t_value))
+ else:
+ num_tokens += len(encoding.encode(str(value)))
+
+ if key == "name":
+ num_tokens += tokens_per_name
+
+ # every reply is primed with assistant
+ num_tokens += 3
+
+ if tools:
+ num_tokens += self._num_tokens_for_tools(encoding, tools)
+
+ return num_tokens
+
+ def _num_tokens_for_tools(
+ self, encoding: tiktoken.Encoding, tools: List[PromptMessageTool]
+ ) -> int:
+ """
+ Calculate num tokens for tool calling with tiktoken package.
+
+ :param encoding: encoding
+ :param tools: tools for tool calling
+ :return: number of tokens
+ """
+ num_tokens = 0
+ for tool in tools:
+ num_tokens += len(encoding.encode("type"))
+ num_tokens += len(encoding.encode("function"))
+
+ # calculate num tokens for function object
+ num_tokens += len(encoding.encode("name"))
+ num_tokens += len(encoding.encode(tool.name))
+ num_tokens += len(encoding.encode("description"))
+ num_tokens += len(encoding.encode(tool.description))
+ parameters = tool.parameters
+ num_tokens += len(encoding.encode("parameters"))
+ if "title" in parameters:
+ num_tokens += len(encoding.encode("title"))
+ num_tokens += len(encoding.encode(parameters.get("title")))
+ num_tokens += len(encoding.encode("type"))
+ num_tokens += len(encoding.encode(parameters.get("type")))
+ if "properties" in parameters:
+ num_tokens += len(encoding.encode("properties"))
+ for key, value in parameters.get("properties").items():
+ num_tokens += len(encoding.encode(key))
+ for field_key, field_value in value.items():
+ num_tokens += len(encoding.encode(field_key))
+ if field_key == "enum":
+ for enum_field in field_value:
+ num_tokens += 3
+ num_tokens += len(encoding.encode(enum_field))
+ else:
+ num_tokens += len(encoding.encode(field_key))
+ num_tokens += len(encoding.encode(str(field_value)))
+ if "required" in parameters:
+ num_tokens += len(encoding.encode("required"))
+ for required_field in parameters["required"]:
+ num_tokens += 3
+ num_tokens += len(encoding.encode(required_field))
return num_tokens
@@ -481,251 +1119,44 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
self, model: str, credentials: dict
) -> AIModelEntity:
"""
- Get customizable model schema.
+ OpenAI supports fine-tuning of their models. This method returns the schema of the base model
+ but renamed to the fine-tuned model name.
:param model: model name
:param credentials: credentials
:return: model schema
"""
- extras = {}
+ if not model.startswith("ft:"):
+ base_model = model
+ else:
+ # get base_model
+ base_model = model.split(":")[1]
- if "vision_support" in credentials and credentials["vision_support"] == "true":
- extras["features"] = [ModelFeature.VISION]
+ # get model schema
+ models = self.predefined_models()
+ model_map = {model.model: model for model in models}
+ if base_model not in model_map:
+ raise ValueError(f"Base model {base_model} not found")
+
+ base_model_schema = model_map[base_model]
+
+ base_model_schema_features = base_model_schema.features or []
+ base_model_schema_model_properties = base_model_schema.model_properties or {}
+ base_model_schema_parameters_rules = base_model_schema.parameter_rules or []
entity = AIModelEntity(
model=model,
label=I18nObject(zh_Hans=model, en_US=model),
model_type=ModelType.LLM,
+ features=[feature for feature in base_model_schema_features],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
- ModelPropertyKey.MODE: credentials.get("mode"),
- ModelPropertyKey.CONTEXT_SIZE: int(
- credentials.get("context_size", 4096)
- ),
+ key: property
+ for key, property in base_model_schema_model_properties.items()
},
- parameter_rules=[
- ParameterRule(
- name=DefaultParameterName.TEMPERATURE.value,
- use_template=DefaultParameterName.TEMPERATURE.value,
- label=I18nObject(en_US="Temperature"),
- type=ParameterType.FLOAT,
- help=I18nObject(
- en_US="The temperature of the model. "
- "Increasing the temperature will make the model answer "
- "more creatively. (Default: 0.8)"
- ),
- default=0.8,
- min=0,
- max=2,
- ),
- ParameterRule(
- name=DefaultParameterName.TOP_P.value,
- use_template=DefaultParameterName.TOP_P.value,
- label=I18nObject(en_US="Top P"),
- type=ParameterType.FLOAT,
- help=I18nObject(
- en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
- "more diverse text, while a lower value (e.g., 0.5) will generate more "
- "focused and conservative text. (Default: 0.9)"
- ),
- default=0.9,
- min=0,
- max=1,
- ),
- ParameterRule(
- name="top_k",
- label=I18nObject(en_US="Top K"),
- type=ParameterType.INT,
- help=I18nObject(
- en_US="Reduces the probability of generating nonsense. "
- "A higher value (e.g. 100) will give more diverse answers, "
- "while a lower value (e.g. 10) will be more conservative. (Default: 40)"
- ),
- default=40,
- min=1,
- max=100,
- ),
- ParameterRule(
- name="repeat_penalty",
- label=I18nObject(en_US="Repeat Penalty"),
- type=ParameterType.FLOAT,
- help=I18nObject(
- en_US="Sets how strongly to penalize repetitions. "
- "A higher value (e.g., 1.5) will penalize repetitions more strongly, "
- "while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
- ),
- default=1.1,
- min=-2,
- max=2,
- ),
- ParameterRule(
- name="num_predict",
- use_template="max_tokens",
- label=I18nObject(en_US="Num Predict"),
- type=ParameterType.INT,
- help=I18nObject(
- en_US="Maximum number of tokens to predict when generating text. "
- "(Default: 128, -1 = infinite generation, -2 = fill context)"
- ),
- default=128,
- min=-2,
- max=int(credentials.get("max_tokens", 4096)),
- ),
- ParameterRule(
- name="mirostat",
- label=I18nObject(en_US="Mirostat sampling"),
- type=ParameterType.INT,
- help=I18nObject(
- en_US="Enable Mirostat sampling for controlling perplexity. "
- "(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
- ),
- default=0,
- min=0,
- max=2,
- ),
- ParameterRule(
- name="mirostat_eta",
- label=I18nObject(en_US="Mirostat Eta"),
- type=ParameterType.FLOAT,
- help=I18nObject(
- en_US="Influences how quickly the algorithm responds to feedback from "
- "the generated text. A lower learning rate will result in slower adjustments, "
- "while a higher learning rate will make the algorithm more responsive. "
- "(Default: 0.1)"
- ),
- default=0.1,
- precision=1,
- ),
- ParameterRule(
- name="mirostat_tau",
- label=I18nObject(en_US="Mirostat Tau"),
- type=ParameterType.FLOAT,
- help=I18nObject(
- en_US="Controls the balance between coherence and diversity of the output. "
- "A lower value will result in more focused and coherent text. (Default: 5.0)"
- ),
- default=5.0,
- precision=1,
- ),
- ParameterRule(
- name="num_ctx",
- label=I18nObject(en_US="Size of context window"),
- type=ParameterType.INT,
- help=I18nObject(
- en_US="Sets the size of the context window used to generate the next token. "
- "(Default: 2048)"
- ),
- default=2048,
- min=1,
- ),
- ParameterRule(
- name="num_gpu",
- label=I18nObject(en_US="Num GPU"),
- type=ParameterType.INT,
- help=I18nObject(
- en_US="The number of layers to send to the GPU(s). "
- "On macOS it defaults to 1 to enable metal support, 0 to disable."
- ),
- default=1,
- min=0,
- max=1,
- ),
- ParameterRule(
- name="num_thread",
- label=I18nObject(en_US="Num Thread"),
- type=ParameterType.INT,
- help=I18nObject(
- en_US="Sets the number of threads to use during computation. "
- "By default, Ollama will detect this for optimal performance. "
- "It is recommended to set this value to the number of physical CPU cores "
- "your system has (as opposed to the logical number of cores)."
- ),
- min=1,
- ),
- ParameterRule(
- name="repeat_last_n",
- label=I18nObject(en_US="Repeat last N"),
- type=ParameterType.INT,
- help=I18nObject(
- en_US="Sets how far back for the model to look back to prevent repetition. "
- "(Default: 64, 0 = disabled, -1 = num_ctx)"
- ),
- default=64,
- min=-1,
- ),
- ParameterRule(
- name="tfs_z",
- label=I18nObject(en_US="TFS Z"),
- type=ParameterType.FLOAT,
- help=I18nObject(
- en_US="Tail free sampling is used to reduce the impact of less probable tokens "
- "from the output. A higher value (e.g., 2.0) will reduce the impact more, "
- "while a value of 1.0 disables this setting. (default: 1)"
- ),
- default=1,
- precision=1,
- ),
- ParameterRule(
- name="seed",
- label=I18nObject(en_US="Seed"),
- type=ParameterType.INT,
- help=I18nObject(
- en_US="Sets the random number seed to use for generation. Setting this to "
- "a specific number will make the model generate the same text for "
- "the same prompt. (Default: 0)"
- ),
- default=0,
- ),
- ParameterRule(
- name="format",
- label=I18nObject(en_US="Format"),
- type=ParameterType.STRING,
- help=I18nObject(
- en_US="the format to return a response in."
- " Currently the only accepted value is json."
- ),
- options=["json"],
- ),
- ],
- pricing=PriceConfig(
- input=Decimal(credentials.get("input_price", 0)),
- output=Decimal(credentials.get("output_price", 0)),
- unit=Decimal(credentials.get("unit", 0)),
- currency=credentials.get("currency", "USD"),
- ),
- **extras,
+ parameter_rules=[rule for rule in base_model_schema_parameters_rules],
+ pricing=base_model_schema.pricing,
)
return entity
-
- @property
- def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
- """
- Map model invoke error to unified error
- The key is the error type thrown to the caller
- The value is the error type thrown by the model,
- which needs to be converted into a unified error type for the caller.
-
- :return: Invoke error mapping
- """
- return {
- InvokeAuthorizationError: [
- requests.exceptions.InvalidHeader, # Missing or Invalid API Key
- ],
- InvokeBadRequestError: [
- requests.exceptions.HTTPError, # Invalid Endpoint URL or model name
- requests.exceptions.InvalidURL, # Misconfigured request or other API error
- ],
- InvokeRateLimitError: [
- requests.exceptions.RetryError # Too many requests sent in a short period of time
- ],
- InvokeServerUnavailableError: [
- requests.exceptions.ConnectionError, # Engine Overloaded
- requests.exceptions.HTTPError, # Server Error
- ],
- InvokeConnectionError: [
- requests.exceptions.ConnectTimeout, # Timeout
- requests.exceptions.ReadTimeout, # Timeout
- ],
- }
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/ollama/ollama.py b/model-providers/model_providers/core/model_runtime/model_providers/ollama/ollama.py
index c6a78011..b4d68230 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/ollama/ollama.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/ollama/ollama.py
@@ -7,7 +7,7 @@ from model_providers.core.model_runtime.model_providers.__base.model_provider im
logger = logging.getLogger(__name__)
-class OpenAIProvider(ModelProvider):
+class OllamaProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/ollama/ollama.yaml b/model-providers/model_providers/core/model_runtime/model_providers/ollama/ollama.yaml
index 82c25fb2..d69c4ca9 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/ollama/ollama.yaml
+++ b/model-providers/model_providers/core/model_runtime/model_providers/ollama/ollama.yaml
@@ -11,7 +11,7 @@ help:
en_US: How to integrate with Ollama
zh_Hans: 如何集成 Ollama
url:
- en_US: https://docs.dify.ai/tutorials/model-configuration/ollama
+ en_US: "ollama"
supported_model_types:
- llm
- text-embedding
@@ -26,73 +26,13 @@ model_credential_schema:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- - variable: base_url
+
+ - variable: openai_api_base
label:
- zh_Hans: 基础 URL
- en_US: Base URL
+ zh_Hans: API Base
+ en_US: API Base
type: text-input
- required: true
- placeholder:
- zh_Hans: Ollama server 的基础 URL,例如 http://192.168.1.100:11434
- en_US: Base url of Ollama server, e.g. http://192.168.1.100:11434
- - variable: mode
- show_on:
- - variable: __model_type
- value: llm
- label:
- zh_Hans: 模型类型
- en_US: Completion mode
- type: select
- required: true
- default: chat
- placeholder:
- zh_Hans: 选择对话类型
- en_US: Select completion mode
- options:
- - value: completion
- label:
- en_US: Completion
- zh_Hans: 补全
- - value: chat
- label:
- en_US: Chat
- zh_Hans: 对话
- - variable: context_size
- label:
- zh_Hans: 模型上下文长度
- en_US: Model context size
- required: true
- type: text-input
- default: '4096'
- placeholder:
- zh_Hans: 在此输入您的模型上下文长度
- en_US: Enter your Model context size
- - variable: max_tokens
- label:
- zh_Hans: 最大 token 上限
- en_US: Upper bound for max tokens
- show_on:
- - variable: __model_type
- value: llm
- default: '4096'
- type: text-input
- required: true
- - variable: vision_support
- label:
- zh_Hans: 是否支持 Vision
- en_US: Vision support
- show_on:
- - variable: __model_type
- value: llm
- default: 'false'
- type: radio
required: false
- options:
- - value: 'true'
- label:
- en_US: "Yes"
- zh_Hans: 是
- - value: 'false'
- label:
- en_US: "No"
- zh_Hans: 否
+ placeholder:
+ zh_Hans: 在此输入您的 API Base
+ en_US: Enter your API Base
\ No newline at end of file
diff --git a/model-providers/tests/conftest.py b/model-providers/tests/conftest.py
index a4508b81..c3191f07 100644
--- a/model-providers/tests/conftest.py
+++ b/model-providers/tests/conftest.py
@@ -136,7 +136,7 @@ def init_server(logging_conf: dict, providers_file: str) -> None:
yield f"http://127.0.0.1:20000"
finally:
print("")
- boot.destroy()
+ # boot.destroy()
except SystemExit:
diff --git a/model-providers/tests/ollama_providers_test/model_providers.yaml b/model-providers/tests/ollama_providers_test/model_providers.yaml
new file mode 100644
index 00000000..9ef23a0d
--- /dev/null
+++ b/model-providers/tests/ollama_providers_test/model_providers.yaml
@@ -0,0 +1,10 @@
+
+ollama:
+ model_credential:
+ - model: 'llama3'
+ model_type: 'llm'
+ model_credentials:
+ openai_api_base: 'http://172.21.80.1:11434'
+
+
+
diff --git a/model-providers/tests/ollama_providers_test/test_ollama_service.py b/model-providers/tests/ollama_providers_test/test_ollama_service.py
new file mode 100644
index 00000000..5338efb9
--- /dev/null
+++ b/model-providers/tests/ollama_providers_test/test_ollama_service.py
@@ -0,0 +1,34 @@
+from langchain.chains import LLMChain
+from langchain_core.prompts import PromptTemplate
+from langchain_openai import ChatOpenAI, OpenAIEmbeddings
+import pytest
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.mark.requires("openai")
+def test_llm(init_server: str):
+ llm = ChatOpenAI(model_name="llama3", openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/ollama/v1")
+ template = """Question: {question}
+
+ Answer: Let's think step by step."""
+
+ prompt = PromptTemplate.from_template(template)
+
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
+ responses = llm_chain.run("你好")
+ logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
+
+
+@pytest.mark.requires("openai")
+def test_embedding(init_server: str):
+ embeddings = OpenAIEmbeddings(model="text-embedding-3-large",
+ openai_api_key="YOUR_API_KEY",
+ openai_api_base=f"{init_server}/zhipuai/v1")
+
+ text = "你好"
+
+ query_result = embeddings.embed_query(text)
+
+ logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
diff --git a/model-providers/tests/openai_providers_test/model_providers.yaml b/model-providers/tests/openai_providers_test/model_providers.yaml
index b98d2924..9ef23a0d 100644
--- a/model-providers/tests/openai_providers_test/model_providers.yaml
+++ b/model-providers/tests/openai_providers_test/model_providers.yaml
@@ -1,5 +1,10 @@
-openai:
- provider_credential:
- openai_api_key: 'sk-'
- openai_organization: ''
- openai_api_base: ''
+
+ollama:
+ model_credential:
+ - model: 'llama3'
+ model_type: 'llm'
+ model_credentials:
+ openai_api_base: 'http://172.21.80.1:11434'
+
+
+
diff --git a/model-providers/tests/openai_providers_test/test_openai_service.py b/model-providers/tests/openai_providers_test/test_ollama_service.py
similarity index 99%
rename from model-providers/tests/openai_providers_test/test_openai_service.py
rename to model-providers/tests/openai_providers_test/test_ollama_service.py
index 958fa108..0bdf455e 100644
--- a/model-providers/tests/openai_providers_test/test_openai_service.py
+++ b/model-providers/tests/openai_providers_test/test_ollama_service.py
@@ -6,6 +6,7 @@ import logging
logger = logging.getLogger(__name__)
+
@pytest.mark.requires("openai")
def test_llm(init_server: str):
llm = ChatOpenAI(openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1")
diff --git a/model-providers/tests/unit_tests/ollama/model_providers.yaml b/model-providers/tests/unit_tests/ollama/model_providers.yaml
new file mode 100644
index 00000000..9ef23a0d
--- /dev/null
+++ b/model-providers/tests/unit_tests/ollama/model_providers.yaml
@@ -0,0 +1,10 @@
+
+ollama:
+ model_credential:
+ - model: 'llama3'
+ model_type: 'llm'
+ model_credentials:
+ openai_api_base: 'http://172.21.80.1:11434'
+
+
+
diff --git a/model-providers/tests/unit_tests/ollama/test_provider_manager_models.py b/model-providers/tests/unit_tests/ollama/test_provider_manager_models.py
new file mode 100644
index 00000000..e60d082d
--- /dev/null
+++ b/model-providers/tests/unit_tests/ollama/test_provider_manager_models.py
@@ -0,0 +1,42 @@
+import asyncio
+import logging
+
+import pytest
+from omegaconf import OmegaConf
+
+from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
+from model_providers.core.model_manager import ModelManager
+from model_providers.core.model_runtime.entities.model_entities import ModelType
+from model_providers.core.provider_manager import ProviderManager
+
+logger = logging.getLogger(__name__)
+
+
+def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
+ logging.config.dictConfig(logging_conf) # type: ignore
+ # 读取配置文件
+ cfg = OmegaConf.load(
+ providers_file
+ )
+ # 转换配置文件
+ (
+ provider_name_to_provider_records_dict,
+ provider_name_to_provider_model_records_dict,
+ ) = _to_custom_provide_configuration(cfg)
+ # 创建模型管理器
+ provider_manager = ProviderManager(
+ provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
+ provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
+ )
+
+ provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
+ provider="ollama", model_type=ModelType.LLM
+ )
+ provider_model_bundle_emb = provider_manager.get_provider_model_bundle(
+ provider="ollama", model_type=ModelType.TEXT_EMBEDDING
+ )
+ predefined_models = (
+ provider_model_bundle_llm.model_type_instance.predefined_models()
+ )
+
+ logger.info(f"predefined_models: {predefined_models}")