mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-25 00:05:58 +08:00
201 lines
6.8 KiB
Python
201 lines
6.8 KiB
Python
import time
|
|
from typing import Optional
|
|
|
|
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle
|
|
|
|
from core.model_runtime.entities.common_entities import I18nObject
|
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
|
from core.model_runtime.errors.invoke import (
|
|
InvokeAuthorizationError,
|
|
InvokeBadRequestError,
|
|
InvokeConnectionError,
|
|
InvokeError,
|
|
InvokeRateLimitError,
|
|
InvokeServerUnavailableError,
|
|
)
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
|
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
|
|
|
|
|
|
class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|
"""
|
|
Model class for Xinference text embedding model.
|
|
"""
|
|
def _invoke(self, model: str, credentials: dict,
|
|
texts: list[str], user: Optional[str] = None) \
|
|
-> TextEmbeddingResult:
|
|
"""
|
|
Invoke text embedding model
|
|
|
|
credentials should be like:
|
|
{
|
|
'server_url': 'server url',
|
|
'model_uid': 'model uid',
|
|
}
|
|
|
|
:param model: model name
|
|
:param credentials: model credentials
|
|
:param texts: texts to embed
|
|
:param user: unique user id
|
|
:return: embeddings result
|
|
"""
|
|
server_url = credentials['server_url']
|
|
model_uid = credentials['model_uid']
|
|
|
|
if server_url.endswith('/'):
|
|
server_url = server_url[:-1]
|
|
|
|
client = Client(base_url=server_url)
|
|
|
|
try:
|
|
handle = client.get_model(model_uid=model_uid)
|
|
except RuntimeError as e:
|
|
raise InvokeAuthorizationError(e)
|
|
|
|
if not isinstance(handle, RESTfulEmbeddingModelHandle):
|
|
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
|
|
|
|
try:
|
|
embeddings = handle.create_embedding(input=texts)
|
|
except RuntimeError as e:
|
|
raise InvokeServerUnavailableError(e)
|
|
|
|
"""
|
|
for convenience, the response json is like:
|
|
class Embedding(TypedDict):
|
|
object: Literal["list"]
|
|
model: str
|
|
data: List[EmbeddingData]
|
|
usage: EmbeddingUsage
|
|
class EmbeddingUsage(TypedDict):
|
|
prompt_tokens: int
|
|
total_tokens: int
|
|
class EmbeddingData(TypedDict):
|
|
index: int
|
|
object: str
|
|
embedding: List[float]
|
|
"""
|
|
|
|
usage = embeddings['usage']
|
|
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens'])
|
|
|
|
result = TextEmbeddingResult(
|
|
model=model,
|
|
embeddings=[embedding['embedding'] for embedding in embeddings['data']],
|
|
usage=usage
|
|
)
|
|
|
|
return result
|
|
|
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
|
"""
|
|
Get number of tokens for given prompt messages
|
|
|
|
:param model: model name
|
|
:param credentials: model credentials
|
|
:param texts: texts to embed
|
|
:return:
|
|
"""
|
|
num_tokens = 0
|
|
for text in texts:
|
|
# use GPT2Tokenizer to get num tokens
|
|
num_tokens += self._get_num_tokens_by_gpt2(text)
|
|
return num_tokens
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
"""
|
|
Validate model credentials
|
|
|
|
:param model: model name
|
|
:param credentials: model credentials
|
|
:return:
|
|
"""
|
|
try:
|
|
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
|
|
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
|
|
|
server_url = credentials['server_url']
|
|
model_uid = credentials['model_uid']
|
|
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
|
|
|
|
if extra_args.max_tokens:
|
|
credentials['max_tokens'] = extra_args.max_tokens
|
|
|
|
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
|
except InvokeAuthorizationError as e:
|
|
raise CredentialsValidateFailedError(f'Failed to validate credentials for model {model}: {e}')
|
|
except RuntimeError as e:
|
|
raise CredentialsValidateFailedError(e)
|
|
|
|
@property
|
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
|
return {
|
|
InvokeConnectionError: [
|
|
InvokeConnectionError
|
|
],
|
|
InvokeServerUnavailableError: [
|
|
InvokeServerUnavailableError
|
|
],
|
|
InvokeRateLimitError: [
|
|
InvokeRateLimitError
|
|
],
|
|
InvokeAuthorizationError: [
|
|
InvokeAuthorizationError
|
|
],
|
|
InvokeBadRequestError: [
|
|
KeyError
|
|
]
|
|
}
|
|
|
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
|
"""
|
|
Calculate response usage
|
|
|
|
:param model: model name
|
|
:param credentials: model credentials
|
|
:param tokens: input tokens
|
|
:return: usage
|
|
"""
|
|
# get input price info
|
|
input_price_info = self.get_price(
|
|
model=model,
|
|
credentials=credentials,
|
|
price_type=PriceType.INPUT,
|
|
tokens=tokens
|
|
)
|
|
|
|
# transform usage
|
|
usage = EmbeddingUsage(
|
|
tokens=tokens,
|
|
total_tokens=tokens,
|
|
unit_price=input_price_info.unit_price,
|
|
price_unit=input_price_info.unit,
|
|
total_price=input_price_info.total_amount,
|
|
currency=input_price_info.currency,
|
|
latency=time.perf_counter() - self.started_at
|
|
)
|
|
|
|
return usage
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
|
"""
|
|
used to define customizable model schema
|
|
"""
|
|
|
|
entity = AIModelEntity(
|
|
model=model,
|
|
label=I18nObject(
|
|
en_US=model
|
|
),
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
model_type=ModelType.TEXT_EMBEDDING,
|
|
model_properties={
|
|
ModelPropertyKey.MAX_CHUNKS: 1,
|
|
ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512,
|
|
},
|
|
parameter_rules=[]
|
|
)
|
|
|
|
return entity |