make format

This commit is contained in:
glide-the 2024-03-31 17:55:57 +08:00
parent 2f1c9bfd11
commit 056b15b99b
2 changed files with 111 additions and 55 deletions

View File

@ -2,8 +2,11 @@ import typing
from subprocess import Popen
from typing import Optional
from model_providers.core.bootstrap.openai_protocol import ChatCompletionStreamResponseChoice, \
ChatCompletionStreamResponse, Finish
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
Finish,
)
from model_providers.core.utils.generic import jsonify
if typing.TYPE_CHECKING:
@ -11,12 +14,14 @@ if typing.TYPE_CHECKING:
def create_stream_chunk(
request_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional[Finish] = None,
request_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional[Finish] = None,
) -> str:
choice = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
choice = ChatCompletionStreamResponseChoice(
index=index, delta=delta, finish_reason=finish_reason
)
chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice])
return jsonify(chunk)

View File

@ -5,7 +5,18 @@ import multiprocessing as mp
import os
import pprint
import threading
from typing import Any, Dict, Optional, Union, Tuple, Type, List, cast, Generator, AsyncGenerator
from typing import (
Any,
AsyncGenerator,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
from fastapi.middleware.cors import CORSMiddleware
@ -15,21 +26,36 @@ from uvicorn import Config, Server
from model_providers.bootstrap_web.common import create_stream_chunk
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionMessage,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse,
ChatMessage,
EmbeddingsRequest,
EmbeddingsResponse,
Finish,
FunctionAvailable,
ModelCard,
ModelList, ChatMessage, ChatCompletionMessage, Role, Finish, ChatCompletionResponseChoice, UsageInfo,
ModelList,
Role,
UsageInfo,
)
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
)
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from model_providers.core.model_runtime.entities.message_entities import (
UserPromptMessage, PromptMessage, AssistantPromptMessage, ToolPromptMessage, SystemPromptMessage,
PromptMessageContent, PromptMessageContentType, TextPromptMessageContent, ImagePromptMessageContent,
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
@ -81,7 +107,7 @@ def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
def _create_template_from_message_type(
message_type: str, template: Union[str, list]
message_type: str, template: Union[str, list]
) -> PromptMessage:
"""Create a message prompt template from a message type and template string.
@ -102,9 +128,7 @@ def _create_template_from_message_type(
text: str = tmpl
else:
text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501
content.append(
TextPromptMessageContent(data=text)
)
content.append(TextPromptMessageContent(data=text))
elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(dict, tmpl)["image_url"]
if isinstance(img_template, str):
@ -142,7 +166,7 @@ def _create_template_from_message_type(
def _convert_to_message(
message: MessageLikeRepresentation,
message: MessageLikeRepresentation,
) -> Union[PromptMessage]:
"""Instantiate a message from a variety of message formats.
@ -161,7 +185,9 @@ def _convert_to_message(
an instance of a message or a message template
"""
if isinstance(message, ChatMessage):
_message = _create_template_from_message_type(message.role.to_origin_role(), message.content)
_message = _create_template_from_message_type(
message.role.to_origin_role(), message.content
)
elif isinstance(message, PromptMessage):
_message = message
@ -174,16 +200,16 @@ def _convert_to_message(
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
else:
raise ValueError(
f"Expected message type string, got {message_type_str}"
)
raise ValueError(f"Expected message type string, got {message_type_str}")
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message
async def _stream_openai_chat_completion(response: Generator) -> AsyncGenerator[str, None]:
async def _stream_openai_chat_completion(
response: Generator,
) -> AsyncGenerator[str, None]:
request_id, model = None, None
for chunk in response:
if not isinstance(chunk, LLMResultChunk):
@ -194,27 +220,41 @@ async def _stream_openai_chat_completion(response: Generator) -> AsyncGenerator[
model = chunk.model
if request_id is None:
request_id = "request_id"
yield create_stream_chunk(request_id, model, ChatCompletionMessage(role=Role.ASSISTANT, content=""))
yield create_stream_chunk(
request_id,
model,
ChatCompletionMessage(role=Role.ASSISTANT, content=""),
)
new_token = chunk.delta.message.content
if new_token:
delta = ChatCompletionMessage(role=Role.value_of(chunk.delta.message.role.to_origin_role()),
content=new_token,
tool_calls=chunk.delta.message.tool_calls)
yield create_stream_chunk(request_id=request_id,
model=model, delta=delta,
index=chunk.delta.index,
finish_reason=chunk.delta.finish_reason)
delta = ChatCompletionMessage(
role=Role.value_of(chunk.delta.message.role.to_origin_role()),
content=new_token,
tool_calls=chunk.delta.message.tool_calls,
)
yield create_stream_chunk(
request_id=request_id,
model=model,
delta=delta,
index=chunk.delta.index,
finish_reason=chunk.delta.finish_reason,
)
yield create_stream_chunk(request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP)
yield create_stream_chunk(
request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP
)
yield "[DONE]"
async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse:
choice = ChatCompletionResponseChoice(
index=0, message=ChatCompletionMessage(**_convert_prompt_message_to_dict(message=response.message)),
finish_reason=Finish.STOP
index=0,
message=ChatCompletionMessage(
**_convert_prompt_message_to_dict(message=response.message)
),
finish_reason=Finish.STOP,
)
usage = UsageInfo(
prompt_tokens=response.usage.prompt_tokens,
@ -335,7 +375,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
return ModelList(data=models_list)
async def create_embeddings(
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
):
logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
@ -345,7 +385,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
return EmbeddingsResponse(**dictify(response))
async def create_chat_completion(
self, provider: str, request: Request, chat_request: ChatCompletionRequest
self, provider: str, request: Request, chat_request: ChatCompletionRequest
):
logger.info(
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
@ -354,22 +394,29 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
model_instance = self._provider_manager.get_model_instance(
provider=provider, model_type=ModelType.LLM, model=chat_request.model
)
prompt_messages = [_convert_to_message(message) for message in chat_request.messages]
prompt_messages = [
_convert_to_message(message) for message in chat_request.messages
]
tools = [PromptMessageTool(name=f.function.name,
description=f.function.description,
parameters=f.function.parameters
)
for f in chat_request.tools]
tools = [
PromptMessageTool(
name=f.function.name,
description=f.function.description,
parameters=f.function.parameters,
)
for f in chat_request.tools
]
if chat_request.functions:
tools.extend([PromptMessageTool(name=f.name,
description=f.description,
parameters=f.parameters
) for f in chat_request.functions])
tools.extend(
[
PromptMessageTool(
name=f.name, description=f.description, parameters=f.parameters
)
for f in chat_request.functions
]
)
try:
response = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={**chat_request.to_model_parameters_dict()},
@ -380,21 +427,25 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
)
if chat_request.stream:
return EventSourceResponse(_stream_openai_chat_completion(response), media_type="text/event-stream")
return EventSourceResponse(
_stream_openai_chat_completion(response),
media_type="text/event-stream",
)
else:
return await _openai_chat_completion(response)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except InvokeError as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
def run(
cfg: Dict,
logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
cfg: Dict,
logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
):
logging.config.dictConfig(logging_conf) # type: ignore
try: