mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
make format
This commit is contained in:
parent
6dd00b5d94
commit
4e9b1d6edf
@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
|
||||
from datetime import date, datetime
|
||||
from typing_extensions import Self
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload
|
||||
|
||||
import pydantic
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._types import StrBytesIntFloat
|
||||
|
||||
@ -45,23 +45,49 @@ if TYPE_CHECKING:
|
||||
|
||||
else:
|
||||
if PYDANTIC_V2:
|
||||
from pydantic.v1.datetime_parse import (
|
||||
parse_date as parse_date,
|
||||
)
|
||||
from pydantic.v1.datetime_parse import (
|
||||
parse_datetime as parse_datetime,
|
||||
)
|
||||
from pydantic.v1.typing import (
|
||||
get_args as get_args,
|
||||
is_union as is_union,
|
||||
)
|
||||
from pydantic.v1.typing import (
|
||||
get_origin as get_origin,
|
||||
is_typeddict as is_typeddict,
|
||||
)
|
||||
from pydantic.v1.typing import (
|
||||
is_literal_type as is_literal_type,
|
||||
)
|
||||
from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
|
||||
from pydantic.v1.typing import (
|
||||
is_typeddict as is_typeddict,
|
||||
)
|
||||
from pydantic.v1.typing import (
|
||||
is_union as is_union,
|
||||
)
|
||||
else:
|
||||
from pydantic.datetime_parse import (
|
||||
parse_date as parse_date,
|
||||
)
|
||||
from pydantic.datetime_parse import (
|
||||
parse_datetime as parse_datetime,
|
||||
)
|
||||
from pydantic.typing import (
|
||||
get_args as get_args,
|
||||
is_union as is_union,
|
||||
)
|
||||
from pydantic.typing import (
|
||||
get_origin as get_origin,
|
||||
is_typeddict as is_typeddict,
|
||||
)
|
||||
from pydantic.typing import (
|
||||
is_literal_type as is_literal_type,
|
||||
)
|
||||
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
|
||||
from pydantic.typing import (
|
||||
is_typeddict as is_typeddict,
|
||||
)
|
||||
from pydantic.typing import (
|
||||
is_union as is_union,
|
||||
)
|
||||
|
||||
|
||||
# refactored config
|
||||
@ -204,7 +230,9 @@ if TYPE_CHECKING:
|
||||
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T:
|
||||
...
|
||||
|
||||
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
|
||||
def __get__(
|
||||
self, instance: object, owner: type[Any] | None = None
|
||||
) -> _T | Self:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __set_name__(self, owner: type[Any], name: str) -> None:
|
||||
|
||||
@ -4,20 +4,20 @@ import io
|
||||
import os
|
||||
import pathlib
|
||||
from typing import overload
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
import anyio
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from ._types import (
|
||||
FileTypes,
|
||||
FileContent,
|
||||
RequestFiles,
|
||||
HttpxFileTypes,
|
||||
Base64FileInput,
|
||||
FileContent,
|
||||
FileTypes,
|
||||
HttpxFileContent,
|
||||
HttpxFileTypes,
|
||||
HttpxRequestFiles,
|
||||
RequestFiles,
|
||||
)
|
||||
from ._utils import is_tuple_t, is_mapping_t, is_sequence_t
|
||||
from ._utils import is_mapping_t, is_sequence_t, is_tuple_t
|
||||
|
||||
|
||||
def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
|
||||
@ -26,13 +26,20 @@ def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
|
||||
|
||||
def is_file_content(obj: object) -> TypeGuard[FileContent]:
|
||||
return (
|
||||
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
|
||||
isinstance(obj, bytes)
|
||||
or isinstance(obj, tuple)
|
||||
or isinstance(obj, io.IOBase)
|
||||
or isinstance(obj, os.PathLike)
|
||||
)
|
||||
|
||||
|
||||
def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
|
||||
if not is_file_content(obj):
|
||||
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
|
||||
prefix = (
|
||||
f"Expected entry at `{key}`"
|
||||
if key is not None
|
||||
else f"Expected file input `{obj!r}`"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads"
|
||||
) from None
|
||||
@ -57,7 +64,9 @@ def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
|
||||
elif is_sequence_t(files):
|
||||
files = [(key, _transform_file(file)) for key, file in files]
|
||||
else:
|
||||
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
|
||||
raise TypeError(
|
||||
f"Unexpected file type input {type(files)}, expected mapping or sequence"
|
||||
)
|
||||
|
||||
return files
|
||||
|
||||
@ -73,7 +82,9 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes:
|
||||
if is_tuple_t(file):
|
||||
return (file[0], _read_file_content(file[1]), *file[2:])
|
||||
|
||||
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
|
||||
raise TypeError(
|
||||
f"Expected file types input to be a FileContent type or to be a tuple"
|
||||
)
|
||||
|
||||
|
||||
def _read_file_content(file: FileContent) -> HttpxFileContent:
|
||||
@ -101,7 +112,9 @@ async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles
|
||||
elif is_sequence_t(files):
|
||||
files = [(key, await _async_transform_file(file)) for key, file in files]
|
||||
else:
|
||||
raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence")
|
||||
raise TypeError(
|
||||
"Unexpected file type input {type(files)}, expected mapping or sequence"
|
||||
)
|
||||
|
||||
return files
|
||||
|
||||
@ -117,7 +130,9 @@ async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
|
||||
if is_tuple_t(file):
|
||||
return (file[0], await _async_read_file_content(file[1]), *file[2:])
|
||||
|
||||
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
|
||||
raise TypeError(
|
||||
f"Expected file types input to be a FileContent type or to be a tuple"
|
||||
)
|
||||
|
||||
|
||||
async def _async_read_file_content(file: FileContent) -> HttpxFileContent:
|
||||
|
||||
@ -1,60 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast
|
||||
import os
|
||||
from datetime import date, datetime
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Type, TypeVar, Union, cast
|
||||
|
||||
import pydantic
|
||||
import pydantic.generics
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import (
|
||||
Unpack,
|
||||
Literal,
|
||||
ClassVar,
|
||||
Literal,
|
||||
Protocol,
|
||||
Required,
|
||||
TypedDict,
|
||||
TypeGuard,
|
||||
Unpack,
|
||||
final,
|
||||
override,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
import pydantic
|
||||
import pydantic.generics
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from ._compat import (
|
||||
PYDANTIC_V2,
|
||||
ConfigDict,
|
||||
field_get_default,
|
||||
get_args,
|
||||
get_model_config,
|
||||
get_model_fields,
|
||||
get_origin,
|
||||
is_literal_type,
|
||||
is_union,
|
||||
parse_obj,
|
||||
)
|
||||
from ._compat import (
|
||||
GenericModel as BaseGenericModel,
|
||||
)
|
||||
from ._types import (
|
||||
IncEx,
|
||||
ModelT,
|
||||
)
|
||||
from ._utils import (
|
||||
PropertyInfo,
|
||||
is_list,
|
||||
is_given,
|
||||
lru_cache,
|
||||
is_mapping,
|
||||
parse_date,
|
||||
coerce_boolean,
|
||||
parse_datetime,
|
||||
strip_not_given,
|
||||
extract_type_arg,
|
||||
is_annotated_type,
|
||||
is_given,
|
||||
is_list,
|
||||
is_mapping,
|
||||
lru_cache,
|
||||
parse_date,
|
||||
parse_datetime,
|
||||
strip_annotated_type,
|
||||
)
|
||||
from ._compat import (
|
||||
PYDANTIC_V2,
|
||||
ConfigDict,
|
||||
GenericModel as BaseGenericModel,
|
||||
get_args,
|
||||
is_union,
|
||||
parse_obj,
|
||||
get_origin,
|
||||
is_literal_type,
|
||||
get_model_config,
|
||||
get_model_fields,
|
||||
field_get_default,
|
||||
strip_not_given,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema
|
||||
from pydantic_core.core_schema import LiteralSchema, ModelField, ModelFieldsSchema
|
||||
|
||||
__all__ = ["BaseModel", "GenericModel"]
|
||||
|
||||
@ -69,7 +71,8 @@ class _ConfigProtocol(Protocol):
|
||||
class BaseModel(pydantic.BaseModel):
|
||||
if PYDANTIC_V2:
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
|
||||
extra="allow",
|
||||
defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")),
|
||||
)
|
||||
else:
|
||||
|
||||
@ -189,7 +192,9 @@ class BaseModel(pydantic.BaseModel):
|
||||
key = name
|
||||
|
||||
if key in values:
|
||||
fields_values[name] = _construct_field(value=values[key], field=field, key=key)
|
||||
fields_values[name] = _construct_field(
|
||||
value=values[key], field=field, key=key
|
||||
)
|
||||
_fields_set.add(name)
|
||||
else:
|
||||
fields_values[name] = field_get_default(field)
|
||||
@ -412,9 +417,13 @@ def construct_type(*, value: object, type_: object) -> object:
|
||||
#
|
||||
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
|
||||
# we'd end up constructing `FooType` when it should be `BarType`.
|
||||
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
|
||||
discriminator = _build_discriminated_union_meta(
|
||||
union=type_, meta_annotations=meta
|
||||
)
|
||||
if discriminator and is_mapping(value):
|
||||
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
|
||||
variant_value = value.get(
|
||||
discriminator.field_alias_from or discriminator.field_name
|
||||
)
|
||||
if variant_value and isinstance(variant_value, str):
|
||||
variant_type = discriminator.mapping.get(variant_value)
|
||||
if variant_type:
|
||||
@ -434,11 +443,19 @@ def construct_type(*, value: object, type_: object) -> object:
|
||||
return value
|
||||
|
||||
_, items_type = get_args(type_) # Dict[_, items_type]
|
||||
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
|
||||
return {
|
||||
key: construct_type(value=item, type_=items_type)
|
||||
for key, item in value.items()
|
||||
}
|
||||
|
||||
if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):
|
||||
if not is_literal_type(type_) and (
|
||||
issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
|
||||
):
|
||||
if is_list(value):
|
||||
return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
|
||||
return [
|
||||
cast(Any, type_).construct(**entry) if is_mapping(entry) else entry
|
||||
for entry in value
|
||||
]
|
||||
|
||||
if is_mapping(value):
|
||||
if issubclass(type_, BaseModel):
|
||||
@ -523,14 +540,19 @@ class DiscriminatorDetails:
|
||||
self.field_alias_from = discriminator_alias
|
||||
|
||||
|
||||
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
|
||||
def _build_discriminated_union_meta(
|
||||
*, union: type, meta_annotations: tuple[Any, ...]
|
||||
) -> DiscriminatorDetails | None:
|
||||
if isinstance(union, CachedDiscriminatorType):
|
||||
return union.__discriminator__
|
||||
|
||||
discriminator_field_name: str | None = None
|
||||
|
||||
for annotation in meta_annotations:
|
||||
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
|
||||
if (
|
||||
isinstance(annotation, PropertyInfo)
|
||||
and annotation.discriminator is not None
|
||||
):
|
||||
discriminator_field_name = annotation.discriminator
|
||||
break
|
||||
|
||||
@ -558,7 +580,9 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
|
||||
if isinstance(entry, str):
|
||||
mapping[entry] = variant
|
||||
else:
|
||||
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
||||
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(
|
||||
discriminator_field_name
|
||||
) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
||||
if not field_info:
|
||||
continue
|
||||
|
||||
@ -582,7 +606,9 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
|
||||
return details
|
||||
|
||||
|
||||
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
|
||||
def _extract_field_schema_pv2(
|
||||
model: type[BaseModel], field_name: str
|
||||
) -> ModelField | None:
|
||||
schema = model.__pydantic_core_schema__
|
||||
if schema["type"] != "model":
|
||||
return None
|
||||
@ -621,7 +647,9 @@ else:
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import TypeAdapter as _TypeAdapter
|
||||
|
||||
_CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter))
|
||||
_CachedTypeAdapter = cast(
|
||||
"TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter)
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import TypeAdapter
|
||||
@ -652,6 +680,3 @@ elif not TYPE_CHECKING: # TODO: condition is weird
|
||||
|
||||
def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
|
||||
return RootModel[type_] # type: ignore
|
||||
|
||||
|
||||
|
||||
|
||||
@ -5,22 +5,29 @@ from typing import (
|
||||
IO,
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Type,
|
||||
Tuple,
|
||||
Union,
|
||||
Mapping,
|
||||
TypeVar,
|
||||
Callable,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable
|
||||
|
||||
import httpx
|
||||
import pydantic
|
||||
from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport
|
||||
from httpx import URL, AsyncBaseTransport, BaseTransport, Proxy, Response, Timeout
|
||||
from typing_extensions import (
|
||||
Literal,
|
||||
Protocol,
|
||||
TypeAlias,
|
||||
TypedDict,
|
||||
override,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._models import BaseModel
|
||||
@ -43,7 +50,9 @@ if TYPE_CHECKING:
|
||||
FileContent = Union[IO[bytes], bytes, PathLike[str]]
|
||||
else:
|
||||
Base64FileInput = Union[IO[bytes], PathLike]
|
||||
FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.
|
||||
FileContent = Union[
|
||||
IO[bytes], bytes, PathLike
|
||||
] # PathLike is not subscriptable in Python 3.8.
|
||||
FileTypes = Union[
|
||||
# file (or bytes)
|
||||
FileContent,
|
||||
@ -68,7 +77,9 @@ HttpxFileTypes = Union[
|
||||
# (filename, file (or bytes), content_type, headers)
|
||||
Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
|
||||
]
|
||||
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]
|
||||
HttpxRequestFiles = Union[
|
||||
Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]
|
||||
]
|
||||
|
||||
# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT
|
||||
# where ResponseT includes `None`. In order to support directly
|
||||
|
||||
@ -1,48 +1,126 @@
|
||||
from ._utils import (
|
||||
flatten as flatten,
|
||||
is_dict as is_dict,
|
||||
is_list as is_list,
|
||||
is_given as is_given,
|
||||
is_tuple as is_tuple,
|
||||
lru_cache as lru_cache,
|
||||
is_mapping as is_mapping,
|
||||
is_tuple_t as is_tuple_t,
|
||||
parse_date as parse_date,
|
||||
is_iterable as is_iterable,
|
||||
is_sequence as is_sequence,
|
||||
coerce_float as coerce_float,
|
||||
is_mapping_t as is_mapping_t,
|
||||
removeprefix as removeprefix,
|
||||
removesuffix as removesuffix,
|
||||
extract_files as extract_files,
|
||||
is_sequence_t as is_sequence_t,
|
||||
required_args as required_args,
|
||||
coerce_boolean as coerce_boolean,
|
||||
coerce_integer as coerce_integer,
|
||||
file_from_path as file_from_path,
|
||||
parse_datetime as parse_datetime,
|
||||
strip_not_given as strip_not_given,
|
||||
deepcopy_minimal as deepcopy_minimal,
|
||||
get_async_library as get_async_library,
|
||||
maybe_coerce_float as maybe_coerce_float,
|
||||
get_required_header as get_required_header,
|
||||
maybe_coerce_boolean as maybe_coerce_boolean,
|
||||
maybe_coerce_integer as maybe_coerce_integer,
|
||||
from ._transform import (
|
||||
PropertyInfo as PropertyInfo,
|
||||
)
|
||||
from ._transform import (
|
||||
async_maybe_transform as async_maybe_transform,
|
||||
)
|
||||
from ._transform import (
|
||||
async_transform as async_transform,
|
||||
)
|
||||
from ._transform import (
|
||||
maybe_transform as maybe_transform,
|
||||
)
|
||||
from ._transform import (
|
||||
transform as transform,
|
||||
)
|
||||
from ._typing import (
|
||||
extract_type_arg as extract_type_arg,
|
||||
)
|
||||
from ._typing import (
|
||||
extract_type_var_from_base as extract_type_var_from_base,
|
||||
)
|
||||
from ._typing import (
|
||||
is_annotated_type as is_annotated_type,
|
||||
)
|
||||
from ._typing import (
|
||||
is_iterable_type as is_iterable_type,
|
||||
)
|
||||
from ._typing import (
|
||||
is_list_type as is_list_type,
|
||||
is_union_type as is_union_type,
|
||||
extract_type_arg as extract_type_arg,
|
||||
is_iterable_type as is_iterable_type,
|
||||
)
|
||||
from ._typing import (
|
||||
is_required_type as is_required_type,
|
||||
is_annotated_type as is_annotated_type,
|
||||
)
|
||||
from ._typing import (
|
||||
is_union_type as is_union_type,
|
||||
)
|
||||
from ._typing import (
|
||||
strip_annotated_type as strip_annotated_type,
|
||||
extract_type_var_from_base as extract_type_var_from_base,
|
||||
)
|
||||
from ._transform import (
|
||||
PropertyInfo as PropertyInfo,
|
||||
transform as transform,
|
||||
async_transform as async_transform,
|
||||
maybe_transform as maybe_transform,
|
||||
async_maybe_transform as async_maybe_transform,
|
||||
from ._utils import (
|
||||
coerce_boolean as coerce_boolean,
|
||||
)
|
||||
from ._utils import (
|
||||
coerce_float as coerce_float,
|
||||
)
|
||||
from ._utils import (
|
||||
coerce_integer as coerce_integer,
|
||||
)
|
||||
from ._utils import (
|
||||
deepcopy_minimal as deepcopy_minimal,
|
||||
)
|
||||
from ._utils import (
|
||||
extract_files as extract_files,
|
||||
)
|
||||
from ._utils import (
|
||||
file_from_path as file_from_path,
|
||||
)
|
||||
from ._utils import (
|
||||
flatten as flatten,
|
||||
)
|
||||
from ._utils import (
|
||||
get_async_library as get_async_library,
|
||||
)
|
||||
from ._utils import (
|
||||
get_required_header as get_required_header,
|
||||
)
|
||||
from ._utils import (
|
||||
is_dict as is_dict,
|
||||
)
|
||||
from ._utils import (
|
||||
is_given as is_given,
|
||||
)
|
||||
from ._utils import (
|
||||
is_iterable as is_iterable,
|
||||
)
|
||||
from ._utils import (
|
||||
is_list as is_list,
|
||||
)
|
||||
from ._utils import (
|
||||
is_mapping as is_mapping,
|
||||
)
|
||||
from ._utils import (
|
||||
is_mapping_t as is_mapping_t,
|
||||
)
|
||||
from ._utils import (
|
||||
is_sequence as is_sequence,
|
||||
)
|
||||
from ._utils import (
|
||||
is_sequence_t as is_sequence_t,
|
||||
)
|
||||
from ._utils import (
|
||||
is_tuple as is_tuple,
|
||||
)
|
||||
from ._utils import (
|
||||
is_tuple_t as is_tuple_t,
|
||||
)
|
||||
from ._utils import (
|
||||
lru_cache as lru_cache,
|
||||
)
|
||||
from ._utils import (
|
||||
maybe_coerce_boolean as maybe_coerce_boolean,
|
||||
)
|
||||
from ._utils import (
|
||||
maybe_coerce_float as maybe_coerce_float,
|
||||
)
|
||||
from ._utils import (
|
||||
maybe_coerce_integer as maybe_coerce_integer,
|
||||
)
|
||||
from ._utils import (
|
||||
parse_date as parse_date,
|
||||
)
|
||||
from ._utils import (
|
||||
parse_datetime as parse_datetime,
|
||||
)
|
||||
from ._utils import (
|
||||
removeprefix as removeprefix,
|
||||
)
|
||||
from ._utils import (
|
||||
removesuffix as removesuffix,
|
||||
)
|
||||
from ._utils import (
|
||||
required_args as required_args,
|
||||
)
|
||||
from ._utils import (
|
||||
strip_not_given as strip_not_given,
|
||||
)
|
||||
|
||||
@ -1,31 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import base64
|
||||
import io
|
||||
import pathlib
|
||||
from typing import Any, Mapping, TypeVar, cast
|
||||
from datetime import date, datetime
|
||||
from typing_extensions import Literal, get_args, override, get_type_hints
|
||||
from typing import Any, Mapping, TypeVar, cast
|
||||
|
||||
import anyio
|
||||
import pydantic
|
||||
from typing_extensions import Literal, get_args, get_type_hints, override
|
||||
|
||||
from ._utils import (
|
||||
is_list,
|
||||
is_mapping,
|
||||
is_iterable,
|
||||
)
|
||||
from .._compat import is_typeddict, model_dump
|
||||
from .._files import is_base64_file_input
|
||||
from ._typing import (
|
||||
is_list_type,
|
||||
is_union_type,
|
||||
extract_type_arg,
|
||||
is_iterable_type,
|
||||
is_required_type,
|
||||
is_annotated_type,
|
||||
is_iterable_type,
|
||||
is_list_type,
|
||||
is_required_type,
|
||||
is_union_type,
|
||||
strip_annotated_type,
|
||||
)
|
||||
from .._compat import model_dump, is_typeddict
|
||||
from ._utils import (
|
||||
is_iterable,
|
||||
is_list,
|
||||
is_mapping,
|
||||
)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
@ -171,10 +171,17 @@ def _transform_recursive(
|
||||
# List[T]
|
||||
(is_list_type(stripped_type) and is_list(data))
|
||||
# Iterable[T]
|
||||
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
|
||||
or (
|
||||
is_iterable_type(stripped_type)
|
||||
and is_iterable(data)
|
||||
and not isinstance(data, str)
|
||||
)
|
||||
):
|
||||
inner_type = extract_type_arg(stripped_type, 0)
|
||||
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
|
||||
return [
|
||||
_transform_recursive(d, annotation=annotation, inner_type=inner_type)
|
||||
for d in data
|
||||
]
|
||||
|
||||
if is_union_type(stripped_type):
|
||||
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
|
||||
@ -201,7 +208,9 @@ def _transform_recursive(
|
||||
return data
|
||||
|
||||
|
||||
def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
|
||||
def _format_data(
|
||||
data: object, format_: PropertyFormat, format_template: str | None
|
||||
) -> object:
|
||||
if isinstance(data, (date, datetime)):
|
||||
if format_ == "iso8601":
|
||||
return data.isoformat()
|
||||
@ -221,7 +230,9 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N
|
||||
binary = binary.encode()
|
||||
|
||||
if not isinstance(binary, bytes):
|
||||
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
|
||||
raise RuntimeError(
|
||||
f"Could not read bytes from {data}; Received {type(binary)}"
|
||||
)
|
||||
|
||||
return base64.b64encode(binary).decode("ascii")
|
||||
|
||||
@ -240,7 +251,9 @@ def _transform_typeddict(
|
||||
# we do not have a type annotation for this field, leave it as is
|
||||
result[key] = value
|
||||
else:
|
||||
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
|
||||
result[_maybe_transform_key(key, type_)] = _transform_recursive(
|
||||
value, annotation=type_
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@ -276,7 +289,9 @@ async def async_transform(
|
||||
|
||||
It should be noted that the transformations that this function does are not represented in the type system.
|
||||
"""
|
||||
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
|
||||
transformed = await _async_transform_recursive(
|
||||
data, annotation=cast(type, expected_type)
|
||||
)
|
||||
return cast(_T, transformed)
|
||||
|
||||
|
||||
@ -309,10 +324,19 @@ async def _async_transform_recursive(
|
||||
# List[T]
|
||||
(is_list_type(stripped_type) and is_list(data))
|
||||
# Iterable[T]
|
||||
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
|
||||
or (
|
||||
is_iterable_type(stripped_type)
|
||||
and is_iterable(data)
|
||||
and not isinstance(data, str)
|
||||
)
|
||||
):
|
||||
inner_type = extract_type_arg(stripped_type, 0)
|
||||
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
|
||||
return [
|
||||
await _async_transform_recursive(
|
||||
d, annotation=annotation, inner_type=inner_type
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
if is_union_type(stripped_type):
|
||||
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
|
||||
@ -320,7 +344,9 @@ async def _async_transform_recursive(
|
||||
# TODO: there may be edge cases where the same normalized field name will transform to two different names
|
||||
# in different subtypes.
|
||||
for subtype in get_args(stripped_type):
|
||||
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
|
||||
data = await _async_transform_recursive(
|
||||
data, annotation=annotation, inner_type=subtype
|
||||
)
|
||||
return data
|
||||
|
||||
if isinstance(data, pydantic.BaseModel):
|
||||
@ -334,12 +360,16 @@ async def _async_transform_recursive(
|
||||
annotations = get_args(annotated_type)[1:]
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
|
||||
return await _async_format_data(data, annotation.format, annotation.format_template)
|
||||
return await _async_format_data(
|
||||
data, annotation.format, annotation.format_template
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
|
||||
async def _async_format_data(
|
||||
data: object, format_: PropertyFormat, format_template: str | None
|
||||
) -> object:
|
||||
if isinstance(data, (date, datetime)):
|
||||
if format_ == "iso8601":
|
||||
return data.isoformat()
|
||||
@ -359,7 +389,9 @@ async def _async_format_data(data: object, format_: PropertyFormat, format_templ
|
||||
binary = binary.encode()
|
||||
|
||||
if not isinstance(binary, bytes):
|
||||
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
|
||||
raise RuntimeError(
|
||||
f"Could not read bytes from {data}; Received {type(binary)}"
|
||||
)
|
||||
|
||||
return base64.b64encode(binary).decode("ascii")
|
||||
|
||||
@ -378,5 +410,7 @@ async def _async_transform_typeddict(
|
||||
# we do not have a type annotation for this field, leave it as is
|
||||
result[key] = value
|
||||
else:
|
||||
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
|
||||
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(
|
||||
value, annotation=type_
|
||||
)
|
||||
return result
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TypeVar, Iterable, cast
|
||||
from collections import abc as _c_abc
|
||||
from typing_extensions import Required, Annotated, get_args, get_origin
|
||||
from typing import Any, Iterable, TypeVar, cast
|
||||
|
||||
from typing_extensions import Annotated, Required, get_args, get_origin
|
||||
|
||||
from .._types import InheritsGeneric
|
||||
from .._compat import is_union as _is_union
|
||||
from .._types import InheritsGeneric
|
||||
|
||||
|
||||
def is_annotated_type(typ: type) -> bool:
|
||||
@ -49,15 +50,17 @@ def extract_type_arg(typ: type, index: int) -> type:
|
||||
try:
|
||||
return cast(type, args[index])
|
||||
except IndexError as err:
|
||||
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
|
||||
raise RuntimeError(
|
||||
f"Expected type {typ} to have a type argument at index {index} but it did not"
|
||||
) from err
|
||||
|
||||
|
||||
def extract_type_var_from_base(
|
||||
typ: type,
|
||||
*,
|
||||
generic_bases: tuple[type, ...],
|
||||
index: int,
|
||||
failure_message: str | None = None,
|
||||
typ: type,
|
||||
*,
|
||||
generic_bases: tuple[type, ...],
|
||||
index: int,
|
||||
failure_message: str | None = None,
|
||||
) -> type:
|
||||
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
|
||||
|
||||
@ -117,4 +120,7 @@ def extract_type_var_from_base(
|
||||
|
||||
return extracted
|
||||
|
||||
raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")
|
||||
raise RuntimeError(
|
||||
failure_message
|
||||
or f"Could not resolve inner type variable at index {index} for {typ}"
|
||||
)
|
||||
|
||||
@ -1,27 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import inspect
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Tuple,
|
||||
Mapping,
|
||||
TypeVar,
|
||||
Callable,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
from pathlib import Path
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
import sniffio
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike
|
||||
from .._compat import parse_date as parse_date, parse_datetime as parse_datetime
|
||||
from .._compat import parse_date as parse_date
|
||||
from .._compat import parse_datetime as parse_datetime
|
||||
from .._types import FileTypes, Headers, HeadersLike, NotGiven, NotGivenOr
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
|
||||
@ -108,7 +109,9 @@ def _extract_items(
|
||||
item,
|
||||
path,
|
||||
index=index,
|
||||
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
|
||||
flattened_key=flattened_key + "[]"
|
||||
if flattened_key is not None
|
||||
else "[]",
|
||||
)
|
||||
for item in obj
|
||||
]
|
||||
@ -261,7 +264,12 @@ def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
|
||||
else: # no break
|
||||
if len(variants) > 1:
|
||||
variations = human_join(
|
||||
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
|
||||
[
|
||||
"("
|
||||
+ human_join([quote(arg) for arg in variant], final="and")
|
||||
+ ")"
|
||||
for variant in variants
|
||||
]
|
||||
)
|
||||
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
|
||||
else:
|
||||
@ -376,7 +384,11 @@ def get_required_header(headers: HeadersLike, header: str) -> str:
|
||||
return v
|
||||
|
||||
""" to deal with the case where the header looks like Stainless-Event-Id """
|
||||
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
|
||||
intercaps_header = re.sub(
|
||||
r"([^\w])(\w)",
|
||||
lambda pat: pat.group(1) + pat.group(2).upper(),
|
||||
header.capitalize(),
|
||||
)
|
||||
|
||||
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
|
||||
value = headers.get(normalized_header)
|
||||
|
||||
@ -1,9 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from ..._compat import PYDANTIC_V2, ConfigDict
|
||||
from ..._models import BaseModel
|
||||
|
||||
from model_providers.core.entities.model_entities import (
|
||||
ModelStatus,
|
||||
ModelWithProviderEntity,
|
||||
@ -26,6 +23,9 @@ from model_providers.core.model_runtime.entities.provider_entities import (
|
||||
SimpleProviderEntity,
|
||||
)
|
||||
|
||||
from ..._compat import PYDANTIC_V2, ConfigDict
|
||||
from ..._models import BaseModel
|
||||
|
||||
|
||||
class CustomConfigurationStatus(Enum):
|
||||
"""
|
||||
@ -74,10 +74,9 @@ class ProviderResponse(BaseModel):
|
||||
custom_configuration: CustomConfigurationResponse
|
||||
system_configuration: SystemConfigurationResponse
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=()
|
||||
)
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
else:
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
@ -180,10 +179,9 @@ class DefaultModelResponse(BaseModel):
|
||||
provider: SimpleProviderEntityResponse
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=()
|
||||
)
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
else:
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
@ -73,7 +73,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
self._server = None
|
||||
self._server_thread = None
|
||||
|
||||
def logging_conf(self,logging_conf: Optional[dict] = None):
|
||||
def logging_conf(self, logging_conf: Optional[dict] = None):
|
||||
self._logging_conf = logging_conf
|
||||
|
||||
@classmethod
|
||||
@ -132,7 +132,10 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
self._app.include_router(self._router)
|
||||
|
||||
config = Config(
|
||||
app=self._app, host=self._host, port=self._port, log_config=self._logging_conf
|
||||
app=self._app,
|
||||
host=self._host,
|
||||
port=self._port,
|
||||
log_config=self._logging_conf,
|
||||
)
|
||||
self._server = Server(config)
|
||||
|
||||
@ -145,7 +148,6 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
self._server_thread.start()
|
||||
|
||||
def destroy(self):
|
||||
|
||||
logger.info("Shutting down server")
|
||||
self._server.should_exit = True # 设置退出标志
|
||||
self._server.shutdown() # 停止服务器
|
||||
@ -155,7 +157,6 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
def join(self):
|
||||
self._server_thread.join()
|
||||
|
||||
|
||||
def set_app_event(self, started_event: mp.Event = None):
|
||||
@self._app.on_event("startup")
|
||||
async def on_startup():
|
||||
@ -190,12 +191,15 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
provider_model_bundle.model_type_instance.predefined_models()
|
||||
)
|
||||
# 获取自定义模型
|
||||
for model in provider_model_bundle.configuration.custom_configuration.models:
|
||||
|
||||
llm_models.append(provider_model_bundle.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
))
|
||||
for (
|
||||
model
|
||||
) in provider_model_bundle.configuration.custom_configuration.models:
|
||||
llm_models.append(
|
||||
provider_model_bundle.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error while fetching models for provider: {provider}, model_type: {model_type}"
|
||||
@ -225,13 +229,15 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
)
|
||||
|
||||
# 判断embeddings_request.input是否为list
|
||||
input = ''
|
||||
input = ""
|
||||
if isinstance(embeddings_request.input, list):
|
||||
tokens = embeddings_request.input
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(embeddings_request.model)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
logger.warning(
|
||||
"Warning: model not found. Using cl100k_base encoding."
|
||||
)
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
for i, token in enumerate(tokens):
|
||||
@ -241,7 +247,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
else:
|
||||
input = embeddings_request.input
|
||||
|
||||
response = model_instance.invoke_text_embedding(texts=[input], user="abc-123")
|
||||
response = model_instance.invoke_text_embedding(
|
||||
texts=[input], user="abc-123"
|
||||
)
|
||||
return await openai_embedding_text(response)
|
||||
|
||||
except ValueError as e:
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from ..._models import BaseModel
|
||||
|
||||
from pydantic import Field as FieldInfo
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..._models import BaseModel
|
||||
|
||||
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import enum
|
||||
from typing import Any, cast, List
|
||||
from typing import Any, List, cast
|
||||
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
@ -8,7 +8,6 @@ from langchain.schema import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from ..._models import BaseModel
|
||||
|
||||
from model_providers.core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -20,6 +19,8 @@ from model_providers.core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
from ..._models import BaseModel
|
||||
|
||||
|
||||
class PromptMessageFileType(enum.Enum):
|
||||
IMAGE = "image"
|
||||
|
||||
@ -1,9 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from ..._compat import PYDANTIC_V2, ConfigDict
|
||||
from ..._models import BaseModel
|
||||
|
||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||
from model_providers.core.model_runtime.entities.model_entities import (
|
||||
ModelType,
|
||||
@ -11,6 +8,9 @@ from model_providers.core.model_runtime.entities.model_entities import (
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
|
||||
from ..._compat import PYDANTIC_V2, ConfigDict
|
||||
from ..._models import BaseModel
|
||||
|
||||
|
||||
class ModelStatus(Enum):
|
||||
"""
|
||||
@ -80,9 +80,8 @@ class DefaultModelEntity(BaseModel):
|
||||
provider: DefaultModelProviderEntity
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=()
|
||||
)
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
else:
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
protected_namespaces = ()
|
||||
|
||||
@ -4,9 +4,6 @@ import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Dict, Iterator, List, Optional
|
||||
|
||||
from ..._compat import PYDANTIC_V2, ConfigDict
|
||||
from ..._models import BaseModel
|
||||
|
||||
from model_providers.core.entities.model_entities import (
|
||||
ModelStatus,
|
||||
ModelWithProviderEntity,
|
||||
@ -29,6 +26,9 @@ from model_providers.core.model_runtime.model_providers.__base.model_provider im
|
||||
ModelProvider,
|
||||
)
|
||||
|
||||
from ..._compat import PYDANTIC_V2, ConfigDict
|
||||
from ..._models import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -351,11 +351,9 @@ class ProviderModelBundle(BaseModel):
|
||||
model_type_instance: AIModel
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
arbitrary_types_allowed=True
|
||||
)
|
||||
model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)
|
||||
else:
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
from ..._compat import PYDANTIC_V2, ConfigDict
|
||||
from ..._models import BaseModel
|
||||
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
CUSTOM = "custom"
|
||||
@ -59,10 +59,9 @@ class RestrictModel(BaseModel):
|
||||
model_type: ModelType
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=()
|
||||
)
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
else:
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
@ -109,10 +108,9 @@ class CustomModelConfiguration(BaseModel):
|
||||
credentials: dict
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=()
|
||||
)
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
else:
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from ..._models import BaseModel
|
||||
|
||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
)
|
||||
|
||||
from ..._models import BaseModel
|
||||
|
||||
|
||||
class QueueEvent(Enum):
|
||||
"""
|
||||
|
||||
@ -2,8 +2,6 @@ from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from ...._models import BaseModel
|
||||
|
||||
from model_providers.core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@ -13,6 +11,8 @@ from model_providers.core.model_runtime.entities.model_entities import (
|
||||
PriceInfo,
|
||||
)
|
||||
|
||||
from ...._models import BaseModel
|
||||
|
||||
|
||||
class LLMMode(Enum):
|
||||
"""
|
||||
|
||||
@ -2,11 +2,11 @@ from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||
|
||||
from ...._compat import PYDANTIC_V2, ConfigDict
|
||||
from ...._models import BaseModel
|
||||
|
||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
"""
|
||||
@ -164,10 +164,9 @@ class ProviderModel(BaseModel):
|
||||
model_properties: Dict[ModelPropertyKey, Any]
|
||||
deprecated: bool = False
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=()
|
||||
)
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
else:
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
@ -1,9 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from ...._compat import PYDANTIC_V2, ConfigDict
|
||||
from ...._models import BaseModel
|
||||
|
||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||
from model_providers.core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
@ -11,6 +8,9 @@ from model_providers.core.model_runtime.entities.model_entities import (
|
||||
ProviderModel,
|
||||
)
|
||||
|
||||
from ...._compat import PYDANTIC_V2, ConfigDict
|
||||
from ...._models import BaseModel
|
||||
|
||||
|
||||
class ConfigurateMethod(Enum):
|
||||
"""
|
||||
@ -136,10 +136,9 @@ class ProviderEntity(BaseModel):
|
||||
model_credential_schema: Optional[ModelCredentialSchema] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=()
|
||||
)
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
else:
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
from decimal import Decimal
|
||||
from typing import List
|
||||
|
||||
from ...._models import BaseModel
|
||||
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelUsage
|
||||
|
||||
from ...._models import BaseModel
|
||||
|
||||
|
||||
class EmbeddingUsage(ModelUsage):
|
||||
"""
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from ....._models import BaseModel
|
||||
|
||||
from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
from model_providers.core.model_runtime.entities.llm_entities import LLMMode
|
||||
from model_providers.core.model_runtime.entities.model_entities import (
|
||||
@ -14,6 +12,8 @@ from model_providers.core.model_runtime.entities.model_entities import (
|
||||
PriceConfig,
|
||||
)
|
||||
|
||||
from ....._models import BaseModel
|
||||
|
||||
AZURE_OPENAI_API_VERSION = "2024-02-15-preview"
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from typing import Generator
|
||||
from typing import Dict, List, Optional, Type, Union, cast
|
||||
from typing import Dict, Generator, List, Optional, Type, Union, cast
|
||||
|
||||
import cohere
|
||||
from cohere.responses import Chat, Generations
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from typing import Generator
|
||||
from typing import List, Optional, Union, cast
|
||||
from decimal import Decimal
|
||||
from typing import Generator, List, Optional, Union, cast
|
||||
|
||||
import tiktoken
|
||||
from openai import OpenAI, Stream
|
||||
@ -37,10 +36,15 @@ from model_providers.core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
FetchFrom,
|
||||
I18nObject,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
PriceConfig, ParameterRule, ParameterType, ModelFeature, ModelPropertyKey, DefaultParameterName,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
PriceConfig,
|
||||
)
|
||||
from model_providers.core.model_runtime.errors.validate import (
|
||||
CredentialsValidateFailedError,
|
||||
@ -48,7 +52,9 @@ from model_providers.core.model_runtime.errors.validate import (
|
||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||
LargeLanguageModel,
|
||||
)
|
||||
from model_providers.core.model_runtime.model_providers.deepseek._common import _CommonDeepseek
|
||||
from model_providers.core.model_runtime.model_providers.deepseek._common import (
|
||||
_CommonDeepseek,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -1117,7 +1123,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
||||
return num_tokens
|
||||
|
||||
def get_customizable_model_schema(
|
||||
self, model: str, credentials: dict
|
||||
self, model: str, credentials: dict
|
||||
) -> AIModelEntity:
|
||||
"""
|
||||
Get customizable model schema.
|
||||
@ -1129,7 +1135,6 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
||||
"""
|
||||
extras = {}
|
||||
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(zh_Hans=model, en_US=model),
|
||||
@ -1149,8 +1154,8 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="The temperature of the model. "
|
||||
"Increasing the temperature will make the model answer "
|
||||
"more creatively. (Default: 0.8)"
|
||||
"Increasing the temperature will make the model answer "
|
||||
"more creatively. (Default: 0.8)"
|
||||
),
|
||||
default=0.8,
|
||||
min=0,
|
||||
@ -1163,8 +1168,8 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
||||
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
||||
"focused and conservative text. (Default: 0.9)"
|
||||
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
||||
"focused and conservative text. (Default: 0.9)"
|
||||
),
|
||||
default=0.9,
|
||||
min=0,
|
||||
@ -1177,8 +1182,8 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
||||
help=I18nObject(
|
||||
en_US="A number between -2.0 and 2.0. If positive, ",
|
||||
zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正,"
|
||||
"那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,"
|
||||
"降低模型重复相同内容的可能性"
|
||||
"那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,"
|
||||
"降低模型重复相同内容的可能性",
|
||||
),
|
||||
default=0,
|
||||
min=-2,
|
||||
@ -1190,7 +1195,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Sets how strongly to presence_penalty. ",
|
||||
zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其是否已在已有文本中出现受到相应的惩罚,从而增加模型谈论新主题的可能性。"
|
||||
zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其是否已在已有文本中出现受到相应的惩罚,从而增加模型谈论新主题的可能性。",
|
||||
),
|
||||
default=1.1,
|
||||
min=-2,
|
||||
@ -1204,7 +1209,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
||||
help=I18nObject(
|
||||
en_US="Maximum number of tokens to predict when generating text. ",
|
||||
zh_Hans="限制一次请求中模型生成 completion 的最大 token 数。"
|
||||
"输入 token 和输出 token 的总长度受模型的上下文长度的限制。"
|
||||
"输入 token 和输出 token 的总长度受模型的上下文长度的限制。",
|
||||
),
|
||||
default=128,
|
||||
min=-2,
|
||||
@ -1216,7 +1221,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
||||
type=ParameterType.BOOLEAN,
|
||||
help=I18nObject(
|
||||
en_US="Whether to return the log probabilities of the tokens. ",
|
||||
zh_Hans="是否返回所输出 token 的对数概率。如果为 true,则在 message 的 content 中返回每个输出 token 的对数概率。"
|
||||
zh_Hans="是否返回所输出 token 的对数概率。如果为 true,则在 message 的 content 中返回每个输出 token 的对数概率。",
|
||||
),
|
||||
),
|
||||
ParameterRule(
|
||||
@ -1226,7 +1231,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
||||
help=I18nObject(
|
||||
en_US="the format to return a response in.",
|
||||
zh_Hans="一个介于 0 到 20 之间的整数 N,指定每个输出位置返回输出概率 top N 的 token,"
|
||||
"且返回这些 token 的对数概率。指定此参数时,logprobs 必须为 true。"
|
||||
"且返回这些 token 的对数概率。指定此参数时,logprobs 必须为 true。",
|
||||
),
|
||||
default=0,
|
||||
min=0,
|
||||
|
||||
@ -3,8 +3,6 @@ import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ...._models import BaseModel
|
||||
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.model_runtime.entities.provider_entities import (
|
||||
ProviderConfig,
|
||||
@ -25,6 +23,8 @@ from model_providers.core.utils.position_helper import (
|
||||
sort_to_dict_by_position_map,
|
||||
)
|
||||
|
||||
from ...._models import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Generator
|
||||
from typing import List, Optional, Union
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
||||
from model_providers.core.model_runtime.entities.message_entities import (
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from typing import Generator
|
||||
from typing import List, Optional, Union, cast
|
||||
from decimal import Decimal
|
||||
from typing import Generator, List, Optional, Union, cast
|
||||
|
||||
import tiktoken
|
||||
from openai import OpenAI, Stream
|
||||
@ -37,10 +36,15 @@ from model_providers.core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
FetchFrom,
|
||||
I18nObject,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
PriceConfig, ModelFeature, ModelPropertyKey, DefaultParameterName, ParameterRule, ParameterType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
PriceConfig,
|
||||
)
|
||||
from model_providers.core.model_runtime.errors.validate import (
|
||||
CredentialsValidateFailedError,
|
||||
@ -48,7 +52,9 @@ from model_providers.core.model_runtime.errors.validate import (
|
||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||
LargeLanguageModel,
|
||||
)
|
||||
from model_providers.core.model_runtime.model_providers.ollama._common import _CommonOllama
|
||||
from model_providers.core.model_runtime.model_providers.ollama._common import (
|
||||
_CommonOllama,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -661,7 +667,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
model=model,
|
||||
stream=stream,
|
||||
extra_body=extra_body
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
if stream:
|
||||
@ -1120,7 +1126,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
return num_tokens
|
||||
|
||||
def get_customizable_model_schema(
|
||||
self, model: str, credentials: dict
|
||||
self, model: str, credentials: dict
|
||||
) -> AIModelEntity:
|
||||
"""
|
||||
Get customizable model schema.
|
||||
@ -1154,8 +1160,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="The temperature of the model. "
|
||||
"Increasing the temperature will make the model answer "
|
||||
"more creatively. (Default: 0.8)"
|
||||
"Increasing the temperature will make the model answer "
|
||||
"more creatively. (Default: 0.8)"
|
||||
),
|
||||
default=0.8,
|
||||
min=0,
|
||||
@ -1168,8 +1174,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
||||
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
||||
"focused and conservative text. (Default: 0.9)"
|
||||
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
||||
"focused and conservative text. (Default: 0.9)"
|
||||
),
|
||||
default=0.9,
|
||||
min=0,
|
||||
@ -1181,8 +1187,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Reduces the probability of generating nonsense. "
|
||||
"A higher value (e.g. 100) will give more diverse answers, "
|
||||
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"
|
||||
"A higher value (e.g. 100) will give more diverse answers, "
|
||||
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"
|
||||
),
|
||||
default=40,
|
||||
min=1,
|
||||
@ -1194,8 +1200,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Sets how strongly to penalize repetitions. "
|
||||
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
||||
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
|
||||
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
||||
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
|
||||
),
|
||||
default=1.1,
|
||||
min=-2,
|
||||
@ -1208,7 +1214,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Maximum number of tokens to predict when generating text. "
|
||||
"(Default: 128, -1 = infinite generation, -2 = fill context)"
|
||||
"(Default: 128, -1 = infinite generation, -2 = fill context)"
|
||||
),
|
||||
default=128,
|
||||
min=-2,
|
||||
@ -1220,7 +1226,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Enable Mirostat sampling for controlling perplexity. "
|
||||
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
|
||||
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
|
||||
),
|
||||
default=0,
|
||||
min=0,
|
||||
@ -1232,9 +1238,9 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Influences how quickly the algorithm responds to feedback from "
|
||||
"the generated text. A lower learning rate will result in slower adjustments, "
|
||||
"while a higher learning rate will make the algorithm more responsive. "
|
||||
"(Default: 0.1)"
|
||||
"the generated text. A lower learning rate will result in slower adjustments, "
|
||||
"while a higher learning rate will make the algorithm more responsive. "
|
||||
"(Default: 0.1)"
|
||||
),
|
||||
default=0.1,
|
||||
precision=1,
|
||||
@ -1245,7 +1251,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Controls the balance between coherence and diversity of the output. "
|
||||
"A lower value will result in more focused and coherent text. (Default: 5.0)"
|
||||
"A lower value will result in more focused and coherent text. (Default: 5.0)"
|
||||
),
|
||||
default=5.0,
|
||||
precision=1,
|
||||
@ -1256,7 +1262,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Sets the size of the context window used to generate the next token. "
|
||||
"(Default: 2048)"
|
||||
"(Default: 2048)"
|
||||
),
|
||||
default=2048,
|
||||
min=1,
|
||||
@ -1267,7 +1273,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="The number of layers to send to the GPU(s). "
|
||||
"On macOS it defaults to 1 to enable metal support, 0 to disable."
|
||||
"On macOS it defaults to 1 to enable metal support, 0 to disable."
|
||||
),
|
||||
default=1,
|
||||
min=0,
|
||||
@ -1279,9 +1285,9 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Sets the number of threads to use during computation. "
|
||||
"By default, Ollama will detect this for optimal performance. "
|
||||
"It is recommended to set this value to the number of physical CPU cores "
|
||||
"your system has (as opposed to the logical number of cores)."
|
||||
"By default, Ollama will detect this for optimal performance. "
|
||||
"It is recommended to set this value to the number of physical CPU cores "
|
||||
"your system has (as opposed to the logical number of cores)."
|
||||
),
|
||||
min=1,
|
||||
),
|
||||
@ -1291,7 +1297,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Sets how far back for the model to look back to prevent repetition. "
|
||||
"(Default: 64, 0 = disabled, -1 = num_ctx)"
|
||||
"(Default: 64, 0 = disabled, -1 = num_ctx)"
|
||||
),
|
||||
default=64,
|
||||
min=-1,
|
||||
@ -1302,8 +1308,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Tail free sampling is used to reduce the impact of less probable tokens "
|
||||
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
||||
"while a value of 1.0 disables this setting. (default: 1)"
|
||||
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
||||
"while a value of 1.0 disables this setting. (default: 1)"
|
||||
),
|
||||
default=1,
|
||||
precision=1,
|
||||
@ -1314,8 +1320,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Sets the random number seed to use for generation. Setting this to "
|
||||
"a specific number will make the model generate the same text for "
|
||||
"the same prompt. (Default: 0)"
|
||||
"a specific number will make the model generate the same text for "
|
||||
"the same prompt. (Default: 0)"
|
||||
),
|
||||
default=0,
|
||||
),
|
||||
@ -1325,7 +1331,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
||||
type=ParameterType.STRING,
|
||||
help=I18nObject(
|
||||
en_US="the format to return a response in."
|
||||
" Currently the only accepted value is json."
|
||||
" Currently the only accepted value is json."
|
||||
),
|
||||
options=["json"],
|
||||
),
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from typing import Generator
|
||||
from typing import List, Optional, Union, cast
|
||||
from typing import Generator, List, Optional, Union, cast
|
||||
|
||||
import tiktoken
|
||||
from openai import OpenAI, Stream
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator
|
||||
from decimal import Decimal
|
||||
from typing import List, Optional, Union, cast
|
||||
from typing import Generator, List, Optional, Union, cast
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from typing import Generator
|
||||
from enum import Enum
|
||||
from json import dumps, loads
|
||||
from typing import Any, Union
|
||||
from typing import Any, Generator, Union
|
||||
|
||||
from requests import Response, post
|
||||
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import threading
|
||||
from typing import Generator
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
from typing import Dict, Generator, List, Optional, Type, Union
|
||||
|
||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Generator
|
||||
from typing import List, Optional, Union
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
||||
from model_providers.core.model_runtime.entities.message_entities import (
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Generator
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
from typing import Dict, Generator, List, Optional, Type, Union
|
||||
|
||||
from dashscope import get_tokenizer
|
||||
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
from typing import Generator, Iterator
|
||||
|
||||
from typing import Dict, List, Union, cast, Type
|
||||
from typing import Dict, Generator, Iterator, List, Type, Union, cast
|
||||
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
@ -527,8 +525,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
def _extract_response_tool_calls(
|
||||
self,
|
||||
response_tool_calls: Union[
|
||||
List[ChatCompletionMessageToolCall],
|
||||
List[ChoiceDeltaToolCall]
|
||||
List[ChatCompletionMessageToolCall], List[ChoiceDeltaToolCall]
|
||||
],
|
||||
) -> List[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
|
||||
@ -195,8 +195,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
if stop:
|
||||
extra_model_kwargs["stop"] = stop
|
||||
|
||||
client = ZhipuAI(base_url=credentials_kwargs["api_base"],
|
||||
api_key=credentials_kwargs["api_key"])
|
||||
client = ZhipuAI(
|
||||
base_url=credentials_kwargs["api_base"],
|
||||
api_key=credentials_kwargs["api_key"],
|
||||
)
|
||||
|
||||
if len(prompt_messages) == 0:
|
||||
raise ValueError("At least one message is required")
|
||||
|
||||
@ -43,8 +43,10 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
:return: embeddings result
|
||||
"""
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = ZhipuAI(base_url=credentials_kwargs["api_base"],
|
||||
api_key=credentials_kwargs["api_key"])
|
||||
client = ZhipuAI(
|
||||
base_url=credentials_kwargs["api_base"],
|
||||
api_key=credentials_kwargs["api_key"],
|
||||
)
|
||||
|
||||
embeddings, embedding_used_tokens = self.embed_documents(model, client, texts)
|
||||
|
||||
@ -85,8 +87,10 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
try:
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = ZhipuAI(base_url=credentials_kwargs["api_base"],
|
||||
api_key=credentials_kwargs["api_key"])
|
||||
client = ZhipuAI(
|
||||
base_url=credentials_kwargs["api_base"],
|
||||
api_key=credentials_kwargs["api_key"],
|
||||
)
|
||||
|
||||
# call embedding model
|
||||
self.embed_documents(
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import pydantic
|
||||
|
||||
from ...._models import BaseModel
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import os
|
||||
|
||||
import orjson
|
||||
|
||||
from ..._models import BaseModel
|
||||
|
||||
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import os
|
||||
import shutil
|
||||
from typing import Generator
|
||||
from contextlib import closing
|
||||
from typing import Union
|
||||
from typing import Generator, Union
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
@ -104,15 +104,17 @@ def logging_conf() -> dict:
|
||||
111,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def providers_file(request) -> str:
|
||||
from pathlib import Path
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 当前执行目录
|
||||
# 获取当前测试文件的路径
|
||||
test_file_path = Path(str(request.fspath)).parent
|
||||
print("test_file_path:",test_file_path)
|
||||
return os.path.join(test_file_path,"model_providers.yaml")
|
||||
print("test_file_path:", test_file_path)
|
||||
return os.path.join(test_file_path, "model_providers.yaml")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -121,9 +123,7 @@ def init_server(logging_conf: dict, providers_file: str) -> None:
|
||||
try:
|
||||
boot = (
|
||||
BootstrapWebBuilder()
|
||||
.model_providers_cfg_path(
|
||||
model_providers_cfg_path=providers_file
|
||||
)
|
||||
.model_providers_cfg_path(model_providers_cfg_path=providers_file)
|
||||
.host(host="127.0.0.1")
|
||||
.port(port=20000)
|
||||
.build()
|
||||
@ -139,5 +139,4 @@ def init_server(logging_conf: dict, providers_file: str) -> None:
|
||||
boot.destroy()
|
||||
|
||||
except SystemExit:
|
||||
|
||||
raise
|
||||
|
||||
@ -1,15 +1,20 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_llm(init_server: str):
|
||||
llm = ChatOpenAI(model_name="deepseek-chat", openai_api_key="sk-dcb625fcbc1e497d80b7b9493b51d758", openai_api_base=f"{init_server}/deepseek/v1")
|
||||
llm = ChatOpenAI(
|
||||
model_name="deepseek-chat",
|
||||
openai_api_key="sk-dcb625fcbc1e497d80b7b9493b51d758",
|
||||
openai_api_base=f"{init_server}/deepseek/v1",
|
||||
)
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
|
||||
@ -1,15 +1,20 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_llm(init_server: str):
|
||||
llm = ChatOpenAI(model_name="llama3", openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/ollama/v1")
|
||||
llm = ChatOpenAI(
|
||||
model_name="llama3",
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_base=f"{init_server}/ollama/v1",
|
||||
)
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
@ -23,9 +28,11 @@ def test_llm(init_server: str):
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_embedding(init_server: str):
|
||||
embeddings = OpenAIEmbeddings(model="text-embedding-3-large",
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_base=f"{init_server}/zhipuai/v1")
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text-embedding-3-large",
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_base=f"{init_server}/zhipuai/v1",
|
||||
)
|
||||
|
||||
text = "你好"
|
||||
|
||||
|
||||
@ -1,15 +1,18 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_llm(init_server: str):
|
||||
llm = ChatOpenAI(openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1")
|
||||
llm = ChatOpenAI(
|
||||
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1"
|
||||
)
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
@ -21,17 +24,16 @@ def test_llm(init_server: str):
|
||||
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_embedding(init_server: str):
|
||||
|
||||
embeddings = OpenAIEmbeddings(model="text-embedding-3-large",
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_base=f"{init_server}/zhipuai/v1")
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text-embedding-3-large",
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_base=f"{init_server}/zhipuai/v1",
|
||||
)
|
||||
|
||||
text = "你好"
|
||||
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
|
||||
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -10,9 +11,10 @@ logger = logging.getLogger(__name__)
|
||||
@pytest.mark.requires("xinference_client")
|
||||
def test_llm(init_server: str):
|
||||
llm = ChatOpenAI(
|
||||
|
||||
model_name="glm-4",
|
||||
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/xinference/v1")
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_base=f"{init_server}/xinference/v1",
|
||||
)
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
@ -26,9 +28,11 @@ def test_llm(init_server: str):
|
||||
|
||||
@pytest.mark.requires("xinference-client")
|
||||
def test_embedding(init_server: str):
|
||||
embeddings = OpenAIEmbeddings(model="text_embedding",
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_base=f"{init_server}/xinference/v1")
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text_embedding",
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_base=f"{init_server}/xinference/v1",
|
||||
)
|
||||
|
||||
text = "你好"
|
||||
|
||||
|
||||
@ -1,17 +1,20 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.requires("zhipuai")
|
||||
def test_llm(init_server: str):
|
||||
llm = ChatOpenAI(
|
||||
|
||||
model_name="glm-4",
|
||||
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1")
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_base=f"{init_server}/zhipuai/v1",
|
||||
)
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
@ -25,15 +28,14 @@ def test_llm(init_server: str):
|
||||
|
||||
@pytest.mark.requires("zhipuai")
|
||||
def test_embedding(init_server: str):
|
||||
|
||||
embeddings = OpenAIEmbeddings(model="text_embedding",
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_base=f"{init_server}/zhipuai/v1")
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text_embedding",
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_base=f"{init_server}/zhipuai/v1",
|
||||
)
|
||||
|
||||
text = "你好"
|
||||
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
|
||||
|
||||
|
||||
|
||||
@ -7,7 +7,10 @@ from omegaconf import OmegaConf
|
||||
|
||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||
from model_providers.core.model_manager import ModelManager
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity
|
||||
from model_providers.core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
ModelType,
|
||||
)
|
||||
from model_providers.core.provider_manager import ProviderManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -16,9 +19,7 @@ logger = logging.getLogger(__name__)
|
||||
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(
|
||||
providers_file
|
||||
)
|
||||
cfg = OmegaConf.load(providers_file)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
@ -35,15 +36,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
||||
)
|
||||
llm_models: List[AIModelEntity] = []
|
||||
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
||||
|
||||
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
))
|
||||
llm_models.append(
|
||||
provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
)
|
||||
)
|
||||
|
||||
# 获取预定义模型
|
||||
llm_models.extend(
|
||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
||||
)
|
||||
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
|
||||
|
||||
logger.info(f"predefined_models: {llm_models}")
|
||||
|
||||
@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(
|
||||
providers_file
|
||||
)
|
||||
cfg = OmegaConf.load(providers_file)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
@ -34,16 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
||||
)
|
||||
llm_models: List[AIModelEntity] = []
|
||||
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
||||
|
||||
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
))
|
||||
llm_models.append(
|
||||
provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
)
|
||||
)
|
||||
|
||||
# 获取预定义模型
|
||||
llm_models.extend(
|
||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
||||
)
|
||||
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
|
||||
|
||||
logger.info(f"predefined_models: {llm_models}")
|
||||
|
||||
|
||||
@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(
|
||||
providers_file
|
||||
)
|
||||
cfg = OmegaConf.load(providers_file)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
@ -34,15 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
||||
)
|
||||
llm_models: List[AIModelEntity] = []
|
||||
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
||||
|
||||
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
))
|
||||
llm_models.append(
|
||||
provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
)
|
||||
)
|
||||
|
||||
# 获取预定义模型
|
||||
llm_models.extend(
|
||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
||||
)
|
||||
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
|
||||
|
||||
logger.info(f"predefined_models: {llm_models}")
|
||||
|
||||
@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(
|
||||
providers_file
|
||||
)
|
||||
cfg = OmegaConf.load(providers_file)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
@ -34,15 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
||||
)
|
||||
llm_models: List[AIModelEntity] = []
|
||||
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
||||
|
||||
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
))
|
||||
llm_models.append(
|
||||
provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
)
|
||||
)
|
||||
|
||||
# 获取预定义模型
|
||||
llm_models.extend(
|
||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
||||
)
|
||||
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
|
||||
|
||||
logger.info(f"predefined_models: {llm_models}")
|
||||
|
||||
@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||
logging.config.dictConfig(logging_conf) # type: ignore
|
||||
# 读取配置文件
|
||||
cfg = OmegaConf.load(
|
||||
providers_file
|
||||
)
|
||||
cfg = OmegaConf.load(providers_file)
|
||||
# 转换配置文件
|
||||
(
|
||||
provider_name_to_provider_records_dict,
|
||||
@ -34,15 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
||||
)
|
||||
llm_models: List[AIModelEntity] = []
|
||||
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
||||
|
||||
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
))
|
||||
llm_models.append(
|
||||
provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||
model=model.model,
|
||||
credentials=model.credentials,
|
||||
)
|
||||
)
|
||||
|
||||
# 获取预定义模型
|
||||
llm_models.extend(
|
||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
||||
)
|
||||
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
|
||||
|
||||
logger.info(f"predefined_models: {llm_models}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user