chat_completions接口报文适配

This commit is contained in:
glide-the 2024-03-31 17:55:32 +08:00
parent a2df71d9ea
commit 2f1c9bfd11
4 changed files with 317 additions and 27 deletions

View File

@ -0,0 +1,22 @@
import typing
from subprocess import Popen
from typing import Optional
from model_providers.core.bootstrap.openai_protocol import ChatCompletionStreamResponseChoice, \
ChatCompletionStreamResponse, Finish
from model_providers.core.utils.generic import jsonify
if typing.TYPE_CHECKING:
from model_providers.core.bootstrap.openai_protocol import ChatCompletionMessage
def create_stream_chunk(
request_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional[Finish] = None,
) -> str:
choice = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice])
return jsonify(chunk)

View File

@ -5,14 +5,14 @@ import multiprocessing as mp
import os
import pprint
import threading
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union, Tuple, Type, List, cast, Generator, AsyncGenerator
import tiktoken
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette import EventSourceResponse
from uvicorn import Config, Server
from model_providers.bootstrap_web.common import create_stream_chunk
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionRequest,
@ -22,20 +22,212 @@ from model_providers.core.bootstrap.openai_protocol import (
EmbeddingsResponse,
FunctionAvailable,
ModelCard,
ModelList,
ModelList, ChatMessage, ChatCompletionMessage, Role, Finish, ChatCompletionResponseChoice, UsageInfo,
)
from model_providers.core.model_manager import ModelInstance, ModelManager
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from model_providers.core.model_runtime.entities.message_entities import (
UserPromptMessage,
UserPromptMessage, PromptMessage, AssistantPromptMessage, ToolPromptMessage, SystemPromptMessage,
PromptMessageContent, PromptMessageContentType, TextPromptMessageContent, ImagePromptMessageContent,
PromptMessageTool,
)
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelType,
)
from model_providers.core.model_runtime.errors.invoke import InvokeError
from model_providers.core.utils.generic import dictify, jsonify
logger = logging.getLogger(__name__)
MessageLike = Union[ChatMessage, PromptMessage]
MessageLikeRepresentation = Union[
MessageLike,
Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
str,
]
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for OpenAI Compatibility API
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError("User message content must be str")
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# check if last message is user message
message = cast(ToolPromptMessage, message)
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return message_dict
def _create_template_from_message_type(
message_type: str, template: Union[str, list]
) -> PromptMessage:
"""Create a message prompt template from a message type and template string.
Args:
message_type: str the type of the message template (e.g., "human", "ai", etc.)
template: str the template string.
Returns:
a message prompt template of the appropriate type.
"""
if isinstance(template, str):
content = template
elif isinstance(template, list):
content = []
for tmpl in template:
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
if isinstance(tmpl, str):
text: str = tmpl
else:
text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501
content.append(
TextPromptMessageContent(data=text)
)
elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(dict, tmpl)["image_url"]
if isinstance(img_template, str):
img_template_obj = ImagePromptMessageContent(data=img_template)
elif isinstance(img_template, dict):
img_template = dict(img_template)
if "url" in img_template:
url = img_template["url"]
else:
url = None
img_template_obj = ImagePromptMessageContent(data=url)
else:
raise ValueError()
content.append(img_template_obj)
else:
raise ValueError()
else:
raise ValueError()
if message_type in ("human", "user"):
_message = UserPromptMessage(content=content)
elif message_type in ("ai", "assistant"):
_message = AssistantPromptMessage(content=content)
elif message_type == "system":
_message = SystemPromptMessage(content=content)
elif message_type in ("function", "tool"):
_message = ToolPromptMessage(content=content)
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system' and 'function' or 'tool'."
)
return _message
def _convert_to_message(
message: MessageLikeRepresentation,
) -> Union[PromptMessage]:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- 2-tuple of (message class, template)
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
Returns:
an instance of a message or a message template
"""
if isinstance(message, ChatMessage):
_message = _create_template_from_message_type(message.role.to_origin_role(), message.content)
elif isinstance(message, PromptMessage):
_message = message
elif isinstance(message, str):
_message = _create_template_from_message_type("human", message)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
else:
raise ValueError(
f"Expected message type string, got {message_type_str}"
)
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message
async def _stream_openai_chat_completion(response: Generator) -> AsyncGenerator[str, None]:
request_id, model = None, None
for chunk in response:
if not isinstance(chunk, LLMResultChunk):
yield "[ERROR]"
return
if model is None:
model = chunk.model
if request_id is None:
request_id = "request_id"
yield create_stream_chunk(request_id, model, ChatCompletionMessage(role=Role.ASSISTANT, content=""))
new_token = chunk.delta.message.content
if new_token:
delta = ChatCompletionMessage(role=Role.value_of(chunk.delta.message.role.to_origin_role()),
content=new_token,
tool_calls=chunk.delta.message.tool_calls)
yield create_stream_chunk(request_id=request_id,
model=model, delta=delta,
index=chunk.delta.index,
finish_reason=chunk.delta.finish_reason)
yield create_stream_chunk(request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP)
yield "[DONE]"
async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse:
choice = ChatCompletionResponseChoice(
index=0, message=ChatCompletionMessage(**_convert_prompt_message_to_dict(message=response.message)),
finish_reason=Finish.STOP
)
usage = UsageInfo(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
return ChatCompletionResponse(
id="request_id",
model=response.model,
choices=[choice],
usage=usage,
)
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
"""
@ -143,7 +335,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
return ModelList(data=models_list)
async def create_embeddings(
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
):
logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
@ -153,7 +345,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
return EmbeddingsResponse(**dictify(response))
async def create_chat_completion(
self, provider: str, request: Request, chat_request: ChatCompletionRequest
self, provider: str, request: Request, chat_request: ChatCompletionRequest
):
logger.info(
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
@ -162,38 +354,47 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
model_instance = self._provider_manager.get_model_instance(
provider=provider, model_type=ModelType.LLM, model=chat_request.model
)
if chat_request.stream:
# Invoke model
prompt_messages = [_convert_to_message(message) for message in chat_request.messages]
tools = [PromptMessageTool(name=f.function.name,
description=f.function.description,
parameters=f.function.parameters
)
for f in chat_request.tools]
if chat_request.functions:
tools.extend([PromptMessageTool(name=f.name,
description=f.description,
parameters=f.parameters
) for f in chat_request.functions])
try:
response = model_instance.invoke_llm(
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
prompt_messages=prompt_messages,
model_parameters={**chat_request.to_model_parameters_dict()},
tools=tools,
stop=chat_request.stop,
stream=chat_request.stream,
user="abc-123",
)
return EventSourceResponse(response, media_type="text/event-stream")
else:
# Invoke model
if chat_request.stream:
response = model_instance.invoke_llm(
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={**chat_request.to_model_parameters_dict()},
stop=chat_request.stop,
stream=chat_request.stream,
user="abc-123",
)
return EventSourceResponse(_stream_openai_chat_completion(response), media_type="text/event-stream")
else:
return await _openai_chat_completion(response)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
chat_response = ChatCompletionResponse(**dictify(response))
return chat_response
except InvokeError as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
def run(
cfg: Dict,
logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
cfg: Dict,
logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
):
logging.config.dictConfig(logging_conf) # type: ignore
try:

View File

@ -13,12 +13,62 @@ class Role(str, Enum):
FUNCTION = "function"
TOOL = "tool"
@classmethod
def value_of(cls, origin_role: str) -> "Role":
if origin_role == "user":
return cls.USER
elif origin_role == "assistant":
return cls.ASSISTANT
elif origin_role == "system":
return cls.SYSTEM
elif origin_role == "function":
return cls.FUNCTION
elif origin_role == "tool":
return cls.TOOL
else:
raise ValueError(f"invalid origin role {origin_role}")
def to_origin_role(self) -> str:
if self == self.USER:
return "user"
elif self == self.ASSISTANT:
return "assistant"
elif self == self.SYSTEM:
return "system"
elif self == self.FUNCTION:
return "function"
elif self == self.TOOL:
return "tool"
else:
raise ValueError(f"invalid role {self}")
class Finish(str, Enum):
STOP = "stop"
LENGTH = "length"
TOOL = "tool_calls"
@classmethod
def value_of(cls, origin_finish: str) -> "Finish":
if origin_finish == "stop":
return cls.STOP
elif origin_finish == "length":
return cls.LENGTH
elif origin_finish == "tool_calls":
return cls.TOOL
else:
raise ValueError(f"invalid origin finish {origin_finish}")
def to_origin_finish(self) -> str:
if self == self.STOP:
return "stop"
elif self == self.LENGTH:
return "length"
elif self == self.TOOL:
return "tool_calls"
else:
raise ValueError(f"invalid finish {self}")
class ModelCard(BaseModel):
id: str
@ -95,7 +145,7 @@ class ChatCompletionRequest(BaseModel):
top_k: Optional[float] = None
n: int = 1
max_tokens: Optional[int] = None
stop: Optional[list[str]] = (None,)
stop: Optional[list[str]] = None
stream: Optional[bool] = False
def to_model_parameters_dict(self, *args, **kwargs):

View File

@ -28,6 +28,23 @@ class PromptMessageRole(Enum):
return mode
raise ValueError(f"invalid prompt message type value {value}")
def to_origin_role(self) -> str:
"""
Get origin role from prompt message role.
:return: origin role
"""
if self == self.SYSTEM:
return "system"
elif self == self.USER:
return "user"
elif self == self.ASSISTANT:
return "assistant"
elif self == self.TOOL:
return "tool"
else:
raise ValueError(f"invalid role {self}")
class PromptMessageTool(BaseModel):
"""