diff --git a/server/xinference/__init__.py b/server/xinference/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/server/xinference/_assets/icon_l_en.svg b/server/xinference/_assets/icon_l_en.svg
new file mode 100644
index 00000000..81091765
--- /dev/null
+++ b/server/xinference/_assets/icon_l_en.svg
@@ -0,0 +1,42 @@
+
diff --git a/server/xinference/_assets/icon_s_en.svg b/server/xinference/_assets/icon_s_en.svg
new file mode 100644
index 00000000..f5c5f75e
--- /dev/null
+++ b/server/xinference/_assets/icon_s_en.svg
@@ -0,0 +1,24 @@
+
diff --git a/server/xinference/llm/__init__.py b/server/xinference/llm/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/server/xinference/llm/llm.py b/server/xinference/llm/llm.py
new file mode 100644
index 00000000..602d0b74
--- /dev/null
+++ b/server/xinference/llm/llm.py
@@ -0,0 +1,734 @@
+from collections.abc import Generator, Iterator
+from typing import cast
+
+from openai import (
+ APIConnectionError,
+ APITimeoutError,
+ AuthenticationError,
+ ConflictError,
+ InternalServerError,
+ NotFoundError,
+ OpenAI,
+ PermissionDeniedError,
+ RateLimitError,
+ UnprocessableEntityError,
+)
+from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
+from openai.types.chat.chat_completion_message import FunctionCall
+from openai.types.completion import Completion
+from xinference_client.client.restful.restful_client import (
+ Client,
+ RESTfulChatglmCppChatModelHandle,
+ RESTfulChatModelHandle,
+ RESTfulGenerateModelHandle,
+)
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
+ PromptMessage,
+ PromptMessageTool,
+ SystemPromptMessage,
+ ToolPromptMessage,
+ UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import (
+ AIModelEntity,
+ FetchFrom,
+ ModelFeature,
+ ModelPropertyKey,
+ ModelType,
+ ParameterRule,
+ ParameterType,
+)
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_runtime.model_providers.xinference.xinference_helper import (
+ XinferenceHelper,
+ XinferenceModelExtraParameter,
+)
+from core.model_runtime.utils import helper
+
+
+class XinferenceAILargeLanguageModel(LargeLanguageModel):
+ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ model_parameters: dict, tools: list[PromptMessageTool] | None = None,
+ stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+ -> LLMResult | Generator:
+ """
+ invoke LLM
+
+ see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
+ """
+ if 'temperature' in model_parameters:
+ if model_parameters['temperature'] < 0.01:
+ model_parameters['temperature'] = 0.01
+ elif model_parameters['temperature'] > 1.0:
+ model_parameters['temperature'] = 0.99
+
+ return self._generate(
+ model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
+ tools=tools, stop=stop, stream=stream, user=user,
+ extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter(
+ server_url=credentials['server_url'],
+ model_uid=credentials['model_uid']
+ )
+ )
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ validate credentials
+
+ credentials should be like:
+ {
+ 'model_type': 'text-generation',
+ 'server_url': 'server url',
+ 'model_uid': 'model uid',
+ }
+ """
+ try:
+ if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
+ raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
+
+ extra_param = XinferenceHelper.get_xinference_extra_parameter(
+ server_url=credentials['server_url'],
+ model_uid=credentials['model_uid']
+ )
+ if 'completion_type' not in credentials:
+ if 'chat' in extra_param.model_ability:
+ credentials['completion_type'] = 'chat'
+ elif 'generate' in extra_param.model_ability:
+ credentials['completion_type'] = 'completion'
+ else:
+ raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type')
+
+ if extra_param.support_function_call:
+ credentials['support_function_call'] = True
+
+ if extra_param.context_length:
+ credentials['context_length'] = extra_param.context_length
+
+ except RuntimeError as e:
+ raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
+ except KeyError as e:
+ raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
+ except Exception as e:
+ raise e
+
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ tools: list[PromptMessageTool] | None = None) -> int:
+ """
+ get number of tokens
+
+ cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use
+ so we just take the GPT2 tokenizer as default
+ """
+ return self._num_tokens_from_messages(prompt_messages, tools)
+
+ def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
+ is_completion_model: bool = False) -> int:
+ def tokens(text: str):
+ return self._get_num_tokens_by_gpt2(text)
+
+ if is_completion_model:
+ return sum([tokens(str(message.content)) for message in messages])
+
+ tokens_per_message = 3
+ tokens_per_name = 1
+
+ num_tokens = 0
+ messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
+ for message in messages_dict:
+ num_tokens += tokens_per_message
+ for key, value in message.items():
+ if isinstance(value, list):
+ text = ''
+ for item in value:
+ if isinstance(item, dict) and item['type'] == 'text':
+ text += item.text
+
+ value = text
+
+ if key == "tool_calls":
+ for tool_call in value:
+ for t_key, t_value in tool_call.items():
+ num_tokens += tokens(t_key)
+ if t_key == "function":
+ for f_key, f_value in t_value.items():
+ num_tokens += tokens(f_key)
+ num_tokens += tokens(f_value)
+ else:
+ num_tokens += tokens(t_key)
+ num_tokens += tokens(t_value)
+ if key == "function_call":
+ for t_key, t_value in value.items():
+ num_tokens += tokens(t_key)
+ if t_key == "function":
+ for f_key, f_value in t_value.items():
+ num_tokens += tokens(f_key)
+ num_tokens += tokens(f_value)
+ else:
+ num_tokens += tokens(t_key)
+ num_tokens += tokens(t_value)
+ else:
+ num_tokens += tokens(str(value))
+
+ if key == "name":
+ num_tokens += tokens_per_name
+ num_tokens += 3
+
+ if tools:
+ num_tokens += self._num_tokens_for_tools(tools)
+
+ return num_tokens
+
+ def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
+ """
+ Calculate num tokens for tool calling
+
+ :param encoding: encoding
+ :param tools: tools for tool calling
+ :return: number of tokens
+ """
+ def tokens(text: str):
+ return self._get_num_tokens_by_gpt2(text)
+
+ num_tokens = 0
+ for tool in tools:
+ # calculate num tokens for function object
+ num_tokens += tokens('name')
+ num_tokens += tokens(tool.name)
+ num_tokens += tokens('description')
+ num_tokens += tokens(tool.description)
+ parameters = tool.parameters
+ num_tokens += tokens('parameters')
+ num_tokens += tokens('type')
+ num_tokens += tokens(parameters.get("type"))
+ if 'properties' in parameters:
+ num_tokens += tokens('properties')
+ for key, value in parameters.get('properties').items():
+ num_tokens += tokens(key)
+ for field_key, field_value in value.items():
+ num_tokens += tokens(field_key)
+ if field_key == 'enum':
+ for enum_field in field_value:
+ num_tokens += 3
+ num_tokens += tokens(enum_field)
+ else:
+ num_tokens += tokens(field_key)
+ num_tokens += tokens(str(field_value))
+ if 'required' in parameters:
+ num_tokens += tokens('required')
+ for required_field in parameters['required']:
+ num_tokens += 3
+ num_tokens += tokens(required_field)
+
+ return num_tokens
+
+ def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
+ """
+ convert prompt message to text
+ """
+ text = ''
+ for item in message:
+ if isinstance(item, UserPromptMessage):
+ text += item.content
+ elif isinstance(item, SystemPromptMessage):
+ text += item.content
+ elif isinstance(item, AssistantPromptMessage):
+ text += item.content
+ else:
+ raise NotImplementedError(f'PromptMessage type {type(item)} is not supported')
+ return text
+
+ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
+ """
+ Convert PromptMessage to dict for OpenAI Compatibility API
+ """
+ if isinstance(message, UserPromptMessage):
+ message = cast(UserPromptMessage, message)
+ if isinstance(message.content, str):
+ message_dict = {"role": "user", "content": message.content}
+ else:
+ raise ValueError("User message content must be str")
+ elif isinstance(message, AssistantPromptMessage):
+ message = cast(AssistantPromptMessage, message)
+ message_dict = {"role": "assistant", "content": message.content}
+ if message.tool_calls and len(message.tool_calls) > 0:
+ message_dict["function_call"] = {
+ "name": message.tool_calls[0].function.name,
+ "arguments": message.tool_calls[0].function.arguments
+ }
+ elif isinstance(message, SystemPromptMessage):
+ message = cast(SystemPromptMessage, message)
+ message_dict = {"role": "system", "content": message.content}
+ elif isinstance(message, ToolPromptMessage):
+ message = cast(ToolPromptMessage, message)
+ message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
+ else:
+ raise ValueError(f"Unknown message type {type(message)}")
+
+ return message_dict
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+ """
+ used to define customizable model schema
+ """
+ rules = [
+ ParameterRule(
+ name='temperature',
+ type=ParameterType.FLOAT,
+ use_template='temperature',
+ label=I18nObject(
+ zh_Hans='温度',
+ en_US='Temperature'
+ ),
+ ),
+ ParameterRule(
+ name='top_p',
+ type=ParameterType.FLOAT,
+ use_template='top_p',
+ label=I18nObject(
+ zh_Hans='Top P',
+ en_US='Top P'
+ )
+ ),
+ ParameterRule(
+ name='max_tokens',
+ type=ParameterType.INT,
+ use_template='max_tokens',
+ min=1,
+ max=credentials.get('context_length', 2048),
+ default=512,
+ label=I18nObject(
+ zh_Hans='最大生成长度',
+ en_US='Max Tokens'
+ )
+ )
+ ]
+
+ completion_type = None
+
+ if 'completion_type' in credentials:
+ if credentials['completion_type'] == 'chat':
+ completion_type = LLMMode.CHAT.value
+ elif credentials['completion_type'] == 'completion':
+ completion_type = LLMMode.COMPLETION.value
+ else:
+ raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
+ else:
+ extra_args = XinferenceHelper.get_xinference_extra_parameter(
+ server_url=credentials['server_url'],
+ model_uid=credentials['model_uid']
+ )
+
+ if 'chat' in extra_args.model_ability:
+ completion_type = LLMMode.CHAT.value
+ elif 'generate' in extra_args.model_ability:
+ completion_type = LLMMode.COMPLETION.value
+ else:
+ raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
+
+ support_function_call = credentials.get('support_function_call', False)
+ context_length = credentials.get('context_length', 2048)
+
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(
+ en_US=model
+ ),
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_type=ModelType.LLM,
+ features=[
+ ModelFeature.TOOL_CALL
+ ] if support_function_call else [],
+ model_properties={
+ ModelPropertyKey.MODE: completion_type,
+ ModelPropertyKey.CONTEXT_SIZE: context_length
+ },
+ parameter_rules=rules
+ )
+
+ return entity
+
+ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter,
+ tools: list[PromptMessageTool] | None = None,
+ stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+ -> LLMResult | Generator:
+ """
+ generate text from LLM
+
+ see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate`
+
+ extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
+ """
+ if 'server_url' not in credentials:
+ raise CredentialsValidateFailedError('server_url is required in credentials')
+
+ if credentials['server_url'].endswith('/'):
+ credentials['server_url'] = credentials['server_url'][:-1]
+
+ client = OpenAI(
+ base_url=f'{credentials["server_url"]}/v1',
+ api_key='abc',
+ max_retries=3,
+ timeout=60,
+ )
+
+ xinference_client = Client(
+ base_url=credentials['server_url'],
+ )
+
+ xinference_model = xinference_client.get_model(credentials['model_uid'])
+
+ generate_config = {
+ 'temperature': model_parameters.get('temperature', 1.0),
+ 'top_p': model_parameters.get('top_p', 0.7),
+ 'max_tokens': model_parameters.get('max_tokens', 512),
+ }
+
+ if stop:
+ generate_config['stop'] = stop
+
+ if tools and len(tools) > 0:
+ generate_config['tools'] = [
+ {
+ 'type': 'function',
+ 'function': helper.dump_model(tool)
+ } for tool in tools
+ ]
+
+ if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle):
+ resp = client.chat.completions.create(
+ model=credentials['model_uid'],
+ messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages],
+ stream=stream,
+ user=user,
+ **generate_config,
+ )
+ if stream:
+ if tools and len(tools) > 0:
+ raise InvokeBadRequestError('xinference tool calls does not support stream mode')
+ return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
+ tools=tools, resp=resp)
+ return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
+ tools=tools, resp=resp)
+ elif isinstance(xinference_model, RESTfulGenerateModelHandle):
+ resp = client.completions.create(
+ model=credentials['model_uid'],
+ prompt=self._convert_prompt_message_to_text(prompt_messages),
+ stream=stream,
+ user=user,
+ **generate_config,
+ )
+ if stream:
+ return self._handle_completion_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
+ tools=tools, resp=resp)
+ return self._handle_completion_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
+ tools=tools, resp=resp)
+ else:
+ raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported')
+
+ def _extract_response_tool_calls(self,
+ response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
+ -> list[AssistantPromptMessage.ToolCall]:
+ """
+ Extract tool calls from response
+
+ :param response_tool_calls: response tool calls
+ :return: list of tool calls
+ """
+ tool_calls = []
+ if response_tool_calls:
+ for response_tool_call in response_tool_calls:
+ function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+ name=response_tool_call.function.name,
+ arguments=response_tool_call.function.arguments
+ )
+
+ tool_call = AssistantPromptMessage.ToolCall(
+ id=response_tool_call.id,
+ type=response_tool_call.type,
+ function=function
+ )
+ tool_calls.append(tool_call)
+
+ return tool_calls
+
+ def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
+ -> AssistantPromptMessage.ToolCall:
+ """
+ Extract function call from response
+
+ :param response_function_call: response function call
+ :return: tool call
+ """
+ tool_call = None
+ if response_function_call:
+ function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+ name=response_function_call.name,
+ arguments=response_function_call.arguments
+ )
+
+ tool_call = AssistantPromptMessage.ToolCall(
+ id=response_function_call.name,
+ type="function",
+ function=function
+ )
+
+ return tool_call
+
+ def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ tools: list[PromptMessageTool],
+ resp: ChatCompletion) -> LLMResult:
+ """
+ handle normal chat generate response
+ """
+ if len(resp.choices) == 0:
+ raise InvokeServerUnavailableError("Empty response")
+
+ assistant_message = resp.choices[0].message
+
+ # convert tool call to assistant message tool call
+ tool_calls = assistant_message.tool_calls
+ assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else [])
+ function_call = assistant_message.function_call
+ if function_call:
+ assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)]
+
+ # transform assistant message to prompt message
+ assistant_prompt_message = AssistantPromptMessage(
+ content=assistant_message.content,
+ tool_calls=assistant_prompt_message_tool_calls
+ )
+
+ prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
+ completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
+
+ usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
+
+ response = LLMResult(
+ model=model,
+ prompt_messages=prompt_messages,
+ system_fingerprint=resp.system_fingerprint,
+ usage=usage,
+ message=assistant_prompt_message,
+ )
+
+ return response
+
+ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ tools: list[PromptMessageTool],
+ resp: Iterator[ChatCompletionChunk]) -> Generator:
+ """
+ handle stream chat generate response
+ """
+ full_response = ''
+
+ for chunk in resp:
+ if len(chunk.choices) == 0:
+ continue
+
+ delta = chunk.choices[0]
+
+ if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
+ continue
+
+ # check if there is a tool call in the response
+ function_call = None
+ tool_calls = []
+ if delta.delta.tool_calls:
+ tool_calls += delta.delta.tool_calls
+ if delta.delta.function_call:
+ function_call = delta.delta.function_call
+
+ assistant_message_tool_calls = self._extract_response_tool_calls(tool_calls)
+ if function_call:
+ assistant_message_tool_calls += [self._extract_response_function_call(function_call)]
+
+ # transform assistant message to prompt message
+ assistant_prompt_message = AssistantPromptMessage(
+ content=delta.delta.content if delta.delta.content else '',
+ tool_calls=assistant_message_tool_calls
+ )
+
+ if delta.finish_reason is not None:
+ # temp_assistant_prompt_message is used to calculate usage
+ temp_assistant_prompt_message = AssistantPromptMessage(
+ content=full_response,
+ tool_calls=assistant_message_tool_calls
+ )
+
+ prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
+ completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
+
+ usage = self._calc_response_usage(model=model, credentials=credentials,
+ prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
+
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ system_fingerprint=chunk.system_fingerprint,
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=assistant_prompt_message,
+ finish_reason=delta.finish_reason,
+ usage=usage
+ ),
+ )
+ else:
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ system_fingerprint=chunk.system_fingerprint,
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=assistant_prompt_message,
+ ),
+ )
+
+ full_response += delta.delta.content
+
+ def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ tools: list[PromptMessageTool],
+ resp: Completion) -> LLMResult:
+ """
+ handle normal completion generate response
+ """
+ if len(resp.choices) == 0:
+ raise InvokeServerUnavailableError("Empty response")
+
+ assistant_message = resp.choices[0].text
+
+ # transform assistant message to prompt message
+ assistant_prompt_message = AssistantPromptMessage(
+ content=assistant_message,
+ tool_calls=[]
+ )
+
+ prompt_tokens = self._get_num_tokens_by_gpt2(
+ self._convert_prompt_message_to_text(prompt_messages)
+ )
+ completion_tokens = self._num_tokens_from_messages(
+ messages=[assistant_prompt_message], tools=[], is_completion_model=True
+ )
+ usage = self._calc_response_usage(
+ model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
+ )
+
+ response = LLMResult(
+ model=model,
+ prompt_messages=prompt_messages,
+ system_fingerprint=resp.system_fingerprint,
+ usage=usage,
+ message=assistant_prompt_message,
+ )
+
+ return response
+
+ def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ tools: list[PromptMessageTool],
+ resp: Iterator[Completion]) -> Generator:
+ """
+ handle stream completion generate response
+ """
+ full_response = ''
+
+ for chunk in resp:
+ if len(chunk.choices) == 0:
+ continue
+
+ delta = chunk.choices[0]
+
+ # transform assistant message to prompt message
+ assistant_prompt_message = AssistantPromptMessage(
+ content=delta.text if delta.text else '',
+ tool_calls=[]
+ )
+
+ if delta.finish_reason is not None:
+ # temp_assistant_prompt_message is used to calculate usage
+ temp_assistant_prompt_message = AssistantPromptMessage(
+ content=full_response,
+ tool_calls=[]
+ )
+
+ prompt_tokens = self._get_num_tokens_by_gpt2(
+ self._convert_prompt_message_to_text(prompt_messages)
+ )
+ completion_tokens = self._num_tokens_from_messages(
+ messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True
+ )
+ usage = self._calc_response_usage(model=model, credentials=credentials,
+ prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
+
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ system_fingerprint=chunk.system_fingerprint,
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=assistant_prompt_message,
+ finish_reason=delta.finish_reason,
+ usage=usage
+ ),
+ )
+ else:
+ if delta.text is None or delta.text == '':
+ continue
+
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ system_fingerprint=chunk.system_fingerprint,
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=assistant_prompt_message,
+ ),
+ )
+
+ full_response += delta.text
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the error type thrown to the caller
+ The value is the error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke error mapping
+ """
+ return {
+ InvokeConnectionError: [
+ APIConnectionError,
+ APITimeoutError,
+ ],
+ InvokeServerUnavailableError: [
+ InternalServerError,
+ ConflictError,
+ NotFoundError,
+ UnprocessableEntityError,
+ PermissionDeniedError
+ ],
+ InvokeRateLimitError: [
+ RateLimitError
+ ],
+ InvokeAuthorizationError: [
+ AuthenticationError
+ ],
+ InvokeBadRequestError: [
+ ValueError
+ ]
+ }
\ No newline at end of file
diff --git a/server/xinference/rerank/__init__.py b/server/xinference/rerank/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/server/xinference/rerank/rerank.py b/server/xinference/rerank/rerank.py
new file mode 100644
index 00000000..dd25037d
--- /dev/null
+++ b/server/xinference/rerank/rerank.py
@@ -0,0 +1,160 @@
+from typing import Optional
+
+from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
+from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+
+
+class XinferenceRerankModel(RerankModel):
+ """
+ Model class for Xinference rerank model.
+ """
+
+ def _invoke(self, model: str, credentials: dict,
+ query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
+ user: Optional[str] = None) \
+ -> RerankResult:
+ """
+ Invoke rerank model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param query: search query
+ :param docs: docs for reranking
+ :param score_threshold: score threshold
+ :param top_n: top n
+ :param user: unique user id
+ :return: rerank result
+ """
+ if len(docs) == 0:
+ return RerankResult(
+ model=model,
+ docs=[]
+ )
+
+ if credentials['server_url'].endswith('/'):
+ credentials['server_url'] = credentials['server_url'][:-1]
+
+ # initialize client
+ client = Client(
+ base_url=credentials['server_url']
+ )
+
+ xinference_client = client.get_model(model_uid=credentials['model_uid'])
+
+ if not isinstance(xinference_client, RESTfulRerankModelHandle):
+ raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model')
+
+ response = xinference_client.rerank(
+ documents=docs,
+ query=query,
+ top_n=top_n,
+ )
+
+ rerank_documents = []
+ for idx, result in enumerate(response['results']):
+ # format document
+ index = result['index']
+ page_content = result['document']
+ rerank_document = RerankDocument(
+ index=index,
+ text=page_content,
+ score=result['relevance_score'],
+ )
+
+ # score threshold check
+ if score_threshold is not None:
+ if result['relevance_score'] >= score_threshold:
+ rerank_documents.append(rerank_document)
+ else:
+ rerank_documents.append(rerank_document)
+
+ return RerankResult(
+ model=model,
+ docs=rerank_documents
+ )
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+ if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
+ raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
+
+ self.invoke(
+ model=model,
+ credentials=credentials,
+ query="Whose kasumi",
+ docs=[
+ "Kasumi is a girl's name of Japanese origin meaning \"mist\".",
+ "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
+ "and she leads a team named PopiParty."
+ ],
+ score_threshold=0.8
+ )
+ except Exception as ex:
+ raise CredentialsValidateFailedError(str(ex))
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the error type thrown to the caller
+ The value is the error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke error mapping
+ """
+ return {
+ InvokeConnectionError: [
+ InvokeConnectionError
+ ],
+ InvokeServerUnavailableError: [
+ InvokeServerUnavailableError
+ ],
+ InvokeRateLimitError: [
+ InvokeRateLimitError
+ ],
+ InvokeAuthorizationError: [
+ InvokeAuthorizationError
+ ],
+ InvokeBadRequestError: [
+ InvokeBadRequestError,
+ KeyError,
+ ValueError
+ ]
+ }
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+ """
+ used to define customizable model schema
+ """
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(
+ en_US=model
+ ),
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_type=ModelType.RERANK,
+ model_properties={ },
+ parameter_rules=[]
+ )
+
+ return entity
\ No newline at end of file
diff --git a/server/xinference/text_embedding/__init__.py b/server/xinference/text_embedding/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/server/xinference/text_embedding/text_embedding.py b/server/xinference/text_embedding/text_embedding.py
new file mode 100644
index 00000000..32d2b151
--- /dev/null
+++ b/server/xinference/text_embedding/text_embedding.py
@@ -0,0 +1,201 @@
+import time
+from typing import Optional
+
+from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
+
+
+class XinferenceTextEmbeddingModel(TextEmbeddingModel):
+ """
+ Model class for Xinference text embedding model.
+ """
+ def _invoke(self, model: str, credentials: dict,
+ texts: list[str], user: Optional[str] = None) \
+ -> TextEmbeddingResult:
+ """
+ Invoke text embedding model
+
+ credentials should be like:
+ {
+ 'server_url': 'server url',
+ 'model_uid': 'model uid',
+ }
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :param user: unique user id
+ :return: embeddings result
+ """
+ server_url = credentials['server_url']
+ model_uid = credentials['model_uid']
+
+ if server_url.endswith('/'):
+ server_url = server_url[:-1]
+
+ client = Client(base_url=server_url)
+
+ try:
+ handle = client.get_model(model_uid=model_uid)
+ except RuntimeError as e:
+ raise InvokeAuthorizationError(e)
+
+ if not isinstance(handle, RESTfulEmbeddingModelHandle):
+ raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
+
+ try:
+ embeddings = handle.create_embedding(input=texts)
+ except RuntimeError as e:
+ raise InvokeServerUnavailableError(e)
+
+ """
+ for convenience, the response json is like:
+ class Embedding(TypedDict):
+ object: Literal["list"]
+ model: str
+ data: List[EmbeddingData]
+ usage: EmbeddingUsage
+ class EmbeddingUsage(TypedDict):
+ prompt_tokens: int
+ total_tokens: int
+ class EmbeddingData(TypedDict):
+ index: int
+ object: str
+ embedding: List[float]
+ """
+
+ usage = embeddings['usage']
+ usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens'])
+
+ result = TextEmbeddingResult(
+ model=model,
+ embeddings=[embedding['embedding'] for embedding in embeddings['data']],
+ usage=usage
+ )
+
+ return result
+
+ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ """
+ Get number of tokens for given prompt messages
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :return:
+ """
+ num_tokens = 0
+ for text in texts:
+ # use GPT2Tokenizer to get num tokens
+ num_tokens += self._get_num_tokens_by_gpt2(text)
+ return num_tokens
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+ if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
+ raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
+
+ server_url = credentials['server_url']
+ model_uid = credentials['model_uid']
+ extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
+
+ if extra_args.max_tokens:
+ credentials['max_tokens'] = extra_args.max_tokens
+
+ self._invoke(model=model, credentials=credentials, texts=['ping'])
+ except InvokeAuthorizationError as e:
+ raise CredentialsValidateFailedError(f'Failed to validate credentials for model {model}: {e}')
+ except RuntimeError as e:
+ raise CredentialsValidateFailedError(e)
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ return {
+ InvokeConnectionError: [
+ InvokeConnectionError
+ ],
+ InvokeServerUnavailableError: [
+ InvokeServerUnavailableError
+ ],
+ InvokeRateLimitError: [
+ InvokeRateLimitError
+ ],
+ InvokeAuthorizationError: [
+ InvokeAuthorizationError
+ ],
+ InvokeBadRequestError: [
+ KeyError
+ ]
+ }
+
+ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+ """
+ Calculate response usage
+
+ :param model: model name
+ :param credentials: model credentials
+ :param tokens: input tokens
+ :return: usage
+ """
+ # get input price info
+ input_price_info = self.get_price(
+ model=model,
+ credentials=credentials,
+ price_type=PriceType.INPUT,
+ tokens=tokens
+ )
+
+ # transform usage
+ usage = EmbeddingUsage(
+ tokens=tokens,
+ total_tokens=tokens,
+ unit_price=input_price_info.unit_price,
+ price_unit=input_price_info.unit,
+ total_price=input_price_info.total_amount,
+ currency=input_price_info.currency,
+ latency=time.perf_counter() - self.started_at
+ )
+
+ return usage
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+ """
+ used to define customizable model schema
+ """
+
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(
+ en_US=model
+ ),
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_type=ModelType.TEXT_EMBEDDING,
+ model_properties={
+ ModelPropertyKey.MAX_CHUNKS: 1,
+ ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512,
+ },
+ parameter_rules=[]
+ )
+
+ return entity
\ No newline at end of file
diff --git a/server/xinference/xinference.py b/server/xinference/xinference.py
new file mode 100644
index 00000000..d85f7c82
--- /dev/null
+++ b/server/xinference/xinference.py
@@ -0,0 +1,10 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class XinferenceAIProvider(ModelProvider):
+ def validate_provider_credentials(self, credentials: dict) -> None:
+ pass
diff --git a/server/xinference/xinference.yaml b/server/xinference/xinference.yaml
new file mode 100644
index 00000000..bb6c6d86
--- /dev/null
+++ b/server/xinference/xinference.yaml
@@ -0,0 +1,47 @@
+provider: xinference
+label:
+ en_US: Xorbits Inference
+icon_small:
+ en_US: icon_s_en.svg
+icon_large:
+ en_US: icon_l_en.svg
+background: "#FAF5FF"
+help:
+ title:
+ en_US: How to deploy Xinference
+ zh_Hans: 如何部署 Xinference
+ url:
+ en_US: https://github.com/xorbitsai/inference
+supported_model_types:
+ - llm
+ - text-embedding
+ - rerank
+configurate_methods:
+ - customizable-model
+model_credential_schema:
+ model:
+ label:
+ en_US: Model Name
+ zh_Hans: 模型名称
+ placeholder:
+ en_US: Enter your model name
+ zh_Hans: 输入模型名称
+ credential_form_schemas:
+ - variable: server_url
+ label:
+ zh_Hans: 服务器URL
+ en_US: Server url
+ type: secret-input
+ required: true
+ placeholder:
+ zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997
+ en_US: Enter the url of your Xinference, e.g. http://192.168.1.100:9997
+ - variable: model_uid
+ label:
+ zh_Hans: 模型UID
+ en_US: Model uid
+ type: text-input
+ required: true
+ placeholder:
+ zh_Hans: 在此输入您的Model UID
+ en_US: Enter the model uid
diff --git a/server/xinference/xinference_helper.py b/server/xinference/xinference_helper.py
new file mode 100644
index 00000000..66dab658
--- /dev/null
+++ b/server/xinference/xinference_helper.py
@@ -0,0 +1,103 @@
+from threading import Lock
+from time import time
+
+from requests.adapters import HTTPAdapter
+from requests.exceptions import ConnectionError, MissingSchema, Timeout
+from requests.sessions import Session
+from yarl import URL
+
+
+class XinferenceModelExtraParameter:
+ model_format: str
+ model_handle_type: str
+ model_ability: list[str]
+ max_tokens: int = 512
+ context_length: int = 2048
+ support_function_call: bool = False
+
+ def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
+ support_function_call: bool, max_tokens: int, context_length: int) -> None:
+ self.model_format = model_format
+ self.model_handle_type = model_handle_type
+ self.model_ability = model_ability
+ self.support_function_call = support_function_call
+ self.max_tokens = max_tokens
+ self.context_length = context_length
+
+cache = {}
+cache_lock = Lock()
+
+class XinferenceHelper:
+ @staticmethod
+ def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
+ XinferenceHelper._clean_cache()
+ with cache_lock:
+ if model_uid not in cache:
+ cache[model_uid] = {
+ 'expires': time() + 300,
+ 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid)
+ }
+ return cache[model_uid]['value']
+
+ @staticmethod
+ def _clean_cache() -> None:
+ try:
+ with cache_lock:
+ expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()]
+ for model_uid in expired_keys:
+ del cache[model_uid]
+ except RuntimeError as e:
+ pass
+
+ @staticmethod
+ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
+ """
+ get xinference model extra parameter like model_format and model_handle_type
+ """
+
+ if not model_uid or not model_uid.strip() or not server_url or not server_url.strip():
+ raise RuntimeError('model_uid is empty')
+
+ url = str(URL(server_url) / 'v1' / 'models' / model_uid)
+
+ # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
+ session = Session()
+ session.mount('http://', HTTPAdapter(max_retries=3))
+ session.mount('https://', HTTPAdapter(max_retries=3))
+
+ try:
+ response = session.get(url, timeout=10)
+ except (MissingSchema, ConnectionError, Timeout) as e:
+ raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
+ if response.status_code != 200:
+ raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')
+
+ response_json = response.json()
+
+ model_format = response_json.get('model_format', 'ggmlv3')
+ model_ability = response_json.get('model_ability', [])
+
+ if response_json.get('model_type') == 'embedding':
+ model_handle_type = 'embedding'
+ elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
+ model_handle_type = 'chatglm'
+ elif 'generate' in model_ability:
+ model_handle_type = 'generate'
+ elif 'chat' in model_ability:
+ model_handle_type = 'chat'
+ else:
+ raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
+
+ support_function_call = 'tools' in model_ability
+ max_tokens = response_json.get('max_tokens', 512)
+
+ context_length = response_json.get('context_length', 2048)
+
+ return XinferenceModelExtraParameter(
+ model_format=model_format,
+ model_handle_type=model_handle_type,
+ model_ability=model_ability,
+ support_function_call=support_function_call,
+ max_tokens=max_tokens,
+ context_length=context_length
+ )
\ No newline at end of file