From eec8d51a910edef9c8a0fff72553d132185cb3a0 Mon Sep 17 00:00:00 2001 From: Leb Date: Wed, 13 Mar 2024 16:37:37 +0800 Subject: [PATCH 1/4] =?UTF-8?q?xinference=E7=9A=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 先传 我后面来改 --- server/xinference/__init__.py | 0 server/xinference/_assets/icon_l_en.svg | 42 + server/xinference/_assets/icon_s_en.svg | 24 + server/xinference/llm/__init__.py | 0 server/xinference/llm/llm.py | 734 ++++++++++++++++++ server/xinference/rerank/__init__.py | 0 server/xinference/rerank/rerank.py | 160 ++++ server/xinference/text_embedding/__init__.py | 0 .../text_embedding/text_embedding.py | 201 +++++ server/xinference/xinference.py | 10 + server/xinference/xinference.yaml | 47 ++ server/xinference/xinference_helper.py | 103 +++ 12 files changed, 1321 insertions(+) create mode 100644 server/xinference/__init__.py create mode 100644 server/xinference/_assets/icon_l_en.svg create mode 100644 server/xinference/_assets/icon_s_en.svg create mode 100644 server/xinference/llm/__init__.py create mode 100644 server/xinference/llm/llm.py create mode 100644 server/xinference/rerank/__init__.py create mode 100644 server/xinference/rerank/rerank.py create mode 100644 server/xinference/text_embedding/__init__.py create mode 100644 server/xinference/text_embedding/text_embedding.py create mode 100644 server/xinference/xinference.py create mode 100644 server/xinference/xinference.yaml create mode 100644 server/xinference/xinference_helper.py diff --git a/server/xinference/__init__.py b/server/xinference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/xinference/_assets/icon_l_en.svg b/server/xinference/_assets/icon_l_en.svg new file mode 100644 index 00000000..81091765 --- /dev/null +++ b/server/xinference/_assets/icon_l_en.svg @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/server/xinference/_assets/icon_s_en.svg b/server/xinference/_assets/icon_s_en.svg new file mode 100644 index 00000000..f5c5f75e --- /dev/null +++ b/server/xinference/_assets/icon_s_en.svg @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/server/xinference/llm/__init__.py b/server/xinference/llm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/xinference/llm/llm.py b/server/xinference/llm/llm.py new file mode 100644 index 00000000..602d0b74 --- /dev/null +++ b/server/xinference/llm/llm.py @@ -0,0 +1,734 @@ +from collections.abc import Generator, Iterator +from typing import cast + +from openai import ( + APIConnectionError, + APITimeoutError, + AuthenticationError, + ConflictError, + InternalServerError, + NotFoundError, + OpenAI, + PermissionDeniedError, + RateLimitError, + UnprocessableEntityError, +) +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall +from openai.types.chat.chat_completion_message import FunctionCall +from openai.types.completion import Completion +from xinference_client.client.restful.restful_client import ( + Client, + RESTfulChatglmCppChatModelHandle, + RESTfulChatModelHandle, + RESTfulGenerateModelHandle, +) + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.xinference.xinference_helper import ( + XinferenceHelper, + XinferenceModelExtraParameter, +) +from core.model_runtime.utils import helper + + +class XinferenceAILargeLanguageModel(LargeLanguageModel): + def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + """ + invoke LLM + + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` + """ + if 'temperature' in model_parameters: + if model_parameters['temperature'] < 0.01: + model_parameters['temperature'] = 0.01 + elif model_parameters['temperature'] > 1.0: + model_parameters['temperature'] = 0.99 + + return self._generate( + model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, + tools=tools, stop=stop, stream=stream, user=user, + extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( + server_url=credentials['server_url'], + model_uid=credentials['model_uid'] + ) + ) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + validate credentials + + credentials should be like: + { + 'model_type': 'text-generation', + 'server_url': 'server url', + 'model_uid': 'model uid', + } + """ + try: + if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + + extra_param = XinferenceHelper.get_xinference_extra_parameter( + server_url=credentials['server_url'], + model_uid=credentials['model_uid'] + ) + if 'completion_type' not in credentials: + if 'chat' in extra_param.model_ability: + credentials['completion_type'] = 'chat' + elif 'generate' in extra_param.model_ability: + credentials['completion_type'] = 'completion' + else: + raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type') + + if extra_param.support_function_call: + credentials['support_function_call'] = True + + if extra_param.context_length: + credentials['context_length'] = extra_param.context_length + + except RuntimeError as e: + raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + except KeyError as e: + raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + except Exception as e: + raise e + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None) -> int: + """ + get number of tokens + + cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use + so we just take the GPT2 tokenizer as default + """ + return self._num_tokens_from_messages(prompt_messages, tools) + + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], + is_completion_model: bool = False) -> int: + def tokens(text: str): + return self._get_num_tokens_by_gpt2(text) + + if is_completion_model: + return sum([tokens(str(message.content)) for message in messages]) + + tokens_per_message = 3 + tokens_per_name = 1 + + num_tokens = 0 + messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + if isinstance(value, list): + text = '' + for item in value: + if isinstance(item, dict) and item['type'] == 'text': + text += item.text + + value = text + + if key == "tool_calls": + for tool_call in value: + for t_key, t_value in tool_call.items(): + num_tokens += tokens(t_key) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += tokens(f_key) + num_tokens += tokens(f_value) + else: + num_tokens += tokens(t_key) + num_tokens += tokens(t_value) + if key == "function_call": + for t_key, t_value in value.items(): + num_tokens += tokens(t_key) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += tokens(f_key) + num_tokens += tokens(f_value) + else: + num_tokens += tokens(t_key) + num_tokens += tokens(t_value) + else: + num_tokens += tokens(str(value)) + + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 + + if tools: + num_tokens += self._num_tokens_for_tools(tools) + + return num_tokens + + def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: + """ + Calculate num tokens for tool calling + + :param encoding: encoding + :param tools: tools for tool calling + :return: number of tokens + """ + def tokens(text: str): + return self._get_num_tokens_by_gpt2(text) + + num_tokens = 0 + for tool in tools: + # calculate num tokens for function object + num_tokens += tokens('name') + num_tokens += tokens(tool.name) + num_tokens += tokens('description') + num_tokens += tokens(tool.description) + parameters = tool.parameters + num_tokens += tokens('parameters') + num_tokens += tokens('type') + num_tokens += tokens(parameters.get("type")) + if 'properties' in parameters: + num_tokens += tokens('properties') + for key, value in parameters.get('properties').items(): + num_tokens += tokens(key) + for field_key, field_value in value.items(): + num_tokens += tokens(field_key) + if field_key == 'enum': + for enum_field in field_value: + num_tokens += 3 + num_tokens += tokens(enum_field) + else: + num_tokens += tokens(field_key) + num_tokens += tokens(str(field_value)) + if 'required' in parameters: + num_tokens += tokens('required') + for required_field in parameters['required']: + num_tokens += 3 + num_tokens += tokens(required_field) + + return num_tokens + + def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: + """ + convert prompt message to text + """ + text = '' + for item in message: + if isinstance(item, UserPromptMessage): + text += item.content + elif isinstance(item, SystemPromptMessage): + text += item.content + elif isinstance(item, AssistantPromptMessage): + text += item.content + else: + raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') + return text + + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict for OpenAI Compatibility API + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + raise ValueError("User message content must be str") + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls and len(message.tool_calls) > 0: + message_dict["function_call"] = { + "name": message.tool_calls[0].function.name, + "arguments": message.tool_calls[0].function.arguments + } + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} + else: + raise ValueError(f"Unknown message type {type(message)}") + + return message_dict + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + rules = [ + ParameterRule( + name='temperature', + type=ParameterType.FLOAT, + use_template='temperature', + label=I18nObject( + zh_Hans='温度', + en_US='Temperature' + ), + ), + ParameterRule( + name='top_p', + type=ParameterType.FLOAT, + use_template='top_p', + label=I18nObject( + zh_Hans='Top P', + en_US='Top P' + ) + ), + ParameterRule( + name='max_tokens', + type=ParameterType.INT, + use_template='max_tokens', + min=1, + max=credentials.get('context_length', 2048), + default=512, + label=I18nObject( + zh_Hans='最大生成长度', + en_US='Max Tokens' + ) + ) + ] + + completion_type = None + + if 'completion_type' in credentials: + if credentials['completion_type'] == 'chat': + completion_type = LLMMode.CHAT.value + elif credentials['completion_type'] == 'completion': + completion_type = LLMMode.COMPLETION.value + else: + raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') + else: + extra_args = XinferenceHelper.get_xinference_extra_parameter( + server_url=credentials['server_url'], + model_uid=credentials['model_uid'] + ) + + if 'chat' in extra_args.model_ability: + completion_type = LLMMode.CHAT.value + elif 'generate' in extra_args.model_ability: + completion_type = LLMMode.COMPLETION.value + else: + raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') + + support_function_call = credentials.get('support_function_call', False) + context_length = credentials.get('context_length', 2048) + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.LLM, + features=[ + ModelFeature.TOOL_CALL + ] if support_function_call else [], + model_properties={ + ModelPropertyKey.MODE: completion_type, + ModelPropertyKey.CONTEXT_SIZE: context_length + }, + parameter_rules=rules + ) + + return entity + + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + """ + generate text from LLM + + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` + + extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` + """ + if 'server_url' not in credentials: + raise CredentialsValidateFailedError('server_url is required in credentials') + + if credentials['server_url'].endswith('/'): + credentials['server_url'] = credentials['server_url'][:-1] + + client = OpenAI( + base_url=f'{credentials["server_url"]}/v1', + api_key='abc', + max_retries=3, + timeout=60, + ) + + xinference_client = Client( + base_url=credentials['server_url'], + ) + + xinference_model = xinference_client.get_model(credentials['model_uid']) + + generate_config = { + 'temperature': model_parameters.get('temperature', 1.0), + 'top_p': model_parameters.get('top_p', 0.7), + 'max_tokens': model_parameters.get('max_tokens', 512), + } + + if stop: + generate_config['stop'] = stop + + if tools and len(tools) > 0: + generate_config['tools'] = [ + { + 'type': 'function', + 'function': helper.dump_model(tool) + } for tool in tools + ] + + if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): + resp = client.chat.completions.create( + model=credentials['model_uid'], + messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], + stream=stream, + user=user, + **generate_config, + ) + if stream: + if tools and len(tools) > 0: + raise InvokeBadRequestError('xinference tool calls does not support stream mode') + return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, + tools=tools, resp=resp) + return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, + tools=tools, resp=resp) + elif isinstance(xinference_model, RESTfulGenerateModelHandle): + resp = client.completions.create( + model=credentials['model_uid'], + prompt=self._convert_prompt_message_to_text(prompt_messages), + stream=stream, + user=user, + **generate_config, + ) + if stream: + return self._handle_completion_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, + tools=tools, resp=resp) + return self._handle_completion_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, + tools=tools, resp=resp) + else: + raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported') + + def _extract_response_tool_calls(self, + response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ + -> list[AssistantPromptMessage.ToolCall]: + """ + Extract tool calls from response + + :param response_tool_calls: response tool calls + :return: list of tool calls + """ + tool_calls = [] + if response_tool_calls: + for response_tool_call in response_tool_calls: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_tool_call.function.name, + arguments=response_tool_call.function.arguments + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call.id, + type=response_tool_call.type, + function=function + ) + tool_calls.append(tool_call) + + return tool_calls + + def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ + -> AssistantPromptMessage.ToolCall: + """ + Extract function call from response + + :param response_function_call: response function call + :return: tool call + """ + tool_call = None + if response_function_call: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_function_call.name, + arguments=response_function_call.arguments + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_function_call.name, + type="function", + function=function + ) + + return tool_call + + def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: ChatCompletion) -> LLMResult: + """ + handle normal chat generate response + """ + if len(resp.choices) == 0: + raise InvokeServerUnavailableError("Empty response") + + assistant_message = resp.choices[0].message + + # convert tool call to assistant message tool call + tool_calls = assistant_message.tool_calls + assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else []) + function_call = assistant_message.function_call + if function_call: + assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)] + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=assistant_message.content, + tool_calls=assistant_prompt_message_tool_calls + ) + + prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) + completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) + + usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + + response = LLMResult( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=resp.system_fingerprint, + usage=usage, + message=assistant_prompt_message, + ) + + return response + + def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[ChatCompletionChunk]) -> Generator: + """ + handle stream chat generate response + """ + full_response = '' + + for chunk in resp: + if len(chunk.choices) == 0: + continue + + delta = chunk.choices[0] + + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + continue + + # check if there is a tool call in the response + function_call = None + tool_calls = [] + if delta.delta.tool_calls: + tool_calls += delta.delta.tool_calls + if delta.delta.function_call: + function_call = delta.delta.function_call + + assistant_message_tool_calls = self._extract_response_tool_calls(tool_calls) + if function_call: + assistant_message_tool_calls += [self._extract_response_function_call(function_call)] + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=delta.delta.content if delta.delta.content else '', + tool_calls=assistant_message_tool_calls + ) + + if delta.finish_reason is not None: + # temp_assistant_prompt_message is used to calculate usage + temp_assistant_prompt_message = AssistantPromptMessage( + content=full_response, + tool_calls=assistant_message_tool_calls + ) + + prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) + completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) + + usage = self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=chunk.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=assistant_prompt_message, + finish_reason=delta.finish_reason, + usage=usage + ), + ) + else: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=chunk.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=assistant_prompt_message, + ), + ) + + full_response += delta.delta.content + + def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Completion) -> LLMResult: + """ + handle normal completion generate response + """ + if len(resp.choices) == 0: + raise InvokeServerUnavailableError("Empty response") + + assistant_message = resp.choices[0].text + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=assistant_message, + tool_calls=[] + ) + + prompt_tokens = self._get_num_tokens_by_gpt2( + self._convert_prompt_message_to_text(prompt_messages) + ) + completion_tokens = self._num_tokens_from_messages( + messages=[assistant_prompt_message], tools=[], is_completion_model=True + ) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) + + response = LLMResult( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=resp.system_fingerprint, + usage=usage, + message=assistant_prompt_message, + ) + + return response + + def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[Completion]) -> Generator: + """ + handle stream completion generate response + """ + full_response = '' + + for chunk in resp: + if len(chunk.choices) == 0: + continue + + delta = chunk.choices[0] + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=delta.text if delta.text else '', + tool_calls=[] + ) + + if delta.finish_reason is not None: + # temp_assistant_prompt_message is used to calculate usage + temp_assistant_prompt_message = AssistantPromptMessage( + content=full_response, + tool_calls=[] + ) + + prompt_tokens = self._get_num_tokens_by_gpt2( + self._convert_prompt_message_to_text(prompt_messages) + ) + completion_tokens = self._num_tokens_from_messages( + messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True + ) + usage = self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=chunk.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=assistant_prompt_message, + finish_reason=delta.finish_reason, + usage=usage + ), + ) + else: + if delta.text is None or delta.text == '': + continue + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=chunk.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=assistant_prompt_message, + ), + ) + + full_response += delta.text + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + APIConnectionError, + APITimeoutError, + ], + InvokeServerUnavailableError: [ + InternalServerError, + ConflictError, + NotFoundError, + UnprocessableEntityError, + PermissionDeniedError + ], + InvokeRateLimitError: [ + RateLimitError + ], + InvokeAuthorizationError: [ + AuthenticationError + ], + InvokeBadRequestError: [ + ValueError + ] + } \ No newline at end of file diff --git a/server/xinference/rerank/__init__.py b/server/xinference/rerank/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/xinference/rerank/rerank.py b/server/xinference/rerank/rerank.py new file mode 100644 index 00000000..dd25037d --- /dev/null +++ b/server/xinference/rerank/rerank.py @@ -0,0 +1,160 @@ +from typing import Optional + +from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class XinferenceRerankModel(RerankModel): + """ + Model class for Xinference rerank model. + """ + + def _invoke(self, model: str, credentials: dict, + query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, + user: Optional[str] = None) \ + -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult( + model=model, + docs=[] + ) + + if credentials['server_url'].endswith('/'): + credentials['server_url'] = credentials['server_url'][:-1] + + # initialize client + client = Client( + base_url=credentials['server_url'] + ) + + xinference_client = client.get_model(model_uid=credentials['model_uid']) + + if not isinstance(xinference_client, RESTfulRerankModelHandle): + raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model') + + response = xinference_client.rerank( + documents=docs, + query=query, + top_n=top_n, + ) + + rerank_documents = [] + for idx, result in enumerate(response['results']): + # format document + index = result['index'] + page_content = result['document'] + rerank_document = RerankDocument( + index=index, + text=page_content, + score=result['relevance_score'], + ) + + # score threshold check + if score_threshold is not None: + if result['relevance_score'] >= score_threshold: + rerank_documents.append(rerank_document) + else: + rerank_documents.append(rerank_document) + + return RerankResult( + model=model, + docs=rerank_documents + ) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + + self.invoke( + model=model, + credentials=credentials, + query="Whose kasumi", + docs=[ + "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", + "and she leads a team named PopiParty." + ], + score_threshold=0.8 + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError, + KeyError, + ValueError + ] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.RERANK, + model_properties={ }, + parameter_rules=[] + ) + + return entity \ No newline at end of file diff --git a/server/xinference/text_embedding/__init__.py b/server/xinference/text_embedding/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/xinference/text_embedding/text_embedding.py b/server/xinference/text_embedding/text_embedding.py new file mode 100644 index 00000000..32d2b151 --- /dev/null +++ b/server/xinference/text_embedding/text_embedding.py @@ -0,0 +1,201 @@ +import time +from typing import Optional + +from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper + + +class XinferenceTextEmbeddingModel(TextEmbeddingModel): + """ + Model class for Xinference text embedding model. + """ + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + credentials should be like: + { + 'server_url': 'server url', + 'model_uid': 'model uid', + } + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + server_url = credentials['server_url'] + model_uid = credentials['model_uid'] + + if server_url.endswith('/'): + server_url = server_url[:-1] + + client = Client(base_url=server_url) + + try: + handle = client.get_model(model_uid=model_uid) + except RuntimeError as e: + raise InvokeAuthorizationError(e) + + if not isinstance(handle, RESTfulEmbeddingModelHandle): + raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') + + try: + embeddings = handle.create_embedding(input=texts) + except RuntimeError as e: + raise InvokeServerUnavailableError(e) + + """ + for convenience, the response json is like: + class Embedding(TypedDict): + object: Literal["list"] + model: str + data: List[EmbeddingData] + usage: EmbeddingUsage + class EmbeddingUsage(TypedDict): + prompt_tokens: int + total_tokens: int + class EmbeddingData(TypedDict): + index: int + object: str + embedding: List[float] + """ + + usage = embeddings['usage'] + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + + result = TextEmbeddingResult( + model=model, + embeddings=[embedding['embedding'] for embedding in embeddings['data']], + usage=usage + ) + + return result + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + num_tokens = 0 + for text in texts: + # use GPT2Tokenizer to get num tokens + num_tokens += self._get_num_tokens_by_gpt2(text) + return num_tokens + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + + server_url = credentials['server_url'] + model_uid = credentials['model_uid'] + extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) + + if extra_args.max_tokens: + credentials['max_tokens'] = extra_args.max_tokens + + self._invoke(model=model, credentials=credentials, texts=['ping']) + except InvokeAuthorizationError as e: + raise CredentialsValidateFailedError(f'Failed to validate credentials for model {model}: {e}') + except RuntimeError as e: + raise CredentialsValidateFailedError(e) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + KeyError + ] + } + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TEXT_EMBEDDING, + model_properties={ + ModelPropertyKey.MAX_CHUNKS: 1, + ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512, + }, + parameter_rules=[] + ) + + return entity \ No newline at end of file diff --git a/server/xinference/xinference.py b/server/xinference/xinference.py new file mode 100644 index 00000000..d85f7c82 --- /dev/null +++ b/server/xinference/xinference.py @@ -0,0 +1,10 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class XinferenceAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/server/xinference/xinference.yaml b/server/xinference/xinference.yaml new file mode 100644 index 00000000..bb6c6d86 --- /dev/null +++ b/server/xinference/xinference.yaml @@ -0,0 +1,47 @@ +provider: xinference +label: + en_US: Xorbits Inference +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#FAF5FF" +help: + title: + en_US: How to deploy Xinference + zh_Hans: 如何部署 Xinference + url: + en_US: https://github.com/xorbitsai/inference +supported_model_types: + - llm + - text-embedding + - rerank +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: server_url + label: + zh_Hans: 服务器URL + en_US: Server url + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997 + en_US: Enter the url of your Xinference, e.g. http://192.168.1.100:9997 + - variable: model_uid + label: + zh_Hans: 模型UID + en_US: Model uid + type: text-input + required: true + placeholder: + zh_Hans: 在此输入您的Model UID + en_US: Enter the model uid diff --git a/server/xinference/xinference_helper.py b/server/xinference/xinference_helper.py new file mode 100644 index 00000000..66dab658 --- /dev/null +++ b/server/xinference/xinference_helper.py @@ -0,0 +1,103 @@ +from threading import Lock +from time import time + +from requests.adapters import HTTPAdapter +from requests.exceptions import ConnectionError, MissingSchema, Timeout +from requests.sessions import Session +from yarl import URL + + +class XinferenceModelExtraParameter: + model_format: str + model_handle_type: str + model_ability: list[str] + max_tokens: int = 512 + context_length: int = 2048 + support_function_call: bool = False + + def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], + support_function_call: bool, max_tokens: int, context_length: int) -> None: + self.model_format = model_format + self.model_handle_type = model_handle_type + self.model_ability = model_ability + self.support_function_call = support_function_call + self.max_tokens = max_tokens + self.context_length = context_length + +cache = {} +cache_lock = Lock() + +class XinferenceHelper: + @staticmethod + def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + XinferenceHelper._clean_cache() + with cache_lock: + if model_uid not in cache: + cache[model_uid] = { + 'expires': time() + 300, + 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid) + } + return cache[model_uid]['value'] + + @staticmethod + def _clean_cache() -> None: + try: + with cache_lock: + expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] + for model_uid in expired_keys: + del cache[model_uid] + except RuntimeError as e: + pass + + @staticmethod + def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + """ + get xinference model extra parameter like model_format and model_handle_type + """ + + if not model_uid or not model_uid.strip() or not server_url or not server_url.strip(): + raise RuntimeError('model_uid is empty') + + url = str(URL(server_url) / 'v1' / 'models' / model_uid) + + # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 + session = Session() + session.mount('http://', HTTPAdapter(max_retries=3)) + session.mount('https://', HTTPAdapter(max_retries=3)) + + try: + response = session.get(url, timeout=10) + except (MissingSchema, ConnectionError, Timeout) as e: + raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') + if response.status_code != 200: + raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') + + response_json = response.json() + + model_format = response_json.get('model_format', 'ggmlv3') + model_ability = response_json.get('model_ability', []) + + if response_json.get('model_type') == 'embedding': + model_handle_type = 'embedding' + elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: + model_handle_type = 'chatglm' + elif 'generate' in model_ability: + model_handle_type = 'generate' + elif 'chat' in model_ability: + model_handle_type = 'chat' + else: + raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') + + support_function_call = 'tools' in model_ability + max_tokens = response_json.get('max_tokens', 512) + + context_length = response_json.get('context_length', 2048) + + return XinferenceModelExtraParameter( + model_format=model_format, + model_handle_type=model_handle_type, + model_ability=model_ability, + support_function_call=support_function_call, + max_tokens=max_tokens, + context_length=context_length + ) \ No newline at end of file From 4939e736e164e1d056089df187df7f2335f74ab6 Mon Sep 17 00:00:00 2001 From: Leb Date: Wed, 13 Mar 2024 16:56:53 +0800 Subject: [PATCH 2/4] Delete server/xinference directory --- server/xinference/__init__.py | 0 server/xinference/_assets/icon_l_en.svg | 42 - server/xinference/_assets/icon_s_en.svg | 24 - server/xinference/llm/__init__.py | 0 server/xinference/llm/llm.py | 734 ------------------ server/xinference/rerank/__init__.py | 0 server/xinference/rerank/rerank.py | 160 ---- server/xinference/text_embedding/__init__.py | 0 .../text_embedding/text_embedding.py | 201 ----- server/xinference/xinference.py | 10 - server/xinference/xinference.yaml | 47 -- server/xinference/xinference_helper.py | 103 --- 12 files changed, 1321 deletions(-) delete mode 100644 server/xinference/__init__.py delete mode 100644 server/xinference/_assets/icon_l_en.svg delete mode 100644 server/xinference/_assets/icon_s_en.svg delete mode 100644 server/xinference/llm/__init__.py delete mode 100644 server/xinference/llm/llm.py delete mode 100644 server/xinference/rerank/__init__.py delete mode 100644 server/xinference/rerank/rerank.py delete mode 100644 server/xinference/text_embedding/__init__.py delete mode 100644 server/xinference/text_embedding/text_embedding.py delete mode 100644 server/xinference/xinference.py delete mode 100644 server/xinference/xinference.yaml delete mode 100644 server/xinference/xinference_helper.py diff --git a/server/xinference/__init__.py b/server/xinference/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/server/xinference/_assets/icon_l_en.svg b/server/xinference/_assets/icon_l_en.svg deleted file mode 100644 index 81091765..00000000 --- a/server/xinference/_assets/icon_l_en.svg +++ /dev/null @@ -1,42 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/server/xinference/_assets/icon_s_en.svg b/server/xinference/_assets/icon_s_en.svg deleted file mode 100644 index f5c5f75e..00000000 --- a/server/xinference/_assets/icon_s_en.svg +++ /dev/null @@ -1,24 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/server/xinference/llm/__init__.py b/server/xinference/llm/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/server/xinference/llm/llm.py b/server/xinference/llm/llm.py deleted file mode 100644 index 602d0b74..00000000 --- a/server/xinference/llm/llm.py +++ /dev/null @@ -1,734 +0,0 @@ -from collections.abc import Generator, Iterator -from typing import cast - -from openai import ( - APIConnectionError, - APITimeoutError, - AuthenticationError, - ConflictError, - InternalServerError, - NotFoundError, - OpenAI, - PermissionDeniedError, - RateLimitError, - UnprocessableEntityError, -) -from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall -from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall -from openai.types.chat.chat_completion_message import FunctionCall -from openai.types.completion import Completion -from xinference_client.client.restful.restful_client import ( - Client, - RESTfulChatglmCppChatModelHandle, - RESTfulChatModelHandle, - RESTfulGenerateModelHandle, -) - -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageTool, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) -from core.model_runtime.entities.model_entities import ( - AIModelEntity, - FetchFrom, - ModelFeature, - ModelPropertyKey, - ModelType, - ParameterRule, - ParameterType, -) -from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.xinference.xinference_helper import ( - XinferenceHelper, - XinferenceModelExtraParameter, -) -from core.model_runtime.utils import helper - - -class XinferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - """ - invoke LLM - - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` - """ - if 'temperature' in model_parameters: - if model_parameters['temperature'] < 0.01: - model_parameters['temperature'] = 0.01 - elif model_parameters['temperature'] > 1.0: - model_parameters['temperature'] = 0.99 - - return self._generate( - model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=tools, stop=stop, stream=stream, user=user, - extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'] - ) - ) - - def validate_credentials(self, model: str, credentials: dict) -> None: - """ - validate credentials - - credentials should be like: - { - 'model_type': 'text-generation', - 'server_url': 'server url', - 'model_uid': 'model uid', - } - """ - try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: - raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - - extra_param = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'] - ) - if 'completion_type' not in credentials: - if 'chat' in extra_param.model_ability: - credentials['completion_type'] = 'chat' - elif 'generate' in extra_param.model_ability: - credentials['completion_type'] = 'completion' - else: - raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type') - - if extra_param.support_function_call: - credentials['support_function_call'] = True - - if extra_param.context_length: - credentials['context_length'] = extra_param.context_length - - except RuntimeError as e: - raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') - except KeyError as e: - raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') - except Exception as e: - raise e - - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: - """ - get number of tokens - - cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use - so we just take the GPT2 tokenizer as default - """ - return self._num_tokens_from_messages(prompt_messages, tools) - - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], - is_completion_model: bool = False) -> int: - def tokens(text: str): - return self._get_num_tokens_by_gpt2(text) - - if is_completion_model: - return sum([tokens(str(message.content)) for message in messages]) - - tokens_per_message = 3 - tokens_per_name = 1 - - num_tokens = 0 - messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] - for message in messages_dict: - num_tokens += tokens_per_message - for key, value in message.items(): - if isinstance(value, list): - text = '' - for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item.text - - value = text - - if key == "tool_calls": - for tool_call in value: - for t_key, t_value in tool_call.items(): - num_tokens += tokens(t_key) - if t_key == "function": - for f_key, f_value in t_value.items(): - num_tokens += tokens(f_key) - num_tokens += tokens(f_value) - else: - num_tokens += tokens(t_key) - num_tokens += tokens(t_value) - if key == "function_call": - for t_key, t_value in value.items(): - num_tokens += tokens(t_key) - if t_key == "function": - for f_key, f_value in t_value.items(): - num_tokens += tokens(f_key) - num_tokens += tokens(f_value) - else: - num_tokens += tokens(t_key) - num_tokens += tokens(t_value) - else: - num_tokens += tokens(str(value)) - - if key == "name": - num_tokens += tokens_per_name - num_tokens += 3 - - if tools: - num_tokens += self._num_tokens_for_tools(tools) - - return num_tokens - - def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: - """ - Calculate num tokens for tool calling - - :param encoding: encoding - :param tools: tools for tool calling - :return: number of tokens - """ - def tokens(text: str): - return self._get_num_tokens_by_gpt2(text) - - num_tokens = 0 - for tool in tools: - # calculate num tokens for function object - num_tokens += tokens('name') - num_tokens += tokens(tool.name) - num_tokens += tokens('description') - num_tokens += tokens(tool.description) - parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') - num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): - num_tokens += tokens(key) - for field_key, field_value in value.items(): - num_tokens += tokens(field_key) - if field_key == 'enum': - for enum_field in field_value: - num_tokens += 3 - num_tokens += tokens(enum_field) - else: - num_tokens += tokens(field_key) - num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: - num_tokens += 3 - num_tokens += tokens(required_field) - - return num_tokens - - def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: - """ - convert prompt message to text - """ - text = '' - for item in message: - if isinstance(item, UserPromptMessage): - text += item.content - elif isinstance(item, SystemPromptMessage): - text += item.content - elif isinstance(item, AssistantPromptMessage): - text += item.content - else: - raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') - return text - - def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: - """ - Convert PromptMessage to dict for OpenAI Compatibility API - """ - if isinstance(message, UserPromptMessage): - message = cast(UserPromptMessage, message) - if isinstance(message.content, str): - message_dict = {"role": "user", "content": message.content} - else: - raise ValueError("User message content must be str") - elif isinstance(message, AssistantPromptMessage): - message = cast(AssistantPromptMessage, message) - message_dict = {"role": "assistant", "content": message.content} - if message.tool_calls and len(message.tool_calls) > 0: - message_dict["function_call"] = { - "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments - } - elif isinstance(message, SystemPromptMessage): - message = cast(SystemPromptMessage, message) - message_dict = {"role": "system", "content": message.content} - elif isinstance(message, ToolPromptMessage): - message = cast(ToolPromptMessage, message) - message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} - else: - raise ValueError(f"Unknown message type {type(message)}") - - return message_dict - - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: - """ - used to define customizable model schema - """ - rules = [ - ParameterRule( - name='temperature', - type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), - ), - ParameterRule( - name='top_p', - type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) - ), - ParameterRule( - name='max_tokens', - type=ParameterType.INT, - use_template='max_tokens', - min=1, - max=credentials.get('context_length', 2048), - default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) - ] - - completion_type = None - - if 'completion_type' in credentials: - if credentials['completion_type'] == 'chat': - completion_type = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': - completion_type = LLMMode.COMPLETION.value - else: - raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') - else: - extra_args = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'] - ) - - if 'chat' in extra_args.model_ability: - completion_type = LLMMode.CHAT.value - elif 'generate' in extra_args.model_ability: - completion_type = LLMMode.COMPLETION.value - else: - raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') - - support_function_call = credentials.get('support_function_call', False) - context_length = credentials.get('context_length', 2048) - - entity = AIModelEntity( - model=model, - label=I18nObject( - en_US=model - ), - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_type=ModelType.LLM, - features=[ - ModelFeature.TOOL_CALL - ] if support_function_call else [], - model_properties={ - ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: context_length - }, - parameter_rules=rules - ) - - return entity - - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - """ - generate text from LLM - - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` - - extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` - """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') - - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] - - client = OpenAI( - base_url=f'{credentials["server_url"]}/v1', - api_key='abc', - max_retries=3, - timeout=60, - ) - - xinference_client = Client( - base_url=credentials['server_url'], - ) - - xinference_model = xinference_client.get_model(credentials['model_uid']) - - generate_config = { - 'temperature': model_parameters.get('temperature', 1.0), - 'top_p': model_parameters.get('top_p', 0.7), - 'max_tokens': model_parameters.get('max_tokens', 512), - } - - if stop: - generate_config['stop'] = stop - - if tools and len(tools) > 0: - generate_config['tools'] = [ - { - 'type': 'function', - 'function': helper.dump_model(tool) - } for tool in tools - ] - - if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): - resp = client.chat.completions.create( - model=credentials['model_uid'], - messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], - stream=stream, - user=user, - **generate_config, - ) - if stream: - if tools and len(tools) > 0: - raise InvokeBadRequestError('xinference tool calls does not support stream mode') - return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=resp) - elif isinstance(xinference_model, RESTfulGenerateModelHandle): - resp = client.completions.create( - model=credentials['model_uid'], - prompt=self._convert_prompt_message_to_text(prompt_messages), - stream=stream, - user=user, - **generate_config, - ) - if stream: - return self._handle_completion_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_completion_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=resp) - else: - raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported') - - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: - """ - Extract tool calls from response - - :param response_tool_calls: response tool calls - :return: list of tool calls - """ - tool_calls = [] - if response_tool_calls: - for response_tool_call in response_tool_calls: - function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments - ) - - tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function - ) - tool_calls.append(tool_call) - - return tool_calls - - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: - """ - Extract function call from response - - :param response_function_call: response function call - :return: tool call - """ - tool_call = None - if response_function_call: - function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments - ) - - tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function - ) - - return tool_call - - def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: ChatCompletion) -> LLMResult: - """ - handle normal chat generate response - """ - if len(resp.choices) == 0: - raise InvokeServerUnavailableError("Empty response") - - assistant_message = resp.choices[0].message - - # convert tool call to assistant message tool call - tool_calls = assistant_message.tool_calls - assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else []) - function_call = assistant_message.function_call - if function_call: - assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)] - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=assistant_prompt_message_tool_calls - ) - - prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) - completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - - response = LLMResult( - model=model, - prompt_messages=prompt_messages, - system_fingerprint=resp.system_fingerprint, - usage=usage, - message=assistant_prompt_message, - ) - - return response - - def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[ChatCompletionChunk]) -> Generator: - """ - handle stream chat generate response - """ - full_response = '' - - for chunk in resp: - if len(chunk.choices) == 0: - continue - - delta = chunk.choices[0] - - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): - continue - - # check if there is a tool call in the response - function_call = None - tool_calls = [] - if delta.delta.tool_calls: - tool_calls += delta.delta.tool_calls - if delta.delta.function_call: - function_call = delta.delta.function_call - - assistant_message_tool_calls = self._extract_response_tool_calls(tool_calls) - if function_call: - assistant_message_tool_calls += [self._extract_response_function_call(function_call)] - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls - ) - - if delta.finish_reason is not None: - # temp_assistant_prompt_message is used to calculate usage - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls - ) - - prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) - completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - system_fingerprint=chunk.system_fingerprint, - delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - finish_reason=delta.finish_reason, - usage=usage - ), - ) - else: - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - system_fingerprint=chunk.system_fingerprint, - delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - ), - ) - - full_response += delta.delta.content - - def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Completion) -> LLMResult: - """ - handle normal completion generate response - """ - if len(resp.choices) == 0: - raise InvokeServerUnavailableError("Empty response") - - assistant_message = resp.choices[0].text - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message, - tool_calls=[] - ) - - prompt_tokens = self._get_num_tokens_by_gpt2( - self._convert_prompt_message_to_text(prompt_messages) - ) - completion_tokens = self._num_tokens_from_messages( - messages=[assistant_prompt_message], tools=[], is_completion_model=True - ) - usage = self._calc_response_usage( - model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens - ) - - response = LLMResult( - model=model, - prompt_messages=prompt_messages, - system_fingerprint=resp.system_fingerprint, - usage=usage, - message=assistant_prompt_message, - ) - - return response - - def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[Completion]) -> Generator: - """ - handle stream completion generate response - """ - full_response = '' - - for chunk in resp: - if len(chunk.choices) == 0: - continue - - delta = chunk.choices[0] - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.text if delta.text else '', - tool_calls=[] - ) - - if delta.finish_reason is not None: - # temp_assistant_prompt_message is used to calculate usage - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] - ) - - prompt_tokens = self._get_num_tokens_by_gpt2( - self._convert_prompt_message_to_text(prompt_messages) - ) - completion_tokens = self._num_tokens_from_messages( - messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True - ) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - system_fingerprint=chunk.system_fingerprint, - delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - finish_reason=delta.finish_reason, - usage=usage - ), - ) - else: - if delta.text is None or delta.text == '': - continue - - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - system_fingerprint=chunk.system_fingerprint, - delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - ), - ) - - full_response += delta.text - - @property - def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - """ - Map model invoke error to unified error - The key is the error type thrown to the caller - The value is the error type thrown by the model, - which needs to be converted into a unified error type for the caller. - - :return: Invoke error mapping - """ - return { - InvokeConnectionError: [ - APIConnectionError, - APITimeoutError, - ], - InvokeServerUnavailableError: [ - InternalServerError, - ConflictError, - NotFoundError, - UnprocessableEntityError, - PermissionDeniedError - ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError - ], - InvokeBadRequestError: [ - ValueError - ] - } \ No newline at end of file diff --git a/server/xinference/rerank/__init__.py b/server/xinference/rerank/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/server/xinference/rerank/rerank.py b/server/xinference/rerank/rerank.py deleted file mode 100644 index dd25037d..00000000 --- a/server/xinference/rerank/rerank.py +++ /dev/null @@ -1,160 +0,0 @@ -from typing import Optional - -from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle - -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult -from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.rerank_model import RerankModel - - -class XinferenceRerankModel(RerankModel): - """ - Model class for Xinference rerank model. - """ - - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: - """ - Invoke rerank model - - :param model: model name - :param credentials: model credentials - :param query: search query - :param docs: docs for reranking - :param score_threshold: score threshold - :param top_n: top n - :param user: unique user id - :return: rerank result - """ - if len(docs) == 0: - return RerankResult( - model=model, - docs=[] - ) - - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] - - # initialize client - client = Client( - base_url=credentials['server_url'] - ) - - xinference_client = client.get_model(model_uid=credentials['model_uid']) - - if not isinstance(xinference_client, RESTfulRerankModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model') - - response = xinference_client.rerank( - documents=docs, - query=query, - top_n=top_n, - ) - - rerank_documents = [] - for idx, result in enumerate(response['results']): - # format document - index = result['index'] - page_content = result['document'] - rerank_document = RerankDocument( - index=index, - text=page_content, - score=result['relevance_score'], - ) - - # score threshold check - if score_threshold is not None: - if result['relevance_score'] >= score_threshold: - rerank_documents.append(rerank_document) - else: - rerank_documents.append(rerank_document) - - return RerankResult( - model=model, - docs=rerank_documents - ) - - def validate_credentials(self, model: str, credentials: dict) -> None: - """ - Validate model credentials - - :param model: model name - :param credentials: model credentials - :return: - """ - try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: - raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - - self.invoke( - model=model, - credentials=credentials, - query="Whose kasumi", - docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", - "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." - ], - score_threshold=0.8 - ) - except Exception as ex: - raise CredentialsValidateFailedError(str(ex)) - - @property - def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - """ - Map model invoke error to unified error - The key is the error type thrown to the caller - The value is the error type thrown by the model, - which needs to be converted into a unified error type for the caller. - - :return: Invoke error mapping - """ - return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] - } - - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: - """ - used to define customizable model schema - """ - entity = AIModelEntity( - model=model, - label=I18nObject( - en_US=model - ), - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_type=ModelType.RERANK, - model_properties={ }, - parameter_rules=[] - ) - - return entity \ No newline at end of file diff --git a/server/xinference/text_embedding/__init__.py b/server/xinference/text_embedding/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/server/xinference/text_embedding/text_embedding.py b/server/xinference/text_embedding/text_embedding.py deleted file mode 100644 index 32d2b151..00000000 --- a/server/xinference/text_embedding/text_embedding.py +++ /dev/null @@ -1,201 +0,0 @@ -import time -from typing import Optional - -from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle - -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType -from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper - - -class XinferenceTextEmbeddingModel(TextEmbeddingModel): - """ - Model class for Xinference text embedding model. - """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: - """ - Invoke text embedding model - - credentials should be like: - { - 'server_url': 'server url', - 'model_uid': 'model uid', - } - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :param user: unique user id - :return: embeddings result - """ - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - - if server_url.endswith('/'): - server_url = server_url[:-1] - - client = Client(base_url=server_url) - - try: - handle = client.get_model(model_uid=model_uid) - except RuntimeError as e: - raise InvokeAuthorizationError(e) - - if not isinstance(handle, RESTfulEmbeddingModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') - - try: - embeddings = handle.create_embedding(input=texts) - except RuntimeError as e: - raise InvokeServerUnavailableError(e) - - """ - for convenience, the response json is like: - class Embedding(TypedDict): - object: Literal["list"] - model: str - data: List[EmbeddingData] - usage: EmbeddingUsage - class EmbeddingUsage(TypedDict): - prompt_tokens: int - total_tokens: int - class EmbeddingData(TypedDict): - index: int - object: str - embedding: List[float] - """ - - usage = embeddings['usage'] - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) - - result = TextEmbeddingResult( - model=model, - embeddings=[embedding['embedding'] for embedding in embeddings['data']], - usage=usage - ) - - return result - - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :return: - """ - num_tokens = 0 - for text in texts: - # use GPT2Tokenizer to get num tokens - num_tokens += self._get_num_tokens_by_gpt2(text) - return num_tokens - - def validate_credentials(self, model: str, credentials: dict) -> None: - """ - Validate model credentials - - :param model: model name - :param credentials: model credentials - :return: - """ - try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: - raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) - - if extra_args.max_tokens: - credentials['max_tokens'] = extra_args.max_tokens - - self._invoke(model=model, credentials=credentials, texts=['ping']) - except InvokeAuthorizationError as e: - raise CredentialsValidateFailedError(f'Failed to validate credentials for model {model}: {e}') - except RuntimeError as e: - raise CredentialsValidateFailedError(e) - - @property - def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] - } - - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: - """ - Calculate response usage - - :param model: model name - :param credentials: model credentials - :param tokens: input tokens - :return: usage - """ - # get input price info - input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens - ) - - # transform usage - usage = EmbeddingUsage( - tokens=tokens, - total_tokens=tokens, - unit_price=input_price_info.unit_price, - price_unit=input_price_info.unit, - total_price=input_price_info.total_amount, - currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at - ) - - return usage - - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: - """ - used to define customizable model schema - """ - - entity = AIModelEntity( - model=model, - label=I18nObject( - en_US=model - ), - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_type=ModelType.TEXT_EMBEDDING, - model_properties={ - ModelPropertyKey.MAX_CHUNKS: 1, - ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512, - }, - parameter_rules=[] - ) - - return entity \ No newline at end of file diff --git a/server/xinference/xinference.py b/server/xinference/xinference.py deleted file mode 100644 index d85f7c82..00000000 --- a/server/xinference/xinference.py +++ /dev/null @@ -1,10 +0,0 @@ -import logging - -from core.model_runtime.model_providers.__base.model_provider import ModelProvider - -logger = logging.getLogger(__name__) - - -class XinferenceAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: - pass diff --git a/server/xinference/xinference.yaml b/server/xinference/xinference.yaml deleted file mode 100644 index bb6c6d86..00000000 --- a/server/xinference/xinference.yaml +++ /dev/null @@ -1,47 +0,0 @@ -provider: xinference -label: - en_US: Xorbits Inference -icon_small: - en_US: icon_s_en.svg -icon_large: - en_US: icon_l_en.svg -background: "#FAF5FF" -help: - title: - en_US: How to deploy Xinference - zh_Hans: 如何部署 Xinference - url: - en_US: https://github.com/xorbitsai/inference -supported_model_types: - - llm - - text-embedding - - rerank -configurate_methods: - - customizable-model -model_credential_schema: - model: - label: - en_US: Model Name - zh_Hans: 模型名称 - placeholder: - en_US: Enter your model name - zh_Hans: 输入模型名称 - credential_form_schemas: - - variable: server_url - label: - zh_Hans: 服务器URL - en_US: Server url - type: secret-input - required: true - placeholder: - zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997 - en_US: Enter the url of your Xinference, e.g. http://192.168.1.100:9997 - - variable: model_uid - label: - zh_Hans: 模型UID - en_US: Model uid - type: text-input - required: true - placeholder: - zh_Hans: 在此输入您的Model UID - en_US: Enter the model uid diff --git a/server/xinference/xinference_helper.py b/server/xinference/xinference_helper.py deleted file mode 100644 index 66dab658..00000000 --- a/server/xinference/xinference_helper.py +++ /dev/null @@ -1,103 +0,0 @@ -from threading import Lock -from time import time - -from requests.adapters import HTTPAdapter -from requests.exceptions import ConnectionError, MissingSchema, Timeout -from requests.sessions import Session -from yarl import URL - - -class XinferenceModelExtraParameter: - model_format: str - model_handle_type: str - model_ability: list[str] - max_tokens: int = 512 - context_length: int = 2048 - support_function_call: bool = False - - def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], - support_function_call: bool, max_tokens: int, context_length: int) -> None: - self.model_format = model_format - self.model_handle_type = model_handle_type - self.model_ability = model_ability - self.support_function_call = support_function_call - self.max_tokens = max_tokens - self.context_length = context_length - -cache = {} -cache_lock = Lock() - -class XinferenceHelper: - @staticmethod - def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: - XinferenceHelper._clean_cache() - with cache_lock: - if model_uid not in cache: - cache[model_uid] = { - 'expires': time() + 300, - 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid) - } - return cache[model_uid]['value'] - - @staticmethod - def _clean_cache() -> None: - try: - with cache_lock: - expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] - for model_uid in expired_keys: - del cache[model_uid] - except RuntimeError as e: - pass - - @staticmethod - def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: - """ - get xinference model extra parameter like model_format and model_handle_type - """ - - if not model_uid or not model_uid.strip() or not server_url or not server_url.strip(): - raise RuntimeError('model_uid is empty') - - url = str(URL(server_url) / 'v1' / 'models' / model_uid) - - # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 - session = Session() - session.mount('http://', HTTPAdapter(max_retries=3)) - session.mount('https://', HTTPAdapter(max_retries=3)) - - try: - response = session.get(url, timeout=10) - except (MissingSchema, ConnectionError, Timeout) as e: - raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') - if response.status_code != 200: - raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') - - response_json = response.json() - - model_format = response_json.get('model_format', 'ggmlv3') - model_ability = response_json.get('model_ability', []) - - if response_json.get('model_type') == 'embedding': - model_handle_type = 'embedding' - elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: - model_handle_type = 'chatglm' - elif 'generate' in model_ability: - model_handle_type = 'generate' - elif 'chat' in model_ability: - model_handle_type = 'chat' - else: - raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') - - support_function_call = 'tools' in model_ability - max_tokens = response_json.get('max_tokens', 512) - - context_length = response_json.get('context_length', 2048) - - return XinferenceModelExtraParameter( - model_format=model_format, - model_handle_type=model_handle_type, - model_ability=model_ability, - support_function_call=support_function_call, - max_tokens=max_tokens, - context_length=context_length - ) \ No newline at end of file From 9208627138354f7f60badf4e4dc0e6b147ea1060 Mon Sep 17 00:00:00 2001 From: Leb Date: Wed, 13 Mar 2024 19:49:02 +0800 Subject: [PATCH 3/4] Create khazic --- server/khazic | 1 + 1 file changed, 1 insertion(+) create mode 100644 server/khazic diff --git a/server/khazic b/server/khazic new file mode 100644 index 00000000..e94f0c96 --- /dev/null +++ b/server/khazic @@ -0,0 +1 @@ +khazic From 980f321beca7409b00ce96acb5f759ef9437bac4 Mon Sep 17 00:00:00 2001 From: Leb Date: Wed, 13 Mar 2024 19:51:36 +0800 Subject: [PATCH 4/4] diiii diii --- server/schema_validators/__init__.py | 0 server/schema_validators/common_validator.py | 87 +++++++ .../model_credential_schema_validator.py | 28 +++ .../provider_credential_schema_validator.py | 20 ++ server/utils/__init__.py | 0 server/utils/_compat.py | 21 ++ server/utils/encoders.py | 228 ++++++++++++++++++ server/utils/helper.py | 9 + 8 files changed, 393 insertions(+) create mode 100644 server/schema_validators/__init__.py create mode 100644 server/schema_validators/common_validator.py create mode 100644 server/schema_validators/model_credential_schema_validator.py create mode 100644 server/schema_validators/provider_credential_schema_validator.py create mode 100644 server/utils/__init__.py create mode 100644 server/utils/_compat.py create mode 100644 server/utils/encoders.py create mode 100644 server/utils/helper.py diff --git a/server/schema_validators/__init__.py b/server/schema_validators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/schema_validators/common_validator.py b/server/schema_validators/common_validator.py new file mode 100644 index 00000000..fe705d69 --- /dev/null +++ b/server/schema_validators/common_validator.py @@ -0,0 +1,87 @@ +from typing import Optional + +from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType + + +class CommonValidator: + def _validate_and_filter_credential_form_schemas(self, + credential_form_schemas: list[CredentialFormSchema], + credentials: dict) -> dict: + need_validate_credential_form_schema_map = {} + for credential_form_schema in credential_form_schemas: + if not credential_form_schema.show_on: + need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema + continue + + all_show_on_match = True + for show_on_object in credential_form_schema.show_on: + if show_on_object.variable not in credentials: + all_show_on_match = False + break + + if credentials[show_on_object.variable] != show_on_object.value: + all_show_on_match = False + break + + if all_show_on_match: + need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema + + # Iterate over the remaining credential_form_schemas, verify each credential_form_schema + validated_credentials = {} + for credential_form_schema in need_validate_credential_form_schema_map.values(): + # add the value of the credential_form_schema corresponding to it to validated_credentials + result = self._validate_credential_form_schema(credential_form_schema, credentials) + if result: + validated_credentials[credential_form_schema.variable] = result + + return validated_credentials + + def _validate_credential_form_schema(self, credential_form_schema: CredentialFormSchema, credentials: dict) \ + -> Optional[str]: + """ + Validate credential form schema + + :param credential_form_schema: credential form schema + :param credentials: credentials + :return: validated credential form schema value + """ + # If the variable does not exist in credentials + if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: + # If required is True, an exception is thrown + if credential_form_schema.required: + raise ValueError(f'Variable {credential_form_schema.variable} is required') + else: + # Get the value of default + if credential_form_schema.default: + # If it exists, add it to validated_credentials + return credential_form_schema.default + else: + # If default does not exist, skip + return None + + # Get the value corresponding to the variable from credentials + value = credentials[credential_form_schema.variable] + + # If max_length=0, no validation is performed + if credential_form_schema.max_length: + if len(value) > credential_form_schema.max_length: + raise ValueError(f'Variable {credential_form_schema.variable} length should not greater than {credential_form_schema.max_length}') + + # check the type of value + if not isinstance(value, str): + raise ValueError(f'Variable {credential_form_schema.variable} should be string') + + if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]: + # If the value is in options, no validation is performed + if credential_form_schema.options: + if value not in [option.value for option in credential_form_schema.options]: + raise ValueError(f'Variable {credential_form_schema.variable} is not in options') + + if credential_form_schema.type == FormType.SWITCH: + # If the value is not in ['true', 'false'], an exception is thrown + if value.lower() not in ['true', 'false']: + raise ValueError(f'Variable {credential_form_schema.variable} should be true or false') + + value = True if value.lower() == 'true' else False + + return value diff --git a/server/schema_validators/model_credential_schema_validator.py b/server/schema_validators/model_credential_schema_validator.py new file mode 100644 index 00000000..c4786fad --- /dev/null +++ b/server/schema_validators/model_credential_schema_validator.py @@ -0,0 +1,28 @@ +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.provider_entities import ModelCredentialSchema +from core.model_runtime.schema_validators.common_validator import CommonValidator + + +class ModelCredentialSchemaValidator(CommonValidator): + + def __init__(self, model_type: ModelType, model_credential_schema: ModelCredentialSchema): + self.model_type = model_type + self.model_credential_schema = model_credential_schema + + def validate_and_filter(self, credentials: dict) -> dict: + """ + Validate model credentials + + :param credentials: model credentials + :return: filtered credentials + """ + + if self.model_credential_schema is None: + raise ValueError("Model credential schema is None") + + # get the credential_form_schemas in provider_credential_schema + credential_form_schemas = self.model_credential_schema.credential_form_schemas + + credentials["__model_type"] = self.model_type.value + + return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/server/schema_validators/provider_credential_schema_validator.py b/server/schema_validators/provider_credential_schema_validator.py new file mode 100644 index 00000000..c9450165 --- /dev/null +++ b/server/schema_validators/provider_credential_schema_validator.py @@ -0,0 +1,20 @@ +from core.model_runtime.entities.provider_entities import ProviderCredentialSchema +from core.model_runtime.schema_validators.common_validator import CommonValidator + + +class ProviderCredentialSchemaValidator(CommonValidator): + + def __init__(self, provider_credential_schema: ProviderCredentialSchema): + self.provider_credential_schema = provider_credential_schema + + def validate_and_filter(self, credentials: dict) -> dict: + """ + Validate provider credentials + + :param credentials: provider credentials + :return: validated provider credentials + """ + # get the credential_form_schemas in provider_credential_schema + credential_form_schemas = self.provider_credential_schema.credential_form_schemas + + return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/server/utils/__init__.py b/server/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/utils/_compat.py b/server/utils/_compat.py new file mode 100644 index 00000000..5c341527 --- /dev/null +++ b/server/utils/_compat.py @@ -0,0 +1,21 @@ +from typing import Any, Literal + +from pydantic import BaseModel +from pydantic.version import VERSION as PYDANTIC_VERSION + +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") + +if PYDANTIC_V2: + from pydantic_core import Url as Url + + def _model_dump( + model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any + ) -> Any: + return model.model_dump(mode=mode, **kwargs) +else: + from pydantic import AnyUrl as Url # noqa: F401 + + def _model_dump( + model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any + ) -> Any: + return model.dict(**kwargs) diff --git a/server/utils/encoders.py b/server/utils/encoders.py new file mode 100644 index 00000000..cf6c98e0 --- /dev/null +++ b/server/utils/encoders.py @@ -0,0 +1,228 @@ +import dataclasses +import datetime +from collections import defaultdict, deque +from collections.abc import Callable +from decimal import Decimal +from enum import Enum +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from pathlib import Path, PurePath +from re import Pattern +from types import GeneratorType +from typing import Any, Optional, Union +from uuid import UUID + +from pydantic import BaseModel +from pydantic.color import Color +from pydantic.networks import AnyUrl, NameEmail +from pydantic.types import SecretBytes, SecretStr + +from ._compat import PYDANTIC_V2, Url, _model_dump + + +# Taken from Pydantic v1 as is +def isoformat(o: Union[datetime.date, datetime.time]) -> str: + return o.isoformat() + + +# Taken from Pydantic v1 as is +# TODO: pv2 should this return strings instead? +def decimal_encoder(dec_value: Decimal) -> Union[int, float]: + """ + Encodes a Decimal as int of there's no exponent, otherwise float + + This is useful when we use ConstrainedDecimal to represent Numeric(x,0) + where a integer (but not int typed) is used. Encoding this as a float + results in failed round-tripping between encode and parse. + Our Id type is a prime example of this. + + >>> decimal_encoder(Decimal("1.0")) + 1.0 + + >>> decimal_encoder(Decimal("1")) + 1 + """ + if dec_value.as_tuple().exponent >= 0: # type: ignore[operator] + return int(dec_value) + else: + return float(dec_value) + + +ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { + bytes: lambda o: o.decode(), + Color: str, + datetime.date: isoformat, + datetime.datetime: isoformat, + datetime.time: isoformat, + datetime.timedelta: lambda td: td.total_seconds(), + Decimal: decimal_encoder, + Enum: lambda o: o.value, + frozenset: list, + deque: list, + GeneratorType: list, + IPv4Address: str, + IPv4Interface: str, + IPv4Network: str, + IPv6Address: str, + IPv6Interface: str, + IPv6Network: str, + NameEmail: str, + Path: str, + Pattern: lambda o: o.pattern, + SecretBytes: str, + SecretStr: str, + set: list, + UUID: str, + Url: str, + AnyUrl: str, +} + + +def generate_encoders_by_class_tuples( + type_encoder_map: dict[Any, Callable[[Any], Any]] +) -> dict[Callable[[Any], Any], tuple[Any, ...]]: + encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict( + tuple + ) + for type_, encoder in type_encoder_map.items(): + encoders_by_class_tuples[encoder] += (type_,) + return encoders_by_class_tuples + + +encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) + + +def jsonable_encoder( + obj: Any, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None, + sqlalchemy_safe: bool = True, +) -> Any: + custom_encoder = custom_encoder or {} + if custom_encoder: + if type(obj) in custom_encoder: + return custom_encoder[type(obj)](obj) + else: + for encoder_type, encoder_instance in custom_encoder.items(): + if isinstance(obj, encoder_type): + return encoder_instance(obj) + if isinstance(obj, BaseModel): + # TODO: remove when deprecating Pydantic v1 + encoders: dict[Any, Any] = {} + if not PYDANTIC_V2: + encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined] + if custom_encoder: + encoders.update(custom_encoder) + obj_dict = _model_dump( + obj, + mode="json", + include=None, + exclude=None, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + ) + if "__root__" in obj_dict: + obj_dict = obj_dict["__root__"] + return jsonable_encoder( + obj_dict, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + # TODO: remove when deprecating Pydantic v1 + custom_encoder=encoders, + sqlalchemy_safe=sqlalchemy_safe, + ) + if dataclasses.is_dataclass(obj): + obj_dict = dataclasses.asdict(obj) + return jsonable_encoder( + obj_dict, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + if isinstance(obj, Enum): + return obj.value + if isinstance(obj, PurePath): + return str(obj) + if isinstance(obj, str | int | float | type(None)): + return obj + if isinstance(obj, Decimal): + return format(obj, 'f') + if isinstance(obj, dict): + encoded_dict = {} + allowed_keys = set(obj.keys()) + for key, value in obj.items(): + if ( + ( + not sqlalchemy_safe + or (not isinstance(key, str)) + or (not key.startswith("_sa")) + ) + and (value is not None or not exclude_none) + and key in allowed_keys + ): + encoded_key = jsonable_encoder( + key, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + encoded_value = jsonable_encoder( + value, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + encoded_dict[encoded_key] = encoded_value + return encoded_dict + if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque): + encoded_list = [] + for item in obj: + encoded_list.append( + jsonable_encoder( + item, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + ) + return encoded_list + + if type(obj) in ENCODERS_BY_TYPE: + return ENCODERS_BY_TYPE[type(obj)](obj) + for encoder, classes_tuple in encoders_by_class_tuples.items(): + if isinstance(obj, classes_tuple): + return encoder(obj) + + try: + data = dict(obj) + except Exception as e: + errors: list[Exception] = [] + errors.append(e) + try: + data = vars(obj) + except Exception as e: + errors.append(e) + raise ValueError(errors) from e + return jsonable_encoder( + data, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) diff --git a/server/utils/helper.py b/server/utils/helper.py new file mode 100644 index 00000000..09d08fa3 --- /dev/null +++ b/server/utils/helper.py @@ -0,0 +1,9 @@ +import pydantic +from pydantic import BaseModel + + +def dump_model(model: BaseModel) -> dict: + if hasattr(pydantic, 'model_dump'): + return pydantic.model_dump(model) + else: + return model.dict()