make format

This commit is contained in:
glide-the 2024-05-20 11:58:30 +08:00
parent 6dd00b5d94
commit 4e9b1d6edf
49 changed files with 631 additions and 407 deletions

View File

@ -1,11 +1,11 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
from datetime import date, datetime from datetime import date, datetime
from typing_extensions import Self from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload
import pydantic import pydantic
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from typing_extensions import Self
from ._types import StrBytesIntFloat from ._types import StrBytesIntFloat
@ -45,23 +45,49 @@ if TYPE_CHECKING:
else: else:
if PYDANTIC_V2: if PYDANTIC_V2:
from pydantic.v1.datetime_parse import (
parse_date as parse_date,
)
from pydantic.v1.datetime_parse import (
parse_datetime as parse_datetime,
)
from pydantic.v1.typing import ( from pydantic.v1.typing import (
get_args as get_args, get_args as get_args,
is_union as is_union, )
from pydantic.v1.typing import (
get_origin as get_origin, get_origin as get_origin,
is_typeddict as is_typeddict, )
from pydantic.v1.typing import (
is_literal_type as is_literal_type, is_literal_type as is_literal_type,
) )
from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime from pydantic.v1.typing import (
is_typeddict as is_typeddict,
)
from pydantic.v1.typing import (
is_union as is_union,
)
else: else:
from pydantic.datetime_parse import (
parse_date as parse_date,
)
from pydantic.datetime_parse import (
parse_datetime as parse_datetime,
)
from pydantic.typing import ( from pydantic.typing import (
get_args as get_args, get_args as get_args,
is_union as is_union, )
from pydantic.typing import (
get_origin as get_origin, get_origin as get_origin,
is_typeddict as is_typeddict, )
from pydantic.typing import (
is_literal_type as is_literal_type, is_literal_type as is_literal_type,
) )
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime from pydantic.typing import (
is_typeddict as is_typeddict,
)
from pydantic.typing import (
is_union as is_union,
)
# refactored config # refactored config
@ -204,7 +230,9 @@ if TYPE_CHECKING:
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: def __get__(self, instance: object, owner: type[Any] | None = None) -> _T:
... ...
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self: def __get__(
self, instance: object, owner: type[Any] | None = None
) -> _T | Self:
raise NotImplementedError() raise NotImplementedError()
def __set_name__(self, owner: type[Any], name: str) -> None: def __set_name__(self, owner: type[Any], name: str) -> None:

View File

@ -4,20 +4,20 @@ import io
import os import os
import pathlib import pathlib
from typing import overload from typing import overload
from typing_extensions import TypeGuard
import anyio import anyio
from typing_extensions import TypeGuard
from ._types import ( from ._types import (
FileTypes,
FileContent,
RequestFiles,
HttpxFileTypes,
Base64FileInput, Base64FileInput,
FileContent,
FileTypes,
HttpxFileContent, HttpxFileContent,
HttpxFileTypes,
HttpxRequestFiles, HttpxRequestFiles,
RequestFiles,
) )
from ._utils import is_tuple_t, is_mapping_t, is_sequence_t from ._utils import is_mapping_t, is_sequence_t, is_tuple_t
def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]: def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
@ -26,13 +26,20 @@ def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
def is_file_content(obj: object) -> TypeGuard[FileContent]: def is_file_content(obj: object) -> TypeGuard[FileContent]:
return ( return (
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike) isinstance(obj, bytes)
or isinstance(obj, tuple)
or isinstance(obj, io.IOBase)
or isinstance(obj, os.PathLike)
) )
def assert_is_file_content(obj: object, *, key: str | None = None) -> None: def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
if not is_file_content(obj): if not is_file_content(obj):
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`" prefix = (
f"Expected entry at `{key}`"
if key is not None
else f"Expected file input `{obj!r}`"
)
raise RuntimeError( raise RuntimeError(
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads" f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads"
) from None ) from None
@ -57,7 +64,9 @@ def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
elif is_sequence_t(files): elif is_sequence_t(files):
files = [(key, _transform_file(file)) for key, file in files] files = [(key, _transform_file(file)) for key, file in files]
else: else:
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence") raise TypeError(
f"Unexpected file type input {type(files)}, expected mapping or sequence"
)
return files return files
@ -73,7 +82,9 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes:
if is_tuple_t(file): if is_tuple_t(file):
return (file[0], _read_file_content(file[1]), *file[2:]) return (file[0], _read_file_content(file[1]), *file[2:])
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") raise TypeError(
f"Expected file types input to be a FileContent type or to be a tuple"
)
def _read_file_content(file: FileContent) -> HttpxFileContent: def _read_file_content(file: FileContent) -> HttpxFileContent:
@ -101,7 +112,9 @@ async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles
elif is_sequence_t(files): elif is_sequence_t(files):
files = [(key, await _async_transform_file(file)) for key, file in files] files = [(key, await _async_transform_file(file)) for key, file in files]
else: else:
raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence") raise TypeError(
"Unexpected file type input {type(files)}, expected mapping or sequence"
)
return files return files
@ -117,7 +130,9 @@ async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
if is_tuple_t(file): if is_tuple_t(file):
return (file[0], await _async_read_file_content(file[1]), *file[2:]) return (file[0], await _async_read_file_content(file[1]), *file[2:])
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") raise TypeError(
f"Expected file types input to be a FileContent type or to be a tuple"
)
async def _async_read_file_content(file: FileContent) -> HttpxFileContent: async def _async_read_file_content(file: FileContent) -> HttpxFileContent:

View File

@ -1,60 +1,62 @@
from __future__ import annotations from __future__ import annotations
import os
import inspect import inspect
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast import os
from datetime import date, datetime from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Callable, Generic, Type, TypeVar, Union, cast
import pydantic
import pydantic.generics
from pydantic.fields import FieldInfo
from typing_extensions import ( from typing_extensions import (
Unpack,
Literal,
ClassVar, ClassVar,
Literal,
Protocol, Protocol,
Required, Required,
TypedDict, TypedDict,
TypeGuard, TypeGuard,
Unpack,
final, final,
override, override,
runtime_checkable, runtime_checkable,
) )
import pydantic from ._compat import (
import pydantic.generics PYDANTIC_V2,
from pydantic.fields import FieldInfo ConfigDict,
field_get_default,
get_args,
get_model_config,
get_model_fields,
get_origin,
is_literal_type,
is_union,
parse_obj,
)
from ._compat import (
GenericModel as BaseGenericModel,
)
from ._types import ( from ._types import (
IncEx, IncEx,
ModelT, ModelT,
) )
from ._utils import ( from ._utils import (
PropertyInfo, PropertyInfo,
is_list,
is_given,
lru_cache,
is_mapping,
parse_date,
coerce_boolean, coerce_boolean,
parse_datetime,
strip_not_given,
extract_type_arg, extract_type_arg,
is_annotated_type, is_annotated_type,
is_given,
is_list,
is_mapping,
lru_cache,
parse_date,
parse_datetime,
strip_annotated_type, strip_annotated_type,
) strip_not_given,
from ._compat import (
PYDANTIC_V2,
ConfigDict,
GenericModel as BaseGenericModel,
get_args,
is_union,
parse_obj,
get_origin,
is_literal_type,
get_model_config,
get_model_fields,
field_get_default,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema from pydantic_core.core_schema import LiteralSchema, ModelField, ModelFieldsSchema
__all__ = ["BaseModel", "GenericModel"] __all__ = ["BaseModel", "GenericModel"]
@ -69,7 +71,8 @@ class _ConfigProtocol(Protocol):
class BaseModel(pydantic.BaseModel): class BaseModel(pydantic.BaseModel):
if PYDANTIC_V2: if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict( model_config: ClassVar[ConfigDict] = ConfigDict(
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")) extra="allow",
defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")),
) )
else: else:
@ -189,7 +192,9 @@ class BaseModel(pydantic.BaseModel):
key = name key = name
if key in values: if key in values:
fields_values[name] = _construct_field(value=values[key], field=field, key=key) fields_values[name] = _construct_field(
value=values[key], field=field, key=key
)
_fields_set.add(name) _fields_set.add(name)
else: else:
fields_values[name] = field_get_default(field) fields_values[name] = field_get_default(field)
@ -412,9 +417,13 @@ def construct_type(*, value: object, type_: object) -> object:
# #
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
# we'd end up constructing `FooType` when it should be `BarType`. # we'd end up constructing `FooType` when it should be `BarType`.
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta) discriminator = _build_discriminated_union_meta(
union=type_, meta_annotations=meta
)
if discriminator and is_mapping(value): if discriminator and is_mapping(value):
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name) variant_value = value.get(
discriminator.field_alias_from or discriminator.field_name
)
if variant_value and isinstance(variant_value, str): if variant_value and isinstance(variant_value, str):
variant_type = discriminator.mapping.get(variant_value) variant_type = discriminator.mapping.get(variant_value)
if variant_type: if variant_type:
@ -434,11 +443,19 @@ def construct_type(*, value: object, type_: object) -> object:
return value return value
_, items_type = get_args(type_) # Dict[_, items_type] _, items_type = get_args(type_) # Dict[_, items_type]
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()} return {
key: construct_type(value=item, type_=items_type)
for key, item in value.items()
}
if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)): if not is_literal_type(type_) and (
issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
):
if is_list(value): if is_list(value):
return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value] return [
cast(Any, type_).construct(**entry) if is_mapping(entry) else entry
for entry in value
]
if is_mapping(value): if is_mapping(value):
if issubclass(type_, BaseModel): if issubclass(type_, BaseModel):
@ -523,14 +540,19 @@ class DiscriminatorDetails:
self.field_alias_from = discriminator_alias self.field_alias_from = discriminator_alias
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: def _build_discriminated_union_meta(
*, union: type, meta_annotations: tuple[Any, ...]
) -> DiscriminatorDetails | None:
if isinstance(union, CachedDiscriminatorType): if isinstance(union, CachedDiscriminatorType):
return union.__discriminator__ return union.__discriminator__
discriminator_field_name: str | None = None discriminator_field_name: str | None = None
for annotation in meta_annotations: for annotation in meta_annotations:
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None: if (
isinstance(annotation, PropertyInfo)
and annotation.discriminator is not None
):
discriminator_field_name = annotation.discriminator discriminator_field_name = annotation.discriminator
break break
@ -558,7 +580,9 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
if isinstance(entry, str): if isinstance(entry, str):
mapping[entry] = variant mapping[entry] = variant
else: else:
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(
discriminator_field_name
) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
if not field_info: if not field_info:
continue continue
@ -582,7 +606,9 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
return details return details
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None: def _extract_field_schema_pv2(
model: type[BaseModel], field_name: str
) -> ModelField | None:
schema = model.__pydantic_core_schema__ schema = model.__pydantic_core_schema__
if schema["type"] != "model": if schema["type"] != "model":
return None return None
@ -621,7 +647,9 @@ else:
if PYDANTIC_V2: if PYDANTIC_V2:
from pydantic import TypeAdapter as _TypeAdapter from pydantic import TypeAdapter as _TypeAdapter
_CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter)) _CachedTypeAdapter = cast(
"TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter)
)
if TYPE_CHECKING: if TYPE_CHECKING:
from pydantic import TypeAdapter from pydantic import TypeAdapter
@ -652,6 +680,3 @@ elif not TYPE_CHECKING: # TODO: condition is weird
def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]: def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
return RootModel[type_] # type: ignore return RootModel[type_] # type: ignore

View File

@ -5,22 +5,29 @@ from typing import (
IO, IO,
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable,
Dict, Dict,
List, List,
Type,
Tuple,
Union,
Mapping, Mapping,
TypeVar,
Callable,
Optional, Optional,
Sequence, Sequence,
Tuple,
Type,
TypeVar,
Union,
) )
from typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable
import httpx import httpx
import pydantic import pydantic
from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport from httpx import URL, AsyncBaseTransport, BaseTransport, Proxy, Response, Timeout
from typing_extensions import (
Literal,
Protocol,
TypeAlias,
TypedDict,
override,
runtime_checkable,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from ._models import BaseModel from ._models import BaseModel
@ -43,7 +50,9 @@ if TYPE_CHECKING:
FileContent = Union[IO[bytes], bytes, PathLike[str]] FileContent = Union[IO[bytes], bytes, PathLike[str]]
else: else:
Base64FileInput = Union[IO[bytes], PathLike] Base64FileInput = Union[IO[bytes], PathLike]
FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8. FileContent = Union[
IO[bytes], bytes, PathLike
] # PathLike is not subscriptable in Python 3.8.
FileTypes = Union[ FileTypes = Union[
# file (or bytes) # file (or bytes)
FileContent, FileContent,
@ -68,7 +77,9 @@ HttpxFileTypes = Union[
# (filename, file (or bytes), content_type, headers) # (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]], Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
] ]
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]] HttpxRequestFiles = Union[
Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]
]
# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT # Workaround to support (cast_to: Type[ResponseT]) -> ResponseT
# where ResponseT includes `None`. In order to support directly # where ResponseT includes `None`. In order to support directly

View File

@ -1,48 +1,126 @@
from ._utils import ( from ._transform import (
flatten as flatten, PropertyInfo as PropertyInfo,
is_dict as is_dict, )
is_list as is_list, from ._transform import (
is_given as is_given, async_maybe_transform as async_maybe_transform,
is_tuple as is_tuple, )
lru_cache as lru_cache, from ._transform import (
is_mapping as is_mapping, async_transform as async_transform,
is_tuple_t as is_tuple_t, )
parse_date as parse_date, from ._transform import (
is_iterable as is_iterable, maybe_transform as maybe_transform,
is_sequence as is_sequence, )
coerce_float as coerce_float, from ._transform import (
is_mapping_t as is_mapping_t, transform as transform,
removeprefix as removeprefix, )
removesuffix as removesuffix, from ._typing import (
extract_files as extract_files, extract_type_arg as extract_type_arg,
is_sequence_t as is_sequence_t, )
required_args as required_args, from ._typing import (
coerce_boolean as coerce_boolean, extract_type_var_from_base as extract_type_var_from_base,
coerce_integer as coerce_integer, )
file_from_path as file_from_path, from ._typing import (
parse_datetime as parse_datetime, is_annotated_type as is_annotated_type,
strip_not_given as strip_not_given, )
deepcopy_minimal as deepcopy_minimal, from ._typing import (
get_async_library as get_async_library, is_iterable_type as is_iterable_type,
maybe_coerce_float as maybe_coerce_float,
get_required_header as get_required_header,
maybe_coerce_boolean as maybe_coerce_boolean,
maybe_coerce_integer as maybe_coerce_integer,
) )
from ._typing import ( from ._typing import (
is_list_type as is_list_type, is_list_type as is_list_type,
is_union_type as is_union_type, )
extract_type_arg as extract_type_arg, from ._typing import (
is_iterable_type as is_iterable_type,
is_required_type as is_required_type, is_required_type as is_required_type,
is_annotated_type as is_annotated_type, )
from ._typing import (
is_union_type as is_union_type,
)
from ._typing import (
strip_annotated_type as strip_annotated_type, strip_annotated_type as strip_annotated_type,
extract_type_var_from_base as extract_type_var_from_base,
) )
from ._transform import ( from ._utils import (
PropertyInfo as PropertyInfo, coerce_boolean as coerce_boolean,
transform as transform, )
async_transform as async_transform, from ._utils import (
maybe_transform as maybe_transform, coerce_float as coerce_float,
async_maybe_transform as async_maybe_transform, )
from ._utils import (
coerce_integer as coerce_integer,
)
from ._utils import (
deepcopy_minimal as deepcopy_minimal,
)
from ._utils import (
extract_files as extract_files,
)
from ._utils import (
file_from_path as file_from_path,
)
from ._utils import (
flatten as flatten,
)
from ._utils import (
get_async_library as get_async_library,
)
from ._utils import (
get_required_header as get_required_header,
)
from ._utils import (
is_dict as is_dict,
)
from ._utils import (
is_given as is_given,
)
from ._utils import (
is_iterable as is_iterable,
)
from ._utils import (
is_list as is_list,
)
from ._utils import (
is_mapping as is_mapping,
)
from ._utils import (
is_mapping_t as is_mapping_t,
)
from ._utils import (
is_sequence as is_sequence,
)
from ._utils import (
is_sequence_t as is_sequence_t,
)
from ._utils import (
is_tuple as is_tuple,
)
from ._utils import (
is_tuple_t as is_tuple_t,
)
from ._utils import (
lru_cache as lru_cache,
)
from ._utils import (
maybe_coerce_boolean as maybe_coerce_boolean,
)
from ._utils import (
maybe_coerce_float as maybe_coerce_float,
)
from ._utils import (
maybe_coerce_integer as maybe_coerce_integer,
)
from ._utils import (
parse_date as parse_date,
)
from ._utils import (
parse_datetime as parse_datetime,
)
from ._utils import (
removeprefix as removeprefix,
)
from ._utils import (
removesuffix as removesuffix,
)
from ._utils import (
required_args as required_args,
)
from ._utils import (
strip_not_given as strip_not_given,
) )

View File

@ -1,31 +1,31 @@
from __future__ import annotations from __future__ import annotations
import io
import base64 import base64
import io
import pathlib import pathlib
from typing import Any, Mapping, TypeVar, cast
from datetime import date, datetime from datetime import date, datetime
from typing_extensions import Literal, get_args, override, get_type_hints from typing import Any, Mapping, TypeVar, cast
import anyio import anyio
import pydantic import pydantic
from typing_extensions import Literal, get_args, get_type_hints, override
from ._utils import ( from .._compat import is_typeddict, model_dump
is_list,
is_mapping,
is_iterable,
)
from .._files import is_base64_file_input from .._files import is_base64_file_input
from ._typing import ( from ._typing import (
is_list_type,
is_union_type,
extract_type_arg, extract_type_arg,
is_iterable_type,
is_required_type,
is_annotated_type, is_annotated_type,
is_iterable_type,
is_list_type,
is_required_type,
is_union_type,
strip_annotated_type, strip_annotated_type,
) )
from .._compat import model_dump, is_typeddict from ._utils import (
is_iterable,
is_list,
is_mapping,
)
_T = TypeVar("_T") _T = TypeVar("_T")
@ -171,10 +171,17 @@ def _transform_recursive(
# List[T] # List[T]
(is_list_type(stripped_type) and is_list(data)) (is_list_type(stripped_type) and is_list(data))
# Iterable[T] # Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) or (
is_iterable_type(stripped_type)
and is_iterable(data)
and not isinstance(data, str)
)
): ):
inner_type = extract_type_arg(stripped_type, 0) inner_type = extract_type_arg(stripped_type, 0)
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] return [
_transform_recursive(d, annotation=annotation, inner_type=inner_type)
for d in data
]
if is_union_type(stripped_type): if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed. # For union types we run the transformation against all subtypes to ensure that everything is transformed.
@ -201,7 +208,9 @@ def _transform_recursive(
return data return data
def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: def _format_data(
data: object, format_: PropertyFormat, format_template: str | None
) -> object:
if isinstance(data, (date, datetime)): if isinstance(data, (date, datetime)):
if format_ == "iso8601": if format_ == "iso8601":
return data.isoformat() return data.isoformat()
@ -221,7 +230,9 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N
binary = binary.encode() binary = binary.encode()
if not isinstance(binary, bytes): if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") raise RuntimeError(
f"Could not read bytes from {data}; Received {type(binary)}"
)
return base64.b64encode(binary).decode("ascii") return base64.b64encode(binary).decode("ascii")
@ -240,7 +251,9 @@ def _transform_typeddict(
# we do not have a type annotation for this field, leave it as is # we do not have a type annotation for this field, leave it as is
result[key] = value result[key] = value
else: else:
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_) result[_maybe_transform_key(key, type_)] = _transform_recursive(
value, annotation=type_
)
return result return result
@ -276,7 +289,9 @@ async def async_transform(
It should be noted that the transformations that this function does are not represented in the type system. It should be noted that the transformations that this function does are not represented in the type system.
""" """
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type)) transformed = await _async_transform_recursive(
data, annotation=cast(type, expected_type)
)
return cast(_T, transformed) return cast(_T, transformed)
@ -309,10 +324,19 @@ async def _async_transform_recursive(
# List[T] # List[T]
(is_list_type(stripped_type) and is_list(data)) (is_list_type(stripped_type) and is_list(data))
# Iterable[T] # Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) or (
is_iterable_type(stripped_type)
and is_iterable(data)
and not isinstance(data, str)
)
): ):
inner_type = extract_type_arg(stripped_type, 0) inner_type = extract_type_arg(stripped_type, 0)
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] return [
await _async_transform_recursive(
d, annotation=annotation, inner_type=inner_type
)
for d in data
]
if is_union_type(stripped_type): if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed. # For union types we run the transformation against all subtypes to ensure that everything is transformed.
@ -320,7 +344,9 @@ async def _async_transform_recursive(
# TODO: there may be edge cases where the same normalized field name will transform to two different names # TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes. # in different subtypes.
for subtype in get_args(stripped_type): for subtype in get_args(stripped_type):
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype) data = await _async_transform_recursive(
data, annotation=annotation, inner_type=subtype
)
return data return data
if isinstance(data, pydantic.BaseModel): if isinstance(data, pydantic.BaseModel):
@ -334,12 +360,16 @@ async def _async_transform_recursive(
annotations = get_args(annotated_type)[1:] annotations = get_args(annotated_type)[1:]
for annotation in annotations: for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None: if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return await _async_format_data(data, annotation.format, annotation.format_template) return await _async_format_data(
data, annotation.format, annotation.format_template
)
return data return data
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: async def _async_format_data(
data: object, format_: PropertyFormat, format_template: str | None
) -> object:
if isinstance(data, (date, datetime)): if isinstance(data, (date, datetime)):
if format_ == "iso8601": if format_ == "iso8601":
return data.isoformat() return data.isoformat()
@ -359,7 +389,9 @@ async def _async_format_data(data: object, format_: PropertyFormat, format_templ
binary = binary.encode() binary = binary.encode()
if not isinstance(binary, bytes): if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") raise RuntimeError(
f"Could not read bytes from {data}; Received {type(binary)}"
)
return base64.b64encode(binary).decode("ascii") return base64.b64encode(binary).decode("ascii")
@ -378,5 +410,7 @@ async def _async_transform_typeddict(
# we do not have a type annotation for this field, leave it as is # we do not have a type annotation for this field, leave it as is
result[key] = value result[key] = value
else: else:
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_) result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(
value, annotation=type_
)
return result return result

View File

@ -1,11 +1,12 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, TypeVar, Iterable, cast
from collections import abc as _c_abc from collections import abc as _c_abc
from typing_extensions import Required, Annotated, get_args, get_origin from typing import Any, Iterable, TypeVar, cast
from typing_extensions import Annotated, Required, get_args, get_origin
from .._types import InheritsGeneric
from .._compat import is_union as _is_union from .._compat import is_union as _is_union
from .._types import InheritsGeneric
def is_annotated_type(typ: type) -> bool: def is_annotated_type(typ: type) -> bool:
@ -49,15 +50,17 @@ def extract_type_arg(typ: type, index: int) -> type:
try: try:
return cast(type, args[index]) return cast(type, args[index])
except IndexError as err: except IndexError as err:
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err raise RuntimeError(
f"Expected type {typ} to have a type argument at index {index} but it did not"
) from err
def extract_type_var_from_base( def extract_type_var_from_base(
typ: type, typ: type,
*, *,
generic_bases: tuple[type, ...], generic_bases: tuple[type, ...],
index: int, index: int,
failure_message: str | None = None, failure_message: str | None = None,
) -> type: ) -> type:
"""Given a type like `Foo[T]`, returns the generic type variable `T`. """Given a type like `Foo[T]`, returns the generic type variable `T`.
@ -117,4 +120,7 @@ def extract_type_var_from_base(
return extracted return extracted
raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}") raise RuntimeError(
failure_message
or f"Could not resolve inner type variable at index {index} for {typ}"
)

View File

@ -1,27 +1,28 @@
from __future__ import annotations from __future__ import annotations
import functools
import inspect
import os import os
import re import re
import inspect from pathlib import Path
import functools
from typing import ( from typing import (
Any, Any,
Tuple,
Mapping,
TypeVar,
Callable, Callable,
Iterable, Iterable,
Mapping,
Sequence, Sequence,
Tuple,
TypeVar,
cast, cast,
overload, overload,
) )
from pathlib import Path
from typing_extensions import TypeGuard
import sniffio import sniffio
from typing_extensions import TypeGuard
from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike from .._compat import parse_date as parse_date
from .._compat import parse_date as parse_date, parse_datetime as parse_datetime from .._compat import parse_datetime as parse_datetime
from .._types import FileTypes, Headers, HeadersLike, NotGiven, NotGivenOr
_T = TypeVar("_T") _T = TypeVar("_T")
_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...]) _TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
@ -108,7 +109,9 @@ def _extract_items(
item, item,
path, path,
index=index, index=index,
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]", flattened_key=flattened_key + "[]"
if flattened_key is not None
else "[]",
) )
for item in obj for item in obj
] ]
@ -261,7 +264,12 @@ def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
else: # no break else: # no break
if len(variants) > 1: if len(variants) > 1:
variations = human_join( variations = human_join(
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants] [
"("
+ human_join([quote(arg) for arg in variant], final="and")
+ ")"
for variant in variants
]
) )
msg = f"Missing required arguments; Expected either {variations} arguments to be given" msg = f"Missing required arguments; Expected either {variations} arguments to be given"
else: else:
@ -376,7 +384,11 @@ def get_required_header(headers: HeadersLike, header: str) -> str:
return v return v
""" to deal with the case where the header looks like Stainless-Event-Id """ """ to deal with the case where the header looks like Stainless-Event-Id """
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize()) intercaps_header = re.sub(
r"([^\w])(\w)",
lambda pat: pat.group(1) + pat.group(2).upper(),
header.capitalize(),
)
for normalized_header in [header, lower_header, header.upper(), intercaps_header]: for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
value = headers.get(normalized_header) value = headers.get(normalized_header)

View File

@ -1,9 +1,6 @@
from enum import Enum from enum import Enum
from typing import List, Literal, Optional from typing import List, Literal, Optional
from ..._compat import PYDANTIC_V2, ConfigDict
from ..._models import BaseModel
from model_providers.core.entities.model_entities import ( from model_providers.core.entities.model_entities import (
ModelStatus, ModelStatus,
ModelWithProviderEntity, ModelWithProviderEntity,
@ -26,6 +23,9 @@ from model_providers.core.model_runtime.entities.provider_entities import (
SimpleProviderEntity, SimpleProviderEntity,
) )
from ..._compat import PYDANTIC_V2, ConfigDict
from ..._models import BaseModel
class CustomConfigurationStatus(Enum): class CustomConfigurationStatus(Enum):
""" """
@ -74,10 +74,9 @@ class ProviderResponse(BaseModel):
custom_configuration: CustomConfigurationResponse custom_configuration: CustomConfigurationResponse
system_configuration: SystemConfigurationResponse system_configuration: SystemConfigurationResponse
if PYDANTIC_V2: if PYDANTIC_V2:
model_config = ConfigDict( model_config = ConfigDict(protected_namespaces=())
protected_namespaces=()
)
else: else:
class Config: class Config:
protected_namespaces = () protected_namespaces = ()
@ -180,10 +179,9 @@ class DefaultModelResponse(BaseModel):
provider: SimpleProviderEntityResponse provider: SimpleProviderEntityResponse
if PYDANTIC_V2: if PYDANTIC_V2:
model_config = ConfigDict( model_config = ConfigDict(protected_namespaces=())
protected_namespaces=()
)
else: else:
class Config: class Config:
protected_namespaces = () protected_namespaces = ()

View File

@ -73,7 +73,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
self._server = None self._server = None
self._server_thread = None self._server_thread = None
def logging_conf(self,logging_conf: Optional[dict] = None): def logging_conf(self, logging_conf: Optional[dict] = None):
self._logging_conf = logging_conf self._logging_conf = logging_conf
@classmethod @classmethod
@ -132,7 +132,10 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
self._app.include_router(self._router) self._app.include_router(self._router)
config = Config( config = Config(
app=self._app, host=self._host, port=self._port, log_config=self._logging_conf app=self._app,
host=self._host,
port=self._port,
log_config=self._logging_conf,
) )
self._server = Server(config) self._server = Server(config)
@ -145,7 +148,6 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
self._server_thread.start() self._server_thread.start()
def destroy(self): def destroy(self):
logger.info("Shutting down server") logger.info("Shutting down server")
self._server.should_exit = True # 设置退出标志 self._server.should_exit = True # 设置退出标志
self._server.shutdown() # 停止服务器 self._server.shutdown() # 停止服务器
@ -155,7 +157,6 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
def join(self): def join(self):
self._server_thread.join() self._server_thread.join()
def set_app_event(self, started_event: mp.Event = None): def set_app_event(self, started_event: mp.Event = None):
@self._app.on_event("startup") @self._app.on_event("startup")
async def on_startup(): async def on_startup():
@ -190,12 +191,15 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
provider_model_bundle.model_type_instance.predefined_models() provider_model_bundle.model_type_instance.predefined_models()
) )
# 获取自定义模型 # 获取自定义模型
for model in provider_model_bundle.configuration.custom_configuration.models: for (
model
llm_models.append(provider_model_bundle.model_type_instance.get_model_schema( ) in provider_model_bundle.configuration.custom_configuration.models:
model=model.model, llm_models.append(
credentials=model.credentials, provider_model_bundle.model_type_instance.get_model_schema(
)) model=model.model,
credentials=model.credentials,
)
)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error while fetching models for provider: {provider}, model_type: {model_type}" f"Error while fetching models for provider: {provider}, model_type: {model_type}"
@ -225,13 +229,15 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
) )
# 判断embeddings_request.input是否为list # 判断embeddings_request.input是否为list
input = '' input = ""
if isinstance(embeddings_request.input, list): if isinstance(embeddings_request.input, list):
tokens = embeddings_request.input tokens = embeddings_request.input
try: try:
encoding = tiktoken.encoding_for_model(embeddings_request.model) encoding = tiktoken.encoding_for_model(embeddings_request.model)
except KeyError: except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.") logger.warning(
"Warning: model not found. Using cl100k_base encoding."
)
model = "cl100k_base" model = "cl100k_base"
encoding = tiktoken.get_encoding(model) encoding = tiktoken.get_encoding(model)
for i, token in enumerate(tokens): for i, token in enumerate(tokens):
@ -241,7 +247,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
else: else:
input = embeddings_request.input input = embeddings_request.input
response = model_instance.invoke_text_embedding(texts=[input], user="abc-123") response = model_instance.invoke_text_embedding(
texts=[input], user="abc-123"
)
return await openai_embedding_text(response) return await openai_embedding_text(response)
except ValueError as e: except ValueError as e:

View File

@ -1,10 +1,12 @@
import time import time
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from ..._models import BaseModel
from pydantic import Field as FieldInfo from pydantic import Field as FieldInfo
from typing_extensions import Literal from typing_extensions import Literal
from ..._models import BaseModel
class Role(str, Enum): class Role(str, Enum):
USER = "user" USER = "user"

View File

@ -1,5 +1,5 @@
import enum import enum
from typing import Any, cast, List from typing import Any, List, cast
from langchain.schema import ( from langchain.schema import (
AIMessage, AIMessage,
@ -8,7 +8,6 @@ from langchain.schema import (
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
) )
from ..._models import BaseModel
from model_providers.core.model_runtime.entities.message_entities import ( from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
@ -20,6 +19,8 @@ from model_providers.core.model_runtime.entities.message_entities import (
UserPromptMessage, UserPromptMessage,
) )
from ..._models import BaseModel
class PromptMessageFileType(enum.Enum): class PromptMessageFileType(enum.Enum):
IMAGE = "image" IMAGE = "image"

View File

@ -1,9 +1,6 @@
from enum import Enum from enum import Enum
from typing import List, Optional from typing import List, Optional
from ..._compat import PYDANTIC_V2, ConfigDict
from ..._models import BaseModel
from model_providers.core.model_runtime.entities.common_entities import I18nObject from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import ( from model_providers.core.model_runtime.entities.model_entities import (
ModelType, ModelType,
@ -11,6 +8,9 @@ from model_providers.core.model_runtime.entities.model_entities import (
) )
from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity
from ..._compat import PYDANTIC_V2, ConfigDict
from ..._models import BaseModel
class ModelStatus(Enum): class ModelStatus(Enum):
""" """
@ -80,9 +80,8 @@ class DefaultModelEntity(BaseModel):
provider: DefaultModelProviderEntity provider: DefaultModelProviderEntity
if PYDANTIC_V2: if PYDANTIC_V2:
model_config = ConfigDict( model_config = ConfigDict(protected_namespaces=())
protected_namespaces=()
)
else: else:
class Config: class Config:
protected_namespaces = () protected_namespaces = ()

View File

@ -4,9 +4,6 @@ import logging
from json import JSONDecodeError from json import JSONDecodeError
from typing import Dict, Iterator, List, Optional from typing import Dict, Iterator, List, Optional
from ..._compat import PYDANTIC_V2, ConfigDict
from ..._models import BaseModel
from model_providers.core.entities.model_entities import ( from model_providers.core.entities.model_entities import (
ModelStatus, ModelStatus,
ModelWithProviderEntity, ModelWithProviderEntity,
@ -29,6 +26,9 @@ from model_providers.core.model_runtime.model_providers.__base.model_provider im
ModelProvider, ModelProvider,
) )
from ..._compat import PYDANTIC_V2, ConfigDict
from ..._models import BaseModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -351,11 +351,9 @@ class ProviderModelBundle(BaseModel):
model_type_instance: AIModel model_type_instance: AIModel
if PYDANTIC_V2: if PYDANTIC_V2:
model_config = ConfigDict( model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)
protected_namespaces=(),
arbitrary_types_allowed=True
)
else: else:
class Config: class Config:
protected_namespaces = () protected_namespaces = ()

View File

@ -1,11 +1,11 @@
from enum import Enum from enum import Enum
from typing import List, Optional from typing import List, Optional
from model_providers.core.model_runtime.entities.model_entities import ModelType
from ..._compat import PYDANTIC_V2, ConfigDict from ..._compat import PYDANTIC_V2, ConfigDict
from ..._models import BaseModel from ..._models import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelType
class ProviderType(Enum): class ProviderType(Enum):
CUSTOM = "custom" CUSTOM = "custom"
@ -59,10 +59,9 @@ class RestrictModel(BaseModel):
model_type: ModelType model_type: ModelType
if PYDANTIC_V2: if PYDANTIC_V2:
model_config = ConfigDict( model_config = ConfigDict(protected_namespaces=())
protected_namespaces=()
)
else: else:
class Config: class Config:
protected_namespaces = () protected_namespaces = ()
@ -109,10 +108,9 @@ class CustomModelConfiguration(BaseModel):
credentials: dict credentials: dict
if PYDANTIC_V2: if PYDANTIC_V2:
model_config = ConfigDict( model_config = ConfigDict(protected_namespaces=())
protected_namespaces=()
)
else: else:
class Config: class Config:
protected_namespaces = () protected_namespaces = ()

View File

@ -1,13 +1,13 @@
from enum import Enum from enum import Enum
from typing import Any from typing import Any
from ..._models import BaseModel
from model_providers.core.model_runtime.entities.llm_entities import ( from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult, LLMResult,
LLMResultChunk, LLMResultChunk,
) )
from ..._models import BaseModel
class QueueEvent(Enum): class QueueEvent(Enum):
""" """

View File

@ -2,8 +2,6 @@ from decimal import Decimal
from enum import Enum from enum import Enum
from typing import List, Optional from typing import List, Optional
from ...._models import BaseModel
from model_providers.core.model_runtime.entities.message_entities import ( from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
@ -13,6 +11,8 @@ from model_providers.core.model_runtime.entities.model_entities import (
PriceInfo, PriceInfo,
) )
from ...._models import BaseModel
class LLMMode(Enum): class LLMMode(Enum):
""" """

View File

@ -2,11 +2,11 @@ from decimal import Decimal
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from ...._compat import PYDANTIC_V2, ConfigDict from ...._compat import PYDANTIC_V2, ConfigDict
from ...._models import BaseModel from ...._models import BaseModel
from model_providers.core.model_runtime.entities.common_entities import I18nObject
class ModelType(Enum): class ModelType(Enum):
""" """
@ -164,10 +164,9 @@ class ProviderModel(BaseModel):
model_properties: Dict[ModelPropertyKey, Any] model_properties: Dict[ModelPropertyKey, Any]
deprecated: bool = False deprecated: bool = False
if PYDANTIC_V2: if PYDANTIC_V2:
model_config = ConfigDict( model_config = ConfigDict(protected_namespaces=())
protected_namespaces=()
)
else: else:
class Config: class Config:
protected_namespaces = () protected_namespaces = ()

View File

@ -1,9 +1,6 @@
from enum import Enum from enum import Enum
from typing import List, Optional from typing import List, Optional
from ...._compat import PYDANTIC_V2, ConfigDict
from ...._models import BaseModel
from model_providers.core.model_runtime.entities.common_entities import I18nObject from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import ( from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity, AIModelEntity,
@ -11,6 +8,9 @@ from model_providers.core.model_runtime.entities.model_entities import (
ProviderModel, ProviderModel,
) )
from ...._compat import PYDANTIC_V2, ConfigDict
from ...._models import BaseModel
class ConfigurateMethod(Enum): class ConfigurateMethod(Enum):
""" """
@ -136,10 +136,9 @@ class ProviderEntity(BaseModel):
model_credential_schema: Optional[ModelCredentialSchema] = None model_credential_schema: Optional[ModelCredentialSchema] = None
if PYDANTIC_V2: if PYDANTIC_V2:
model_config = ConfigDict( model_config = ConfigDict(protected_namespaces=())
protected_namespaces=()
)
else: else:
class Config: class Config:
protected_namespaces = () protected_namespaces = ()

View File

@ -1,10 +1,10 @@
from decimal import Decimal from decimal import Decimal
from typing import List from typing import List
from ...._models import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelUsage from model_providers.core.model_runtime.entities.model_entities import ModelUsage
from ...._models import BaseModel
class EmbeddingUsage(ModelUsage): class EmbeddingUsage(ModelUsage):
""" """

View File

@ -1,5 +1,3 @@
from ....._models import BaseModel
from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from model_providers.core.model_runtime.entities.llm_entities import LLMMode from model_providers.core.model_runtime.entities.llm_entities import LLMMode
from model_providers.core.model_runtime.entities.model_entities import ( from model_providers.core.model_runtime.entities.model_entities import (
@ -14,6 +12,8 @@ from model_providers.core.model_runtime.entities.model_entities import (
PriceConfig, PriceConfig,
) )
from ....._models import BaseModel
AZURE_OPENAI_API_VERSION = "2024-02-15-preview" AZURE_OPENAI_API_VERSION = "2024-02-15-preview"

View File

@ -1,6 +1,5 @@
import logging import logging
from typing import Generator from typing import Dict, Generator, List, Optional, Type, Union, cast
from typing import Dict, List, Optional, Type, Union, cast
import cohere import cohere
from cohere.responses import Chat, Generations from cohere.responses import Chat, Generations

View File

@ -1,7 +1,6 @@
import logging import logging
from typing import Generator
from typing import List, Optional, Union, cast
from decimal import Decimal from decimal import Decimal
from typing import Generator, List, Optional, Union, cast
import tiktoken import tiktoken
from openai import OpenAI, Stream from openai import OpenAI, Stream
@ -37,10 +36,15 @@ from model_providers.core.model_runtime.entities.message_entities import (
) )
from model_providers.core.model_runtime.entities.model_entities import ( from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity, AIModelEntity,
DefaultParameterName,
FetchFrom, FetchFrom,
I18nObject, I18nObject,
ModelFeature,
ModelPropertyKey,
ModelType, ModelType,
PriceConfig, ParameterRule, ParameterType, ModelFeature, ModelPropertyKey, DefaultParameterName, ParameterRule,
ParameterType,
PriceConfig,
) )
from model_providers.core.model_runtime.errors.validate import ( from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError, CredentialsValidateFailedError,
@ -48,7 +52,9 @@ from model_providers.core.model_runtime.errors.validate import (
from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel, LargeLanguageModel,
) )
from model_providers.core.model_runtime.model_providers.deepseek._common import _CommonDeepseek from model_providers.core.model_runtime.model_providers.deepseek._common import (
_CommonDeepseek,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -1117,7 +1123,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
return num_tokens return num_tokens
def get_customizable_model_schema( def get_customizable_model_schema(
self, model: str, credentials: dict self, model: str, credentials: dict
) -> AIModelEntity: ) -> AIModelEntity:
""" """
Get customizable model schema. Get customizable model schema.
@ -1129,7 +1135,6 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
""" """
extras = {} extras = {}
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,
label=I18nObject(zh_Hans=model, en_US=model), label=I18nObject(zh_Hans=model, en_US=model),
@ -1149,8 +1154,8 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
type=ParameterType.FLOAT, type=ParameterType.FLOAT,
help=I18nObject( help=I18nObject(
en_US="The temperature of the model. " en_US="The temperature of the model. "
"Increasing the temperature will make the model answer " "Increasing the temperature will make the model answer "
"more creatively. (Default: 0.8)" "more creatively. (Default: 0.8)"
), ),
default=0.8, default=0.8,
min=0, min=0,
@ -1163,8 +1168,8 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
type=ParameterType.FLOAT, type=ParameterType.FLOAT,
help=I18nObject( help=I18nObject(
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to " en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
"more diverse text, while a lower value (e.g., 0.5) will generate more " "more diverse text, while a lower value (e.g., 0.5) will generate more "
"focused and conservative text. (Default: 0.9)" "focused and conservative text. (Default: 0.9)"
), ),
default=0.9, default=0.9,
min=0, min=0,
@ -1177,8 +1182,8 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
help=I18nObject( help=I18nObject(
en_US="A number between -2.0 and 2.0. If positive, ", en_US="A number between -2.0 and 2.0. If positive, ",
zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正," zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正,"
"那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚," "那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,"
"降低模型重复相同内容的可能性" "降低模型重复相同内容的可能性",
), ),
default=0, default=0,
min=-2, min=-2,
@ -1190,7 +1195,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
type=ParameterType.FLOAT, type=ParameterType.FLOAT,
help=I18nObject( help=I18nObject(
en_US="Sets how strongly to presence_penalty. ", en_US="Sets how strongly to presence_penalty. ",
zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其是否已在已有文本中出现受到相应的惩罚,从而增加模型谈论新主题的可能性。" zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其是否已在已有文本中出现受到相应的惩罚,从而增加模型谈论新主题的可能性。",
), ),
default=1.1, default=1.1,
min=-2, min=-2,
@ -1204,7 +1209,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
help=I18nObject( help=I18nObject(
en_US="Maximum number of tokens to predict when generating text. ", en_US="Maximum number of tokens to predict when generating text. ",
zh_Hans="限制一次请求中模型生成 completion 的最大 token 数。" zh_Hans="限制一次请求中模型生成 completion 的最大 token 数。"
"输入 token 和输出 token 的总长度受模型的上下文长度的限制。" "输入 token 和输出 token 的总长度受模型的上下文长度的限制。",
), ),
default=128, default=128,
min=-2, min=-2,
@ -1216,7 +1221,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
type=ParameterType.BOOLEAN, type=ParameterType.BOOLEAN,
help=I18nObject( help=I18nObject(
en_US="Whether to return the log probabilities of the tokens. ", en_US="Whether to return the log probabilities of the tokens. ",
zh_Hans="是否返回所输出 token 的对数概率。如果为 true则在 message 的 content 中返回每个输出 token 的对数概率。" zh_Hans="是否返回所输出 token 的对数概率。如果为 true则在 message 的 content 中返回每个输出 token 的对数概率。",
), ),
), ),
ParameterRule( ParameterRule(
@ -1226,7 +1231,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
help=I18nObject( help=I18nObject(
en_US="the format to return a response in.", en_US="the format to return a response in.",
zh_Hans="一个介于 0 到 20 之间的整数 N指定每个输出位置返回输出概率 top N 的 token" zh_Hans="一个介于 0 到 20 之间的整数 N指定每个输出位置返回输出概率 top N 的 token"
"且返回这些 token 的对数概率。指定此参数时logprobs 必须为 true。" "且返回这些 token 的对数概率。指定此参数时logprobs 必须为 true。",
), ),
default=0, default=0,
min=0, min=0,

View File

@ -3,8 +3,6 @@ import logging
import os import os
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from ...._models import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelType from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.entities.provider_entities import ( from model_providers.core.model_runtime.entities.provider_entities import (
ProviderConfig, ProviderConfig,
@ -25,6 +23,8 @@ from model_providers.core.utils.position_helper import (
sort_to_dict_by_position_map, sort_to_dict_by_position_map,
) )
from ...._models import BaseModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,5 +1,4 @@
from typing import Generator from typing import Generator, List, Optional, Union
from typing import List, Optional, Union
from model_providers.core.model_runtime.entities.llm_entities import LLMResult from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import ( from model_providers.core.model_runtime.entities.message_entities import (

View File

@ -1,7 +1,6 @@
import logging import logging
from typing import Generator
from typing import List, Optional, Union, cast
from decimal import Decimal from decimal import Decimal
from typing import Generator, List, Optional, Union, cast
import tiktoken import tiktoken
from openai import OpenAI, Stream from openai import OpenAI, Stream
@ -37,10 +36,15 @@ from model_providers.core.model_runtime.entities.message_entities import (
) )
from model_providers.core.model_runtime.entities.model_entities import ( from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity, AIModelEntity,
DefaultParameterName,
FetchFrom, FetchFrom,
I18nObject, I18nObject,
ModelFeature,
ModelPropertyKey,
ModelType, ModelType,
PriceConfig, ModelFeature, ModelPropertyKey, DefaultParameterName, ParameterRule, ParameterType, ParameterRule,
ParameterType,
PriceConfig,
) )
from model_providers.core.model_runtime.errors.validate import ( from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError, CredentialsValidateFailedError,
@ -48,7 +52,9 @@ from model_providers.core.model_runtime.errors.validate import (
from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel, LargeLanguageModel,
) )
from model_providers.core.model_runtime.model_providers.ollama._common import _CommonOllama from model_providers.core.model_runtime.model_providers.ollama._common import (
_CommonOllama,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -661,7 +667,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
model=model, model=model,
stream=stream, stream=stream,
extra_body=extra_body extra_body=extra_body,
) )
if stream: if stream:
@ -1120,7 +1126,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
return num_tokens return num_tokens
def get_customizable_model_schema( def get_customizable_model_schema(
self, model: str, credentials: dict self, model: str, credentials: dict
) -> AIModelEntity: ) -> AIModelEntity:
""" """
Get customizable model schema. Get customizable model schema.
@ -1154,8 +1160,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.FLOAT, type=ParameterType.FLOAT,
help=I18nObject( help=I18nObject(
en_US="The temperature of the model. " en_US="The temperature of the model. "
"Increasing the temperature will make the model answer " "Increasing the temperature will make the model answer "
"more creatively. (Default: 0.8)" "more creatively. (Default: 0.8)"
), ),
default=0.8, default=0.8,
min=0, min=0,
@ -1168,8 +1174,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.FLOAT, type=ParameterType.FLOAT,
help=I18nObject( help=I18nObject(
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to " en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
"more diverse text, while a lower value (e.g., 0.5) will generate more " "more diverse text, while a lower value (e.g., 0.5) will generate more "
"focused and conservative text. (Default: 0.9)" "focused and conservative text. (Default: 0.9)"
), ),
default=0.9, default=0.9,
min=0, min=0,
@ -1181,8 +1187,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT, type=ParameterType.INT,
help=I18nObject( help=I18nObject(
en_US="Reduces the probability of generating nonsense. " en_US="Reduces the probability of generating nonsense. "
"A higher value (e.g. 100) will give more diverse answers, " "A higher value (e.g. 100) will give more diverse answers, "
"while a lower value (e.g. 10) will be more conservative. (Default: 40)" "while a lower value (e.g. 10) will be more conservative. (Default: 40)"
), ),
default=40, default=40,
min=1, min=1,
@ -1194,8 +1200,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.FLOAT, type=ParameterType.FLOAT,
help=I18nObject( help=I18nObject(
en_US="Sets how strongly to penalize repetitions. " en_US="Sets how strongly to penalize repetitions. "
"A higher value (e.g., 1.5) will penalize repetitions more strongly, " "A higher value (e.g., 1.5) will penalize repetitions more strongly, "
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)" "while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
), ),
default=1.1, default=1.1,
min=-2, min=-2,
@ -1208,7 +1214,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT, type=ParameterType.INT,
help=I18nObject( help=I18nObject(
en_US="Maximum number of tokens to predict when generating text. " en_US="Maximum number of tokens to predict when generating text. "
"(Default: 128, -1 = infinite generation, -2 = fill context)" "(Default: 128, -1 = infinite generation, -2 = fill context)"
), ),
default=128, default=128,
min=-2, min=-2,
@ -1220,7 +1226,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT, type=ParameterType.INT,
help=I18nObject( help=I18nObject(
en_US="Enable Mirostat sampling for controlling perplexity. " en_US="Enable Mirostat sampling for controlling perplexity. "
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)" "(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
), ),
default=0, default=0,
min=0, min=0,
@ -1232,9 +1238,9 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.FLOAT, type=ParameterType.FLOAT,
help=I18nObject( help=I18nObject(
en_US="Influences how quickly the algorithm responds to feedback from " en_US="Influences how quickly the algorithm responds to feedback from "
"the generated text. A lower learning rate will result in slower adjustments, " "the generated text. A lower learning rate will result in slower adjustments, "
"while a higher learning rate will make the algorithm more responsive. " "while a higher learning rate will make the algorithm more responsive. "
"(Default: 0.1)" "(Default: 0.1)"
), ),
default=0.1, default=0.1,
precision=1, precision=1,
@ -1245,7 +1251,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.FLOAT, type=ParameterType.FLOAT,
help=I18nObject( help=I18nObject(
en_US="Controls the balance between coherence and diversity of the output. " en_US="Controls the balance between coherence and diversity of the output. "
"A lower value will result in more focused and coherent text. (Default: 5.0)" "A lower value will result in more focused and coherent text. (Default: 5.0)"
), ),
default=5.0, default=5.0,
precision=1, precision=1,
@ -1256,7 +1262,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT, type=ParameterType.INT,
help=I18nObject( help=I18nObject(
en_US="Sets the size of the context window used to generate the next token. " en_US="Sets the size of the context window used to generate the next token. "
"(Default: 2048)" "(Default: 2048)"
), ),
default=2048, default=2048,
min=1, min=1,
@ -1267,7 +1273,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT, type=ParameterType.INT,
help=I18nObject( help=I18nObject(
en_US="The number of layers to send to the GPU(s). " en_US="The number of layers to send to the GPU(s). "
"On macOS it defaults to 1 to enable metal support, 0 to disable." "On macOS it defaults to 1 to enable metal support, 0 to disable."
), ),
default=1, default=1,
min=0, min=0,
@ -1279,9 +1285,9 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT, type=ParameterType.INT,
help=I18nObject( help=I18nObject(
en_US="Sets the number of threads to use during computation. " en_US="Sets the number of threads to use during computation. "
"By default, Ollama will detect this for optimal performance. " "By default, Ollama will detect this for optimal performance. "
"It is recommended to set this value to the number of physical CPU cores " "It is recommended to set this value to the number of physical CPU cores "
"your system has (as opposed to the logical number of cores)." "your system has (as opposed to the logical number of cores)."
), ),
min=1, min=1,
), ),
@ -1291,7 +1297,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT, type=ParameterType.INT,
help=I18nObject( help=I18nObject(
en_US="Sets how far back for the model to look back to prevent repetition. " en_US="Sets how far back for the model to look back to prevent repetition. "
"(Default: 64, 0 = disabled, -1 = num_ctx)" "(Default: 64, 0 = disabled, -1 = num_ctx)"
), ),
default=64, default=64,
min=-1, min=-1,
@ -1302,8 +1308,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.FLOAT, type=ParameterType.FLOAT,
help=I18nObject( help=I18nObject(
en_US="Tail free sampling is used to reduce the impact of less probable tokens " en_US="Tail free sampling is used to reduce the impact of less probable tokens "
"from the output. A higher value (e.g., 2.0) will reduce the impact more, " "from the output. A higher value (e.g., 2.0) will reduce the impact more, "
"while a value of 1.0 disables this setting. (default: 1)" "while a value of 1.0 disables this setting. (default: 1)"
), ),
default=1, default=1,
precision=1, precision=1,
@ -1314,8 +1320,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.INT, type=ParameterType.INT,
help=I18nObject( help=I18nObject(
en_US="Sets the random number seed to use for generation. Setting this to " en_US="Sets the random number seed to use for generation. Setting this to "
"a specific number will make the model generate the same text for " "a specific number will make the model generate the same text for "
"the same prompt. (Default: 0)" "the same prompt. (Default: 0)"
), ),
default=0, default=0,
), ),
@ -1325,7 +1331,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
type=ParameterType.STRING, type=ParameterType.STRING,
help=I18nObject( help=I18nObject(
en_US="the format to return a response in." en_US="the format to return a response in."
" Currently the only accepted value is json." " Currently the only accepted value is json."
), ),
options=["json"], options=["json"],
), ),

View File

@ -1,6 +1,5 @@
import logging import logging
from typing import Generator from typing import Generator, List, Optional, Union, cast
from typing import List, Optional, Union, cast
import tiktoken import tiktoken
from openai import OpenAI, Stream from openai import OpenAI, Stream

View File

@ -1,8 +1,7 @@
import json import json
import logging import logging
from typing import Generator
from decimal import Decimal from decimal import Decimal
from typing import List, Optional, Union, cast from typing import Generator, List, Optional, Union, cast
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import requests

View File

@ -1,7 +1,6 @@
from typing import Generator
from enum import Enum from enum import Enum
from json import dumps, loads from json import dumps, loads
from typing import Any, Union from typing import Any, Generator, Union
from requests import Response, post from requests import Response, post
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema

View File

@ -1,6 +1,5 @@
import threading import threading
from typing import Generator from typing import Dict, Generator, List, Optional, Type, Union
from typing import Dict, List, Optional, Type, Union
from model_providers.core.model_runtime.entities.llm_entities import ( from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult, LLMResult,

View File

@ -1,5 +1,4 @@
from typing import Generator from typing import Generator, List, Optional, Union
from typing import List, Optional, Union
from model_providers.core.model_runtime.entities.llm_entities import LLMResult from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import ( from model_providers.core.model_runtime.entities.message_entities import (

View File

@ -1,5 +1,4 @@
from typing import Generator from typing import Dict, Generator, List, Optional, Type, Union
from typing import Dict, List, Optional, Type, Union
from dashscope import get_tokenizer from dashscope import get_tokenizer
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse from dashscope.api_entities.dashscope_response import DashScopeAPIResponse

View File

@ -1,6 +1,4 @@
from typing import Generator, Iterator from typing import Dict, Generator, Iterator, List, Type, Union, cast
from typing import Dict, List, Union, cast, Type
from openai import ( from openai import (
APIConnectionError, APIConnectionError,
@ -527,8 +525,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
def _extract_response_tool_calls( def _extract_response_tool_calls(
self, self,
response_tool_calls: Union[ response_tool_calls: Union[
List[ChatCompletionMessageToolCall], List[ChatCompletionMessageToolCall], List[ChoiceDeltaToolCall]
List[ChoiceDeltaToolCall]
], ],
) -> List[AssistantPromptMessage.ToolCall]: ) -> List[AssistantPromptMessage.ToolCall]:
""" """

View File

@ -195,8 +195,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
if stop: if stop:
extra_model_kwargs["stop"] = stop extra_model_kwargs["stop"] = stop
client = ZhipuAI(base_url=credentials_kwargs["api_base"], client = ZhipuAI(
api_key=credentials_kwargs["api_key"]) base_url=credentials_kwargs["api_base"],
api_key=credentials_kwargs["api_key"],
)
if len(prompt_messages) == 0: if len(prompt_messages) == 0:
raise ValueError("At least one message is required") raise ValueError("At least one message is required")

View File

@ -43,8 +43,10 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
:return: embeddings result :return: embeddings result
""" """
credentials_kwargs = self._to_credential_kwargs(credentials) credentials_kwargs = self._to_credential_kwargs(credentials)
client = ZhipuAI(base_url=credentials_kwargs["api_base"], client = ZhipuAI(
api_key=credentials_kwargs["api_key"]) base_url=credentials_kwargs["api_base"],
api_key=credentials_kwargs["api_key"],
)
embeddings, embedding_used_tokens = self.embed_documents(model, client, texts) embeddings, embedding_used_tokens = self.embed_documents(model, client, texts)
@ -85,8 +87,10 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
try: try:
# transform credentials to kwargs for model instance # transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials) credentials_kwargs = self._to_credential_kwargs(credentials)
client = ZhipuAI(base_url=credentials_kwargs["api_base"], client = ZhipuAI(
api_key=credentials_kwargs["api_key"]) base_url=credentials_kwargs["api_base"],
api_key=credentials_kwargs["api_key"],
)
# call embedding model # call embedding model
self.embed_documents( self.embed_documents(

View File

@ -1,4 +1,5 @@
import pydantic import pydantic
from ...._models import BaseModel from ...._models import BaseModel

View File

@ -1,6 +1,7 @@
import os import os
import orjson import orjson
from ..._models import BaseModel from ..._models import BaseModel

View File

@ -1,8 +1,7 @@
import os import os
import shutil import shutil
from typing import Generator
from contextlib import closing from contextlib import closing
from typing import Union from typing import Generator, Union
import boto3 import boto3
from botocore.exceptions import ClientError from botocore.exceptions import ClientError

View File

@ -104,15 +104,17 @@ def logging_conf() -> dict:
111, 111,
) )
@pytest.fixture @pytest.fixture
def providers_file(request) -> str: def providers_file(request) -> str:
from pathlib import Path
import os import os
from pathlib import Path
# 当前执行目录 # 当前执行目录
# 获取当前测试文件的路径 # 获取当前测试文件的路径
test_file_path = Path(str(request.fspath)).parent test_file_path = Path(str(request.fspath)).parent
print("test_file_path:",test_file_path) print("test_file_path:", test_file_path)
return os.path.join(test_file_path,"model_providers.yaml") return os.path.join(test_file_path, "model_providers.yaml")
@pytest.fixture @pytest.fixture
@ -121,9 +123,7 @@ def init_server(logging_conf: dict, providers_file: str) -> None:
try: try:
boot = ( boot = (
BootstrapWebBuilder() BootstrapWebBuilder()
.model_providers_cfg_path( .model_providers_cfg_path(model_providers_cfg_path=providers_file)
model_providers_cfg_path=providers_file
)
.host(host="127.0.0.1") .host(host="127.0.0.1")
.port(port=20000) .port(port=20000)
.build() .build()
@ -139,5 +139,4 @@ def init_server(logging_conf: dict, providers_file: str) -> None:
boot.destroy() boot.destroy()
except SystemExit: except SystemExit:
raise raise

View File

@ -1,15 +1,20 @@
import logging
import pytest
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@pytest.mark.requires("openai") @pytest.mark.requires("openai")
def test_llm(init_server: str): def test_llm(init_server: str):
llm = ChatOpenAI(model_name="deepseek-chat", openai_api_key="sk-dcb625fcbc1e497d80b7b9493b51d758", openai_api_base=f"{init_server}/deepseek/v1") llm = ChatOpenAI(
model_name="deepseek-chat",
openai_api_key="sk-dcb625fcbc1e497d80b7b9493b51d758",
openai_api_base=f"{init_server}/deepseek/v1",
)
template = """Question: {question} template = """Question: {question}
Answer: Let's think step by step.""" Answer: Let's think step by step."""

View File

@ -1,15 +1,20 @@
import logging
import pytest
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@pytest.mark.requires("openai") @pytest.mark.requires("openai")
def test_llm(init_server: str): def test_llm(init_server: str):
llm = ChatOpenAI(model_name="llama3", openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/ollama/v1") llm = ChatOpenAI(
model_name="llama3",
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/ollama/v1",
)
template = """Question: {question} template = """Question: {question}
Answer: Let's think step by step.""" Answer: Let's think step by step."""
@ -23,9 +28,11 @@ def test_llm(init_server: str):
@pytest.mark.requires("openai") @pytest.mark.requires("openai")
def test_embedding(init_server: str): def test_embedding(init_server: str):
embeddings = OpenAIEmbeddings(model="text-embedding-3-large", embeddings = OpenAIEmbeddings(
openai_api_key="YOUR_API_KEY", model="text-embedding-3-large",
openai_api_base=f"{init_server}/zhipuai/v1") openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/zhipuai/v1",
)
text = "你好" text = "你好"

View File

@ -1,15 +1,18 @@
import logging
import pytest
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@pytest.mark.requires("openai") @pytest.mark.requires("openai")
def test_llm(init_server: str): def test_llm(init_server: str):
llm = ChatOpenAI(openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1") llm = ChatOpenAI(
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1"
)
template = """Question: {question} template = """Question: {question}
Answer: Let's think step by step.""" Answer: Let's think step by step."""
@ -21,14 +24,13 @@ def test_llm(init_server: str):
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m") logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
@pytest.mark.requires("openai") @pytest.mark.requires("openai")
def test_embedding(init_server: str): def test_embedding(init_server: str):
embeddings = OpenAIEmbeddings(
embeddings = OpenAIEmbeddings(model="text-embedding-3-large", model="text-embedding-3-large",
openai_api_key="YOUR_API_KEY", openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/zhipuai/v1") openai_api_base=f"{init_server}/zhipuai/v1",
)
text = "你好" text = "你好"

View File

@ -1,8 +1,9 @@
import logging
import pytest
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -10,9 +11,10 @@ logger = logging.getLogger(__name__)
@pytest.mark.requires("xinference_client") @pytest.mark.requires("xinference_client")
def test_llm(init_server: str): def test_llm(init_server: str):
llm = ChatOpenAI( llm = ChatOpenAI(
model_name="glm-4", model_name="glm-4",
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/xinference/v1") openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/xinference/v1",
)
template = """Question: {question} template = """Question: {question}
Answer: Let's think step by step.""" Answer: Let's think step by step."""
@ -26,9 +28,11 @@ def test_llm(init_server: str):
@pytest.mark.requires("xinference-client") @pytest.mark.requires("xinference-client")
def test_embedding(init_server: str): def test_embedding(init_server: str):
embeddings = OpenAIEmbeddings(model="text_embedding", embeddings = OpenAIEmbeddings(
openai_api_key="YOUR_API_KEY", model="text_embedding",
openai_api_base=f"{init_server}/xinference/v1") openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/xinference/v1",
)
text = "你好" text = "你好"

View File

@ -1,17 +1,20 @@
import logging
import pytest
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@pytest.mark.requires("zhipuai") @pytest.mark.requires("zhipuai")
def test_llm(init_server: str): def test_llm(init_server: str):
llm = ChatOpenAI( llm = ChatOpenAI(
model_name="glm-4", model_name="glm-4",
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1") openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/zhipuai/v1",
)
template = """Question: {question} template = """Question: {question}
Answer: Let's think step by step.""" Answer: Let's think step by step."""
@ -25,15 +28,14 @@ def test_llm(init_server: str):
@pytest.mark.requires("zhipuai") @pytest.mark.requires("zhipuai")
def test_embedding(init_server: str): def test_embedding(init_server: str):
embeddings = OpenAIEmbeddings(
embeddings = OpenAIEmbeddings(model="text_embedding", model="text_embedding",
openai_api_key="YOUR_API_KEY", openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/zhipuai/v1") openai_api_base=f"{init_server}/zhipuai/v1",
)
text = "你好" text = "你好"
query_result = embeddings.embed_query(text) query_result = embeddings.embed_query(text)
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m") logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")

View File

@ -7,7 +7,10 @@ from omegaconf import OmegaConf
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
from model_providers.core.model_manager import ModelManager from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelType,
)
from model_providers.core.provider_manager import ProviderManager from model_providers.core.provider_manager import ProviderManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,9 +19,7 @@ logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None: def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件 # 读取配置文件
cfg = OmegaConf.load( cfg = OmegaConf.load(providers_file)
providers_file
)
# 转换配置文件 # 转换配置文件
( (
provider_name_to_provider_records_dict, provider_name_to_provider_records_dict,
@ -35,15 +36,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
) )
llm_models: List[AIModelEntity] = [] llm_models: List[AIModelEntity] = []
for model in provider_model_bundle_llm.configuration.custom_configuration.models: for model in provider_model_bundle_llm.configuration.custom_configuration.models:
llm_models.append(
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema( provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model, model=model.model,
credentials=model.credentials, credentials=model.credentials,
)) )
)
# 获取预定义模型 # 获取预定义模型
llm_models.extend( llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
provider_model_bundle_llm.model_type_instance.predefined_models()
)
logger.info(f"predefined_models: {llm_models}") logger.info(f"predefined_models: {llm_models}")

View File

@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None: def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件 # 读取配置文件
cfg = OmegaConf.load( cfg = OmegaConf.load(providers_file)
providers_file
)
# 转换配置文件 # 转换配置文件
( (
provider_name_to_provider_records_dict, provider_name_to_provider_records_dict,
@ -34,16 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
) )
llm_models: List[AIModelEntity] = [] llm_models: List[AIModelEntity] = []
for model in provider_model_bundle_llm.configuration.custom_configuration.models: for model in provider_model_bundle_llm.configuration.custom_configuration.models:
llm_models.append(
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema( provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model, model=model.model,
credentials=model.credentials, credentials=model.credentials,
)) )
)
# 获取预定义模型 # 获取预定义模型
llm_models.extend( llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
provider_model_bundle_llm.model_type_instance.predefined_models()
)
logger.info(f"predefined_models: {llm_models}") logger.info(f"predefined_models: {llm_models}")

View File

@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None: def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件 # 读取配置文件
cfg = OmegaConf.load( cfg = OmegaConf.load(providers_file)
providers_file
)
# 转换配置文件 # 转换配置文件
( (
provider_name_to_provider_records_dict, provider_name_to_provider_records_dict,
@ -34,15 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
) )
llm_models: List[AIModelEntity] = [] llm_models: List[AIModelEntity] = []
for model in provider_model_bundle_llm.configuration.custom_configuration.models: for model in provider_model_bundle_llm.configuration.custom_configuration.models:
llm_models.append(
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema( provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model, model=model.model,
credentials=model.credentials, credentials=model.credentials,
)) )
)
# 获取预定义模型 # 获取预定义模型
llm_models.extend( llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
provider_model_bundle_llm.model_type_instance.predefined_models()
)
logger.info(f"predefined_models: {llm_models}") logger.info(f"predefined_models: {llm_models}")

View File

@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None: def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件 # 读取配置文件
cfg = OmegaConf.load( cfg = OmegaConf.load(providers_file)
providers_file
)
# 转换配置文件 # 转换配置文件
( (
provider_name_to_provider_records_dict, provider_name_to_provider_records_dict,
@ -34,15 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
) )
llm_models: List[AIModelEntity] = [] llm_models: List[AIModelEntity] = []
for model in provider_model_bundle_llm.configuration.custom_configuration.models: for model in provider_model_bundle_llm.configuration.custom_configuration.models:
llm_models.append(
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema( provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model, model=model.model,
credentials=model.credentials, credentials=model.credentials,
)) )
)
# 获取预定义模型 # 获取预定义模型
llm_models.extend( llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
provider_model_bundle_llm.model_type_instance.predefined_models()
)
logger.info(f"predefined_models: {llm_models}") logger.info(f"predefined_models: {llm_models}")

View File

@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None: def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件 # 读取配置文件
cfg = OmegaConf.load( cfg = OmegaConf.load(providers_file)
providers_file
)
# 转换配置文件 # 转换配置文件
( (
provider_name_to_provider_records_dict, provider_name_to_provider_records_dict,
@ -34,15 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
) )
llm_models: List[AIModelEntity] = [] llm_models: List[AIModelEntity] = []
for model in provider_model_bundle_llm.configuration.custom_configuration.models: for model in provider_model_bundle_llm.configuration.custom_configuration.models:
llm_models.append(
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema( provider_model_bundle_llm.model_type_instance.get_model_schema(
model=model.model, model=model.model,
credentials=model.credentials, credentials=model.credentials,
)) )
)
# 获取预定义模型 # 获取预定义模型
llm_models.extend( llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
provider_model_bundle_llm.model_type_instance.predefined_models()
)
logger.info(f"predefined_models: {llm_models}") logger.info(f"predefined_models: {llm_models}")