mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-04 13:43:12 +08:00
make format
This commit is contained in:
parent
6dd00b5d94
commit
4e9b1d6edf
@ -1,11 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
|
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from typing_extensions import Self
|
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from ._types import StrBytesIntFloat
|
from ._types import StrBytesIntFloat
|
||||||
|
|
||||||
@ -45,23 +45,49 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
|
from pydantic.v1.datetime_parse import (
|
||||||
|
parse_date as parse_date,
|
||||||
|
)
|
||||||
|
from pydantic.v1.datetime_parse import (
|
||||||
|
parse_datetime as parse_datetime,
|
||||||
|
)
|
||||||
from pydantic.v1.typing import (
|
from pydantic.v1.typing import (
|
||||||
get_args as get_args,
|
get_args as get_args,
|
||||||
is_union as is_union,
|
)
|
||||||
|
from pydantic.v1.typing import (
|
||||||
get_origin as get_origin,
|
get_origin as get_origin,
|
||||||
is_typeddict as is_typeddict,
|
)
|
||||||
|
from pydantic.v1.typing import (
|
||||||
is_literal_type as is_literal_type,
|
is_literal_type as is_literal_type,
|
||||||
)
|
)
|
||||||
from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
|
from pydantic.v1.typing import (
|
||||||
|
is_typeddict as is_typeddict,
|
||||||
|
)
|
||||||
|
from pydantic.v1.typing import (
|
||||||
|
is_union as is_union,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
from pydantic.datetime_parse import (
|
||||||
|
parse_date as parse_date,
|
||||||
|
)
|
||||||
|
from pydantic.datetime_parse import (
|
||||||
|
parse_datetime as parse_datetime,
|
||||||
|
)
|
||||||
from pydantic.typing import (
|
from pydantic.typing import (
|
||||||
get_args as get_args,
|
get_args as get_args,
|
||||||
is_union as is_union,
|
)
|
||||||
|
from pydantic.typing import (
|
||||||
get_origin as get_origin,
|
get_origin as get_origin,
|
||||||
is_typeddict as is_typeddict,
|
)
|
||||||
|
from pydantic.typing import (
|
||||||
is_literal_type as is_literal_type,
|
is_literal_type as is_literal_type,
|
||||||
)
|
)
|
||||||
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
|
from pydantic.typing import (
|
||||||
|
is_typeddict as is_typeddict,
|
||||||
|
)
|
||||||
|
from pydantic.typing import (
|
||||||
|
is_union as is_union,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# refactored config
|
# refactored config
|
||||||
@ -204,7 +230,9 @@ if TYPE_CHECKING:
|
|||||||
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T:
|
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T:
|
||||||
...
|
...
|
||||||
|
|
||||||
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
|
def __get__(
|
||||||
|
self, instance: object, owner: type[Any] | None = None
|
||||||
|
) -> _T | Self:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def __set_name__(self, owner: type[Any], name: str) -> None:
|
def __set_name__(self, owner: type[Any], name: str) -> None:
|
||||||
|
|||||||
@ -4,20 +4,20 @@ import io
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import overload
|
from typing import overload
|
||||||
from typing_extensions import TypeGuard
|
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
|
from typing_extensions import TypeGuard
|
||||||
|
|
||||||
from ._types import (
|
from ._types import (
|
||||||
FileTypes,
|
|
||||||
FileContent,
|
|
||||||
RequestFiles,
|
|
||||||
HttpxFileTypes,
|
|
||||||
Base64FileInput,
|
Base64FileInput,
|
||||||
|
FileContent,
|
||||||
|
FileTypes,
|
||||||
HttpxFileContent,
|
HttpxFileContent,
|
||||||
|
HttpxFileTypes,
|
||||||
HttpxRequestFiles,
|
HttpxRequestFiles,
|
||||||
|
RequestFiles,
|
||||||
)
|
)
|
||||||
from ._utils import is_tuple_t, is_mapping_t, is_sequence_t
|
from ._utils import is_mapping_t, is_sequence_t, is_tuple_t
|
||||||
|
|
||||||
|
|
||||||
def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
|
def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
|
||||||
@ -26,13 +26,20 @@ def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
|
|||||||
|
|
||||||
def is_file_content(obj: object) -> TypeGuard[FileContent]:
|
def is_file_content(obj: object) -> TypeGuard[FileContent]:
|
||||||
return (
|
return (
|
||||||
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
|
isinstance(obj, bytes)
|
||||||
|
or isinstance(obj, tuple)
|
||||||
|
or isinstance(obj, io.IOBase)
|
||||||
|
or isinstance(obj, os.PathLike)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
|
def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
|
||||||
if not is_file_content(obj):
|
if not is_file_content(obj):
|
||||||
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
|
prefix = (
|
||||||
|
f"Expected entry at `{key}`"
|
||||||
|
if key is not None
|
||||||
|
else f"Expected file input `{obj!r}`"
|
||||||
|
)
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads"
|
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads"
|
||||||
) from None
|
) from None
|
||||||
@ -57,7 +64,9 @@ def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
|
|||||||
elif is_sequence_t(files):
|
elif is_sequence_t(files):
|
||||||
files = [(key, _transform_file(file)) for key, file in files]
|
files = [(key, _transform_file(file)) for key, file in files]
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
|
raise TypeError(
|
||||||
|
f"Unexpected file type input {type(files)}, expected mapping or sequence"
|
||||||
|
)
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
@ -73,7 +82,9 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes:
|
|||||||
if is_tuple_t(file):
|
if is_tuple_t(file):
|
||||||
return (file[0], _read_file_content(file[1]), *file[2:])
|
return (file[0], _read_file_content(file[1]), *file[2:])
|
||||||
|
|
||||||
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
|
raise TypeError(
|
||||||
|
f"Expected file types input to be a FileContent type or to be a tuple"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _read_file_content(file: FileContent) -> HttpxFileContent:
|
def _read_file_content(file: FileContent) -> HttpxFileContent:
|
||||||
@ -101,7 +112,9 @@ async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles
|
|||||||
elif is_sequence_t(files):
|
elif is_sequence_t(files):
|
||||||
files = [(key, await _async_transform_file(file)) for key, file in files]
|
files = [(key, await _async_transform_file(file)) for key, file in files]
|
||||||
else:
|
else:
|
||||||
raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence")
|
raise TypeError(
|
||||||
|
"Unexpected file type input {type(files)}, expected mapping or sequence"
|
||||||
|
)
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
@ -117,7 +130,9 @@ async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
|
|||||||
if is_tuple_t(file):
|
if is_tuple_t(file):
|
||||||
return (file[0], await _async_read_file_content(file[1]), *file[2:])
|
return (file[0], await _async_read_file_content(file[1]), *file[2:])
|
||||||
|
|
||||||
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
|
raise TypeError(
|
||||||
|
f"Expected file types input to be a FileContent type or to be a tuple"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _async_read_file_content(file: FileContent) -> HttpxFileContent:
|
async def _async_read_file_content(file: FileContent) -> HttpxFileContent:
|
||||||
|
|||||||
@ -1,60 +1,62 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast
|
import os
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Generic, Type, TypeVar, Union, cast
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
import pydantic.generics
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
from typing_extensions import (
|
from typing_extensions import (
|
||||||
Unpack,
|
|
||||||
Literal,
|
|
||||||
ClassVar,
|
ClassVar,
|
||||||
|
Literal,
|
||||||
Protocol,
|
Protocol,
|
||||||
Required,
|
Required,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeGuard,
|
TypeGuard,
|
||||||
|
Unpack,
|
||||||
final,
|
final,
|
||||||
override,
|
override,
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
import pydantic
|
from ._compat import (
|
||||||
import pydantic.generics
|
PYDANTIC_V2,
|
||||||
from pydantic.fields import FieldInfo
|
ConfigDict,
|
||||||
|
field_get_default,
|
||||||
|
get_args,
|
||||||
|
get_model_config,
|
||||||
|
get_model_fields,
|
||||||
|
get_origin,
|
||||||
|
is_literal_type,
|
||||||
|
is_union,
|
||||||
|
parse_obj,
|
||||||
|
)
|
||||||
|
from ._compat import (
|
||||||
|
GenericModel as BaseGenericModel,
|
||||||
|
)
|
||||||
from ._types import (
|
from ._types import (
|
||||||
IncEx,
|
IncEx,
|
||||||
ModelT,
|
ModelT,
|
||||||
)
|
)
|
||||||
from ._utils import (
|
from ._utils import (
|
||||||
PropertyInfo,
|
PropertyInfo,
|
||||||
is_list,
|
|
||||||
is_given,
|
|
||||||
lru_cache,
|
|
||||||
is_mapping,
|
|
||||||
parse_date,
|
|
||||||
coerce_boolean,
|
coerce_boolean,
|
||||||
parse_datetime,
|
|
||||||
strip_not_given,
|
|
||||||
extract_type_arg,
|
extract_type_arg,
|
||||||
is_annotated_type,
|
is_annotated_type,
|
||||||
|
is_given,
|
||||||
|
is_list,
|
||||||
|
is_mapping,
|
||||||
|
lru_cache,
|
||||||
|
parse_date,
|
||||||
|
parse_datetime,
|
||||||
strip_annotated_type,
|
strip_annotated_type,
|
||||||
)
|
strip_not_given,
|
||||||
from ._compat import (
|
|
||||||
PYDANTIC_V2,
|
|
||||||
ConfigDict,
|
|
||||||
GenericModel as BaseGenericModel,
|
|
||||||
get_args,
|
|
||||||
is_union,
|
|
||||||
parse_obj,
|
|
||||||
get_origin,
|
|
||||||
is_literal_type,
|
|
||||||
get_model_config,
|
|
||||||
get_model_fields,
|
|
||||||
field_get_default,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema
|
from pydantic_core.core_schema import LiteralSchema, ModelField, ModelFieldsSchema
|
||||||
|
|
||||||
__all__ = ["BaseModel", "GenericModel"]
|
__all__ = ["BaseModel", "GenericModel"]
|
||||||
|
|
||||||
@ -69,7 +71,8 @@ class _ConfigProtocol(Protocol):
|
|||||||
class BaseModel(pydantic.BaseModel):
|
class BaseModel(pydantic.BaseModel):
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||||
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
|
extra="allow",
|
||||||
|
defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
@ -189,7 +192,9 @@ class BaseModel(pydantic.BaseModel):
|
|||||||
key = name
|
key = name
|
||||||
|
|
||||||
if key in values:
|
if key in values:
|
||||||
fields_values[name] = _construct_field(value=values[key], field=field, key=key)
|
fields_values[name] = _construct_field(
|
||||||
|
value=values[key], field=field, key=key
|
||||||
|
)
|
||||||
_fields_set.add(name)
|
_fields_set.add(name)
|
||||||
else:
|
else:
|
||||||
fields_values[name] = field_get_default(field)
|
fields_values[name] = field_get_default(field)
|
||||||
@ -412,9 +417,13 @@ def construct_type(*, value: object, type_: object) -> object:
|
|||||||
#
|
#
|
||||||
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
|
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
|
||||||
# we'd end up constructing `FooType` when it should be `BarType`.
|
# we'd end up constructing `FooType` when it should be `BarType`.
|
||||||
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
|
discriminator = _build_discriminated_union_meta(
|
||||||
|
union=type_, meta_annotations=meta
|
||||||
|
)
|
||||||
if discriminator and is_mapping(value):
|
if discriminator and is_mapping(value):
|
||||||
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
|
variant_value = value.get(
|
||||||
|
discriminator.field_alias_from or discriminator.field_name
|
||||||
|
)
|
||||||
if variant_value and isinstance(variant_value, str):
|
if variant_value and isinstance(variant_value, str):
|
||||||
variant_type = discriminator.mapping.get(variant_value)
|
variant_type = discriminator.mapping.get(variant_value)
|
||||||
if variant_type:
|
if variant_type:
|
||||||
@ -434,11 +443,19 @@ def construct_type(*, value: object, type_: object) -> object:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
_, items_type = get_args(type_) # Dict[_, items_type]
|
_, items_type = get_args(type_) # Dict[_, items_type]
|
||||||
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
|
return {
|
||||||
|
key: construct_type(value=item, type_=items_type)
|
||||||
|
for key, item in value.items()
|
||||||
|
}
|
||||||
|
|
||||||
if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):
|
if not is_literal_type(type_) and (
|
||||||
|
issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
|
||||||
|
):
|
||||||
if is_list(value):
|
if is_list(value):
|
||||||
return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
|
return [
|
||||||
|
cast(Any, type_).construct(**entry) if is_mapping(entry) else entry
|
||||||
|
for entry in value
|
||||||
|
]
|
||||||
|
|
||||||
if is_mapping(value):
|
if is_mapping(value):
|
||||||
if issubclass(type_, BaseModel):
|
if issubclass(type_, BaseModel):
|
||||||
@ -523,14 +540,19 @@ class DiscriminatorDetails:
|
|||||||
self.field_alias_from = discriminator_alias
|
self.field_alias_from = discriminator_alias
|
||||||
|
|
||||||
|
|
||||||
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
|
def _build_discriminated_union_meta(
|
||||||
|
*, union: type, meta_annotations: tuple[Any, ...]
|
||||||
|
) -> DiscriminatorDetails | None:
|
||||||
if isinstance(union, CachedDiscriminatorType):
|
if isinstance(union, CachedDiscriminatorType):
|
||||||
return union.__discriminator__
|
return union.__discriminator__
|
||||||
|
|
||||||
discriminator_field_name: str | None = None
|
discriminator_field_name: str | None = None
|
||||||
|
|
||||||
for annotation in meta_annotations:
|
for annotation in meta_annotations:
|
||||||
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
|
if (
|
||||||
|
isinstance(annotation, PropertyInfo)
|
||||||
|
and annotation.discriminator is not None
|
||||||
|
):
|
||||||
discriminator_field_name = annotation.discriminator
|
discriminator_field_name = annotation.discriminator
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -558,7 +580,9 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
|
|||||||
if isinstance(entry, str):
|
if isinstance(entry, str):
|
||||||
mapping[entry] = variant
|
mapping[entry] = variant
|
||||||
else:
|
else:
|
||||||
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(
|
||||||
|
discriminator_field_name
|
||||||
|
) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
||||||
if not field_info:
|
if not field_info:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -582,7 +606,9 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
|
|||||||
return details
|
return details
|
||||||
|
|
||||||
|
|
||||||
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
|
def _extract_field_schema_pv2(
|
||||||
|
model: type[BaseModel], field_name: str
|
||||||
|
) -> ModelField | None:
|
||||||
schema = model.__pydantic_core_schema__
|
schema = model.__pydantic_core_schema__
|
||||||
if schema["type"] != "model":
|
if schema["type"] != "model":
|
||||||
return None
|
return None
|
||||||
@ -621,7 +647,9 @@ else:
|
|||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
from pydantic import TypeAdapter as _TypeAdapter
|
from pydantic import TypeAdapter as _TypeAdapter
|
||||||
|
|
||||||
_CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter))
|
_CachedTypeAdapter = cast(
|
||||||
|
"TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter)
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
@ -652,6 +680,3 @@ elif not TYPE_CHECKING: # TODO: condition is weird
|
|||||||
|
|
||||||
def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
|
def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
|
||||||
return RootModel[type_] # type: ignore
|
return RootModel[type_] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,22 +5,29 @@ from typing import (
|
|||||||
IO,
|
IO,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Type,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
Mapping,
|
Mapping,
|
||||||
TypeVar,
|
|
||||||
Callable,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
from typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pydantic
|
import pydantic
|
||||||
from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport
|
from httpx import URL, AsyncBaseTransport, BaseTransport, Proxy, Response, Timeout
|
||||||
|
from typing_extensions import (
|
||||||
|
Literal,
|
||||||
|
Protocol,
|
||||||
|
TypeAlias,
|
||||||
|
TypedDict,
|
||||||
|
override,
|
||||||
|
runtime_checkable,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ._models import BaseModel
|
from ._models import BaseModel
|
||||||
@ -43,7 +50,9 @@ if TYPE_CHECKING:
|
|||||||
FileContent = Union[IO[bytes], bytes, PathLike[str]]
|
FileContent = Union[IO[bytes], bytes, PathLike[str]]
|
||||||
else:
|
else:
|
||||||
Base64FileInput = Union[IO[bytes], PathLike]
|
Base64FileInput = Union[IO[bytes], PathLike]
|
||||||
FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.
|
FileContent = Union[
|
||||||
|
IO[bytes], bytes, PathLike
|
||||||
|
] # PathLike is not subscriptable in Python 3.8.
|
||||||
FileTypes = Union[
|
FileTypes = Union[
|
||||||
# file (or bytes)
|
# file (or bytes)
|
||||||
FileContent,
|
FileContent,
|
||||||
@ -68,7 +77,9 @@ HttpxFileTypes = Union[
|
|||||||
# (filename, file (or bytes), content_type, headers)
|
# (filename, file (or bytes), content_type, headers)
|
||||||
Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
|
Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
|
||||||
]
|
]
|
||||||
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]
|
HttpxRequestFiles = Union[
|
||||||
|
Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]
|
||||||
|
]
|
||||||
|
|
||||||
# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT
|
# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT
|
||||||
# where ResponseT includes `None`. In order to support directly
|
# where ResponseT includes `None`. In order to support directly
|
||||||
|
|||||||
@ -1,48 +1,126 @@
|
|||||||
from ._utils import (
|
from ._transform import (
|
||||||
flatten as flatten,
|
PropertyInfo as PropertyInfo,
|
||||||
is_dict as is_dict,
|
)
|
||||||
is_list as is_list,
|
from ._transform import (
|
||||||
is_given as is_given,
|
async_maybe_transform as async_maybe_transform,
|
||||||
is_tuple as is_tuple,
|
)
|
||||||
lru_cache as lru_cache,
|
from ._transform import (
|
||||||
is_mapping as is_mapping,
|
async_transform as async_transform,
|
||||||
is_tuple_t as is_tuple_t,
|
)
|
||||||
parse_date as parse_date,
|
from ._transform import (
|
||||||
is_iterable as is_iterable,
|
maybe_transform as maybe_transform,
|
||||||
is_sequence as is_sequence,
|
)
|
||||||
coerce_float as coerce_float,
|
from ._transform import (
|
||||||
is_mapping_t as is_mapping_t,
|
transform as transform,
|
||||||
removeprefix as removeprefix,
|
)
|
||||||
removesuffix as removesuffix,
|
from ._typing import (
|
||||||
extract_files as extract_files,
|
extract_type_arg as extract_type_arg,
|
||||||
is_sequence_t as is_sequence_t,
|
)
|
||||||
required_args as required_args,
|
from ._typing import (
|
||||||
coerce_boolean as coerce_boolean,
|
extract_type_var_from_base as extract_type_var_from_base,
|
||||||
coerce_integer as coerce_integer,
|
)
|
||||||
file_from_path as file_from_path,
|
from ._typing import (
|
||||||
parse_datetime as parse_datetime,
|
is_annotated_type as is_annotated_type,
|
||||||
strip_not_given as strip_not_given,
|
)
|
||||||
deepcopy_minimal as deepcopy_minimal,
|
from ._typing import (
|
||||||
get_async_library as get_async_library,
|
is_iterable_type as is_iterable_type,
|
||||||
maybe_coerce_float as maybe_coerce_float,
|
|
||||||
get_required_header as get_required_header,
|
|
||||||
maybe_coerce_boolean as maybe_coerce_boolean,
|
|
||||||
maybe_coerce_integer as maybe_coerce_integer,
|
|
||||||
)
|
)
|
||||||
from ._typing import (
|
from ._typing import (
|
||||||
is_list_type as is_list_type,
|
is_list_type as is_list_type,
|
||||||
is_union_type as is_union_type,
|
)
|
||||||
extract_type_arg as extract_type_arg,
|
from ._typing import (
|
||||||
is_iterable_type as is_iterable_type,
|
|
||||||
is_required_type as is_required_type,
|
is_required_type as is_required_type,
|
||||||
is_annotated_type as is_annotated_type,
|
)
|
||||||
|
from ._typing import (
|
||||||
|
is_union_type as is_union_type,
|
||||||
|
)
|
||||||
|
from ._typing import (
|
||||||
strip_annotated_type as strip_annotated_type,
|
strip_annotated_type as strip_annotated_type,
|
||||||
extract_type_var_from_base as extract_type_var_from_base,
|
|
||||||
)
|
)
|
||||||
from ._transform import (
|
from ._utils import (
|
||||||
PropertyInfo as PropertyInfo,
|
coerce_boolean as coerce_boolean,
|
||||||
transform as transform,
|
)
|
||||||
async_transform as async_transform,
|
from ._utils import (
|
||||||
maybe_transform as maybe_transform,
|
coerce_float as coerce_float,
|
||||||
async_maybe_transform as async_maybe_transform,
|
)
|
||||||
|
from ._utils import (
|
||||||
|
coerce_integer as coerce_integer,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
deepcopy_minimal as deepcopy_minimal,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
extract_files as extract_files,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
file_from_path as file_from_path,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
flatten as flatten,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
get_async_library as get_async_library,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
get_required_header as get_required_header,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
is_dict as is_dict,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
is_given as is_given,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
is_iterable as is_iterable,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
is_list as is_list,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
is_mapping as is_mapping,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
is_mapping_t as is_mapping_t,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
is_sequence as is_sequence,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
is_sequence_t as is_sequence_t,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
is_tuple as is_tuple,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
is_tuple_t as is_tuple_t,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
lru_cache as lru_cache,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
maybe_coerce_boolean as maybe_coerce_boolean,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
maybe_coerce_float as maybe_coerce_float,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
maybe_coerce_integer as maybe_coerce_integer,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
parse_date as parse_date,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
parse_datetime as parse_datetime,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
removeprefix as removeprefix,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
removesuffix as removesuffix,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
required_args as required_args,
|
||||||
|
)
|
||||||
|
from ._utils import (
|
||||||
|
strip_not_given as strip_not_given,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,31 +1,31 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import io
|
|
||||||
import base64
|
import base64
|
||||||
|
import io
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import Any, Mapping, TypeVar, cast
|
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from typing_extensions import Literal, get_args, override, get_type_hints
|
from typing import Any, Mapping, TypeVar, cast
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import pydantic
|
import pydantic
|
||||||
|
from typing_extensions import Literal, get_args, get_type_hints, override
|
||||||
|
|
||||||
from ._utils import (
|
from .._compat import is_typeddict, model_dump
|
||||||
is_list,
|
|
||||||
is_mapping,
|
|
||||||
is_iterable,
|
|
||||||
)
|
|
||||||
from .._files import is_base64_file_input
|
from .._files import is_base64_file_input
|
||||||
from ._typing import (
|
from ._typing import (
|
||||||
is_list_type,
|
|
||||||
is_union_type,
|
|
||||||
extract_type_arg,
|
extract_type_arg,
|
||||||
is_iterable_type,
|
|
||||||
is_required_type,
|
|
||||||
is_annotated_type,
|
is_annotated_type,
|
||||||
|
is_iterable_type,
|
||||||
|
is_list_type,
|
||||||
|
is_required_type,
|
||||||
|
is_union_type,
|
||||||
strip_annotated_type,
|
strip_annotated_type,
|
||||||
)
|
)
|
||||||
from .._compat import model_dump, is_typeddict
|
from ._utils import (
|
||||||
|
is_iterable,
|
||||||
|
is_list,
|
||||||
|
is_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
@ -171,10 +171,17 @@ def _transform_recursive(
|
|||||||
# List[T]
|
# List[T]
|
||||||
(is_list_type(stripped_type) and is_list(data))
|
(is_list_type(stripped_type) and is_list(data))
|
||||||
# Iterable[T]
|
# Iterable[T]
|
||||||
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
|
or (
|
||||||
|
is_iterable_type(stripped_type)
|
||||||
|
and is_iterable(data)
|
||||||
|
and not isinstance(data, str)
|
||||||
|
)
|
||||||
):
|
):
|
||||||
inner_type = extract_type_arg(stripped_type, 0)
|
inner_type = extract_type_arg(stripped_type, 0)
|
||||||
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
|
return [
|
||||||
|
_transform_recursive(d, annotation=annotation, inner_type=inner_type)
|
||||||
|
for d in data
|
||||||
|
]
|
||||||
|
|
||||||
if is_union_type(stripped_type):
|
if is_union_type(stripped_type):
|
||||||
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
|
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
|
||||||
@ -201,7 +208,9 @@ def _transform_recursive(
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
|
def _format_data(
|
||||||
|
data: object, format_: PropertyFormat, format_template: str | None
|
||||||
|
) -> object:
|
||||||
if isinstance(data, (date, datetime)):
|
if isinstance(data, (date, datetime)):
|
||||||
if format_ == "iso8601":
|
if format_ == "iso8601":
|
||||||
return data.isoformat()
|
return data.isoformat()
|
||||||
@ -221,7 +230,9 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N
|
|||||||
binary = binary.encode()
|
binary = binary.encode()
|
||||||
|
|
||||||
if not isinstance(binary, bytes):
|
if not isinstance(binary, bytes):
|
||||||
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
|
raise RuntimeError(
|
||||||
|
f"Could not read bytes from {data}; Received {type(binary)}"
|
||||||
|
)
|
||||||
|
|
||||||
return base64.b64encode(binary).decode("ascii")
|
return base64.b64encode(binary).decode("ascii")
|
||||||
|
|
||||||
@ -240,7 +251,9 @@ def _transform_typeddict(
|
|||||||
# we do not have a type annotation for this field, leave it as is
|
# we do not have a type annotation for this field, leave it as is
|
||||||
result[key] = value
|
result[key] = value
|
||||||
else:
|
else:
|
||||||
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
|
result[_maybe_transform_key(key, type_)] = _transform_recursive(
|
||||||
|
value, annotation=type_
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -276,7 +289,9 @@ async def async_transform(
|
|||||||
|
|
||||||
It should be noted that the transformations that this function does are not represented in the type system.
|
It should be noted that the transformations that this function does are not represented in the type system.
|
||||||
"""
|
"""
|
||||||
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
|
transformed = await _async_transform_recursive(
|
||||||
|
data, annotation=cast(type, expected_type)
|
||||||
|
)
|
||||||
return cast(_T, transformed)
|
return cast(_T, transformed)
|
||||||
|
|
||||||
|
|
||||||
@ -309,10 +324,19 @@ async def _async_transform_recursive(
|
|||||||
# List[T]
|
# List[T]
|
||||||
(is_list_type(stripped_type) and is_list(data))
|
(is_list_type(stripped_type) and is_list(data))
|
||||||
# Iterable[T]
|
# Iterable[T]
|
||||||
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
|
or (
|
||||||
|
is_iterable_type(stripped_type)
|
||||||
|
and is_iterable(data)
|
||||||
|
and not isinstance(data, str)
|
||||||
|
)
|
||||||
):
|
):
|
||||||
inner_type = extract_type_arg(stripped_type, 0)
|
inner_type = extract_type_arg(stripped_type, 0)
|
||||||
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
|
return [
|
||||||
|
await _async_transform_recursive(
|
||||||
|
d, annotation=annotation, inner_type=inner_type
|
||||||
|
)
|
||||||
|
for d in data
|
||||||
|
]
|
||||||
|
|
||||||
if is_union_type(stripped_type):
|
if is_union_type(stripped_type):
|
||||||
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
|
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
|
||||||
@ -320,7 +344,9 @@ async def _async_transform_recursive(
|
|||||||
# TODO: there may be edge cases where the same normalized field name will transform to two different names
|
# TODO: there may be edge cases where the same normalized field name will transform to two different names
|
||||||
# in different subtypes.
|
# in different subtypes.
|
||||||
for subtype in get_args(stripped_type):
|
for subtype in get_args(stripped_type):
|
||||||
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
|
data = await _async_transform_recursive(
|
||||||
|
data, annotation=annotation, inner_type=subtype
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
if isinstance(data, pydantic.BaseModel):
|
if isinstance(data, pydantic.BaseModel):
|
||||||
@ -334,12 +360,16 @@ async def _async_transform_recursive(
|
|||||||
annotations = get_args(annotated_type)[1:]
|
annotations = get_args(annotated_type)[1:]
|
||||||
for annotation in annotations:
|
for annotation in annotations:
|
||||||
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
|
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
|
||||||
return await _async_format_data(data, annotation.format, annotation.format_template)
|
return await _async_format_data(
|
||||||
|
data, annotation.format, annotation.format_template
|
||||||
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
|
async def _async_format_data(
|
||||||
|
data: object, format_: PropertyFormat, format_template: str | None
|
||||||
|
) -> object:
|
||||||
if isinstance(data, (date, datetime)):
|
if isinstance(data, (date, datetime)):
|
||||||
if format_ == "iso8601":
|
if format_ == "iso8601":
|
||||||
return data.isoformat()
|
return data.isoformat()
|
||||||
@ -359,7 +389,9 @@ async def _async_format_data(data: object, format_: PropertyFormat, format_templ
|
|||||||
binary = binary.encode()
|
binary = binary.encode()
|
||||||
|
|
||||||
if not isinstance(binary, bytes):
|
if not isinstance(binary, bytes):
|
||||||
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
|
raise RuntimeError(
|
||||||
|
f"Could not read bytes from {data}; Received {type(binary)}"
|
||||||
|
)
|
||||||
|
|
||||||
return base64.b64encode(binary).decode("ascii")
|
return base64.b64encode(binary).decode("ascii")
|
||||||
|
|
||||||
@ -378,5 +410,7 @@ async def _async_transform_typeddict(
|
|||||||
# we do not have a type annotation for this field, leave it as is
|
# we do not have a type annotation for this field, leave it as is
|
||||||
result[key] = value
|
result[key] = value
|
||||||
else:
|
else:
|
||||||
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
|
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(
|
||||||
|
value, annotation=type_
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, TypeVar, Iterable, cast
|
|
||||||
from collections import abc as _c_abc
|
from collections import abc as _c_abc
|
||||||
from typing_extensions import Required, Annotated, get_args, get_origin
|
from typing import Any, Iterable, TypeVar, cast
|
||||||
|
|
||||||
|
from typing_extensions import Annotated, Required, get_args, get_origin
|
||||||
|
|
||||||
from .._types import InheritsGeneric
|
|
||||||
from .._compat import is_union as _is_union
|
from .._compat import is_union as _is_union
|
||||||
|
from .._types import InheritsGeneric
|
||||||
|
|
||||||
|
|
||||||
def is_annotated_type(typ: type) -> bool:
|
def is_annotated_type(typ: type) -> bool:
|
||||||
@ -49,15 +50,17 @@ def extract_type_arg(typ: type, index: int) -> type:
|
|||||||
try:
|
try:
|
||||||
return cast(type, args[index])
|
return cast(type, args[index])
|
||||||
except IndexError as err:
|
except IndexError as err:
|
||||||
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
|
raise RuntimeError(
|
||||||
|
f"Expected type {typ} to have a type argument at index {index} but it did not"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
|
||||||
def extract_type_var_from_base(
|
def extract_type_var_from_base(
|
||||||
typ: type,
|
typ: type,
|
||||||
*,
|
*,
|
||||||
generic_bases: tuple[type, ...],
|
generic_bases: tuple[type, ...],
|
||||||
index: int,
|
index: int,
|
||||||
failure_message: str | None = None,
|
failure_message: str | None = None,
|
||||||
) -> type:
|
) -> type:
|
||||||
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
|
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
|
||||||
|
|
||||||
@ -117,4 +120,7 @@ def extract_type_var_from_base(
|
|||||||
|
|
||||||
return extracted
|
return extracted
|
||||||
|
|
||||||
raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")
|
raise RuntimeError(
|
||||||
|
failure_message
|
||||||
|
or f"Could not resolve inner type variable at index {index} for {typ}"
|
||||||
|
)
|
||||||
|
|||||||
@ -1,27 +1,28 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import inspect
|
from pathlib import Path
|
||||||
import functools
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Tuple,
|
|
||||||
Mapping,
|
|
||||||
TypeVar,
|
|
||||||
Callable,
|
Callable,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
Mapping,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
cast,
|
cast,
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
from pathlib import Path
|
|
||||||
from typing_extensions import TypeGuard
|
|
||||||
|
|
||||||
import sniffio
|
import sniffio
|
||||||
|
from typing_extensions import TypeGuard
|
||||||
|
|
||||||
from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike
|
from .._compat import parse_date as parse_date
|
||||||
from .._compat import parse_date as parse_date, parse_datetime as parse_datetime
|
from .._compat import parse_datetime as parse_datetime
|
||||||
|
from .._types import FileTypes, Headers, HeadersLike, NotGiven, NotGivenOr
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
|
_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
|
||||||
@ -108,7 +109,9 @@ def _extract_items(
|
|||||||
item,
|
item,
|
||||||
path,
|
path,
|
||||||
index=index,
|
index=index,
|
||||||
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
|
flattened_key=flattened_key + "[]"
|
||||||
|
if flattened_key is not None
|
||||||
|
else "[]",
|
||||||
)
|
)
|
||||||
for item in obj
|
for item in obj
|
||||||
]
|
]
|
||||||
@ -261,7 +264,12 @@ def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
|
|||||||
else: # no break
|
else: # no break
|
||||||
if len(variants) > 1:
|
if len(variants) > 1:
|
||||||
variations = human_join(
|
variations = human_join(
|
||||||
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
|
[
|
||||||
|
"("
|
||||||
|
+ human_join([quote(arg) for arg in variant], final="and")
|
||||||
|
+ ")"
|
||||||
|
for variant in variants
|
||||||
|
]
|
||||||
)
|
)
|
||||||
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
|
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
|
||||||
else:
|
else:
|
||||||
@ -376,7 +384,11 @@ def get_required_header(headers: HeadersLike, header: str) -> str:
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
""" to deal with the case where the header looks like Stainless-Event-Id """
|
""" to deal with the case where the header looks like Stainless-Event-Id """
|
||||||
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
|
intercaps_header = re.sub(
|
||||||
|
r"([^\w])(\w)",
|
||||||
|
lambda pat: pat.group(1) + pat.group(2).upper(),
|
||||||
|
header.capitalize(),
|
||||||
|
)
|
||||||
|
|
||||||
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
|
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
|
||||||
value = headers.get(normalized_header)
|
value = headers.get(normalized_header)
|
||||||
|
|||||||
@ -1,9 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Literal, Optional
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
from ..._compat import PYDANTIC_V2, ConfigDict
|
|
||||||
from ..._models import BaseModel
|
|
||||||
|
|
||||||
from model_providers.core.entities.model_entities import (
|
from model_providers.core.entities.model_entities import (
|
||||||
ModelStatus,
|
ModelStatus,
|
||||||
ModelWithProviderEntity,
|
ModelWithProviderEntity,
|
||||||
@ -26,6 +23,9 @@ from model_providers.core.model_runtime.entities.provider_entities import (
|
|||||||
SimpleProviderEntity,
|
SimpleProviderEntity,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ..._compat import PYDANTIC_V2, ConfigDict
|
||||||
|
from ..._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class CustomConfigurationStatus(Enum):
|
class CustomConfigurationStatus(Enum):
|
||||||
"""
|
"""
|
||||||
@ -74,10 +74,9 @@ class ProviderResponse(BaseModel):
|
|||||||
custom_configuration: CustomConfigurationResponse
|
custom_configuration: CustomConfigurationResponse
|
||||||
system_configuration: SystemConfigurationResponse
|
system_configuration: SystemConfigurationResponse
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
protected_namespaces=()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
@ -180,10 +179,9 @@ class DefaultModelResponse(BaseModel):
|
|||||||
provider: SimpleProviderEntityResponse
|
provider: SimpleProviderEntityResponse
|
||||||
|
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
protected_namespaces=()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|||||||
@ -73,7 +73,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
self._server = None
|
self._server = None
|
||||||
self._server_thread = None
|
self._server_thread = None
|
||||||
|
|
||||||
def logging_conf(self,logging_conf: Optional[dict] = None):
|
def logging_conf(self, logging_conf: Optional[dict] = None):
|
||||||
self._logging_conf = logging_conf
|
self._logging_conf = logging_conf
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -132,7 +132,10 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
self._app.include_router(self._router)
|
self._app.include_router(self._router)
|
||||||
|
|
||||||
config = Config(
|
config = Config(
|
||||||
app=self._app, host=self._host, port=self._port, log_config=self._logging_conf
|
app=self._app,
|
||||||
|
host=self._host,
|
||||||
|
port=self._port,
|
||||||
|
log_config=self._logging_conf,
|
||||||
)
|
)
|
||||||
self._server = Server(config)
|
self._server = Server(config)
|
||||||
|
|
||||||
@ -145,7 +148,6 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
self._server_thread.start()
|
self._server_thread.start()
|
||||||
|
|
||||||
def destroy(self):
|
def destroy(self):
|
||||||
|
|
||||||
logger.info("Shutting down server")
|
logger.info("Shutting down server")
|
||||||
self._server.should_exit = True # 设置退出标志
|
self._server.should_exit = True # 设置退出标志
|
||||||
self._server.shutdown() # 停止服务器
|
self._server.shutdown() # 停止服务器
|
||||||
@ -155,7 +157,6 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
def join(self):
|
def join(self):
|
||||||
self._server_thread.join()
|
self._server_thread.join()
|
||||||
|
|
||||||
|
|
||||||
def set_app_event(self, started_event: mp.Event = None):
|
def set_app_event(self, started_event: mp.Event = None):
|
||||||
@self._app.on_event("startup")
|
@self._app.on_event("startup")
|
||||||
async def on_startup():
|
async def on_startup():
|
||||||
@ -190,12 +191,15 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
provider_model_bundle.model_type_instance.predefined_models()
|
provider_model_bundle.model_type_instance.predefined_models()
|
||||||
)
|
)
|
||||||
# 获取自定义模型
|
# 获取自定义模型
|
||||||
for model in provider_model_bundle.configuration.custom_configuration.models:
|
for (
|
||||||
|
model
|
||||||
llm_models.append(provider_model_bundle.model_type_instance.get_model_schema(
|
) in provider_model_bundle.configuration.custom_configuration.models:
|
||||||
model=model.model,
|
llm_models.append(
|
||||||
credentials=model.credentials,
|
provider_model_bundle.model_type_instance.get_model_schema(
|
||||||
))
|
model=model.model,
|
||||||
|
credentials=model.credentials,
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error while fetching models for provider: {provider}, model_type: {model_type}"
|
f"Error while fetching models for provider: {provider}, model_type: {model_type}"
|
||||||
@ -225,13 +229,15 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 判断embeddings_request.input是否为list
|
# 判断embeddings_request.input是否为list
|
||||||
input = ''
|
input = ""
|
||||||
if isinstance(embeddings_request.input, list):
|
if isinstance(embeddings_request.input, list):
|
||||||
tokens = embeddings_request.input
|
tokens = embeddings_request.input
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.encoding_for_model(embeddings_request.model)
|
encoding = tiktoken.encoding_for_model(embeddings_request.model)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
logger.warning(
|
||||||
|
"Warning: model not found. Using cl100k_base encoding."
|
||||||
|
)
|
||||||
model = "cl100k_base"
|
model = "cl100k_base"
|
||||||
encoding = tiktoken.get_encoding(model)
|
encoding = tiktoken.get_encoding(model)
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
@ -241,7 +247,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
else:
|
else:
|
||||||
input = embeddings_request.input
|
input = embeddings_request.input
|
||||||
|
|
||||||
response = model_instance.invoke_text_embedding(texts=[input], user="abc-123")
|
response = model_instance.invoke_text_embedding(
|
||||||
|
texts=[input], user="abc-123"
|
||||||
|
)
|
||||||
return await openai_embedding_text(response)
|
return await openai_embedding_text(response)
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from ..._models import BaseModel
|
|
||||||
from pydantic import Field as FieldInfo
|
from pydantic import Field as FieldInfo
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
from ..._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class Role(str, Enum):
|
class Role(str, Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import enum
|
import enum
|
||||||
from typing import Any, cast, List
|
from typing import Any, List, cast
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
@ -8,7 +8,6 @@ from langchain.schema import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from ..._models import BaseModel
|
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
@ -20,6 +19,8 @@ from model_providers.core.model_runtime.entities.message_entities import (
|
|||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ..._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class PromptMessageFileType(enum.Enum):
|
class PromptMessageFileType(enum.Enum):
|
||||||
IMAGE = "image"
|
IMAGE = "image"
|
||||||
|
|||||||
@ -1,9 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from ..._compat import PYDANTIC_V2, ConfigDict
|
|
||||||
from ..._models import BaseModel
|
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||||
from model_providers.core.model_runtime.entities.model_entities import (
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
ModelType,
|
ModelType,
|
||||||
@ -11,6 +8,9 @@ from model_providers.core.model_runtime.entities.model_entities import (
|
|||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity
|
from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity
|
||||||
|
|
||||||
|
from ..._compat import PYDANTIC_V2, ConfigDict
|
||||||
|
from ..._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class ModelStatus(Enum):
|
class ModelStatus(Enum):
|
||||||
"""
|
"""
|
||||||
@ -80,9 +80,8 @@ class DefaultModelEntity(BaseModel):
|
|||||||
provider: DefaultModelProviderEntity
|
provider: DefaultModelProviderEntity
|
||||||
|
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
protected_namespaces=()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|||||||
@ -4,9 +4,6 @@ import logging
|
|||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Dict, Iterator, List, Optional
|
from typing import Dict, Iterator, List, Optional
|
||||||
|
|
||||||
from ..._compat import PYDANTIC_V2, ConfigDict
|
|
||||||
from ..._models import BaseModel
|
|
||||||
|
|
||||||
from model_providers.core.entities.model_entities import (
|
from model_providers.core.entities.model_entities import (
|
||||||
ModelStatus,
|
ModelStatus,
|
||||||
ModelWithProviderEntity,
|
ModelWithProviderEntity,
|
||||||
@ -29,6 +26,9 @@ from model_providers.core.model_runtime.model_providers.__base.model_provider im
|
|||||||
ModelProvider,
|
ModelProvider,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ..._compat import PYDANTIC_V2, ConfigDict
|
||||||
|
from ..._models import BaseModel
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -351,11 +351,9 @@ class ProviderModelBundle(BaseModel):
|
|||||||
model_type_instance: AIModel
|
model_type_instance: AIModel
|
||||||
|
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)
|
||||||
protected_namespaces=(),
|
|
||||||
arbitrary_types_allowed=True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
|
|
||||||
from ..._compat import PYDANTIC_V2, ConfigDict
|
from ..._compat import PYDANTIC_V2, ConfigDict
|
||||||
from ..._models import BaseModel
|
from ..._models import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderType(Enum):
|
class ProviderType(Enum):
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
@ -59,10 +59,9 @@ class RestrictModel(BaseModel):
|
|||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
|
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
protected_namespaces=()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
@ -109,10 +108,9 @@ class CustomModelConfiguration(BaseModel):
|
|||||||
credentials: dict
|
credentials: dict
|
||||||
|
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
protected_namespaces=()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ..._models import BaseModel
|
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
LLMResult,
|
LLMResult,
|
||||||
LLMResultChunk,
|
LLMResultChunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ..._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class QueueEvent(Enum):
|
class QueueEvent(Enum):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -2,8 +2,6 @@ from decimal import Decimal
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from ...._models import BaseModel
|
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -13,6 +11,8 @@ from model_providers.core.model_runtime.entities.model_entities import (
|
|||||||
PriceInfo,
|
PriceInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ...._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class LLMMode(Enum):
|
class LLMMode(Enum):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -2,11 +2,11 @@ from decimal import Decimal
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
|
||||||
from ...._compat import PYDANTIC_V2, ConfigDict
|
from ...._compat import PYDANTIC_V2, ConfigDict
|
||||||
from ...._models import BaseModel
|
from ...._models import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
|
||||||
|
|
||||||
|
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
"""
|
"""
|
||||||
@ -164,10 +164,9 @@ class ProviderModel(BaseModel):
|
|||||||
model_properties: Dict[ModelPropertyKey, Any]
|
model_properties: Dict[ModelPropertyKey, Any]
|
||||||
deprecated: bool = False
|
deprecated: bool = False
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
protected_namespaces=()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from ...._compat import PYDANTIC_V2, ConfigDict
|
|
||||||
from ...._models import BaseModel
|
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||||
from model_providers.core.model_runtime.entities.model_entities import (
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
AIModelEntity,
|
AIModelEntity,
|
||||||
@ -11,6 +8,9 @@ from model_providers.core.model_runtime.entities.model_entities import (
|
|||||||
ProviderModel,
|
ProviderModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ...._compat import PYDANTIC_V2, ConfigDict
|
||||||
|
from ...._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class ConfigurateMethod(Enum):
|
class ConfigurateMethod(Enum):
|
||||||
"""
|
"""
|
||||||
@ -136,10 +136,9 @@ class ProviderEntity(BaseModel):
|
|||||||
model_credential_schema: Optional[ModelCredentialSchema] = None
|
model_credential_schema: Optional[ModelCredentialSchema] = None
|
||||||
|
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
protected_namespaces=()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from ...._models import BaseModel
|
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelUsage
|
from model_providers.core.model_runtime.entities.model_entities import ModelUsage
|
||||||
|
|
||||||
|
from ...._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingUsage(ModelUsage):
|
class EmbeddingUsage(ModelUsage):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from ....._models import BaseModel
|
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMMode
|
from model_providers.core.model_runtime.entities.llm_entities import LLMMode
|
||||||
from model_providers.core.model_runtime.entities.model_entities import (
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
@ -14,6 +12,8 @@ from model_providers.core.model_runtime.entities.model_entities import (
|
|||||||
PriceConfig,
|
PriceConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ....._models import BaseModel
|
||||||
|
|
||||||
AZURE_OPENAI_API_VERSION = "2024-02-15-preview"
|
AZURE_OPENAI_API_VERSION = "2024-02-15-preview"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Generator
|
from typing import Dict, Generator, List, Optional, Type, Union, cast
|
||||||
from typing import Dict, List, Optional, Type, Union, cast
|
|
||||||
|
|
||||||
import cohere
|
import cohere
|
||||||
from cohere.responses import Chat, Generations
|
from cohere.responses import Chat, Generations
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Generator
|
|
||||||
from typing import List, Optional, Union, cast
|
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
from typing import Generator, List, Optional, Union, cast
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from openai import OpenAI, Stream
|
from openai import OpenAI, Stream
|
||||||
@ -37,10 +36,15 @@ from model_providers.core.model_runtime.entities.message_entities import (
|
|||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.entities.model_entities import (
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
AIModelEntity,
|
AIModelEntity,
|
||||||
|
DefaultParameterName,
|
||||||
FetchFrom,
|
FetchFrom,
|
||||||
I18nObject,
|
I18nObject,
|
||||||
|
ModelFeature,
|
||||||
|
ModelPropertyKey,
|
||||||
ModelType,
|
ModelType,
|
||||||
PriceConfig, ParameterRule, ParameterType, ModelFeature, ModelPropertyKey, DefaultParameterName,
|
ParameterRule,
|
||||||
|
ParameterType,
|
||||||
|
PriceConfig,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import (
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
CredentialsValidateFailedError,
|
CredentialsValidateFailedError,
|
||||||
@ -48,7 +52,9 @@ from model_providers.core.model_runtime.errors.validate import (
|
|||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
LargeLanguageModel,
|
LargeLanguageModel,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.deepseek._common import _CommonDeepseek
|
from model_providers.core.model_runtime.model_providers.deepseek._common import (
|
||||||
|
_CommonDeepseek,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -1117,7 +1123,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
|||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def get_customizable_model_schema(
|
def get_customizable_model_schema(
|
||||||
self, model: str, credentials: dict
|
self, model: str, credentials: dict
|
||||||
) -> AIModelEntity:
|
) -> AIModelEntity:
|
||||||
"""
|
"""
|
||||||
Get customizable model schema.
|
Get customizable model schema.
|
||||||
@ -1129,7 +1135,6 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
extras = {}
|
extras = {}
|
||||||
|
|
||||||
|
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
label=I18nObject(zh_Hans=model, en_US=model),
|
label=I18nObject(zh_Hans=model, en_US=model),
|
||||||
@ -1149,8 +1154,8 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
|||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="The temperature of the model. "
|
en_US="The temperature of the model. "
|
||||||
"Increasing the temperature will make the model answer "
|
"Increasing the temperature will make the model answer "
|
||||||
"more creatively. (Default: 0.8)"
|
"more creatively. (Default: 0.8)"
|
||||||
),
|
),
|
||||||
default=0.8,
|
default=0.8,
|
||||||
min=0,
|
min=0,
|
||||||
@ -1163,8 +1168,8 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
|||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
||||||
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
||||||
"focused and conservative text. (Default: 0.9)"
|
"focused and conservative text. (Default: 0.9)"
|
||||||
),
|
),
|
||||||
default=0.9,
|
default=0.9,
|
||||||
min=0,
|
min=0,
|
||||||
@ -1177,8 +1182,8 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
|||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="A number between -2.0 and 2.0. If positive, ",
|
en_US="A number between -2.0 and 2.0. If positive, ",
|
||||||
zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正,"
|
zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正,"
|
||||||
"那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,"
|
"那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,"
|
||||||
"降低模型重复相同内容的可能性"
|
"降低模型重复相同内容的可能性",
|
||||||
),
|
),
|
||||||
default=0,
|
default=0,
|
||||||
min=-2,
|
min=-2,
|
||||||
@ -1190,7 +1195,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
|||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Sets how strongly to presence_penalty. ",
|
en_US="Sets how strongly to presence_penalty. ",
|
||||||
zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其是否已在已有文本中出现受到相应的惩罚,从而增加模型谈论新主题的可能性。"
|
zh_Hans="介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其是否已在已有文本中出现受到相应的惩罚,从而增加模型谈论新主题的可能性。",
|
||||||
),
|
),
|
||||||
default=1.1,
|
default=1.1,
|
||||||
min=-2,
|
min=-2,
|
||||||
@ -1204,7 +1209,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
|||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Maximum number of tokens to predict when generating text. ",
|
en_US="Maximum number of tokens to predict when generating text. ",
|
||||||
zh_Hans="限制一次请求中模型生成 completion 的最大 token 数。"
|
zh_Hans="限制一次请求中模型生成 completion 的最大 token 数。"
|
||||||
"输入 token 和输出 token 的总长度受模型的上下文长度的限制。"
|
"输入 token 和输出 token 的总长度受模型的上下文长度的限制。",
|
||||||
),
|
),
|
||||||
default=128,
|
default=128,
|
||||||
min=-2,
|
min=-2,
|
||||||
@ -1216,7 +1221,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
|||||||
type=ParameterType.BOOLEAN,
|
type=ParameterType.BOOLEAN,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Whether to return the log probabilities of the tokens. ",
|
en_US="Whether to return the log probabilities of the tokens. ",
|
||||||
zh_Hans="是否返回所输出 token 的对数概率。如果为 true,则在 message 的 content 中返回每个输出 token 的对数概率。"
|
zh_Hans="是否返回所输出 token 的对数概率。如果为 true,则在 message 的 content 中返回每个输出 token 的对数概率。",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
@ -1226,7 +1231,7 @@ class DeepseekLargeLanguageModel(_CommonDeepseek, LargeLanguageModel):
|
|||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="the format to return a response in.",
|
en_US="the format to return a response in.",
|
||||||
zh_Hans="一个介于 0 到 20 之间的整数 N,指定每个输出位置返回输出概率 top N 的 token,"
|
zh_Hans="一个介于 0 到 20 之间的整数 N,指定每个输出位置返回输出概率 top N 的 token,"
|
||||||
"且返回这些 token 的对数概率。指定此参数时,logprobs 必须为 true。"
|
"且返回这些 token 的对数概率。指定此参数时,logprobs 必须为 true。",
|
||||||
),
|
),
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
|
|||||||
@ -3,8 +3,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from ...._models import BaseModel
|
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.entities.provider_entities import (
|
from model_providers.core.model_runtime.entities.provider_entities import (
|
||||||
ProviderConfig,
|
ProviderConfig,
|
||||||
@ -25,6 +23,8 @@ from model_providers.core.utils.position_helper import (
|
|||||||
sort_to_dict_by_position_map,
|
sort_to_dict_by_position_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ...._models import BaseModel
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from typing import Generator
|
from typing import Generator, List, Optional, Union
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Generator
|
|
||||||
from typing import List, Optional, Union, cast
|
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
from typing import Generator, List, Optional, Union, cast
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from openai import OpenAI, Stream
|
from openai import OpenAI, Stream
|
||||||
@ -37,10 +36,15 @@ from model_providers.core.model_runtime.entities.message_entities import (
|
|||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.entities.model_entities import (
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
AIModelEntity,
|
AIModelEntity,
|
||||||
|
DefaultParameterName,
|
||||||
FetchFrom,
|
FetchFrom,
|
||||||
I18nObject,
|
I18nObject,
|
||||||
|
ModelFeature,
|
||||||
|
ModelPropertyKey,
|
||||||
ModelType,
|
ModelType,
|
||||||
PriceConfig, ModelFeature, ModelPropertyKey, DefaultParameterName, ParameterRule, ParameterType,
|
ParameterRule,
|
||||||
|
ParameterType,
|
||||||
|
PriceConfig,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import (
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
CredentialsValidateFailedError,
|
CredentialsValidateFailedError,
|
||||||
@ -48,7 +52,9 @@ from model_providers.core.model_runtime.errors.validate import (
|
|||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
LargeLanguageModel,
|
LargeLanguageModel,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.ollama._common import _CommonOllama
|
from model_providers.core.model_runtime.model_providers.ollama._common import (
|
||||||
|
_CommonOllama,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -661,7 +667,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||||
model=model,
|
model=model,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
extra_body=extra_body
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
@ -1120,7 +1126,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def get_customizable_model_schema(
|
def get_customizable_model_schema(
|
||||||
self, model: str, credentials: dict
|
self, model: str, credentials: dict
|
||||||
) -> AIModelEntity:
|
) -> AIModelEntity:
|
||||||
"""
|
"""
|
||||||
Get customizable model schema.
|
Get customizable model schema.
|
||||||
@ -1154,8 +1160,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="The temperature of the model. "
|
en_US="The temperature of the model. "
|
||||||
"Increasing the temperature will make the model answer "
|
"Increasing the temperature will make the model answer "
|
||||||
"more creatively. (Default: 0.8)"
|
"more creatively. (Default: 0.8)"
|
||||||
),
|
),
|
||||||
default=0.8,
|
default=0.8,
|
||||||
min=0,
|
min=0,
|
||||||
@ -1168,8 +1174,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
||||||
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
||||||
"focused and conservative text. (Default: 0.9)"
|
"focused and conservative text. (Default: 0.9)"
|
||||||
),
|
),
|
||||||
default=0.9,
|
default=0.9,
|
||||||
min=0,
|
min=0,
|
||||||
@ -1181,8 +1187,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Reduces the probability of generating nonsense. "
|
en_US="Reduces the probability of generating nonsense. "
|
||||||
"A higher value (e.g. 100) will give more diverse answers, "
|
"A higher value (e.g. 100) will give more diverse answers, "
|
||||||
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"
|
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"
|
||||||
),
|
),
|
||||||
default=40,
|
default=40,
|
||||||
min=1,
|
min=1,
|
||||||
@ -1194,8 +1200,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Sets how strongly to penalize repetitions. "
|
en_US="Sets how strongly to penalize repetitions. "
|
||||||
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
||||||
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
|
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
|
||||||
),
|
),
|
||||||
default=1.1,
|
default=1.1,
|
||||||
min=-2,
|
min=-2,
|
||||||
@ -1208,7 +1214,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Maximum number of tokens to predict when generating text. "
|
en_US="Maximum number of tokens to predict when generating text. "
|
||||||
"(Default: 128, -1 = infinite generation, -2 = fill context)"
|
"(Default: 128, -1 = infinite generation, -2 = fill context)"
|
||||||
),
|
),
|
||||||
default=128,
|
default=128,
|
||||||
min=-2,
|
min=-2,
|
||||||
@ -1220,7 +1226,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Enable Mirostat sampling for controlling perplexity. "
|
en_US="Enable Mirostat sampling for controlling perplexity. "
|
||||||
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
|
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
|
||||||
),
|
),
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
@ -1232,9 +1238,9 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Influences how quickly the algorithm responds to feedback from "
|
en_US="Influences how quickly the algorithm responds to feedback from "
|
||||||
"the generated text. A lower learning rate will result in slower adjustments, "
|
"the generated text. A lower learning rate will result in slower adjustments, "
|
||||||
"while a higher learning rate will make the algorithm more responsive. "
|
"while a higher learning rate will make the algorithm more responsive. "
|
||||||
"(Default: 0.1)"
|
"(Default: 0.1)"
|
||||||
),
|
),
|
||||||
default=0.1,
|
default=0.1,
|
||||||
precision=1,
|
precision=1,
|
||||||
@ -1245,7 +1251,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Controls the balance between coherence and diversity of the output. "
|
en_US="Controls the balance between coherence and diversity of the output. "
|
||||||
"A lower value will result in more focused and coherent text. (Default: 5.0)"
|
"A lower value will result in more focused and coherent text. (Default: 5.0)"
|
||||||
),
|
),
|
||||||
default=5.0,
|
default=5.0,
|
||||||
precision=1,
|
precision=1,
|
||||||
@ -1256,7 +1262,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Sets the size of the context window used to generate the next token. "
|
en_US="Sets the size of the context window used to generate the next token. "
|
||||||
"(Default: 2048)"
|
"(Default: 2048)"
|
||||||
),
|
),
|
||||||
default=2048,
|
default=2048,
|
||||||
min=1,
|
min=1,
|
||||||
@ -1267,7 +1273,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="The number of layers to send to the GPU(s). "
|
en_US="The number of layers to send to the GPU(s). "
|
||||||
"On macOS it defaults to 1 to enable metal support, 0 to disable."
|
"On macOS it defaults to 1 to enable metal support, 0 to disable."
|
||||||
),
|
),
|
||||||
default=1,
|
default=1,
|
||||||
min=0,
|
min=0,
|
||||||
@ -1279,9 +1285,9 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Sets the number of threads to use during computation. "
|
en_US="Sets the number of threads to use during computation. "
|
||||||
"By default, Ollama will detect this for optimal performance. "
|
"By default, Ollama will detect this for optimal performance. "
|
||||||
"It is recommended to set this value to the number of physical CPU cores "
|
"It is recommended to set this value to the number of physical CPU cores "
|
||||||
"your system has (as opposed to the logical number of cores)."
|
"your system has (as opposed to the logical number of cores)."
|
||||||
),
|
),
|
||||||
min=1,
|
min=1,
|
||||||
),
|
),
|
||||||
@ -1291,7 +1297,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Sets how far back for the model to look back to prevent repetition. "
|
en_US="Sets how far back for the model to look back to prevent repetition. "
|
||||||
"(Default: 64, 0 = disabled, -1 = num_ctx)"
|
"(Default: 64, 0 = disabled, -1 = num_ctx)"
|
||||||
),
|
),
|
||||||
default=64,
|
default=64,
|
||||||
min=-1,
|
min=-1,
|
||||||
@ -1302,8 +1308,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Tail free sampling is used to reduce the impact of less probable tokens "
|
en_US="Tail free sampling is used to reduce the impact of less probable tokens "
|
||||||
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
||||||
"while a value of 1.0 disables this setting. (default: 1)"
|
"while a value of 1.0 disables this setting. (default: 1)"
|
||||||
),
|
),
|
||||||
default=1,
|
default=1,
|
||||||
precision=1,
|
precision=1,
|
||||||
@ -1314,8 +1320,8 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Sets the random number seed to use for generation. Setting this to "
|
en_US="Sets the random number seed to use for generation. Setting this to "
|
||||||
"a specific number will make the model generate the same text for "
|
"a specific number will make the model generate the same text for "
|
||||||
"the same prompt. (Default: 0)"
|
"the same prompt. (Default: 0)"
|
||||||
),
|
),
|
||||||
default=0,
|
default=0,
|
||||||
),
|
),
|
||||||
@ -1325,7 +1331,7 @@ class OllamaLargeLanguageModel(_CommonOllama, LargeLanguageModel):
|
|||||||
type=ParameterType.STRING,
|
type=ParameterType.STRING,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="the format to return a response in."
|
en_US="the format to return a response in."
|
||||||
" Currently the only accepted value is json."
|
" Currently the only accepted value is json."
|
||||||
),
|
),
|
||||||
options=["json"],
|
options=["json"],
|
||||||
),
|
),
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Generator
|
from typing import Generator, List, Optional, Union, cast
|
||||||
from typing import List, Optional, Union, cast
|
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from openai import OpenAI, Stream
|
from openai import OpenAI, Stream
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Generator
|
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import List, Optional, Union, cast
|
from typing import Generator, List, Optional, Union, cast
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from typing import Generator
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from json import dumps, loads
|
from json import dumps, loads
|
||||||
from typing import Any, Union
|
from typing import Any, Generator, Union
|
||||||
|
|
||||||
from requests import Response, post
|
from requests import Response, post
|
||||||
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
|
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import threading
|
import threading
|
||||||
from typing import Generator
|
from typing import Dict, Generator, List, Optional, Type, Union
|
||||||
from typing import Dict, List, Optional, Type, Union
|
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
LLMResult,
|
LLMResult,
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from typing import Generator
|
from typing import Generator, List, Optional, Union
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from typing import Generator
|
from typing import Dict, Generator, List, Optional, Type, Union
|
||||||
from typing import Dict, List, Optional, Type, Union
|
|
||||||
|
|
||||||
from dashscope import get_tokenizer
|
from dashscope import get_tokenizer
|
||||||
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
|
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
|
||||||
|
|||||||
@ -1,6 +1,4 @@
|
|||||||
from typing import Generator, Iterator
|
from typing import Dict, Generator, Iterator, List, Type, Union, cast
|
||||||
|
|
||||||
from typing import Dict, List, Union, cast, Type
|
|
||||||
|
|
||||||
from openai import (
|
from openai import (
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
@ -527,8 +525,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
def _extract_response_tool_calls(
|
def _extract_response_tool_calls(
|
||||||
self,
|
self,
|
||||||
response_tool_calls: Union[
|
response_tool_calls: Union[
|
||||||
List[ChatCompletionMessageToolCall],
|
List[ChatCompletionMessageToolCall], List[ChoiceDeltaToolCall]
|
||||||
List[ChoiceDeltaToolCall]
|
|
||||||
],
|
],
|
||||||
) -> List[AssistantPromptMessage.ToolCall]:
|
) -> List[AssistantPromptMessage.ToolCall]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -195,8 +195,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
if stop:
|
if stop:
|
||||||
extra_model_kwargs["stop"] = stop
|
extra_model_kwargs["stop"] = stop
|
||||||
|
|
||||||
client = ZhipuAI(base_url=credentials_kwargs["api_base"],
|
client = ZhipuAI(
|
||||||
api_key=credentials_kwargs["api_key"])
|
base_url=credentials_kwargs["api_base"],
|
||||||
|
api_key=credentials_kwargs["api_key"],
|
||||||
|
)
|
||||||
|
|
||||||
if len(prompt_messages) == 0:
|
if len(prompt_messages) == 0:
|
||||||
raise ValueError("At least one message is required")
|
raise ValueError("At least one message is required")
|
||||||
|
|||||||
@ -43,8 +43,10 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
|||||||
:return: embeddings result
|
:return: embeddings result
|
||||||
"""
|
"""
|
||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
client = ZhipuAI(base_url=credentials_kwargs["api_base"],
|
client = ZhipuAI(
|
||||||
api_key=credentials_kwargs["api_key"])
|
base_url=credentials_kwargs["api_base"],
|
||||||
|
api_key=credentials_kwargs["api_key"],
|
||||||
|
)
|
||||||
|
|
||||||
embeddings, embedding_used_tokens = self.embed_documents(model, client, texts)
|
embeddings, embedding_used_tokens = self.embed_documents(model, client, texts)
|
||||||
|
|
||||||
@ -85,8 +87,10 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
|||||||
try:
|
try:
|
||||||
# transform credentials to kwargs for model instance
|
# transform credentials to kwargs for model instance
|
||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
client = ZhipuAI(base_url=credentials_kwargs["api_base"],
|
client = ZhipuAI(
|
||||||
api_key=credentials_kwargs["api_key"])
|
base_url=credentials_kwargs["api_base"],
|
||||||
|
api_key=credentials_kwargs["api_key"],
|
||||||
|
)
|
||||||
|
|
||||||
# call embedding model
|
# call embedding model
|
||||||
self.embed_documents(
|
self.embed_documents(
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
from ...._models import BaseModel
|
from ...._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
from ..._models import BaseModel
|
from ..._models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Generator
|
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from typing import Union
|
from typing import Generator, Union
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
|
|||||||
@ -104,15 +104,17 @@ def logging_conf() -> dict:
|
|||||||
111,
|
111,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def providers_file(request) -> str:
|
def providers_file(request) -> str:
|
||||||
from pathlib import Path
|
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
# 当前执行目录
|
# 当前执行目录
|
||||||
# 获取当前测试文件的路径
|
# 获取当前测试文件的路径
|
||||||
test_file_path = Path(str(request.fspath)).parent
|
test_file_path = Path(str(request.fspath)).parent
|
||||||
print("test_file_path:",test_file_path)
|
print("test_file_path:", test_file_path)
|
||||||
return os.path.join(test_file_path,"model_providers.yaml")
|
return os.path.join(test_file_path, "model_providers.yaml")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -121,9 +123,7 @@ def init_server(logging_conf: dict, providers_file: str) -> None:
|
|||||||
try:
|
try:
|
||||||
boot = (
|
boot = (
|
||||||
BootstrapWebBuilder()
|
BootstrapWebBuilder()
|
||||||
.model_providers_cfg_path(
|
.model_providers_cfg_path(model_providers_cfg_path=providers_file)
|
||||||
model_providers_cfg_path=providers_file
|
|
||||||
)
|
|
||||||
.host(host="127.0.0.1")
|
.host(host="127.0.0.1")
|
||||||
.port(port=20000)
|
.port(port=20000)
|
||||||
.build()
|
.build()
|
||||||
@ -139,5 +139,4 @@ def init_server(logging_conf: dict, providers_file: str) -> None:
|
|||||||
boot.destroy()
|
boot.destroy()
|
||||||
|
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
|
|
||||||
raise
|
raise
|
||||||
|
|||||||
@ -1,15 +1,20 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||||
import pytest
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_llm(init_server: str):
|
def test_llm(init_server: str):
|
||||||
llm = ChatOpenAI(model_name="deepseek-chat", openai_api_key="sk-dcb625fcbc1e497d80b7b9493b51d758", openai_api_base=f"{init_server}/deepseek/v1")
|
llm = ChatOpenAI(
|
||||||
|
model_name="deepseek-chat",
|
||||||
|
openai_api_key="sk-dcb625fcbc1e497d80b7b9493b51d758",
|
||||||
|
openai_api_base=f"{init_server}/deepseek/v1",
|
||||||
|
)
|
||||||
template = """Question: {question}
|
template = """Question: {question}
|
||||||
|
|
||||||
Answer: Let's think step by step."""
|
Answer: Let's think step by step."""
|
||||||
|
|||||||
@ -1,15 +1,20 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||||
import pytest
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_llm(init_server: str):
|
def test_llm(init_server: str):
|
||||||
llm = ChatOpenAI(model_name="llama3", openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/ollama/v1")
|
llm = ChatOpenAI(
|
||||||
|
model_name="llama3",
|
||||||
|
openai_api_key="YOUR_API_KEY",
|
||||||
|
openai_api_base=f"{init_server}/ollama/v1",
|
||||||
|
)
|
||||||
template = """Question: {question}
|
template = """Question: {question}
|
||||||
|
|
||||||
Answer: Let's think step by step."""
|
Answer: Let's think step by step."""
|
||||||
@ -23,9 +28,11 @@ def test_llm(init_server: str):
|
|||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_embedding(init_server: str):
|
def test_embedding(init_server: str):
|
||||||
embeddings = OpenAIEmbeddings(model="text-embedding-3-large",
|
embeddings = OpenAIEmbeddings(
|
||||||
openai_api_key="YOUR_API_KEY",
|
model="text-embedding-3-large",
|
||||||
openai_api_base=f"{init_server}/zhipuai/v1")
|
openai_api_key="YOUR_API_KEY",
|
||||||
|
openai_api_base=f"{init_server}/zhipuai/v1",
|
||||||
|
)
|
||||||
|
|
||||||
text = "你好"
|
text = "你好"
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +1,18 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||||
import pytest
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_llm(init_server: str):
|
def test_llm(init_server: str):
|
||||||
llm = ChatOpenAI(openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1")
|
llm = ChatOpenAI(
|
||||||
|
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1"
|
||||||
|
)
|
||||||
template = """Question: {question}
|
template = """Question: {question}
|
||||||
|
|
||||||
Answer: Let's think step by step."""
|
Answer: Let's think step by step."""
|
||||||
@ -21,17 +24,16 @@ def test_llm(init_server: str):
|
|||||||
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
|
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_embedding(init_server: str):
|
def test_embedding(init_server: str):
|
||||||
|
embeddings = OpenAIEmbeddings(
|
||||||
embeddings = OpenAIEmbeddings(model="text-embedding-3-large",
|
model="text-embedding-3-large",
|
||||||
openai_api_key="YOUR_API_KEY",
|
openai_api_key="YOUR_API_KEY",
|
||||||
openai_api_base=f"{init_server}/zhipuai/v1")
|
openai_api_base=f"{init_server}/zhipuai/v1",
|
||||||
|
)
|
||||||
|
|
||||||
text = "你好"
|
text = "你好"
|
||||||
|
|
||||||
query_result = embeddings.embed_query(text)
|
query_result = embeddings.embed_query(text)
|
||||||
|
|
||||||
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
|
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||||
import pytest
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -10,9 +11,10 @@ logger = logging.getLogger(__name__)
|
|||||||
@pytest.mark.requires("xinference_client")
|
@pytest.mark.requires("xinference_client")
|
||||||
def test_llm(init_server: str):
|
def test_llm(init_server: str):
|
||||||
llm = ChatOpenAI(
|
llm = ChatOpenAI(
|
||||||
|
|
||||||
model_name="glm-4",
|
model_name="glm-4",
|
||||||
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/xinference/v1")
|
openai_api_key="YOUR_API_KEY",
|
||||||
|
openai_api_base=f"{init_server}/xinference/v1",
|
||||||
|
)
|
||||||
template = """Question: {question}
|
template = """Question: {question}
|
||||||
|
|
||||||
Answer: Let's think step by step."""
|
Answer: Let's think step by step."""
|
||||||
@ -26,9 +28,11 @@ def test_llm(init_server: str):
|
|||||||
|
|
||||||
@pytest.mark.requires("xinference-client")
|
@pytest.mark.requires("xinference-client")
|
||||||
def test_embedding(init_server: str):
|
def test_embedding(init_server: str):
|
||||||
embeddings = OpenAIEmbeddings(model="text_embedding",
|
embeddings = OpenAIEmbeddings(
|
||||||
openai_api_key="YOUR_API_KEY",
|
model="text_embedding",
|
||||||
openai_api_base=f"{init_server}/xinference/v1")
|
openai_api_key="YOUR_API_KEY",
|
||||||
|
openai_api_base=f"{init_server}/xinference/v1",
|
||||||
|
)
|
||||||
|
|
||||||
text = "你好"
|
text = "你好"
|
||||||
|
|
||||||
|
|||||||
@ -1,17 +1,20 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||||
import pytest
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("zhipuai")
|
@pytest.mark.requires("zhipuai")
|
||||||
def test_llm(init_server: str):
|
def test_llm(init_server: str):
|
||||||
llm = ChatOpenAI(
|
llm = ChatOpenAI(
|
||||||
|
|
||||||
model_name="glm-4",
|
model_name="glm-4",
|
||||||
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1")
|
openai_api_key="YOUR_API_KEY",
|
||||||
|
openai_api_base=f"{init_server}/zhipuai/v1",
|
||||||
|
)
|
||||||
template = """Question: {question}
|
template = """Question: {question}
|
||||||
|
|
||||||
Answer: Let's think step by step."""
|
Answer: Let's think step by step."""
|
||||||
@ -25,15 +28,14 @@ def test_llm(init_server: str):
|
|||||||
|
|
||||||
@pytest.mark.requires("zhipuai")
|
@pytest.mark.requires("zhipuai")
|
||||||
def test_embedding(init_server: str):
|
def test_embedding(init_server: str):
|
||||||
|
embeddings = OpenAIEmbeddings(
|
||||||
embeddings = OpenAIEmbeddings(model="text_embedding",
|
model="text_embedding",
|
||||||
openai_api_key="YOUR_API_KEY",
|
openai_api_key="YOUR_API_KEY",
|
||||||
openai_api_base=f"{init_server}/zhipuai/v1")
|
openai_api_base=f"{init_server}/zhipuai/v1",
|
||||||
|
)
|
||||||
|
|
||||||
text = "你好"
|
text = "你好"
|
||||||
|
|
||||||
query_result = embeddings.embed_query(text)
|
query_result = embeddings.embed_query(text)
|
||||||
|
|
||||||
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
|
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,10 @@ from omegaconf import OmegaConf
|
|||||||
|
|
||||||
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
||||||
from model_providers.core.model_manager import ModelManager
|
from model_providers.core.model_manager import ModelManager
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
from model_providers.core.provider_manager import ProviderManager
|
from model_providers.core.provider_manager import ProviderManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -16,9 +19,7 @@ logger = logging.getLogger(__name__)
|
|||||||
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||||
logging.config.dictConfig(logging_conf) # type: ignore
|
logging.config.dictConfig(logging_conf) # type: ignore
|
||||||
# 读取配置文件
|
# 读取配置文件
|
||||||
cfg = OmegaConf.load(
|
cfg = OmegaConf.load(providers_file)
|
||||||
providers_file
|
|
||||||
)
|
|
||||||
# 转换配置文件
|
# 转换配置文件
|
||||||
(
|
(
|
||||||
provider_name_to_provider_records_dict,
|
provider_name_to_provider_records_dict,
|
||||||
@ -35,15 +36,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
|||||||
)
|
)
|
||||||
llm_models: List[AIModelEntity] = []
|
llm_models: List[AIModelEntity] = []
|
||||||
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
||||||
|
llm_models.append(
|
||||||
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||||
model=model.model,
|
model=model.model,
|
||||||
credentials=model.credentials,
|
credentials=model.credentials,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 获取预定义模型
|
# 获取预定义模型
|
||||||
llm_models.extend(
|
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
|
||||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"predefined_models: {llm_models}")
|
logger.info(f"predefined_models: {llm_models}")
|
||||||
|
|||||||
@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
|
|||||||
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||||
logging.config.dictConfig(logging_conf) # type: ignore
|
logging.config.dictConfig(logging_conf) # type: ignore
|
||||||
# 读取配置文件
|
# 读取配置文件
|
||||||
cfg = OmegaConf.load(
|
cfg = OmegaConf.load(providers_file)
|
||||||
providers_file
|
|
||||||
)
|
|
||||||
# 转换配置文件
|
# 转换配置文件
|
||||||
(
|
(
|
||||||
provider_name_to_provider_records_dict,
|
provider_name_to_provider_records_dict,
|
||||||
@ -34,16 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
|||||||
)
|
)
|
||||||
llm_models: List[AIModelEntity] = []
|
llm_models: List[AIModelEntity] = []
|
||||||
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
||||||
|
llm_models.append(
|
||||||
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||||
model=model.model,
|
model=model.model,
|
||||||
credentials=model.credentials,
|
credentials=model.credentials,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 获取预定义模型
|
# 获取预定义模型
|
||||||
llm_models.extend(
|
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
|
||||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"predefined_models: {llm_models}")
|
logger.info(f"predefined_models: {llm_models}")
|
||||||
|
|
||||||
|
|||||||
@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
|
|||||||
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||||
logging.config.dictConfig(logging_conf) # type: ignore
|
logging.config.dictConfig(logging_conf) # type: ignore
|
||||||
# 读取配置文件
|
# 读取配置文件
|
||||||
cfg = OmegaConf.load(
|
cfg = OmegaConf.load(providers_file)
|
||||||
providers_file
|
|
||||||
)
|
|
||||||
# 转换配置文件
|
# 转换配置文件
|
||||||
(
|
(
|
||||||
provider_name_to_provider_records_dict,
|
provider_name_to_provider_records_dict,
|
||||||
@ -34,15 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
|||||||
)
|
)
|
||||||
llm_models: List[AIModelEntity] = []
|
llm_models: List[AIModelEntity] = []
|
||||||
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
||||||
|
llm_models.append(
|
||||||
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||||
model=model.model,
|
model=model.model,
|
||||||
credentials=model.credentials,
|
credentials=model.credentials,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 获取预定义模型
|
# 获取预定义模型
|
||||||
llm_models.extend(
|
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
|
||||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"predefined_models: {llm_models}")
|
logger.info(f"predefined_models: {llm_models}")
|
||||||
|
|||||||
@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
|
|||||||
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||||
logging.config.dictConfig(logging_conf) # type: ignore
|
logging.config.dictConfig(logging_conf) # type: ignore
|
||||||
# 读取配置文件
|
# 读取配置文件
|
||||||
cfg = OmegaConf.load(
|
cfg = OmegaConf.load(providers_file)
|
||||||
providers_file
|
|
||||||
)
|
|
||||||
# 转换配置文件
|
# 转换配置文件
|
||||||
(
|
(
|
||||||
provider_name_to_provider_records_dict,
|
provider_name_to_provider_records_dict,
|
||||||
@ -34,15 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
|||||||
)
|
)
|
||||||
llm_models: List[AIModelEntity] = []
|
llm_models: List[AIModelEntity] = []
|
||||||
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
||||||
|
llm_models.append(
|
||||||
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||||
model=model.model,
|
model=model.model,
|
||||||
credentials=model.credentials,
|
credentials=model.credentials,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 获取预定义模型
|
# 获取预定义模型
|
||||||
llm_models.extend(
|
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
|
||||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"predefined_models: {llm_models}")
|
logger.info(f"predefined_models: {llm_models}")
|
||||||
|
|||||||
@ -15,9 +15,7 @@ logger = logging.getLogger(__name__)
|
|||||||
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
||||||
logging.config.dictConfig(logging_conf) # type: ignore
|
logging.config.dictConfig(logging_conf) # type: ignore
|
||||||
# 读取配置文件
|
# 读取配置文件
|
||||||
cfg = OmegaConf.load(
|
cfg = OmegaConf.load(providers_file)
|
||||||
providers_file
|
|
||||||
)
|
|
||||||
# 转换配置文件
|
# 转换配置文件
|
||||||
(
|
(
|
||||||
provider_name_to_provider_records_dict,
|
provider_name_to_provider_records_dict,
|
||||||
@ -34,15 +32,14 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
|
|||||||
)
|
)
|
||||||
llm_models: List[AIModelEntity] = []
|
llm_models: List[AIModelEntity] = []
|
||||||
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
||||||
|
llm_models.append(
|
||||||
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
provider_model_bundle_llm.model_type_instance.get_model_schema(
|
||||||
model=model.model,
|
model=model.model,
|
||||||
credentials=model.credentials,
|
credentials=model.credentials,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 获取预定义模型
|
# 获取预定义模型
|
||||||
llm_models.extend(
|
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
|
||||||
provider_model_bundle_llm.model_type_instance.predefined_models()
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"predefined_models: {llm_models}")
|
logger.info(f"predefined_models: {llm_models}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user