From 4e9b1d6edff21549c683de892fc23d60d9bedf0f Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Mon, 20 May 2024 11:58:30 +0800 Subject: [PATCH] make format --- model-providers/model_providers/_compat.py | 46 ++++- model-providers/model_providers/_files.py | 39 +++-- model-providers/model_providers/_models.py | 113 ++++++++----- model-providers/model_providers/_types.py | 29 +++- .../model_providers/_utils/__init__.py | 160 +++++++++++++----- .../model_providers/_utils/_transform.py | 86 +++++++--- .../model_providers/_utils/_typing.py | 26 +-- .../model_providers/_utils/_utils.py | 36 ++-- .../entities/model_provider_entities.py | 16 +- .../bootstrap_web/openai_bootstrap_web.py | 34 ++-- .../core/bootstrap/openai_protocol.py | 4 +- .../core/entities/message_entities.py | 5 +- .../core/entities/model_entities.py | 13 +- .../core/entities/provider_configuration.py | 12 +- .../core/entities/provider_entities.py | 14 +- .../core/entities/queue_entities.py | 4 +- .../model_runtime/entities/llm_entities.py | 4 +- .../model_runtime/entities/model_entities.py | 9 +- .../entities/provider_entities.py | 11 +- .../entities/text_embedding_entities.py | 4 +- .../model_providers/azure_openai/_constant.py | 4 +- .../model_providers/cohere/llm/llm.py | 3 +- .../model_providers/deepseek/llm/llm.py | 37 ++-- .../model_providers/model_provider_factory.py | 4 +- .../model_providers/moonshot/llm/llm.py | 3 +- .../model_providers/ollama/llm/llm.py | 68 ++++---- .../model_providers/openai/llm/llm.py | 3 +- .../openai_api_compatible/llm/llm.py | 3 +- .../openllm/llm/openllm_generate.py | 3 +- .../model_providers/spark/llm/llm.py | 3 +- .../model_providers/togetherai/llm/llm.py | 3 +- .../model_providers/tongyi/llm/llm.py | 3 +- .../model_providers/xinference/llm/llm.py | 7 +- .../model_providers/zhipuai/llm/llm.py | 6 +- .../zhipuai/text_embedding/text_embedding.py | 12 +- .../core/model_runtime/utils/helper.py | 1 + .../model_providers/core/utils/json_dumps.py | 1 + .../model_providers/extensions/ext_storage.py | 3 +- model-providers/tests/conftest.py | 13 +- .../test_deepseek_service.py | 11 +- .../test_ollama_service.py | 19 ++- .../test_openai_service.py | 22 +-- .../test_xinference_service.py | 18 +- .../test_zhipuai_service.py | 22 +-- .../test_deepseek_provider_manager_models.py | 24 +-- .../test_ollama_provider_manager_models.py | 20 +-- .../test_openai_provider_manager_models.py | 19 +-- ...test_xinference_provider_manager_models.py | 19 +-- .../test_zhipuai_provider_manager_models.py | 19 +-- 49 files changed, 631 insertions(+), 407 deletions(-) diff --git a/model-providers/model_providers/_compat.py b/model-providers/model_providers/_compat.py index 0339d10a..0e01c71f 100644 --- a/model-providers/model_providers/_compat.py +++ b/model-providers/model_providers/_compat.py @@ -1,11 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload from datetime import date, datetime -from typing_extensions import Self +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload import pydantic from pydantic.fields import FieldInfo +from typing_extensions import Self from ._types import StrBytesIntFloat @@ -45,23 +45,49 @@ if TYPE_CHECKING: else: if PYDANTIC_V2: + from pydantic.v1.datetime_parse import ( + parse_date as parse_date, + ) + from pydantic.v1.datetime_parse import ( + parse_datetime as parse_datetime, + ) from pydantic.v1.typing import ( get_args as get_args, - is_union as is_union, + ) + from pydantic.v1.typing import ( get_origin as get_origin, - is_typeddict as is_typeddict, + ) + from pydantic.v1.typing import ( is_literal_type as is_literal_type, ) - from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime + from pydantic.v1.typing import ( + is_typeddict as is_typeddict, + ) + from pydantic.v1.typing import ( + is_union as is_union, + ) else: + from pydantic.datetime_parse import ( + parse_date as parse_date, + ) + from pydantic.datetime_parse import ( + parse_datetime as parse_datetime, + ) from pydantic.typing import ( get_args as get_args, - is_union as is_union, + ) + from pydantic.typing import ( get_origin as get_origin, - is_typeddict as is_typeddict, + ) + from pydantic.typing import ( is_literal_type as is_literal_type, ) - from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime + from pydantic.typing import ( + is_typeddict as is_typeddict, + ) + from pydantic.typing import ( + is_union as is_union, + ) # refactored config @@ -204,7 +230,9 @@ if TYPE_CHECKING: def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ... - def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self: + def __get__( + self, instance: object, owner: type[Any] | None = None + ) -> _T | Self: raise NotImplementedError() def __set_name__(self, owner: type[Any], name: str) -> None: diff --git a/model-providers/model_providers/_files.py b/model-providers/model_providers/_files.py index ad7b668b..15e485ad 100644 --- a/model-providers/model_providers/_files.py +++ b/model-providers/model_providers/_files.py @@ -4,20 +4,20 @@ import io import os import pathlib from typing import overload -from typing_extensions import TypeGuard import anyio +from typing_extensions import TypeGuard from ._types import ( - FileTypes, - FileContent, - RequestFiles, - HttpxFileTypes, Base64FileInput, + FileContent, + FileTypes, HttpxFileContent, + HttpxFileTypes, HttpxRequestFiles, + RequestFiles, ) -from ._utils import is_tuple_t, is_mapping_t, is_sequence_t +from ._utils import is_mapping_t, is_sequence_t, is_tuple_t def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]: @@ -26,13 +26,20 @@ def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]: def is_file_content(obj: object) -> TypeGuard[FileContent]: return ( - isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike) + isinstance(obj, bytes) + or isinstance(obj, tuple) + or isinstance(obj, io.IOBase) + or isinstance(obj, os.PathLike) ) def assert_is_file_content(obj: object, *, key: str | None = None) -> None: if not is_file_content(obj): - prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`" + prefix = ( + f"Expected entry at `{key}`" + if key is not None + else f"Expected file input `{obj!r}`" + ) raise RuntimeError( f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads" ) from None @@ -57,7 +64,9 @@ def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: elif is_sequence_t(files): files = [(key, _transform_file(file)) for key, file in files] else: - raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence") + raise TypeError( + f"Unexpected file type input {type(files)}, expected mapping or sequence" + ) return files @@ -73,7 +82,9 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes: if is_tuple_t(file): return (file[0], _read_file_content(file[1]), *file[2:]) - raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") + raise TypeError( + f"Expected file types input to be a FileContent type or to be a tuple" + ) def _read_file_content(file: FileContent) -> HttpxFileContent: @@ -101,7 +112,9 @@ async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles elif is_sequence_t(files): files = [(key, await _async_transform_file(file)) for key, file in files] else: - raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence") + raise TypeError( + "Unexpected file type input {type(files)}, expected mapping or sequence" + ) return files @@ -117,7 +130,9 @@ async def _async_transform_file(file: FileTypes) -> HttpxFileTypes: if is_tuple_t(file): return (file[0], await _async_read_file_content(file[1]), *file[2:]) - raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") + raise TypeError( + f"Expected file types input to be a FileContent type or to be a tuple" + ) async def _async_read_file_content(file: FileContent) -> HttpxFileContent: diff --git a/model-providers/model_providers/_models.py b/model-providers/model_providers/_models.py index e7aa662a..c173353a 100644 --- a/model-providers/model_providers/_models.py +++ b/model-providers/model_providers/_models.py @@ -1,60 +1,62 @@ from __future__ import annotations -import os import inspect -from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast +import os from datetime import date, datetime +from typing import TYPE_CHECKING, Any, Callable, Generic, Type, TypeVar, Union, cast + +import pydantic +import pydantic.generics +from pydantic.fields import FieldInfo from typing_extensions import ( - Unpack, - Literal, ClassVar, + Literal, Protocol, Required, TypedDict, TypeGuard, + Unpack, final, override, runtime_checkable, ) -import pydantic -import pydantic.generics -from pydantic.fields import FieldInfo - +from ._compat import ( + PYDANTIC_V2, + ConfigDict, + field_get_default, + get_args, + get_model_config, + get_model_fields, + get_origin, + is_literal_type, + is_union, + parse_obj, +) +from ._compat import ( + GenericModel as BaseGenericModel, +) from ._types import ( IncEx, ModelT, ) from ._utils import ( PropertyInfo, - is_list, - is_given, - lru_cache, - is_mapping, - parse_date, coerce_boolean, - parse_datetime, - strip_not_given, extract_type_arg, is_annotated_type, + is_given, + is_list, + is_mapping, + lru_cache, + parse_date, + parse_datetime, strip_annotated_type, -) -from ._compat import ( - PYDANTIC_V2, - ConfigDict, - GenericModel as BaseGenericModel, - get_args, - is_union, - parse_obj, - get_origin, - is_literal_type, - get_model_config, - get_model_fields, - field_get_default, + strip_not_given, ) if TYPE_CHECKING: - from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema + from pydantic_core.core_schema import LiteralSchema, ModelField, ModelFieldsSchema __all__ = ["BaseModel", "GenericModel"] @@ -69,7 +71,8 @@ class _ConfigProtocol(Protocol): class BaseModel(pydantic.BaseModel): if PYDANTIC_V2: model_config: ClassVar[ConfigDict] = ConfigDict( - extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")) + extra="allow", + defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")), ) else: @@ -189,7 +192,9 @@ class BaseModel(pydantic.BaseModel): key = name if key in values: - fields_values[name] = _construct_field(value=values[key], field=field, key=key) + fields_values[name] = _construct_field( + value=values[key], field=field, key=key + ) _fields_set.add(name) else: fields_values[name] = field_get_default(field) @@ -412,9 +417,13 @@ def construct_type(*, value: object, type_: object) -> object: # # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then # we'd end up constructing `FooType` when it should be `BarType`. - discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta) + discriminator = _build_discriminated_union_meta( + union=type_, meta_annotations=meta + ) if discriminator and is_mapping(value): - variant_value = value.get(discriminator.field_alias_from or discriminator.field_name) + variant_value = value.get( + discriminator.field_alias_from or discriminator.field_name + ) if variant_value and isinstance(variant_value, str): variant_type = discriminator.mapping.get(variant_value) if variant_type: @@ -434,11 +443,19 @@ def construct_type(*, value: object, type_: object) -> object: return value _, items_type = get_args(type_) # Dict[_, items_type] - return {key: construct_type(value=item, type_=items_type) for key, item in value.items()} + return { + key: construct_type(value=item, type_=items_type) + for key, item in value.items() + } - if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)): + if not is_literal_type(type_) and ( + issubclass(origin, BaseModel) or issubclass(origin, GenericModel) + ): if is_list(value): - return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value] + return [ + cast(Any, type_).construct(**entry) if is_mapping(entry) else entry + for entry in value + ] if is_mapping(value): if issubclass(type_, BaseModel): @@ -523,14 +540,19 @@ class DiscriminatorDetails: self.field_alias_from = discriminator_alias -def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: +def _build_discriminated_union_meta( + *, union: type, meta_annotations: tuple[Any, ...] +) -> DiscriminatorDetails | None: if isinstance(union, CachedDiscriminatorType): return union.__discriminator__ discriminator_field_name: str | None = None for annotation in meta_annotations: - if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None: + if ( + isinstance(annotation, PropertyInfo) + and annotation.discriminator is not None + ): discriminator_field_name = annotation.discriminator break @@ -558,7 +580,9 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, if isinstance(entry, str): mapping[entry] = variant else: - field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + field_info = cast("dict[str, FieldInfo]", variant.__fields__).get( + discriminator_field_name + ) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] if not field_info: continue @@ -582,7 +606,9 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, return details -def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None: +def _extract_field_schema_pv2( + model: type[BaseModel], field_name: str +) -> ModelField | None: schema = model.__pydantic_core_schema__ if schema["type"] != "model": return None @@ -621,7 +647,9 @@ else: if PYDANTIC_V2: from pydantic import TypeAdapter as _TypeAdapter - _CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter)) + _CachedTypeAdapter = cast( + "TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter) + ) if TYPE_CHECKING: from pydantic import TypeAdapter @@ -652,6 +680,3 @@ elif not TYPE_CHECKING: # TODO: condition is weird def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]: return RootModel[type_] # type: ignore - - - diff --git a/model-providers/model_providers/_types.py b/model-providers/model_providers/_types.py index 6fce8e09..8f6730a3 100644 --- a/model-providers/model_providers/_types.py +++ b/model-providers/model_providers/_types.py @@ -5,22 +5,29 @@ from typing import ( IO, TYPE_CHECKING, Any, + Callable, Dict, List, - Type, - Tuple, - Union, Mapping, - TypeVar, - Callable, Optional, Sequence, + Tuple, + Type, + TypeVar, + Union, ) -from typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable import httpx import pydantic -from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport +from httpx import URL, AsyncBaseTransport, BaseTransport, Proxy, Response, Timeout +from typing_extensions import ( + Literal, + Protocol, + TypeAlias, + TypedDict, + override, + runtime_checkable, +) if TYPE_CHECKING: from ._models import BaseModel @@ -43,7 +50,9 @@ if TYPE_CHECKING: FileContent = Union[IO[bytes], bytes, PathLike[str]] else: Base64FileInput = Union[IO[bytes], PathLike] - FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8. + FileContent = Union[ + IO[bytes], bytes, PathLike + ] # PathLike is not subscriptable in Python 3.8. FileTypes = Union[ # file (or bytes) FileContent, @@ -68,7 +77,9 @@ HttpxFileTypes = Union[ # (filename, file (or bytes), content_type, headers) Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]], ] -HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]] +HttpxRequestFiles = Union[ + Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]] +] # Workaround to support (cast_to: Type[ResponseT]) -> ResponseT # where ResponseT includes `None`. In order to support directly diff --git a/model-providers/model_providers/_utils/__init__.py b/model-providers/model_providers/_utils/__init__.py index c5e3bee8..51f08eb7 100644 --- a/model-providers/model_providers/_utils/__init__.py +++ b/model-providers/model_providers/_utils/__init__.py @@ -1,48 +1,126 @@ -from ._utils import ( - flatten as flatten, - is_dict as is_dict, - is_list as is_list, - is_given as is_given, - is_tuple as is_tuple, - lru_cache as lru_cache, - is_mapping as is_mapping, - is_tuple_t as is_tuple_t, - parse_date as parse_date, - is_iterable as is_iterable, - is_sequence as is_sequence, - coerce_float as coerce_float, - is_mapping_t as is_mapping_t, - removeprefix as removeprefix, - removesuffix as removesuffix, - extract_files as extract_files, - is_sequence_t as is_sequence_t, - required_args as required_args, - coerce_boolean as coerce_boolean, - coerce_integer as coerce_integer, - file_from_path as file_from_path, - parse_datetime as parse_datetime, - strip_not_given as strip_not_given, - deepcopy_minimal as deepcopy_minimal, - get_async_library as get_async_library, - maybe_coerce_float as maybe_coerce_float, - get_required_header as get_required_header, - maybe_coerce_boolean as maybe_coerce_boolean, - maybe_coerce_integer as maybe_coerce_integer, +from ._transform import ( + PropertyInfo as PropertyInfo, +) +from ._transform import ( + async_maybe_transform as async_maybe_transform, +) +from ._transform import ( + async_transform as async_transform, +) +from ._transform import ( + maybe_transform as maybe_transform, +) +from ._transform import ( + transform as transform, +) +from ._typing import ( + extract_type_arg as extract_type_arg, +) +from ._typing import ( + extract_type_var_from_base as extract_type_var_from_base, +) +from ._typing import ( + is_annotated_type as is_annotated_type, +) +from ._typing import ( + is_iterable_type as is_iterable_type, ) from ._typing import ( is_list_type as is_list_type, - is_union_type as is_union_type, - extract_type_arg as extract_type_arg, - is_iterable_type as is_iterable_type, +) +from ._typing import ( is_required_type as is_required_type, - is_annotated_type as is_annotated_type, +) +from ._typing import ( + is_union_type as is_union_type, +) +from ._typing import ( strip_annotated_type as strip_annotated_type, - extract_type_var_from_base as extract_type_var_from_base, ) -from ._transform import ( - PropertyInfo as PropertyInfo, - transform as transform, - async_transform as async_transform, - maybe_transform as maybe_transform, - async_maybe_transform as async_maybe_transform, +from ._utils import ( + coerce_boolean as coerce_boolean, +) +from ._utils import ( + coerce_float as coerce_float, +) +from ._utils import ( + coerce_integer as coerce_integer, +) +from ._utils import ( + deepcopy_minimal as deepcopy_minimal, +) +from ._utils import ( + extract_files as extract_files, +) +from ._utils import ( + file_from_path as file_from_path, +) +from ._utils import ( + flatten as flatten, +) +from ._utils import ( + get_async_library as get_async_library, +) +from ._utils import ( + get_required_header as get_required_header, +) +from ._utils import ( + is_dict as is_dict, +) +from ._utils import ( + is_given as is_given, +) +from ._utils import ( + is_iterable as is_iterable, +) +from ._utils import ( + is_list as is_list, +) +from ._utils import ( + is_mapping as is_mapping, +) +from ._utils import ( + is_mapping_t as is_mapping_t, +) +from ._utils import ( + is_sequence as is_sequence, +) +from ._utils import ( + is_sequence_t as is_sequence_t, +) +from ._utils import ( + is_tuple as is_tuple, +) +from ._utils import ( + is_tuple_t as is_tuple_t, +) +from ._utils import ( + lru_cache as lru_cache, +) +from ._utils import ( + maybe_coerce_boolean as maybe_coerce_boolean, +) +from ._utils import ( + maybe_coerce_float as maybe_coerce_float, +) +from ._utils import ( + maybe_coerce_integer as maybe_coerce_integer, +) +from ._utils import ( + parse_date as parse_date, +) +from ._utils import ( + parse_datetime as parse_datetime, +) +from ._utils import ( + removeprefix as removeprefix, +) +from ._utils import ( + removesuffix as removesuffix, +) +from ._utils import ( + required_args as required_args, +) +from ._utils import ( + strip_not_given as strip_not_given, ) diff --git a/model-providers/model_providers/_utils/_transform.py b/model-providers/model_providers/_utils/_transform.py index 47e262a5..9527d650 100644 --- a/model-providers/model_providers/_utils/_transform.py +++ b/model-providers/model_providers/_utils/_transform.py @@ -1,31 +1,31 @@ from __future__ import annotations -import io import base64 +import io import pathlib -from typing import Any, Mapping, TypeVar, cast from datetime import date, datetime -from typing_extensions import Literal, get_args, override, get_type_hints +from typing import Any, Mapping, TypeVar, cast import anyio import pydantic +from typing_extensions import Literal, get_args, get_type_hints, override -from ._utils import ( - is_list, - is_mapping, - is_iterable, -) +from .._compat import is_typeddict, model_dump from .._files import is_base64_file_input from ._typing import ( - is_list_type, - is_union_type, extract_type_arg, - is_iterable_type, - is_required_type, is_annotated_type, + is_iterable_type, + is_list_type, + is_required_type, + is_union_type, strip_annotated_type, ) -from .._compat import model_dump, is_typeddict +from ._utils import ( + is_iterable, + is_list, + is_mapping, +) _T = TypeVar("_T") @@ -171,10 +171,17 @@ def _transform_recursive( # List[T] (is_list_type(stripped_type) and is_list(data)) # Iterable[T] - or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + or ( + is_iterable_type(stripped_type) + and is_iterable(data) + and not isinstance(data, str) + ) ): inner_type = extract_type_arg(stripped_type, 0) - return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] + return [ + _transform_recursive(d, annotation=annotation, inner_type=inner_type) + for d in data + ] if is_union_type(stripped_type): # For union types we run the transformation against all subtypes to ensure that everything is transformed. @@ -201,7 +208,9 @@ def _transform_recursive( return data -def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: +def _format_data( + data: object, format_: PropertyFormat, format_template: str | None +) -> object: if isinstance(data, (date, datetime)): if format_ == "iso8601": return data.isoformat() @@ -221,7 +230,9 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N binary = binary.encode() if not isinstance(binary, bytes): - raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + raise RuntimeError( + f"Could not read bytes from {data}; Received {type(binary)}" + ) return base64.b64encode(binary).decode("ascii") @@ -240,7 +251,9 @@ def _transform_typeddict( # we do not have a type annotation for this field, leave it as is result[key] = value else: - result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_) + result[_maybe_transform_key(key, type_)] = _transform_recursive( + value, annotation=type_ + ) return result @@ -276,7 +289,9 @@ async def async_transform( It should be noted that the transformations that this function does are not represented in the type system. """ - transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type)) + transformed = await _async_transform_recursive( + data, annotation=cast(type, expected_type) + ) return cast(_T, transformed) @@ -309,10 +324,19 @@ async def _async_transform_recursive( # List[T] (is_list_type(stripped_type) and is_list(data)) # Iterable[T] - or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + or ( + is_iterable_type(stripped_type) + and is_iterable(data) + and not isinstance(data, str) + ) ): inner_type = extract_type_arg(stripped_type, 0) - return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] + return [ + await _async_transform_recursive( + d, annotation=annotation, inner_type=inner_type + ) + for d in data + ] if is_union_type(stripped_type): # For union types we run the transformation against all subtypes to ensure that everything is transformed. @@ -320,7 +344,9 @@ async def _async_transform_recursive( # TODO: there may be edge cases where the same normalized field name will transform to two different names # in different subtypes. for subtype in get_args(stripped_type): - data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype) + data = await _async_transform_recursive( + data, annotation=annotation, inner_type=subtype + ) return data if isinstance(data, pydantic.BaseModel): @@ -334,12 +360,16 @@ async def _async_transform_recursive( annotations = get_args(annotated_type)[1:] for annotation in annotations: if isinstance(annotation, PropertyInfo) and annotation.format is not None: - return await _async_format_data(data, annotation.format, annotation.format_template) + return await _async_format_data( + data, annotation.format, annotation.format_template + ) return data -async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: +async def _async_format_data( + data: object, format_: PropertyFormat, format_template: str | None +) -> object: if isinstance(data, (date, datetime)): if format_ == "iso8601": return data.isoformat() @@ -359,7 +389,9 @@ async def _async_format_data(data: object, format_: PropertyFormat, format_templ binary = binary.encode() if not isinstance(binary, bytes): - raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + raise RuntimeError( + f"Could not read bytes from {data}; Received {type(binary)}" + ) return base64.b64encode(binary).decode("ascii") @@ -378,5 +410,7 @@ async def _async_transform_typeddict( # we do not have a type annotation for this field, leave it as is result[key] = value else: - result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_) + result[_maybe_transform_key(key, type_)] = await _async_transform_recursive( + value, annotation=type_ + ) return result diff --git a/model-providers/model_providers/_utils/_typing.py b/model-providers/model_providers/_utils/_typing.py index 003ca84a..5c200b93 100644 --- a/model-providers/model_providers/_utils/_typing.py +++ b/model-providers/model_providers/_utils/_typing.py @@ -1,11 +1,12 @@ from __future__ import annotations -from typing import Any, TypeVar, Iterable, cast from collections import abc as _c_abc -from typing_extensions import Required, Annotated, get_args, get_origin +from typing import Any, Iterable, TypeVar, cast + +from typing_extensions import Annotated, Required, get_args, get_origin -from .._types import InheritsGeneric from .._compat import is_union as _is_union +from .._types import InheritsGeneric def is_annotated_type(typ: type) -> bool: @@ -49,15 +50,17 @@ def extract_type_arg(typ: type, index: int) -> type: try: return cast(type, args[index]) except IndexError as err: - raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err + raise RuntimeError( + f"Expected type {typ} to have a type argument at index {index} but it did not" + ) from err def extract_type_var_from_base( - typ: type, - *, - generic_bases: tuple[type, ...], - index: int, - failure_message: str | None = None, + typ: type, + *, + generic_bases: tuple[type, ...], + index: int, + failure_message: str | None = None, ) -> type: """Given a type like `Foo[T]`, returns the generic type variable `T`. @@ -117,4 +120,7 @@ def extract_type_var_from_base( return extracted - raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}") + raise RuntimeError( + failure_message + or f"Could not resolve inner type variable at index {index} for {typ}" + ) diff --git a/model-providers/model_providers/_utils/_utils.py b/model-providers/model_providers/_utils/_utils.py index 17904ce6..7177a6ce 100644 --- a/model-providers/model_providers/_utils/_utils.py +++ b/model-providers/model_providers/_utils/_utils.py @@ -1,27 +1,28 @@ from __future__ import annotations +import functools +import inspect import os import re -import inspect -import functools +from pathlib import Path from typing import ( Any, - Tuple, - Mapping, - TypeVar, Callable, Iterable, + Mapping, Sequence, + Tuple, + TypeVar, cast, overload, ) -from pathlib import Path -from typing_extensions import TypeGuard import sniffio +from typing_extensions import TypeGuard -from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike -from .._compat import parse_date as parse_date, parse_datetime as parse_datetime +from .._compat import parse_date as parse_date +from .._compat import parse_datetime as parse_datetime +from .._types import FileTypes, Headers, HeadersLike, NotGiven, NotGivenOr _T = TypeVar("_T") _TupleT = TypeVar("_TupleT", bound=Tuple[object, ...]) @@ -108,7 +109,9 @@ def _extract_items( item, path, index=index, - flattened_key=flattened_key + "[]" if flattened_key is not None else "[]", + flattened_key=flattened_key + "[]" + if flattened_key is not None + else "[]", ) for item in obj ] @@ -261,7 +264,12 @@ def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: else: # no break if len(variants) > 1: variations = human_join( - ["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants] + [ + "(" + + human_join([quote(arg) for arg in variant], final="and") + + ")" + for variant in variants + ] ) msg = f"Missing required arguments; Expected either {variations} arguments to be given" else: @@ -376,7 +384,11 @@ def get_required_header(headers: HeadersLike, header: str) -> str: return v """ to deal with the case where the header looks like Stainless-Event-Id """ - intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize()) + intercaps_header = re.sub( + r"([^\w])(\w)", + lambda pat: pat.group(1) + pat.group(2).upper(), + header.capitalize(), + ) for normalized_header in [header, lower_header, header.upper(), intercaps_header]: value = headers.get(normalized_header) diff --git a/model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py b/model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py index 2ac109a9..06f815cd 100644 --- a/model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py +++ b/model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py @@ -1,9 +1,6 @@ from enum import Enum from typing import List, Literal, Optional -from ..._compat import PYDANTIC_V2, ConfigDict -from ..._models import BaseModel - from model_providers.core.entities.model_entities import ( ModelStatus, ModelWithProviderEntity, @@ -26,6 +23,9 @@ from model_providers.core.model_runtime.entities.provider_entities import ( SimpleProviderEntity, ) +from ..._compat import PYDANTIC_V2, ConfigDict +from ..._models import BaseModel + class CustomConfigurationStatus(Enum): """ @@ -74,10 +74,9 @@ class ProviderResponse(BaseModel): custom_configuration: CustomConfigurationResponse system_configuration: SystemConfigurationResponse if PYDANTIC_V2: - model_config = ConfigDict( - protected_namespaces=() - ) + model_config = ConfigDict(protected_namespaces=()) else: + class Config: protected_namespaces = () @@ -180,10 +179,9 @@ class DefaultModelResponse(BaseModel): provider: SimpleProviderEntityResponse if PYDANTIC_V2: - model_config = ConfigDict( - protected_namespaces=() - ) + model_config = ConfigDict(protected_namespaces=()) else: + class Config: protected_namespaces = () diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index 593fe394..9532553c 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -73,7 +73,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): self._server = None self._server_thread = None - def logging_conf(self,logging_conf: Optional[dict] = None): + def logging_conf(self, logging_conf: Optional[dict] = None): self._logging_conf = logging_conf @classmethod @@ -132,7 +132,10 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): self._app.include_router(self._router) config = Config( - app=self._app, host=self._host, port=self._port, log_config=self._logging_conf + app=self._app, + host=self._host, + port=self._port, + log_config=self._logging_conf, ) self._server = Server(config) @@ -145,7 +148,6 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): self._server_thread.start() def destroy(self): - logger.info("Shutting down server") self._server.should_exit = True # 设置退出标志 self._server.shutdown() # 停止服务器 @@ -155,7 +157,6 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): def join(self): self._server_thread.join() - def set_app_event(self, started_event: mp.Event = None): @self._app.on_event("startup") async def on_startup(): @@ -190,12 +191,15 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): provider_model_bundle.model_type_instance.predefined_models() ) # 获取自定义模型 - for model in provider_model_bundle.configuration.custom_configuration.models: - - llm_models.append(provider_model_bundle.model_type_instance.get_model_schema( - model=model.model, - credentials=model.credentials, - )) + for ( + model + ) in provider_model_bundle.configuration.custom_configuration.models: + llm_models.append( + provider_model_bundle.model_type_instance.get_model_schema( + model=model.model, + credentials=model.credentials, + ) + ) except Exception as e: logger.error( f"Error while fetching models for provider: {provider}, model_type: {model_type}" @@ -225,13 +229,15 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): ) # 判断embeddings_request.input是否为list - input = '' + input = "" if isinstance(embeddings_request.input, list): tokens = embeddings_request.input try: encoding = tiktoken.encoding_for_model(embeddings_request.model) except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") + logger.warning( + "Warning: model not found. Using cl100k_base encoding." + ) model = "cl100k_base" encoding = tiktoken.get_encoding(model) for i, token in enumerate(tokens): @@ -241,7 +247,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): else: input = embeddings_request.input - response = model_instance.invoke_text_embedding(texts=[input], user="abc-123") + response = model_instance.invoke_text_embedding( + texts=[input], user="abc-123" + ) return await openai_embedding_text(response) except ValueError as e: diff --git a/model-providers/model_providers/core/bootstrap/openai_protocol.py b/model-providers/model_providers/core/bootstrap/openai_protocol.py index 5f30c1cc..c4eeeb90 100644 --- a/model-providers/model_providers/core/bootstrap/openai_protocol.py +++ b/model-providers/model_providers/core/bootstrap/openai_protocol.py @@ -1,10 +1,12 @@ import time from enum import Enum from typing import Any, Dict, List, Optional, Union -from ..._models import BaseModel + from pydantic import Field as FieldInfo from typing_extensions import Literal +from ..._models import BaseModel + class Role(str, Enum): USER = "user" diff --git a/model-providers/model_providers/core/entities/message_entities.py b/model-providers/model_providers/core/entities/message_entities.py index c768c0e9..40c66e83 100644 --- a/model-providers/model_providers/core/entities/message_entities.py +++ b/model-providers/model_providers/core/entities/message_entities.py @@ -1,5 +1,5 @@ import enum -from typing import Any, cast, List +from typing import Any, List, cast from langchain.schema import ( AIMessage, @@ -8,7 +8,6 @@ from langchain.schema import ( HumanMessage, SystemMessage, ) -from ..._models import BaseModel from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -20,6 +19,8 @@ from model_providers.core.model_runtime.entities.message_entities import ( UserPromptMessage, ) +from ..._models import BaseModel + class PromptMessageFileType(enum.Enum): IMAGE = "image" diff --git a/model-providers/model_providers/core/entities/model_entities.py b/model-providers/model_providers/core/entities/model_entities.py index e3a9702f..d9fd23de 100644 --- a/model-providers/model_providers/core/entities/model_entities.py +++ b/model-providers/model_providers/core/entities/model_entities.py @@ -1,9 +1,6 @@ from enum import Enum from typing import List, Optional -from ..._compat import PYDANTIC_V2, ConfigDict -from ..._models import BaseModel - from model_providers.core.model_runtime.entities.common_entities import I18nObject from model_providers.core.model_runtime.entities.model_entities import ( ModelType, @@ -11,6 +8,9 @@ from model_providers.core.model_runtime.entities.model_entities import ( ) from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity +from ..._compat import PYDANTIC_V2, ConfigDict +from ..._models import BaseModel + class ModelStatus(Enum): """ @@ -80,9 +80,8 @@ class DefaultModelEntity(BaseModel): provider: DefaultModelProviderEntity if PYDANTIC_V2: - model_config = ConfigDict( - protected_namespaces=() - ) + model_config = ConfigDict(protected_namespaces=()) else: + class Config: - protected_namespaces = () \ No newline at end of file + protected_namespaces = () diff --git a/model-providers/model_providers/core/entities/provider_configuration.py b/model-providers/model_providers/core/entities/provider_configuration.py index b0d35537..8253b2cf 100644 --- a/model-providers/model_providers/core/entities/provider_configuration.py +++ b/model-providers/model_providers/core/entities/provider_configuration.py @@ -4,9 +4,6 @@ import logging from json import JSONDecodeError from typing import Dict, Iterator, List, Optional -from ..._compat import PYDANTIC_V2, ConfigDict -from ..._models import BaseModel - from model_providers.core.entities.model_entities import ( ModelStatus, ModelWithProviderEntity, @@ -29,6 +26,9 @@ from model_providers.core.model_runtime.model_providers.__base.model_provider im ModelProvider, ) +from ..._compat import PYDANTIC_V2, ConfigDict +from ..._models import BaseModel + logger = logging.getLogger(__name__) @@ -351,11 +351,9 @@ class ProviderModelBundle(BaseModel): model_type_instance: AIModel if PYDANTIC_V2: - model_config = ConfigDict( - protected_namespaces=(), - arbitrary_types_allowed=True - ) + model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True) else: + class Config: protected_namespaces = () diff --git a/model-providers/model_providers/core/entities/provider_entities.py b/model-providers/model_providers/core/entities/provider_entities.py index 73f87d27..a06acb61 100644 --- a/model-providers/model_providers/core/entities/provider_entities.py +++ b/model-providers/model_providers/core/entities/provider_entities.py @@ -1,11 +1,11 @@ from enum import Enum from typing import List, Optional +from model_providers.core.model_runtime.entities.model_entities import ModelType + from ..._compat import PYDANTIC_V2, ConfigDict from ..._models import BaseModel -from model_providers.core.model_runtime.entities.model_entities import ModelType - class ProviderType(Enum): CUSTOM = "custom" @@ -59,10 +59,9 @@ class RestrictModel(BaseModel): model_type: ModelType if PYDANTIC_V2: - model_config = ConfigDict( - protected_namespaces=() - ) + model_config = ConfigDict(protected_namespaces=()) else: + class Config: protected_namespaces = () @@ -109,10 +108,9 @@ class CustomModelConfiguration(BaseModel): credentials: dict if PYDANTIC_V2: - model_config = ConfigDict( - protected_namespaces=() - ) + model_config = ConfigDict(protected_namespaces=()) else: + class Config: protected_namespaces = () diff --git a/model-providers/model_providers/core/entities/queue_entities.py b/model-providers/model_providers/core/entities/queue_entities.py index 3dba6d29..55cab2fe 100644 --- a/model-providers/model_providers/core/entities/queue_entities.py +++ b/model-providers/model_providers/core/entities/queue_entities.py @@ -1,13 +1,13 @@ from enum import Enum from typing import Any -from ..._models import BaseModel - from model_providers.core.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, ) +from ..._models import BaseModel + class QueueEvent(Enum): """ diff --git a/model-providers/model_providers/core/model_runtime/entities/llm_entities.py b/model-providers/model_providers/core/model_runtime/entities/llm_entities.py index c9d09f86..52dd1bea 100644 --- a/model-providers/model_providers/core/model_runtime/entities/llm_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/llm_entities.py @@ -2,8 +2,6 @@ from decimal import Decimal from enum import Enum from typing import List, Optional -from ...._models import BaseModel - from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -13,6 +11,8 @@ from model_providers.core.model_runtime.entities.model_entities import ( PriceInfo, ) +from ...._models import BaseModel + class LLMMode(Enum): """ diff --git a/model-providers/model_providers/core/model_runtime/entities/model_entities.py b/model-providers/model_providers/core/model_runtime/entities/model_entities.py index c372099c..786e49cf 100644 --- a/model-providers/model_providers/core/model_runtime/entities/model_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/model_entities.py @@ -2,11 +2,11 @@ from decimal import Decimal from enum import Enum from typing import Any, Dict, List, Optional +from model_providers.core.model_runtime.entities.common_entities import I18nObject + from ...._compat import PYDANTIC_V2, ConfigDict from ...._models import BaseModel -from model_providers.core.model_runtime.entities.common_entities import I18nObject - class ModelType(Enum): """ @@ -164,10 +164,9 @@ class ProviderModel(BaseModel): model_properties: Dict[ModelPropertyKey, Any] deprecated: bool = False if PYDANTIC_V2: - model_config = ConfigDict( - protected_namespaces=() - ) + model_config = ConfigDict(protected_namespaces=()) else: + class Config: protected_namespaces = () diff --git a/model-providers/model_providers/core/model_runtime/entities/provider_entities.py b/model-providers/model_providers/core/model_runtime/entities/provider_entities.py index 167c4ccb..c3e71ca2 100644 --- a/model-providers/model_providers/core/model_runtime/entities/provider_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/provider_entities.py @@ -1,9 +1,6 @@ from enum import Enum from typing import List, Optional -from ...._compat import PYDANTIC_V2, ConfigDict -from ...._models import BaseModel - from model_providers.core.model_runtime.entities.common_entities import I18nObject from model_providers.core.model_runtime.entities.model_entities import ( AIModelEntity, @@ -11,6 +8,9 @@ from model_providers.core.model_runtime.entities.model_entities import ( ProviderModel, ) +from ...._compat import PYDANTIC_V2, ConfigDict +from ...._models import BaseModel + class ConfigurateMethod(Enum): """ @@ -136,10 +136,9 @@ class ProviderEntity(BaseModel): model_credential_schema: Optional[ModelCredentialSchema] = None if PYDANTIC_V2: - model_config = ConfigDict( - protected_namespaces=() - ) + model_config = ConfigDict(protected_namespaces=()) else: + class Config: protected_namespaces = () diff --git a/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py b/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py index b222e649..4b4eeb9d 100644 --- a/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py @@ -1,10 +1,10 @@ from decimal import Decimal from typing import List -from ...._models import BaseModel - from model_providers.core.model_runtime.entities.model_entities import ModelUsage +from ...._models import BaseModel + class EmbeddingUsage(ModelUsage): """ diff --git a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_constant.py b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_constant.py index 2d53dcbc..77470470 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_constant.py @@ -1,5 +1,3 @@ -from ....._models import BaseModel - from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from model_providers.core.model_runtime.entities.llm_entities import LLMMode from model_providers.core.model_runtime.entities.model_entities import ( @@ -14,6 +12,8 @@ from model_providers.core.model_runtime.entities.model_entities import ( PriceConfig, ) +from ....._models import BaseModel + AZURE_OPENAI_API_VERSION = "2024-02-15-preview" diff --git a/model-providers/model_providers/core/model_runtime/model_providers/cohere/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/cohere/llm/llm.py index a2031792..0bb104b0 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/cohere/llm/llm.py @@ -1,6 +1,5 @@ import logging -from typing import Generator -from typing import Dict, List, Optional, Type, Union, cast +from typing import Dict, Generator, List, Optional, Type, Union, cast import cohere from cohere.responses import Chat, Generations diff --git a/model-providers/model_providers/core/model_runtime/model_providers/deepseek/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/deepseek/llm/llm.py index 5a68da78..bd170e13 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/deepseek/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/deepseek/llm/llm.py @@ -1,7 +1,6 @@ import logging -from typing import Generator -from typing import List, Optional, Union, cast from decimal import Decimal +from typing import Generator, List, Optional, Union, cast import tiktoken from openai import OpenAI, Stream @@ -37,10 +36,15 @@ from model_providers.core.model_runtime.entities.message_entities import ( ) from model_providers.core.model_runtime.entities.model_entities import ( AIModelEntity, + DefaultParameterName, FetchFrom, I18nObject, + ModelFeature, + ModelPropertyKey, ModelType, - PriceConfig, ParameterRule, ParameterType, ModelFeature, ModelPropertyKey, DefaultParameterName, + ParameterRule, + ParameterType, + PriceConfig, ) from model_providers.core.model_runtime.errors.validate import ( CredentialsValidateFailedError, @@ -48,7 +52,9 @@ from model_providers.core.model_runtime.errors.validate import ( from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( LargeLanguageModel, ) -from model_providers.core.model_runtime.model_providers.deepseek._common import _CommonDeepseek +from model_providers.core.model_runtime.model_providers.deepseek._common import ( + _CommonDeepseek, +) logger = logging.getLogger(__name__) @@ -1117,7 +1123,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel): return num_tokens def get_customizable_model_schema( - self, model: str, credentials: dict + self, model: str, credentials: dict ) -> AIModelEntity: """ Get customizable model schema. @@ -1129,7 +1135,6 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel): """ extras = {} - entity = AIModelEntity( model=model, label=I18nObject(zh_Hans=model, en_US=model), @@ -1149,8 +1154,8 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel): type=ParameterType.FLOAT, help=I18nObject( en_US="The temperature of the model. " - "Increasing the temperature will make the model answer " - "more creatively. (Default: 0.8)" + "Increasing the temperature will make the model answer " + "more creatively. (Default: 0.8)" ), default=0.8, min=0, @@ -1163,8 +1168,8 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel): type=ParameterType.FLOAT, help=I18nObject( en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to " - "more diverse text, while a lower value (e.g., 0.5) will generate more " - "focused and conservative text. (Default: 0.9)" + "more diverse text, while a lower value (e.g., 0.5) will generate more " + "focused and conservative text. (Default: 0.9)" ), default=0.9, min=0, @@ -1177,8 +1182,8 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel): help=I18nObject( en_US="A number between -2.0 and 2.0. If positive, ", zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正," - "那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚," - "降低模型重复相同内容的可能性" + "那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚," + "降低模型重复相同内容的可能性", ), default=0, min=-2, @@ -1190,7 +1195,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel): type=ParameterType.FLOAT, help=I18nObject( en_US="Sets how strongly to presence_penalty. ", - zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其是否已在已有文本中出现受到相应的惩罚,从而增加模型谈论新主题的可能性。" + zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其是否已在已有文本中出现受到相应的惩罚,从而增加模型谈论新主题的可能性。", ), default=1.1, min=-2, @@ -1204,7 +1209,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel): help=I18nObject( en_US="Maximum number of tokens to predict when generating text. ", zh_Hans="限制一次请求中模型生成 completion 的最大 token 数。" - "输入 token 和输出 token 的总长度受模型的上下文长度的限制。" + "输入 token 和输出 token 的总长度受模型的上下文长度的限制。", ), default=128, min=-2, @@ -1216,7 +1221,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel): type=ParameterType.BOOLEAN, help=I18nObject( en_US="Whether to return the log probabilities of the tokens. ", - zh_Hans="是否返回所输出 token 的对数概率。如果为 true,则在 message 的 content 中返回每个输出 token 的对数概率。" + zh_Hans="是否返回所输出 token 的对数概率。如果为 true,则在 message 的 content 中返回每个输出 token 的对数概率。", ), ), ParameterRule( @@ -1226,7 +1231,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel): help=I18nObject( en_US="the format to return a response in.", zh_Hans="一个介于 0 到 20 之间的整数 N,指定每个输出位置返回输出概率 top N 的 token," - "且返回这些 token 的对数概率。指定此参数时,logprobs 必须为 true。" + "且返回这些 token 的对数概率。指定此参数时,logprobs 必须为 true。", ), default=0, min=0, diff --git a/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py b/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py index 740560a1..9fcf99e2 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py @@ -3,8 +3,6 @@ import logging import os from typing import Dict, List, Optional, Union -from ...._models import BaseModel - from model_providers.core.model_runtime.entities.model_entities import ModelType from model_providers.core.model_runtime.entities.provider_entities import ( ProviderConfig, @@ -25,6 +23,8 @@ from model_providers.core.utils.position_helper import ( sort_to_dict_by_position_map, ) +from ...._models import BaseModel + logger = logging.getLogger(__name__) diff --git a/model-providers/model_providers/core/model_runtime/model_providers/moonshot/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/moonshot/llm/llm.py index 3300b3cc..9a166123 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -1,5 +1,4 @@ -from typing import Generator -from typing import List, Optional, Union +from typing import Generator, List, Optional, Union from model_providers.core.model_runtime.entities.llm_entities import LLMResult from model_providers.core.model_runtime.entities.message_entities import ( diff --git a/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py index 59b0d666..92a8ca91 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py @@ -1,7 +1,6 @@ import logging -from typing import Generator -from typing import List, Optional, Union, cast from decimal import Decimal +from typing import Generator, List, Optional, Union, cast import tiktoken from openai import OpenAI, Stream @@ -37,10 +36,15 @@ from model_providers.core.model_runtime.entities.message_entities import ( ) from model_providers.core.model_runtime.entities.model_entities import ( AIModelEntity, + DefaultParameterName, FetchFrom, I18nObject, + ModelFeature, + ModelPropertyKey, ModelType, - PriceConfig, ModelFeature, ModelPropertyKey, DefaultParameterName, ParameterRule, ParameterType, + ParameterRule, + ParameterType, + PriceConfig, ) from model_providers.core.model_runtime.errors.validate import ( CredentialsValidateFailedError, @@ -48,7 +52,9 @@ from model_providers.core.model_runtime.errors.validate import ( from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( LargeLanguageModel, ) -from model_providers.core.model_runtime.model_providers.ollama._common import _CommonOllama +from model_providers.core.model_runtime.model_providers.ollama._common import ( + _CommonOllama, +) logger = logging.getLogger(__name__) @@ -661,7 +667,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], model=model, stream=stream, - extra_body=extra_body + extra_body=extra_body, ) if stream: @@ -1120,7 +1126,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): return num_tokens def get_customizable_model_schema( - self, model: str, credentials: dict + self, model: str, credentials: dict ) -> AIModelEntity: """ Get customizable model schema. @@ -1154,8 +1160,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): type=ParameterType.FLOAT, help=I18nObject( en_US="The temperature of the model. " - "Increasing the temperature will make the model answer " - "more creatively. (Default: 0.8)" + "Increasing the temperature will make the model answer " + "more creatively. (Default: 0.8)" ), default=0.8, min=0, @@ -1168,8 +1174,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): type=ParameterType.FLOAT, help=I18nObject( en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to " - "more diverse text, while a lower value (e.g., 0.5) will generate more " - "focused and conservative text. (Default: 0.9)" + "more diverse text, while a lower value (e.g., 0.5) will generate more " + "focused and conservative text. (Default: 0.9)" ), default=0.9, min=0, @@ -1181,8 +1187,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): type=ParameterType.INT, help=I18nObject( en_US="Reduces the probability of generating nonsense. " - "A higher value (e.g. 100) will give more diverse answers, " - "while a lower value (e.g. 10) will be more conservative. (Default: 40)" + "A higher value (e.g. 100) will give more diverse answers, " + "while a lower value (e.g. 10) will be more conservative. (Default: 40)" ), default=40, min=1, @@ -1194,8 +1200,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): type=ParameterType.FLOAT, help=I18nObject( en_US="Sets how strongly to penalize repetitions. " - "A higher value (e.g., 1.5) will penalize repetitions more strongly, " - "while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)" + "A higher value (e.g., 1.5) will penalize repetitions more strongly, " + "while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)" ), default=1.1, min=-2, @@ -1208,7 +1214,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): type=ParameterType.INT, help=I18nObject( en_US="Maximum number of tokens to predict when generating text. " - "(Default: 128, -1 = infinite generation, -2 = fill context)" + "(Default: 128, -1 = infinite generation, -2 = fill context)" ), default=128, min=-2, @@ -1220,7 +1226,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): type=ParameterType.INT, help=I18nObject( en_US="Enable Mirostat sampling for controlling perplexity. " - "(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)" + "(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)" ), default=0, min=0, @@ -1232,9 +1238,9 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): type=ParameterType.FLOAT, help=I18nObject( en_US="Influences how quickly the algorithm responds to feedback from " - "the generated text. A lower learning rate will result in slower adjustments, " - "while a higher learning rate will make the algorithm more responsive. " - "(Default: 0.1)" + "the generated text. A lower learning rate will result in slower adjustments, " + "while a higher learning rate will make the algorithm more responsive. " + "(Default: 0.1)" ), default=0.1, precision=1, @@ -1245,7 +1251,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): type=ParameterType.FLOAT, help=I18nObject( en_US="Controls the balance between coherence and diversity of the output. " - "A lower value will result in more focused and coherent text. (Default: 5.0)" + "A lower value will result in more focused and coherent text. (Default: 5.0)" ), default=5.0, precision=1, @@ -1256,7 +1262,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): type=ParameterType.INT, help=I18nObject( en_US="Sets the size of the context window used to generate the next token. " - "(Default: 2048)" + "(Default: 2048)" ), default=2048, min=1, @@ -1267,7 +1273,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): type=ParameterType.INT, help=I18nObject( en_US="The number of layers to send to the GPU(s). " - "On macOS it defaults to 1 to enable metal support, 0 to disable." + "On macOS it defaults to 1 to enable metal support, 0 to disable." ), default=1, min=0, @@ -1279,9 +1285,9 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): type=ParameterType.INT, help=I18nObject( en_US="Sets the number of threads to use during computation. " - "By default, Ollama will detect this for optimal performance. " - "It is recommended to set this value to the number of physical CPU cores " - "your system has (as opposed to the logical number of cores)." + "By default, Ollama will detect this for optimal performance. " + "It is recommended to set this value to the number of physical CPU cores " + "your system has (as opposed to the logical number of cores)." ), min=1, ), @@ -1291,7 +1297,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): type=ParameterType.INT, help=I18nObject( en_US="Sets how far back for the model to look back to prevent repetition. " - "(Default: 64, 0 = disabled, -1 = num_ctx)" + "(Default: 64, 0 = disabled, -1 = num_ctx)" ), default=64, min=-1, @@ -1302,8 +1308,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): type=ParameterType.FLOAT, help=I18nObject( en_US="Tail free sampling is used to reduce the impact of less probable tokens " - "from the output. A higher value (e.g., 2.0) will reduce the impact more, " - "while a value of 1.0 disables this setting. (default: 1)" + "from the output. A higher value (e.g., 2.0) will reduce the impact more, " + "while a value of 1.0 disables this setting. (default: 1)" ), default=1, precision=1, @@ -1314,8 +1320,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): type=ParameterType.INT, help=I18nObject( en_US="Sets the random number seed to use for generation. Setting this to " - "a specific number will make the model generate the same text for " - "the same prompt. (Default: 0)" + "a specific number will make the model generate the same text for " + "the same prompt. (Default: 0)" ), default=0, ), @@ -1325,7 +1331,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel): type=ParameterType.STRING, help=I18nObject( en_US="the format to return a response in." - " Currently the only accepted value is json." + " Currently the only accepted value is json." ), options=["json"], ), diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/openai/llm/llm.py index e32399bd..ae3effd7 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openai/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openai/llm/llm.py @@ -1,6 +1,5 @@ import logging -from typing import Generator -from typing import List, Optional, Union, cast +from typing import Generator, List, Optional, Union, cast import tiktoken from openai import OpenAI, Stream diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 1267d01e..d23ca90e 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -1,8 +1,7 @@ import json import logging -from typing import Generator from decimal import Decimal -from typing import List, Optional, Union, cast +from typing import Generator, List, Optional, Union, cast from urllib.parse import urljoin import requests diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index a45aac4b..b8eeb40d 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -1,7 +1,6 @@ -from typing import Generator from enum import Enum from json import dumps, loads -from typing import Any, Union +from typing import Any, Generator, Union from requests import Response, post from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema diff --git a/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/llm.py index cc4ee81e..8f6e0e60 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/llm.py @@ -1,6 +1,5 @@ import threading -from typing import Generator -from typing import Dict, List, Optional, Type, Union +from typing import Dict, Generator, List, Optional, Type, Union from model_providers.core.model_runtime.entities.llm_entities import ( LLMResult, diff --git a/model-providers/model_providers/core/model_runtime/model_providers/togetherai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/togetherai/llm/llm.py index 2010ab34..296d8707 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -1,5 +1,4 @@ -from typing import Generator -from typing import List, Optional, Union +from typing import Generator, List, Optional, Union from model_providers.core.model_runtime.entities.llm_entities import LLMResult from model_providers.core.model_runtime.entities.message_entities import ( diff --git a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/llm.py index 2edfb37c..4feb2ee3 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -1,5 +1,4 @@ -from typing import Generator -from typing import Dict, List, Optional, Type, Union +from typing import Dict, Generator, List, Optional, Type, Union from dashscope import get_tokenizer from dashscope.api_entities.dashscope_response import DashScopeAPIResponse diff --git a/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/llm.py index 68b2cf58..68fe5e1d 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/llm.py @@ -1,6 +1,4 @@ -from typing import Generator, Iterator - -from typing import Dict, List, Union, cast, Type +from typing import Dict, Generator, Iterator, List, Type, Union, cast from openai import ( APIConnectionError, @@ -527,8 +525,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): def _extract_response_tool_calls( self, response_tool_calls: Union[ - List[ChatCompletionMessageToolCall], - List[ChoiceDeltaToolCall] + List[ChatCompletionMessageToolCall], List[ChoiceDeltaToolCall] ], ) -> List[AssistantPromptMessage.ToolCall]: """ diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py index b987e6fc..ddbcaa26 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -195,8 +195,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): if stop: extra_model_kwargs["stop"] = stop - client = ZhipuAI(base_url=credentials_kwargs["api_base"], - api_key=credentials_kwargs["api_key"]) + client = ZhipuAI( + base_url=credentials_kwargs["api_base"], + api_key=credentials_kwargs["api_key"], + ) if len(prompt_messages) == 0: raise ValueError("At least one message is required") diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 9e12b53c..af254a69 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -43,8 +43,10 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): :return: embeddings result """ credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI(base_url=credentials_kwargs["api_base"], - api_key=credentials_kwargs["api_key"]) + client = ZhipuAI( + base_url=credentials_kwargs["api_base"], + api_key=credentials_kwargs["api_key"], + ) embeddings, embedding_used_tokens = self.embed_documents(model, client, texts) @@ -85,8 +87,10 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): try: # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI(base_url=credentials_kwargs["api_base"], - api_key=credentials_kwargs["api_key"]) + client = ZhipuAI( + base_url=credentials_kwargs["api_base"], + api_key=credentials_kwargs["api_key"], + ) # call embedding model self.embed_documents( diff --git a/model-providers/model_providers/core/model_runtime/utils/helper.py b/model-providers/model_providers/core/model_runtime/utils/helper.py index c6e97c83..6ad58e3c 100644 --- a/model-providers/model_providers/core/model_runtime/utils/helper.py +++ b/model-providers/model_providers/core/model_runtime/utils/helper.py @@ -1,4 +1,5 @@ import pydantic + from ...._models import BaseModel diff --git a/model-providers/model_providers/core/utils/json_dumps.py b/model-providers/model_providers/core/utils/json_dumps.py index 20b44de5..55deac59 100644 --- a/model-providers/model_providers/core/utils/json_dumps.py +++ b/model-providers/model_providers/core/utils/json_dumps.py @@ -1,6 +1,7 @@ import os import orjson + from ..._models import BaseModel diff --git a/model-providers/model_providers/extensions/ext_storage.py b/model-providers/model_providers/extensions/ext_storage.py index fcf863a6..7b88563d 100644 --- a/model-providers/model_providers/extensions/ext_storage.py +++ b/model-providers/model_providers/extensions/ext_storage.py @@ -1,8 +1,7 @@ import os import shutil -from typing import Generator from contextlib import closing -from typing import Union +from typing import Generator, Union import boto3 from botocore.exceptions import ClientError diff --git a/model-providers/tests/conftest.py b/model-providers/tests/conftest.py index a4508b81..2e1141be 100644 --- a/model-providers/tests/conftest.py +++ b/model-providers/tests/conftest.py @@ -104,15 +104,17 @@ def logging_conf() -> dict: 111, ) + @pytest.fixture def providers_file(request) -> str: - from pathlib import Path import os + from pathlib import Path + # 当前执行目录 # 获取当前测试文件的路径 test_file_path = Path(str(request.fspath)).parent - print("test_file_path:",test_file_path) - return os.path.join(test_file_path,"model_providers.yaml") + print("test_file_path:", test_file_path) + return os.path.join(test_file_path, "model_providers.yaml") @pytest.fixture @@ -121,9 +123,7 @@ def init_server(logging_conf: dict, providers_file: str) -> None: try: boot = ( BootstrapWebBuilder() - .model_providers_cfg_path( - model_providers_cfg_path=providers_file - ) + .model_providers_cfg_path(model_providers_cfg_path=providers_file) .host(host="127.0.0.1") .port(port=20000) .build() @@ -139,5 +139,4 @@ def init_server(logging_conf: dict, providers_file: str) -> None: boot.destroy() except SystemExit: - raise diff --git a/model-providers/tests/integration_tests/deepseek_providers_test/test_deepseek_service.py b/model-providers/tests/integration_tests/deepseek_providers_test/test_deepseek_service.py index a529ecc7..d9f20773 100644 --- a/model-providers/tests/integration_tests/deepseek_providers_test/test_deepseek_service.py +++ b/model-providers/tests/integration_tests/deepseek_providers_test/test_deepseek_service.py @@ -1,15 +1,20 @@ +import logging + +import pytest from langchain.chains import LLMChain from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI, OpenAIEmbeddings -import pytest -import logging logger = logging.getLogger(__name__) @pytest.mark.requires("openai") def test_llm(init_server: str): - llm = ChatOpenAI(model_name="deepseek-chat", openai_api_key="sk-dcb625fcbc1e497d80b7b9493b51d758", openai_api_base=f"{init_server}/deepseek/v1") + llm = ChatOpenAI( + model_name="deepseek-chat", + openai_api_key="sk-dcb625fcbc1e497d80b7b9493b51d758", + openai_api_base=f"{init_server}/deepseek/v1", + ) template = """Question: {question} Answer: Let's think step by step.""" diff --git a/model-providers/tests/integration_tests/ollama_providers_test/test_ollama_service.py b/model-providers/tests/integration_tests/ollama_providers_test/test_ollama_service.py index 5338efb9..e08d4fd7 100644 --- a/model-providers/tests/integration_tests/ollama_providers_test/test_ollama_service.py +++ b/model-providers/tests/integration_tests/ollama_providers_test/test_ollama_service.py @@ -1,15 +1,20 @@ +import logging + +import pytest from langchain.chains import LLMChain from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI, OpenAIEmbeddings -import pytest -import logging logger = logging.getLogger(__name__) @pytest.mark.requires("openai") def test_llm(init_server: str): - llm = ChatOpenAI(model_name="llama3", openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/ollama/v1") + llm = ChatOpenAI( + model_name="llama3", + openai_api_key="YOUR_API_KEY", + openai_api_base=f"{init_server}/ollama/v1", + ) template = """Question: {question} Answer: Let's think step by step.""" @@ -23,9 +28,11 @@ def test_llm(init_server: str): @pytest.mark.requires("openai") def test_embedding(init_server: str): - embeddings = OpenAIEmbeddings(model="text-embedding-3-large", - openai_api_key="YOUR_API_KEY", - openai_api_base=f"{init_server}/zhipuai/v1") + embeddings = OpenAIEmbeddings( + model="text-embedding-3-large", + openai_api_key="YOUR_API_KEY", + openai_api_base=f"{init_server}/zhipuai/v1", + ) text = "你好" diff --git a/model-providers/tests/integration_tests/openai_providers_test/test_openai_service.py b/model-providers/tests/integration_tests/openai_providers_test/test_openai_service.py index 0bdf455e..b3298256 100644 --- a/model-providers/tests/integration_tests/openai_providers_test/test_openai_service.py +++ b/model-providers/tests/integration_tests/openai_providers_test/test_openai_service.py @@ -1,15 +1,18 @@ +import logging + +import pytest from langchain.chains import LLMChain from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI, OpenAIEmbeddings -import pytest -import logging logger = logging.getLogger(__name__) @pytest.mark.requires("openai") def test_llm(init_server: str): - llm = ChatOpenAI(openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1") + llm = ChatOpenAI( + openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1" + ) template = """Question: {question} Answer: Let's think step by step.""" @@ -21,17 +24,16 @@ def test_llm(init_server: str): logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m") - - @pytest.mark.requires("openai") def test_embedding(init_server: str): - - embeddings = OpenAIEmbeddings(model="text-embedding-3-large", - openai_api_key="YOUR_API_KEY", - openai_api_base=f"{init_server}/zhipuai/v1") + embeddings = OpenAIEmbeddings( + model="text-embedding-3-large", + openai_api_key="YOUR_API_KEY", + openai_api_base=f"{init_server}/zhipuai/v1", + ) text = "你好" query_result = embeddings.embed_query(text) - logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m") \ No newline at end of file + logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m") diff --git a/model-providers/tests/integration_tests/xinference_providers_test/test_xinference_service.py b/model-providers/tests/integration_tests/xinference_providers_test/test_xinference_service.py index 831406ff..31634c2f 100644 --- a/model-providers/tests/integration_tests/xinference_providers_test/test_xinference_service.py +++ b/model-providers/tests/integration_tests/xinference_providers_test/test_xinference_service.py @@ -1,8 +1,9 @@ +import logging + +import pytest from langchain.chains import LLMChain from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI, OpenAIEmbeddings -import pytest -import logging logger = logging.getLogger(__name__) @@ -10,9 +11,10 @@ logger = logging.getLogger(__name__) @pytest.mark.requires("xinference_client") def test_llm(init_server: str): llm = ChatOpenAI( - model_name="glm-4", - openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/xinference/v1") + openai_api_key="YOUR_API_KEY", + openai_api_base=f"{init_server}/xinference/v1", + ) template = """Question: {question} Answer: Let's think step by step.""" @@ -26,9 +28,11 @@ def test_llm(init_server: str): @pytest.mark.requires("xinference-client") def test_embedding(init_server: str): - embeddings = OpenAIEmbeddings(model="text_embedding", - openai_api_key="YOUR_API_KEY", - openai_api_base=f"{init_server}/xinference/v1") + embeddings = OpenAIEmbeddings( + model="text_embedding", + openai_api_key="YOUR_API_KEY", + openai_api_base=f"{init_server}/xinference/v1", + ) text = "你好" diff --git a/model-providers/tests/integration_tests/zhipuai_providers_test/test_zhipuai_service.py b/model-providers/tests/integration_tests/zhipuai_providers_test/test_zhipuai_service.py index c110b71c..9d4088b3 100644 --- a/model-providers/tests/integration_tests/zhipuai_providers_test/test_zhipuai_service.py +++ b/model-providers/tests/integration_tests/zhipuai_providers_test/test_zhipuai_service.py @@ -1,17 +1,20 @@ +import logging + +import pytest from langchain.chains import LLMChain from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI, OpenAIEmbeddings -import pytest -import logging logger = logging.getLogger(__name__) + @pytest.mark.requires("zhipuai") def test_llm(init_server: str): llm = ChatOpenAI( - model_name="glm-4", - openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1") + openai_api_key="YOUR_API_KEY", + openai_api_base=f"{init_server}/zhipuai/v1", + ) template = """Question: {question} Answer: Let's think step by step.""" @@ -25,15 +28,14 @@ def test_llm(init_server: str): @pytest.mark.requires("zhipuai") def test_embedding(init_server: str): - - embeddings = OpenAIEmbeddings(model="text_embedding", - openai_api_key="YOUR_API_KEY", - openai_api_base=f"{init_server}/zhipuai/v1") + embeddings = OpenAIEmbeddings( + model="text_embedding", + openai_api_key="YOUR_API_KEY", + openai_api_base=f"{init_server}/zhipuai/v1", + ) text = "你好" query_result = embeddings.embed_query(text) logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m") - - diff --git a/model-providers/tests/unit_tests/deepseek/test_deepseek_provider_manager_models.py b/model-providers/tests/unit_tests/deepseek/test_deepseek_provider_manager_models.py index f28e8ba4..77569b4b 100644 --- a/model-providers/tests/unit_tests/deepseek/test_deepseek_provider_manager_models.py +++ b/model-providers/tests/unit_tests/deepseek/test_deepseek_provider_manager_models.py @@ -7,7 +7,10 @@ from omegaconf import OmegaConf from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration from model_providers.core.model_manager import ModelManager -from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity +from model_providers.core.model_runtime.entities.model_entities import ( + AIModelEntity, + ModelType, +) from model_providers.core.provider_manager import ProviderManager logger = logging.getLogger(__name__) @@ -16,9 +19,7 @@ logger = logging.getLogger(__name__) def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None: logging.config.dictConfig(logging_conf) # type: ignore # 读取配置文件 - cfg = OmegaConf.load( - providers_file - ) + cfg = OmegaConf.load(providers_file) # 转换配置文件 ( provider_name_to_provider_records_dict, @@ -35,15 +36,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non ) llm_models: List[AIModelEntity] = [] for model in provider_model_bundle_llm.configuration.custom_configuration.models: - - llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema( - model=model.model, - credentials=model.credentials, - )) + llm_models.append( + provider_model_bundle_llm.model_type_instance.get_model_schema( + model=model.model, + credentials=model.credentials, + ) + ) # 获取预定义模型 - llm_models.extend( - provider_model_bundle_llm.model_type_instance.predefined_models() - ) + llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models()) logger.info(f"predefined_models: {llm_models}") diff --git a/model-providers/tests/unit_tests/ollama/test_ollama_provider_manager_models.py b/model-providers/tests/unit_tests/ollama/test_ollama_provider_manager_models.py index 6d307698..eaaa14fe 100644 --- a/model-providers/tests/unit_tests/ollama/test_ollama_provider_manager_models.py +++ b/model-providers/tests/unit_tests/ollama/test_ollama_provider_manager_models.py @@ -15,9 +15,7 @@ logger = logging.getLogger(__name__) def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None: logging.config.dictConfig(logging_conf) # type: ignore # 读取配置文件 - cfg = OmegaConf.load( - providers_file - ) + cfg = OmegaConf.load(providers_file) # 转换配置文件 ( provider_name_to_provider_records_dict, @@ -34,16 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non ) llm_models: List[AIModelEntity] = [] for model in provider_model_bundle_llm.configuration.custom_configuration.models: - - llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema( - model=model.model, - credentials=model.credentials, - )) + llm_models.append( + provider_model_bundle_llm.model_type_instance.get_model_schema( + model=model.model, + credentials=model.credentials, + ) + ) # 获取预定义模型 - llm_models.extend( - provider_model_bundle_llm.model_type_instance.predefined_models() - ) + llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models()) logger.info(f"predefined_models: {llm_models}") - diff --git a/model-providers/tests/unit_tests/openai/test_openai_provider_manager_models.py b/model-providers/tests/unit_tests/openai/test_openai_provider_manager_models.py index 43c0a3d2..21707821 100644 --- a/model-providers/tests/unit_tests/openai/test_openai_provider_manager_models.py +++ b/model-providers/tests/unit_tests/openai/test_openai_provider_manager_models.py @@ -15,9 +15,7 @@ logger = logging.getLogger(__name__) def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None: logging.config.dictConfig(logging_conf) # type: ignore # 读取配置文件 - cfg = OmegaConf.load( - providers_file - ) + cfg = OmegaConf.load(providers_file) # 转换配置文件 ( provider_name_to_provider_records_dict, @@ -34,15 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non ) llm_models: List[AIModelEntity] = [] for model in provider_model_bundle_llm.configuration.custom_configuration.models: - - llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema( - model=model.model, - credentials=model.credentials, - )) + llm_models.append( + provider_model_bundle_llm.model_type_instance.get_model_schema( + model=model.model, + credentials=model.credentials, + ) + ) # 获取预定义模型 - llm_models.extend( - provider_model_bundle_llm.model_type_instance.predefined_models() - ) + llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models()) logger.info(f"predefined_models: {llm_models}") diff --git a/model-providers/tests/unit_tests/xinference/test_xinference_provider_manager_models.py b/model-providers/tests/unit_tests/xinference/test_xinference_provider_manager_models.py index a8c404ef..88b98ed4 100644 --- a/model-providers/tests/unit_tests/xinference/test_xinference_provider_manager_models.py +++ b/model-providers/tests/unit_tests/xinference/test_xinference_provider_manager_models.py @@ -15,9 +15,7 @@ logger = logging.getLogger(__name__) def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None: logging.config.dictConfig(logging_conf) # type: ignore # 读取配置文件 - cfg = OmegaConf.load( - providers_file - ) + cfg = OmegaConf.load(providers_file) # 转换配置文件 ( provider_name_to_provider_records_dict, @@ -34,15 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non ) llm_models: List[AIModelEntity] = [] for model in provider_model_bundle_llm.configuration.custom_configuration.models: - - llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema( - model=model.model, - credentials=model.credentials, - )) + llm_models.append( + provider_model_bundle_llm.model_type_instance.get_model_schema( + model=model.model, + credentials=model.credentials, + ) + ) # 获取预定义模型 - llm_models.extend( - provider_model_bundle_llm.model_type_instance.predefined_models() - ) + llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models()) logger.info(f"predefined_models: {llm_models}") diff --git a/model-providers/tests/unit_tests/zhipuai/test_zhipuai_provider_manager_models.py b/model-providers/tests/unit_tests/zhipuai/test_zhipuai_provider_manager_models.py index 23ba6f62..38999168 100644 --- a/model-providers/tests/unit_tests/zhipuai/test_zhipuai_provider_manager_models.py +++ b/model-providers/tests/unit_tests/zhipuai/test_zhipuai_provider_manager_models.py @@ -15,9 +15,7 @@ logger = logging.getLogger(__name__) def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None: logging.config.dictConfig(logging_conf) # type: ignore # 读取配置文件 - cfg = OmegaConf.load( - providers_file - ) + cfg = OmegaConf.load(providers_file) # 转换配置文件 ( provider_name_to_provider_records_dict, @@ -34,15 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non ) llm_models: List[AIModelEntity] = [] for model in provider_model_bundle_llm.configuration.custom_configuration.models: - - llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema( - model=model.model, - credentials=model.credentials, - )) + llm_models.append( + provider_model_bundle_llm.model_type_instance.get_model_schema( + model=model.model, + credentials=model.credentials, + ) + ) # 获取预定义模型 - llm_models.extend( - provider_model_bundle_llm.model_type_instance.predefined_models() - ) + llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models()) logger.info(f"predefined_models: {llm_models}")