一些兼容 pydantic<3,>=1.9.0 的代码,

This commit is contained in:
glide-the 2024-05-20 11:35:14 +08:00
parent 2bb24d8b6d
commit d0267bf66b
31 changed files with 2204 additions and 611 deletions

View File

@ -0,0 +1,222 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
from datetime import date, datetime
from typing_extensions import Self
import pydantic
from pydantic.fields import FieldInfo
from ._types import StrBytesIntFloat
_T = TypeVar("_T")
_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
# --------------- Pydantic v2 compatibility ---------------
# Pyright incorrectly reports some of our functions as overriding a method when they don't
# pyright: reportIncompatibleMethodOverride=false
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
# v1 re-exports
if TYPE_CHECKING:
def parse_date(value: Union[date, StrBytesIntFloat]) -> date: # noqa: ARG001
...
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001
...
def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001
...
def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001
...
def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001
...
def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001
...
def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001
...
else:
if PYDANTIC_V2:
from pydantic.v1.typing import (
get_args as get_args,
is_union as is_union,
get_origin as get_origin,
is_typeddict as is_typeddict,
is_literal_type as is_literal_type,
)
from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
else:
from pydantic.typing import (
get_args as get_args,
is_union as is_union,
get_origin as get_origin,
is_typeddict as is_typeddict,
is_literal_type as is_literal_type,
)
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
# refactored config
if TYPE_CHECKING:
from pydantic import ConfigDict as ConfigDict
else:
if PYDANTIC_V2:
from pydantic import ConfigDict
else:
# TODO: provide an error message here?
ConfigDict = None
# renamed methods / properties
def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
if PYDANTIC_V2:
return model.model_validate(value)
else:
return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
def field_is_required(field: FieldInfo) -> bool:
if PYDANTIC_V2:
return field.is_required()
return field.required # type: ignore
def field_get_default(field: FieldInfo) -> Any:
value = field.get_default()
if PYDANTIC_V2:
from pydantic_core import PydanticUndefined
if value == PydanticUndefined:
return None
return value
return value
def field_outer_type(field: FieldInfo) -> Any:
if PYDANTIC_V2:
return field.annotation
return field.outer_type_ # type: ignore
def get_model_config(model: type[pydantic.BaseModel]) -> Any:
if PYDANTIC_V2:
return model.model_config
return model.__config__ # type: ignore
def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
if PYDANTIC_V2:
return model.model_fields
return model.__fields__ # type: ignore
def model_copy(model: _ModelT) -> _ModelT:
if PYDANTIC_V2:
return model.model_copy()
return model.copy() # type: ignore
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
if PYDANTIC_V2:
return model.model_dump_json(indent=indent)
return model.json(indent=indent) # type: ignore
def model_dump(
model: pydantic.BaseModel,
*,
exclude_unset: bool = False,
exclude_defaults: bool = False,
) -> dict[str, Any]:
if PYDANTIC_V2:
return model.model_dump(
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
)
return cast(
"dict[str, Any]",
model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
),
)
def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
if PYDANTIC_V2:
return model.model_validate(data)
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
# generic models
if TYPE_CHECKING:
class GenericModel(pydantic.BaseModel):
...
else:
if PYDANTIC_V2:
# there no longer needs to be a distinction in v2 but
# we still have to create our own subclass to avoid
# inconsistent MRO ordering errors
class GenericModel(pydantic.BaseModel):
...
else:
import pydantic.generics
class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel):
...
# cached properties
if TYPE_CHECKING:
cached_property = property
# we define a separate type (copied from typeshed)
# that represents that `cached_property` is `set`able
# at runtime, which differs from `@property`.
#
# this is a separate type as editors likely special case
# `@property` and we don't want to cause issues just to have
# more helpful internal types.
class typed_cached_property(Generic[_T]):
func: Callable[[Any], _T]
attrname: str | None
def __init__(self, func: Callable[[Any], _T]) -> None:
...
@overload
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self:
...
@overload
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T:
...
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
raise NotImplementedError()
def __set_name__(self, owner: type[Any], name: str) -> None:
...
# __set__ is not defined at runtime, but @cached_property is designed to be settable
def __set__(self, instance: object, value: _T) -> None:
...
else:
try:
from functools import cached_property as cached_property
except ImportError:
from cached_property import cached_property as cached_property
typed_cached_property = cached_property

View File

@ -0,0 +1,127 @@
from __future__ import annotations
import io
import os
import pathlib
from typing import overload
from typing_extensions import TypeGuard
import anyio
from ._types import (
FileTypes,
FileContent,
RequestFiles,
HttpxFileTypes,
Base64FileInput,
HttpxFileContent,
HttpxRequestFiles,
)
from ._utils import is_tuple_t, is_mapping_t, is_sequence_t
def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
def is_file_content(obj: object) -> TypeGuard[FileContent]:
return (
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
)
def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
if not is_file_content(obj):
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
raise RuntimeError(
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads"
) from None
@overload
def to_httpx_files(files: None) -> None:
...
@overload
def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles:
...
def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
if files is None:
return None
if is_mapping_t(files):
files = {key: _transform_file(file) for key, file in files.items()}
elif is_sequence_t(files):
files = [(key, _transform_file(file)) for key, file in files]
else:
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
return files
def _transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = pathlib.Path(file)
return (path.name, path.read_bytes())
return file
if is_tuple_t(file):
return (file[0], _read_file_content(file[1]), *file[2:])
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
def _read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return pathlib.Path(file).read_bytes()
return file
@overload
async def async_to_httpx_files(files: None) -> None:
...
@overload
async def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles:
...
async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
if files is None:
return None
if is_mapping_t(files):
files = {key: await _async_transform_file(file) for key, file in files.items()}
elif is_sequence_t(files):
files = [(key, await _async_transform_file(file)) for key, file in files]
else:
raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence")
return files
async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = anyio.Path(file)
return (path.name, await path.read_bytes())
return file
if is_tuple_t(file):
return (file[0], await _async_read_file_content(file[1]), *file[2:])
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
async def _async_read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return await anyio.Path(file).read_bytes()
return file

View File

@ -0,0 +1,657 @@
from __future__ import annotations
import os
import inspect
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast
from datetime import date, datetime
from typing_extensions import (
Unpack,
Literal,
ClassVar,
Protocol,
Required,
TypedDict,
TypeGuard,
final,
override,
runtime_checkable,
)
import pydantic
import pydantic.generics
from pydantic.fields import FieldInfo
from ._types import (
IncEx,
ModelT,
)
from ._utils import (
PropertyInfo,
is_list,
is_given,
lru_cache,
is_mapping,
parse_date,
coerce_boolean,
parse_datetime,
strip_not_given,
extract_type_arg,
is_annotated_type,
strip_annotated_type,
)
from ._compat import (
PYDANTIC_V2,
ConfigDict,
GenericModel as BaseGenericModel,
get_args,
is_union,
parse_obj,
get_origin,
is_literal_type,
get_model_config,
get_model_fields,
field_get_default,
)
if TYPE_CHECKING:
from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema
__all__ = ["BaseModel", "GenericModel"]
_T = TypeVar("_T")
@runtime_checkable
class _ConfigProtocol(Protocol):
allow_population_by_field_name: bool
class BaseModel(pydantic.BaseModel):
if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
)
else:
@property
@override
def model_fields_set(self) -> set[str]:
# a forwards-compat shim for pydantic v2
return self.__fields_set__ # type: ignore
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
extra: Any = pydantic.Extra.allow # type: ignore
def to_dict(
self,
*,
mode: Literal["json", "python"] = "python",
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> dict[str, object]:
"""Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
mode:
If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`.
If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
exclude_none: Whether to exclude fields that have a value of `None` from the output.
warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2.
"""
return self.model_dump(
mode=mode,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)
def to_json(
self,
*,
indent: int | None = 2,
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> str:
"""Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that have the default value.
exclude_none: Whether to exclude fields that have a value of `None`.
warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2.
"""
return self.model_dump_json(
indent=indent,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)
@override
def __str__(self) -> str:
# mypy complains about an invalid self arg
return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc]
# Override the 'construct' method in a way that supports recursive parsing without validation.
# Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.
@classmethod
@override
def construct(
cls: Type[ModelT],
_fields_set: set[str] | None = None,
**values: object,
) -> ModelT:
m = cls.__new__(cls)
fields_values: dict[str, object] = {}
config = get_model_config(cls)
populate_by_name = (
config.allow_population_by_field_name
if isinstance(config, _ConfigProtocol)
else config.get("populate_by_name")
)
if _fields_set is None:
_fields_set = set()
model_fields = get_model_fields(cls)
for name, field in model_fields.items():
key = field.alias
if key is None or (key not in values and populate_by_name):
key = name
if key in values:
fields_values[name] = _construct_field(value=values[key], field=field, key=key)
_fields_set.add(name)
else:
fields_values[name] = field_get_default(field)
_extra = {}
for key, value in values.items():
if key not in model_fields:
if PYDANTIC_V2:
_extra[key] = value
else:
_fields_set.add(key)
fields_values[key] = value
object.__setattr__(m, "__dict__", fields_values)
if PYDANTIC_V2:
# these properties are copied from Pydantic's `model_construct()` method
object.__setattr__(m, "__pydantic_private__", None)
object.__setattr__(m, "__pydantic_extra__", _extra)
object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
else:
# init_private_attributes() does not exist in v2
m._init_private_attributes() # type: ignore
# copied from Pydantic v1's `construct()` method
object.__setattr__(m, "__fields_set__", _fields_set)
return m
if not TYPE_CHECKING:
# type checkers incorrectly complain about this assignment
# because the type signatures are technically different
# although not in practice
model_construct = construct
if not PYDANTIC_V2:
# we define aliases for some of the new pydantic v2 methods so
# that we can just document these methods without having to specify
# a specific pydantic version as some users may not know which
# pydantic version they are currently using
@override
def model_dump(
self,
*,
mode: Literal["json", "python"] | str = "python",
include: IncEx = None,
exclude: IncEx = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
context: dict[str, Any] | None = None,
serialize_as_any: bool = False,
) -> dict[str, Any]:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
Args:
mode: The mode in which `to_python` should run.
If mode is 'json', the dictionary will only contain JSON serializable types.
If mode is 'python', the dictionary may contain any Python objects.
include: A list of fields to include in the output.
exclude: A list of fields to exclude from the output.
by_alias: Whether to use the field's alias in the dictionary key if defined.
exclude_unset: Whether to exclude fields that are unset or None from the output.
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
exclude_none: Whether to exclude fields that have a value of `None` from the output.
round_trip: Whether to enable serialization and deserialization round-trip support.
warnings: Whether to log warnings when invalid fields are encountered.
Returns:
A dictionary representation of the model.
"""
if mode != "python":
raise ValueError("mode is only supported in Pydantic v2")
if round_trip != False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings != True:
raise ValueError("warnings is only supported in Pydantic v2")
if context is not None:
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
return super().dict( # pyright: ignore[reportDeprecated]
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
@override
def model_dump_json(
self,
*,
indent: int | None = None,
include: IncEx = None,
exclude: IncEx = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
context: dict[str, Any] | None = None,
serialize_as_any: bool = False,
) -> str:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json
Generates a JSON representation of the model using Pydantic's `to_json` method.
Args:
indent: Indentation to use in the JSON output. If None is passed, the output will be compact.
include: Field(s) to include in the JSON output. Can take either a string or set of strings.
exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.
by_alias: Whether to serialize using field aliases.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that have the default value.
exclude_none: Whether to exclude fields that have a value of `None`.
round_trip: Whether to use serialization/deserialization between JSON and class instance.
warnings: Whether to show any warnings that occurred during serialization.
Returns:
A JSON string representation of the model.
"""
if round_trip != False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings != True:
raise ValueError("warnings is only supported in Pydantic v2")
if context is not None:
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
return super().json( # type: ignore[reportDeprecated]
indent=indent,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
def _construct_field(value: object, field: FieldInfo, key: str) -> object:
if value is None:
return field_get_default(field)
if PYDANTIC_V2:
type_ = field.annotation
else:
type_ = cast(type, field.outer_type_) # type: ignore
if type_ is None:
raise RuntimeError(f"Unexpected field type is None for {key}")
return construct_type(value=value, type_=type_)
def is_basemodel(type_: type) -> bool:
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
if is_union(type_):
for variant in get_args(type_):
if is_basemodel(variant):
return True
return False
return is_basemodel_type(type_)
def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
origin = get_origin(type_) or type_
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
def construct_type(*, value: object, type_: object) -> object:
"""Loose coercion to the expected type with construction of nested values.
If the given value does not match the expected type then it is returned as-is.
"""
# we allow `object` as the input type because otherwise, passing things like
# `Literal['value']` will be reported as a type error by type checkers
type_ = cast("type[object]", type_)
# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(type_):
meta: tuple[Any, ...] = get_args(type_)[1:]
type_ = extract_type_arg(type_, 0)
else:
meta = tuple()
# we need to use the origin class for any types that are subscripted generics
# e.g. Dict[str, object]
origin = get_origin(type_) or type_
args = get_args(type_)
if is_union(origin):
try:
return validate_type(type_=cast("type[object]", type_), value=value)
except Exception:
pass
# if the type is a discriminated union then we want to construct the right variant
# in the union, even if the data doesn't match exactly, otherwise we'd break code
# that relies on the constructed class types, e.g.
#
# class FooType:
# kind: Literal['foo']
# value: str
#
# class BarType:
# kind: Literal['bar']
# value: int
#
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
# we'd end up constructing `FooType` when it should be `BarType`.
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
if discriminator and is_mapping(value):
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
if variant_value and isinstance(variant_value, str):
variant_type = discriminator.mapping.get(variant_value)
if variant_type:
return construct_type(type_=variant_type, value=value)
# if the data is not valid, use the first variant that doesn't fail while deserializing
for variant in args:
try:
return construct_type(value=value, type_=variant)
except Exception:
continue
raise RuntimeError(f"Could not convert data into a valid instance of {type_}")
if origin == dict:
if not is_mapping(value):
return value
_, items_type = get_args(type_) # Dict[_, items_type]
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):
if is_list(value):
return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
if is_mapping(value):
if issubclass(type_, BaseModel):
return type_.construct(**value) # type: ignore[arg-type]
return cast(Any, type_).construct(**value)
if origin == list:
if not is_list(value):
return value
inner_type = args[0] # List[inner_type]
return [construct_type(value=entry, type_=inner_type) for entry in value]
if origin == float:
if isinstance(value, int):
coerced = float(value)
if coerced != value:
return value
return coerced
return value
if type_ == datetime:
try:
return parse_datetime(value) # type: ignore
except Exception:
return value
if type_ == date:
try:
return parse_date(value) # type: ignore
except Exception:
return value
return value
@runtime_checkable
class CachedDiscriminatorType(Protocol):
__discriminator__: DiscriminatorDetails
class DiscriminatorDetails:
field_name: str
"""The name of the discriminator field in the variant class, e.g.
```py
class Foo(BaseModel):
type: Literal['foo']
```
Will result in field_name='type'
"""
field_alias_from: str | None
"""The name of the discriminator field in the API response, e.g.
```py
class Foo(BaseModel):
type: Literal['foo'] = Field(alias='type_from_api')
```
Will result in field_alias_from='type_from_api'
"""
mapping: dict[str, type]
"""Mapping of discriminator value to variant type, e.g.
{'foo': FooVariant, 'bar': BarVariant}
"""
def __init__(
self,
*,
mapping: dict[str, type],
discriminator_field: str,
discriminator_alias: str | None,
) -> None:
self.mapping = mapping
self.field_name = discriminator_field
self.field_alias_from = discriminator_alias
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
if isinstance(union, CachedDiscriminatorType):
return union.__discriminator__
discriminator_field_name: str | None = None
for annotation in meta_annotations:
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
discriminator_field_name = annotation.discriminator
break
if not discriminator_field_name:
return None
mapping: dict[str, type] = {}
discriminator_alias: str | None = None
for variant in get_args(union):
variant = strip_annotated_type(variant)
if is_basemodel_type(variant):
if PYDANTIC_V2:
field = _extract_field_schema_pv2(variant, discriminator_field_name)
if not field:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field.get("serialization_alias")
field_schema = field["schema"]
if field_schema["type"] == "literal":
for entry in cast("LiteralSchema", field_schema)["expected"]:
if isinstance(entry, str):
mapping[entry] = variant
else:
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
if not field_info:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field_info.alias
if field_info.annotation and is_literal_type(field_info.annotation):
for entry in get_args(field_info.annotation):
if isinstance(entry, str):
mapping[entry] = variant
if not mapping:
return None
details = DiscriminatorDetails(
mapping=mapping,
discriminator_field=discriminator_field_name,
discriminator_alias=discriminator_alias,
)
cast(CachedDiscriminatorType, union).__discriminator__ = details
return details
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
schema = model.__pydantic_core_schema__
if schema["type"] != "model":
return None
fields_schema = schema["schema"]
if fields_schema["type"] != "model-fields":
return None
fields_schema = cast("ModelFieldsSchema", fields_schema)
field = fields_schema["fields"].get(field_name)
if not field:
return None
return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
def validate_type(*, type_: type[_T], value: object) -> _T:
"""Strict validation that the given value matches the expected type"""
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
return cast(_T, parse_obj(type_, value))
return cast(_T, _validate_non_model_type(type_=type_, value=value))
# our use of subclasssing here causes weirdness for type checkers,
# so we just pretend that we don't subclass
if TYPE_CHECKING:
GenericModel = BaseModel
else:
class GenericModel(BaseGenericModel, BaseModel):
pass
if PYDANTIC_V2:
from pydantic import TypeAdapter as _TypeAdapter
_CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter))
if TYPE_CHECKING:
from pydantic import TypeAdapter
else:
TypeAdapter = _CachedTypeAdapter
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
return TypeAdapter(type_).validate_python(value)
elif not TYPE_CHECKING: # TODO: condition is weird
class RootModel(GenericModel, Generic[_T]):
"""Used as a placeholder to easily convert runtime types to a Pydantic format
to provide validation.
For example:
```py
validated = RootModel[int](__root__="5").__root__
# validated: 5
```
"""
__root__: _T
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
model = _create_pydantic_model(type_).validate(value)
return cast(_T, model.__root__)
def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
return RootModel[type_] # type: ignore

View File

@ -0,0 +1,220 @@
from __future__ import annotations
from os import PathLike
from typing import (
IO,
TYPE_CHECKING,
Any,
Dict,
List,
Type,
Tuple,
Union,
Mapping,
TypeVar,
Callable,
Optional,
Sequence,
)
from typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable
import httpx
import pydantic
from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport
if TYPE_CHECKING:
from ._models import BaseModel
Transport = BaseTransport
AsyncTransport = AsyncBaseTransport
Query = Mapping[str, object]
Body = object
AnyMapping = Mapping[str, object]
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
_T = TypeVar("_T")
# Approximates httpx internal ProxiesTypes and RequestFiles types
# while adding support for `PathLike` instances
ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]]
ProxiesTypes = Union[str, Proxy, ProxiesDict]
if TYPE_CHECKING:
Base64FileInput = Union[IO[bytes], PathLike[str]]
FileContent = Union[IO[bytes], bytes, PathLike[str]]
else:
Base64FileInput = Union[IO[bytes], PathLike]
FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.
FileTypes = Union[
# file (or bytes)
FileContent,
# (filename, file (or bytes))
Tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
Tuple[Optional[str], FileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
]
RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
# duplicate of the above but without our custom file support
HttpxFileContent = Union[IO[bytes], bytes]
HttpxFileTypes = Union[
# file (or bytes)
HttpxFileContent,
# (filename, file (or bytes))
Tuple[Optional[str], HttpxFileContent],
# (filename, file (or bytes), content_type)
Tuple[Optional[str], HttpxFileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
]
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]
# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT
# where ResponseT includes `None`. In order to support directly
# passing `None`, overloads would have to be defined for every
# method that uses `ResponseT` which would lead to an unacceptable
# amount of code duplication and make it unreadable. See _base_client.py
# for example usage.
#
# This unfortunately means that you will either have
# to import this type and pass it explicitly:
#
# from openai import NoneType
# client.get('/foo', cast_to=NoneType)
#
# or build it yourself:
#
# client.get('/foo', cast_to=type(None))
if TYPE_CHECKING:
NoneType: Type[None]
else:
NoneType = type(None)
class RequestOptions(TypedDict, total=False):
headers: Headers
max_retries: int
timeout: float | Timeout | None
params: Query
extra_json: AnyMapping
idempotency_key: str
# Sentinel class used until PEP 0661 is accepted
class NotGiven:
"""
A sentinel singleton class used to distinguish omitted keyword arguments
from those passed in with the value None (which may have different behavior).
For example:
```py
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response:
...
get(timeout=1) # 1s timeout
get(timeout=None) # No timeout
get() # Default timeout behavior, which may not be statically known at the method definition.
```
"""
def __bool__(self) -> Literal[False]:
return False
@override
def __repr__(self) -> str:
return "NOT_GIVEN"
NotGivenOr = Union[_T, NotGiven]
NOT_GIVEN = NotGiven()
class Omit:
"""In certain situations you need to be able to represent a case where a default value has
to be explicitly removed and `None` is not an appropriate substitute, for example:
```py
# as the default `Content-Type` header is `application/json` that will be sent
client.post("/upload/files", files={"file": b"my raw file content"})
# you can't explicitly override the header as it has to be dynamically generated
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
client.post(..., headers={"Content-Type": "multipart/form-data"})
# instead you can remove the default `application/json` header by passing Omit
client.post(..., headers={"Content-Type": Omit()})
```
"""
def __bool__(self) -> Literal[False]:
return False
@runtime_checkable
class ModelBuilderProtocol(Protocol):
@classmethod
def build(
cls: type[_T],
*,
response: Response,
data: object,
) -> _T:
...
Headers = Mapping[str, Union[str, Omit]]
class HeadersLikeProtocol(Protocol):
def get(self, __key: str) -> str | None:
...
HeadersLike = Union[Headers, HeadersLikeProtocol]
ResponseT = TypeVar(
"ResponseT",
bound=Union[
object,
str,
None,
"BaseModel",
List[Any],
Dict[str, Any],
Response,
ModelBuilderProtocol,
"APIResponse[Any]",
"AsyncAPIResponse[Any]",
"HttpxBinaryResponseContent",
],
)
StrBytesIntFloat = Union[str, bytes, int, float]
# Note: copied from Pydantic
# https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49
IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None"
PostParser = Callable[[Any], Any]
@runtime_checkable
class InheritsGeneric(Protocol):
"""Represents a type that has inherited from `Generic`
The `__orig_bases__` property can be used to determine the resolved
type variable for a given base class.
"""
__orig_bases__: tuple[_GenericAlias]
class _GenericAlias(Protocol):
__origin__: type[object]
class HttpxSendArgs(TypedDict, total=False):
auth: httpx.Auth

View File

@ -0,0 +1,48 @@
from ._utils import (
flatten as flatten,
is_dict as is_dict,
is_list as is_list,
is_given as is_given,
is_tuple as is_tuple,
lru_cache as lru_cache,
is_mapping as is_mapping,
is_tuple_t as is_tuple_t,
parse_date as parse_date,
is_iterable as is_iterable,
is_sequence as is_sequence,
coerce_float as coerce_float,
is_mapping_t as is_mapping_t,
removeprefix as removeprefix,
removesuffix as removesuffix,
extract_files as extract_files,
is_sequence_t as is_sequence_t,
required_args as required_args,
coerce_boolean as coerce_boolean,
coerce_integer as coerce_integer,
file_from_path as file_from_path,
parse_datetime as parse_datetime,
strip_not_given as strip_not_given,
deepcopy_minimal as deepcopy_minimal,
get_async_library as get_async_library,
maybe_coerce_float as maybe_coerce_float,
get_required_header as get_required_header,
maybe_coerce_boolean as maybe_coerce_boolean,
maybe_coerce_integer as maybe_coerce_integer,
)
from ._typing import (
is_list_type as is_list_type,
is_union_type as is_union_type,
extract_type_arg as extract_type_arg,
is_iterable_type as is_iterable_type,
is_required_type as is_required_type,
is_annotated_type as is_annotated_type,
strip_annotated_type as strip_annotated_type,
extract_type_var_from_base as extract_type_var_from_base,
)
from ._transform import (
PropertyInfo as PropertyInfo,
transform as transform,
async_transform as async_transform,
maybe_transform as maybe_transform,
async_maybe_transform as async_maybe_transform,
)

View File

@ -0,0 +1,382 @@
from __future__ import annotations
import io
import base64
import pathlib
from typing import Any, Mapping, TypeVar, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, override, get_type_hints
import anyio
import pydantic
from ._utils import (
is_list,
is_mapping,
is_iterable,
)
from .._files import is_base64_file_input
from ._typing import (
is_list_type,
is_union_type,
extract_type_arg,
is_iterable_type,
is_required_type,
is_annotated_type,
strip_annotated_type,
)
from .._compat import model_dump, is_typeddict
_T = TypeVar("_T")
# TODO: support for drilling globals() and locals()
# TODO: ensure works correctly with forward references in all cases
PropertyFormat = Literal["iso8601", "base64", "custom"]
class PropertyInfo:
"""Metadata class to be used in Annotated types to provide information about a given type.
For example:
class MyParams(TypedDict):
account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
"""
alias: str | None
format: PropertyFormat | None
format_template: str | None
discriminator: str | None
def __init__(
self,
*,
alias: str | None = None,
format: PropertyFormat | None = None,
format_template: str | None = None,
discriminator: str | None = None,
) -> None:
self.alias = alias
self.format = format
self.format_template = format_template
self.discriminator = discriminator
@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"
def maybe_transform(
data: object,
expected_type: object,
) -> Any | None:
"""Wrapper over `transform()` that allows `None` to be passed.
See `transform()` for more details.
"""
if data is None:
return None
return transform(data, expected_type)
# Wrapper over _transform_recursive providing fake types
def transform(
data: _T,
expected_type: object,
) -> _T:
"""Transform dictionaries based off of type information from the given type, for example:
```py
class Params(TypedDict, total=False):
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
transformed = transform({"card_id": "<my card ID>"}, Params)
# {'cardID': '<my card ID>'}
```
Any keys / data that does not have type information given will be included as is.
It should be noted that the transformations that this function does are not represented in the type system.
"""
transformed = _transform_recursive(data, annotation=cast(type, expected_type))
return cast(_T, transformed)
def _get_annotated_type(type_: type) -> type | None:
"""If the given type is an `Annotated` type then it is returned, if not `None` is returned.
This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
"""
if is_required_type(type_):
# Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
type_ = get_args(type_)[0]
if is_annotated_type(type_):
return type_
return None
def _maybe_transform_key(key: str, type_: type) -> str:
"""Transform the given `data` based on the annotations provided in `type_`.
Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata.
"""
annotated_type = _get_annotated_type(type_)
if annotated_type is None:
# no `Annotated` definition for this type, no transformation needed
return key
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
return annotation.alias
return key
def _transform_recursive(
data: object,
*,
annotation: type,
inner_type: type | None = None,
) -> object:
"""Transform the given data against the expected type.
Args:
annotation: The direct type annotation given to the particular piece of data.
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
the list can be transformed using the metadata from the container type.
Defaults to the same value as the `annotation` argument.
"""
if inner_type is None:
inner_type = annotation
stripped_type = strip_annotated_type(inner_type)
if is_typeddict(stripped_type) and is_mapping(data):
return _transform_typeddict(data, stripped_type)
if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
):
inner_type = extract_type_arg(stripped_type, 0)
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
return data
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True)
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return _format_data(data, annotation.format, annotation.format_template)
return data
def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
if isinstance(data, (date, datetime)):
if format_ == "iso8601":
return data.isoformat()
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)
if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None
if isinstance(data, pathlib.Path):
binary = data.read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()
if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()
if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
return base64.b64encode(binary).decode("ascii")
return data
def _transform_typeddict(
data: Mapping[str, object],
expected_type: type,
) -> Mapping[str, object]:
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
return result
async def async_maybe_transform(
data: object,
expected_type: object,
) -> Any | None:
"""Wrapper over `async_transform()` that allows `None` to be passed.
See `async_transform()` for more details.
"""
if data is None:
return None
return await async_transform(data, expected_type)
async def async_transform(
data: _T,
expected_type: object,
) -> _T:
"""Transform dictionaries based off of type information from the given type, for example:
```py
class Params(TypedDict, total=False):
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
transformed = transform({"card_id": "<my card ID>"}, Params)
# {'cardID': '<my card ID>'}
```
Any keys / data that does not have type information given will be included as is.
It should be noted that the transformations that this function does are not represented in the type system.
"""
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
return cast(_T, transformed)
async def _async_transform_recursive(
data: object,
*,
annotation: type,
inner_type: type | None = None,
) -> object:
"""Transform the given data against the expected type.
Args:
annotation: The direct type annotation given to the particular piece of data.
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
the list can be transformed using the metadata from the container type.
Defaults to the same value as the `annotation` argument.
"""
if inner_type is None:
inner_type = annotation
stripped_type = strip_annotated_type(inner_type)
if is_typeddict(stripped_type) and is_mapping(data):
return await _async_transform_typeddict(data, stripped_type)
if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
):
inner_type = extract_type_arg(stripped_type, 0)
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
return data
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True)
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return await _async_format_data(data, annotation.format, annotation.format_template)
return data
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
if isinstance(data, (date, datetime)):
if format_ == "iso8601":
return data.isoformat()
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)
if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None
if isinstance(data, pathlib.Path):
binary = await anyio.Path(data).read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()
if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()
if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
return base64.b64encode(binary).decode("ascii")
return data
async def _async_transform_typeddict(
data: Mapping[str, object],
expected_type: type,
) -> Mapping[str, object]:
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
return result

View File

@ -0,0 +1,120 @@
from __future__ import annotations
from typing import Any, TypeVar, Iterable, cast
from collections import abc as _c_abc
from typing_extensions import Required, Annotated, get_args, get_origin
from .._types import InheritsGeneric
from .._compat import is_union as _is_union
def is_annotated_type(typ: type) -> bool:
return get_origin(typ) == Annotated
def is_list_type(typ: type) -> bool:
return (get_origin(typ) or typ) == list
def is_iterable_type(typ: type) -> bool:
"""If the given type is `typing.Iterable[T]`"""
origin = get_origin(typ) or typ
return origin == Iterable or origin == _c_abc.Iterable
def is_union_type(typ: type) -> bool:
return _is_union(get_origin(typ))
def is_required_type(typ: type) -> bool:
return get_origin(typ) == Required
def is_typevar(typ: type) -> bool:
# type ignore is required because type checkers
# think this expression will always return False
return type(typ) == TypeVar # type: ignore
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
return strip_annotated_type(cast(type, get_args(typ)[0]))
return typ
def extract_type_arg(typ: type, index: int) -> type:
args = get_args(typ)
try:
return cast(type, args[index])
except IndexError as err:
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
def extract_type_var_from_base(
typ: type,
*,
generic_bases: tuple[type, ...],
index: int,
failure_message: str | None = None,
) -> type:
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyResponse(Foo[bytes]):
...
extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
```
And where a generic subclass is given:
```py
_T = TypeVar('_T')
class MyResponse(Foo[_T]):
...
extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes
```
"""
cls = cast(object, get_origin(typ) or typ)
if cls in generic_bases:
# we're given the class directly
return extract_type_arg(typ, index)
# if a subclass is given
# ---
# this is needed as __orig_bases__ is not present in the typeshed stubs
# because it is intended to be for internal use only, however there does
# not seem to be a way to resolve generic TypeVars for inherited subclasses
# without using it.
if isinstance(cls, InheritsGeneric):
target_base_class: Any | None = None
for base in cls.__orig_bases__:
if base.__origin__ in generic_bases:
target_base_class = base
break
if target_base_class is None:
raise RuntimeError(
"Could not find the generic base class;\n"
"This should never happen;\n"
f"Does {cls} inherit from one of {generic_bases} ?"
)
extracted = extract_type_arg(target_base_class, index)
if is_typevar(extracted):
# If the extracted type argument is itself a type variable
# then that means the subclass itself is generic, so we have
# to resolve the type argument from the class itself, not
# the base class.
#
# Note: if there is more than 1 type argument, the subclass could
# change the ordering of the type arguments, this is not currently
# supported.
return extract_type_arg(typ, index)
return extracted
raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")

View File

@ -0,0 +1,403 @@
from __future__ import annotations
import os
import re
import inspect
import functools
from typing import (
Any,
Tuple,
Mapping,
TypeVar,
Callable,
Iterable,
Sequence,
cast,
overload,
)
from pathlib import Path
from typing_extensions import TypeGuard
import sniffio
from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike
from .._compat import parse_date as parse_date, parse_datetime as parse_datetime
_T = TypeVar("_T")
_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
_MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
return [item for sublist in t for item in sublist]
def extract_files(
# TODO: this needs to take Dict but variance issues.....
# create protocol type ?
query: Mapping[str, object],
*,
paths: Sequence[Sequence[str]],
) -> list[tuple[str, FileTypes]]:
"""Recursively extract files from the given dictionary based on specified paths.
A path may look like this ['foo', 'files', '<array>', 'data'].
Note: this mutates the given dictionary.
"""
files: list[tuple[str, FileTypes]] = []
for path in paths:
files.extend(_extract_items(query, path, index=0, flattened_key=None))
return files
def _extract_items(
obj: object,
path: Sequence[str],
*,
index: int,
flattened_key: str | None,
) -> list[tuple[str, FileTypes]]:
try:
key = path[index]
except IndexError:
if isinstance(obj, NotGiven):
# no value was provided - we can safely ignore
return []
# cyclical import
from .._files import assert_is_file_content
# We have exhausted the path, return the entry we found.
assert_is_file_content(obj, key=flattened_key)
assert flattened_key is not None
return [(flattened_key, cast(FileTypes, obj))]
index += 1
if is_dict(obj):
try:
# We are at the last entry in the path so we must remove the field
if (len(path)) == index:
item = obj.pop(key)
else:
item = obj[key]
except KeyError:
# Key was not present in the dictionary, this is not indicative of an error
# as the given path may not point to a required field. We also do not want
# to enforce required fields as the API may differ from the spec in some cases.
return []
if flattened_key is None:
flattened_key = key
else:
flattened_key += f"[{key}]"
return _extract_items(
item,
path,
index=index,
flattened_key=flattened_key,
)
elif is_list(obj):
if key != "<array>":
return []
return flatten(
[
_extract_items(
item,
path,
index=index,
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
)
for item in obj
]
)
# Something unexpected was passed, just ignore it.
return []
def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]:
return not isinstance(obj, NotGiven)
# Type safe methods for narrowing types with TypeVars.
# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
# however this cause Pyright to rightfully report errors. As we know we don't
# care about the contained types we can safely use `object` in it's place.
#
# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
# `is_*` is for when you're dealing with an unknown input
# `is_*_t` is for when you're narrowing a known union type to a specific subset
def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
return isinstance(obj, tuple)
def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
return isinstance(obj, tuple)
def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
return isinstance(obj, Sequence)
def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
return isinstance(obj, Sequence)
def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
return isinstance(obj, Mapping)
def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
return isinstance(obj, Mapping)
def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
return isinstance(obj, dict)
def is_list(obj: object) -> TypeGuard[list[object]]:
return isinstance(obj, list)
def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
return isinstance(obj, Iterable)
def deepcopy_minimal(item: _T) -> _T:
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
- mappings, e.g. `dict`
- list
This is done for performance reasons.
"""
if is_mapping(item):
return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
if is_list(item):
return cast(_T, [deepcopy_minimal(entry) for entry in item])
return item
# copied from https://github.com/Rapptz/RoboDanny
def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
size = len(seq)
if size == 0:
return ""
if size == 1:
return seq[0]
if size == 2:
return f"{seq[0]} {final} {seq[1]}"
return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
def quote(string: str) -> str:
"""Add single quotation marks around the given string. Does *not* do any escaping."""
return f"'{string}'"
def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
"""Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
Useful for enforcing runtime validation of overloaded functions.
Example usage:
```py
@overload
def foo(*, a: str) -> str:
...
@overload
def foo(*, b: bool) -> str:
...
# This enforces the same constraints that a static type checker would
# i.e. that either a or b must be passed to the function
@required_args(["a"], ["b"])
def foo(*, a: str | None = None, b: bool | None = None) -> str:
...
```
"""
def inner(func: CallableT) -> CallableT:
params = inspect.signature(func).parameters
positional = [
name
for name, param in params.items()
if param.kind
in {
param.POSITIONAL_ONLY,
param.POSITIONAL_OR_KEYWORD,
}
]
@functools.wraps(func)
def wrapper(*args: object, **kwargs: object) -> object:
given_params: set[str] = set()
for i, _ in enumerate(args):
try:
given_params.add(positional[i])
except IndexError:
raise TypeError(
f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
) from None
for key in kwargs.keys():
given_params.add(key)
for variant in variants:
matches = all((param in given_params for param in variant))
if matches:
break
else: # no break
if len(variants) > 1:
variations = human_join(
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
)
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
else:
assert len(variants) > 0
# TODO: this error message is not deterministic
missing = list(set(variants[0]) - given_params)
if len(missing) > 1:
msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
else:
msg = f"Missing required argument: {quote(missing[0])}"
raise TypeError(msg)
return func(*args, **kwargs)
return wrapper # type: ignore
return inner
_K = TypeVar("_K")
_V = TypeVar("_V")
@overload
def strip_not_given(obj: None) -> None:
...
@overload
def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]:
...
@overload
def strip_not_given(obj: object) -> object:
...
def strip_not_given(obj: object | None) -> object:
"""Remove all top-level keys where their values are instances of `NotGiven`"""
if obj is None:
return None
if not is_mapping(obj):
return obj
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
def coerce_integer(val: str) -> int:
return int(val, base=10)
def coerce_float(val: str) -> float:
return float(val)
def coerce_boolean(val: str) -> bool:
return val == "true" or val == "1" or val == "on"
def maybe_coerce_integer(val: str | None) -> int | None:
if val is None:
return None
return coerce_integer(val)
def maybe_coerce_float(val: str | None) -> float | None:
if val is None:
return None
return coerce_float(val)
def maybe_coerce_boolean(val: str | None) -> bool | None:
if val is None:
return None
return coerce_boolean(val)
def removeprefix(string: str, prefix: str) -> str:
"""Remove a prefix from a string.
Backport of `str.removeprefix` for Python < 3.9
"""
if string.startswith(prefix):
return string[len(prefix) :]
return string
def removesuffix(string: str, suffix: str) -> str:
"""Remove a suffix from a string.
Backport of `str.removesuffix` for Python < 3.9
"""
if string.endswith(suffix):
return string[: -len(suffix)]
return string
def file_from_path(path: str) -> FileTypes:
contents = Path(path).read_bytes()
file_name = os.path.basename(path)
return (file_name, contents)
def get_required_header(headers: HeadersLike, header: str) -> str:
lower_header = header.lower()
if isinstance(headers, Mapping):
headers = cast(Headers, headers)
for k, v in headers.items():
if k.lower() == lower_header and isinstance(v, str):
return v
""" to deal with the case where the header looks like Stainless-Event-Id """
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
value = headers.get(normalized_header)
if value:
return value
raise ValueError(f"Could not find {header} header")
def get_async_library() -> str:
try:
return sniffio.current_async_library()
except Exception:
return "false"
def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
"""A version of functools.lru_cache that retains the type signature
for the wrapped function arguments.
"""
wrapper = functools.lru_cache( # noqa: TID251
maxsize=maxsize,
)
return cast(Any, wrapper) # type: ignore[no-any-return]

View File

@ -1,7 +1,7 @@
from enum import Enum
from typing import List, Literal, Optional
from pydantic import BaseModel
from ..._models import BaseModel
from model_providers.core.entities.model_entities import (
ModelStatus,

View File

@ -1,8 +1,8 @@
import time
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, root_validator
from ..._models import BaseModel
from pydantic import Field as FieldInfo
from typing_extensions import Literal
@ -81,7 +81,7 @@ class ModelCard(BaseModel):
"tts",
"text2img",
] = "llm"
created: int = Field(default_factory=lambda: int(time.time()))
created: int = FieldInfo(default_factory=lambda: int(time.time()))
owned_by: Literal["owner"] = "owner"
@ -171,7 +171,7 @@ class ChatCompletionStreamResponseChoice(BaseModel):
class ChatCompletionResponse(BaseModel):
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
created: int = FieldInfo(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
@ -180,7 +180,7 @@ class ChatCompletionResponse(BaseModel):
class ChatCompletionStreamResponse(BaseModel):
id: str
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
created: int = FieldInfo(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionStreamResponseChoice]

View File

@ -1,331 +0,0 @@
from enum import Enum
from typing import Any, Literal, Optional, Union
from pydantic import BaseModel
from model_providers.core.entities.provider_configuration import ProviderModelBundle
from model_providers.core.file.file_obj import FileObj
from model_providers.core.model_runtime.entities.message_entities import (
PromptMessageRole,
)
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity
class ModelConfigEntity(BaseModel):
"""
Model Config Entity.
"""
provider: str
model: str
model_schema: AIModelEntity
mode: str
provider_model_bundle: ProviderModelBundle
credentials: Dict[str, Any] = {}
parameters: Dict[str, Any] = {}
stop: List[str] = []
class AdvancedChatMessageEntity(BaseModel):
"""
Advanced Chat Message Entity.
"""
text: str
role: PromptMessageRole
class AdvancedChatPromptTemplateEntity(BaseModel):
"""
Advanced Chat Prompt Template Entity.
"""
messages: List[AdvancedChatMessageEntity]
class AdvancedCompletionPromptTemplateEntity(BaseModel):
"""
Advanced Completion Prompt Template Entity.
"""
class RolePrefixEntity(BaseModel):
"""
Role Prefix Entity.
"""
user: str
assistant: str
prompt: str
role_prefix: Optional[RolePrefixEntity] = None
class PromptTemplateEntity(BaseModel):
"""
Prompt Template Entity.
"""
class PromptType(Enum):
"""
Prompt Type.
'simple', 'advanced'
"""
SIMPLE = "simple"
ADVANCED = "advanced"
@classmethod
def value_of(cls, value: str) -> "PromptType":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid prompt type value {value}")
prompt_type: PromptType
simple_prompt_template: Optional[str] = None
advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
advanced_completion_prompt_template: Optional[
AdvancedCompletionPromptTemplateEntity
] = None
class ExternalDataVariableEntity(BaseModel):
"""
External Data Variable Entity.
"""
variable: str
type: str
config: Dict[str, Any] = {}
class DatasetRetrieveConfigEntity(BaseModel):
"""
Dataset Retrieve Config Entity.
"""
class RetrieveStrategy(Enum):
"""
Dataset Retrieve Strategy.
'single' or 'multiple'
"""
SINGLE = "single"
MULTIPLE = "multiple"
@classmethod
def value_of(cls, value: str) -> "RetrieveStrategy":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid retrieve strategy value {value}")
query_variable: Optional[str] = None # Only when app mode is completion
retrieve_strategy: RetrieveStrategy
single_strategy: Optional[str] = None # for temp
top_k: Optional[int] = None
score_threshold: Optional[float] = None
reranking_model: Optional[dict] = None
class DatasetEntity(BaseModel):
"""
Dataset Config Entity.
"""
dataset_ids: List[str]
retrieve_config: DatasetRetrieveConfigEntity
class SensitiveWordAvoidanceEntity(BaseModel):
"""
Sensitive Word Avoidance Entity.
"""
type: str
config: Dict[str, Any] = {}
class TextToSpeechEntity(BaseModel):
"""
Sensitive Word Avoidance Entity.
"""
enabled: bool
voice: Optional[str] = None
language: Optional[str] = None
class FileUploadEntity(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: Literal["builtin", "api"]
provider_id: str
tool_name: str
tool_parameters: Dict[str, Any] = {}
class AgentPromptEntity(BaseModel):
"""
Agent Prompt Entity.
"""
first_prompt: str
next_iteration: str
class AgentScratchpadUnit(BaseModel):
"""
Agent First Prompt Entity.
"""
class Action(BaseModel):
"""
Action Entity.
"""
action_name: str
action_input: Union[dict, str]
agent_response: Optional[str] = None
thought: Optional[str] = None
action_str: Optional[str] = None
observation: Optional[str] = None
action: Optional[Action] = None
class AgentEntity(BaseModel):
"""
Agent Entity.
"""
class Strategy(Enum):
"""
Agent Strategy.
"""
CHAIN_OF_THOUGHT = "chain-of-thought"
FUNCTION_CALLING = "function-calling"
provider: str
model: str
strategy: Strategy
prompt: Optional[AgentPromptEntity] = None
tools: List[AgentToolEntity] = None
max_iteration: int = 5
class AppOrchestrationConfigEntity(BaseModel):
"""
App Orchestration Config Entity.
"""
model_config: ModelConfigEntity
prompt_template: PromptTemplateEntity
external_data_variables: List[ExternalDataVariableEntity] = []
agent: Optional[AgentEntity] = None
# features
dataset: Optional[DatasetEntity] = None
file_upload: Optional[FileUploadEntity] = None
opening_statement: Optional[str] = None
suggested_questions_after_answer: bool = False
show_retrieve_source: bool = False
more_like_this: bool = False
speech_to_text: bool = False
text_to_speech: dict = {}
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
class InvokeFrom(Enum):
"""
Invoke From.
"""
SERVICE_API = "service-api"
WEB_APP = "web-app"
EXPLORE = "explore"
DEBUGGER = "debugger"
@classmethod
def value_of(cls, value: str) -> "InvokeFrom":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid invoke from value {value}")
def to_source(self) -> str:
"""
Get source of invoke from.
:return: source
"""
if self == InvokeFrom.WEB_APP:
return "web_app"
elif self == InvokeFrom.DEBUGGER:
return "dev"
elif self == InvokeFrom.EXPLORE:
return "explore_app"
elif self == InvokeFrom.SERVICE_API:
return "api"
return "dev"
class ApplicationGenerateEntity(BaseModel):
"""
Application Generate Entity.
"""
task_id: str
tenant_id: str
app_id: str
app_model_config_id: str
# for save
app_model_config_dict: dict
app_model_config_override: bool
# Converted from app_model_config to Entity object, or directly covered by external input
app_orchestration_config_entity: AppOrchestrationConfigEntity
conversation_id: Optional[str] = None
inputs: Dict[str, str]
query: Optional[str] = None
files: List[FileObj] = []
user_id: str
# extras
stream: bool
invoke_from: InvokeFrom
# extra parameters, like: auto_generate_conversation_name
extras: Dict[str, Any] = {}

View File

@ -1,5 +1,5 @@
import enum
from typing import Any, cast
from typing import Any, cast, List
from langchain.schema import (
AIMessage,
@ -8,7 +8,7 @@ from langchain.schema import (
HumanMessage,
SystemMessage,
)
from pydantic import BaseModel
from ..._models import BaseModel
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,

View File

@ -1,7 +1,7 @@
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
from ..._models import BaseModel
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import (

View File

@ -4,7 +4,7 @@ import logging
from json import JSONDecodeError
from typing import Dict, Iterator, List, Optional
from pydantic import BaseModel
from ..._models import BaseModel
from model_providers.core.entities.model_entities import (
ModelStatus,

View File

@ -1,7 +1,7 @@
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
from ..._models import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelType

View File

@ -1,7 +1,7 @@
from enum import Enum
from typing import Any
from pydantic import BaseModel
from ..._models import BaseModel
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,

View File

@ -1,6 +1,6 @@
from typing import Optional
from pydantic import BaseModel
from ...._models import BaseModel
class I18nObject(BaseModel):

View File

@ -2,7 +2,7 @@ from decimal import Decimal
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
from ...._models import BaseModel
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,

View File

@ -2,7 +2,7 @@ from abc import ABC
from enum import Enum
from typing import List, Optional, Union
from pydantic import BaseModel
from ...._models import BaseModel
class PromptMessageRole(Enum):

View File

@ -2,7 +2,7 @@ from decimal import Decimal
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from ...._models import BaseModel
from model_providers.core.model_runtime.entities.common_entities import I18nObject

View File

@ -1,7 +1,7 @@
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
from ...._models import BaseModel
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import (

View File

@ -1,6 +1,6 @@
from typing import List
from pydantic import BaseModel
from ...._models import BaseModel
class RerankDocument(BaseModel):

View File

@ -1,7 +1,7 @@
from decimal import Decimal
from typing import List
from pydantic import BaseModel
from ...._models import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelUsage

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel
from ....._models import BaseModel
from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from model_providers.core.model_runtime.entities.llm_entities import LLMMode

View File

@ -3,7 +3,7 @@ import logging
import os
from typing import Dict, List, Optional, Union
from pydantic import BaseModel
from ...._models import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.entities.provider_entities import (

View File

@ -1,21 +0,0 @@
from typing import Any, Literal
from pydantic import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
if PYDANTIC_V2:
from pydantic_core import Url as Url
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
return model.model_dump(mode=mode, **kwargs)
else:
from pydantic import AnyUrl as Url # noqa: F401
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
return model.dict(**kwargs)

View File

@ -1,234 +0,0 @@
import dataclasses
import datetime
from collections import defaultdict, deque
from collections.abc import Callable
from decimal import Decimal
from enum import Enum
from ipaddress import (
IPv4Address,
IPv4Interface,
IPv4Network,
IPv6Address,
IPv6Interface,
IPv6Network,
)
from pathlib import Path, PurePath
from re import Pattern
from types import GeneratorType
from typing import Any, Optional, Union, Dict, Type, List, Tuple
from uuid import UUID
from pydantic import BaseModel
from pydantic.color import Color
from pydantic.networks import AnyUrl, NameEmail
from pydantic.types import SecretBytes, SecretStr
from ._compat import PYDANTIC_V2, Url, _model_dump
# Taken from Pydantic v1 as is
def isoformat(o: Union[datetime.date, datetime.time]) -> str:
return o.isoformat()
# Taken from Pydantic v1 as is
# TODO: pv2 should this return strings instead?
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
"""
Encodes a Decimal as int of there's no exponent, otherwise float
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
where a integer (but not int typed) is used. Encoding this as a float
results in failed round-tripping between encode and parse.
Our Id type is a prime example of this.
>>> decimal_encoder(Decimal("1.0"))
1.0
>>> decimal_encoder(Decimal("1"))
1
"""
if dec_value.as_tuple().exponent >= 0: # type: ignore[operator]
return int(dec_value)
else:
return float(dec_value)
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
bytes: lambda o: o.decode(),
Color: str,
datetime.date: isoformat,
datetime.datetime: isoformat,
datetime.time: isoformat,
datetime.timedelta: lambda td: td.total_seconds(),
Decimal: decimal_encoder,
Enum: lambda o: o.value,
frozenset: list,
deque: list,
GeneratorType: list,
IPv4Address: str,
IPv4Interface: str,
IPv4Network: str,
IPv6Address: str,
IPv6Interface: str,
IPv6Network: str,
NameEmail: str,
Path: str,
Pattern: lambda o: o.pattern,
SecretBytes: str,
SecretStr: str,
set: list,
UUID: str,
Url: str,
AnyUrl: str,
}
def generate_encoders_by_class_tuples(
type_encoder_map: Dict[Any, Callable[[Any], Any]],
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
tuple
)
for type_, encoder in type_encoder_map.items():
encoders_by_class_tuples[encoder] += (type_,)
return encoders_by_class_tuples
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
def jsonable_encoder(
obj: Any,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None,
sqlalchemy_safe: bool = True,
) -> Any:
custom_encoder = custom_encoder or {}
if custom_encoder:
if type(obj) in custom_encoder:
return custom_encoder[type(obj)](obj)
else:
for encoder_type, encoder_instance in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder_instance(obj)
if isinstance(obj, BaseModel):
# TODO: remove when deprecating Pydantic v1
encoders: Dict[Any, Any] = {}
if not PYDANTIC_V2:
encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
if custom_encoder:
encoders.update(custom_encoder)
obj_dict = _model_dump(
obj,
mode="json",
include=None,
exclude=None,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"]
return jsonable_encoder(
obj_dict,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
# TODO: remove when deprecating Pydantic v1
custom_encoder=encoders,
sqlalchemy_safe=sqlalchemy_safe,
)
if dataclasses.is_dataclass(obj):
obj_dict = dataclasses.asdict(obj)
return jsonable_encoder(
obj_dict,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, PurePath):
return str(obj)
if isinstance(obj, str | int | float | type(None)):
return obj
if isinstance(obj, Decimal):
return format(obj, "f")
if isinstance(obj, dict):
encoded_dict = {}
allowed_keys = set(obj.keys())
for key, value in obj.items():
if (
(
not sqlalchemy_safe
or (not isinstance(key, str))
or (not key.startswith("_sa"))
)
and (value is not None or not exclude_none)
and key in allowed_keys
):
encoded_key = jsonable_encoder(
key,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_value = jsonable_encoder(
value,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_dict[encoded_key] = encoded_value
return encoded_dict
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)):
encoded_list = []
for item in obj:
encoded_list.append(
jsonable_encoder(
item,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
)
return encoded_list
if type(obj) in ENCODERS_BY_TYPE:
return ENCODERS_BY_TYPE[type(obj)](obj)
for encoder, classes_tuple in encoders_by_class_tuples.items():
if isinstance(obj, classes_tuple):
return encoder(obj)
try:
data = dict(obj)
except Exception as e:
errors: List[Exception] = [e]
try:
data = vars(obj)
except Exception as e:
errors.append(e)
raise ValueError(errors) from e
return jsonable_encoder(
data,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)

View File

@ -1,5 +1,5 @@
import pydantic
from pydantic import BaseModel
from ...._models import BaseModel
def dump_model(model: BaseModel) -> dict:

View File

@ -2,7 +2,7 @@ import json
from typing import TYPE_CHECKING, Any, Dict
if TYPE_CHECKING:
from pydantic import BaseModel
from ..._models import BaseModel
def dictify(data: "BaseModel") -> Dict[str, Any]:

View File

@ -1,7 +1,7 @@
import os
import orjson
from pydantic import BaseModel
from ..._models import BaseModel
def json_dumps(o):

View File

@ -136,7 +136,7 @@ def init_server(logging_conf: dict, providers_file: str) -> None:
yield f"http://127.0.0.1:20000"
finally:
print("")
# boot.destroy()
boot.destroy()
except SystemExit: