diff --git a/server/xinference/__init__.py b/server/xinference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/xinference/_assets/icon_l_en.svg b/server/xinference/_assets/icon_l_en.svg new file mode 100644 index 00000000..81091765 --- /dev/null +++ b/server/xinference/_assets/icon_l_en.svg @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/server/xinference/_assets/icon_s_en.svg b/server/xinference/_assets/icon_s_en.svg new file mode 100644 index 00000000..f5c5f75e --- /dev/null +++ b/server/xinference/_assets/icon_s_en.svg @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/server/xinference/llm/__init__.py b/server/xinference/llm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/xinference/llm/llm.py b/server/xinference/llm/llm.py new file mode 100644 index 00000000..602d0b74 --- /dev/null +++ b/server/xinference/llm/llm.py @@ -0,0 +1,734 @@ +from collections.abc import Generator, Iterator +from typing import cast + +from openai import ( + APIConnectionError, + APITimeoutError, + AuthenticationError, + ConflictError, + InternalServerError, + NotFoundError, + OpenAI, + PermissionDeniedError, + RateLimitError, + UnprocessableEntityError, +) +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall +from openai.types.chat.chat_completion_message import FunctionCall +from openai.types.completion import Completion +from xinference_client.client.restful.restful_client import ( + Client, + RESTfulChatglmCppChatModelHandle, + RESTfulChatModelHandle, + RESTfulGenerateModelHandle, +) + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.xinference.xinference_helper import ( + XinferenceHelper, + XinferenceModelExtraParameter, +) +from core.model_runtime.utils import helper + + +class XinferenceAILargeLanguageModel(LargeLanguageModel): + def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + """ + invoke LLM + + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` + """ + if 'temperature' in model_parameters: + if model_parameters['temperature'] < 0.01: + model_parameters['temperature'] = 0.01 + elif model_parameters['temperature'] > 1.0: + model_parameters['temperature'] = 0.99 + + return self._generate( + model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, + tools=tools, stop=stop, stream=stream, user=user, + extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( + server_url=credentials['server_url'], + model_uid=credentials['model_uid'] + ) + ) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + validate credentials + + credentials should be like: + { + 'model_type': 'text-generation', + 'server_url': 'server url', + 'model_uid': 'model uid', + } + """ + try: + if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + + extra_param = XinferenceHelper.get_xinference_extra_parameter( + server_url=credentials['server_url'], + model_uid=credentials['model_uid'] + ) + if 'completion_type' not in credentials: + if 'chat' in extra_param.model_ability: + credentials['completion_type'] = 'chat' + elif 'generate' in extra_param.model_ability: + credentials['completion_type'] = 'completion' + else: + raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type') + + if extra_param.support_function_call: + credentials['support_function_call'] = True + + if extra_param.context_length: + credentials['context_length'] = extra_param.context_length + + except RuntimeError as e: + raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + except KeyError as e: + raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + except Exception as e: + raise e + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None) -> int: + """ + get number of tokens + + cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use + so we just take the GPT2 tokenizer as default + """ + return self._num_tokens_from_messages(prompt_messages, tools) + + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], + is_completion_model: bool = False) -> int: + def tokens(text: str): + return self._get_num_tokens_by_gpt2(text) + + if is_completion_model: + return sum([tokens(str(message.content)) for message in messages]) + + tokens_per_message = 3 + tokens_per_name = 1 + + num_tokens = 0 + messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + if isinstance(value, list): + text = '' + for item in value: + if isinstance(item, dict) and item['type'] == 'text': + text += item.text + + value = text + + if key == "tool_calls": + for tool_call in value: + for t_key, t_value in tool_call.items(): + num_tokens += tokens(t_key) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += tokens(f_key) + num_tokens += tokens(f_value) + else: + num_tokens += tokens(t_key) + num_tokens += tokens(t_value) + if key == "function_call": + for t_key, t_value in value.items(): + num_tokens += tokens(t_key) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += tokens(f_key) + num_tokens += tokens(f_value) + else: + num_tokens += tokens(t_key) + num_tokens += tokens(t_value) + else: + num_tokens += tokens(str(value)) + + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 + + if tools: + num_tokens += self._num_tokens_for_tools(tools) + + return num_tokens + + def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: + """ + Calculate num tokens for tool calling + + :param encoding: encoding + :param tools: tools for tool calling + :return: number of tokens + """ + def tokens(text: str): + return self._get_num_tokens_by_gpt2(text) + + num_tokens = 0 + for tool in tools: + # calculate num tokens for function object + num_tokens += tokens('name') + num_tokens += tokens(tool.name) + num_tokens += tokens('description') + num_tokens += tokens(tool.description) + parameters = tool.parameters + num_tokens += tokens('parameters') + num_tokens += tokens('type') + num_tokens += tokens(parameters.get("type")) + if 'properties' in parameters: + num_tokens += tokens('properties') + for key, value in parameters.get('properties').items(): + num_tokens += tokens(key) + for field_key, field_value in value.items(): + num_tokens += tokens(field_key) + if field_key == 'enum': + for enum_field in field_value: + num_tokens += 3 + num_tokens += tokens(enum_field) + else: + num_tokens += tokens(field_key) + num_tokens += tokens(str(field_value)) + if 'required' in parameters: + num_tokens += tokens('required') + for required_field in parameters['required']: + num_tokens += 3 + num_tokens += tokens(required_field) + + return num_tokens + + def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: + """ + convert prompt message to text + """ + text = '' + for item in message: + if isinstance(item, UserPromptMessage): + text += item.content + elif isinstance(item, SystemPromptMessage): + text += item.content + elif isinstance(item, AssistantPromptMessage): + text += item.content + else: + raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') + return text + + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict for OpenAI Compatibility API + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + raise ValueError("User message content must be str") + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls and len(message.tool_calls) > 0: + message_dict["function_call"] = { + "name": message.tool_calls[0].function.name, + "arguments": message.tool_calls[0].function.arguments + } + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} + else: + raise ValueError(f"Unknown message type {type(message)}") + + return message_dict + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + rules = [ + ParameterRule( + name='temperature', + type=ParameterType.FLOAT, + use_template='temperature', + label=I18nObject( + zh_Hans='温度', + en_US='Temperature' + ), + ), + ParameterRule( + name='top_p', + type=ParameterType.FLOAT, + use_template='top_p', + label=I18nObject( + zh_Hans='Top P', + en_US='Top P' + ) + ), + ParameterRule( + name='max_tokens', + type=ParameterType.INT, + use_template='max_tokens', + min=1, + max=credentials.get('context_length', 2048), + default=512, + label=I18nObject( + zh_Hans='最大生成长度', + en_US='Max Tokens' + ) + ) + ] + + completion_type = None + + if 'completion_type' in credentials: + if credentials['completion_type'] == 'chat': + completion_type = LLMMode.CHAT.value + elif credentials['completion_type'] == 'completion': + completion_type = LLMMode.COMPLETION.value + else: + raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') + else: + extra_args = XinferenceHelper.get_xinference_extra_parameter( + server_url=credentials['server_url'], + model_uid=credentials['model_uid'] + ) + + if 'chat' in extra_args.model_ability: + completion_type = LLMMode.CHAT.value + elif 'generate' in extra_args.model_ability: + completion_type = LLMMode.COMPLETION.value + else: + raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') + + support_function_call = credentials.get('support_function_call', False) + context_length = credentials.get('context_length', 2048) + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.LLM, + features=[ + ModelFeature.TOOL_CALL + ] if support_function_call else [], + model_properties={ + ModelPropertyKey.MODE: completion_type, + ModelPropertyKey.CONTEXT_SIZE: context_length + }, + parameter_rules=rules + ) + + return entity + + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + """ + generate text from LLM + + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` + + extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` + """ + if 'server_url' not in credentials: + raise CredentialsValidateFailedError('server_url is required in credentials') + + if credentials['server_url'].endswith('/'): + credentials['server_url'] = credentials['server_url'][:-1] + + client = OpenAI( + base_url=f'{credentials["server_url"]}/v1', + api_key='abc', + max_retries=3, + timeout=60, + ) + + xinference_client = Client( + base_url=credentials['server_url'], + ) + + xinference_model = xinference_client.get_model(credentials['model_uid']) + + generate_config = { + 'temperature': model_parameters.get('temperature', 1.0), + 'top_p': model_parameters.get('top_p', 0.7), + 'max_tokens': model_parameters.get('max_tokens', 512), + } + + if stop: + generate_config['stop'] = stop + + if tools and len(tools) > 0: + generate_config['tools'] = [ + { + 'type': 'function', + 'function': helper.dump_model(tool) + } for tool in tools + ] + + if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): + resp = client.chat.completions.create( + model=credentials['model_uid'], + messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], + stream=stream, + user=user, + **generate_config, + ) + if stream: + if tools and len(tools) > 0: + raise InvokeBadRequestError('xinference tool calls does not support stream mode') + return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, + tools=tools, resp=resp) + return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, + tools=tools, resp=resp) + elif isinstance(xinference_model, RESTfulGenerateModelHandle): + resp = client.completions.create( + model=credentials['model_uid'], + prompt=self._convert_prompt_message_to_text(prompt_messages), + stream=stream, + user=user, + **generate_config, + ) + if stream: + return self._handle_completion_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, + tools=tools, resp=resp) + return self._handle_completion_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, + tools=tools, resp=resp) + else: + raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported') + + def _extract_response_tool_calls(self, + response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ + -> list[AssistantPromptMessage.ToolCall]: + """ + Extract tool calls from response + + :param response_tool_calls: response tool calls + :return: list of tool calls + """ + tool_calls = [] + if response_tool_calls: + for response_tool_call in response_tool_calls: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_tool_call.function.name, + arguments=response_tool_call.function.arguments + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call.id, + type=response_tool_call.type, + function=function + ) + tool_calls.append(tool_call) + + return tool_calls + + def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ + -> AssistantPromptMessage.ToolCall: + """ + Extract function call from response + + :param response_function_call: response function call + :return: tool call + """ + tool_call = None + if response_function_call: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_function_call.name, + arguments=response_function_call.arguments + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_function_call.name, + type="function", + function=function + ) + + return tool_call + + def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: ChatCompletion) -> LLMResult: + """ + handle normal chat generate response + """ + if len(resp.choices) == 0: + raise InvokeServerUnavailableError("Empty response") + + assistant_message = resp.choices[0].message + + # convert tool call to assistant message tool call + tool_calls = assistant_message.tool_calls + assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else []) + function_call = assistant_message.function_call + if function_call: + assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)] + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=assistant_message.content, + tool_calls=assistant_prompt_message_tool_calls + ) + + prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) + completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) + + usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + + response = LLMResult( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=resp.system_fingerprint, + usage=usage, + message=assistant_prompt_message, + ) + + return response + + def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[ChatCompletionChunk]) -> Generator: + """ + handle stream chat generate response + """ + full_response = '' + + for chunk in resp: + if len(chunk.choices) == 0: + continue + + delta = chunk.choices[0] + + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + continue + + # check if there is a tool call in the response + function_call = None + tool_calls = [] + if delta.delta.tool_calls: + tool_calls += delta.delta.tool_calls + if delta.delta.function_call: + function_call = delta.delta.function_call + + assistant_message_tool_calls = self._extract_response_tool_calls(tool_calls) + if function_call: + assistant_message_tool_calls += [self._extract_response_function_call(function_call)] + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=delta.delta.content if delta.delta.content else '', + tool_calls=assistant_message_tool_calls + ) + + if delta.finish_reason is not None: + # temp_assistant_prompt_message is used to calculate usage + temp_assistant_prompt_message = AssistantPromptMessage( + content=full_response, + tool_calls=assistant_message_tool_calls + ) + + prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) + completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) + + usage = self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=chunk.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=assistant_prompt_message, + finish_reason=delta.finish_reason, + usage=usage + ), + ) + else: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=chunk.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=assistant_prompt_message, + ), + ) + + full_response += delta.delta.content + + def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Completion) -> LLMResult: + """ + handle normal completion generate response + """ + if len(resp.choices) == 0: + raise InvokeServerUnavailableError("Empty response") + + assistant_message = resp.choices[0].text + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=assistant_message, + tool_calls=[] + ) + + prompt_tokens = self._get_num_tokens_by_gpt2( + self._convert_prompt_message_to_text(prompt_messages) + ) + completion_tokens = self._num_tokens_from_messages( + messages=[assistant_prompt_message], tools=[], is_completion_model=True + ) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) + + response = LLMResult( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=resp.system_fingerprint, + usage=usage, + message=assistant_prompt_message, + ) + + return response + + def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[Completion]) -> Generator: + """ + handle stream completion generate response + """ + full_response = '' + + for chunk in resp: + if len(chunk.choices) == 0: + continue + + delta = chunk.choices[0] + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=delta.text if delta.text else '', + tool_calls=[] + ) + + if delta.finish_reason is not None: + # temp_assistant_prompt_message is used to calculate usage + temp_assistant_prompt_message = AssistantPromptMessage( + content=full_response, + tool_calls=[] + ) + + prompt_tokens = self._get_num_tokens_by_gpt2( + self._convert_prompt_message_to_text(prompt_messages) + ) + completion_tokens = self._num_tokens_from_messages( + messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True + ) + usage = self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=chunk.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=assistant_prompt_message, + finish_reason=delta.finish_reason, + usage=usage + ), + ) + else: + if delta.text is None or delta.text == '': + continue + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=chunk.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=assistant_prompt_message, + ), + ) + + full_response += delta.text + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + APIConnectionError, + APITimeoutError, + ], + InvokeServerUnavailableError: [ + InternalServerError, + ConflictError, + NotFoundError, + UnprocessableEntityError, + PermissionDeniedError + ], + InvokeRateLimitError: [ + RateLimitError + ], + InvokeAuthorizationError: [ + AuthenticationError + ], + InvokeBadRequestError: [ + ValueError + ] + } \ No newline at end of file diff --git a/server/xinference/rerank/__init__.py b/server/xinference/rerank/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/xinference/rerank/rerank.py b/server/xinference/rerank/rerank.py new file mode 100644 index 00000000..dd25037d --- /dev/null +++ b/server/xinference/rerank/rerank.py @@ -0,0 +1,160 @@ +from typing import Optional + +from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class XinferenceRerankModel(RerankModel): + """ + Model class for Xinference rerank model. + """ + + def _invoke(self, model: str, credentials: dict, + query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, + user: Optional[str] = None) \ + -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult( + model=model, + docs=[] + ) + + if credentials['server_url'].endswith('/'): + credentials['server_url'] = credentials['server_url'][:-1] + + # initialize client + client = Client( + base_url=credentials['server_url'] + ) + + xinference_client = client.get_model(model_uid=credentials['model_uid']) + + if not isinstance(xinference_client, RESTfulRerankModelHandle): + raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model') + + response = xinference_client.rerank( + documents=docs, + query=query, + top_n=top_n, + ) + + rerank_documents = [] + for idx, result in enumerate(response['results']): + # format document + index = result['index'] + page_content = result['document'] + rerank_document = RerankDocument( + index=index, + text=page_content, + score=result['relevance_score'], + ) + + # score threshold check + if score_threshold is not None: + if result['relevance_score'] >= score_threshold: + rerank_documents.append(rerank_document) + else: + rerank_documents.append(rerank_document) + + return RerankResult( + model=model, + docs=rerank_documents + ) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + + self.invoke( + model=model, + credentials=credentials, + query="Whose kasumi", + docs=[ + "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", + "and she leads a team named PopiParty." + ], + score_threshold=0.8 + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError, + KeyError, + ValueError + ] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.RERANK, + model_properties={ }, + parameter_rules=[] + ) + + return entity \ No newline at end of file diff --git a/server/xinference/text_embedding/__init__.py b/server/xinference/text_embedding/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/xinference/text_embedding/text_embedding.py b/server/xinference/text_embedding/text_embedding.py new file mode 100644 index 00000000..32d2b151 --- /dev/null +++ b/server/xinference/text_embedding/text_embedding.py @@ -0,0 +1,201 @@ +import time +from typing import Optional + +from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper + + +class XinferenceTextEmbeddingModel(TextEmbeddingModel): + """ + Model class for Xinference text embedding model. + """ + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + credentials should be like: + { + 'server_url': 'server url', + 'model_uid': 'model uid', + } + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + server_url = credentials['server_url'] + model_uid = credentials['model_uid'] + + if server_url.endswith('/'): + server_url = server_url[:-1] + + client = Client(base_url=server_url) + + try: + handle = client.get_model(model_uid=model_uid) + except RuntimeError as e: + raise InvokeAuthorizationError(e) + + if not isinstance(handle, RESTfulEmbeddingModelHandle): + raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') + + try: + embeddings = handle.create_embedding(input=texts) + except RuntimeError as e: + raise InvokeServerUnavailableError(e) + + """ + for convenience, the response json is like: + class Embedding(TypedDict): + object: Literal["list"] + model: str + data: List[EmbeddingData] + usage: EmbeddingUsage + class EmbeddingUsage(TypedDict): + prompt_tokens: int + total_tokens: int + class EmbeddingData(TypedDict): + index: int + object: str + embedding: List[float] + """ + + usage = embeddings['usage'] + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + + result = TextEmbeddingResult( + model=model, + embeddings=[embedding['embedding'] for embedding in embeddings['data']], + usage=usage + ) + + return result + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + num_tokens = 0 + for text in texts: + # use GPT2Tokenizer to get num tokens + num_tokens += self._get_num_tokens_by_gpt2(text) + return num_tokens + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + + server_url = credentials['server_url'] + model_uid = credentials['model_uid'] + extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) + + if extra_args.max_tokens: + credentials['max_tokens'] = extra_args.max_tokens + + self._invoke(model=model, credentials=credentials, texts=['ping']) + except InvokeAuthorizationError as e: + raise CredentialsValidateFailedError(f'Failed to validate credentials for model {model}: {e}') + except RuntimeError as e: + raise CredentialsValidateFailedError(e) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + KeyError + ] + } + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TEXT_EMBEDDING, + model_properties={ + ModelPropertyKey.MAX_CHUNKS: 1, + ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512, + }, + parameter_rules=[] + ) + + return entity \ No newline at end of file diff --git a/server/xinference/xinference.py b/server/xinference/xinference.py new file mode 100644 index 00000000..d85f7c82 --- /dev/null +++ b/server/xinference/xinference.py @@ -0,0 +1,10 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class XinferenceAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/server/xinference/xinference.yaml b/server/xinference/xinference.yaml new file mode 100644 index 00000000..bb6c6d86 --- /dev/null +++ b/server/xinference/xinference.yaml @@ -0,0 +1,47 @@ +provider: xinference +label: + en_US: Xorbits Inference +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#FAF5FF" +help: + title: + en_US: How to deploy Xinference + zh_Hans: 如何部署 Xinference + url: + en_US: https://github.com/xorbitsai/inference +supported_model_types: + - llm + - text-embedding + - rerank +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: server_url + label: + zh_Hans: 服务器URL + en_US: Server url + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997 + en_US: Enter the url of your Xinference, e.g. http://192.168.1.100:9997 + - variable: model_uid + label: + zh_Hans: 模型UID + en_US: Model uid + type: text-input + required: true + placeholder: + zh_Hans: 在此输入您的Model UID + en_US: Enter the model uid diff --git a/server/xinference/xinference_helper.py b/server/xinference/xinference_helper.py new file mode 100644 index 00000000..66dab658 --- /dev/null +++ b/server/xinference/xinference_helper.py @@ -0,0 +1,103 @@ +from threading import Lock +from time import time + +from requests.adapters import HTTPAdapter +from requests.exceptions import ConnectionError, MissingSchema, Timeout +from requests.sessions import Session +from yarl import URL + + +class XinferenceModelExtraParameter: + model_format: str + model_handle_type: str + model_ability: list[str] + max_tokens: int = 512 + context_length: int = 2048 + support_function_call: bool = False + + def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], + support_function_call: bool, max_tokens: int, context_length: int) -> None: + self.model_format = model_format + self.model_handle_type = model_handle_type + self.model_ability = model_ability + self.support_function_call = support_function_call + self.max_tokens = max_tokens + self.context_length = context_length + +cache = {} +cache_lock = Lock() + +class XinferenceHelper: + @staticmethod + def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + XinferenceHelper._clean_cache() + with cache_lock: + if model_uid not in cache: + cache[model_uid] = { + 'expires': time() + 300, + 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid) + } + return cache[model_uid]['value'] + + @staticmethod + def _clean_cache() -> None: + try: + with cache_lock: + expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] + for model_uid in expired_keys: + del cache[model_uid] + except RuntimeError as e: + pass + + @staticmethod + def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + """ + get xinference model extra parameter like model_format and model_handle_type + """ + + if not model_uid or not model_uid.strip() or not server_url or not server_url.strip(): + raise RuntimeError('model_uid is empty') + + url = str(URL(server_url) / 'v1' / 'models' / model_uid) + + # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 + session = Session() + session.mount('http://', HTTPAdapter(max_retries=3)) + session.mount('https://', HTTPAdapter(max_retries=3)) + + try: + response = session.get(url, timeout=10) + except (MissingSchema, ConnectionError, Timeout) as e: + raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') + if response.status_code != 200: + raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') + + response_json = response.json() + + model_format = response_json.get('model_format', 'ggmlv3') + model_ability = response_json.get('model_ability', []) + + if response_json.get('model_type') == 'embedding': + model_handle_type = 'embedding' + elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: + model_handle_type = 'chatglm' + elif 'generate' in model_ability: + model_handle_type = 'generate' + elif 'chat' in model_ability: + model_handle_type = 'chat' + else: + raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') + + support_function_call = 'tools' in model_ability + max_tokens = response_json.get('max_tokens', 512) + + context_length = response_json.get('context_length', 2048) + + return XinferenceModelExtraParameter( + model_format=model_format, + model_handle_type=model_handle_type, + model_ability=model_ability, + support_function_call=support_function_call, + max_tokens=max_tokens, + context_length=context_length + ) \ No newline at end of file