mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-07 07:23:29 +08:00
一些兼容 pydantic<3,>=1.9.0 的代码,
This commit is contained in:
parent
2bb24d8b6d
commit
d0267bf66b
222
model-providers/model_providers/_compat.py
Normal file
222
model-providers/model_providers/_compat.py
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
|
||||||
|
from datetime import date, datetime
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
|
||||||
|
from ._types import StrBytesIntFloat
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
|
||||||
|
|
||||||
|
# --------------- Pydantic v2 compatibility ---------------
|
||||||
|
|
||||||
|
# Pyright incorrectly reports some of our functions as overriding a method when they don't
|
||||||
|
# pyright: reportIncompatibleMethodOverride=false
|
||||||
|
|
||||||
|
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
|
||||||
|
|
||||||
|
# v1 re-exports
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
def parse_date(value: Union[date, StrBytesIntFloat]) -> date: # noqa: ARG001
|
||||||
|
...
|
||||||
|
|
||||||
|
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001
|
||||||
|
...
|
||||||
|
|
||||||
|
def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001
|
||||||
|
...
|
||||||
|
|
||||||
|
def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001
|
||||||
|
...
|
||||||
|
|
||||||
|
def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001
|
||||||
|
...
|
||||||
|
|
||||||
|
else:
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
from pydantic.v1.typing import (
|
||||||
|
get_args as get_args,
|
||||||
|
is_union as is_union,
|
||||||
|
get_origin as get_origin,
|
||||||
|
is_typeddict as is_typeddict,
|
||||||
|
is_literal_type as is_literal_type,
|
||||||
|
)
|
||||||
|
from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
|
||||||
|
else:
|
||||||
|
from pydantic.typing import (
|
||||||
|
get_args as get_args,
|
||||||
|
is_union as is_union,
|
||||||
|
get_origin as get_origin,
|
||||||
|
is_typeddict as is_typeddict,
|
||||||
|
is_literal_type as is_literal_type,
|
||||||
|
)
|
||||||
|
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
|
||||||
|
|
||||||
|
|
||||||
|
# refactored config
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pydantic import ConfigDict as ConfigDict
|
||||||
|
else:
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
else:
|
||||||
|
# TODO: provide an error message here?
|
||||||
|
ConfigDict = None
|
||||||
|
|
||||||
|
|
||||||
|
# renamed methods / properties
|
||||||
|
def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
return model.model_validate(value)
|
||||||
|
else:
|
||||||
|
return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
||||||
|
|
||||||
|
|
||||||
|
def field_is_required(field: FieldInfo) -> bool:
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
return field.is_required()
|
||||||
|
return field.required # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def field_get_default(field: FieldInfo) -> Any:
|
||||||
|
value = field.get_default()
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
|
if value == PydanticUndefined:
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def field_outer_type(field: FieldInfo) -> Any:
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
return field.annotation
|
||||||
|
return field.outer_type_ # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_config(model: type[pydantic.BaseModel]) -> Any:
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
return model.model_config
|
||||||
|
return model.__config__ # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
return model.model_fields
|
||||||
|
return model.__fields__ # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def model_copy(model: _ModelT) -> _ModelT:
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
return model.model_copy()
|
||||||
|
return model.copy() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
return model.model_dump_json(indent=indent)
|
||||||
|
return model.json(indent=indent) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def model_dump(
|
||||||
|
model: pydantic.BaseModel,
|
||||||
|
*,
|
||||||
|
exclude_unset: bool = False,
|
||||||
|
exclude_defaults: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
return model.model_dump(
|
||||||
|
exclude_unset=exclude_unset,
|
||||||
|
exclude_defaults=exclude_defaults,
|
||||||
|
)
|
||||||
|
return cast(
|
||||||
|
"dict[str, Any]",
|
||||||
|
model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
||||||
|
exclude_unset=exclude_unset,
|
||||||
|
exclude_defaults=exclude_defaults,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
return model.model_validate(data)
|
||||||
|
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
|
||||||
|
|
||||||
|
|
||||||
|
# generic models
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
class GenericModel(pydantic.BaseModel):
|
||||||
|
...
|
||||||
|
|
||||||
|
else:
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
# there no longer needs to be a distinction in v2 but
|
||||||
|
# we still have to create our own subclass to avoid
|
||||||
|
# inconsistent MRO ordering errors
|
||||||
|
class GenericModel(pydantic.BaseModel):
|
||||||
|
...
|
||||||
|
|
||||||
|
else:
|
||||||
|
import pydantic.generics
|
||||||
|
|
||||||
|
class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# cached properties
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
cached_property = property
|
||||||
|
|
||||||
|
# we define a separate type (copied from typeshed)
|
||||||
|
# that represents that `cached_property` is `set`able
|
||||||
|
# at runtime, which differs from `@property`.
|
||||||
|
#
|
||||||
|
# this is a separate type as editors likely special case
|
||||||
|
# `@property` and we don't want to cause issues just to have
|
||||||
|
# more helpful internal types.
|
||||||
|
|
||||||
|
class typed_cached_property(Generic[_T]):
|
||||||
|
func: Callable[[Any], _T]
|
||||||
|
attrname: str | None
|
||||||
|
|
||||||
|
def __init__(self, func: Callable[[Any], _T]) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def __set_name__(self, owner: type[Any], name: str) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
# __set__ is not defined at runtime, but @cached_property is designed to be settable
|
||||||
|
def __set__(self, instance: object, value: _T) -> None:
|
||||||
|
...
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from functools import cached_property as cached_property
|
||||||
|
except ImportError:
|
||||||
|
from cached_property import cached_property as cached_property
|
||||||
|
|
||||||
|
typed_cached_property = cached_property
|
||||||
127
model-providers/model_providers/_files.py
Normal file
127
model-providers/model_providers/_files.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
from typing import overload
|
||||||
|
from typing_extensions import TypeGuard
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from ._types import (
|
||||||
|
FileTypes,
|
||||||
|
FileContent,
|
||||||
|
RequestFiles,
|
||||||
|
HttpxFileTypes,
|
||||||
|
Base64FileInput,
|
||||||
|
HttpxFileContent,
|
||||||
|
HttpxRequestFiles,
|
||||||
|
)
|
||||||
|
from ._utils import is_tuple_t, is_mapping_t, is_sequence_t
|
||||||
|
|
||||||
|
|
||||||
|
def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
|
||||||
|
return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
|
||||||
|
|
||||||
|
|
||||||
|
def is_file_content(obj: object) -> TypeGuard[FileContent]:
|
||||||
|
return (
|
||||||
|
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
|
||||||
|
if not is_file_content(obj):
|
||||||
|
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def to_httpx_files(files: None) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
|
||||||
|
if files is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if is_mapping_t(files):
|
||||||
|
files = {key: _transform_file(file) for key, file in files.items()}
|
||||||
|
elif is_sequence_t(files):
|
||||||
|
files = [(key, _transform_file(file)) for key, file in files]
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_file(file: FileTypes) -> HttpxFileTypes:
|
||||||
|
if is_file_content(file):
|
||||||
|
if isinstance(file, os.PathLike):
|
||||||
|
path = pathlib.Path(file)
|
||||||
|
return (path.name, path.read_bytes())
|
||||||
|
|
||||||
|
return file
|
||||||
|
|
||||||
|
if is_tuple_t(file):
|
||||||
|
return (file[0], _read_file_content(file[1]), *file[2:])
|
||||||
|
|
||||||
|
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
|
||||||
|
|
||||||
|
|
||||||
|
def _read_file_content(file: FileContent) -> HttpxFileContent:
|
||||||
|
if isinstance(file, os.PathLike):
|
||||||
|
return pathlib.Path(file).read_bytes()
|
||||||
|
return file
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def async_to_httpx_files(files: None) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
|
||||||
|
if files is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if is_mapping_t(files):
|
||||||
|
files = {key: await _async_transform_file(file) for key, file in files.items()}
|
||||||
|
elif is_sequence_t(files):
|
||||||
|
files = [(key, await _async_transform_file(file)) for key, file in files]
|
||||||
|
else:
|
||||||
|
raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence")
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
|
||||||
|
if is_file_content(file):
|
||||||
|
if isinstance(file, os.PathLike):
|
||||||
|
path = anyio.Path(file)
|
||||||
|
return (path.name, await path.read_bytes())
|
||||||
|
|
||||||
|
return file
|
||||||
|
|
||||||
|
if is_tuple_t(file):
|
||||||
|
return (file[0], await _async_read_file_content(file[1]), *file[2:])
|
||||||
|
|
||||||
|
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_read_file_content(file: FileContent) -> HttpxFileContent:
|
||||||
|
if isinstance(file, os.PathLike):
|
||||||
|
return await anyio.Path(file).read_bytes()
|
||||||
|
|
||||||
|
return file
|
||||||
657
model-providers/model_providers/_models.py
Normal file
657
model-providers/model_providers/_models.py
Normal file
@ -0,0 +1,657 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import inspect
|
||||||
|
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast
|
||||||
|
from datetime import date, datetime
|
||||||
|
from typing_extensions import (
|
||||||
|
Unpack,
|
||||||
|
Literal,
|
||||||
|
ClassVar,
|
||||||
|
Protocol,
|
||||||
|
Required,
|
||||||
|
TypedDict,
|
||||||
|
TypeGuard,
|
||||||
|
final,
|
||||||
|
override,
|
||||||
|
runtime_checkable,
|
||||||
|
)
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
import pydantic.generics
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
|
||||||
|
from ._types import (
|
||||||
|
IncEx,
|
||||||
|
ModelT,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
PropertyInfo,
|
||||||
|
is_list,
|
||||||
|
is_given,
|
||||||
|
lru_cache,
|
||||||
|
is_mapping,
|
||||||
|
parse_date,
|
||||||
|
coerce_boolean,
|
||||||
|
parse_datetime,
|
||||||
|
strip_not_given,
|
||||||
|
extract_type_arg,
|
||||||
|
is_annotated_type,
|
||||||
|
strip_annotated_type,
|
||||||
|
)
|
||||||
|
from ._compat import (
|
||||||
|
PYDANTIC_V2,
|
||||||
|
ConfigDict,
|
||||||
|
GenericModel as BaseGenericModel,
|
||||||
|
get_args,
|
||||||
|
is_union,
|
||||||
|
parse_obj,
|
||||||
|
get_origin,
|
||||||
|
is_literal_type,
|
||||||
|
get_model_config,
|
||||||
|
get_model_fields,
|
||||||
|
field_get_default,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema
|
||||||
|
|
||||||
|
__all__ = ["BaseModel", "GenericModel"]
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class _ConfigProtocol(Protocol):
|
||||||
|
allow_population_by_field_name: bool
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(pydantic.BaseModel):
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||||
|
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
@property
|
||||||
|
@override
|
||||||
|
def model_fields_set(self) -> set[str]:
|
||||||
|
# a forwards-compat shim for pydantic v2
|
||||||
|
return self.__fields_set__ # type: ignore
|
||||||
|
|
||||||
|
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
|
||||||
|
extra: Any = pydantic.Extra.allow # type: ignore
|
||||||
|
|
||||||
|
def to_dict(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
mode: Literal["json", "python"] = "python",
|
||||||
|
use_api_names: bool = True,
|
||||||
|
exclude_unset: bool = True,
|
||||||
|
exclude_defaults: bool = False,
|
||||||
|
exclude_none: bool = False,
|
||||||
|
warnings: bool = True,
|
||||||
|
) -> dict[str, object]:
|
||||||
|
"""Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
|
||||||
|
|
||||||
|
By default, fields that were not set by the API will not be included,
|
||||||
|
and keys will match the API response, *not* the property names from the model.
|
||||||
|
|
||||||
|
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
|
||||||
|
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode:
|
||||||
|
If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`.
|
||||||
|
If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)`
|
||||||
|
|
||||||
|
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
|
||||||
|
exclude_unset: Whether to exclude fields that have not been explicitly set.
|
||||||
|
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
|
||||||
|
exclude_none: Whether to exclude fields that have a value of `None` from the output.
|
||||||
|
warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2.
|
||||||
|
"""
|
||||||
|
return self.model_dump(
|
||||||
|
mode=mode,
|
||||||
|
by_alias=use_api_names,
|
||||||
|
exclude_unset=exclude_unset,
|
||||||
|
exclude_defaults=exclude_defaults,
|
||||||
|
exclude_none=exclude_none,
|
||||||
|
warnings=warnings,
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_json(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
indent: int | None = 2,
|
||||||
|
use_api_names: bool = True,
|
||||||
|
exclude_unset: bool = True,
|
||||||
|
exclude_defaults: bool = False,
|
||||||
|
exclude_none: bool = False,
|
||||||
|
warnings: bool = True,
|
||||||
|
) -> str:
|
||||||
|
"""Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
|
||||||
|
|
||||||
|
By default, fields that were not set by the API will not be included,
|
||||||
|
and keys will match the API response, *not* the property names from the model.
|
||||||
|
|
||||||
|
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
|
||||||
|
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2`
|
||||||
|
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
|
||||||
|
exclude_unset: Whether to exclude fields that have not been explicitly set.
|
||||||
|
exclude_defaults: Whether to exclude fields that have the default value.
|
||||||
|
exclude_none: Whether to exclude fields that have a value of `None`.
|
||||||
|
warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2.
|
||||||
|
"""
|
||||||
|
return self.model_dump_json(
|
||||||
|
indent=indent,
|
||||||
|
by_alias=use_api_names,
|
||||||
|
exclude_unset=exclude_unset,
|
||||||
|
exclude_defaults=exclude_defaults,
|
||||||
|
exclude_none=exclude_none,
|
||||||
|
warnings=warnings,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def __str__(self) -> str:
|
||||||
|
# mypy complains about an invalid self arg
|
||||||
|
return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc]
|
||||||
|
|
||||||
|
# Override the 'construct' method in a way that supports recursive parsing without validation.
|
||||||
|
# Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.
|
||||||
|
@classmethod
|
||||||
|
@override
|
||||||
|
def construct(
|
||||||
|
cls: Type[ModelT],
|
||||||
|
_fields_set: set[str] | None = None,
|
||||||
|
**values: object,
|
||||||
|
) -> ModelT:
|
||||||
|
m = cls.__new__(cls)
|
||||||
|
fields_values: dict[str, object] = {}
|
||||||
|
|
||||||
|
config = get_model_config(cls)
|
||||||
|
populate_by_name = (
|
||||||
|
config.allow_population_by_field_name
|
||||||
|
if isinstance(config, _ConfigProtocol)
|
||||||
|
else config.get("populate_by_name")
|
||||||
|
)
|
||||||
|
|
||||||
|
if _fields_set is None:
|
||||||
|
_fields_set = set()
|
||||||
|
|
||||||
|
model_fields = get_model_fields(cls)
|
||||||
|
for name, field in model_fields.items():
|
||||||
|
key = field.alias
|
||||||
|
if key is None or (key not in values and populate_by_name):
|
||||||
|
key = name
|
||||||
|
|
||||||
|
if key in values:
|
||||||
|
fields_values[name] = _construct_field(value=values[key], field=field, key=key)
|
||||||
|
_fields_set.add(name)
|
||||||
|
else:
|
||||||
|
fields_values[name] = field_get_default(field)
|
||||||
|
|
||||||
|
_extra = {}
|
||||||
|
for key, value in values.items():
|
||||||
|
if key not in model_fields:
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
_extra[key] = value
|
||||||
|
else:
|
||||||
|
_fields_set.add(key)
|
||||||
|
fields_values[key] = value
|
||||||
|
|
||||||
|
object.__setattr__(m, "__dict__", fields_values)
|
||||||
|
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
# these properties are copied from Pydantic's `model_construct()` method
|
||||||
|
object.__setattr__(m, "__pydantic_private__", None)
|
||||||
|
object.__setattr__(m, "__pydantic_extra__", _extra)
|
||||||
|
object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
|
||||||
|
else:
|
||||||
|
# init_private_attributes() does not exist in v2
|
||||||
|
m._init_private_attributes() # type: ignore
|
||||||
|
|
||||||
|
# copied from Pydantic v1's `construct()` method
|
||||||
|
object.__setattr__(m, "__fields_set__", _fields_set)
|
||||||
|
|
||||||
|
return m
|
||||||
|
|
||||||
|
if not TYPE_CHECKING:
|
||||||
|
# type checkers incorrectly complain about this assignment
|
||||||
|
# because the type signatures are technically different
|
||||||
|
# although not in practice
|
||||||
|
model_construct = construct
|
||||||
|
|
||||||
|
if not PYDANTIC_V2:
|
||||||
|
# we define aliases for some of the new pydantic v2 methods so
|
||||||
|
# that we can just document these methods without having to specify
|
||||||
|
# a specific pydantic version as some users may not know which
|
||||||
|
# pydantic version they are currently using
|
||||||
|
|
||||||
|
@override
|
||||||
|
def model_dump(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
mode: Literal["json", "python"] | str = "python",
|
||||||
|
include: IncEx = None,
|
||||||
|
exclude: IncEx = None,
|
||||||
|
by_alias: bool = False,
|
||||||
|
exclude_unset: bool = False,
|
||||||
|
exclude_defaults: bool = False,
|
||||||
|
exclude_none: bool = False,
|
||||||
|
round_trip: bool = False,
|
||||||
|
warnings: bool | Literal["none", "warn", "error"] = True,
|
||||||
|
context: dict[str, Any] | None = None,
|
||||||
|
serialize_as_any: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
|
||||||
|
|
||||||
|
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode: The mode in which `to_python` should run.
|
||||||
|
If mode is 'json', the dictionary will only contain JSON serializable types.
|
||||||
|
If mode is 'python', the dictionary may contain any Python objects.
|
||||||
|
include: A list of fields to include in the output.
|
||||||
|
exclude: A list of fields to exclude from the output.
|
||||||
|
by_alias: Whether to use the field's alias in the dictionary key if defined.
|
||||||
|
exclude_unset: Whether to exclude fields that are unset or None from the output.
|
||||||
|
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
|
||||||
|
exclude_none: Whether to exclude fields that have a value of `None` from the output.
|
||||||
|
round_trip: Whether to enable serialization and deserialization round-trip support.
|
||||||
|
warnings: Whether to log warnings when invalid fields are encountered.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary representation of the model.
|
||||||
|
"""
|
||||||
|
if mode != "python":
|
||||||
|
raise ValueError("mode is only supported in Pydantic v2")
|
||||||
|
if round_trip != False:
|
||||||
|
raise ValueError("round_trip is only supported in Pydantic v2")
|
||||||
|
if warnings != True:
|
||||||
|
raise ValueError("warnings is only supported in Pydantic v2")
|
||||||
|
if context is not None:
|
||||||
|
raise ValueError("context is only supported in Pydantic v2")
|
||||||
|
if serialize_as_any != False:
|
||||||
|
raise ValueError("serialize_as_any is only supported in Pydantic v2")
|
||||||
|
return super().dict( # pyright: ignore[reportDeprecated]
|
||||||
|
include=include,
|
||||||
|
exclude=exclude,
|
||||||
|
by_alias=by_alias,
|
||||||
|
exclude_unset=exclude_unset,
|
||||||
|
exclude_defaults=exclude_defaults,
|
||||||
|
exclude_none=exclude_none,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def model_dump_json(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
indent: int | None = None,
|
||||||
|
include: IncEx = None,
|
||||||
|
exclude: IncEx = None,
|
||||||
|
by_alias: bool = False,
|
||||||
|
exclude_unset: bool = False,
|
||||||
|
exclude_defaults: bool = False,
|
||||||
|
exclude_none: bool = False,
|
||||||
|
round_trip: bool = False,
|
||||||
|
warnings: bool | Literal["none", "warn", "error"] = True,
|
||||||
|
context: dict[str, Any] | None = None,
|
||||||
|
serialize_as_any: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json
|
||||||
|
|
||||||
|
Generates a JSON representation of the model using Pydantic's `to_json` method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indent: Indentation to use in the JSON output. If None is passed, the output will be compact.
|
||||||
|
include: Field(s) to include in the JSON output. Can take either a string or set of strings.
|
||||||
|
exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.
|
||||||
|
by_alias: Whether to serialize using field aliases.
|
||||||
|
exclude_unset: Whether to exclude fields that have not been explicitly set.
|
||||||
|
exclude_defaults: Whether to exclude fields that have the default value.
|
||||||
|
exclude_none: Whether to exclude fields that have a value of `None`.
|
||||||
|
round_trip: Whether to use serialization/deserialization between JSON and class instance.
|
||||||
|
warnings: Whether to show any warnings that occurred during serialization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON string representation of the model.
|
||||||
|
"""
|
||||||
|
if round_trip != False:
|
||||||
|
raise ValueError("round_trip is only supported in Pydantic v2")
|
||||||
|
if warnings != True:
|
||||||
|
raise ValueError("warnings is only supported in Pydantic v2")
|
||||||
|
if context is not None:
|
||||||
|
raise ValueError("context is only supported in Pydantic v2")
|
||||||
|
if serialize_as_any != False:
|
||||||
|
raise ValueError("serialize_as_any is only supported in Pydantic v2")
|
||||||
|
return super().json( # type: ignore[reportDeprecated]
|
||||||
|
indent=indent,
|
||||||
|
include=include,
|
||||||
|
exclude=exclude,
|
||||||
|
by_alias=by_alias,
|
||||||
|
exclude_unset=exclude_unset,
|
||||||
|
exclude_defaults=exclude_defaults,
|
||||||
|
exclude_none=exclude_none,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _construct_field(value: object, field: FieldInfo, key: str) -> object:
|
||||||
|
if value is None:
|
||||||
|
return field_get_default(field)
|
||||||
|
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
type_ = field.annotation
|
||||||
|
else:
|
||||||
|
type_ = cast(type, field.outer_type_) # type: ignore
|
||||||
|
|
||||||
|
if type_ is None:
|
||||||
|
raise RuntimeError(f"Unexpected field type is None for {key}")
|
||||||
|
|
||||||
|
return construct_type(value=value, type_=type_)
|
||||||
|
|
||||||
|
|
||||||
|
def is_basemodel(type_: type) -> bool:
|
||||||
|
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
|
||||||
|
if is_union(type_):
|
||||||
|
for variant in get_args(type_):
|
||||||
|
if is_basemodel(variant):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
return is_basemodel_type(type_)
|
||||||
|
|
||||||
|
|
||||||
|
def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
|
||||||
|
origin = get_origin(type_) or type_
|
||||||
|
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
|
||||||
|
|
||||||
|
|
||||||
|
def construct_type(*, value: object, type_: object) -> object:
|
||||||
|
"""Loose coercion to the expected type with construction of nested values.
|
||||||
|
|
||||||
|
If the given value does not match the expected type then it is returned as-is.
|
||||||
|
"""
|
||||||
|
# we allow `object` as the input type because otherwise, passing things like
|
||||||
|
# `Literal['value']` will be reported as a type error by type checkers
|
||||||
|
type_ = cast("type[object]", type_)
|
||||||
|
|
||||||
|
# unwrap `Annotated[T, ...]` -> `T`
|
||||||
|
if is_annotated_type(type_):
|
||||||
|
meta: tuple[Any, ...] = get_args(type_)[1:]
|
||||||
|
type_ = extract_type_arg(type_, 0)
|
||||||
|
else:
|
||||||
|
meta = tuple()
|
||||||
|
|
||||||
|
# we need to use the origin class for any types that are subscripted generics
|
||||||
|
# e.g. Dict[str, object]
|
||||||
|
origin = get_origin(type_) or type_
|
||||||
|
args = get_args(type_)
|
||||||
|
|
||||||
|
if is_union(origin):
|
||||||
|
try:
|
||||||
|
return validate_type(type_=cast("type[object]", type_), value=value)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# if the type is a discriminated union then we want to construct the right variant
|
||||||
|
# in the union, even if the data doesn't match exactly, otherwise we'd break code
|
||||||
|
# that relies on the constructed class types, e.g.
|
||||||
|
#
|
||||||
|
# class FooType:
|
||||||
|
# kind: Literal['foo']
|
||||||
|
# value: str
|
||||||
|
#
|
||||||
|
# class BarType:
|
||||||
|
# kind: Literal['bar']
|
||||||
|
# value: int
|
||||||
|
#
|
||||||
|
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
|
||||||
|
# we'd end up constructing `FooType` when it should be `BarType`.
|
||||||
|
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
|
||||||
|
if discriminator and is_mapping(value):
|
||||||
|
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
|
||||||
|
if variant_value and isinstance(variant_value, str):
|
||||||
|
variant_type = discriminator.mapping.get(variant_value)
|
||||||
|
if variant_type:
|
||||||
|
return construct_type(type_=variant_type, value=value)
|
||||||
|
|
||||||
|
# if the data is not valid, use the first variant that doesn't fail while deserializing
|
||||||
|
for variant in args:
|
||||||
|
try:
|
||||||
|
return construct_type(value=value, type_=variant)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise RuntimeError(f"Could not convert data into a valid instance of {type_}")
|
||||||
|
|
||||||
|
if origin == dict:
|
||||||
|
if not is_mapping(value):
|
||||||
|
return value
|
||||||
|
|
||||||
|
_, items_type = get_args(type_) # Dict[_, items_type]
|
||||||
|
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
|
||||||
|
|
||||||
|
if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):
|
||||||
|
if is_list(value):
|
||||||
|
return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
|
||||||
|
|
||||||
|
if is_mapping(value):
|
||||||
|
if issubclass(type_, BaseModel):
|
||||||
|
return type_.construct(**value) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
return cast(Any, type_).construct(**value)
|
||||||
|
|
||||||
|
if origin == list:
|
||||||
|
if not is_list(value):
|
||||||
|
return value
|
||||||
|
|
||||||
|
inner_type = args[0] # List[inner_type]
|
||||||
|
return [construct_type(value=entry, type_=inner_type) for entry in value]
|
||||||
|
|
||||||
|
if origin == float:
|
||||||
|
if isinstance(value, int):
|
||||||
|
coerced = float(value)
|
||||||
|
if coerced != value:
|
||||||
|
return value
|
||||||
|
return coerced
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
if type_ == datetime:
|
||||||
|
try:
|
||||||
|
return parse_datetime(value) # type: ignore
|
||||||
|
except Exception:
|
||||||
|
return value
|
||||||
|
|
||||||
|
if type_ == date:
|
||||||
|
try:
|
||||||
|
return parse_date(value) # type: ignore
|
||||||
|
except Exception:
|
||||||
|
return value
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class CachedDiscriminatorType(Protocol):
|
||||||
|
__discriminator__: DiscriminatorDetails
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorDetails:
|
||||||
|
field_name: str
|
||||||
|
"""The name of the discriminator field in the variant class, e.g.
|
||||||
|
|
||||||
|
```py
|
||||||
|
class Foo(BaseModel):
|
||||||
|
type: Literal['foo']
|
||||||
|
```
|
||||||
|
|
||||||
|
Will result in field_name='type'
|
||||||
|
"""
|
||||||
|
|
||||||
|
field_alias_from: str | None
|
||||||
|
"""The name of the discriminator field in the API response, e.g.
|
||||||
|
|
||||||
|
```py
|
||||||
|
class Foo(BaseModel):
|
||||||
|
type: Literal['foo'] = Field(alias='type_from_api')
|
||||||
|
```
|
||||||
|
|
||||||
|
Will result in field_alias_from='type_from_api'
|
||||||
|
"""
|
||||||
|
|
||||||
|
mapping: dict[str, type]
|
||||||
|
"""Mapping of discriminator value to variant type, e.g.
|
||||||
|
|
||||||
|
{'foo': FooVariant, 'bar': BarVariant}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
mapping: dict[str, type],
|
||||||
|
discriminator_field: str,
|
||||||
|
discriminator_alias: str | None,
|
||||||
|
) -> None:
|
||||||
|
self.mapping = mapping
|
||||||
|
self.field_name = discriminator_field
|
||||||
|
self.field_alias_from = discriminator_alias
|
||||||
|
|
||||||
|
|
||||||
|
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
|
||||||
|
if isinstance(union, CachedDiscriminatorType):
|
||||||
|
return union.__discriminator__
|
||||||
|
|
||||||
|
discriminator_field_name: str | None = None
|
||||||
|
|
||||||
|
for annotation in meta_annotations:
|
||||||
|
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
|
||||||
|
discriminator_field_name = annotation.discriminator
|
||||||
|
break
|
||||||
|
|
||||||
|
if not discriminator_field_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
mapping: dict[str, type] = {}
|
||||||
|
discriminator_alias: str | None = None
|
||||||
|
|
||||||
|
for variant in get_args(union):
|
||||||
|
variant = strip_annotated_type(variant)
|
||||||
|
if is_basemodel_type(variant):
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
field = _extract_field_schema_pv2(variant, discriminator_field_name)
|
||||||
|
if not field:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Note: if one variant defines an alias then they all should
|
||||||
|
discriminator_alias = field.get("serialization_alias")
|
||||||
|
|
||||||
|
field_schema = field["schema"]
|
||||||
|
|
||||||
|
if field_schema["type"] == "literal":
|
||||||
|
for entry in cast("LiteralSchema", field_schema)["expected"]:
|
||||||
|
if isinstance(entry, str):
|
||||||
|
mapping[entry] = variant
|
||||||
|
else:
|
||||||
|
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
||||||
|
if not field_info:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Note: if one variant defines an alias then they all should
|
||||||
|
discriminator_alias = field_info.alias
|
||||||
|
|
||||||
|
if field_info.annotation and is_literal_type(field_info.annotation):
|
||||||
|
for entry in get_args(field_info.annotation):
|
||||||
|
if isinstance(entry, str):
|
||||||
|
mapping[entry] = variant
|
||||||
|
|
||||||
|
if not mapping:
|
||||||
|
return None
|
||||||
|
|
||||||
|
details = DiscriminatorDetails(
|
||||||
|
mapping=mapping,
|
||||||
|
discriminator_field=discriminator_field_name,
|
||||||
|
discriminator_alias=discriminator_alias,
|
||||||
|
)
|
||||||
|
cast(CachedDiscriminatorType, union).__discriminator__ = details
|
||||||
|
return details
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
|
||||||
|
schema = model.__pydantic_core_schema__
|
||||||
|
if schema["type"] != "model":
|
||||||
|
return None
|
||||||
|
|
||||||
|
fields_schema = schema["schema"]
|
||||||
|
if fields_schema["type"] != "model-fields":
|
||||||
|
return None
|
||||||
|
|
||||||
|
fields_schema = cast("ModelFieldsSchema", fields_schema)
|
||||||
|
|
||||||
|
field = fields_schema["fields"].get(field_name)
|
||||||
|
if not field:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
|
||||||
|
|
||||||
|
|
||||||
|
def validate_type(*, type_: type[_T], value: object) -> _T:
|
||||||
|
"""Strict validation that the given value matches the expected type"""
|
||||||
|
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
|
||||||
|
return cast(_T, parse_obj(type_, value))
|
||||||
|
|
||||||
|
return cast(_T, _validate_non_model_type(type_=type_, value=value))
|
||||||
|
|
||||||
|
|
||||||
|
# our use of subclasssing here causes weirdness for type checkers,
|
||||||
|
# so we just pretend that we don't subclass
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
GenericModel = BaseModel
|
||||||
|
else:
|
||||||
|
|
||||||
|
class GenericModel(BaseGenericModel, BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
from pydantic import TypeAdapter as _TypeAdapter
|
||||||
|
|
||||||
|
_CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter))
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
else:
|
||||||
|
TypeAdapter = _CachedTypeAdapter
|
||||||
|
|
||||||
|
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
|
||||||
|
return TypeAdapter(type_).validate_python(value)
|
||||||
|
|
||||||
|
elif not TYPE_CHECKING: # TODO: condition is weird
|
||||||
|
|
||||||
|
class RootModel(GenericModel, Generic[_T]):
|
||||||
|
"""Used as a placeholder to easily convert runtime types to a Pydantic format
|
||||||
|
to provide validation.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
```py
|
||||||
|
validated = RootModel[int](__root__="5").__root__
|
||||||
|
# validated: 5
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
__root__: _T
|
||||||
|
|
||||||
|
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
|
||||||
|
model = _create_pydantic_model(type_).validate(value)
|
||||||
|
return cast(_T, model.__root__)
|
||||||
|
|
||||||
|
def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
|
||||||
|
return RootModel[type_] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
220
model-providers/model_providers/_types.py
Normal file
220
model-providers/model_providers/_types.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from os import PathLike
|
||||||
|
from typing import (
|
||||||
|
IO,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Type,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
Mapping,
|
||||||
|
TypeVar,
|
||||||
|
Callable,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
)
|
||||||
|
from typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pydantic
|
||||||
|
from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ._models import BaseModel
|
||||||
|
|
||||||
|
Transport = BaseTransport
|
||||||
|
AsyncTransport = AsyncBaseTransport
|
||||||
|
Query = Mapping[str, object]
|
||||||
|
Body = object
|
||||||
|
AnyMapping = Mapping[str, object]
|
||||||
|
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
# Approximates httpx internal ProxiesTypes and RequestFiles types
|
||||||
|
# while adding support for `PathLike` instances
|
||||||
|
ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]]
|
||||||
|
ProxiesTypes = Union[str, Proxy, ProxiesDict]
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
Base64FileInput = Union[IO[bytes], PathLike[str]]
|
||||||
|
FileContent = Union[IO[bytes], bytes, PathLike[str]]
|
||||||
|
else:
|
||||||
|
Base64FileInput = Union[IO[bytes], PathLike]
|
||||||
|
FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.
|
||||||
|
FileTypes = Union[
|
||||||
|
# file (or bytes)
|
||||||
|
FileContent,
|
||||||
|
# (filename, file (or bytes))
|
||||||
|
Tuple[Optional[str], FileContent],
|
||||||
|
# (filename, file (or bytes), content_type)
|
||||||
|
Tuple[Optional[str], FileContent, Optional[str]],
|
||||||
|
# (filename, file (or bytes), content_type, headers)
|
||||||
|
Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
|
||||||
|
]
|
||||||
|
RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
|
||||||
|
|
||||||
|
# duplicate of the above but without our custom file support
|
||||||
|
HttpxFileContent = Union[IO[bytes], bytes]
|
||||||
|
HttpxFileTypes = Union[
|
||||||
|
# file (or bytes)
|
||||||
|
HttpxFileContent,
|
||||||
|
# (filename, file (or bytes))
|
||||||
|
Tuple[Optional[str], HttpxFileContent],
|
||||||
|
# (filename, file (or bytes), content_type)
|
||||||
|
Tuple[Optional[str], HttpxFileContent, Optional[str]],
|
||||||
|
# (filename, file (or bytes), content_type, headers)
|
||||||
|
Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
|
||||||
|
]
|
||||||
|
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]
|
||||||
|
|
||||||
|
# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT
|
||||||
|
# where ResponseT includes `None`. In order to support directly
|
||||||
|
# passing `None`, overloads would have to be defined for every
|
||||||
|
# method that uses `ResponseT` which would lead to an unacceptable
|
||||||
|
# amount of code duplication and make it unreadable. See _base_client.py
|
||||||
|
# for example usage.
|
||||||
|
#
|
||||||
|
# This unfortunately means that you will either have
|
||||||
|
# to import this type and pass it explicitly:
|
||||||
|
#
|
||||||
|
# from openai import NoneType
|
||||||
|
# client.get('/foo', cast_to=NoneType)
|
||||||
|
#
|
||||||
|
# or build it yourself:
|
||||||
|
#
|
||||||
|
# client.get('/foo', cast_to=type(None))
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
NoneType: Type[None]
|
||||||
|
else:
|
||||||
|
NoneType = type(None)
|
||||||
|
|
||||||
|
|
||||||
|
class RequestOptions(TypedDict, total=False):
|
||||||
|
headers: Headers
|
||||||
|
max_retries: int
|
||||||
|
timeout: float | Timeout | None
|
||||||
|
params: Query
|
||||||
|
extra_json: AnyMapping
|
||||||
|
idempotency_key: str
|
||||||
|
|
||||||
|
|
||||||
|
# Sentinel class used until PEP 0661 is accepted
|
||||||
|
class NotGiven:
|
||||||
|
"""
|
||||||
|
A sentinel singleton class used to distinguish omitted keyword arguments
|
||||||
|
from those passed in with the value None (which may have different behavior).
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
get(timeout=1) # 1s timeout
|
||||||
|
get(timeout=None) # No timeout
|
||||||
|
get() # Default timeout behavior, which may not be statically known at the method definition.
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __bool__(self) -> Literal[False]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@override
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "NOT_GIVEN"
|
||||||
|
|
||||||
|
|
||||||
|
NotGivenOr = Union[_T, NotGiven]
|
||||||
|
NOT_GIVEN = NotGiven()
|
||||||
|
|
||||||
|
|
||||||
|
class Omit:
|
||||||
|
"""In certain situations you need to be able to represent a case where a default value has
|
||||||
|
to be explicitly removed and `None` is not an appropriate substitute, for example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
# as the default `Content-Type` header is `application/json` that will be sent
|
||||||
|
client.post("/upload/files", files={"file": b"my raw file content"})
|
||||||
|
|
||||||
|
# you can't explicitly override the header as it has to be dynamically generated
|
||||||
|
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
|
||||||
|
client.post(..., headers={"Content-Type": "multipart/form-data"})
|
||||||
|
|
||||||
|
# instead you can remove the default `application/json` header by passing Omit
|
||||||
|
client.post(..., headers={"Content-Type": Omit()})
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __bool__(self) -> Literal[False]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class ModelBuilderProtocol(Protocol):
|
||||||
|
@classmethod
|
||||||
|
def build(
|
||||||
|
cls: type[_T],
|
||||||
|
*,
|
||||||
|
response: Response,
|
||||||
|
data: object,
|
||||||
|
) -> _T:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
Headers = Mapping[str, Union[str, Omit]]
|
||||||
|
|
||||||
|
|
||||||
|
class HeadersLikeProtocol(Protocol):
|
||||||
|
def get(self, __key: str) -> str | None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
HeadersLike = Union[Headers, HeadersLikeProtocol]
|
||||||
|
|
||||||
|
ResponseT = TypeVar(
|
||||||
|
"ResponseT",
|
||||||
|
bound=Union[
|
||||||
|
object,
|
||||||
|
str,
|
||||||
|
None,
|
||||||
|
"BaseModel",
|
||||||
|
List[Any],
|
||||||
|
Dict[str, Any],
|
||||||
|
Response,
|
||||||
|
ModelBuilderProtocol,
|
||||||
|
"APIResponse[Any]",
|
||||||
|
"AsyncAPIResponse[Any]",
|
||||||
|
"HttpxBinaryResponseContent",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
StrBytesIntFloat = Union[str, bytes, int, float]
|
||||||
|
|
||||||
|
# Note: copied from Pydantic
|
||||||
|
# https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49
|
||||||
|
IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None"
|
||||||
|
|
||||||
|
PostParser = Callable[[Any], Any]
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class InheritsGeneric(Protocol):
|
||||||
|
"""Represents a type that has inherited from `Generic`
|
||||||
|
|
||||||
|
The `__orig_bases__` property can be used to determine the resolved
|
||||||
|
type variable for a given base class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__orig_bases__: tuple[_GenericAlias]
|
||||||
|
|
||||||
|
|
||||||
|
class _GenericAlias(Protocol):
|
||||||
|
__origin__: type[object]
|
||||||
|
|
||||||
|
|
||||||
|
class HttpxSendArgs(TypedDict, total=False):
|
||||||
|
auth: httpx.Auth
|
||||||
48
model-providers/model_providers/_utils/__init__.py
Normal file
48
model-providers/model_providers/_utils/__init__.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from ._utils import (
|
||||||
|
flatten as flatten,
|
||||||
|
is_dict as is_dict,
|
||||||
|
is_list as is_list,
|
||||||
|
is_given as is_given,
|
||||||
|
is_tuple as is_tuple,
|
||||||
|
lru_cache as lru_cache,
|
||||||
|
is_mapping as is_mapping,
|
||||||
|
is_tuple_t as is_tuple_t,
|
||||||
|
parse_date as parse_date,
|
||||||
|
is_iterable as is_iterable,
|
||||||
|
is_sequence as is_sequence,
|
||||||
|
coerce_float as coerce_float,
|
||||||
|
is_mapping_t as is_mapping_t,
|
||||||
|
removeprefix as removeprefix,
|
||||||
|
removesuffix as removesuffix,
|
||||||
|
extract_files as extract_files,
|
||||||
|
is_sequence_t as is_sequence_t,
|
||||||
|
required_args as required_args,
|
||||||
|
coerce_boolean as coerce_boolean,
|
||||||
|
coerce_integer as coerce_integer,
|
||||||
|
file_from_path as file_from_path,
|
||||||
|
parse_datetime as parse_datetime,
|
||||||
|
strip_not_given as strip_not_given,
|
||||||
|
deepcopy_minimal as deepcopy_minimal,
|
||||||
|
get_async_library as get_async_library,
|
||||||
|
maybe_coerce_float as maybe_coerce_float,
|
||||||
|
get_required_header as get_required_header,
|
||||||
|
maybe_coerce_boolean as maybe_coerce_boolean,
|
||||||
|
maybe_coerce_integer as maybe_coerce_integer,
|
||||||
|
)
|
||||||
|
from ._typing import (
|
||||||
|
is_list_type as is_list_type,
|
||||||
|
is_union_type as is_union_type,
|
||||||
|
extract_type_arg as extract_type_arg,
|
||||||
|
is_iterable_type as is_iterable_type,
|
||||||
|
is_required_type as is_required_type,
|
||||||
|
is_annotated_type as is_annotated_type,
|
||||||
|
strip_annotated_type as strip_annotated_type,
|
||||||
|
extract_type_var_from_base as extract_type_var_from_base,
|
||||||
|
)
|
||||||
|
from ._transform import (
|
||||||
|
PropertyInfo as PropertyInfo,
|
||||||
|
transform as transform,
|
||||||
|
async_transform as async_transform,
|
||||||
|
maybe_transform as maybe_transform,
|
||||||
|
async_maybe_transform as async_maybe_transform,
|
||||||
|
)
|
||||||
382
model-providers/model_providers/_utils/_transform.py
Normal file
382
model-providers/model_providers/_utils/_transform.py
Normal file
@ -0,0 +1,382 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
import pathlib
|
||||||
|
from typing import Any, Mapping, TypeVar, cast
|
||||||
|
from datetime import date, datetime
|
||||||
|
from typing_extensions import Literal, get_args, override, get_type_hints
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
import pydantic
|
||||||
|
|
||||||
|
from ._utils import (
|
||||||
|
is_list,
|
||||||
|
is_mapping,
|
||||||
|
is_iterable,
|
||||||
|
)
|
||||||
|
from .._files import is_base64_file_input
|
||||||
|
from ._typing import (
|
||||||
|
is_list_type,
|
||||||
|
is_union_type,
|
||||||
|
extract_type_arg,
|
||||||
|
is_iterable_type,
|
||||||
|
is_required_type,
|
||||||
|
is_annotated_type,
|
||||||
|
strip_annotated_type,
|
||||||
|
)
|
||||||
|
from .._compat import model_dump, is_typeddict
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: support for drilling globals() and locals()
|
||||||
|
# TODO: ensure works correctly with forward references in all cases
|
||||||
|
|
||||||
|
|
||||||
|
PropertyFormat = Literal["iso8601", "base64", "custom"]
|
||||||
|
|
||||||
|
|
||||||
|
class PropertyInfo:
|
||||||
|
"""Metadata class to be used in Annotated types to provide information about a given type.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
class MyParams(TypedDict):
|
||||||
|
account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
|
||||||
|
|
||||||
|
This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
alias: str | None
|
||||||
|
format: PropertyFormat | None
|
||||||
|
format_template: str | None
|
||||||
|
discriminator: str | None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
alias: str | None = None,
|
||||||
|
format: PropertyFormat | None = None,
|
||||||
|
format_template: str | None = None,
|
||||||
|
discriminator: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.alias = alias
|
||||||
|
self.format = format
|
||||||
|
self.format_template = format_template
|
||||||
|
self.discriminator = discriminator
|
||||||
|
|
||||||
|
@override
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_transform(
|
||||||
|
data: object,
|
||||||
|
expected_type: object,
|
||||||
|
) -> Any | None:
|
||||||
|
"""Wrapper over `transform()` that allows `None` to be passed.
|
||||||
|
|
||||||
|
See `transform()` for more details.
|
||||||
|
"""
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
return transform(data, expected_type)
|
||||||
|
|
||||||
|
|
||||||
|
# Wrapper over _transform_recursive providing fake types
|
||||||
|
def transform(
|
||||||
|
data: _T,
|
||||||
|
expected_type: object,
|
||||||
|
) -> _T:
|
||||||
|
"""Transform dictionaries based off of type information from the given type, for example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
class Params(TypedDict, total=False):
|
||||||
|
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
|
||||||
|
|
||||||
|
|
||||||
|
transformed = transform({"card_id": "<my card ID>"}, Params)
|
||||||
|
# {'cardID': '<my card ID>'}
|
||||||
|
```
|
||||||
|
|
||||||
|
Any keys / data that does not have type information given will be included as is.
|
||||||
|
|
||||||
|
It should be noted that the transformations that this function does are not represented in the type system.
|
||||||
|
"""
|
||||||
|
transformed = _transform_recursive(data, annotation=cast(type, expected_type))
|
||||||
|
return cast(_T, transformed)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_annotated_type(type_: type) -> type | None:
|
||||||
|
"""If the given type is an `Annotated` type then it is returned, if not `None` is returned.
|
||||||
|
|
||||||
|
This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
|
||||||
|
"""
|
||||||
|
if is_required_type(type_):
|
||||||
|
# Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
|
||||||
|
type_ = get_args(type_)[0]
|
||||||
|
|
||||||
|
if is_annotated_type(type_):
|
||||||
|
return type_
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_transform_key(key: str, type_: type) -> str:
|
||||||
|
"""Transform the given `data` based on the annotations provided in `type_`.
|
||||||
|
|
||||||
|
Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata.
|
||||||
|
"""
|
||||||
|
annotated_type = _get_annotated_type(type_)
|
||||||
|
if annotated_type is None:
|
||||||
|
# no `Annotated` definition for this type, no transformation needed
|
||||||
|
return key
|
||||||
|
|
||||||
|
# ignore the first argument as it is the actual type
|
||||||
|
annotations = get_args(annotated_type)[1:]
|
||||||
|
for annotation in annotations:
|
||||||
|
if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
|
||||||
|
return annotation.alias
|
||||||
|
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_recursive(
|
||||||
|
data: object,
|
||||||
|
*,
|
||||||
|
annotation: type,
|
||||||
|
inner_type: type | None = None,
|
||||||
|
) -> object:
|
||||||
|
"""Transform the given data against the expected type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
annotation: The direct type annotation given to the particular piece of data.
|
||||||
|
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
|
||||||
|
|
||||||
|
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
|
||||||
|
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
|
||||||
|
the list can be transformed using the metadata from the container type.
|
||||||
|
|
||||||
|
Defaults to the same value as the `annotation` argument.
|
||||||
|
"""
|
||||||
|
if inner_type is None:
|
||||||
|
inner_type = annotation
|
||||||
|
|
||||||
|
stripped_type = strip_annotated_type(inner_type)
|
||||||
|
if is_typeddict(stripped_type) and is_mapping(data):
|
||||||
|
return _transform_typeddict(data, stripped_type)
|
||||||
|
|
||||||
|
if (
|
||||||
|
# List[T]
|
||||||
|
(is_list_type(stripped_type) and is_list(data))
|
||||||
|
# Iterable[T]
|
||||||
|
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
|
||||||
|
):
|
||||||
|
inner_type = extract_type_arg(stripped_type, 0)
|
||||||
|
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
|
||||||
|
|
||||||
|
if is_union_type(stripped_type):
|
||||||
|
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
|
||||||
|
#
|
||||||
|
# TODO: there may be edge cases where the same normalized field name will transform to two different names
|
||||||
|
# in different subtypes.
|
||||||
|
for subtype in get_args(stripped_type):
|
||||||
|
data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
|
||||||
|
return data
|
||||||
|
|
||||||
|
if isinstance(data, pydantic.BaseModel):
|
||||||
|
return model_dump(data, exclude_unset=True)
|
||||||
|
|
||||||
|
annotated_type = _get_annotated_type(annotation)
|
||||||
|
if annotated_type is None:
|
||||||
|
return data
|
||||||
|
|
||||||
|
# ignore the first argument as it is the actual type
|
||||||
|
annotations = get_args(annotated_type)[1:]
|
||||||
|
for annotation in annotations:
|
||||||
|
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
|
||||||
|
return _format_data(data, annotation.format, annotation.format_template)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
|
||||||
|
if isinstance(data, (date, datetime)):
|
||||||
|
if format_ == "iso8601":
|
||||||
|
return data.isoformat()
|
||||||
|
|
||||||
|
if format_ == "custom" and format_template is not None:
|
||||||
|
return data.strftime(format_template)
|
||||||
|
|
||||||
|
if format_ == "base64" and is_base64_file_input(data):
|
||||||
|
binary: str | bytes | None = None
|
||||||
|
|
||||||
|
if isinstance(data, pathlib.Path):
|
||||||
|
binary = data.read_bytes()
|
||||||
|
elif isinstance(data, io.IOBase):
|
||||||
|
binary = data.read()
|
||||||
|
|
||||||
|
if isinstance(binary, str): # type: ignore[unreachable]
|
||||||
|
binary = binary.encode()
|
||||||
|
|
||||||
|
if not isinstance(binary, bytes):
|
||||||
|
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
|
||||||
|
|
||||||
|
return base64.b64encode(binary).decode("ascii")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_typeddict(
|
||||||
|
data: Mapping[str, object],
|
||||||
|
expected_type: type,
|
||||||
|
) -> Mapping[str, object]:
|
||||||
|
result: dict[str, object] = {}
|
||||||
|
annotations = get_type_hints(expected_type, include_extras=True)
|
||||||
|
for key, value in data.items():
|
||||||
|
type_ = annotations.get(key)
|
||||||
|
if type_ is None:
|
||||||
|
# we do not have a type annotation for this field, leave it as is
|
||||||
|
result[key] = value
|
||||||
|
else:
|
||||||
|
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def async_maybe_transform(
|
||||||
|
data: object,
|
||||||
|
expected_type: object,
|
||||||
|
) -> Any | None:
|
||||||
|
"""Wrapper over `async_transform()` that allows `None` to be passed.
|
||||||
|
|
||||||
|
See `async_transform()` for more details.
|
||||||
|
"""
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
return await async_transform(data, expected_type)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_transform(
|
||||||
|
data: _T,
|
||||||
|
expected_type: object,
|
||||||
|
) -> _T:
|
||||||
|
"""Transform dictionaries based off of type information from the given type, for example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
class Params(TypedDict, total=False):
|
||||||
|
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
|
||||||
|
|
||||||
|
|
||||||
|
transformed = transform({"card_id": "<my card ID>"}, Params)
|
||||||
|
# {'cardID': '<my card ID>'}
|
||||||
|
```
|
||||||
|
|
||||||
|
Any keys / data that does not have type information given will be included as is.
|
||||||
|
|
||||||
|
It should be noted that the transformations that this function does are not represented in the type system.
|
||||||
|
"""
|
||||||
|
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
|
||||||
|
return cast(_T, transformed)
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_transform_recursive(
|
||||||
|
data: object,
|
||||||
|
*,
|
||||||
|
annotation: type,
|
||||||
|
inner_type: type | None = None,
|
||||||
|
) -> object:
|
||||||
|
"""Transform the given data against the expected type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
annotation: The direct type annotation given to the particular piece of data.
|
||||||
|
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
|
||||||
|
|
||||||
|
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
|
||||||
|
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
|
||||||
|
the list can be transformed using the metadata from the container type.
|
||||||
|
|
||||||
|
Defaults to the same value as the `annotation` argument.
|
||||||
|
"""
|
||||||
|
if inner_type is None:
|
||||||
|
inner_type = annotation
|
||||||
|
|
||||||
|
stripped_type = strip_annotated_type(inner_type)
|
||||||
|
if is_typeddict(stripped_type) and is_mapping(data):
|
||||||
|
return await _async_transform_typeddict(data, stripped_type)
|
||||||
|
|
||||||
|
if (
|
||||||
|
# List[T]
|
||||||
|
(is_list_type(stripped_type) and is_list(data))
|
||||||
|
# Iterable[T]
|
||||||
|
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
|
||||||
|
):
|
||||||
|
inner_type = extract_type_arg(stripped_type, 0)
|
||||||
|
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
|
||||||
|
|
||||||
|
if is_union_type(stripped_type):
|
||||||
|
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
|
||||||
|
#
|
||||||
|
# TODO: there may be edge cases where the same normalized field name will transform to two different names
|
||||||
|
# in different subtypes.
|
||||||
|
for subtype in get_args(stripped_type):
|
||||||
|
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
|
||||||
|
return data
|
||||||
|
|
||||||
|
if isinstance(data, pydantic.BaseModel):
|
||||||
|
return model_dump(data, exclude_unset=True)
|
||||||
|
|
||||||
|
annotated_type = _get_annotated_type(annotation)
|
||||||
|
if annotated_type is None:
|
||||||
|
return data
|
||||||
|
|
||||||
|
# ignore the first argument as it is the actual type
|
||||||
|
annotations = get_args(annotated_type)[1:]
|
||||||
|
for annotation in annotations:
|
||||||
|
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
|
||||||
|
return await _async_format_data(data, annotation.format, annotation.format_template)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
|
||||||
|
if isinstance(data, (date, datetime)):
|
||||||
|
if format_ == "iso8601":
|
||||||
|
return data.isoformat()
|
||||||
|
|
||||||
|
if format_ == "custom" and format_template is not None:
|
||||||
|
return data.strftime(format_template)
|
||||||
|
|
||||||
|
if format_ == "base64" and is_base64_file_input(data):
|
||||||
|
binary: str | bytes | None = None
|
||||||
|
|
||||||
|
if isinstance(data, pathlib.Path):
|
||||||
|
binary = await anyio.Path(data).read_bytes()
|
||||||
|
elif isinstance(data, io.IOBase):
|
||||||
|
binary = data.read()
|
||||||
|
|
||||||
|
if isinstance(binary, str): # type: ignore[unreachable]
|
||||||
|
binary = binary.encode()
|
||||||
|
|
||||||
|
if not isinstance(binary, bytes):
|
||||||
|
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
|
||||||
|
|
||||||
|
return base64.b64encode(binary).decode("ascii")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_transform_typeddict(
|
||||||
|
data: Mapping[str, object],
|
||||||
|
expected_type: type,
|
||||||
|
) -> Mapping[str, object]:
|
||||||
|
result: dict[str, object] = {}
|
||||||
|
annotations = get_type_hints(expected_type, include_extras=True)
|
||||||
|
for key, value in data.items():
|
||||||
|
type_ = annotations.get(key)
|
||||||
|
if type_ is None:
|
||||||
|
# we do not have a type annotation for this field, leave it as is
|
||||||
|
result[key] = value
|
||||||
|
else:
|
||||||
|
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
|
||||||
|
return result
|
||||||
120
model-providers/model_providers/_utils/_typing.py
Normal file
120
model-providers/model_providers/_utils/_typing.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, TypeVar, Iterable, cast
|
||||||
|
from collections import abc as _c_abc
|
||||||
|
from typing_extensions import Required, Annotated, get_args, get_origin
|
||||||
|
|
||||||
|
from .._types import InheritsGeneric
|
||||||
|
from .._compat import is_union as _is_union
|
||||||
|
|
||||||
|
|
||||||
|
def is_annotated_type(typ: type) -> bool:
|
||||||
|
return get_origin(typ) == Annotated
|
||||||
|
|
||||||
|
|
||||||
|
def is_list_type(typ: type) -> bool:
|
||||||
|
return (get_origin(typ) or typ) == list
|
||||||
|
|
||||||
|
|
||||||
|
def is_iterable_type(typ: type) -> bool:
|
||||||
|
"""If the given type is `typing.Iterable[T]`"""
|
||||||
|
origin = get_origin(typ) or typ
|
||||||
|
return origin == Iterable or origin == _c_abc.Iterable
|
||||||
|
|
||||||
|
|
||||||
|
def is_union_type(typ: type) -> bool:
|
||||||
|
return _is_union(get_origin(typ))
|
||||||
|
|
||||||
|
|
||||||
|
def is_required_type(typ: type) -> bool:
|
||||||
|
return get_origin(typ) == Required
|
||||||
|
|
||||||
|
|
||||||
|
def is_typevar(typ: type) -> bool:
|
||||||
|
# type ignore is required because type checkers
|
||||||
|
# think this expression will always return False
|
||||||
|
return type(typ) == TypeVar # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
|
||||||
|
def strip_annotated_type(typ: type) -> type:
|
||||||
|
if is_required_type(typ) or is_annotated_type(typ):
|
||||||
|
return strip_annotated_type(cast(type, get_args(typ)[0]))
|
||||||
|
|
||||||
|
return typ
|
||||||
|
|
||||||
|
|
||||||
|
def extract_type_arg(typ: type, index: int) -> type:
|
||||||
|
args = get_args(typ)
|
||||||
|
try:
|
||||||
|
return cast(type, args[index])
|
||||||
|
except IndexError as err:
|
||||||
|
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
|
||||||
|
|
||||||
|
|
||||||
|
def extract_type_var_from_base(
|
||||||
|
typ: type,
|
||||||
|
*,
|
||||||
|
generic_bases: tuple[type, ...],
|
||||||
|
index: int,
|
||||||
|
failure_message: str | None = None,
|
||||||
|
) -> type:
|
||||||
|
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
|
||||||
|
|
||||||
|
This also handles the case where a concrete subclass is given, e.g.
|
||||||
|
```py
|
||||||
|
class MyResponse(Foo[bytes]):
|
||||||
|
...
|
||||||
|
|
||||||
|
extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
|
||||||
|
```
|
||||||
|
|
||||||
|
And where a generic subclass is given:
|
||||||
|
```py
|
||||||
|
_T = TypeVar('_T')
|
||||||
|
class MyResponse(Foo[_T]):
|
||||||
|
...
|
||||||
|
|
||||||
|
extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
cls = cast(object, get_origin(typ) or typ)
|
||||||
|
if cls in generic_bases:
|
||||||
|
# we're given the class directly
|
||||||
|
return extract_type_arg(typ, index)
|
||||||
|
|
||||||
|
# if a subclass is given
|
||||||
|
# ---
|
||||||
|
# this is needed as __orig_bases__ is not present in the typeshed stubs
|
||||||
|
# because it is intended to be for internal use only, however there does
|
||||||
|
# not seem to be a way to resolve generic TypeVars for inherited subclasses
|
||||||
|
# without using it.
|
||||||
|
if isinstance(cls, InheritsGeneric):
|
||||||
|
target_base_class: Any | None = None
|
||||||
|
for base in cls.__orig_bases__:
|
||||||
|
if base.__origin__ in generic_bases:
|
||||||
|
target_base_class = base
|
||||||
|
break
|
||||||
|
|
||||||
|
if target_base_class is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Could not find the generic base class;\n"
|
||||||
|
"This should never happen;\n"
|
||||||
|
f"Does {cls} inherit from one of {generic_bases} ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted = extract_type_arg(target_base_class, index)
|
||||||
|
if is_typevar(extracted):
|
||||||
|
# If the extracted type argument is itself a type variable
|
||||||
|
# then that means the subclass itself is generic, so we have
|
||||||
|
# to resolve the type argument from the class itself, not
|
||||||
|
# the base class.
|
||||||
|
#
|
||||||
|
# Note: if there is more than 1 type argument, the subclass could
|
||||||
|
# change the ordering of the type arguments, this is not currently
|
||||||
|
# supported.
|
||||||
|
return extract_type_arg(typ, index)
|
||||||
|
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")
|
||||||
403
model-providers/model_providers/_utils/_utils.py
Normal file
403
model-providers/model_providers/_utils/_utils.py
Normal file
@ -0,0 +1,403 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import inspect
|
||||||
|
import functools
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Tuple,
|
||||||
|
Mapping,
|
||||||
|
TypeVar,
|
||||||
|
Callable,
|
||||||
|
Iterable,
|
||||||
|
Sequence,
|
||||||
|
cast,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
from pathlib import Path
|
||||||
|
from typing_extensions import TypeGuard
|
||||||
|
|
||||||
|
import sniffio
|
||||||
|
|
||||||
|
from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike
|
||||||
|
from .._compat import parse_date as parse_date, parse_datetime as parse_datetime
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
|
||||||
|
_MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
|
||||||
|
_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
|
||||||
|
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
|
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
|
||||||
|
return [item for sublist in t for item in sublist]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_files(
|
||||||
|
# TODO: this needs to take Dict but variance issues.....
|
||||||
|
# create protocol type ?
|
||||||
|
query: Mapping[str, object],
|
||||||
|
*,
|
||||||
|
paths: Sequence[Sequence[str]],
|
||||||
|
) -> list[tuple[str, FileTypes]]:
|
||||||
|
"""Recursively extract files from the given dictionary based on specified paths.
|
||||||
|
|
||||||
|
A path may look like this ['foo', 'files', '<array>', 'data'].
|
||||||
|
|
||||||
|
Note: this mutates the given dictionary.
|
||||||
|
"""
|
||||||
|
files: list[tuple[str, FileTypes]] = []
|
||||||
|
for path in paths:
|
||||||
|
files.extend(_extract_items(query, path, index=0, flattened_key=None))
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_items(
|
||||||
|
obj: object,
|
||||||
|
path: Sequence[str],
|
||||||
|
*,
|
||||||
|
index: int,
|
||||||
|
flattened_key: str | None,
|
||||||
|
) -> list[tuple[str, FileTypes]]:
|
||||||
|
try:
|
||||||
|
key = path[index]
|
||||||
|
except IndexError:
|
||||||
|
if isinstance(obj, NotGiven):
|
||||||
|
# no value was provided - we can safely ignore
|
||||||
|
return []
|
||||||
|
|
||||||
|
# cyclical import
|
||||||
|
from .._files import assert_is_file_content
|
||||||
|
|
||||||
|
# We have exhausted the path, return the entry we found.
|
||||||
|
assert_is_file_content(obj, key=flattened_key)
|
||||||
|
assert flattened_key is not None
|
||||||
|
return [(flattened_key, cast(FileTypes, obj))]
|
||||||
|
|
||||||
|
index += 1
|
||||||
|
if is_dict(obj):
|
||||||
|
try:
|
||||||
|
# We are at the last entry in the path so we must remove the field
|
||||||
|
if (len(path)) == index:
|
||||||
|
item = obj.pop(key)
|
||||||
|
else:
|
||||||
|
item = obj[key]
|
||||||
|
except KeyError:
|
||||||
|
# Key was not present in the dictionary, this is not indicative of an error
|
||||||
|
# as the given path may not point to a required field. We also do not want
|
||||||
|
# to enforce required fields as the API may differ from the spec in some cases.
|
||||||
|
return []
|
||||||
|
if flattened_key is None:
|
||||||
|
flattened_key = key
|
||||||
|
else:
|
||||||
|
flattened_key += f"[{key}]"
|
||||||
|
return _extract_items(
|
||||||
|
item,
|
||||||
|
path,
|
||||||
|
index=index,
|
||||||
|
flattened_key=flattened_key,
|
||||||
|
)
|
||||||
|
elif is_list(obj):
|
||||||
|
if key != "<array>":
|
||||||
|
return []
|
||||||
|
|
||||||
|
return flatten(
|
||||||
|
[
|
||||||
|
_extract_items(
|
||||||
|
item,
|
||||||
|
path,
|
||||||
|
index=index,
|
||||||
|
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
|
||||||
|
)
|
||||||
|
for item in obj
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Something unexpected was passed, just ignore it.
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]:
|
||||||
|
return not isinstance(obj, NotGiven)
|
||||||
|
|
||||||
|
|
||||||
|
# Type safe methods for narrowing types with TypeVars.
|
||||||
|
# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
|
||||||
|
# however this cause Pyright to rightfully report errors. As we know we don't
|
||||||
|
# care about the contained types we can safely use `object` in it's place.
|
||||||
|
#
|
||||||
|
# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
|
||||||
|
# `is_*` is for when you're dealing with an unknown input
|
||||||
|
# `is_*_t` is for when you're narrowing a known union type to a specific subset
|
||||||
|
|
||||||
|
|
||||||
|
def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
|
||||||
|
return isinstance(obj, tuple)
|
||||||
|
|
||||||
|
|
||||||
|
def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
|
||||||
|
return isinstance(obj, tuple)
|
||||||
|
|
||||||
|
|
||||||
|
def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
|
||||||
|
return isinstance(obj, Sequence)
|
||||||
|
|
||||||
|
|
||||||
|
def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
|
||||||
|
return isinstance(obj, Sequence)
|
||||||
|
|
||||||
|
|
||||||
|
def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
|
||||||
|
return isinstance(obj, Mapping)
|
||||||
|
|
||||||
|
|
||||||
|
def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
|
||||||
|
return isinstance(obj, Mapping)
|
||||||
|
|
||||||
|
|
||||||
|
def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
|
||||||
|
return isinstance(obj, dict)
|
||||||
|
|
||||||
|
|
||||||
|
def is_list(obj: object) -> TypeGuard[list[object]]:
|
||||||
|
return isinstance(obj, list)
|
||||||
|
|
||||||
|
|
||||||
|
def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
|
||||||
|
return isinstance(obj, Iterable)
|
||||||
|
|
||||||
|
|
||||||
|
def deepcopy_minimal(item: _T) -> _T:
|
||||||
|
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
|
||||||
|
|
||||||
|
- mappings, e.g. `dict`
|
||||||
|
- list
|
||||||
|
|
||||||
|
This is done for performance reasons.
|
||||||
|
"""
|
||||||
|
if is_mapping(item):
|
||||||
|
return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
|
||||||
|
if is_list(item):
|
||||||
|
return cast(_T, [deepcopy_minimal(entry) for entry in item])
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
# copied from https://github.com/Rapptz/RoboDanny
|
||||||
|
def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
|
||||||
|
size = len(seq)
|
||||||
|
if size == 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if size == 1:
|
||||||
|
return seq[0]
|
||||||
|
|
||||||
|
if size == 2:
|
||||||
|
return f"{seq[0]} {final} {seq[1]}"
|
||||||
|
|
||||||
|
return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
|
||||||
|
|
||||||
|
|
||||||
|
def quote(string: str) -> str:
|
||||||
|
"""Add single quotation marks around the given string. Does *not* do any escaping."""
|
||||||
|
return f"'{string}'"
|
||||||
|
|
||||||
|
|
||||||
|
def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
|
||||||
|
"""Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
|
||||||
|
|
||||||
|
Useful for enforcing runtime validation of overloaded functions.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```py
|
||||||
|
@overload
|
||||||
|
def foo(*, a: str) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def foo(*, b: bool) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# This enforces the same constraints that a static type checker would
|
||||||
|
# i.e. that either a or b must be passed to the function
|
||||||
|
@required_args(["a"], ["b"])
|
||||||
|
def foo(*, a: str | None = None, b: bool | None = None) -> str:
|
||||||
|
...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def inner(func: CallableT) -> CallableT:
|
||||||
|
params = inspect.signature(func).parameters
|
||||||
|
positional = [
|
||||||
|
name
|
||||||
|
for name, param in params.items()
|
||||||
|
if param.kind
|
||||||
|
in {
|
||||||
|
param.POSITIONAL_ONLY,
|
||||||
|
param.POSITIONAL_OR_KEYWORD,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args: object, **kwargs: object) -> object:
|
||||||
|
given_params: set[str] = set()
|
||||||
|
for i, _ in enumerate(args):
|
||||||
|
try:
|
||||||
|
given_params.add(positional[i])
|
||||||
|
except IndexError:
|
||||||
|
raise TypeError(
|
||||||
|
f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
for key in kwargs.keys():
|
||||||
|
given_params.add(key)
|
||||||
|
|
||||||
|
for variant in variants:
|
||||||
|
matches = all((param in given_params for param in variant))
|
||||||
|
if matches:
|
||||||
|
break
|
||||||
|
else: # no break
|
||||||
|
if len(variants) > 1:
|
||||||
|
variations = human_join(
|
||||||
|
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
|
||||||
|
)
|
||||||
|
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
|
||||||
|
else:
|
||||||
|
assert len(variants) > 0
|
||||||
|
|
||||||
|
# TODO: this error message is not deterministic
|
||||||
|
missing = list(set(variants[0]) - given_params)
|
||||||
|
if len(missing) > 1:
|
||||||
|
msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
|
||||||
|
else:
|
||||||
|
msg = f"Missing required argument: {quote(missing[0])}"
|
||||||
|
raise TypeError(msg)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper # type: ignore
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
_K = TypeVar("_K")
|
||||||
|
_V = TypeVar("_V")
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def strip_not_given(obj: None) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def strip_not_given(obj: object) -> object:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def strip_not_given(obj: object | None) -> object:
|
||||||
|
"""Remove all top-level keys where their values are instances of `NotGiven`"""
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not is_mapping(obj):
|
||||||
|
return obj
|
||||||
|
|
||||||
|
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
|
||||||
|
|
||||||
|
|
||||||
|
def coerce_integer(val: str) -> int:
|
||||||
|
return int(val, base=10)
|
||||||
|
|
||||||
|
|
||||||
|
def coerce_float(val: str) -> float:
|
||||||
|
return float(val)
|
||||||
|
|
||||||
|
|
||||||
|
def coerce_boolean(val: str) -> bool:
|
||||||
|
return val == "true" or val == "1" or val == "on"
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_coerce_integer(val: str | None) -> int | None:
|
||||||
|
if val is None:
|
||||||
|
return None
|
||||||
|
return coerce_integer(val)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_coerce_float(val: str | None) -> float | None:
|
||||||
|
if val is None:
|
||||||
|
return None
|
||||||
|
return coerce_float(val)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_coerce_boolean(val: str | None) -> bool | None:
|
||||||
|
if val is None:
|
||||||
|
return None
|
||||||
|
return coerce_boolean(val)
|
||||||
|
|
||||||
|
|
||||||
|
def removeprefix(string: str, prefix: str) -> str:
|
||||||
|
"""Remove a prefix from a string.
|
||||||
|
|
||||||
|
Backport of `str.removeprefix` for Python < 3.9
|
||||||
|
"""
|
||||||
|
if string.startswith(prefix):
|
||||||
|
return string[len(prefix) :]
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def removesuffix(string: str, suffix: str) -> str:
|
||||||
|
"""Remove a suffix from a string.
|
||||||
|
|
||||||
|
Backport of `str.removesuffix` for Python < 3.9
|
||||||
|
"""
|
||||||
|
if string.endswith(suffix):
|
||||||
|
return string[: -len(suffix)]
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def file_from_path(path: str) -> FileTypes:
|
||||||
|
contents = Path(path).read_bytes()
|
||||||
|
file_name = os.path.basename(path)
|
||||||
|
return (file_name, contents)
|
||||||
|
|
||||||
|
|
||||||
|
def get_required_header(headers: HeadersLike, header: str) -> str:
|
||||||
|
lower_header = header.lower()
|
||||||
|
if isinstance(headers, Mapping):
|
||||||
|
headers = cast(Headers, headers)
|
||||||
|
for k, v in headers.items():
|
||||||
|
if k.lower() == lower_header and isinstance(v, str):
|
||||||
|
return v
|
||||||
|
|
||||||
|
""" to deal with the case where the header looks like Stainless-Event-Id """
|
||||||
|
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
|
||||||
|
|
||||||
|
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
|
||||||
|
value = headers.get(normalized_header)
|
||||||
|
if value:
|
||||||
|
return value
|
||||||
|
|
||||||
|
raise ValueError(f"Could not find {header} header")
|
||||||
|
|
||||||
|
|
||||||
|
def get_async_library() -> str:
|
||||||
|
try:
|
||||||
|
return sniffio.current_async_library()
|
||||||
|
except Exception:
|
||||||
|
return "false"
|
||||||
|
|
||||||
|
|
||||||
|
def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
|
||||||
|
"""A version of functools.lru_cache that retains the type signature
|
||||||
|
for the wrapped function arguments.
|
||||||
|
"""
|
||||||
|
wrapper = functools.lru_cache( # noqa: TID251
|
||||||
|
maxsize=maxsize,
|
||||||
|
)
|
||||||
|
return cast(Any, wrapper) # type: ignore[no-any-return]
|
||||||
@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Literal, Optional
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from ..._models import BaseModel
|
||||||
|
|
||||||
from model_providers.core.entities.model_entities import (
|
from model_providers.core.entities.model_entities import (
|
||||||
ModelStatus,
|
ModelStatus,
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
from ..._models import BaseModel
|
||||||
from pydantic import BaseModel, Field, root_validator
|
from pydantic import Field as FieldInfo
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ class ModelCard(BaseModel):
|
|||||||
"tts",
|
"tts",
|
||||||
"text2img",
|
"text2img",
|
||||||
] = "llm"
|
] = "llm"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = FieldInfo(default_factory=lambda: int(time.time()))
|
||||||
owned_by: Literal["owner"] = "owner"
|
owned_by: Literal["owner"] = "owner"
|
||||||
|
|
||||||
|
|
||||||
@ -171,7 +171,7 @@ class ChatCompletionStreamResponseChoice(BaseModel):
|
|||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
object: Literal["chat.completion"] = "chat.completion"
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = FieldInfo(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseChoice]
|
choices: List[ChatCompletionResponseChoice]
|
||||||
usage: UsageInfo
|
usage: UsageInfo
|
||||||
@ -180,7 +180,7 @@ class ChatCompletionResponse(BaseModel):
|
|||||||
class ChatCompletionStreamResponse(BaseModel):
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = FieldInfo(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionStreamResponseChoice]
|
choices: List[ChatCompletionStreamResponseChoice]
|
||||||
|
|
||||||
|
|||||||
@ -1,331 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
from typing import Any, Literal, Optional, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from model_providers.core.entities.provider_configuration import ProviderModelBundle
|
|
||||||
from model_providers.core.file.file_obj import FileObj
|
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
|
||||||
PromptMessageRole,
|
|
||||||
)
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
Model Config Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
provider: str
|
|
||||||
model: str
|
|
||||||
model_schema: AIModelEntity
|
|
||||||
mode: str
|
|
||||||
provider_model_bundle: ProviderModelBundle
|
|
||||||
credentials: Dict[str, Any] = {}
|
|
||||||
parameters: Dict[str, Any] = {}
|
|
||||||
stop: List[str] = []
|
|
||||||
|
|
||||||
|
|
||||||
class AdvancedChatMessageEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
Advanced Chat Message Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
text: str
|
|
||||||
role: PromptMessageRole
|
|
||||||
|
|
||||||
|
|
||||||
class AdvancedChatPromptTemplateEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
Advanced Chat Prompt Template Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
messages: List[AdvancedChatMessageEntity]
|
|
||||||
|
|
||||||
|
|
||||||
class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
Advanced Completion Prompt Template Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class RolePrefixEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
Role Prefix Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
user: str
|
|
||||||
assistant: str
|
|
||||||
|
|
||||||
prompt: str
|
|
||||||
role_prefix: Optional[RolePrefixEntity] = None
|
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplateEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
Prompt Template Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class PromptType(Enum):
|
|
||||||
"""
|
|
||||||
Prompt Type.
|
|
||||||
'simple', 'advanced'
|
|
||||||
"""
|
|
||||||
|
|
||||||
SIMPLE = "simple"
|
|
||||||
ADVANCED = "advanced"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def value_of(cls, value: str) -> "PromptType":
|
|
||||||
"""
|
|
||||||
Get value of given mode.
|
|
||||||
|
|
||||||
:param value: mode value
|
|
||||||
:return: mode
|
|
||||||
"""
|
|
||||||
for mode in cls:
|
|
||||||
if mode.value == value:
|
|
||||||
return mode
|
|
||||||
raise ValueError(f"invalid prompt type value {value}")
|
|
||||||
|
|
||||||
prompt_type: PromptType
|
|
||||||
simple_prompt_template: Optional[str] = None
|
|
||||||
advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
|
|
||||||
advanced_completion_prompt_template: Optional[
|
|
||||||
AdvancedCompletionPromptTemplateEntity
|
|
||||||
] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ExternalDataVariableEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
External Data Variable Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
variable: str
|
|
||||||
type: str
|
|
||||||
config: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetRetrieveConfigEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
Dataset Retrieve Config Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class RetrieveStrategy(Enum):
|
|
||||||
"""
|
|
||||||
Dataset Retrieve Strategy.
|
|
||||||
'single' or 'multiple'
|
|
||||||
"""
|
|
||||||
|
|
||||||
SINGLE = "single"
|
|
||||||
MULTIPLE = "multiple"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def value_of(cls, value: str) -> "RetrieveStrategy":
|
|
||||||
"""
|
|
||||||
Get value of given mode.
|
|
||||||
|
|
||||||
:param value: mode value
|
|
||||||
:return: mode
|
|
||||||
"""
|
|
||||||
for mode in cls:
|
|
||||||
if mode.value == value:
|
|
||||||
return mode
|
|
||||||
raise ValueError(f"invalid retrieve strategy value {value}")
|
|
||||||
|
|
||||||
query_variable: Optional[str] = None # Only when app mode is completion
|
|
||||||
|
|
||||||
retrieve_strategy: RetrieveStrategy
|
|
||||||
single_strategy: Optional[str] = None # for temp
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
score_threshold: Optional[float] = None
|
|
||||||
reranking_model: Optional[dict] = None
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
Dataset Config Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
dataset_ids: List[str]
|
|
||||||
retrieve_config: DatasetRetrieveConfigEntity
|
|
||||||
|
|
||||||
|
|
||||||
class SensitiveWordAvoidanceEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
Sensitive Word Avoidance Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: str
|
|
||||||
config: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
|
|
||||||
class TextToSpeechEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
Sensitive Word Avoidance Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
enabled: bool
|
|
||||||
voice: Optional[str] = None
|
|
||||||
language: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class FileUploadEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
File Upload Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
image_config: Optional[dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class AgentToolEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
Agent Tool Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
provider_type: Literal["builtin", "api"]
|
|
||||||
provider_id: str
|
|
||||||
tool_name: str
|
|
||||||
tool_parameters: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
|
|
||||||
class AgentPromptEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
Agent Prompt Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
first_prompt: str
|
|
||||||
next_iteration: str
|
|
||||||
|
|
||||||
|
|
||||||
class AgentScratchpadUnit(BaseModel):
|
|
||||||
"""
|
|
||||||
Agent First Prompt Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Action(BaseModel):
|
|
||||||
"""
|
|
||||||
Action Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
action_name: str
|
|
||||||
action_input: Union[dict, str]
|
|
||||||
|
|
||||||
agent_response: Optional[str] = None
|
|
||||||
thought: Optional[str] = None
|
|
||||||
action_str: Optional[str] = None
|
|
||||||
observation: Optional[str] = None
|
|
||||||
action: Optional[Action] = None
|
|
||||||
|
|
||||||
|
|
||||||
class AgentEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
Agent Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Strategy(Enum):
|
|
||||||
"""
|
|
||||||
Agent Strategy.
|
|
||||||
"""
|
|
||||||
|
|
||||||
CHAIN_OF_THOUGHT = "chain-of-thought"
|
|
||||||
FUNCTION_CALLING = "function-calling"
|
|
||||||
|
|
||||||
provider: str
|
|
||||||
model: str
|
|
||||||
strategy: Strategy
|
|
||||||
prompt: Optional[AgentPromptEntity] = None
|
|
||||||
tools: List[AgentToolEntity] = None
|
|
||||||
max_iteration: int = 5
|
|
||||||
|
|
||||||
|
|
||||||
class AppOrchestrationConfigEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
App Orchestration Config Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config: ModelConfigEntity
|
|
||||||
prompt_template: PromptTemplateEntity
|
|
||||||
external_data_variables: List[ExternalDataVariableEntity] = []
|
|
||||||
agent: Optional[AgentEntity] = None
|
|
||||||
|
|
||||||
# features
|
|
||||||
dataset: Optional[DatasetEntity] = None
|
|
||||||
file_upload: Optional[FileUploadEntity] = None
|
|
||||||
opening_statement: Optional[str] = None
|
|
||||||
suggested_questions_after_answer: bool = False
|
|
||||||
show_retrieve_source: bool = False
|
|
||||||
more_like_this: bool = False
|
|
||||||
speech_to_text: bool = False
|
|
||||||
text_to_speech: dict = {}
|
|
||||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeFrom(Enum):
|
|
||||||
"""
|
|
||||||
Invoke From.
|
|
||||||
"""
|
|
||||||
|
|
||||||
SERVICE_API = "service-api"
|
|
||||||
WEB_APP = "web-app"
|
|
||||||
EXPLORE = "explore"
|
|
||||||
DEBUGGER = "debugger"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def value_of(cls, value: str) -> "InvokeFrom":
|
|
||||||
"""
|
|
||||||
Get value of given mode.
|
|
||||||
|
|
||||||
:param value: mode value
|
|
||||||
:return: mode
|
|
||||||
"""
|
|
||||||
for mode in cls:
|
|
||||||
if mode.value == value:
|
|
||||||
return mode
|
|
||||||
raise ValueError(f"invalid invoke from value {value}")
|
|
||||||
|
|
||||||
def to_source(self) -> str:
|
|
||||||
"""
|
|
||||||
Get source of invoke from.
|
|
||||||
|
|
||||||
:return: source
|
|
||||||
"""
|
|
||||||
if self == InvokeFrom.WEB_APP:
|
|
||||||
return "web_app"
|
|
||||||
elif self == InvokeFrom.DEBUGGER:
|
|
||||||
return "dev"
|
|
||||||
elif self == InvokeFrom.EXPLORE:
|
|
||||||
return "explore_app"
|
|
||||||
elif self == InvokeFrom.SERVICE_API:
|
|
||||||
return "api"
|
|
||||||
|
|
||||||
return "dev"
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationGenerateEntity(BaseModel):
|
|
||||||
"""
|
|
||||||
Application Generate Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
task_id: str
|
|
||||||
tenant_id: str
|
|
||||||
|
|
||||||
app_id: str
|
|
||||||
app_model_config_id: str
|
|
||||||
# for save
|
|
||||||
app_model_config_dict: dict
|
|
||||||
app_model_config_override: bool
|
|
||||||
|
|
||||||
# Converted from app_model_config to Entity object, or directly covered by external input
|
|
||||||
app_orchestration_config_entity: AppOrchestrationConfigEntity
|
|
||||||
|
|
||||||
conversation_id: Optional[str] = None
|
|
||||||
inputs: Dict[str, str]
|
|
||||||
query: Optional[str] = None
|
|
||||||
files: List[FileObj] = []
|
|
||||||
user_id: str
|
|
||||||
# extras
|
|
||||||
stream: bool
|
|
||||||
invoke_from: InvokeFrom
|
|
||||||
|
|
||||||
# extra parameters, like: auto_generate_conversation_name
|
|
||||||
extras: Dict[str, Any] = {}
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
import enum
|
import enum
|
||||||
from typing import Any, cast
|
from typing import Any, cast, List
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
@ -8,7 +8,7 @@ from langchain.schema import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel
|
from ..._models import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from ..._models import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||||
from model_providers.core.model_runtime.entities.model_entities import (
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import logging
|
|||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Dict, Iterator, List, Optional
|
from typing import Dict, Iterator, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from ..._models import BaseModel
|
||||||
|
|
||||||
from model_providers.core.entities.model_entities import (
|
from model_providers.core.entities.model_entities import (
|
||||||
ModelStatus,
|
ModelStatus,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from ..._models import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from ..._models import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
LLMResult,
|
LLMResult,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from ...._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class I18nObject(BaseModel):
|
class I18nObject(BaseModel):
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from decimal import Decimal
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from ...._models import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from abc import ABC
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from ...._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class PromptMessageRole(Enum):
|
class PromptMessageRole(Enum):
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from decimal import Decimal
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from ...._models import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from ...._models import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||||
from model_providers.core.model_runtime.entities.model_entities import (
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from ...._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class RerankDocument(BaseModel):
|
class RerankDocument(BaseModel):
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from ...._models import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelUsage
|
from model_providers.core.model_runtime.entities.model_entities import ModelUsage
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from pydantic import BaseModel
|
from ....._models import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMMode
|
from model_providers.core.model_runtime.entities.llm_entities import LLMMode
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from ...._models import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.entities.provider_entities import (
|
from model_providers.core.model_runtime.entities.provider_entities import (
|
||||||
|
|||||||
@ -1,21 +0,0 @@
|
|||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
|
||||||
|
|
||||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
|
||||||
|
|
||||||
if PYDANTIC_V2:
|
|
||||||
from pydantic_core import Url as Url
|
|
||||||
|
|
||||||
def _model_dump(
|
|
||||||
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
|
|
||||||
) -> Any:
|
|
||||||
return model.model_dump(mode=mode, **kwargs)
|
|
||||||
else:
|
|
||||||
from pydantic import AnyUrl as Url # noqa: F401
|
|
||||||
|
|
||||||
def _model_dump(
|
|
||||||
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
|
|
||||||
) -> Any:
|
|
||||||
return model.dict(**kwargs)
|
|
||||||
@ -1,234 +0,0 @@
|
|||||||
import dataclasses
|
|
||||||
import datetime
|
|
||||||
from collections import defaultdict, deque
|
|
||||||
from collections.abc import Callable
|
|
||||||
from decimal import Decimal
|
|
||||||
from enum import Enum
|
|
||||||
from ipaddress import (
|
|
||||||
IPv4Address,
|
|
||||||
IPv4Interface,
|
|
||||||
IPv4Network,
|
|
||||||
IPv6Address,
|
|
||||||
IPv6Interface,
|
|
||||||
IPv6Network,
|
|
||||||
)
|
|
||||||
from pathlib import Path, PurePath
|
|
||||||
from re import Pattern
|
|
||||||
from types import GeneratorType
|
|
||||||
from typing import Any, Optional, Union, Dict, Type, List, Tuple
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from pydantic.color import Color
|
|
||||||
from pydantic.networks import AnyUrl, NameEmail
|
|
||||||
from pydantic.types import SecretBytes, SecretStr
|
|
||||||
|
|
||||||
from ._compat import PYDANTIC_V2, Url, _model_dump
|
|
||||||
|
|
||||||
|
|
||||||
# Taken from Pydantic v1 as is
|
|
||||||
def isoformat(o: Union[datetime.date, datetime.time]) -> str:
|
|
||||||
return o.isoformat()
|
|
||||||
|
|
||||||
|
|
||||||
# Taken from Pydantic v1 as is
|
|
||||||
# TODO: pv2 should this return strings instead?
|
|
||||||
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
|
|
||||||
"""
|
|
||||||
Encodes a Decimal as int of there's no exponent, otherwise float
|
|
||||||
|
|
||||||
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
|
|
||||||
where a integer (but not int typed) is used. Encoding this as a float
|
|
||||||
results in failed round-tripping between encode and parse.
|
|
||||||
Our Id type is a prime example of this.
|
|
||||||
|
|
||||||
>>> decimal_encoder(Decimal("1.0"))
|
|
||||||
1.0
|
|
||||||
|
|
||||||
>>> decimal_encoder(Decimal("1"))
|
|
||||||
1
|
|
||||||
"""
|
|
||||||
if dec_value.as_tuple().exponent >= 0: # type: ignore[operator]
|
|
||||||
return int(dec_value)
|
|
||||||
else:
|
|
||||||
return float(dec_value)
|
|
||||||
|
|
||||||
|
|
||||||
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
|
|
||||||
bytes: lambda o: o.decode(),
|
|
||||||
Color: str,
|
|
||||||
datetime.date: isoformat,
|
|
||||||
datetime.datetime: isoformat,
|
|
||||||
datetime.time: isoformat,
|
|
||||||
datetime.timedelta: lambda td: td.total_seconds(),
|
|
||||||
Decimal: decimal_encoder,
|
|
||||||
Enum: lambda o: o.value,
|
|
||||||
frozenset: list,
|
|
||||||
deque: list,
|
|
||||||
GeneratorType: list,
|
|
||||||
IPv4Address: str,
|
|
||||||
IPv4Interface: str,
|
|
||||||
IPv4Network: str,
|
|
||||||
IPv6Address: str,
|
|
||||||
IPv6Interface: str,
|
|
||||||
IPv6Network: str,
|
|
||||||
NameEmail: str,
|
|
||||||
Path: str,
|
|
||||||
Pattern: lambda o: o.pattern,
|
|
||||||
SecretBytes: str,
|
|
||||||
SecretStr: str,
|
|
||||||
set: list,
|
|
||||||
UUID: str,
|
|
||||||
Url: str,
|
|
||||||
AnyUrl: str,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def generate_encoders_by_class_tuples(
|
|
||||||
type_encoder_map: Dict[Any, Callable[[Any], Any]],
|
|
||||||
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
|
|
||||||
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
|
|
||||||
tuple
|
|
||||||
)
|
|
||||||
for type_, encoder in type_encoder_map.items():
|
|
||||||
encoders_by_class_tuples[encoder] += (type_,)
|
|
||||||
return encoders_by_class_tuples
|
|
||||||
|
|
||||||
|
|
||||||
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
|
|
||||||
|
|
||||||
|
|
||||||
def jsonable_encoder(
|
|
||||||
obj: Any,
|
|
||||||
by_alias: bool = True,
|
|
||||||
exclude_unset: bool = False,
|
|
||||||
exclude_defaults: bool = False,
|
|
||||||
exclude_none: bool = False,
|
|
||||||
custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None,
|
|
||||||
sqlalchemy_safe: bool = True,
|
|
||||||
) -> Any:
|
|
||||||
custom_encoder = custom_encoder or {}
|
|
||||||
if custom_encoder:
|
|
||||||
if type(obj) in custom_encoder:
|
|
||||||
return custom_encoder[type(obj)](obj)
|
|
||||||
else:
|
|
||||||
for encoder_type, encoder_instance in custom_encoder.items():
|
|
||||||
if isinstance(obj, encoder_type):
|
|
||||||
return encoder_instance(obj)
|
|
||||||
if isinstance(obj, BaseModel):
|
|
||||||
# TODO: remove when deprecating Pydantic v1
|
|
||||||
encoders: Dict[Any, Any] = {}
|
|
||||||
if not PYDANTIC_V2:
|
|
||||||
encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
|
|
||||||
if custom_encoder:
|
|
||||||
encoders.update(custom_encoder)
|
|
||||||
obj_dict = _model_dump(
|
|
||||||
obj,
|
|
||||||
mode="json",
|
|
||||||
include=None,
|
|
||||||
exclude=None,
|
|
||||||
by_alias=by_alias,
|
|
||||||
exclude_unset=exclude_unset,
|
|
||||||
exclude_none=exclude_none,
|
|
||||||
exclude_defaults=exclude_defaults,
|
|
||||||
)
|
|
||||||
if "__root__" in obj_dict:
|
|
||||||
obj_dict = obj_dict["__root__"]
|
|
||||||
return jsonable_encoder(
|
|
||||||
obj_dict,
|
|
||||||
exclude_none=exclude_none,
|
|
||||||
exclude_defaults=exclude_defaults,
|
|
||||||
# TODO: remove when deprecating Pydantic v1
|
|
||||||
custom_encoder=encoders,
|
|
||||||
sqlalchemy_safe=sqlalchemy_safe,
|
|
||||||
)
|
|
||||||
if dataclasses.is_dataclass(obj):
|
|
||||||
obj_dict = dataclasses.asdict(obj)
|
|
||||||
return jsonable_encoder(
|
|
||||||
obj_dict,
|
|
||||||
by_alias=by_alias,
|
|
||||||
exclude_unset=exclude_unset,
|
|
||||||
exclude_defaults=exclude_defaults,
|
|
||||||
exclude_none=exclude_none,
|
|
||||||
custom_encoder=custom_encoder,
|
|
||||||
sqlalchemy_safe=sqlalchemy_safe,
|
|
||||||
)
|
|
||||||
if isinstance(obj, Enum):
|
|
||||||
return obj.value
|
|
||||||
if isinstance(obj, PurePath):
|
|
||||||
return str(obj)
|
|
||||||
if isinstance(obj, str | int | float | type(None)):
|
|
||||||
return obj
|
|
||||||
if isinstance(obj, Decimal):
|
|
||||||
return format(obj, "f")
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
encoded_dict = {}
|
|
||||||
allowed_keys = set(obj.keys())
|
|
||||||
for key, value in obj.items():
|
|
||||||
if (
|
|
||||||
(
|
|
||||||
not sqlalchemy_safe
|
|
||||||
or (not isinstance(key, str))
|
|
||||||
or (not key.startswith("_sa"))
|
|
||||||
)
|
|
||||||
and (value is not None or not exclude_none)
|
|
||||||
and key in allowed_keys
|
|
||||||
):
|
|
||||||
encoded_key = jsonable_encoder(
|
|
||||||
key,
|
|
||||||
by_alias=by_alias,
|
|
||||||
exclude_unset=exclude_unset,
|
|
||||||
exclude_none=exclude_none,
|
|
||||||
custom_encoder=custom_encoder,
|
|
||||||
sqlalchemy_safe=sqlalchemy_safe,
|
|
||||||
)
|
|
||||||
encoded_value = jsonable_encoder(
|
|
||||||
value,
|
|
||||||
by_alias=by_alias,
|
|
||||||
exclude_unset=exclude_unset,
|
|
||||||
exclude_none=exclude_none,
|
|
||||||
custom_encoder=custom_encoder,
|
|
||||||
sqlalchemy_safe=sqlalchemy_safe,
|
|
||||||
)
|
|
||||||
encoded_dict[encoded_key] = encoded_value
|
|
||||||
return encoded_dict
|
|
||||||
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)):
|
|
||||||
encoded_list = []
|
|
||||||
for item in obj:
|
|
||||||
encoded_list.append(
|
|
||||||
jsonable_encoder(
|
|
||||||
item,
|
|
||||||
by_alias=by_alias,
|
|
||||||
exclude_unset=exclude_unset,
|
|
||||||
exclude_defaults=exclude_defaults,
|
|
||||||
exclude_none=exclude_none,
|
|
||||||
custom_encoder=custom_encoder,
|
|
||||||
sqlalchemy_safe=sqlalchemy_safe,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return encoded_list
|
|
||||||
|
|
||||||
if type(obj) in ENCODERS_BY_TYPE:
|
|
||||||
return ENCODERS_BY_TYPE[type(obj)](obj)
|
|
||||||
for encoder, classes_tuple in encoders_by_class_tuples.items():
|
|
||||||
if isinstance(obj, classes_tuple):
|
|
||||||
return encoder(obj)
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = dict(obj)
|
|
||||||
except Exception as e:
|
|
||||||
errors: List[Exception] = [e]
|
|
||||||
try:
|
|
||||||
data = vars(obj)
|
|
||||||
except Exception as e:
|
|
||||||
errors.append(e)
|
|
||||||
raise ValueError(errors) from e
|
|
||||||
return jsonable_encoder(
|
|
||||||
data,
|
|
||||||
by_alias=by_alias,
|
|
||||||
exclude_unset=exclude_unset,
|
|
||||||
exclude_defaults=exclude_defaults,
|
|
||||||
exclude_none=exclude_none,
|
|
||||||
custom_encoder=custom_encoder,
|
|
||||||
sqlalchemy_safe=sqlalchemy_safe,
|
|
||||||
)
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
import pydantic
|
import pydantic
|
||||||
from pydantic import BaseModel
|
from ...._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
def dump_model(model: BaseModel) -> dict:
|
def dump_model(model: BaseModel) -> dict:
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import json
|
|||||||
from typing import TYPE_CHECKING, Any, Dict
|
from typing import TYPE_CHECKING, Any, Dict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pydantic import BaseModel
|
from ..._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from pydantic import BaseModel
|
from ..._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
def json_dumps(o):
|
def json_dumps(o):
|
||||||
|
|||||||
@ -136,7 +136,7 @@ def init_server(logging_conf: dict, providers_file: str) -> None:
|
|||||||
yield f"http://127.0.0.1:20000"
|
yield f"http://127.0.0.1:20000"
|
||||||
finally:
|
finally:
|
||||||
print("")
|
print("")
|
||||||
# boot.destroy()
|
boot.destroy()
|
||||||
|
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user