diff --git a/server/xinference/__init__.py b/server/xinference/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/server/xinference/_assets/icon_l_en.svg b/server/xinference/_assets/icon_l_en.svg
deleted file mode 100644
index 81091765..00000000
--- a/server/xinference/_assets/icon_l_en.svg
+++ /dev/null
@@ -1,42 +0,0 @@
-
diff --git a/server/xinference/_assets/icon_s_en.svg b/server/xinference/_assets/icon_s_en.svg
deleted file mode 100644
index f5c5f75e..00000000
--- a/server/xinference/_assets/icon_s_en.svg
+++ /dev/null
@@ -1,24 +0,0 @@
-
diff --git a/server/xinference/llm/__init__.py b/server/xinference/llm/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/server/xinference/llm/llm.py b/server/xinference/llm/llm.py
deleted file mode 100644
index 602d0b74..00000000
--- a/server/xinference/llm/llm.py
+++ /dev/null
@@ -1,734 +0,0 @@
-from collections.abc import Generator, Iterator
-from typing import cast
-
-from openai import (
- APIConnectionError,
- APITimeoutError,
- AuthenticationError,
- ConflictError,
- InternalServerError,
- NotFoundError,
- OpenAI,
- PermissionDeniedError,
- RateLimitError,
- UnprocessableEntityError,
-)
-from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
-from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
-from openai.types.chat.chat_completion_message import FunctionCall
-from openai.types.completion import Completion
-from xinference_client.client.restful.restful_client import (
- Client,
- RESTfulChatglmCppChatModelHandle,
- RESTfulChatModelHandle,
- RESTfulGenerateModelHandle,
-)
-
-from core.model_runtime.entities.common_entities import I18nObject
-from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
-from core.model_runtime.entities.message_entities import (
- AssistantPromptMessage,
- PromptMessage,
- PromptMessageTool,
- SystemPromptMessage,
- ToolPromptMessage,
- UserPromptMessage,
-)
-from core.model_runtime.entities.model_entities import (
- AIModelEntity,
- FetchFrom,
- ModelFeature,
- ModelPropertyKey,
- ModelType,
- ParameterRule,
- ParameterType,
-)
-from core.model_runtime.errors.invoke import (
- InvokeAuthorizationError,
- InvokeBadRequestError,
- InvokeConnectionError,
- InvokeError,
- InvokeRateLimitError,
- InvokeServerUnavailableError,
-)
-from core.model_runtime.errors.validate import CredentialsValidateFailedError
-from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from core.model_runtime.model_providers.xinference.xinference_helper import (
- XinferenceHelper,
- XinferenceModelExtraParameter,
-)
-from core.model_runtime.utils import helper
-
-
-class XinferenceAILargeLanguageModel(LargeLanguageModel):
- def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
- model_parameters: dict, tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
- -> LLMResult | Generator:
- """
- invoke LLM
-
- see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
- """
- if 'temperature' in model_parameters:
- if model_parameters['temperature'] < 0.01:
- model_parameters['temperature'] = 0.01
- elif model_parameters['temperature'] > 1.0:
- model_parameters['temperature'] = 0.99
-
- return self._generate(
- model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
- tools=tools, stop=stop, stream=stream, user=user,
- extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter(
- server_url=credentials['server_url'],
- model_uid=credentials['model_uid']
- )
- )
-
- def validate_credentials(self, model: str, credentials: dict) -> None:
- """
- validate credentials
-
- credentials should be like:
- {
- 'model_type': 'text-generation',
- 'server_url': 'server url',
- 'model_uid': 'model uid',
- }
- """
- try:
- if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
- raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
-
- extra_param = XinferenceHelper.get_xinference_extra_parameter(
- server_url=credentials['server_url'],
- model_uid=credentials['model_uid']
- )
- if 'completion_type' not in credentials:
- if 'chat' in extra_param.model_ability:
- credentials['completion_type'] = 'chat'
- elif 'generate' in extra_param.model_ability:
- credentials['completion_type'] = 'completion'
- else:
- raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type')
-
- if extra_param.support_function_call:
- credentials['support_function_call'] = True
-
- if extra_param.context_length:
- credentials['context_length'] = extra_param.context_length
-
- except RuntimeError as e:
- raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
- except KeyError as e:
- raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
- except Exception as e:
- raise e
-
- def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool] | None = None) -> int:
- """
- get number of tokens
-
- cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use
- so we just take the GPT2 tokenizer as default
- """
- return self._num_tokens_from_messages(prompt_messages, tools)
-
- def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
- is_completion_model: bool = False) -> int:
- def tokens(text: str):
- return self._get_num_tokens_by_gpt2(text)
-
- if is_completion_model:
- return sum([tokens(str(message.content)) for message in messages])
-
- tokens_per_message = 3
- tokens_per_name = 1
-
- num_tokens = 0
- messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
- for message in messages_dict:
- num_tokens += tokens_per_message
- for key, value in message.items():
- if isinstance(value, list):
- text = ''
- for item in value:
- if isinstance(item, dict) and item['type'] == 'text':
- text += item.text
-
- value = text
-
- if key == "tool_calls":
- for tool_call in value:
- for t_key, t_value in tool_call.items():
- num_tokens += tokens(t_key)
- if t_key == "function":
- for f_key, f_value in t_value.items():
- num_tokens += tokens(f_key)
- num_tokens += tokens(f_value)
- else:
- num_tokens += tokens(t_key)
- num_tokens += tokens(t_value)
- if key == "function_call":
- for t_key, t_value in value.items():
- num_tokens += tokens(t_key)
- if t_key == "function":
- for f_key, f_value in t_value.items():
- num_tokens += tokens(f_key)
- num_tokens += tokens(f_value)
- else:
- num_tokens += tokens(t_key)
- num_tokens += tokens(t_value)
- else:
- num_tokens += tokens(str(value))
-
- if key == "name":
- num_tokens += tokens_per_name
- num_tokens += 3
-
- if tools:
- num_tokens += self._num_tokens_for_tools(tools)
-
- return num_tokens
-
- def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
- """
- Calculate num tokens for tool calling
-
- :param encoding: encoding
- :param tools: tools for tool calling
- :return: number of tokens
- """
- def tokens(text: str):
- return self._get_num_tokens_by_gpt2(text)
-
- num_tokens = 0
- for tool in tools:
- # calculate num tokens for function object
- num_tokens += tokens('name')
- num_tokens += tokens(tool.name)
- num_tokens += tokens('description')
- num_tokens += tokens(tool.description)
- parameters = tool.parameters
- num_tokens += tokens('parameters')
- num_tokens += tokens('type')
- num_tokens += tokens(parameters.get("type"))
- if 'properties' in parameters:
- num_tokens += tokens('properties')
- for key, value in parameters.get('properties').items():
- num_tokens += tokens(key)
- for field_key, field_value in value.items():
- num_tokens += tokens(field_key)
- if field_key == 'enum':
- for enum_field in field_value:
- num_tokens += 3
- num_tokens += tokens(enum_field)
- else:
- num_tokens += tokens(field_key)
- num_tokens += tokens(str(field_value))
- if 'required' in parameters:
- num_tokens += tokens('required')
- for required_field in parameters['required']:
- num_tokens += 3
- num_tokens += tokens(required_field)
-
- return num_tokens
-
- def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
- """
- convert prompt message to text
- """
- text = ''
- for item in message:
- if isinstance(item, UserPromptMessage):
- text += item.content
- elif isinstance(item, SystemPromptMessage):
- text += item.content
- elif isinstance(item, AssistantPromptMessage):
- text += item.content
- else:
- raise NotImplementedError(f'PromptMessage type {type(item)} is not supported')
- return text
-
- def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
- """
- Convert PromptMessage to dict for OpenAI Compatibility API
- """
- if isinstance(message, UserPromptMessage):
- message = cast(UserPromptMessage, message)
- if isinstance(message.content, str):
- message_dict = {"role": "user", "content": message.content}
- else:
- raise ValueError("User message content must be str")
- elif isinstance(message, AssistantPromptMessage):
- message = cast(AssistantPromptMessage, message)
- message_dict = {"role": "assistant", "content": message.content}
- if message.tool_calls and len(message.tool_calls) > 0:
- message_dict["function_call"] = {
- "name": message.tool_calls[0].function.name,
- "arguments": message.tool_calls[0].function.arguments
- }
- elif isinstance(message, SystemPromptMessage):
- message = cast(SystemPromptMessage, message)
- message_dict = {"role": "system", "content": message.content}
- elif isinstance(message, ToolPromptMessage):
- message = cast(ToolPromptMessage, message)
- message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
- else:
- raise ValueError(f"Unknown message type {type(message)}")
-
- return message_dict
-
- def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
- """
- used to define customizable model schema
- """
- rules = [
- ParameterRule(
- name='temperature',
- type=ParameterType.FLOAT,
- use_template='temperature',
- label=I18nObject(
- zh_Hans='温度',
- en_US='Temperature'
- ),
- ),
- ParameterRule(
- name='top_p',
- type=ParameterType.FLOAT,
- use_template='top_p',
- label=I18nObject(
- zh_Hans='Top P',
- en_US='Top P'
- )
- ),
- ParameterRule(
- name='max_tokens',
- type=ParameterType.INT,
- use_template='max_tokens',
- min=1,
- max=credentials.get('context_length', 2048),
- default=512,
- label=I18nObject(
- zh_Hans='最大生成长度',
- en_US='Max Tokens'
- )
- )
- ]
-
- completion_type = None
-
- if 'completion_type' in credentials:
- if credentials['completion_type'] == 'chat':
- completion_type = LLMMode.CHAT.value
- elif credentials['completion_type'] == 'completion':
- completion_type = LLMMode.COMPLETION.value
- else:
- raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
- else:
- extra_args = XinferenceHelper.get_xinference_extra_parameter(
- server_url=credentials['server_url'],
- model_uid=credentials['model_uid']
- )
-
- if 'chat' in extra_args.model_ability:
- completion_type = LLMMode.CHAT.value
- elif 'generate' in extra_args.model_ability:
- completion_type = LLMMode.COMPLETION.value
- else:
- raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
-
- support_function_call = credentials.get('support_function_call', False)
- context_length = credentials.get('context_length', 2048)
-
- entity = AIModelEntity(
- model=model,
- label=I18nObject(
- en_US=model
- ),
- fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
- model_type=ModelType.LLM,
- features=[
- ModelFeature.TOOL_CALL
- ] if support_function_call else [],
- model_properties={
- ModelPropertyKey.MODE: completion_type,
- ModelPropertyKey.CONTEXT_SIZE: context_length
- },
- parameter_rules=rules
- )
-
- return entity
-
- def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
- model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
- -> LLMResult | Generator:
- """
- generate text from LLM
-
- see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate`
-
- extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
- """
- if 'server_url' not in credentials:
- raise CredentialsValidateFailedError('server_url is required in credentials')
-
- if credentials['server_url'].endswith('/'):
- credentials['server_url'] = credentials['server_url'][:-1]
-
- client = OpenAI(
- base_url=f'{credentials["server_url"]}/v1',
- api_key='abc',
- max_retries=3,
- timeout=60,
- )
-
- xinference_client = Client(
- base_url=credentials['server_url'],
- )
-
- xinference_model = xinference_client.get_model(credentials['model_uid'])
-
- generate_config = {
- 'temperature': model_parameters.get('temperature', 1.0),
- 'top_p': model_parameters.get('top_p', 0.7),
- 'max_tokens': model_parameters.get('max_tokens', 512),
- }
-
- if stop:
- generate_config['stop'] = stop
-
- if tools and len(tools) > 0:
- generate_config['tools'] = [
- {
- 'type': 'function',
- 'function': helper.dump_model(tool)
- } for tool in tools
- ]
-
- if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle):
- resp = client.chat.completions.create(
- model=credentials['model_uid'],
- messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages],
- stream=stream,
- user=user,
- **generate_config,
- )
- if stream:
- if tools and len(tools) > 0:
- raise InvokeBadRequestError('xinference tool calls does not support stream mode')
- return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
- tools=tools, resp=resp)
- return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
- tools=tools, resp=resp)
- elif isinstance(xinference_model, RESTfulGenerateModelHandle):
- resp = client.completions.create(
- model=credentials['model_uid'],
- prompt=self._convert_prompt_message_to_text(prompt_messages),
- stream=stream,
- user=user,
- **generate_config,
- )
- if stream:
- return self._handle_completion_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
- tools=tools, resp=resp)
- return self._handle_completion_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
- tools=tools, resp=resp)
- else:
- raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported')
-
- def _extract_response_tool_calls(self,
- response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
- -> list[AssistantPromptMessage.ToolCall]:
- """
- Extract tool calls from response
-
- :param response_tool_calls: response tool calls
- :return: list of tool calls
- """
- tool_calls = []
- if response_tool_calls:
- for response_tool_call in response_tool_calls:
- function = AssistantPromptMessage.ToolCall.ToolCallFunction(
- name=response_tool_call.function.name,
- arguments=response_tool_call.function.arguments
- )
-
- tool_call = AssistantPromptMessage.ToolCall(
- id=response_tool_call.id,
- type=response_tool_call.type,
- function=function
- )
- tool_calls.append(tool_call)
-
- return tool_calls
-
- def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
- -> AssistantPromptMessage.ToolCall:
- """
- Extract function call from response
-
- :param response_function_call: response function call
- :return: tool call
- """
- tool_call = None
- if response_function_call:
- function = AssistantPromptMessage.ToolCall.ToolCallFunction(
- name=response_function_call.name,
- arguments=response_function_call.arguments
- )
-
- tool_call = AssistantPromptMessage.ToolCall(
- id=response_function_call.name,
- type="function",
- function=function
- )
-
- return tool_call
-
- def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool],
- resp: ChatCompletion) -> LLMResult:
- """
- handle normal chat generate response
- """
- if len(resp.choices) == 0:
- raise InvokeServerUnavailableError("Empty response")
-
- assistant_message = resp.choices[0].message
-
- # convert tool call to assistant message tool call
- tool_calls = assistant_message.tool_calls
- assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else [])
- function_call = assistant_message.function_call
- if function_call:
- assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)]
-
- # transform assistant message to prompt message
- assistant_prompt_message = AssistantPromptMessage(
- content=assistant_message.content,
- tool_calls=assistant_prompt_message_tool_calls
- )
-
- prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
- completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
-
- usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
-
- response = LLMResult(
- model=model,
- prompt_messages=prompt_messages,
- system_fingerprint=resp.system_fingerprint,
- usage=usage,
- message=assistant_prompt_message,
- )
-
- return response
-
- def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool],
- resp: Iterator[ChatCompletionChunk]) -> Generator:
- """
- handle stream chat generate response
- """
- full_response = ''
-
- for chunk in resp:
- if len(chunk.choices) == 0:
- continue
-
- delta = chunk.choices[0]
-
- if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
- continue
-
- # check if there is a tool call in the response
- function_call = None
- tool_calls = []
- if delta.delta.tool_calls:
- tool_calls += delta.delta.tool_calls
- if delta.delta.function_call:
- function_call = delta.delta.function_call
-
- assistant_message_tool_calls = self._extract_response_tool_calls(tool_calls)
- if function_call:
- assistant_message_tool_calls += [self._extract_response_function_call(function_call)]
-
- # transform assistant message to prompt message
- assistant_prompt_message = AssistantPromptMessage(
- content=delta.delta.content if delta.delta.content else '',
- tool_calls=assistant_message_tool_calls
- )
-
- if delta.finish_reason is not None:
- # temp_assistant_prompt_message is used to calculate usage
- temp_assistant_prompt_message = AssistantPromptMessage(
- content=full_response,
- tool_calls=assistant_message_tool_calls
- )
-
- prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
- completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
-
- usage = self._calc_response_usage(model=model, credentials=credentials,
- prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
-
- yield LLMResultChunk(
- model=model,
- prompt_messages=prompt_messages,
- system_fingerprint=chunk.system_fingerprint,
- delta=LLMResultChunkDelta(
- index=0,
- message=assistant_prompt_message,
- finish_reason=delta.finish_reason,
- usage=usage
- ),
- )
- else:
- yield LLMResultChunk(
- model=model,
- prompt_messages=prompt_messages,
- system_fingerprint=chunk.system_fingerprint,
- delta=LLMResultChunkDelta(
- index=0,
- message=assistant_prompt_message,
- ),
- )
-
- full_response += delta.delta.content
-
- def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool],
- resp: Completion) -> LLMResult:
- """
- handle normal completion generate response
- """
- if len(resp.choices) == 0:
- raise InvokeServerUnavailableError("Empty response")
-
- assistant_message = resp.choices[0].text
-
- # transform assistant message to prompt message
- assistant_prompt_message = AssistantPromptMessage(
- content=assistant_message,
- tool_calls=[]
- )
-
- prompt_tokens = self._get_num_tokens_by_gpt2(
- self._convert_prompt_message_to_text(prompt_messages)
- )
- completion_tokens = self._num_tokens_from_messages(
- messages=[assistant_prompt_message], tools=[], is_completion_model=True
- )
- usage = self._calc_response_usage(
- model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
- )
-
- response = LLMResult(
- model=model,
- prompt_messages=prompt_messages,
- system_fingerprint=resp.system_fingerprint,
- usage=usage,
- message=assistant_prompt_message,
- )
-
- return response
-
- def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool],
- resp: Iterator[Completion]) -> Generator:
- """
- handle stream completion generate response
- """
- full_response = ''
-
- for chunk in resp:
- if len(chunk.choices) == 0:
- continue
-
- delta = chunk.choices[0]
-
- # transform assistant message to prompt message
- assistant_prompt_message = AssistantPromptMessage(
- content=delta.text if delta.text else '',
- tool_calls=[]
- )
-
- if delta.finish_reason is not None:
- # temp_assistant_prompt_message is used to calculate usage
- temp_assistant_prompt_message = AssistantPromptMessage(
- content=full_response,
- tool_calls=[]
- )
-
- prompt_tokens = self._get_num_tokens_by_gpt2(
- self._convert_prompt_message_to_text(prompt_messages)
- )
- completion_tokens = self._num_tokens_from_messages(
- messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True
- )
- usage = self._calc_response_usage(model=model, credentials=credentials,
- prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
-
- yield LLMResultChunk(
- model=model,
- prompt_messages=prompt_messages,
- system_fingerprint=chunk.system_fingerprint,
- delta=LLMResultChunkDelta(
- index=0,
- message=assistant_prompt_message,
- finish_reason=delta.finish_reason,
- usage=usage
- ),
- )
- else:
- if delta.text is None or delta.text == '':
- continue
-
- yield LLMResultChunk(
- model=model,
- prompt_messages=prompt_messages,
- system_fingerprint=chunk.system_fingerprint,
- delta=LLMResultChunkDelta(
- index=0,
- message=assistant_prompt_message,
- ),
- )
-
- full_response += delta.text
-
- @property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
- """
- Map model invoke error to unified error
- The key is the error type thrown to the caller
- The value is the error type thrown by the model,
- which needs to be converted into a unified error type for the caller.
-
- :return: Invoke error mapping
- """
- return {
- InvokeConnectionError: [
- APIConnectionError,
- APITimeoutError,
- ],
- InvokeServerUnavailableError: [
- InternalServerError,
- ConflictError,
- NotFoundError,
- UnprocessableEntityError,
- PermissionDeniedError
- ],
- InvokeRateLimitError: [
- RateLimitError
- ],
- InvokeAuthorizationError: [
- AuthenticationError
- ],
- InvokeBadRequestError: [
- ValueError
- ]
- }
\ No newline at end of file
diff --git a/server/xinference/rerank/__init__.py b/server/xinference/rerank/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/server/xinference/rerank/rerank.py b/server/xinference/rerank/rerank.py
deleted file mode 100644
index dd25037d..00000000
--- a/server/xinference/rerank/rerank.py
+++ /dev/null
@@ -1,160 +0,0 @@
-from typing import Optional
-
-from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle
-
-from core.model_runtime.entities.common_entities import I18nObject
-from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
-from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
-from core.model_runtime.errors.invoke import (
- InvokeAuthorizationError,
- InvokeBadRequestError,
- InvokeConnectionError,
- InvokeError,
- InvokeRateLimitError,
- InvokeServerUnavailableError,
-)
-from core.model_runtime.errors.validate import CredentialsValidateFailedError
-from core.model_runtime.model_providers.__base.rerank_model import RerankModel
-
-
-class XinferenceRerankModel(RerankModel):
- """
- Model class for Xinference rerank model.
- """
-
- def _invoke(self, model: str, credentials: dict,
- query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
- user: Optional[str] = None) \
- -> RerankResult:
- """
- Invoke rerank model
-
- :param model: model name
- :param credentials: model credentials
- :param query: search query
- :param docs: docs for reranking
- :param score_threshold: score threshold
- :param top_n: top n
- :param user: unique user id
- :return: rerank result
- """
- if len(docs) == 0:
- return RerankResult(
- model=model,
- docs=[]
- )
-
- if credentials['server_url'].endswith('/'):
- credentials['server_url'] = credentials['server_url'][:-1]
-
- # initialize client
- client = Client(
- base_url=credentials['server_url']
- )
-
- xinference_client = client.get_model(model_uid=credentials['model_uid'])
-
- if not isinstance(xinference_client, RESTfulRerankModelHandle):
- raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model')
-
- response = xinference_client.rerank(
- documents=docs,
- query=query,
- top_n=top_n,
- )
-
- rerank_documents = []
- for idx, result in enumerate(response['results']):
- # format document
- index = result['index']
- page_content = result['document']
- rerank_document = RerankDocument(
- index=index,
- text=page_content,
- score=result['relevance_score'],
- )
-
- # score threshold check
- if score_threshold is not None:
- if result['relevance_score'] >= score_threshold:
- rerank_documents.append(rerank_document)
- else:
- rerank_documents.append(rerank_document)
-
- return RerankResult(
- model=model,
- docs=rerank_documents
- )
-
- def validate_credentials(self, model: str, credentials: dict) -> None:
- """
- Validate model credentials
-
- :param model: model name
- :param credentials: model credentials
- :return:
- """
- try:
- if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
- raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
-
- self.invoke(
- model=model,
- credentials=credentials,
- query="Whose kasumi",
- docs=[
- "Kasumi is a girl's name of Japanese origin meaning \"mist\".",
- "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
- "and she leads a team named PopiParty."
- ],
- score_threshold=0.8
- )
- except Exception as ex:
- raise CredentialsValidateFailedError(str(ex))
-
- @property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
- """
- Map model invoke error to unified error
- The key is the error type thrown to the caller
- The value is the error type thrown by the model,
- which needs to be converted into a unified error type for the caller.
-
- :return: Invoke error mapping
- """
- return {
- InvokeConnectionError: [
- InvokeConnectionError
- ],
- InvokeServerUnavailableError: [
- InvokeServerUnavailableError
- ],
- InvokeRateLimitError: [
- InvokeRateLimitError
- ],
- InvokeAuthorizationError: [
- InvokeAuthorizationError
- ],
- InvokeBadRequestError: [
- InvokeBadRequestError,
- KeyError,
- ValueError
- ]
- }
-
- def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
- """
- used to define customizable model schema
- """
- entity = AIModelEntity(
- model=model,
- label=I18nObject(
- en_US=model
- ),
- fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
- model_type=ModelType.RERANK,
- model_properties={ },
- parameter_rules=[]
- )
-
- return entity
\ No newline at end of file
diff --git a/server/xinference/text_embedding/__init__.py b/server/xinference/text_embedding/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/server/xinference/text_embedding/text_embedding.py b/server/xinference/text_embedding/text_embedding.py
deleted file mode 100644
index 32d2b151..00000000
--- a/server/xinference/text_embedding/text_embedding.py
+++ /dev/null
@@ -1,201 +0,0 @@
-import time
-from typing import Optional
-
-from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle
-
-from core.model_runtime.entities.common_entities import I18nObject
-from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
-from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
-from core.model_runtime.errors.invoke import (
- InvokeAuthorizationError,
- InvokeBadRequestError,
- InvokeConnectionError,
- InvokeError,
- InvokeRateLimitError,
- InvokeServerUnavailableError,
-)
-from core.model_runtime.errors.validate import CredentialsValidateFailedError
-from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
-
-
-class XinferenceTextEmbeddingModel(TextEmbeddingModel):
- """
- Model class for Xinference text embedding model.
- """
- def _invoke(self, model: str, credentials: dict,
- texts: list[str], user: Optional[str] = None) \
- -> TextEmbeddingResult:
- """
- Invoke text embedding model
-
- credentials should be like:
- {
- 'server_url': 'server url',
- 'model_uid': 'model uid',
- }
-
- :param model: model name
- :param credentials: model credentials
- :param texts: texts to embed
- :param user: unique user id
- :return: embeddings result
- """
- server_url = credentials['server_url']
- model_uid = credentials['model_uid']
-
- if server_url.endswith('/'):
- server_url = server_url[:-1]
-
- client = Client(base_url=server_url)
-
- try:
- handle = client.get_model(model_uid=model_uid)
- except RuntimeError as e:
- raise InvokeAuthorizationError(e)
-
- if not isinstance(handle, RESTfulEmbeddingModelHandle):
- raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
-
- try:
- embeddings = handle.create_embedding(input=texts)
- except RuntimeError as e:
- raise InvokeServerUnavailableError(e)
-
- """
- for convenience, the response json is like:
- class Embedding(TypedDict):
- object: Literal["list"]
- model: str
- data: List[EmbeddingData]
- usage: EmbeddingUsage
- class EmbeddingUsage(TypedDict):
- prompt_tokens: int
- total_tokens: int
- class EmbeddingData(TypedDict):
- index: int
- object: str
- embedding: List[float]
- """
-
- usage = embeddings['usage']
- usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens'])
-
- result = TextEmbeddingResult(
- model=model,
- embeddings=[embedding['embedding'] for embedding in embeddings['data']],
- usage=usage
- )
-
- return result
-
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
- """
- Get number of tokens for given prompt messages
-
- :param model: model name
- :param credentials: model credentials
- :param texts: texts to embed
- :return:
- """
- num_tokens = 0
- for text in texts:
- # use GPT2Tokenizer to get num tokens
- num_tokens += self._get_num_tokens_by_gpt2(text)
- return num_tokens
-
- def validate_credentials(self, model: str, credentials: dict) -> None:
- """
- Validate model credentials
-
- :param model: model name
- :param credentials: model credentials
- :return:
- """
- try:
- if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
- raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
-
- server_url = credentials['server_url']
- model_uid = credentials['model_uid']
- extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
-
- if extra_args.max_tokens:
- credentials['max_tokens'] = extra_args.max_tokens
-
- self._invoke(model=model, credentials=credentials, texts=['ping'])
- except InvokeAuthorizationError as e:
- raise CredentialsValidateFailedError(f'Failed to validate credentials for model {model}: {e}')
- except RuntimeError as e:
- raise CredentialsValidateFailedError(e)
-
- @property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
- return {
- InvokeConnectionError: [
- InvokeConnectionError
- ],
- InvokeServerUnavailableError: [
- InvokeServerUnavailableError
- ],
- InvokeRateLimitError: [
- InvokeRateLimitError
- ],
- InvokeAuthorizationError: [
- InvokeAuthorizationError
- ],
- InvokeBadRequestError: [
- KeyError
- ]
- }
-
- def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
- """
- Calculate response usage
-
- :param model: model name
- :param credentials: model credentials
- :param tokens: input tokens
- :return: usage
- """
- # get input price info
- input_price_info = self.get_price(
- model=model,
- credentials=credentials,
- price_type=PriceType.INPUT,
- tokens=tokens
- )
-
- # transform usage
- usage = EmbeddingUsage(
- tokens=tokens,
- total_tokens=tokens,
- unit_price=input_price_info.unit_price,
- price_unit=input_price_info.unit,
- total_price=input_price_info.total_amount,
- currency=input_price_info.currency,
- latency=time.perf_counter() - self.started_at
- )
-
- return usage
-
- def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
- """
- used to define customizable model schema
- """
-
- entity = AIModelEntity(
- model=model,
- label=I18nObject(
- en_US=model
- ),
- fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
- model_type=ModelType.TEXT_EMBEDDING,
- model_properties={
- ModelPropertyKey.MAX_CHUNKS: 1,
- ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512,
- },
- parameter_rules=[]
- )
-
- return entity
\ No newline at end of file
diff --git a/server/xinference/xinference.py b/server/xinference/xinference.py
deleted file mode 100644
index d85f7c82..00000000
--- a/server/xinference/xinference.py
+++ /dev/null
@@ -1,10 +0,0 @@
-import logging
-
-from core.model_runtime.model_providers.__base.model_provider import ModelProvider
-
-logger = logging.getLogger(__name__)
-
-
-class XinferenceAIProvider(ModelProvider):
- def validate_provider_credentials(self, credentials: dict) -> None:
- pass
diff --git a/server/xinference/xinference.yaml b/server/xinference/xinference.yaml
deleted file mode 100644
index bb6c6d86..00000000
--- a/server/xinference/xinference.yaml
+++ /dev/null
@@ -1,47 +0,0 @@
-provider: xinference
-label:
- en_US: Xorbits Inference
-icon_small:
- en_US: icon_s_en.svg
-icon_large:
- en_US: icon_l_en.svg
-background: "#FAF5FF"
-help:
- title:
- en_US: How to deploy Xinference
- zh_Hans: 如何部署 Xinference
- url:
- en_US: https://github.com/xorbitsai/inference
-supported_model_types:
- - llm
- - text-embedding
- - rerank
-configurate_methods:
- - customizable-model
-model_credential_schema:
- model:
- label:
- en_US: Model Name
- zh_Hans: 模型名称
- placeholder:
- en_US: Enter your model name
- zh_Hans: 输入模型名称
- credential_form_schemas:
- - variable: server_url
- label:
- zh_Hans: 服务器URL
- en_US: Server url
- type: secret-input
- required: true
- placeholder:
- zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997
- en_US: Enter the url of your Xinference, e.g. http://192.168.1.100:9997
- - variable: model_uid
- label:
- zh_Hans: 模型UID
- en_US: Model uid
- type: text-input
- required: true
- placeholder:
- zh_Hans: 在此输入您的Model UID
- en_US: Enter the model uid
diff --git a/server/xinference/xinference_helper.py b/server/xinference/xinference_helper.py
deleted file mode 100644
index 66dab658..00000000
--- a/server/xinference/xinference_helper.py
+++ /dev/null
@@ -1,103 +0,0 @@
-from threading import Lock
-from time import time
-
-from requests.adapters import HTTPAdapter
-from requests.exceptions import ConnectionError, MissingSchema, Timeout
-from requests.sessions import Session
-from yarl import URL
-
-
-class XinferenceModelExtraParameter:
- model_format: str
- model_handle_type: str
- model_ability: list[str]
- max_tokens: int = 512
- context_length: int = 2048
- support_function_call: bool = False
-
- def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
- support_function_call: bool, max_tokens: int, context_length: int) -> None:
- self.model_format = model_format
- self.model_handle_type = model_handle_type
- self.model_ability = model_ability
- self.support_function_call = support_function_call
- self.max_tokens = max_tokens
- self.context_length = context_length
-
-cache = {}
-cache_lock = Lock()
-
-class XinferenceHelper:
- @staticmethod
- def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
- XinferenceHelper._clean_cache()
- with cache_lock:
- if model_uid not in cache:
- cache[model_uid] = {
- 'expires': time() + 300,
- 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid)
- }
- return cache[model_uid]['value']
-
- @staticmethod
- def _clean_cache() -> None:
- try:
- with cache_lock:
- expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()]
- for model_uid in expired_keys:
- del cache[model_uid]
- except RuntimeError as e:
- pass
-
- @staticmethod
- def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
- """
- get xinference model extra parameter like model_format and model_handle_type
- """
-
- if not model_uid or not model_uid.strip() or not server_url or not server_url.strip():
- raise RuntimeError('model_uid is empty')
-
- url = str(URL(server_url) / 'v1' / 'models' / model_uid)
-
- # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
- session = Session()
- session.mount('http://', HTTPAdapter(max_retries=3))
- session.mount('https://', HTTPAdapter(max_retries=3))
-
- try:
- response = session.get(url, timeout=10)
- except (MissingSchema, ConnectionError, Timeout) as e:
- raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
- if response.status_code != 200:
- raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')
-
- response_json = response.json()
-
- model_format = response_json.get('model_format', 'ggmlv3')
- model_ability = response_json.get('model_ability', [])
-
- if response_json.get('model_type') == 'embedding':
- model_handle_type = 'embedding'
- elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
- model_handle_type = 'chatglm'
- elif 'generate' in model_ability:
- model_handle_type = 'generate'
- elif 'chat' in model_ability:
- model_handle_type = 'chat'
- else:
- raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
-
- support_function_call = 'tools' in model_ability
- max_tokens = response_json.get('max_tokens', 512)
-
- context_length = response_json.get('context_length', 2048)
-
- return XinferenceModelExtraParameter(
- model_format=model_format,
- model_handle_type=model_handle_type,
- model_ability=model_ability,
- support_function_call=support_function_call,
- max_tokens=max_tokens,
- context_length=context_length
- )
\ No newline at end of file