mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-29 02:03:37 +08:00
734 lines
29 KiB
Python
734 lines
29 KiB
Python
from collections.abc import Generator, Iterator
|
|
from typing import cast
|
|
|
|
from openai import (
|
|
APIConnectionError,
|
|
APITimeoutError,
|
|
AuthenticationError,
|
|
ConflictError,
|
|
InternalServerError,
|
|
NotFoundError,
|
|
OpenAI,
|
|
PermissionDeniedError,
|
|
RateLimitError,
|
|
UnprocessableEntityError,
|
|
)
|
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
|
|
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
|
|
from openai.types.chat.chat_completion_message import FunctionCall
|
|
from openai.types.completion import Completion
|
|
from xinference_client.client.restful.restful_client import (
|
|
Client,
|
|
RESTfulChatglmCppChatModelHandle,
|
|
RESTfulChatModelHandle,
|
|
RESTfulGenerateModelHandle,
|
|
)
|
|
|
|
from core.model_runtime.entities.common_entities import I18nObject
|
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
from core.model_runtime.entities.message_entities import (
|
|
AssistantPromptMessage,
|
|
PromptMessage,
|
|
PromptMessageTool,
|
|
SystemPromptMessage,
|
|
ToolPromptMessage,
|
|
UserPromptMessage,
|
|
)
|
|
from core.model_runtime.entities.model_entities import (
|
|
AIModelEntity,
|
|
FetchFrom,
|
|
ModelFeature,
|
|
ModelPropertyKey,
|
|
ModelType,
|
|
ParameterRule,
|
|
ParameterType,
|
|
)
|
|
from core.model_runtime.errors.invoke import (
|
|
InvokeAuthorizationError,
|
|
InvokeBadRequestError,
|
|
InvokeConnectionError,
|
|
InvokeError,
|
|
InvokeRateLimitError,
|
|
InvokeServerUnavailableError,
|
|
)
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
from core.model_runtime.model_providers.xinference.xinference_helper import (
|
|
XinferenceHelper,
|
|
XinferenceModelExtraParameter,
|
|
)
|
|
from core.model_runtime.utils import helper
|
|
|
|
|
|
class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
|
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
|
-> LLMResult | Generator:
|
|
"""
|
|
invoke LLM
|
|
|
|
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
|
|
"""
|
|
if 'temperature' in model_parameters:
|
|
if model_parameters['temperature'] < 0.01:
|
|
model_parameters['temperature'] = 0.01
|
|
elif model_parameters['temperature'] > 1.0:
|
|
model_parameters['temperature'] = 0.99
|
|
|
|
return self._generate(
|
|
model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
|
|
tools=tools, stop=stop, stream=stream, user=user,
|
|
extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter(
|
|
server_url=credentials['server_url'],
|
|
model_uid=credentials['model_uid']
|
|
)
|
|
)
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
"""
|
|
validate credentials
|
|
|
|
credentials should be like:
|
|
{
|
|
'model_type': 'text-generation',
|
|
'server_url': 'server url',
|
|
'model_uid': 'model uid',
|
|
}
|
|
"""
|
|
try:
|
|
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
|
|
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
|
|
|
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
|
server_url=credentials['server_url'],
|
|
model_uid=credentials['model_uid']
|
|
)
|
|
if 'completion_type' not in credentials:
|
|
if 'chat' in extra_param.model_ability:
|
|
credentials['completion_type'] = 'chat'
|
|
elif 'generate' in extra_param.model_ability:
|
|
credentials['completion_type'] = 'completion'
|
|
else:
|
|
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type')
|
|
|
|
if extra_param.support_function_call:
|
|
credentials['support_function_call'] = True
|
|
|
|
if extra_param.context_length:
|
|
credentials['context_length'] = extra_param.context_length
|
|
|
|
except RuntimeError as e:
|
|
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
|
|
except KeyError as e:
|
|
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
tools: list[PromptMessageTool] | None = None) -> int:
|
|
"""
|
|
get number of tokens
|
|
|
|
cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use
|
|
so we just take the GPT2 tokenizer as default
|
|
"""
|
|
return self._num_tokens_from_messages(prompt_messages, tools)
|
|
|
|
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
|
|
is_completion_model: bool = False) -> int:
|
|
def tokens(text: str):
|
|
return self._get_num_tokens_by_gpt2(text)
|
|
|
|
if is_completion_model:
|
|
return sum([tokens(str(message.content)) for message in messages])
|
|
|
|
tokens_per_message = 3
|
|
tokens_per_name = 1
|
|
|
|
num_tokens = 0
|
|
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
|
for message in messages_dict:
|
|
num_tokens += tokens_per_message
|
|
for key, value in message.items():
|
|
if isinstance(value, list):
|
|
text = ''
|
|
for item in value:
|
|
if isinstance(item, dict) and item['type'] == 'text':
|
|
text += item.text
|
|
|
|
value = text
|
|
|
|
if key == "tool_calls":
|
|
for tool_call in value:
|
|
for t_key, t_value in tool_call.items():
|
|
num_tokens += tokens(t_key)
|
|
if t_key == "function":
|
|
for f_key, f_value in t_value.items():
|
|
num_tokens += tokens(f_key)
|
|
num_tokens += tokens(f_value)
|
|
else:
|
|
num_tokens += tokens(t_key)
|
|
num_tokens += tokens(t_value)
|
|
if key == "function_call":
|
|
for t_key, t_value in value.items():
|
|
num_tokens += tokens(t_key)
|
|
if t_key == "function":
|
|
for f_key, f_value in t_value.items():
|
|
num_tokens += tokens(f_key)
|
|
num_tokens += tokens(f_value)
|
|
else:
|
|
num_tokens += tokens(t_key)
|
|
num_tokens += tokens(t_value)
|
|
else:
|
|
num_tokens += tokens(str(value))
|
|
|
|
if key == "name":
|
|
num_tokens += tokens_per_name
|
|
num_tokens += 3
|
|
|
|
if tools:
|
|
num_tokens += self._num_tokens_for_tools(tools)
|
|
|
|
return num_tokens
|
|
|
|
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
|
|
"""
|
|
Calculate num tokens for tool calling
|
|
|
|
:param encoding: encoding
|
|
:param tools: tools for tool calling
|
|
:return: number of tokens
|
|
"""
|
|
def tokens(text: str):
|
|
return self._get_num_tokens_by_gpt2(text)
|
|
|
|
num_tokens = 0
|
|
for tool in tools:
|
|
# calculate num tokens for function object
|
|
num_tokens += tokens('name')
|
|
num_tokens += tokens(tool.name)
|
|
num_tokens += tokens('description')
|
|
num_tokens += tokens(tool.description)
|
|
parameters = tool.parameters
|
|
num_tokens += tokens('parameters')
|
|
num_tokens += tokens('type')
|
|
num_tokens += tokens(parameters.get("type"))
|
|
if 'properties' in parameters:
|
|
num_tokens += tokens('properties')
|
|
for key, value in parameters.get('properties').items():
|
|
num_tokens += tokens(key)
|
|
for field_key, field_value in value.items():
|
|
num_tokens += tokens(field_key)
|
|
if field_key == 'enum':
|
|
for enum_field in field_value:
|
|
num_tokens += 3
|
|
num_tokens += tokens(enum_field)
|
|
else:
|
|
num_tokens += tokens(field_key)
|
|
num_tokens += tokens(str(field_value))
|
|
if 'required' in parameters:
|
|
num_tokens += tokens('required')
|
|
for required_field in parameters['required']:
|
|
num_tokens += 3
|
|
num_tokens += tokens(required_field)
|
|
|
|
return num_tokens
|
|
|
|
def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
|
|
"""
|
|
convert prompt message to text
|
|
"""
|
|
text = ''
|
|
for item in message:
|
|
if isinstance(item, UserPromptMessage):
|
|
text += item.content
|
|
elif isinstance(item, SystemPromptMessage):
|
|
text += item.content
|
|
elif isinstance(item, AssistantPromptMessage):
|
|
text += item.content
|
|
else:
|
|
raise NotImplementedError(f'PromptMessage type {type(item)} is not supported')
|
|
return text
|
|
|
|
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
|
"""
|
|
Convert PromptMessage to dict for OpenAI Compatibility API
|
|
"""
|
|
if isinstance(message, UserPromptMessage):
|
|
message = cast(UserPromptMessage, message)
|
|
if isinstance(message.content, str):
|
|
message_dict = {"role": "user", "content": message.content}
|
|
else:
|
|
raise ValueError("User message content must be str")
|
|
elif isinstance(message, AssistantPromptMessage):
|
|
message = cast(AssistantPromptMessage, message)
|
|
message_dict = {"role": "assistant", "content": message.content}
|
|
if message.tool_calls and len(message.tool_calls) > 0:
|
|
message_dict["function_call"] = {
|
|
"name": message.tool_calls[0].function.name,
|
|
"arguments": message.tool_calls[0].function.arguments
|
|
}
|
|
elif isinstance(message, SystemPromptMessage):
|
|
message = cast(SystemPromptMessage, message)
|
|
message_dict = {"role": "system", "content": message.content}
|
|
elif isinstance(message, ToolPromptMessage):
|
|
message = cast(ToolPromptMessage, message)
|
|
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
|
|
else:
|
|
raise ValueError(f"Unknown message type {type(message)}")
|
|
|
|
return message_dict
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
|
"""
|
|
used to define customizable model schema
|
|
"""
|
|
rules = [
|
|
ParameterRule(
|
|
name='temperature',
|
|
type=ParameterType.FLOAT,
|
|
use_template='temperature',
|
|
label=I18nObject(
|
|
zh_Hans='温度',
|
|
en_US='Temperature'
|
|
),
|
|
),
|
|
ParameterRule(
|
|
name='top_p',
|
|
type=ParameterType.FLOAT,
|
|
use_template='top_p',
|
|
label=I18nObject(
|
|
zh_Hans='Top P',
|
|
en_US='Top P'
|
|
)
|
|
),
|
|
ParameterRule(
|
|
name='max_tokens',
|
|
type=ParameterType.INT,
|
|
use_template='max_tokens',
|
|
min=1,
|
|
max=credentials.get('context_length', 2048),
|
|
default=512,
|
|
label=I18nObject(
|
|
zh_Hans='最大生成长度',
|
|
en_US='Max Tokens'
|
|
)
|
|
)
|
|
]
|
|
|
|
completion_type = None
|
|
|
|
if 'completion_type' in credentials:
|
|
if credentials['completion_type'] == 'chat':
|
|
completion_type = LLMMode.CHAT.value
|
|
elif credentials['completion_type'] == 'completion':
|
|
completion_type = LLMMode.COMPLETION.value
|
|
else:
|
|
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
|
|
else:
|
|
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
|
server_url=credentials['server_url'],
|
|
model_uid=credentials['model_uid']
|
|
)
|
|
|
|
if 'chat' in extra_args.model_ability:
|
|
completion_type = LLMMode.CHAT.value
|
|
elif 'generate' in extra_args.model_ability:
|
|
completion_type = LLMMode.COMPLETION.value
|
|
else:
|
|
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
|
|
|
|
support_function_call = credentials.get('support_function_call', False)
|
|
context_length = credentials.get('context_length', 2048)
|
|
|
|
entity = AIModelEntity(
|
|
model=model,
|
|
label=I18nObject(
|
|
en_US=model
|
|
),
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
model_type=ModelType.LLM,
|
|
features=[
|
|
ModelFeature.TOOL_CALL
|
|
] if support_function_call else [],
|
|
model_properties={
|
|
ModelPropertyKey.MODE: completion_type,
|
|
ModelPropertyKey.CONTEXT_SIZE: context_length
|
|
},
|
|
parameter_rules=rules
|
|
)
|
|
|
|
return entity
|
|
|
|
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter,
|
|
tools: list[PromptMessageTool] | None = None,
|
|
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
|
-> LLMResult | Generator:
|
|
"""
|
|
generate text from LLM
|
|
|
|
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate`
|
|
|
|
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
|
|
"""
|
|
if 'server_url' not in credentials:
|
|
raise CredentialsValidateFailedError('server_url is required in credentials')
|
|
|
|
if credentials['server_url'].endswith('/'):
|
|
credentials['server_url'] = credentials['server_url'][:-1]
|
|
|
|
client = OpenAI(
|
|
base_url=f'{credentials["server_url"]}/v1',
|
|
api_key='abc',
|
|
max_retries=3,
|
|
timeout=60,
|
|
)
|
|
|
|
xinference_client = Client(
|
|
base_url=credentials['server_url'],
|
|
)
|
|
|
|
xinference_model = xinference_client.get_model(credentials['model_uid'])
|
|
|
|
generate_config = {
|
|
'temperature': model_parameters.get('temperature', 1.0),
|
|
'top_p': model_parameters.get('top_p', 0.7),
|
|
'max_tokens': model_parameters.get('max_tokens', 512),
|
|
}
|
|
|
|
if stop:
|
|
generate_config['stop'] = stop
|
|
|
|
if tools and len(tools) > 0:
|
|
generate_config['tools'] = [
|
|
{
|
|
'type': 'function',
|
|
'function': helper.dump_model(tool)
|
|
} for tool in tools
|
|
]
|
|
|
|
if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle):
|
|
resp = client.chat.completions.create(
|
|
model=credentials['model_uid'],
|
|
messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages],
|
|
stream=stream,
|
|
user=user,
|
|
**generate_config,
|
|
)
|
|
if stream:
|
|
if tools and len(tools) > 0:
|
|
raise InvokeBadRequestError('xinference tool calls does not support stream mode')
|
|
return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
|
tools=tools, resp=resp)
|
|
return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
|
tools=tools, resp=resp)
|
|
elif isinstance(xinference_model, RESTfulGenerateModelHandle):
|
|
resp = client.completions.create(
|
|
model=credentials['model_uid'],
|
|
prompt=self._convert_prompt_message_to_text(prompt_messages),
|
|
stream=stream,
|
|
user=user,
|
|
**generate_config,
|
|
)
|
|
if stream:
|
|
return self._handle_completion_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
|
tools=tools, resp=resp)
|
|
return self._handle_completion_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
|
tools=tools, resp=resp)
|
|
else:
|
|
raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported')
|
|
|
|
def _extract_response_tool_calls(self,
|
|
response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
|
|
-> list[AssistantPromptMessage.ToolCall]:
|
|
"""
|
|
Extract tool calls from response
|
|
|
|
:param response_tool_calls: response tool calls
|
|
:return: list of tool calls
|
|
"""
|
|
tool_calls = []
|
|
if response_tool_calls:
|
|
for response_tool_call in response_tool_calls:
|
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
name=response_tool_call.function.name,
|
|
arguments=response_tool_call.function.arguments
|
|
)
|
|
|
|
tool_call = AssistantPromptMessage.ToolCall(
|
|
id=response_tool_call.id,
|
|
type=response_tool_call.type,
|
|
function=function
|
|
)
|
|
tool_calls.append(tool_call)
|
|
|
|
return tool_calls
|
|
|
|
def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
|
|
-> AssistantPromptMessage.ToolCall:
|
|
"""
|
|
Extract function call from response
|
|
|
|
:param response_function_call: response function call
|
|
:return: tool call
|
|
"""
|
|
tool_call = None
|
|
if response_function_call:
|
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
name=response_function_call.name,
|
|
arguments=response_function_call.arguments
|
|
)
|
|
|
|
tool_call = AssistantPromptMessage.ToolCall(
|
|
id=response_function_call.name,
|
|
type="function",
|
|
function=function
|
|
)
|
|
|
|
return tool_call
|
|
|
|
def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
tools: list[PromptMessageTool],
|
|
resp: ChatCompletion) -> LLMResult:
|
|
"""
|
|
handle normal chat generate response
|
|
"""
|
|
if len(resp.choices) == 0:
|
|
raise InvokeServerUnavailableError("Empty response")
|
|
|
|
assistant_message = resp.choices[0].message
|
|
|
|
# convert tool call to assistant message tool call
|
|
tool_calls = assistant_message.tool_calls
|
|
assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else [])
|
|
function_call = assistant_message.function_call
|
|
if function_call:
|
|
assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)]
|
|
|
|
# transform assistant message to prompt message
|
|
assistant_prompt_message = AssistantPromptMessage(
|
|
content=assistant_message.content,
|
|
tool_calls=assistant_prompt_message_tool_calls
|
|
)
|
|
|
|
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
|
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
|
|
|
|
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
|
|
|
response = LLMResult(
|
|
model=model,
|
|
prompt_messages=prompt_messages,
|
|
system_fingerprint=resp.system_fingerprint,
|
|
usage=usage,
|
|
message=assistant_prompt_message,
|
|
)
|
|
|
|
return response
|
|
|
|
def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
tools: list[PromptMessageTool],
|
|
resp: Iterator[ChatCompletionChunk]) -> Generator:
|
|
"""
|
|
handle stream chat generate response
|
|
"""
|
|
full_response = ''
|
|
|
|
for chunk in resp:
|
|
if len(chunk.choices) == 0:
|
|
continue
|
|
|
|
delta = chunk.choices[0]
|
|
|
|
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
|
|
continue
|
|
|
|
# check if there is a tool call in the response
|
|
function_call = None
|
|
tool_calls = []
|
|
if delta.delta.tool_calls:
|
|
tool_calls += delta.delta.tool_calls
|
|
if delta.delta.function_call:
|
|
function_call = delta.delta.function_call
|
|
|
|
assistant_message_tool_calls = self._extract_response_tool_calls(tool_calls)
|
|
if function_call:
|
|
assistant_message_tool_calls += [self._extract_response_function_call(function_call)]
|
|
|
|
# transform assistant message to prompt message
|
|
assistant_prompt_message = AssistantPromptMessage(
|
|
content=delta.delta.content if delta.delta.content else '',
|
|
tool_calls=assistant_message_tool_calls
|
|
)
|
|
|
|
if delta.finish_reason is not None:
|
|
# temp_assistant_prompt_message is used to calculate usage
|
|
temp_assistant_prompt_message = AssistantPromptMessage(
|
|
content=full_response,
|
|
tool_calls=assistant_message_tool_calls
|
|
)
|
|
|
|
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
|
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
|
|
|
|
usage = self._calc_response_usage(model=model, credentials=credentials,
|
|
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
|
|
|
yield LLMResultChunk(
|
|
model=model,
|
|
prompt_messages=prompt_messages,
|
|
system_fingerprint=chunk.system_fingerprint,
|
|
delta=LLMResultChunkDelta(
|
|
index=0,
|
|
message=assistant_prompt_message,
|
|
finish_reason=delta.finish_reason,
|
|
usage=usage
|
|
),
|
|
)
|
|
else:
|
|
yield LLMResultChunk(
|
|
model=model,
|
|
prompt_messages=prompt_messages,
|
|
system_fingerprint=chunk.system_fingerprint,
|
|
delta=LLMResultChunkDelta(
|
|
index=0,
|
|
message=assistant_prompt_message,
|
|
),
|
|
)
|
|
|
|
full_response += delta.delta.content
|
|
|
|
def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
tools: list[PromptMessageTool],
|
|
resp: Completion) -> LLMResult:
|
|
"""
|
|
handle normal completion generate response
|
|
"""
|
|
if len(resp.choices) == 0:
|
|
raise InvokeServerUnavailableError("Empty response")
|
|
|
|
assistant_message = resp.choices[0].text
|
|
|
|
# transform assistant message to prompt message
|
|
assistant_prompt_message = AssistantPromptMessage(
|
|
content=assistant_message,
|
|
tool_calls=[]
|
|
)
|
|
|
|
prompt_tokens = self._get_num_tokens_by_gpt2(
|
|
self._convert_prompt_message_to_text(prompt_messages)
|
|
)
|
|
completion_tokens = self._num_tokens_from_messages(
|
|
messages=[assistant_prompt_message], tools=[], is_completion_model=True
|
|
)
|
|
usage = self._calc_response_usage(
|
|
model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
|
|
)
|
|
|
|
response = LLMResult(
|
|
model=model,
|
|
prompt_messages=prompt_messages,
|
|
system_fingerprint=resp.system_fingerprint,
|
|
usage=usage,
|
|
message=assistant_prompt_message,
|
|
)
|
|
|
|
return response
|
|
|
|
def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
tools: list[PromptMessageTool],
|
|
resp: Iterator[Completion]) -> Generator:
|
|
"""
|
|
handle stream completion generate response
|
|
"""
|
|
full_response = ''
|
|
|
|
for chunk in resp:
|
|
if len(chunk.choices) == 0:
|
|
continue
|
|
|
|
delta = chunk.choices[0]
|
|
|
|
# transform assistant message to prompt message
|
|
assistant_prompt_message = AssistantPromptMessage(
|
|
content=delta.text if delta.text else '',
|
|
tool_calls=[]
|
|
)
|
|
|
|
if delta.finish_reason is not None:
|
|
# temp_assistant_prompt_message is used to calculate usage
|
|
temp_assistant_prompt_message = AssistantPromptMessage(
|
|
content=full_response,
|
|
tool_calls=[]
|
|
)
|
|
|
|
prompt_tokens = self._get_num_tokens_by_gpt2(
|
|
self._convert_prompt_message_to_text(prompt_messages)
|
|
)
|
|
completion_tokens = self._num_tokens_from_messages(
|
|
messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True
|
|
)
|
|
usage = self._calc_response_usage(model=model, credentials=credentials,
|
|
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
|
|
|
yield LLMResultChunk(
|
|
model=model,
|
|
prompt_messages=prompt_messages,
|
|
system_fingerprint=chunk.system_fingerprint,
|
|
delta=LLMResultChunkDelta(
|
|
index=0,
|
|
message=assistant_prompt_message,
|
|
finish_reason=delta.finish_reason,
|
|
usage=usage
|
|
),
|
|
)
|
|
else:
|
|
if delta.text is None or delta.text == '':
|
|
continue
|
|
|
|
yield LLMResultChunk(
|
|
model=model,
|
|
prompt_messages=prompt_messages,
|
|
system_fingerprint=chunk.system_fingerprint,
|
|
delta=LLMResultChunkDelta(
|
|
index=0,
|
|
message=assistant_prompt_message,
|
|
),
|
|
)
|
|
|
|
full_response += delta.text
|
|
|
|
@property
|
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
|
"""
|
|
Map model invoke error to unified error
|
|
The key is the error type thrown to the caller
|
|
The value is the error type thrown by the model,
|
|
which needs to be converted into a unified error type for the caller.
|
|
|
|
:return: Invoke error mapping
|
|
"""
|
|
return {
|
|
InvokeConnectionError: [
|
|
APIConnectionError,
|
|
APITimeoutError,
|
|
],
|
|
InvokeServerUnavailableError: [
|
|
InternalServerError,
|
|
ConflictError,
|
|
NotFoundError,
|
|
UnprocessableEntityError,
|
|
PermissionDeniedError
|
|
],
|
|
InvokeRateLimitError: [
|
|
RateLimitError
|
|
],
|
|
InvokeAuthorizationError: [
|
|
AuthenticationError
|
|
],
|
|
InvokeBadRequestError: [
|
|
ValueError
|
|
]
|
|
} |