From d0267bf66b286527f525da53613e7f74f2a554fd Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Mon, 20 May 2024 11:35:14 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=80=E4=BA=9B=E5=85=BC=E5=AE=B9=20pydantic?= =?UTF-8?q?<3,>=3D1.9.0=20=20=E7=9A=84=E4=BB=A3=E7=A0=81=EF=BC=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model-providers/model_providers/_compat.py | 222 ++++++ model-providers/model_providers/_files.py | 127 ++++ model-providers/model_providers/_models.py | 657 ++++++++++++++++++ model-providers/model_providers/_types.py | 220 ++++++ .../model_providers/_utils/__init__.py | 48 ++ .../model_providers/_utils/_transform.py | 382 ++++++++++ .../model_providers/_utils/_typing.py | 120 ++++ .../model_providers/_utils/_utils.py | 403 +++++++++++ .../entities/model_provider_entities.py | 2 +- .../core/bootstrap/openai_protocol.py | 10 +- .../core/entities/application_entities.py | 331 --------- .../core/entities/message_entities.py | 4 +- .../core/entities/model_entities.py | 2 +- .../core/entities/provider_configuration.py | 2 +- .../core/entities/provider_entities.py | 2 +- .../core/entities/queue_entities.py | 2 +- .../model_runtime/entities/common_entities.py | 2 +- .../model_runtime/entities/llm_entities.py | 2 +- .../entities/message_entities.py | 2 +- .../model_runtime/entities/model_entities.py | 2 +- .../entities/provider_entities.py | 2 +- .../model_runtime/entities/rerank_entities.py | 2 +- .../entities/text_embedding_entities.py | 2 +- .../model_providers/azure_openai/_constant.py | 2 +- .../model_providers/model_provider_factory.py | 2 +- .../core/model_runtime/utils/_compat.py | 21 - .../core/model_runtime/utils/encoders.py | 234 ------- .../core/model_runtime/utils/helper.py | 2 +- .../model_providers/core/utils/generic.py | 2 +- .../model_providers/core/utils/json_dumps.py | 2 +- model-providers/tests/conftest.py | 2 +- 31 files changed, 2204 insertions(+), 611 deletions(-) create mode 100644 model-providers/model_providers/_compat.py create mode 100644 model-providers/model_providers/_files.py create mode 100644 model-providers/model_providers/_models.py create mode 100644 model-providers/model_providers/_types.py create mode 100644 model-providers/model_providers/_utils/__init__.py create mode 100644 model-providers/model_providers/_utils/_transform.py create mode 100644 model-providers/model_providers/_utils/_typing.py create mode 100644 model-providers/model_providers/_utils/_utils.py delete mode 100644 model-providers/model_providers/core/entities/application_entities.py delete mode 100644 model-providers/model_providers/core/model_runtime/utils/_compat.py delete mode 100644 model-providers/model_providers/core/model_runtime/utils/encoders.py diff --git a/model-providers/model_providers/_compat.py b/model-providers/model_providers/_compat.py new file mode 100644 index 00000000..0339d10a --- /dev/null +++ b/model-providers/model_providers/_compat.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload +from datetime import date, datetime +from typing_extensions import Self + +import pydantic +from pydantic.fields import FieldInfo + +from ._types import StrBytesIntFloat + +_T = TypeVar("_T") +_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel) + +# --------------- Pydantic v2 compatibility --------------- + +# Pyright incorrectly reports some of our functions as overriding a method when they don't +# pyright: reportIncompatibleMethodOverride=false + +PYDANTIC_V2 = pydantic.VERSION.startswith("2.") + +# v1 re-exports +if TYPE_CHECKING: + + def parse_date(value: Union[date, StrBytesIntFloat]) -> date: # noqa: ARG001 + ... + + def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001 + ... + + def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001 + ... + + def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001 + ... + + def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001 + ... + + def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001 + ... + + def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001 + ... + +else: + if PYDANTIC_V2: + from pydantic.v1.typing import ( + get_args as get_args, + is_union as is_union, + get_origin as get_origin, + is_typeddict as is_typeddict, + is_literal_type as is_literal_type, + ) + from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime + else: + from pydantic.typing import ( + get_args as get_args, + is_union as is_union, + get_origin as get_origin, + is_typeddict as is_typeddict, + is_literal_type as is_literal_type, + ) + from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime + + +# refactored config +if TYPE_CHECKING: + from pydantic import ConfigDict as ConfigDict +else: + if PYDANTIC_V2: + from pydantic import ConfigDict + else: + # TODO: provide an error message here? + ConfigDict = None + + +# renamed methods / properties +def parse_obj(model: type[_ModelT], value: object) -> _ModelT: + if PYDANTIC_V2: + return model.model_validate(value) + else: + return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + + +def field_is_required(field: FieldInfo) -> bool: + if PYDANTIC_V2: + return field.is_required() + return field.required # type: ignore + + +def field_get_default(field: FieldInfo) -> Any: + value = field.get_default() + if PYDANTIC_V2: + from pydantic_core import PydanticUndefined + + if value == PydanticUndefined: + return None + return value + return value + + +def field_outer_type(field: FieldInfo) -> Any: + if PYDANTIC_V2: + return field.annotation + return field.outer_type_ # type: ignore + + +def get_model_config(model: type[pydantic.BaseModel]) -> Any: + if PYDANTIC_V2: + return model.model_config + return model.__config__ # type: ignore + + +def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]: + if PYDANTIC_V2: + return model.model_fields + return model.__fields__ # type: ignore + + +def model_copy(model: _ModelT) -> _ModelT: + if PYDANTIC_V2: + return model.model_copy() + return model.copy() # type: ignore + + +def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: + if PYDANTIC_V2: + return model.model_dump_json(indent=indent) + return model.json(indent=indent) # type: ignore + + +def model_dump( + model: pydantic.BaseModel, + *, + exclude_unset: bool = False, + exclude_defaults: bool = False, +) -> dict[str, Any]: + if PYDANTIC_V2: + return model.model_dump( + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + ) + return cast( + "dict[str, Any]", + model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + ), + ) + + +def model_parse(model: type[_ModelT], data: Any) -> _ModelT: + if PYDANTIC_V2: + return model.model_validate(data) + return model.parse_obj(data) # pyright: ignore[reportDeprecated] + + +# generic models +if TYPE_CHECKING: + + class GenericModel(pydantic.BaseModel): + ... + +else: + if PYDANTIC_V2: + # there no longer needs to be a distinction in v2 but + # we still have to create our own subclass to avoid + # inconsistent MRO ordering errors + class GenericModel(pydantic.BaseModel): + ... + + else: + import pydantic.generics + + class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): + ... + + +# cached properties +if TYPE_CHECKING: + cached_property = property + + # we define a separate type (copied from typeshed) + # that represents that `cached_property` is `set`able + # at runtime, which differs from `@property`. + # + # this is a separate type as editors likely special case + # `@property` and we don't want to cause issues just to have + # more helpful internal types. + + class typed_cached_property(Generic[_T]): + func: Callable[[Any], _T] + attrname: str | None + + def __init__(self, func: Callable[[Any], _T]) -> None: + ... + + @overload + def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: + ... + + @overload + def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: + ... + + def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self: + raise NotImplementedError() + + def __set_name__(self, owner: type[Any], name: str) -> None: + ... + + # __set__ is not defined at runtime, but @cached_property is designed to be settable + def __set__(self, instance: object, value: _T) -> None: + ... +else: + try: + from functools import cached_property as cached_property + except ImportError: + from cached_property import cached_property as cached_property + + typed_cached_property = cached_property diff --git a/model-providers/model_providers/_files.py b/model-providers/model_providers/_files.py new file mode 100644 index 00000000..ad7b668b --- /dev/null +++ b/model-providers/model_providers/_files.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import io +import os +import pathlib +from typing import overload +from typing_extensions import TypeGuard + +import anyio + +from ._types import ( + FileTypes, + FileContent, + RequestFiles, + HttpxFileTypes, + Base64FileInput, + HttpxFileContent, + HttpxRequestFiles, +) +from ._utils import is_tuple_t, is_mapping_t, is_sequence_t + + +def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]: + return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike) + + +def is_file_content(obj: object) -> TypeGuard[FileContent]: + return ( + isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike) + ) + + +def assert_is_file_content(obj: object, *, key: str | None = None) -> None: + if not is_file_content(obj): + prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`" + raise RuntimeError( + f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads" + ) from None + + +@overload +def to_httpx_files(files: None) -> None: + ... + + +@overload +def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: + ... + + +def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: + if files is None: + return None + + if is_mapping_t(files): + files = {key: _transform_file(file) for key, file in files.items()} + elif is_sequence_t(files): + files = [(key, _transform_file(file)) for key, file in files] + else: + raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence") + + return files + + +def _transform_file(file: FileTypes) -> HttpxFileTypes: + if is_file_content(file): + if isinstance(file, os.PathLike): + path = pathlib.Path(file) + return (path.name, path.read_bytes()) + + return file + + if is_tuple_t(file): + return (file[0], _read_file_content(file[1]), *file[2:]) + + raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") + + +def _read_file_content(file: FileContent) -> HttpxFileContent: + if isinstance(file, os.PathLike): + return pathlib.Path(file).read_bytes() + return file + + +@overload +async def async_to_httpx_files(files: None) -> None: + ... + + +@overload +async def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: + ... + + +async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: + if files is None: + return None + + if is_mapping_t(files): + files = {key: await _async_transform_file(file) for key, file in files.items()} + elif is_sequence_t(files): + files = [(key, await _async_transform_file(file)) for key, file in files] + else: + raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence") + + return files + + +async def _async_transform_file(file: FileTypes) -> HttpxFileTypes: + if is_file_content(file): + if isinstance(file, os.PathLike): + path = anyio.Path(file) + return (path.name, await path.read_bytes()) + + return file + + if is_tuple_t(file): + return (file[0], await _async_read_file_content(file[1]), *file[2:]) + + raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") + + +async def _async_read_file_content(file: FileContent) -> HttpxFileContent: + if isinstance(file, os.PathLike): + return await anyio.Path(file).read_bytes() + + return file diff --git a/model-providers/model_providers/_models.py b/model-providers/model_providers/_models.py new file mode 100644 index 00000000..e7aa662a --- /dev/null +++ b/model-providers/model_providers/_models.py @@ -0,0 +1,657 @@ +from __future__ import annotations + +import os +import inspect +from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast +from datetime import date, datetime +from typing_extensions import ( + Unpack, + Literal, + ClassVar, + Protocol, + Required, + TypedDict, + TypeGuard, + final, + override, + runtime_checkable, +) + +import pydantic +import pydantic.generics +from pydantic.fields import FieldInfo + +from ._types import ( + IncEx, + ModelT, +) +from ._utils import ( + PropertyInfo, + is_list, + is_given, + lru_cache, + is_mapping, + parse_date, + coerce_boolean, + parse_datetime, + strip_not_given, + extract_type_arg, + is_annotated_type, + strip_annotated_type, +) +from ._compat import ( + PYDANTIC_V2, + ConfigDict, + GenericModel as BaseGenericModel, + get_args, + is_union, + parse_obj, + get_origin, + is_literal_type, + get_model_config, + get_model_fields, + field_get_default, +) + +if TYPE_CHECKING: + from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema + +__all__ = ["BaseModel", "GenericModel"] + +_T = TypeVar("_T") + + +@runtime_checkable +class _ConfigProtocol(Protocol): + allow_population_by_field_name: bool + + +class BaseModel(pydantic.BaseModel): + if PYDANTIC_V2: + model_config: ClassVar[ConfigDict] = ConfigDict( + extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")) + ) + else: + + @property + @override + def model_fields_set(self) -> set[str]: + # a forwards-compat shim for pydantic v2 + return self.__fields_set__ # type: ignore + + class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] + extra: Any = pydantic.Extra.allow # type: ignore + + def to_dict( + self, + *, + mode: Literal["json", "python"] = "python", + use_api_names: bool = True, + exclude_unset: bool = True, + exclude_defaults: bool = False, + exclude_none: bool = False, + warnings: bool = True, + ) -> dict[str, object]: + """Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude. + + By default, fields that were not set by the API will not be included, + and keys will match the API response, *not* the property names from the model. + + For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property, + the output will use the `"fooBar"` key (unless `use_api_names=False` is passed). + + Args: + mode: + If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`. + If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)` + + use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that are set to their default value from the output. + exclude_none: Whether to exclude fields that have a value of `None` from the output. + warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2. + """ + return self.model_dump( + mode=mode, + by_alias=use_api_names, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + warnings=warnings, + ) + + def to_json( + self, + *, + indent: int | None = 2, + use_api_names: bool = True, + exclude_unset: bool = True, + exclude_defaults: bool = False, + exclude_none: bool = False, + warnings: bool = True, + ) -> str: + """Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation). + + By default, fields that were not set by the API will not be included, + and keys will match the API response, *not* the property names from the model. + + For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property, + the output will use the `"fooBar"` key (unless `use_api_names=False` is passed). + + Args: + indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2` + use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that have the default value. + exclude_none: Whether to exclude fields that have a value of `None`. + warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2. + """ + return self.model_dump_json( + indent=indent, + by_alias=use_api_names, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + warnings=warnings, + ) + + @override + def __str__(self) -> str: + # mypy complains about an invalid self arg + return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc] + + # Override the 'construct' method in a way that supports recursive parsing without validation. + # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836. + @classmethod + @override + def construct( + cls: Type[ModelT], + _fields_set: set[str] | None = None, + **values: object, + ) -> ModelT: + m = cls.__new__(cls) + fields_values: dict[str, object] = {} + + config = get_model_config(cls) + populate_by_name = ( + config.allow_population_by_field_name + if isinstance(config, _ConfigProtocol) + else config.get("populate_by_name") + ) + + if _fields_set is None: + _fields_set = set() + + model_fields = get_model_fields(cls) + for name, field in model_fields.items(): + key = field.alias + if key is None or (key not in values and populate_by_name): + key = name + + if key in values: + fields_values[name] = _construct_field(value=values[key], field=field, key=key) + _fields_set.add(name) + else: + fields_values[name] = field_get_default(field) + + _extra = {} + for key, value in values.items(): + if key not in model_fields: + if PYDANTIC_V2: + _extra[key] = value + else: + _fields_set.add(key) + fields_values[key] = value + + object.__setattr__(m, "__dict__", fields_values) + + if PYDANTIC_V2: + # these properties are copied from Pydantic's `model_construct()` method + object.__setattr__(m, "__pydantic_private__", None) + object.__setattr__(m, "__pydantic_extra__", _extra) + object.__setattr__(m, "__pydantic_fields_set__", _fields_set) + else: + # init_private_attributes() does not exist in v2 + m._init_private_attributes() # type: ignore + + # copied from Pydantic v1's `construct()` method + object.__setattr__(m, "__fields_set__", _fields_set) + + return m + + if not TYPE_CHECKING: + # type checkers incorrectly complain about this assignment + # because the type signatures are technically different + # although not in practice + model_construct = construct + + if not PYDANTIC_V2: + # we define aliases for some of the new pydantic v2 methods so + # that we can just document these methods without having to specify + # a specific pydantic version as some users may not know which + # pydantic version they are currently using + + @override + def model_dump( + self, + *, + mode: Literal["json", "python"] | str = "python", + include: IncEx = None, + exclude: IncEx = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + round_trip: bool = False, + warnings: bool | Literal["none", "warn", "error"] = True, + context: dict[str, Any] | None = None, + serialize_as_any: bool = False, + ) -> dict[str, Any]: + """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump + + Generate a dictionary representation of the model, optionally specifying which fields to include or exclude. + + Args: + mode: The mode in which `to_python` should run. + If mode is 'json', the dictionary will only contain JSON serializable types. + If mode is 'python', the dictionary may contain any Python objects. + include: A list of fields to include in the output. + exclude: A list of fields to exclude from the output. + by_alias: Whether to use the field's alias in the dictionary key if defined. + exclude_unset: Whether to exclude fields that are unset or None from the output. + exclude_defaults: Whether to exclude fields that are set to their default value from the output. + exclude_none: Whether to exclude fields that have a value of `None` from the output. + round_trip: Whether to enable serialization and deserialization round-trip support. + warnings: Whether to log warnings when invalid fields are encountered. + + Returns: + A dictionary representation of the model. + """ + if mode != "python": + raise ValueError("mode is only supported in Pydantic v2") + if round_trip != False: + raise ValueError("round_trip is only supported in Pydantic v2") + if warnings != True: + raise ValueError("warnings is only supported in Pydantic v2") + if context is not None: + raise ValueError("context is only supported in Pydantic v2") + if serialize_as_any != False: + raise ValueError("serialize_as_any is only supported in Pydantic v2") + return super().dict( # pyright: ignore[reportDeprecated] + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + @override + def model_dump_json( + self, + *, + indent: int | None = None, + include: IncEx = None, + exclude: IncEx = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + round_trip: bool = False, + warnings: bool | Literal["none", "warn", "error"] = True, + context: dict[str, Any] | None = None, + serialize_as_any: bool = False, + ) -> str: + """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json + + Generates a JSON representation of the model using Pydantic's `to_json` method. + + Args: + indent: Indentation to use in the JSON output. If None is passed, the output will be compact. + include: Field(s) to include in the JSON output. Can take either a string or set of strings. + exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings. + by_alias: Whether to serialize using field aliases. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that have the default value. + exclude_none: Whether to exclude fields that have a value of `None`. + round_trip: Whether to use serialization/deserialization between JSON and class instance. + warnings: Whether to show any warnings that occurred during serialization. + + Returns: + A JSON string representation of the model. + """ + if round_trip != False: + raise ValueError("round_trip is only supported in Pydantic v2") + if warnings != True: + raise ValueError("warnings is only supported in Pydantic v2") + if context is not None: + raise ValueError("context is only supported in Pydantic v2") + if serialize_as_any != False: + raise ValueError("serialize_as_any is only supported in Pydantic v2") + return super().json( # type: ignore[reportDeprecated] + indent=indent, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + +def _construct_field(value: object, field: FieldInfo, key: str) -> object: + if value is None: + return field_get_default(field) + + if PYDANTIC_V2: + type_ = field.annotation + else: + type_ = cast(type, field.outer_type_) # type: ignore + + if type_ is None: + raise RuntimeError(f"Unexpected field type is None for {key}") + + return construct_type(value=value, type_=type_) + + +def is_basemodel(type_: type) -> bool: + """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`""" + if is_union(type_): + for variant in get_args(type_): + if is_basemodel(variant): + return True + + return False + + return is_basemodel_type(type_) + + +def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]: + origin = get_origin(type_) or type_ + return issubclass(origin, BaseModel) or issubclass(origin, GenericModel) + + +def construct_type(*, value: object, type_: object) -> object: + """Loose coercion to the expected type with construction of nested values. + + If the given value does not match the expected type then it is returned as-is. + """ + # we allow `object` as the input type because otherwise, passing things like + # `Literal['value']` will be reported as a type error by type checkers + type_ = cast("type[object]", type_) + + # unwrap `Annotated[T, ...]` -> `T` + if is_annotated_type(type_): + meta: tuple[Any, ...] = get_args(type_)[1:] + type_ = extract_type_arg(type_, 0) + else: + meta = tuple() + + # we need to use the origin class for any types that are subscripted generics + # e.g. Dict[str, object] + origin = get_origin(type_) or type_ + args = get_args(type_) + + if is_union(origin): + try: + return validate_type(type_=cast("type[object]", type_), value=value) + except Exception: + pass + + # if the type is a discriminated union then we want to construct the right variant + # in the union, even if the data doesn't match exactly, otherwise we'd break code + # that relies on the constructed class types, e.g. + # + # class FooType: + # kind: Literal['foo'] + # value: str + # + # class BarType: + # kind: Literal['bar'] + # value: int + # + # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then + # we'd end up constructing `FooType` when it should be `BarType`. + discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta) + if discriminator and is_mapping(value): + variant_value = value.get(discriminator.field_alias_from or discriminator.field_name) + if variant_value and isinstance(variant_value, str): + variant_type = discriminator.mapping.get(variant_value) + if variant_type: + return construct_type(type_=variant_type, value=value) + + # if the data is not valid, use the first variant that doesn't fail while deserializing + for variant in args: + try: + return construct_type(value=value, type_=variant) + except Exception: + continue + + raise RuntimeError(f"Could not convert data into a valid instance of {type_}") + + if origin == dict: + if not is_mapping(value): + return value + + _, items_type = get_args(type_) # Dict[_, items_type] + return {key: construct_type(value=item, type_=items_type) for key, item in value.items()} + + if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)): + if is_list(value): + return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value] + + if is_mapping(value): + if issubclass(type_, BaseModel): + return type_.construct(**value) # type: ignore[arg-type] + + return cast(Any, type_).construct(**value) + + if origin == list: + if not is_list(value): + return value + + inner_type = args[0] # List[inner_type] + return [construct_type(value=entry, type_=inner_type) for entry in value] + + if origin == float: + if isinstance(value, int): + coerced = float(value) + if coerced != value: + return value + return coerced + + return value + + if type_ == datetime: + try: + return parse_datetime(value) # type: ignore + except Exception: + return value + + if type_ == date: + try: + return parse_date(value) # type: ignore + except Exception: + return value + + return value + + +@runtime_checkable +class CachedDiscriminatorType(Protocol): + __discriminator__: DiscriminatorDetails + + +class DiscriminatorDetails: + field_name: str + """The name of the discriminator field in the variant class, e.g. + + ```py + class Foo(BaseModel): + type: Literal['foo'] + ``` + + Will result in field_name='type' + """ + + field_alias_from: str | None + """The name of the discriminator field in the API response, e.g. + + ```py + class Foo(BaseModel): + type: Literal['foo'] = Field(alias='type_from_api') + ``` + + Will result in field_alias_from='type_from_api' + """ + + mapping: dict[str, type] + """Mapping of discriminator value to variant type, e.g. + + {'foo': FooVariant, 'bar': BarVariant} + """ + + def __init__( + self, + *, + mapping: dict[str, type], + discriminator_field: str, + discriminator_alias: str | None, + ) -> None: + self.mapping = mapping + self.field_name = discriminator_field + self.field_alias_from = discriminator_alias + + +def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: + if isinstance(union, CachedDiscriminatorType): + return union.__discriminator__ + + discriminator_field_name: str | None = None + + for annotation in meta_annotations: + if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None: + discriminator_field_name = annotation.discriminator + break + + if not discriminator_field_name: + return None + + mapping: dict[str, type] = {} + discriminator_alias: str | None = None + + for variant in get_args(union): + variant = strip_annotated_type(variant) + if is_basemodel_type(variant): + if PYDANTIC_V2: + field = _extract_field_schema_pv2(variant, discriminator_field_name) + if not field: + continue + + # Note: if one variant defines an alias then they all should + discriminator_alias = field.get("serialization_alias") + + field_schema = field["schema"] + + if field_schema["type"] == "literal": + for entry in cast("LiteralSchema", field_schema)["expected"]: + if isinstance(entry, str): + mapping[entry] = variant + else: + field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + if not field_info: + continue + + # Note: if one variant defines an alias then they all should + discriminator_alias = field_info.alias + + if field_info.annotation and is_literal_type(field_info.annotation): + for entry in get_args(field_info.annotation): + if isinstance(entry, str): + mapping[entry] = variant + + if not mapping: + return None + + details = DiscriminatorDetails( + mapping=mapping, + discriminator_field=discriminator_field_name, + discriminator_alias=discriminator_alias, + ) + cast(CachedDiscriminatorType, union).__discriminator__ = details + return details + + +def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None: + schema = model.__pydantic_core_schema__ + if schema["type"] != "model": + return None + + fields_schema = schema["schema"] + if fields_schema["type"] != "model-fields": + return None + + fields_schema = cast("ModelFieldsSchema", fields_schema) + + field = fields_schema["fields"].get(field_name) + if not field: + return None + + return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast] + + +def validate_type(*, type_: type[_T], value: object) -> _T: + """Strict validation that the given value matches the expected type""" + if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): + return cast(_T, parse_obj(type_, value)) + + return cast(_T, _validate_non_model_type(type_=type_, value=value)) + + +# our use of subclasssing here causes weirdness for type checkers, +# so we just pretend that we don't subclass +if TYPE_CHECKING: + GenericModel = BaseModel +else: + + class GenericModel(BaseGenericModel, BaseModel): + pass + + +if PYDANTIC_V2: + from pydantic import TypeAdapter as _TypeAdapter + + _CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter)) + + if TYPE_CHECKING: + from pydantic import TypeAdapter + else: + TypeAdapter = _CachedTypeAdapter + + def _validate_non_model_type(*, type_: type[_T], value: object) -> _T: + return TypeAdapter(type_).validate_python(value) + +elif not TYPE_CHECKING: # TODO: condition is weird + + class RootModel(GenericModel, Generic[_T]): + """Used as a placeholder to easily convert runtime types to a Pydantic format + to provide validation. + + For example: + ```py + validated = RootModel[int](__root__="5").__root__ + # validated: 5 + ``` + """ + + __root__: _T + + def _validate_non_model_type(*, type_: type[_T], value: object) -> _T: + model = _create_pydantic_model(type_).validate(value) + return cast(_T, model.__root__) + + def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]: + return RootModel[type_] # type: ignore + + + diff --git a/model-providers/model_providers/_types.py b/model-providers/model_providers/_types.py new file mode 100644 index 00000000..6fce8e09 --- /dev/null +++ b/model-providers/model_providers/_types.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +from os import PathLike +from typing import ( + IO, + TYPE_CHECKING, + Any, + Dict, + List, + Type, + Tuple, + Union, + Mapping, + TypeVar, + Callable, + Optional, + Sequence, +) +from typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable + +import httpx +import pydantic +from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport + +if TYPE_CHECKING: + from ._models import BaseModel + +Transport = BaseTransport +AsyncTransport = AsyncBaseTransport +Query = Mapping[str, object] +Body = object +AnyMapping = Mapping[str, object] +ModelT = TypeVar("ModelT", bound=pydantic.BaseModel) +_T = TypeVar("_T") + + +# Approximates httpx internal ProxiesTypes and RequestFiles types +# while adding support for `PathLike` instances +ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]] +ProxiesTypes = Union[str, Proxy, ProxiesDict] +if TYPE_CHECKING: + Base64FileInput = Union[IO[bytes], PathLike[str]] + FileContent = Union[IO[bytes], bytes, PathLike[str]] +else: + Base64FileInput = Union[IO[bytes], PathLike] + FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8. +FileTypes = Union[ + # file (or bytes) + FileContent, + # (filename, file (or bytes)) + Tuple[Optional[str], FileContent], + # (filename, file (or bytes), content_type) + Tuple[Optional[str], FileContent, Optional[str]], + # (filename, file (or bytes), content_type, headers) + Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], +] +RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]] + +# duplicate of the above but without our custom file support +HttpxFileContent = Union[IO[bytes], bytes] +HttpxFileTypes = Union[ + # file (or bytes) + HttpxFileContent, + # (filename, file (or bytes)) + Tuple[Optional[str], HttpxFileContent], + # (filename, file (or bytes), content_type) + Tuple[Optional[str], HttpxFileContent, Optional[str]], + # (filename, file (or bytes), content_type, headers) + Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]], +] +HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]] + +# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT +# where ResponseT includes `None`. In order to support directly +# passing `None`, overloads would have to be defined for every +# method that uses `ResponseT` which would lead to an unacceptable +# amount of code duplication and make it unreadable. See _base_client.py +# for example usage. +# +# This unfortunately means that you will either have +# to import this type and pass it explicitly: +# +# from openai import NoneType +# client.get('/foo', cast_to=NoneType) +# +# or build it yourself: +# +# client.get('/foo', cast_to=type(None)) +if TYPE_CHECKING: + NoneType: Type[None] +else: + NoneType = type(None) + + +class RequestOptions(TypedDict, total=False): + headers: Headers + max_retries: int + timeout: float | Timeout | None + params: Query + extra_json: AnyMapping + idempotency_key: str + + +# Sentinel class used until PEP 0661 is accepted +class NotGiven: + """ + A sentinel singleton class used to distinguish omitted keyword arguments + from those passed in with the value None (which may have different behavior). + + For example: + + ```py + def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: + ... + + + get(timeout=1) # 1s timeout + get(timeout=None) # No timeout + get() # Default timeout behavior, which may not be statically known at the method definition. + ``` + """ + + def __bool__(self) -> Literal[False]: + return False + + @override + def __repr__(self) -> str: + return "NOT_GIVEN" + + +NotGivenOr = Union[_T, NotGiven] +NOT_GIVEN = NotGiven() + + +class Omit: + """In certain situations you need to be able to represent a case where a default value has + to be explicitly removed and `None` is not an appropriate substitute, for example: + + ```py + # as the default `Content-Type` header is `application/json` that will be sent + client.post("/upload/files", files={"file": b"my raw file content"}) + + # you can't explicitly override the header as it has to be dynamically generated + # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983' + client.post(..., headers={"Content-Type": "multipart/form-data"}) + + # instead you can remove the default `application/json` header by passing Omit + client.post(..., headers={"Content-Type": Omit()}) + ``` + """ + + def __bool__(self) -> Literal[False]: + return False + + +@runtime_checkable +class ModelBuilderProtocol(Protocol): + @classmethod + def build( + cls: type[_T], + *, + response: Response, + data: object, + ) -> _T: + ... + + +Headers = Mapping[str, Union[str, Omit]] + + +class HeadersLikeProtocol(Protocol): + def get(self, __key: str) -> str | None: + ... + + +HeadersLike = Union[Headers, HeadersLikeProtocol] + +ResponseT = TypeVar( + "ResponseT", + bound=Union[ + object, + str, + None, + "BaseModel", + List[Any], + Dict[str, Any], + Response, + ModelBuilderProtocol, + "APIResponse[Any]", + "AsyncAPIResponse[Any]", + "HttpxBinaryResponseContent", + ], +) + +StrBytesIntFloat = Union[str, bytes, int, float] + +# Note: copied from Pydantic +# https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49 +IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None" + +PostParser = Callable[[Any], Any] + + +@runtime_checkable +class InheritsGeneric(Protocol): + """Represents a type that has inherited from `Generic` + + The `__orig_bases__` property can be used to determine the resolved + type variable for a given base class. + """ + + __orig_bases__: tuple[_GenericAlias] + + +class _GenericAlias(Protocol): + __origin__: type[object] + + +class HttpxSendArgs(TypedDict, total=False): + auth: httpx.Auth diff --git a/model-providers/model_providers/_utils/__init__.py b/model-providers/model_providers/_utils/__init__.py new file mode 100644 index 00000000..c5e3bee8 --- /dev/null +++ b/model-providers/model_providers/_utils/__init__.py @@ -0,0 +1,48 @@ +from ._utils import ( + flatten as flatten, + is_dict as is_dict, + is_list as is_list, + is_given as is_given, + is_tuple as is_tuple, + lru_cache as lru_cache, + is_mapping as is_mapping, + is_tuple_t as is_tuple_t, + parse_date as parse_date, + is_iterable as is_iterable, + is_sequence as is_sequence, + coerce_float as coerce_float, + is_mapping_t as is_mapping_t, + removeprefix as removeprefix, + removesuffix as removesuffix, + extract_files as extract_files, + is_sequence_t as is_sequence_t, + required_args as required_args, + coerce_boolean as coerce_boolean, + coerce_integer as coerce_integer, + file_from_path as file_from_path, + parse_datetime as parse_datetime, + strip_not_given as strip_not_given, + deepcopy_minimal as deepcopy_minimal, + get_async_library as get_async_library, + maybe_coerce_float as maybe_coerce_float, + get_required_header as get_required_header, + maybe_coerce_boolean as maybe_coerce_boolean, + maybe_coerce_integer as maybe_coerce_integer, +) +from ._typing import ( + is_list_type as is_list_type, + is_union_type as is_union_type, + extract_type_arg as extract_type_arg, + is_iterable_type as is_iterable_type, + is_required_type as is_required_type, + is_annotated_type as is_annotated_type, + strip_annotated_type as strip_annotated_type, + extract_type_var_from_base as extract_type_var_from_base, +) +from ._transform import ( + PropertyInfo as PropertyInfo, + transform as transform, + async_transform as async_transform, + maybe_transform as maybe_transform, + async_maybe_transform as async_maybe_transform, +) diff --git a/model-providers/model_providers/_utils/_transform.py b/model-providers/model_providers/_utils/_transform.py new file mode 100644 index 00000000..47e262a5 --- /dev/null +++ b/model-providers/model_providers/_utils/_transform.py @@ -0,0 +1,382 @@ +from __future__ import annotations + +import io +import base64 +import pathlib +from typing import Any, Mapping, TypeVar, cast +from datetime import date, datetime +from typing_extensions import Literal, get_args, override, get_type_hints + +import anyio +import pydantic + +from ._utils import ( + is_list, + is_mapping, + is_iterable, +) +from .._files import is_base64_file_input +from ._typing import ( + is_list_type, + is_union_type, + extract_type_arg, + is_iterable_type, + is_required_type, + is_annotated_type, + strip_annotated_type, +) +from .._compat import model_dump, is_typeddict + +_T = TypeVar("_T") + + +# TODO: support for drilling globals() and locals() +# TODO: ensure works correctly with forward references in all cases + + +PropertyFormat = Literal["iso8601", "base64", "custom"] + + +class PropertyInfo: + """Metadata class to be used in Annotated types to provide information about a given type. + + For example: + + class MyParams(TypedDict): + account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')] + + This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API. + """ + + alias: str | None + format: PropertyFormat | None + format_template: str | None + discriminator: str | None + + def __init__( + self, + *, + alias: str | None = None, + format: PropertyFormat | None = None, + format_template: str | None = None, + discriminator: str | None = None, + ) -> None: + self.alias = alias + self.format = format + self.format_template = format_template + self.discriminator = discriminator + + @override + def __repr__(self) -> str: + return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')" + + +def maybe_transform( + data: object, + expected_type: object, +) -> Any | None: + """Wrapper over `transform()` that allows `None` to be passed. + + See `transform()` for more details. + """ + if data is None: + return None + return transform(data, expected_type) + + +# Wrapper over _transform_recursive providing fake types +def transform( + data: _T, + expected_type: object, +) -> _T: + """Transform dictionaries based off of type information from the given type, for example: + + ```py + class Params(TypedDict, total=False): + card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]] + + + transformed = transform({"card_id": ""}, Params) + # {'cardID': ''} + ``` + + Any keys / data that does not have type information given will be included as is. + + It should be noted that the transformations that this function does are not represented in the type system. + """ + transformed = _transform_recursive(data, annotation=cast(type, expected_type)) + return cast(_T, transformed) + + +def _get_annotated_type(type_: type) -> type | None: + """If the given type is an `Annotated` type then it is returned, if not `None` is returned. + + This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]` + """ + if is_required_type(type_): + # Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]` + type_ = get_args(type_)[0] + + if is_annotated_type(type_): + return type_ + + return None + + +def _maybe_transform_key(key: str, type_: type) -> str: + """Transform the given `data` based on the annotations provided in `type_`. + + Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata. + """ + annotated_type = _get_annotated_type(type_) + if annotated_type is None: + # no `Annotated` definition for this type, no transformation needed + return key + + # ignore the first argument as it is the actual type + annotations = get_args(annotated_type)[1:] + for annotation in annotations: + if isinstance(annotation, PropertyInfo) and annotation.alias is not None: + return annotation.alias + + return key + + +def _transform_recursive( + data: object, + *, + annotation: type, + inner_type: type | None = None, +) -> object: + """Transform the given data against the expected type. + + Args: + annotation: The direct type annotation given to the particular piece of data. + This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc + + inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type + is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in + the list can be transformed using the metadata from the container type. + + Defaults to the same value as the `annotation` argument. + """ + if inner_type is None: + inner_type = annotation + + stripped_type = strip_annotated_type(inner_type) + if is_typeddict(stripped_type) and is_mapping(data): + return _transform_typeddict(data, stripped_type) + + if ( + # List[T] + (is_list_type(stripped_type) and is_list(data)) + # Iterable[T] + or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + ): + inner_type = extract_type_arg(stripped_type, 0) + return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] + + if is_union_type(stripped_type): + # For union types we run the transformation against all subtypes to ensure that everything is transformed. + # + # TODO: there may be edge cases where the same normalized field name will transform to two different names + # in different subtypes. + for subtype in get_args(stripped_type): + data = _transform_recursive(data, annotation=annotation, inner_type=subtype) + return data + + if isinstance(data, pydantic.BaseModel): + return model_dump(data, exclude_unset=True) + + annotated_type = _get_annotated_type(annotation) + if annotated_type is None: + return data + + # ignore the first argument as it is the actual type + annotations = get_args(annotated_type)[1:] + for annotation in annotations: + if isinstance(annotation, PropertyInfo) and annotation.format is not None: + return _format_data(data, annotation.format, annotation.format_template) + + return data + + +def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: + if isinstance(data, (date, datetime)): + if format_ == "iso8601": + return data.isoformat() + + if format_ == "custom" and format_template is not None: + return data.strftime(format_template) + + if format_ == "base64" and is_base64_file_input(data): + binary: str | bytes | None = None + + if isinstance(data, pathlib.Path): + binary = data.read_bytes() + elif isinstance(data, io.IOBase): + binary = data.read() + + if isinstance(binary, str): # type: ignore[unreachable] + binary = binary.encode() + + if not isinstance(binary, bytes): + raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + + return base64.b64encode(binary).decode("ascii") + + return data + + +def _transform_typeddict( + data: Mapping[str, object], + expected_type: type, +) -> Mapping[str, object]: + result: dict[str, object] = {} + annotations = get_type_hints(expected_type, include_extras=True) + for key, value in data.items(): + type_ = annotations.get(key) + if type_ is None: + # we do not have a type annotation for this field, leave it as is + result[key] = value + else: + result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_) + return result + + +async def async_maybe_transform( + data: object, + expected_type: object, +) -> Any | None: + """Wrapper over `async_transform()` that allows `None` to be passed. + + See `async_transform()` for more details. + """ + if data is None: + return None + return await async_transform(data, expected_type) + + +async def async_transform( + data: _T, + expected_type: object, +) -> _T: + """Transform dictionaries based off of type information from the given type, for example: + + ```py + class Params(TypedDict, total=False): + card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]] + + + transformed = transform({"card_id": ""}, Params) + # {'cardID': ''} + ``` + + Any keys / data that does not have type information given will be included as is. + + It should be noted that the transformations that this function does are not represented in the type system. + """ + transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type)) + return cast(_T, transformed) + + +async def _async_transform_recursive( + data: object, + *, + annotation: type, + inner_type: type | None = None, +) -> object: + """Transform the given data against the expected type. + + Args: + annotation: The direct type annotation given to the particular piece of data. + This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc + + inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type + is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in + the list can be transformed using the metadata from the container type. + + Defaults to the same value as the `annotation` argument. + """ + if inner_type is None: + inner_type = annotation + + stripped_type = strip_annotated_type(inner_type) + if is_typeddict(stripped_type) and is_mapping(data): + return await _async_transform_typeddict(data, stripped_type) + + if ( + # List[T] + (is_list_type(stripped_type) and is_list(data)) + # Iterable[T] + or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + ): + inner_type = extract_type_arg(stripped_type, 0) + return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] + + if is_union_type(stripped_type): + # For union types we run the transformation against all subtypes to ensure that everything is transformed. + # + # TODO: there may be edge cases where the same normalized field name will transform to two different names + # in different subtypes. + for subtype in get_args(stripped_type): + data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype) + return data + + if isinstance(data, pydantic.BaseModel): + return model_dump(data, exclude_unset=True) + + annotated_type = _get_annotated_type(annotation) + if annotated_type is None: + return data + + # ignore the first argument as it is the actual type + annotations = get_args(annotated_type)[1:] + for annotation in annotations: + if isinstance(annotation, PropertyInfo) and annotation.format is not None: + return await _async_format_data(data, annotation.format, annotation.format_template) + + return data + + +async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: + if isinstance(data, (date, datetime)): + if format_ == "iso8601": + return data.isoformat() + + if format_ == "custom" and format_template is not None: + return data.strftime(format_template) + + if format_ == "base64" and is_base64_file_input(data): + binary: str | bytes | None = None + + if isinstance(data, pathlib.Path): + binary = await anyio.Path(data).read_bytes() + elif isinstance(data, io.IOBase): + binary = data.read() + + if isinstance(binary, str): # type: ignore[unreachable] + binary = binary.encode() + + if not isinstance(binary, bytes): + raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + + return base64.b64encode(binary).decode("ascii") + + return data + + +async def _async_transform_typeddict( + data: Mapping[str, object], + expected_type: type, +) -> Mapping[str, object]: + result: dict[str, object] = {} + annotations = get_type_hints(expected_type, include_extras=True) + for key, value in data.items(): + type_ = annotations.get(key) + if type_ is None: + # we do not have a type annotation for this field, leave it as is + result[key] = value + else: + result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_) + return result diff --git a/model-providers/model_providers/_utils/_typing.py b/model-providers/model_providers/_utils/_typing.py new file mode 100644 index 00000000..003ca84a --- /dev/null +++ b/model-providers/model_providers/_utils/_typing.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from typing import Any, TypeVar, Iterable, cast +from collections import abc as _c_abc +from typing_extensions import Required, Annotated, get_args, get_origin + +from .._types import InheritsGeneric +from .._compat import is_union as _is_union + + +def is_annotated_type(typ: type) -> bool: + return get_origin(typ) == Annotated + + +def is_list_type(typ: type) -> bool: + return (get_origin(typ) or typ) == list + + +def is_iterable_type(typ: type) -> bool: + """If the given type is `typing.Iterable[T]`""" + origin = get_origin(typ) or typ + return origin == Iterable or origin == _c_abc.Iterable + + +def is_union_type(typ: type) -> bool: + return _is_union(get_origin(typ)) + + +def is_required_type(typ: type) -> bool: + return get_origin(typ) == Required + + +def is_typevar(typ: type) -> bool: + # type ignore is required because type checkers + # think this expression will always return False + return type(typ) == TypeVar # type: ignore + + +# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]] +def strip_annotated_type(typ: type) -> type: + if is_required_type(typ) or is_annotated_type(typ): + return strip_annotated_type(cast(type, get_args(typ)[0])) + + return typ + + +def extract_type_arg(typ: type, index: int) -> type: + args = get_args(typ) + try: + return cast(type, args[index]) + except IndexError as err: + raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err + + +def extract_type_var_from_base( + typ: type, + *, + generic_bases: tuple[type, ...], + index: int, + failure_message: str | None = None, +) -> type: + """Given a type like `Foo[T]`, returns the generic type variable `T`. + + This also handles the case where a concrete subclass is given, e.g. + ```py + class MyResponse(Foo[bytes]): + ... + + extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes + ``` + + And where a generic subclass is given: + ```py + _T = TypeVar('_T') + class MyResponse(Foo[_T]): + ... + + extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes + ``` + """ + cls = cast(object, get_origin(typ) or typ) + if cls in generic_bases: + # we're given the class directly + return extract_type_arg(typ, index) + + # if a subclass is given + # --- + # this is needed as __orig_bases__ is not present in the typeshed stubs + # because it is intended to be for internal use only, however there does + # not seem to be a way to resolve generic TypeVars for inherited subclasses + # without using it. + if isinstance(cls, InheritsGeneric): + target_base_class: Any | None = None + for base in cls.__orig_bases__: + if base.__origin__ in generic_bases: + target_base_class = base + break + + if target_base_class is None: + raise RuntimeError( + "Could not find the generic base class;\n" + "This should never happen;\n" + f"Does {cls} inherit from one of {generic_bases} ?" + ) + + extracted = extract_type_arg(target_base_class, index) + if is_typevar(extracted): + # If the extracted type argument is itself a type variable + # then that means the subclass itself is generic, so we have + # to resolve the type argument from the class itself, not + # the base class. + # + # Note: if there is more than 1 type argument, the subclass could + # change the ordering of the type arguments, this is not currently + # supported. + return extract_type_arg(typ, index) + + return extracted + + raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}") diff --git a/model-providers/model_providers/_utils/_utils.py b/model-providers/model_providers/_utils/_utils.py new file mode 100644 index 00000000..17904ce6 --- /dev/null +++ b/model-providers/model_providers/_utils/_utils.py @@ -0,0 +1,403 @@ +from __future__ import annotations + +import os +import re +import inspect +import functools +from typing import ( + Any, + Tuple, + Mapping, + TypeVar, + Callable, + Iterable, + Sequence, + cast, + overload, +) +from pathlib import Path +from typing_extensions import TypeGuard + +import sniffio + +from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike +from .._compat import parse_date as parse_date, parse_datetime as parse_datetime + +_T = TypeVar("_T") +_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...]) +_MappingT = TypeVar("_MappingT", bound=Mapping[str, object]) +_SequenceT = TypeVar("_SequenceT", bound=Sequence[object]) +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) + + +def flatten(t: Iterable[Iterable[_T]]) -> list[_T]: + return [item for sublist in t for item in sublist] + + +def extract_files( + # TODO: this needs to take Dict but variance issues..... + # create protocol type ? + query: Mapping[str, object], + *, + paths: Sequence[Sequence[str]], +) -> list[tuple[str, FileTypes]]: + """Recursively extract files from the given dictionary based on specified paths. + + A path may look like this ['foo', 'files', '', 'data']. + + Note: this mutates the given dictionary. + """ + files: list[tuple[str, FileTypes]] = [] + for path in paths: + files.extend(_extract_items(query, path, index=0, flattened_key=None)) + return files + + +def _extract_items( + obj: object, + path: Sequence[str], + *, + index: int, + flattened_key: str | None, +) -> list[tuple[str, FileTypes]]: + try: + key = path[index] + except IndexError: + if isinstance(obj, NotGiven): + # no value was provided - we can safely ignore + return [] + + # cyclical import + from .._files import assert_is_file_content + + # We have exhausted the path, return the entry we found. + assert_is_file_content(obj, key=flattened_key) + assert flattened_key is not None + return [(flattened_key, cast(FileTypes, obj))] + + index += 1 + if is_dict(obj): + try: + # We are at the last entry in the path so we must remove the field + if (len(path)) == index: + item = obj.pop(key) + else: + item = obj[key] + except KeyError: + # Key was not present in the dictionary, this is not indicative of an error + # as the given path may not point to a required field. We also do not want + # to enforce required fields as the API may differ from the spec in some cases. + return [] + if flattened_key is None: + flattened_key = key + else: + flattened_key += f"[{key}]" + return _extract_items( + item, + path, + index=index, + flattened_key=flattened_key, + ) + elif is_list(obj): + if key != "": + return [] + + return flatten( + [ + _extract_items( + item, + path, + index=index, + flattened_key=flattened_key + "[]" if flattened_key is not None else "[]", + ) + for item in obj + ] + ) + + # Something unexpected was passed, just ignore it. + return [] + + +def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]: + return not isinstance(obj, NotGiven) + + +# Type safe methods for narrowing types with TypeVars. +# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown], +# however this cause Pyright to rightfully report errors. As we know we don't +# care about the contained types we can safely use `object` in it's place. +# +# There are two separate functions defined, `is_*` and `is_*_t` for different use cases. +# `is_*` is for when you're dealing with an unknown input +# `is_*_t` is for when you're narrowing a known union type to a specific subset + + +def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]: + return isinstance(obj, tuple) + + +def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]: + return isinstance(obj, tuple) + + +def is_sequence(obj: object) -> TypeGuard[Sequence[object]]: + return isinstance(obj, Sequence) + + +def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]: + return isinstance(obj, Sequence) + + +def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]: + return isinstance(obj, Mapping) + + +def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]: + return isinstance(obj, Mapping) + + +def is_dict(obj: object) -> TypeGuard[dict[object, object]]: + return isinstance(obj, dict) + + +def is_list(obj: object) -> TypeGuard[list[object]]: + return isinstance(obj, list) + + +def is_iterable(obj: object) -> TypeGuard[Iterable[object]]: + return isinstance(obj, Iterable) + + +def deepcopy_minimal(item: _T) -> _T: + """Minimal reimplementation of copy.deepcopy() that will only copy certain object types: + + - mappings, e.g. `dict` + - list + + This is done for performance reasons. + """ + if is_mapping(item): + return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()}) + if is_list(item): + return cast(_T, [deepcopy_minimal(entry) for entry in item]) + return item + + +# copied from https://github.com/Rapptz/RoboDanny +def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str: + size = len(seq) + if size == 0: + return "" + + if size == 1: + return seq[0] + + if size == 2: + return f"{seq[0]} {final} {seq[1]}" + + return delim.join(seq[:-1]) + f" {final} {seq[-1]}" + + +def quote(string: str) -> str: + """Add single quotation marks around the given string. Does *not* do any escaping.""" + return f"'{string}'" + + +def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: + """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function. + + Useful for enforcing runtime validation of overloaded functions. + + Example usage: + ```py + @overload + def foo(*, a: str) -> str: + ... + + + @overload + def foo(*, b: bool) -> str: + ... + + + # This enforces the same constraints that a static type checker would + # i.e. that either a or b must be passed to the function + @required_args(["a"], ["b"]) + def foo(*, a: str | None = None, b: bool | None = None) -> str: + ... + ``` + """ + + def inner(func: CallableT) -> CallableT: + params = inspect.signature(func).parameters + positional = [ + name + for name, param in params.items() + if param.kind + in { + param.POSITIONAL_ONLY, + param.POSITIONAL_OR_KEYWORD, + } + ] + + @functools.wraps(func) + def wrapper(*args: object, **kwargs: object) -> object: + given_params: set[str] = set() + for i, _ in enumerate(args): + try: + given_params.add(positional[i]) + except IndexError: + raise TypeError( + f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given" + ) from None + + for key in kwargs.keys(): + given_params.add(key) + + for variant in variants: + matches = all((param in given_params for param in variant)) + if matches: + break + else: # no break + if len(variants) > 1: + variations = human_join( + ["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants] + ) + msg = f"Missing required arguments; Expected either {variations} arguments to be given" + else: + assert len(variants) > 0 + + # TODO: this error message is not deterministic + missing = list(set(variants[0]) - given_params) + if len(missing) > 1: + msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}" + else: + msg = f"Missing required argument: {quote(missing[0])}" + raise TypeError(msg) + return func(*args, **kwargs) + + return wrapper # type: ignore + + return inner + + +_K = TypeVar("_K") +_V = TypeVar("_V") + + +@overload +def strip_not_given(obj: None) -> None: + ... + + +@overload +def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: + ... + + +@overload +def strip_not_given(obj: object) -> object: + ... + + +def strip_not_given(obj: object | None) -> object: + """Remove all top-level keys where their values are instances of `NotGiven`""" + if obj is None: + return None + + if not is_mapping(obj): + return obj + + return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)} + + +def coerce_integer(val: str) -> int: + return int(val, base=10) + + +def coerce_float(val: str) -> float: + return float(val) + + +def coerce_boolean(val: str) -> bool: + return val == "true" or val == "1" or val == "on" + + +def maybe_coerce_integer(val: str | None) -> int | None: + if val is None: + return None + return coerce_integer(val) + + +def maybe_coerce_float(val: str | None) -> float | None: + if val is None: + return None + return coerce_float(val) + + +def maybe_coerce_boolean(val: str | None) -> bool | None: + if val is None: + return None + return coerce_boolean(val) + + +def removeprefix(string: str, prefix: str) -> str: + """Remove a prefix from a string. + + Backport of `str.removeprefix` for Python < 3.9 + """ + if string.startswith(prefix): + return string[len(prefix) :] + return string + + +def removesuffix(string: str, suffix: str) -> str: + """Remove a suffix from a string. + + Backport of `str.removesuffix` for Python < 3.9 + """ + if string.endswith(suffix): + return string[: -len(suffix)] + return string + + +def file_from_path(path: str) -> FileTypes: + contents = Path(path).read_bytes() + file_name = os.path.basename(path) + return (file_name, contents) + + +def get_required_header(headers: HeadersLike, header: str) -> str: + lower_header = header.lower() + if isinstance(headers, Mapping): + headers = cast(Headers, headers) + for k, v in headers.items(): + if k.lower() == lower_header and isinstance(v, str): + return v + + """ to deal with the case where the header looks like Stainless-Event-Id """ + intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize()) + + for normalized_header in [header, lower_header, header.upper(), intercaps_header]: + value = headers.get(normalized_header) + if value: + return value + + raise ValueError(f"Could not find {header} header") + + +def get_async_library() -> str: + try: + return sniffio.current_async_library() + except Exception: + return "false" + + +def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]: + """A version of functools.lru_cache that retains the type signature + for the wrapped function arguments. + """ + wrapper = functools.lru_cache( # noqa: TID251 + maxsize=maxsize, + ) + return cast(Any, wrapper) # type: ignore[no-any-return] diff --git a/model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py b/model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py index 77819c8a..36a901dc 100644 --- a/model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py +++ b/model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py @@ -1,7 +1,7 @@ from enum import Enum from typing import List, Literal, Optional -from pydantic import BaseModel +from ..._models import BaseModel from model_providers.core.entities.model_entities import ( ModelStatus, diff --git a/model-providers/model_providers/core/bootstrap/openai_protocol.py b/model-providers/model_providers/core/bootstrap/openai_protocol.py index 2bd364f3..5f30c1cc 100644 --- a/model-providers/model_providers/core/bootstrap/openai_protocol.py +++ b/model-providers/model_providers/core/bootstrap/openai_protocol.py @@ -1,8 +1,8 @@ import time from enum import Enum from typing import Any, Dict, List, Optional, Union - -from pydantic import BaseModel, Field, root_validator +from ..._models import BaseModel +from pydantic import Field as FieldInfo from typing_extensions import Literal @@ -81,7 +81,7 @@ class ModelCard(BaseModel): "tts", "text2img", ] = "llm" - created: int = Field(default_factory=lambda: int(time.time())) + created: int = FieldInfo(default_factory=lambda: int(time.time())) owned_by: Literal["owner"] = "owner" @@ -171,7 +171,7 @@ class ChatCompletionStreamResponseChoice(BaseModel): class ChatCompletionResponse(BaseModel): id: str object: Literal["chat.completion"] = "chat.completion" - created: int = Field(default_factory=lambda: int(time.time())) + created: int = FieldInfo(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseChoice] usage: UsageInfo @@ -180,7 +180,7 @@ class ChatCompletionResponse(BaseModel): class ChatCompletionStreamResponse(BaseModel): id: str object: Literal["chat.completion.chunk"] = "chat.completion.chunk" - created: int = Field(default_factory=lambda: int(time.time())) + created: int = FieldInfo(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionStreamResponseChoice] diff --git a/model-providers/model_providers/core/entities/application_entities.py b/model-providers/model_providers/core/entities/application_entities.py deleted file mode 100644 index 263693cc..00000000 --- a/model-providers/model_providers/core/entities/application_entities.py +++ /dev/null @@ -1,331 +0,0 @@ -from enum import Enum -from typing import Any, Literal, Optional, Union - -from pydantic import BaseModel - -from model_providers.core.entities.provider_configuration import ProviderModelBundle -from model_providers.core.file.file_obj import FileObj -from model_providers.core.model_runtime.entities.message_entities import ( - PromptMessageRole, -) -from model_providers.core.model_runtime.entities.model_entities import AIModelEntity - - -class ModelConfigEntity(BaseModel): - """ - Model Config Entity. - """ - - provider: str - model: str - model_schema: AIModelEntity - mode: str - provider_model_bundle: ProviderModelBundle - credentials: Dict[str, Any] = {} - parameters: Dict[str, Any] = {} - stop: List[str] = [] - - -class AdvancedChatMessageEntity(BaseModel): - """ - Advanced Chat Message Entity. - """ - - text: str - role: PromptMessageRole - - -class AdvancedChatPromptTemplateEntity(BaseModel): - """ - Advanced Chat Prompt Template Entity. - """ - - messages: List[AdvancedChatMessageEntity] - - -class AdvancedCompletionPromptTemplateEntity(BaseModel): - """ - Advanced Completion Prompt Template Entity. - """ - - class RolePrefixEntity(BaseModel): - """ - Role Prefix Entity. - """ - - user: str - assistant: str - - prompt: str - role_prefix: Optional[RolePrefixEntity] = None - - -class PromptTemplateEntity(BaseModel): - """ - Prompt Template Entity. - """ - - class PromptType(Enum): - """ - Prompt Type. - 'simple', 'advanced' - """ - - SIMPLE = "simple" - ADVANCED = "advanced" - - @classmethod - def value_of(cls, value: str) -> "PromptType": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid prompt type value {value}") - - prompt_type: PromptType - simple_prompt_template: Optional[str] = None - advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None - advanced_completion_prompt_template: Optional[ - AdvancedCompletionPromptTemplateEntity - ] = None - - -class ExternalDataVariableEntity(BaseModel): - """ - External Data Variable Entity. - """ - - variable: str - type: str - config: Dict[str, Any] = {} - - -class DatasetRetrieveConfigEntity(BaseModel): - """ - Dataset Retrieve Config Entity. - """ - - class RetrieveStrategy(Enum): - """ - Dataset Retrieve Strategy. - 'single' or 'multiple' - """ - - SINGLE = "single" - MULTIPLE = "multiple" - - @classmethod - def value_of(cls, value: str) -> "RetrieveStrategy": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid retrieve strategy value {value}") - - query_variable: Optional[str] = None # Only when app mode is completion - - retrieve_strategy: RetrieveStrategy - single_strategy: Optional[str] = None # for temp - top_k: Optional[int] = None - score_threshold: Optional[float] = None - reranking_model: Optional[dict] = None - - -class DatasetEntity(BaseModel): - """ - Dataset Config Entity. - """ - - dataset_ids: List[str] - retrieve_config: DatasetRetrieveConfigEntity - - -class SensitiveWordAvoidanceEntity(BaseModel): - """ - Sensitive Word Avoidance Entity. - """ - - type: str - config: Dict[str, Any] = {} - - -class TextToSpeechEntity(BaseModel): - """ - Sensitive Word Avoidance Entity. - """ - - enabled: bool - voice: Optional[str] = None - language: Optional[str] = None - - -class FileUploadEntity(BaseModel): - """ - File Upload Entity. - """ - - image_config: Optional[dict[str, Any]] = None - - -class AgentToolEntity(BaseModel): - """ - Agent Tool Entity. - """ - - provider_type: Literal["builtin", "api"] - provider_id: str - tool_name: str - tool_parameters: Dict[str, Any] = {} - - -class AgentPromptEntity(BaseModel): - """ - Agent Prompt Entity. - """ - - first_prompt: str - next_iteration: str - - -class AgentScratchpadUnit(BaseModel): - """ - Agent First Prompt Entity. - """ - - class Action(BaseModel): - """ - Action Entity. - """ - - action_name: str - action_input: Union[dict, str] - - agent_response: Optional[str] = None - thought: Optional[str] = None - action_str: Optional[str] = None - observation: Optional[str] = None - action: Optional[Action] = None - - -class AgentEntity(BaseModel): - """ - Agent Entity. - """ - - class Strategy(Enum): - """ - Agent Strategy. - """ - - CHAIN_OF_THOUGHT = "chain-of-thought" - FUNCTION_CALLING = "function-calling" - - provider: str - model: str - strategy: Strategy - prompt: Optional[AgentPromptEntity] = None - tools: List[AgentToolEntity] = None - max_iteration: int = 5 - - -class AppOrchestrationConfigEntity(BaseModel): - """ - App Orchestration Config Entity. - """ - - model_config: ModelConfigEntity - prompt_template: PromptTemplateEntity - external_data_variables: List[ExternalDataVariableEntity] = [] - agent: Optional[AgentEntity] = None - - # features - dataset: Optional[DatasetEntity] = None - file_upload: Optional[FileUploadEntity] = None - opening_statement: Optional[str] = None - suggested_questions_after_answer: bool = False - show_retrieve_source: bool = False - more_like_this: bool = False - speech_to_text: bool = False - text_to_speech: dict = {} - sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None - - -class InvokeFrom(Enum): - """ - Invoke From. - """ - - SERVICE_API = "service-api" - WEB_APP = "web-app" - EXPLORE = "explore" - DEBUGGER = "debugger" - - @classmethod - def value_of(cls, value: str) -> "InvokeFrom": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid invoke from value {value}") - - def to_source(self) -> str: - """ - Get source of invoke from. - - :return: source - """ - if self == InvokeFrom.WEB_APP: - return "web_app" - elif self == InvokeFrom.DEBUGGER: - return "dev" - elif self == InvokeFrom.EXPLORE: - return "explore_app" - elif self == InvokeFrom.SERVICE_API: - return "api" - - return "dev" - - -class ApplicationGenerateEntity(BaseModel): - """ - Application Generate Entity. - """ - - task_id: str - tenant_id: str - - app_id: str - app_model_config_id: str - # for save - app_model_config_dict: dict - app_model_config_override: bool - - # Converted from app_model_config to Entity object, or directly covered by external input - app_orchestration_config_entity: AppOrchestrationConfigEntity - - conversation_id: Optional[str] = None - inputs: Dict[str, str] - query: Optional[str] = None - files: List[FileObj] = [] - user_id: str - # extras - stream: bool - invoke_from: InvokeFrom - - # extra parameters, like: auto_generate_conversation_name - extras: Dict[str, Any] = {} diff --git a/model-providers/model_providers/core/entities/message_entities.py b/model-providers/model_providers/core/entities/message_entities.py index 52aa3fa0..c768c0e9 100644 --- a/model-providers/model_providers/core/entities/message_entities.py +++ b/model-providers/model_providers/core/entities/message_entities.py @@ -1,5 +1,5 @@ import enum -from typing import Any, cast +from typing import Any, cast, List from langchain.schema import ( AIMessage, @@ -8,7 +8,7 @@ from langchain.schema import ( HumanMessage, SystemMessage, ) -from pydantic import BaseModel +from ..._models import BaseModel from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, diff --git a/model-providers/model_providers/core/entities/model_entities.py b/model-providers/model_providers/core/entities/model_entities.py index cfaf6b82..c95be8ae 100644 --- a/model-providers/model_providers/core/entities/model_entities.py +++ b/model-providers/model_providers/core/entities/model_entities.py @@ -1,7 +1,7 @@ from enum import Enum from typing import List, Optional -from pydantic import BaseModel +from ..._models import BaseModel from model_providers.core.model_runtime.entities.common_entities import I18nObject from model_providers.core.model_runtime.entities.model_entities import ( diff --git a/model-providers/model_providers/core/entities/provider_configuration.py b/model-providers/model_providers/core/entities/provider_configuration.py index a068bd9e..6e887a14 100644 --- a/model-providers/model_providers/core/entities/provider_configuration.py +++ b/model-providers/model_providers/core/entities/provider_configuration.py @@ -4,7 +4,7 @@ import logging from json import JSONDecodeError from typing import Dict, Iterator, List, Optional -from pydantic import BaseModel +from ..._models import BaseModel from model_providers.core.entities.model_entities import ( ModelStatus, diff --git a/model-providers/model_providers/core/entities/provider_entities.py b/model-providers/model_providers/core/entities/provider_entities.py index 7b0705db..1e6dc7bc 100644 --- a/model-providers/model_providers/core/entities/provider_entities.py +++ b/model-providers/model_providers/core/entities/provider_entities.py @@ -1,7 +1,7 @@ from enum import Enum from typing import List, Optional -from pydantic import BaseModel +from ..._models import BaseModel from model_providers.core.model_runtime.entities.model_entities import ModelType diff --git a/model-providers/model_providers/core/entities/queue_entities.py b/model-providers/model_providers/core/entities/queue_entities.py index f72080cb..3dba6d29 100644 --- a/model-providers/model_providers/core/entities/queue_entities.py +++ b/model-providers/model_providers/core/entities/queue_entities.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any -from pydantic import BaseModel +from ..._models import BaseModel from model_providers.core.model_runtime.entities.llm_entities import ( LLMResult, diff --git a/model-providers/model_providers/core/model_runtime/entities/common_entities.py b/model-providers/model_providers/core/model_runtime/entities/common_entities.py index 659ad59b..944d2fd5 100644 --- a/model-providers/model_providers/core/model_runtime/entities/common_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/common_entities.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import BaseModel +from ...._models import BaseModel class I18nObject(BaseModel): diff --git a/model-providers/model_providers/core/model_runtime/entities/llm_entities.py b/model-providers/model_providers/core/model_runtime/entities/llm_entities.py index 8976ff8c..c9d09f86 100644 --- a/model-providers/model_providers/core/model_runtime/entities/llm_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/llm_entities.py @@ -2,7 +2,7 @@ from decimal import Decimal from enum import Enum from typing import List, Optional -from pydantic import BaseModel +from ...._models import BaseModel from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, diff --git a/model-providers/model_providers/core/model_runtime/entities/message_entities.py b/model-providers/model_providers/core/model_runtime/entities/message_entities.py index 0fd0ed17..87a271b9 100644 --- a/model-providers/model_providers/core/model_runtime/entities/message_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/message_entities.py @@ -2,7 +2,7 @@ from abc import ABC from enum import Enum from typing import List, Optional, Union -from pydantic import BaseModel +from ...._models import BaseModel class PromptMessageRole(Enum): diff --git a/model-providers/model_providers/core/model_runtime/entities/model_entities.py b/model-providers/model_providers/core/model_runtime/entities/model_entities.py index 5cc6e80b..10e6d01d 100644 --- a/model-providers/model_providers/core/model_runtime/entities/model_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/model_entities.py @@ -2,7 +2,7 @@ from decimal import Decimal from enum import Enum from typing import Any, Dict, List, Optional -from pydantic import BaseModel +from ...._models import BaseModel from model_providers.core.model_runtime.entities.common_entities import I18nObject diff --git a/model-providers/model_providers/core/model_runtime/entities/provider_entities.py b/model-providers/model_providers/core/model_runtime/entities/provider_entities.py index 2bfee500..923bcab4 100644 --- a/model-providers/model_providers/core/model_runtime/entities/provider_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/provider_entities.py @@ -1,7 +1,7 @@ from enum import Enum from typing import List, Optional -from pydantic import BaseModel +from ...._models import BaseModel from model_providers.core.model_runtime.entities.common_entities import I18nObject from model_providers.core.model_runtime.entities.model_entities import ( diff --git a/model-providers/model_providers/core/model_runtime/entities/rerank_entities.py b/model-providers/model_providers/core/model_runtime/entities/rerank_entities.py index 034a7286..0785ddaf 100644 --- a/model-providers/model_providers/core/model_runtime/entities/rerank_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/rerank_entities.py @@ -1,6 +1,6 @@ from typing import List -from pydantic import BaseModel +from ...._models import BaseModel class RerankDocument(BaseModel): diff --git a/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py b/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py index 454e41ee..b222e649 100644 --- a/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py @@ -1,7 +1,7 @@ from decimal import Decimal from typing import List -from pydantic import BaseModel +from ...._models import BaseModel from model_providers.core.model_runtime.entities.model_entities import ModelUsage diff --git a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_constant.py b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_constant.py index b4a4cbba..2d53dcbc 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_constant.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from ....._models import BaseModel from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from model_providers.core.model_runtime.entities.llm_entities import LLMMode diff --git a/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py b/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py index 6adc96e2..740560a1 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py @@ -3,7 +3,7 @@ import logging import os from typing import Dict, List, Optional, Union -from pydantic import BaseModel +from ...._models import BaseModel from model_providers.core.model_runtime.entities.model_entities import ModelType from model_providers.core.model_runtime.entities.provider_entities import ( diff --git a/model-providers/model_providers/core/model_runtime/utils/_compat.py b/model-providers/model_providers/core/model_runtime/utils/_compat.py deleted file mode 100644 index 5c341527..00000000 --- a/model-providers/model_providers/core/model_runtime/utils/_compat.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Any, Literal - -from pydantic import BaseModel -from pydantic.version import VERSION as PYDANTIC_VERSION - -PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") - -if PYDANTIC_V2: - from pydantic_core import Url as Url - - def _model_dump( - model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any - ) -> Any: - return model.model_dump(mode=mode, **kwargs) -else: - from pydantic import AnyUrl as Url # noqa: F401 - - def _model_dump( - model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any - ) -> Any: - return model.dict(**kwargs) diff --git a/model-providers/model_providers/core/model_runtime/utils/encoders.py b/model-providers/model_providers/core/model_runtime/utils/encoders.py deleted file mode 100644 index fe5836c8..00000000 --- a/model-providers/model_providers/core/model_runtime/utils/encoders.py +++ /dev/null @@ -1,234 +0,0 @@ -import dataclasses -import datetime -from collections import defaultdict, deque -from collections.abc import Callable -from decimal import Decimal -from enum import Enum -from ipaddress import ( - IPv4Address, - IPv4Interface, - IPv4Network, - IPv6Address, - IPv6Interface, - IPv6Network, -) -from pathlib import Path, PurePath -from re import Pattern -from types import GeneratorType -from typing import Any, Optional, Union, Dict, Type, List, Tuple -from uuid import UUID - -from pydantic import BaseModel -from pydantic.color import Color -from pydantic.networks import AnyUrl, NameEmail -from pydantic.types import SecretBytes, SecretStr - -from ._compat import PYDANTIC_V2, Url, _model_dump - - -# Taken from Pydantic v1 as is -def isoformat(o: Union[datetime.date, datetime.time]) -> str: - return o.isoformat() - - -# Taken from Pydantic v1 as is -# TODO: pv2 should this return strings instead? -def decimal_encoder(dec_value: Decimal) -> Union[int, float]: - """ - Encodes a Decimal as int of there's no exponent, otherwise float - - This is useful when we use ConstrainedDecimal to represent Numeric(x,0) - where a integer (but not int typed) is used. Encoding this as a float - results in failed round-tripping between encode and parse. - Our Id type is a prime example of this. - - >>> decimal_encoder(Decimal("1.0")) - 1.0 - - >>> decimal_encoder(Decimal("1")) - 1 - """ - if dec_value.as_tuple().exponent >= 0: # type: ignore[operator] - return int(dec_value) - else: - return float(dec_value) - - -ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { - bytes: lambda o: o.decode(), - Color: str, - datetime.date: isoformat, - datetime.datetime: isoformat, - datetime.time: isoformat, - datetime.timedelta: lambda td: td.total_seconds(), - Decimal: decimal_encoder, - Enum: lambda o: o.value, - frozenset: list, - deque: list, - GeneratorType: list, - IPv4Address: str, - IPv4Interface: str, - IPv4Network: str, - IPv6Address: str, - IPv6Interface: str, - IPv6Network: str, - NameEmail: str, - Path: str, - Pattern: lambda o: o.pattern, - SecretBytes: str, - SecretStr: str, - set: list, - UUID: str, - Url: str, - AnyUrl: str, -} - - -def generate_encoders_by_class_tuples( - type_encoder_map: Dict[Any, Callable[[Any], Any]], -) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]: - encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict( - tuple - ) - for type_, encoder in type_encoder_map.items(): - encoders_by_class_tuples[encoder] += (type_,) - return encoders_by_class_tuples - - -encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) - - -def jsonable_encoder( - obj: Any, - by_alias: bool = True, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None, - sqlalchemy_safe: bool = True, -) -> Any: - custom_encoder = custom_encoder or {} - if custom_encoder: - if type(obj) in custom_encoder: - return custom_encoder[type(obj)](obj) - else: - for encoder_type, encoder_instance in custom_encoder.items(): - if isinstance(obj, encoder_type): - return encoder_instance(obj) - if isinstance(obj, BaseModel): - # TODO: remove when deprecating Pydantic v1 - encoders: Dict[Any, Any] = {} - if not PYDANTIC_V2: - encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined] - if custom_encoder: - encoders.update(custom_encoder) - obj_dict = _model_dump( - obj, - mode="json", - include=None, - exclude=None, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - exclude_defaults=exclude_defaults, - ) - if "__root__" in obj_dict: - obj_dict = obj_dict["__root__"] - return jsonable_encoder( - obj_dict, - exclude_none=exclude_none, - exclude_defaults=exclude_defaults, - # TODO: remove when deprecating Pydantic v1 - custom_encoder=encoders, - sqlalchemy_safe=sqlalchemy_safe, - ) - if dataclasses.is_dataclass(obj): - obj_dict = dataclasses.asdict(obj) - return jsonable_encoder( - obj_dict, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - if isinstance(obj, Enum): - return obj.value - if isinstance(obj, PurePath): - return str(obj) - if isinstance(obj, str | int | float | type(None)): - return obj - if isinstance(obj, Decimal): - return format(obj, "f") - if isinstance(obj, dict): - encoded_dict = {} - allowed_keys = set(obj.keys()) - for key, value in obj.items(): - if ( - ( - not sqlalchemy_safe - or (not isinstance(key, str)) - or (not key.startswith("_sa")) - ) - and (value is not None or not exclude_none) - and key in allowed_keys - ): - encoded_key = jsonable_encoder( - key, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_value = jsonable_encoder( - value, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_dict[encoded_key] = encoded_value - return encoded_dict - if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)): - encoded_list = [] - for item in obj: - encoded_list.append( - jsonable_encoder( - item, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - ) - return encoded_list - - if type(obj) in ENCODERS_BY_TYPE: - return ENCODERS_BY_TYPE[type(obj)](obj) - for encoder, classes_tuple in encoders_by_class_tuples.items(): - if isinstance(obj, classes_tuple): - return encoder(obj) - - try: - data = dict(obj) - except Exception as e: - errors: List[Exception] = [e] - try: - data = vars(obj) - except Exception as e: - errors.append(e) - raise ValueError(errors) from e - return jsonable_encoder( - data, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) diff --git a/model-providers/model_providers/core/model_runtime/utils/helper.py b/model-providers/model_providers/core/model_runtime/utils/helper.py index 7774868f..c6e97c83 100644 --- a/model-providers/model_providers/core/model_runtime/utils/helper.py +++ b/model-providers/model_providers/core/model_runtime/utils/helper.py @@ -1,5 +1,5 @@ import pydantic -from pydantic import BaseModel +from ...._models import BaseModel def dump_model(model: BaseModel) -> dict: diff --git a/model-providers/model_providers/core/utils/generic.py b/model-providers/model_providers/core/utils/generic.py index 06eb2a7a..a713df09 100644 --- a/model-providers/model_providers/core/utils/generic.py +++ b/model-providers/model_providers/core/utils/generic.py @@ -2,7 +2,7 @@ import json from typing import TYPE_CHECKING, Any, Dict if TYPE_CHECKING: - from pydantic import BaseModel + from ..._models import BaseModel def dictify(data: "BaseModel") -> Dict[str, Any]: diff --git a/model-providers/model_providers/core/utils/json_dumps.py b/model-providers/model_providers/core/utils/json_dumps.py index 2f39ea5a..20b44de5 100644 --- a/model-providers/model_providers/core/utils/json_dumps.py +++ b/model-providers/model_providers/core/utils/json_dumps.py @@ -1,7 +1,7 @@ import os import orjson -from pydantic import BaseModel +from ..._models import BaseModel def json_dumps(o): diff --git a/model-providers/tests/conftest.py b/model-providers/tests/conftest.py index c3191f07..a4508b81 100644 --- a/model-providers/tests/conftest.py +++ b/model-providers/tests/conftest.py @@ -136,7 +136,7 @@ def init_server(logging_conf: dict, providers_file: str) -> None: yield f"http://127.0.0.1:20000" finally: print("") - # boot.destroy() + boot.destroy() except SystemExit: