make format

This commit is contained in:
glide-the 2024-05-20 11:58:30 +08:00
parent 6dd00b5d94
commit 4e9b1d6edf
49 changed files with 631 additions and 407 deletions

View File

@ -1,11 +1,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
from datetime import date, datetime
from typing_extensions import Self
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload
import pydantic
from pydantic.fields import FieldInfo
from typing_extensions import Self
from ._types import StrBytesIntFloat
@ -45,23 +45,49 @@ if TYPE_CHECKING:
else:
if PYDANTIC_V2:
from pydantic.v1.datetime_parse import (
parse_date as parse_date,
)
from pydantic.v1.datetime_parse import (
parse_datetime as parse_datetime,
)
from pydantic.v1.typing import (
get_args as get_args,
is_union as is_union,
)
from pydantic.v1.typing import (
get_origin as get_origin,
is_typeddict as is_typeddict,
)
from pydantic.v1.typing import (
is_literal_type as is_literal_type,
)
from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
from pydantic.v1.typing import (
is_typeddict as is_typeddict,
)
from pydantic.v1.typing import (
is_union as is_union,
)
else:
from pydantic.datetime_parse import (
parse_date as parse_date,
)
from pydantic.datetime_parse import (
parse_datetime as parse_datetime,
)
from pydantic.typing import (
get_args as get_args,
is_union as is_union,
)
from pydantic.typing import (
get_origin as get_origin,
is_typeddict as is_typeddict,
)
from pydantic.typing import (
is_literal_type as is_literal_type,
)
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
from pydantic.typing import (
is_typeddict as is_typeddict,
)
from pydantic.typing import (
is_union as is_union,
)
# refactored config
@ -204,7 +230,9 @@ if TYPE_CHECKING:
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T:
...
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
def __get__(
self, instance: object, owner: type[Any] | None = None
) -> _T | Self:
raise NotImplementedError()
def __set_name__(self, owner: type[Any], name: str) -> None:

View File

@ -4,20 +4,20 @@ import io
import os
import pathlib
from typing import overload
from typing_extensions import TypeGuard
import anyio
from typing_extensions import TypeGuard
from ._types import (
FileTypes,
FileContent,
RequestFiles,
HttpxFileTypes,
Base64FileInput,
FileContent,
FileTypes,
HttpxFileContent,
HttpxFileTypes,
HttpxRequestFiles,
RequestFiles,
)
from ._utils import is_tuple_t, is_mapping_t, is_sequence_t
from ._utils import is_mapping_t, is_sequence_t, is_tuple_t
def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
@ -26,13 +26,20 @@ def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
def is_file_content(obj: object) -> TypeGuard[FileContent]:
return (
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
isinstance(obj, bytes)
or isinstance(obj, tuple)
or isinstance(obj, io.IOBase)
or isinstance(obj, os.PathLike)
)
def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
if not is_file_content(obj):
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
prefix = (
f"Expected entry at `{key}`"
if key is not None
else f"Expected file input `{obj!r}`"
)
raise RuntimeError(
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads"
) from None
@ -57,7 +64,9 @@ def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
elif is_sequence_t(files):
files = [(key, _transform_file(file)) for key, file in files]
else:
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
raise TypeError(
f"Unexpected file type input {type(files)}, expected mapping or sequence"
)
return files
@ -73,7 +82,9 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes:
if is_tuple_t(file):
return (file[0], _read_file_content(file[1]), *file[2:])
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
raise TypeError(
f"Expected file types input to be a FileContent type or to be a tuple"
)
def _read_file_content(file: FileContent) -> HttpxFileContent:
@ -101,7 +112,9 @@ async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles
elif is_sequence_t(files):
files = [(key, await _async_transform_file(file)) for key, file in files]
else:
raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence")
raise TypeError(
"Unexpected file type input {type(files)}, expected mapping or sequence"
)
return files
@ -117,7 +130,9 @@ async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
if is_tuple_t(file):
return (file[0], await _async_read_file_content(file[1]), *file[2:])
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
raise TypeError(
f"Expected file types input to be a FileContent type or to be a tuple"
)
async def _async_read_file_content(file: FileContent) -> HttpxFileContent:

View File

@ -1,60 +1,62 @@
from __future__ import annotations
import os
import inspect
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast
import os
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Callable, Generic, Type, TypeVar, Union, cast
import pydantic
import pydantic.generics
from pydantic.fields import FieldInfo
from typing_extensions import (
Unpack,
Literal,
ClassVar,
Literal,
Protocol,
Required,
TypedDict,
TypeGuard,
Unpack,
final,
override,
runtime_checkable,
)
import pydantic
import pydantic.generics
from pydantic.fields import FieldInfo
from ._compat import (
PYDANTIC_V2,
ConfigDict,
field_get_default,
get_args,
get_model_config,
get_model_fields,
get_origin,
is_literal_type,
is_union,
parse_obj,
)
from ._compat import (
GenericModel as BaseGenericModel,
)
from ._types import (
IncEx,
ModelT,
)
from ._utils import (
PropertyInfo,
is_list,
is_given,
lru_cache,
is_mapping,
parse_date,
coerce_boolean,
parse_datetime,
strip_not_given,
extract_type_arg,
is_annotated_type,
is_given,
is_list,
is_mapping,
lru_cache,
parse_date,
parse_datetime,
strip_annotated_type,
)
from ._compat import (
PYDANTIC_V2,
ConfigDict,
GenericModel as BaseGenericModel,
get_args,
is_union,
parse_obj,
get_origin,
is_literal_type,
get_model_config,
get_model_fields,
field_get_default,
strip_not_given,
)
if TYPE_CHECKING:
from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema
from pydantic_core.core_schema import LiteralSchema, ModelField, ModelFieldsSchema
__all__ = ["BaseModel", "GenericModel"]
@ -69,7 +71,8 @@ class _ConfigProtocol(Protocol):
class BaseModel(pydantic.BaseModel):
if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
extra="allow",
defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")),
)
else:
@ -189,7 +192,9 @@ class BaseModel(pydantic.BaseModel):
key = name
if key in values:
fields_values[name] = _construct_field(value=values[key], field=field, key=key)
fields_values[name] = _construct_field(
value=values[key], field=field, key=key
)
_fields_set.add(name)
else:
fields_values[name] = field_get_default(field)
@ -412,9 +417,13 @@ def construct_type(*, value: object, type_: object) -> object:
#
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
# we'd end up constructing `FooType` when it should be `BarType`.
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
discriminator = _build_discriminated_union_meta(
union=type_, meta_annotations=meta
)
if discriminator and is_mapping(value):
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
variant_value = value.get(
discriminator.field_alias_from or discriminator.field_name
)
if variant_value and isinstance(variant_value, str):
variant_type = discriminator.mapping.get(variant_value)
if variant_type:
@ -434,11 +443,19 @@ def construct_type(*, value: object, type_: object) -> object:
return value
_, items_type = get_args(type_) # Dict[_, items_type]
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
return {
key: construct_type(value=item, type_=items_type)
for key, item in value.items()
}
if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):
if not is_literal_type(type_) and (
issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
):
if is_list(value):
return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
return [
cast(Any, type_).construct(**entry) if is_mapping(entry) else entry
for entry in value
]
if is_mapping(value):
if issubclass(type_, BaseModel):
@ -523,14 +540,19 @@ class DiscriminatorDetails:
self.field_alias_from = discriminator_alias
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
def _build_discriminated_union_meta(
*, union: type, meta_annotations: tuple[Any, ...]
) -> DiscriminatorDetails | None:
if isinstance(union, CachedDiscriminatorType):
return union.__discriminator__
discriminator_field_name: str | None = None
for annotation in meta_annotations:
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
if (
isinstance(annotation, PropertyInfo)
and annotation.discriminator is not None
):
discriminator_field_name = annotation.discriminator
break
@ -558,7 +580,9 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
if isinstance(entry, str):
mapping[entry] = variant
else:
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(
discriminator_field_name
) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
if not field_info:
continue
@ -582,7 +606,9 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
return details
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
def _extract_field_schema_pv2(
model: type[BaseModel], field_name: str
) -> ModelField | None:
schema = model.__pydantic_core_schema__
if schema["type"] != "model":
return None
@ -621,7 +647,9 @@ else:
if PYDANTIC_V2:
from pydantic import TypeAdapter as _TypeAdapter
_CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter))
_CachedTypeAdapter = cast(
"TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter)
)
if TYPE_CHECKING:
from pydantic import TypeAdapter
@ -652,6 +680,3 @@ elif not TYPE_CHECKING: # TODO: condition is weird
def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
return RootModel[type_] # type: ignore

View File

@ -5,22 +5,29 @@ from typing import (
IO,
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Type,
Tuple,
Union,
Mapping,
TypeVar,
Callable,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
from typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable
import httpx
import pydantic
from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport
from httpx import URL, AsyncBaseTransport, BaseTransport, Proxy, Response, Timeout
from typing_extensions import (
Literal,
Protocol,
TypeAlias,
TypedDict,
override,
runtime_checkable,
)
if TYPE_CHECKING:
from ._models import BaseModel
@ -43,7 +50,9 @@ if TYPE_CHECKING:
FileContent = Union[IO[bytes], bytes, PathLike[str]]
else:
Base64FileInput = Union[IO[bytes], PathLike]
FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.
FileContent = Union[
IO[bytes], bytes, PathLike
] # PathLike is not subscriptable in Python 3.8.
FileTypes = Union[
# file (or bytes)
FileContent,
@ -68,7 +77,9 @@ HttpxFileTypes = Union[
# (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
]
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]
HttpxRequestFiles = Union[
Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]
]
# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT
# where ResponseT includes `None`. In order to support directly

View File

@ -1,48 +1,126 @@
from ._utils import (
flatten as flatten,
is_dict as is_dict,
is_list as is_list,
is_given as is_given,
is_tuple as is_tuple,
lru_cache as lru_cache,
is_mapping as is_mapping,
is_tuple_t as is_tuple_t,
parse_date as parse_date,
is_iterable as is_iterable,
is_sequence as is_sequence,
coerce_float as coerce_float,
is_mapping_t as is_mapping_t,
removeprefix as removeprefix,
removesuffix as removesuffix,
extract_files as extract_files,
is_sequence_t as is_sequence_t,
required_args as required_args,
coerce_boolean as coerce_boolean,
coerce_integer as coerce_integer,
file_from_path as file_from_path,
parse_datetime as parse_datetime,
strip_not_given as strip_not_given,
deepcopy_minimal as deepcopy_minimal,
get_async_library as get_async_library,
maybe_coerce_float as maybe_coerce_float,
get_required_header as get_required_header,
maybe_coerce_boolean as maybe_coerce_boolean,
maybe_coerce_integer as maybe_coerce_integer,
from ._transform import (
PropertyInfo as PropertyInfo,
)
from ._transform import (
async_maybe_transform as async_maybe_transform,
)
from ._transform import (
async_transform as async_transform,
)
from ._transform import (
maybe_transform as maybe_transform,
)
from ._transform import (
transform as transform,
)
from ._typing import (
extract_type_arg as extract_type_arg,
)
from ._typing import (
extract_type_var_from_base as extract_type_var_from_base,
)
from ._typing import (
is_annotated_type as is_annotated_type,
)
from ._typing import (
is_iterable_type as is_iterable_type,
)
from ._typing import (
is_list_type as is_list_type,
is_union_type as is_union_type,
extract_type_arg as extract_type_arg,
is_iterable_type as is_iterable_type,
)
from ._typing import (
is_required_type as is_required_type,
is_annotated_type as is_annotated_type,
)
from ._typing import (
is_union_type as is_union_type,
)
from ._typing import (
strip_annotated_type as strip_annotated_type,
extract_type_var_from_base as extract_type_var_from_base,
)
from ._transform import (
PropertyInfo as PropertyInfo,
transform as transform,
async_transform as async_transform,
maybe_transform as maybe_transform,
async_maybe_transform as async_maybe_transform,
from ._utils import (
coerce_boolean as coerce_boolean,
)
from ._utils import (
coerce_float as coerce_float,
)
from ._utils import (
coerce_integer as coerce_integer,
)
from ._utils import (
deepcopy_minimal as deepcopy_minimal,
)
from ._utils import (
extract_files as extract_files,
)
from ._utils import (
file_from_path as file_from_path,
)
from ._utils import (
flatten as flatten,
)
from ._utils import (
get_async_library as get_async_library,
)
from ._utils import (
get_required_header as get_required_header,
)
from ._utils import (
is_dict as is_dict,
)
from ._utils import (
is_given as is_given,
)
from ._utils import (
is_iterable as is_iterable,
)
from ._utils import (
is_list as is_list,
)
from ._utils import (
is_mapping as is_mapping,
)
from ._utils import (
is_mapping_t as is_mapping_t,
)
from ._utils import (
is_sequence as is_sequence,
)
from ._utils import (
is_sequence_t as is_sequence_t,
)
from ._utils import (
is_tuple as is_tuple,
)
from ._utils import (
is_tuple_t as is_tuple_t,
)
from ._utils import (
lru_cache as lru_cache,
)
from ._utils import (
maybe_coerce_boolean as maybe_coerce_boolean,
)
from ._utils import (
maybe_coerce_float as maybe_coerce_float,
)
from ._utils import (
maybe_coerce_integer as maybe_coerce_integer,
)
from ._utils import (
parse_date as parse_date,
)
from ._utils import (
parse_datetime as parse_datetime,
)
from ._utils import (
removeprefix as removeprefix,
)
from ._utils import (
removesuffix as removesuffix,
)
from ._utils import (
required_args as required_args,
)
from ._utils import (
strip_not_given as strip_not_given,
)

View File

@ -1,31 +1,31 @@
from __future__ import annotations
import io
import base64
import io
import pathlib
from typing import Any, Mapping, TypeVar, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, override, get_type_hints
from typing import Any, Mapping, TypeVar, cast
import anyio
import pydantic
from typing_extensions import Literal, get_args, get_type_hints, override
from ._utils import (
is_list,
is_mapping,
is_iterable,
)
from .._compat import is_typeddict, model_dump
from .._files import is_base64_file_input
from ._typing import (
is_list_type,
is_union_type,
extract_type_arg,
is_iterable_type,
is_required_type,
is_annotated_type,
is_iterable_type,
is_list_type,
is_required_type,
is_union_type,
strip_annotated_type,
)
from .._compat import model_dump, is_typeddict
from ._utils import (
is_iterable,
is_list,
is_mapping,
)
_T = TypeVar("_T")
@ -171,10 +171,17 @@ def _transform_recursive(
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
or (
is_iterable_type(stripped_type)
and is_iterable(data)
and not isinstance(data, str)
)
):
inner_type = extract_type_arg(stripped_type, 0)
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
return [
_transform_recursive(d, annotation=annotation, inner_type=inner_type)
for d in data
]
if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
@ -201,7 +208,9 @@ def _transform_recursive(
return data
def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
def _format_data(
data: object, format_: PropertyFormat, format_template: str | None
) -> object:
if isinstance(data, (date, datetime)):
if format_ == "iso8601":
return data.isoformat()
@ -221,7 +230,9 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N
binary = binary.encode()
if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
raise RuntimeError(
f"Could not read bytes from {data}; Received {type(binary)}"
)
return base64.b64encode(binary).decode("ascii")
@ -240,7 +251,9 @@ def _transform_typeddict(
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
result[_maybe_transform_key(key, type_)] = _transform_recursive(
value, annotation=type_
)
return result
@ -276,7 +289,9 @@ async def async_transform(
It should be noted that the transformations that this function does are not represented in the type system.
"""
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
transformed = await _async_transform_recursive(
data, annotation=cast(type, expected_type)
)
return cast(_T, transformed)
@ -309,10 +324,19 @@ async def _async_transform_recursive(
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
or (
is_iterable_type(stripped_type)
and is_iterable(data)
and not isinstance(data, str)
)
):
inner_type = extract_type_arg(stripped_type, 0)
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
return [
await _async_transform_recursive(
d, annotation=annotation, inner_type=inner_type
)
for d in data
]
if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
@ -320,7 +344,9 @@ async def _async_transform_recursive(
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
data = await _async_transform_recursive(
data, annotation=annotation, inner_type=subtype
)
return data
if isinstance(data, pydantic.BaseModel):
@ -334,12 +360,16 @@ async def _async_transform_recursive(
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return await _async_format_data(data, annotation.format, annotation.format_template)
return await _async_format_data(
data, annotation.format, annotation.format_template
)
return data
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
async def _async_format_data(
data: object, format_: PropertyFormat, format_template: str | None
) -> object:
if isinstance(data, (date, datetime)):
if format_ == "iso8601":
return data.isoformat()
@ -359,7 +389,9 @@ async def _async_format_data(data: object, format_: PropertyFormat, format_templ
binary = binary.encode()
if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
raise RuntimeError(
f"Could not read bytes from {data}; Received {type(binary)}"
)
return base64.b64encode(binary).decode("ascii")
@ -378,5 +410,7 @@ async def _async_transform_typeddict(
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(
value, annotation=type_
)
return result

View File

@ -1,11 +1,12 @@
from __future__ import annotations
from typing import Any, TypeVar, Iterable, cast
from collections import abc as _c_abc
from typing_extensions import Required, Annotated, get_args, get_origin
from typing import Any, Iterable, TypeVar, cast
from typing_extensions import Annotated, Required, get_args, get_origin
from .._types import InheritsGeneric
from .._compat import is_union as _is_union
from .._types import InheritsGeneric
def is_annotated_type(typ: type) -> bool:
@ -49,15 +50,17 @@ def extract_type_arg(typ: type, index: int) -> type:
try:
return cast(type, args[index])
except IndexError as err:
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
raise RuntimeError(
f"Expected type {typ} to have a type argument at index {index} but it did not"
) from err
def extract_type_var_from_base(
typ: type,
*,
generic_bases: tuple[type, ...],
index: int,
failure_message: str | None = None,
typ: type,
*,
generic_bases: tuple[type, ...],
index: int,
failure_message: str | None = None,
) -> type:
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
@ -117,4 +120,7 @@ def extract_type_var_from_base(
return extracted
raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")
raise RuntimeError(
failure_message
or f"Could not resolve inner type variable at index {index} for {typ}"
)

View File

@ -1,27 +1,28 @@
from __future__ import annotations
import functools
import inspect
import os
import re
import inspect
import functools
from pathlib import Path
from typing import (
Any,
Tuple,
Mapping,
TypeVar,
Callable,
Iterable,
Mapping,
Sequence,
Tuple,
TypeVar,
cast,
overload,
)
from pathlib import Path
from typing_extensions import TypeGuard
import sniffio
from typing_extensions import TypeGuard
from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike
from .._compat import parse_date as parse_date, parse_datetime as parse_datetime
from .._compat import parse_date as parse_date
from .._compat import parse_datetime as parse_datetime
from .._types import FileTypes, Headers, HeadersLike, NotGiven, NotGivenOr
_T = TypeVar("_T")
_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
@ -108,7 +109,9 @@ def _extract_items(
item,
path,
index=index,
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
flattened_key=flattened_key + "[]"
if flattened_key is not None
else "[]",
)
for item in obj
]
@ -261,7 +264,12 @@ def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
else: # no break
if len(variants) > 1:
variations = human_join(
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
[
"("
+ human_join([quote(arg) for arg in variant], final="and")
+ ")"
for variant in variants
]
)
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
else:
@ -376,7 +384,11 @@ def get_required_header(headers: HeadersLike, header: str) -> str:
return v
""" to deal with the case where the header looks like Stainless-Event-Id """
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
intercaps_header = re.sub(
r"([^\w])(\w)",
lambda pat: pat.group(1) + pat.group(2).upper(),
header.capitalize(),
)
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
value = headers.get(normalized_header)

View File

@ -1,9 +1,6 @@
from enum import Enum
from typing import List, Literal, Optional
from ..._compat import PYDANTIC_V2, ConfigDict
from ..._models import BaseModel
from model_providers.core.entities.model_entities import (
ModelStatus,
ModelWithProviderEntity,
@ -26,6 +23,9 @@ from model_providers.core.model_runtime.entities.provider_entities import (
SimpleProviderEntity,
)
from ..._compat import PYDANTIC_V2, ConfigDict
from ..._models import BaseModel
class CustomConfigurationStatus(Enum):
"""
@ -74,10 +74,9 @@ class ProviderResponse(BaseModel):
custom_configuration: CustomConfigurationResponse
system_configuration: SystemConfigurationResponse
if PYDANTIC_V2:
model_config = ConfigDict(
protected_namespaces=()
)
model_config = ConfigDict(protected_namespaces=())
else:
class Config:
protected_namespaces = ()
@ -180,10 +179,9 @@ class DefaultModelResponse(BaseModel):
provider: SimpleProviderEntityResponse
if PYDANTIC_V2:
model_config = ConfigDict(
protected_namespaces=()
)
model_config = ConfigDict(protected_namespaces=())
else:
class Config:
protected_namespaces = ()

View File

@ -73,7 +73,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
self._server = None
self._server_thread = None
def logging_conf(self,logging_conf: Optional[dict] = None):
def logging_conf(self, logging_conf: Optional[dict] = None):
self._logging_conf = logging_conf
@classmethod
@ -132,7 +132,10 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
self._app.include_router(self._router)
config = Config(
app=self._app, host=self._host, port=self._port, log_config=self._logging_conf
app=self._app,
host=self._host,
port=self._port,
log_config=self._logging_conf,
)
self._server = Server(config)
@ -145,7 +148,6 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
self._server_thread.start()
def destroy(self):
logger.info("Shutting down server")
self._server.should_exit = True # 设置退出标志
self._server.shutdown() # 停止服务器
@ -155,7 +157,6 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
def join(self):
self._server_thread.join()
def set_app_event(self, started_event: mp.Event = None):
@self._app.on_event("startup")
async def on_startup():
@ -190,12 +191,15 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
provider_model_bundle.model_type_instance.predefined_models()
)
# 获取自定义模型
for model in provider_model_bundle.configuration.custom_configuration.models:
llm_models.append(provider_model_bundle.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
))
for (
model
) in provider_model_bundle.configuration.custom_configuration.models:
llm_models.append(
provider_model_bundle.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
)
)
except Exception as e:
logger.error(
f"Error while fetching models for provider: {provider}, model_type: {model_type}"
@ -225,13 +229,15 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
)
# 判断embeddings_request.input是否为list
input = ''
input = ""
if isinstance(embeddings_request.input, list):
tokens = embeddings_request.input
try:
encoding = tiktoken.encoding_for_model(embeddings_request.model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
logger.warning(
"Warning: model not found. Using cl100k_base encoding."
)
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, token in enumerate(tokens):
@ -241,7 +247,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
else:
input = embeddings_request.input
response = model_instance.invoke_text_embedding(texts=[input], user="abc-123")
response = model_instance.invoke_text_embedding(
texts=[input], user="abc-123"
)
return await openai_embedding_text(response)
except ValueError as e:

View File

@ -1,10 +1,12 @@
import time
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from ..._models import BaseModel
from pydantic import Field as FieldInfo
from typing_extensions import Literal
from ..._models import BaseModel
class Role(str, Enum):
USER = "user"

View File

@ -1,5 +1,5 @@
import enum
from typing import Any, cast, List
from typing import Any, List, cast
from langchain.schema import (
AIMessage,
@ -8,7 +8,6 @@ from langchain.schema import (
HumanMessage,
SystemMessage,
)
from ..._models import BaseModel
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@ -20,6 +19,8 @@ from model_providers.core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from ..._models import BaseModel
class PromptMessageFileType(enum.Enum):
IMAGE = "image"

View File

@ -1,9 +1,6 @@
from enum import Enum
from typing import List, Optional
from ..._compat import PYDANTIC_V2, ConfigDict
from ..._models import BaseModel
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import (
ModelType,
@ -11,6 +8,9 @@ from model_providers.core.model_runtime.entities.model_entities import (
)
from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity
from ..._compat import PYDANTIC_V2, ConfigDict
from ..._models import BaseModel
class ModelStatus(Enum):
"""
@ -80,9 +80,8 @@ class DefaultModelEntity(BaseModel):
provider: DefaultModelProviderEntity
if PYDANTIC_V2:
model_config = ConfigDict(
protected_namespaces=()
)
model_config = ConfigDict(protected_namespaces=())
else:
class Config:
protected_namespaces = ()
protected_namespaces = ()

View File

@ -4,9 +4,6 @@ import logging
from json import JSONDecodeError
from typing import Dict, Iterator, List, Optional
from ..._compat import PYDANTIC_V2, ConfigDict
from ..._models import BaseModel
from model_providers.core.entities.model_entities import (
ModelStatus,
ModelWithProviderEntity,
@ -29,6 +26,9 @@ from model_providers.core.model_runtime.model_providers.__base.model_provider im
ModelProvider,
)
from ..._compat import PYDANTIC_V2, ConfigDict
from ..._models import BaseModel
logger = logging.getLogger(__name__)
@ -351,11 +351,9 @@ class ProviderModelBundle(BaseModel):
model_type_instance: AIModel
if PYDANTIC_V2:
model_config = ConfigDict(
protected_namespaces=(),
arbitrary_types_allowed=True
)
model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)
else:
class Config:
protected_namespaces = ()

View File

@ -1,11 +1,11 @@
from enum import Enum
from typing import List, Optional
from model_providers.core.model_runtime.entities.model_entities import ModelType
from ..._compat import PYDANTIC_V2, ConfigDict
from ..._models import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelType
class ProviderType(Enum):
CUSTOM = "custom"
@ -59,10 +59,9 @@ class RestrictModel(BaseModel):
model_type: ModelType
if PYDANTIC_V2:
model_config = ConfigDict(
protected_namespaces=()
)
model_config = ConfigDict(protected_namespaces=())
else:
class Config:
protected_namespaces = ()
@ -109,10 +108,9 @@ class CustomModelConfiguration(BaseModel):
credentials: dict
if PYDANTIC_V2:
model_config = ConfigDict(
protected_namespaces=()
)
model_config = ConfigDict(protected_namespaces=())
else:
class Config:
protected_namespaces = ()

View File

@ -1,13 +1,13 @@
from enum import Enum
from typing import Any
from ..._models import BaseModel
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
)
from ..._models import BaseModel
class QueueEvent(Enum):
"""

View File

@ -2,8 +2,6 @@ from decimal import Decimal
from enum import Enum
from typing import List, Optional
from ...._models import BaseModel
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@ -13,6 +11,8 @@ from model_providers.core.model_runtime.entities.model_entities import (
PriceInfo,
)
from ...._models import BaseModel
class LLMMode(Enum):
"""

View File

@ -2,11 +2,11 @@ from decimal import Decimal
from enum import Enum
from typing import Any, Dict, List, Optional
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from ...._compat import PYDANTIC_V2, ConfigDict
from ...._models import BaseModel
from model_providers.core.model_runtime.entities.common_entities import I18nObject
class ModelType(Enum):
"""
@ -164,10 +164,9 @@ class ProviderModel(BaseModel):
model_properties: Dict[ModelPropertyKey, Any]
deprecated: bool = False
if PYDANTIC_V2:
model_config = ConfigDict(
protected_namespaces=()
)
model_config = ConfigDict(protected_namespaces=())
else:
class Config:
protected_namespaces = ()

View File

@ -1,9 +1,6 @@
from enum import Enum
from typing import List, Optional
from ...._compat import PYDANTIC_V2, ConfigDict
from ...._models import BaseModel
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
@ -11,6 +8,9 @@ from model_providers.core.model_runtime.entities.model_entities import (
ProviderModel,
)
from ...._compat import PYDANTIC_V2, ConfigDict
from ...._models import BaseModel
class ConfigurateMethod(Enum):
"""
@ -136,10 +136,9 @@ class ProviderEntity(BaseModel):
model_credential_schema: Optional[ModelCredentialSchema] = None
if PYDANTIC_V2:
model_config = ConfigDict(
protected_namespaces=()
)
model_config = ConfigDict(protected_namespaces=())
else:
class Config:
protected_namespaces = ()

View File

@ -1,10 +1,10 @@
from decimal import Decimal
from typing import List
from ...._models import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelUsage
from ...._models import BaseModel
class EmbeddingUsage(ModelUsage):
"""

View File

@ -1,5 +1,3 @@
from ....._models import BaseModel
from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from model_providers.core.model_runtime.entities.llm_entities import LLMMode
from model_providers.core.model_runtime.entities.model_entities import (
@ -14,6 +12,8 @@ from model_providers.core.model_runtime.entities.model_entities import (
PriceConfig,
)
from ....._models import BaseModel
AZURE_OPENAI_API_VERSION = "2024-02-15-preview"

View File

@ -1,6 +1,5 @@
import logging
from typing import Generator
from typing import Dict, List, Optional, Type, Union, cast
from typing import Dict, Generator, List, Optional, Type, Union, cast
import cohere
from cohere.responses import Chat, Generations

View File

@ -1,7 +1,6 @@
import logging
from typing import Generator
from typing import List, Optional, Union, cast
from decimal import Decimal
from typing import Generator, List, Optional, Union, cast
import tiktoken
from openai import OpenAI, Stream
@ -37,10 +36,15 @@ from model_providers.core.model_runtime.entities.message_entities import (
)
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
DefaultParameterName,
FetchFrom,
I18nObject,
ModelFeature,
ModelPropertyKey,
ModelType,
PriceConfig, ParameterRule, ParameterType, ModelFeature, ModelPropertyKey, DefaultParameterName,
ParameterRule,
ParameterType,
PriceConfig,
)
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
@ -48,7 +52,9 @@ from model_providers.core.model_runtime.errors.validate import (
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
from model_providers.core.model_runtime.model_providers.deepseek._common import _CommonDeepseek
from model_providers.core.model_runtime.model_providers.deepseek._common import (
_CommonDeepseek,
)
logger = logging.getLogger(__name__)
@ -1117,7 +1123,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
return num_tokens
def get_customizable_model_schema(
self, model: str, credentials: dict
self, model: str, credentials: dict
) -> AIModelEntity:
"""
Get customizable model schema.
@ -1129,7 +1135,6 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
"""
extras = {}
entity = AIModelEntity(
model=model,
label=I18nObject(zh_Hans=model, en_US=model),
@ -1149,8 +1154,8 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
type=ParameterType.FLOAT,
help=I18nObject(
en_US="The temperature of the model. "
"Increasing the temperature will make the model answer "
"more creatively. (Default: 0.8)"
"Increasing the temperature will make the model answer "
"more creatively. (Default: 0.8)"
),
default=0.8,
min=0,
@ -1163,8 +1168,8 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
type=ParameterType.FLOAT,
help=I18nObject(
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
"more diverse text, while a lower value (e.g., 0.5) will generate more "
"focused and conservative text. (Default: 0.9)"
"more diverse text, while a lower value (e.g., 0.5) will generate more "
"focused and conservative text. (Default: 0.9)"
),
default=0.9,
min=0,
@ -1177,8 +1182,8 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
help=I18nObject(
en_US="A number between -2.0 and 2.0. If positive, ",
zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正,"
"那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,"
"降低模型重复相同内容的可能性"
"那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,"
"降低模型重复相同内容的可能性",
),
default=0,
min=-2,
@ -1190,7 +1195,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
type=ParameterType.FLOAT,
help=I18nObject(
en_US="Sets how strongly to presence_penalty. ",
zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其是否已在已有文本中出现受到相应的惩罚,从而增加模型谈论新主题的可能性。"
zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其是否已在已有文本中出现受到相应的惩罚,从而增加模型谈论新主题的可能性。",
),
default=1.1,
min=-2,
@ -1204,7 +1209,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
help=I18nObject(
en_US="Maximum number of tokens to predict when generating text. ",
zh_Hans="限制一次请求中模型生成 completion 的最大 token 数。"
"输入 token 和输出 token 的总长度受模型的上下文长度的限制。"
"输入 token 和输出 token 的总长度受模型的上下文长度的限制。",
),
default=128,
min=-2,
@ -1216,7 +1221,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
type=ParameterType.BOOLEAN,
help=I18nObject(
en_US="Whether to return the log probabilities of the tokens. ",
zh_Hans="是否返回所输出 token 的对数概率。如果为 true则在 message 的 content 中返回每个输出 token 的对数概率。"
zh_Hans="是否返回所输出 token 的对数概率。如果为 true则在 message 的 content 中返回每个输出 token 的对数概率。",
),
),
ParameterRule(
@ -1226,7 +1231,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
help=I18nObject(
en_US="the format to return a response in.",
zh_Hans="一个介于 0 到 20 之间的整数 N指定每个输出位置返回输出概率 top N 的 token"
"且返回这些 token 的对数概率。指定此参数时logprobs 必须为 true。"
"且返回这些 token 的对数概率。指定此参数时logprobs 必须为 true。",
),
default=0,
min=0,

View File

@ -3,8 +3,6 @@ import logging
import os
from typing import Dict, List, Optional, Union
from ...._models import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.entities.provider_entities import (
ProviderConfig,
@ -25,6 +23,8 @@ from model_providers.core.utils.position_helper import (
sort_to_dict_by_position_map,
)
from ...._models import BaseModel
logger = logging.getLogger(__name__)

View File

@ -1,5 +1,4 @@
from typing import Generator
from typing import List, Optional, Union
from typing import Generator, List, Optional, Union
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import (

View File

@ -1,7 +1,6 @@
import logging
from typing import Generator
from typing import List, Optional, Union, cast
from decimal import Decimal
from typing import Generator, List, Optional, Union, cast
import tiktoken
from openai import OpenAI, Stream
@ -37,10 +36,15 @@ from model_providers.core.model_runtime.entities.message_entities import (
)
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
DefaultParameterName,
FetchFrom,
I18nObject,
ModelFeature,
ModelPropertyKey,
ModelType,
PriceConfig, ModelFeature, ModelPropertyKey, DefaultParameterName, ParameterRule, ParameterType,
ParameterRule,
ParameterType,
PriceConfig,
)
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
@ -48,7 +52,9 @@ from model_providers.core.model_runtime.errors.validate import (
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
from model_providers.core.model_runtime.model_providers.ollama._common import _CommonOllama
from model_providers.core.model_runtime.model_providers.ollama._common import (
_CommonOllama,
)
logger = logging.getLogger(__name__)
@ -661,7 +667,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
model=model,
stream=stream,
extra_body=extra_body
extra_body=extra_body,
)
if stream:
@ -1120,7 +1126,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
return num_tokens
def get_customizable_model_schema(
self, model: str, credentials: dict
self, model: str, credentials: dict
) -> AIModelEntity:
"""
Get customizable model schema.
@ -1154,8 +1160,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.FLOAT,
help=I18nObject(
en_US="The temperature of the model. "
"Increasing the temperature will make the model answer "
"more creatively. (Default: 0.8)"
"Increasing the temperature will make the model answer "
"more creatively. (Default: 0.8)"
),
default=0.8,
min=0,
@ -1168,8 +1174,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.FLOAT,
help=I18nObject(
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
"more diverse text, while a lower value (e.g., 0.5) will generate more "
"focused and conservative text. (Default: 0.9)"
"more diverse text, while a lower value (e.g., 0.5) will generate more "
"focused and conservative text. (Default: 0.9)"
),
default=0.9,
min=0,
@ -1181,8 +1187,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT,
help=I18nObject(
en_US="Reduces the probability of generating nonsense. "
"A higher value (e.g. 100) will give more diverse answers, "
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"
"A higher value (e.g. 100) will give more diverse answers, "
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"
),
default=40,
min=1,
@ -1194,8 +1200,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.FLOAT,
help=I18nObject(
en_US="Sets how strongly to penalize repetitions. "
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
),
default=1.1,
min=-2,
@ -1208,7 +1214,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT,
help=I18nObject(
en_US="Maximum number of tokens to predict when generating text. "
"(Default: 128, -1 = infinite generation, -2 = fill context)"
"(Default: 128, -1 = infinite generation, -2 = fill context)"
),
default=128,
min=-2,
@ -1220,7 +1226,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT,
help=I18nObject(
en_US="Enable Mirostat sampling for controlling perplexity. "
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
),
default=0,
min=0,
@ -1232,9 +1238,9 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.FLOAT,
help=I18nObject(
en_US="Influences how quickly the algorithm responds to feedback from "
"the generated text. A lower learning rate will result in slower adjustments, "
"while a higher learning rate will make the algorithm more responsive. "
"(Default: 0.1)"
"the generated text. A lower learning rate will result in slower adjustments, "
"while a higher learning rate will make the algorithm more responsive. "
"(Default: 0.1)"
),
default=0.1,
precision=1,
@ -1245,7 +1251,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.FLOAT,
help=I18nObject(
en_US="Controls the balance between coherence and diversity of the output. "
"A lower value will result in more focused and coherent text. (Default: 5.0)"
"A lower value will result in more focused and coherent text. (Default: 5.0)"
),
default=5.0,
precision=1,
@ -1256,7 +1262,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT,
help=I18nObject(
en_US="Sets the size of the context window used to generate the next token. "
"(Default: 2048)"
"(Default: 2048)"
),
default=2048,
min=1,
@ -1267,7 +1273,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT,
help=I18nObject(
en_US="The number of layers to send to the GPU(s). "
"On macOS it defaults to 1 to enable metal support, 0 to disable."
"On macOS it defaults to 1 to enable metal support, 0 to disable."
),
default=1,
min=0,
@ -1279,9 +1285,9 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT,
help=I18nObject(
en_US="Sets the number of threads to use during computation. "
"By default, Ollama will detect this for optimal performance. "
"It is recommended to set this value to the number of physical CPU cores "
"your system has (as opposed to the logical number of cores)."
"By default, Ollama will detect this for optimal performance. "
"It is recommended to set this value to the number of physical CPU cores "
"your system has (as opposed to the logical number of cores)."
),
min=1,
),
@ -1291,7 +1297,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT,
help=I18nObject(
en_US="Sets how far back for the model to look back to prevent repetition. "
"(Default: 64, 0 = disabled, -1 = num_ctx)"
"(Default: 64, 0 = disabled, -1 = num_ctx)"
),
default=64,
min=-1,
@ -1302,8 +1308,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.FLOAT,
help=I18nObject(
en_US="Tail free sampling is used to reduce the impact of less probable tokens "
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
"while a value of 1.0 disables this setting. (default: 1)"
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
"while a value of 1.0 disables this setting. (default: 1)"
),
default=1,
precision=1,
@ -1314,8 +1320,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT,
help=I18nObject(
en_US="Sets the random number seed to use for generation. Setting this to "
"a specific number will make the model generate the same text for "
"the same prompt. (Default: 0)"
"a specific number will make the model generate the same text for "
"the same prompt. (Default: 0)"
),
default=0,
),
@ -1325,7 +1331,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.STRING,
help=I18nObject(
en_US="the format to return a response in."
" Currently the only accepted value is json."
" Currently the only accepted value is json."
),
options=["json"],
),

View File

@ -1,6 +1,5 @@
import logging
from typing import Generator
from typing import List, Optional, Union, cast
from typing import Generator, List, Optional, Union, cast
import tiktoken
from openai import OpenAI, Stream

View File

@ -1,8 +1,7 @@
import json
import logging
from typing import Generator
from decimal import Decimal
from typing import List, Optional, Union, cast
from typing import Generator, List, Optional, Union, cast
from urllib.parse import urljoin
import requests

View File

@ -1,7 +1,6 @@
from typing import Generator
from enum import Enum
from json import dumps, loads
from typing import Any, Union
from typing import Any, Generator, Union
from requests import Response, post
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema

View File

@ -1,6 +1,5 @@
import threading
from typing import Generator
from typing import Dict, List, Optional, Type, Union
from typing import Dict, Generator, List, Optional, Type, Union
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,

View File

@ -1,5 +1,4 @@
from typing import Generator
from typing import List, Optional, Union
from typing import Generator, List, Optional, Union
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import (

View File

@ -1,5 +1,4 @@
from typing import Generator
from typing import Dict, List, Optional, Type, Union
from typing import Dict, Generator, List, Optional, Type, Union
from dashscope import get_tokenizer
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse

View File

@ -1,6 +1,4 @@
from typing import Generator, Iterator
from typing import Dict, List, Union, cast, Type
from typing import Dict, Generator, Iterator, List, Type, Union, cast
from openai import (
APIConnectionError,
@ -527,8 +525,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
def _extract_response_tool_calls(
self,
response_tool_calls: Union[
List[ChatCompletionMessageToolCall],
List[ChoiceDeltaToolCall]
List[ChatCompletionMessageToolCall], List[ChoiceDeltaToolCall]
],
) -> List[AssistantPromptMessage.ToolCall]:
"""

View File

@ -195,8 +195,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
if stop:
extra_model_kwargs["stop"] = stop
client = ZhipuAI(base_url=credentials_kwargs["api_base"],
api_key=credentials_kwargs["api_key"])
client = ZhipuAI(
base_url=credentials_kwargs["api_base"],
api_key=credentials_kwargs["api_key"],
)
if len(prompt_messages) == 0:
raise ValueError("At least one message is required")

View File

@ -43,8 +43,10 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
:return: embeddings result
"""
credentials_kwargs = self._to_credential_kwargs(credentials)
client = ZhipuAI(base_url=credentials_kwargs["api_base"],
api_key=credentials_kwargs["api_key"])
client = ZhipuAI(
base_url=credentials_kwargs["api_base"],
api_key=credentials_kwargs["api_key"],
)
embeddings, embedding_used_tokens = self.embed_documents(model, client, texts)
@ -85,8 +87,10 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
try:
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
client = ZhipuAI(base_url=credentials_kwargs["api_base"],
api_key=credentials_kwargs["api_key"])
client = ZhipuAI(
base_url=credentials_kwargs["api_base"],
api_key=credentials_kwargs["api_key"],
)
# call embedding model
self.embed_documents(

View File

@ -1,4 +1,5 @@
import pydantic
from ...._models import BaseModel

View File

@ -1,6 +1,7 @@
import os
import orjson
from ..._models import BaseModel

View File

@ -1,8 +1,7 @@
import os
import shutil
from typing import Generator
from contextlib import closing
from typing import Union
from typing import Generator, Union
import boto3
from botocore.exceptions import ClientError

View File

@ -104,15 +104,17 @@ def logging_conf() -> dict:
111,
)
@pytest.fixture
def providers_file(request) -> str:
from pathlib import Path
import os
from pathlib import Path
# 当前执行目录
# 获取当前测试文件的路径
test_file_path = Path(str(request.fspath)).parent
print("test_file_path:",test_file_path)
return os.path.join(test_file_path,"model_providers.yaml")
print("test_file_path:", test_file_path)
return os.path.join(test_file_path, "model_providers.yaml")
@pytest.fixture
@ -121,9 +123,7 @@ def init_server(logging_conf: dict, providers_file: str) -> None:
try:
boot = (
BootstrapWebBuilder()
.model_providers_cfg_path(
model_providers_cfg_path=providers_file
)
.model_providers_cfg_path(model_providers_cfg_path=providers_file)
.host(host="127.0.0.1")
.port(port=20000)
.build()
@ -139,5 +139,4 @@ def init_server(logging_conf: dict, providers_file: str) -> None:
boot.destroy()
except SystemExit:
raise

View File

@ -1,15 +1,20 @@
import logging
import pytest
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest
import logging
logger = logging.getLogger(__name__)
@pytest.mark.requires("openai")
def test_llm(init_server: str):
llm = ChatOpenAI(model_name="deepseek-chat", openai_api_key="sk-dcb625fcbc1e497d80b7b9493b51d758", openai_api_base=f"{init_server}/deepseek/v1")
llm = ChatOpenAI(
model_name="deepseek-chat",
openai_api_key="sk-dcb625fcbc1e497d80b7b9493b51d758",
openai_api_base=f"{init_server}/deepseek/v1",
)
template = """Question: {question}
Answer: Let's think step by step."""

View File

@ -1,15 +1,20 @@
import logging
import pytest
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest
import logging
logger = logging.getLogger(__name__)
@pytest.mark.requires("openai")
def test_llm(init_server: str):
llm = ChatOpenAI(model_name="llama3", openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/ollama/v1")
llm = ChatOpenAI(
model_name="llama3",
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/ollama/v1",
)
template = """Question: {question}
Answer: Let's think step by step."""
@ -23,9 +28,11 @@ def test_llm(init_server: str):
@pytest.mark.requires("openai")
def test_embedding(init_server: str):
embeddings = OpenAIEmbeddings(model="text-embedding-3-large",
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/zhipuai/v1")
embeddings = OpenAIEmbeddings(
model="text-embedding-3-large",
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/zhipuai/v1",
)
text = "你好"

View File

@ -1,15 +1,18 @@
import logging
import pytest
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest
import logging
logger = logging.getLogger(__name__)
@pytest.mark.requires("openai")
def test_llm(init_server: str):
llm = ChatOpenAI(openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1")
llm = ChatOpenAI(
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1"
)
template = """Question: {question}
Answer: Let's think step by step."""
@ -21,17 +24,16 @@ def test_llm(init_server: str):
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
@pytest.mark.requires("openai")
def test_embedding(init_server: str):
embeddings = OpenAIEmbeddings(model="text-embedding-3-large",
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/zhipuai/v1")
embeddings = OpenAIEmbeddings(
model="text-embedding-3-large",
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/zhipuai/v1",
)
text = "你好"
query_result = embeddings.embed_query(text)
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")

View File

@ -1,8 +1,9 @@
import logging
import pytest
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest
import logging
logger = logging.getLogger(__name__)
@ -10,9 +11,10 @@ logger = logging.getLogger(__name__)
@pytest.mark.requires("xinference_client")
def test_llm(init_server: str):
llm = ChatOpenAI(
model_name="glm-4",
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/xinference/v1")
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/xinference/v1",
)
template = """Question: {question}
Answer: Let's think step by step."""
@ -26,9 +28,11 @@ def test_llm(init_server: str):
@pytest.mark.requires("xinference-client")
def test_embedding(init_server: str):
embeddings = OpenAIEmbeddings(model="text_embedding",
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/xinference/v1")
embeddings = OpenAIEmbeddings(
model="text_embedding",
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/xinference/v1",
)
text = "你好"

View File

@ -1,17 +1,20 @@
import logging
import pytest
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest
import logging
logger = logging.getLogger(__name__)
@pytest.mark.requires("zhipuai")
def test_llm(init_server: str):
llm = ChatOpenAI(
model_name="glm-4",
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1")
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/zhipuai/v1",
)
template = """Question: {question}
Answer: Let's think step by step."""
@ -25,15 +28,14 @@ def test_llm(init_server: str):
@pytest.mark.requires("zhipuai")
def test_embedding(init_server: str):
embeddings = OpenAIEmbeddings(model="text_embedding",
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/zhipuai/v1")
embeddings = OpenAIEmbeddings(
model="text_embedding",
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/zhipuai/v1",
)
text = "你好"
query_result = embeddings.embed_query(text)
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")

View File

@ -7,7 +7,10 @@ from omegaconf import OmegaConf
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelType,
)
from model_providers.core.provider_manager import ProviderManager
logger = logging.getLogger(__name__)
@ -16,9 +19,7 @@ logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(
providers_file
)
cfg = OmegaConf.load(providers_file)
# 转换配置文件
(
provider_name_to_provider_records_dict,
@ -35,15 +36,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
)
llm_models: List[AIModelEntity] = []
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
))
llm_models.append(
provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
)
)
# 获取预定义模型
llm_models.extend(
provider_model_bundle_llm.model_type_instance.predefined_models()
)
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
logger.info(f"predefined_models: {llm_models}")

View File

@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(
providers_file
)
cfg = OmegaConf.load(providers_file)
# 转换配置文件
(
provider_name_to_provider_records_dict,
@ -34,16 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
)
llm_models: List[AIModelEntity] = []
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
))
llm_models.append(
provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
)
)
# 获取预定义模型
llm_models.extend(
provider_model_bundle_llm.model_type_instance.predefined_models()
)
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
logger.info(f"predefined_models: {llm_models}")

View File

@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(
providers_file
)
cfg = OmegaConf.load(providers_file)
# 转换配置文件
(
provider_name_to_provider_records_dict,
@ -34,15 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
)
llm_models: List[AIModelEntity] = []
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
))
llm_models.append(
provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
)
)
# 获取预定义模型
llm_models.extend(
provider_model_bundle_llm.model_type_instance.predefined_models()
)
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
logger.info(f"predefined_models: {llm_models}")

View File

@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(
providers_file
)
cfg = OmegaConf.load(providers_file)
# 转换配置文件
(
provider_name_to_provider_records_dict,
@ -34,15 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
)
llm_models: List[AIModelEntity] = []
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
))
llm_models.append(
provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
)
)
# 获取预定义模型
llm_models.extend(
provider_model_bundle_llm.model_type_instance.predefined_models()
)
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
logger.info(f"predefined_models: {llm_models}")

View File

@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(
providers_file
)
cfg = OmegaConf.load(providers_file)
# 转换配置文件
(
provider_name_to_provider_records_dict,
@ -34,15 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
)
llm_models: List[AIModelEntity] = []
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
))
llm_models.append(
provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model,
credentials=model.credentials,
)
)
# 获取预定义模型
llm_models.extend(
provider_model_bundle_llm.model_type_instance.predefined_models()
)
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
logger.info(f"predefined_models: {llm_models}")