update readme and add more demos

This commit is contained in:
zh-zheng 2024-09-06 20:29:35 +08:00
parent f541d519a0
commit 9ac6cf3e38
12 changed files with 2259 additions and 6 deletions

View File

@ -415,16 +415,47 @@ We have supported fine-tuning MiniCPM3 using [LLaMA-Factory](https://github.com/
### Advanced Features
We recommend using [vLLM](#vllm) for the following advanced features.
#### Function calling
We provide example code for using function calls with MiniCPM3, see [`demo/function_call.py`](./demo/function_calling.py).
We provide example code for using function calls with MiniCPM3:
```bash
cd demo/function_call
python function_call.py
```
If you want to start a function call service, use the following commands:
```bash
cd demo/function_call
pip install -r requirements.txt
python openai_api_server.py \
--model openbmb/MiniCPM3-4B \
--served-model-name MiniCPM3-4B \
--chat-template chatml.jinja \
--dtype auto \
--api-key token-abc123 \
--tensor-parallel-size 1 \
--trust-remote-code
```
Below is a demo of using a search engine to answer the question:
![function_call](./assets/function_call.gif)
#### Code Interpreter
We provide example code for using the code interpreter with MiniCPM3, see [`demo/code_interpreter.py`](./demo/code_interpreter.py).
We provide example code for using the code interpreter with MiniCPM3:
Below is a demo:
```bash
cd demo/code_interpreter
pip install -r requirements.txt
python code_interpreter.py openbmb/MiniCPM3-4B
```
Below is an example of using the code interpreter to generate a QR code:
![code_interpreter](./assets/code_interpreter.gif)

View File

@ -415,15 +415,48 @@ print(responds)
目前模型微调支持 [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory),使用方法参考 [LLaMA-Factory 微调](https://modelbest.feishu.cn/docx/Z7USdW4lloZzkZxQ14icJ3senjb?from=from_copylink)。
### 进阶功能
对于以下进阶功能,我们推荐使用 [vLLM](#vllm)。
#### 工具调用
我们提供了使用 MiniCPM3 调用工具的示例代码,见[`demo/function_calling.py`](./demo/function_calling.py)。
我们提供了使用 MiniCPM3 调用工具的示例代码:
```bash
cd demo/function_call
python function_call.py
```
如果你想启动一个能够调用工具的推理服务,使用以下代码:
```bash
cd demo/function_call
pip install -r requirements.txt
python openai_api_server.py \
--model openbmb/MiniCPM3-4B \
--served-model-name MiniCPM3-4B \
--chat-template chatml.jinja \
--dtype auto \
--api-key token-abc123 \
--tensor-parallel-size 1 \
--trust-remote-code
```
下面是一个调用搜索工具回答问题的演示:
![function_call](./assets/function_call.gif)
#### 代码解释器
我们提供了一个 MiniCPM3 使用代码解释器的示例代码,见[`demo/code_interpreter.py`](./demo/code_interpreter.py)。
我们提供了一个 MiniCPM3 使用代码解释器的示例代码
下面是一个 Demo
```bash
cd demo/code_interpreter
pip install -r requirements.txt
python code_interpreter.py openbmb/MiniCPM3-4B
```
下面是一个使用代码解释器生成二维码的演示:
![code_interpreter](./assets/code_interpreter.gif)

BIN
assets/function_call.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

View File

@ -0,0 +1 @@
fire

View File

@ -0,0 +1,2 @@
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}

View File

@ -0,0 +1,370 @@
import asyncio
import importlib
import inspect
import re
from argparse import Namespace
from contextlib import asynccontextmanager
from http import HTTPStatus
from multiprocessing import Process
from typing import AsyncIterator, Set
from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from starlette.routing import Mount
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser
# yapf conflicts with isort for this block
# yapf: disable
from openai_protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest, ErrorResponse,
TokenizeRequest,
TokenizeResponse)
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable
from openai_serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_open_port
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
async_engine_client: AsyncEngineClient
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
openai_serving_tokenization: OpenAIServingTokenization
logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set()
def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
return ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
seed=0,
dtype="float16").embedding_mode
@asynccontextmanager
async def lifespan(app: FastAPI):
async def _force_log():
while True:
await asyncio.sleep(10)
await async_engine_client.do_log_stats()
if not engine_args.disable_log_stats:
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
yield
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global engine_args
engine_args = AsyncEngineArgs.from_cli_args(args)
# Backend itself still global for the silly lil' health handler
global async_engine_client
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if (model_is_embedding(args.model, args.trust_remote_code)
or args.disable_frontend_multiprocessing):
async_engine_client = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
yield async_engine_client
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
else:
# Start RPCServer in separate process (holds the AsyncLLMEngine).
port = get_open_port(envs.VLLM_RPC_PORT)
rpc_server_process = Process(target=run_rpc_server,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
port))
rpc_server_process.start()
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(port)
await async_engine_client.setup()
try:
yield async_engine_client
finally:
# Ensure rpc server process was terminated
rpc_server_process.terminate()
# Close all open connections to the backend
async_engine_client.close()
# Wait for server process to join
rpc_server_process.join()
router = APIRouter()
def mount_metrics(app: FastAPI):
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(metrics_route)
@router.get("/health")
async def health() -> Response:
"""Health check."""
await async_engine_client.check_health()
return Response(status_code=200)
@router.post("/tokenize")
async def tokenize(request: TokenizeRequest):
generator = await openai_serving_tokenization.create_tokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, TokenizeResponse)
return JSONResponse(content=generator.model_dump())
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
generator = await openai_serving_tokenization.create_detokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, DetokenizeResponse)
return JSONResponse(content=generator.model_dump())
@router.get("/v1/models")
async def show_available_models():
models = await openai_serving_completion.show_available_models()
return JSONResponse(content=models.model_dump())
@router.get("/version")
async def show_version():
ver = {"version": VLLM_VERSION}
return JSONResponse(content=ver)
@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump())
@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
return JSONResponse(content=generator.model_dump())
@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await openai_serving_embedding.create_embedding(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
return JSONResponse(content=generator.model_dump())
def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path
mount_metrics(app)
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)
if token := envs.VLLM_API_KEY or args.api_key:
@app.middleware("http")
async def authentication(request: Request, call_next):
root_path = "" if args.root_path is None else args.root_path
if request.method == "OPTIONS":
return await call_next(request)
if not request.url.path.startswith(f"{root_path}/v1"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token:
return JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return await call_next(request)
for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.")
return app
async def init_app(
async_engine_client: AsyncEngineClient,
args: Namespace,
) -> FastAPI:
app = build_app(args)
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
model_config = await async_engine_client.get_model_config()
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
global openai_serving_chat
global openai_serving_completion
global openai_serving_embedding
global openai_serving_tokenization
openai_serving_chat = OpenAIServingChat(
async_engine_client,
model_config,
served_model_names,
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=args.chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_completion = OpenAIServingCompletion(
async_engine_client,
model_config,
served_model_names,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_embedding = OpenAIServingEmbedding(
async_engine_client,
model_config,
served_model_names,
request_logger=request_logger,
)
openai_serving_tokenization = OpenAIServingTokenization(
async_engine_client,
model_config,
served_model_names,
lora_modules=args.lora_modules,
request_logger=request_logger,
chat_template=args.chat_template,
)
app.root_path = args.root_path
return app
async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
async with build_async_engine_client(args) as async_engine_client:
app = await init_app(async_engine_client, args)
shutdown_task = await serve_http(
app,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)
# NB: Await server shutdown only after the backend context is exited
await shutdown_task
if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
args = parser.parse_args()
asyncio.run(run_server(args))

View File

@ -0,0 +1,738 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union
import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator
from transformers import PreTrainedTokenizer
from typing_extensions import Annotated
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.utils import random_uuid
# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
try:
from sphinx.ext.autodoc.mock import _MockModule
if isinstance(torch, _MockModule):
_LONG_INFO = _MOCK_LONG_INFO
else:
_LONG_INFO = torch.iinfo(torch.long)
except ModuleNotFoundError:
_LONG_INFO = torch.iinfo(torch.long)
assert _LONG_INFO.min == _MOCK_LONG_INFO.min
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
class ErrorResponse(OpenAIBaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int
class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
allow_create_engine: bool = False
allow_sampling: bool = True
allow_logprobs: bool = True
allow_search_indices: bool = False
allow_view: bool = True
allow_fine_tuning: bool = False
organization: str = "*"
group: Optional[str] = None
is_blocking: bool = False
class ModelCard(OpenAIBaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "vllm"
root: Optional[str] = None
parent: Optional[str] = None
max_model_len: Optional[int] = None
permission: List[ModelPermission] = Field(default_factory=list)
class ModelList(OpenAIBaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class UsageInfo(OpenAIBaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ResponseFormat(OpenAIBaseModel):
# type must be "json_object" or "text"
type: Literal["text", "json_object"]
class StreamOptions(OpenAIBaseModel):
include_usage: Optional[bool] = True
continuous_usage_stats: Optional[bool] = True
class FunctionDefinition(OpenAIBaseModel):
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None
class ChatCompletionToolsParam(OpenAIBaseModel):
type: Literal["function"] = "function"
function: FunctionDefinition
class ChatCompletionNamedFunction(OpenAIBaseModel):
name: str
class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
function: ChatCompletionNamedFunction
type: Literal["function"] = "function"
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[ChatCompletionMessageParam]
model: str
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
max_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"],
ChatCompletionNamedToolChoiceParam]] = "none"
user: Optional[str] = None
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
use_beam_search: bool = False
top_k: int = -1
min_p: float = 0.0
repetition_penalty: float = 1.0
length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False
ignore_eos: bool = False
min_tokens: int = 0
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params
echo: bool = Field(
default=False,
description=(
"If true, the new message will be prepended with the last message "
"if they belong to the same role."),
)
add_generation_prompt: bool = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."),
)
documents: Optional[List[Dict[str, str]]] = Field(
default=None,
description=
("A list of dicts representing documents that will be accessible to "
"the model if it is performing RAG (retrieval-augmented generation)."
" If the template does not support RAG, this argument will have no "
"effect. We recommend that each document should be a dict containing "
"\"title\" and \"text\" keys."),
)
chat_template: Optional[str] = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"If this is not passed, the model's default chat template will be "
"used instead."),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
# doc: end-chat-completion-extra-params
def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
# We now allow logprobs being true without top_logrobs.
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
allowed_token_ids=None,
tokenizer=tokenizer,
)
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)
return SamplingParams(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
ignore_eos=self.ignore_eos,
max_tokens=max_tokens,
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
@model_validator(mode='before')
@classmethod
def validate_stream_options(cls, values):
if (values.get('stream_options') is not None
and not values.get('stream')):
raise ValueError(
"stream_options can only be set if stream is true")
return values
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
# you can only use one kind of guided decoding
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both
if guide_count > 1 and "tool_choice" in data and data[
"tool_choice"] != "none":
raise ValueError(
"You can only either use guided decoding or tools, not both.")
return data
@model_validator(mode="before")
@classmethod
def check_tool_choice(cls, data):
if "tool_choice" in data and data["tool_choice"] != "none":
if not isinstance(data["tool_choice"], dict):
raise ValueError("Currently only named tools are supported.")
if "tools" not in data or data["tools"] is None:
raise ValueError(
"When using `tool_choice`, `tools` must be set.")
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "top_logprobs" in data and data["top_logprobs"] is not None:
if "logprobs" not in data or data["logprobs"] is False:
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
elif data["top_logprobs"] < 0:
raise ValueError(
"`top_logprobs` must be a value a positive value.")
return data
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
best_of: Optional[int] = None
echo: Optional[bool] = False
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = 16
n: int = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
user: Optional[str] = None
# doc: begin-completion-sampling-params
use_beam_search: bool = False
top_k: int = -1
min_p: float = 0.0
repetition_penalty: float = 1.0
length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False
ignore_eos: bool = False
min_tokens: int = 0
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
allowed_token_ids: Optional[List[int]] = None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."),
)
response_format: Optional[ResponseFormat] = Field(
default=None,
description=
("Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
"supported."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
# doc: end-completion-extra-params
def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids,
tokenizer=tokenizer,
)
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)
return SamplingParams(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.logprobs,
ignore_eos=self.ignore_eos,
max_tokens=max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "logprobs" in data and data[
"logprobs"] is not None and not data["logprobs"] >= 0:
raise ValueError("if passed, `logprobs` must be a positive value.")
return data
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"Stream options can only be defined when stream is true.")
return data
class EmbeddingRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model: str
input: Union[List[int], List[List[int]], str, List[str]]
encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
dimensions: Optional[int] = None
user: Optional[str] = None
# doc: begin-embedding-pooling-params
additional_data: Optional[Any] = None
# doc: end-embedding-pooling-params
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
class CompletionLogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
class CompletionResponseChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[CompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
class CompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[CompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
class CompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
class EmbeddingResponseData(OpenAIBaseModel):
index: int
object: str = "embedding"
embedding: Union[List[float], str]
class EmbeddingResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: List[EmbeddingResponseData]
usage: UsageInfo
class FunctionCall(OpenAIBaseModel):
name: str
arguments: str
class ToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
function: FunctionCall
class ChatMessage(OpenAIBaseModel):
role: str
content: str
tool_calls: List[ToolCall] = Field(default_factory=list)
tool_call_id: Optional[str] = None
class ChatCompletionLogProb(OpenAIBaseModel):
token: str
logprob: float = -9999.0
bytes: Optional[List[int]] = None
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
class ChatCompletionLogProbs(OpenAIBaseModel):
content: Optional[List[ChatCompletionLogProbsContent]] = None
class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int
message: ChatMessage
logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None
class ChatCompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
tool_call_id: Optional[str] = None
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None
class ChatCompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
class BatchRequestInput(OpenAIBaseModel):
"""
The per-line object of the batch input file.
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
"""
# A developer-provided per-request id that will be used to match outputs to
# inputs. Must be unique for each request in a batch.
custom_id: str
# The HTTP method to be used for the request. Currently only POST is
# supported.
method: str
# The OpenAI API relative URL to be used for the request. Currently
# /v1/chat/completions is supported.
url: str
# The parameters of the request.
body: ChatCompletionRequest
class BatchResponseData(OpenAIBaseModel):
# HTTP status code of the response.
status_code: int = 200
# An unique identifier for the API request.
request_id: str
# The body of the response.
body: Optional[ChatCompletionResponse] = None
class BatchRequestOutput(OpenAIBaseModel):
"""
The per-line object of the batch output and error files
"""
id: str
# A developer-provided per-request id that will be used to match outputs to
# inputs.
custom_id: str
response: Optional[BatchResponseData]
# For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure.
error: Optional[Any]
class TokenizeCompletionRequest(OpenAIBaseModel):
model: str
prompt: str
add_special_tokens: bool = Field(default=True)
class TokenizeChatRequest(OpenAIBaseModel):
model: str
messages: List[ChatCompletionMessageParam]
add_generation_prompt: bool = Field(default=True)
add_special_tokens: bool = Field(default=False)
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
class TokenizeResponse(OpenAIBaseModel):
count: int
max_model_len: int
tokens: List[int]
class DetokenizeRequest(OpenAIBaseModel):
model: str
tokens: List[int]
class DetokenizeResponse(OpenAIBaseModel):
prompt: str

View File

@ -0,0 +1,691 @@
import time
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Union
from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (
ConversationMessage,
load_chat_template,
parse_chat_messages,
)
from vllm.entrypoints.logger import RequestLogger
from openai_protocol import (
ChatCompletionLogProb,
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
ErrorResponse,
FunctionCall,
ToolCall,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import (
LoRAModulePath,
OpenAIServing,
PromptAdapterPath,
)
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
log_tracing_disabled_warning,
)
from vllm.utils import random_uuid
from utils import decode_function_call
import json
logger = init_logger(__name__)
class OpenAIServingChat(OpenAIServing):
def __init__(
self,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
):
super().__init__(
async_engine_client=async_engine_client,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
)
self.response_role = response_role
# If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template)
async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Optional[Request] = None
) -> Union[ErrorResponse, AsyncGenerator[str, None], ChatCompletionResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI
ChatCompletion API.
NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
model_config = self.model_config
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
conversation, mm_futures = parse_chat_messages(
request.messages, model_config, tokenizer
)
print('conversation:', conversation)
# parse_chat_messages ignores tool_calls and tool_call_id
# we have to fix this
conversation = request.messages
for msg in conversation:
if 'tool_calls' in msg and msg['tool_calls'] is not None:
msg['tool_calls'] = [tc for tc in msg['tool_calls']]
print('fixed conversation:', conversation)
tool_dicts = (
None
if request.tools is None
else [tool.model_dump() for tool in request.tools]
)
prompt = tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}),
)
assert isinstance(prompt, str)
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
mm_data: Optional[MultiModalDataDict] = None
try:
if len(mm_futures):
# since we support only single mm data currently
assert (
len(mm_futures) == 1
), "Multiple 'image_url' input is currently not supported."
mm_data = await mm_futures[0]
except Exception as e:
logger.error("Error in loading multi-modal data: %s", e)
return self.create_error_response(str(e))
request_id = f"chat-{random_uuid()}"
try:
guided_decode_logits_processor = await self._guided_decode_logits_processor(
request, tokenizer
)
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
default_max_tokens=self.max_model_len
- len(prompt_inputs["prompt_token_ids"]),
)
self._log_inputs(
request_id,
prompt_inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
engine_inputs: PromptInputs = {
"prompt_token_ids": prompt_inputs["prompt_token_ids"],
}
if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data
is_tracing_enabled = await self.async_engine_client.is_tracing_enabled()
trace_headers = None
if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers)
if (
not is_tracing_enabled
and raw_request
and contains_trace_headers(raw_request.headers)
):
log_tracing_disabled_warning()
result_generator = self.async_engine_client.generate(
engine_inputs,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer
)
else:
try:
return await self.chat_completion_full_generator(
request,
raw_request,
result_generator,
request_id,
conversation,
tokenizer,
)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
return self.response_role
else:
return request.messages[-1]["role"]
async def chat_completion_stream_generator(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
chunk_object_type = "chat.completion.chunk"
first_iteration = True
# Send response for each token for each request.n (index)
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
try:
async for res in result_generator:
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
if first_iteration:
# Send first response for each request.n (index) with
# the role
role = self.get_chat_request_role(request)
for i in range(num_choices):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role),
logprobs=None,
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name,
)
if (
request.stream_options
and request.stream_options.include_usage
):
if request.stream_options.continuous_usage_stats:
prompt_tokens = len(res.prompt_token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=0,
total_tokens=prompt_tokens,
)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Send response to echo the input portion of the
# last message
if request.echo:
last_msg_content = ""
if (
conversation
and conversation[-1].get("content")
and conversation[-1].get("role") == role
):
last_msg_content = conversation[-1]["content"]
if last_msg_content:
for i in range(num_choices):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=last_msg_content),
logprobs=None,
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name,
)
if (
request.stream_options
and request.stream_options.include_usage
):
if request.stream_options.continuous_usage_stats:
prompt_tokens = len(res.prompt_token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=0,
total_tokens=prompt_tokens,
)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
first_iteration = False
for output in res.outputs:
i = output.index
if finish_reason_sent[i]:
continue
delta_token_ids = output.token_ids[previous_num_tokens[i] :]
out_logprobs = (
output.logprobs[previous_num_tokens[i] :]
if output.logprobs
else None
)
if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
)
else:
logprobs = None
# if have tool
if request.tools is not None and len(request.tools) > 0:
if output.finish_reason is not None:
msg = decode_function_call(output.text)
if "tool_calls" in msg and msg["tool_calls"] is not None and len(msg["tool_calls"]) > 0:
delta_message = DeltaMessage(
content=msg.get("thought", ""),
tool_calls=[
ToolCall(
id=f"chatcmpl-tool-{random_uuid()}",
function=FunctionCall(
name=fc["function"]["name"],
arguments=json.dumps(fc["function"]["arguments"], ensure_ascii=False),
)
)
for fc in msg["tool_calls"]
]
)
else:
delta_message = DeltaMessage(content=msg.get("content", ""))
else:
# only return the last one
continue
else:
delta_text = output.text[len(previous_texts[i]) :]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
if (
request.tool_choice
and type(request.tool_choice)
is ChatCompletionNamedToolChoiceParam
):
delta_message = DeltaMessage(
tool_calls=[
ToolCall(
function=FunctionCall(
name=request.tool_choice.function.name,
arguments=delta_text,
)
)
]
)
else:
delta_message = DeltaMessage(content=delta_text)
if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name,
)
if (
request.stream_options
and request.stream_options.include_usage
):
if request.stream_options.continuous_usage_stats:
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
else:
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name,
)
if (
request.stream_options
and request.stream_options.include_usage
):
if request.stream_options.continuous_usage_stats:
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True
if request.stream_options and request.stream_options.include_usage:
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens + previous_num_tokens[i],
)
final_usage_chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[],
model=model_name,
usage=final_usage,
)
final_usage_data = final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True
)
yield f"data: {final_usage_data}\n\n"
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
async def chat_completion_full_generator(
self,
request: ChatCompletionRequest,
raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.served_model_names[0]
created_time = int(time.time())
final_res: Optional[RequestOutput] = None
async for res in result_generator:
if raw_request is not None and await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.async_engine_client.abort(request_id)
return self.create_error_response("Client disconnected")
final_res = res
assert final_res is not None
choices: List[ChatCompletionResponseChoice] = []
role = self.get_chat_request_role(request)
for output in final_res.outputs:
token_ids = output.token_ids
out_logprobs = output.logprobs
if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_chat_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.top_logprobs,
tokenizer=tokenizer,
)
else:
logprobs = None
#if (
# request.tool_choice
# and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
#):
# message = ChatMessage(
# role=role,
# content="",
# tool_calls=[
# ToolCall(
# function=FunctionCall(
# name=request.tool_choice.function.name,
# arguments=output.text,
# )
# )
# ],
# )
#elif not request.tool_choice or request.tool_choice == "none":
# message = ChatMessage(role=role, content=output.text)
msg = decode_function_call(output.text)
if "tool_calls" in msg and msg["tool_calls"] is not None and len(msg["tool_calls"]) > 0:
message = ChatMessage(
role=role,
content=msg.get("thought", ""),
tool_calls=[
ToolCall(
function=FunctionCall(
name=fc["function"]["name"],
arguments=json.dumps(fc["function"]["arguments"], ensure_ascii=False),
)
)
for fc in msg["tool_calls"]
],
)
else:
message = ChatMessage(role=role, content=msg.get("content", ""))
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=message,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
choices.append(choice_data)
if request.echo:
last_msg_content = ""
if (
conversation
and conversation[-1].get("content")
and conversation[-1].get("role") == role
):
last_msg_content = conversation[-1]["content"]
for choice in choices:
full_message = last_msg_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs
)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
return response
def _get_top_logprobs(
self,
logprobs: Dict[int, Logprob],
top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer,
) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(
token=(
token := self._get_decoded_token(
p[1],
p[0],
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids,
)
),
logprob=max(p[1].logprob, -9999.0),
bytes=list(token.encode("utf-8", errors="replace")),
)
for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
]
def _create_chat_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: PreTrainedTokenizer,
num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
logprobs_content = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
logprobs_content.append(
ChatCompletionLogProbsContent(
token=token, bytes=list(token.encode("utf-8", errors="replace"))
)
)
else:
logprobs_content.append(
ChatCompletionLogProbsContent(
token=self._get_decoded_token(
step_top_logprobs[token_id],
token_id,
tokenizer,
self.return_tokens_as_token_ids,
),
logprob=max(step_top_logprobs[token_id].logprob, -9999.0),
bytes=list(
step_top_logprobs[token_id].decoded_token.encode(
"utf-8", errors="replace"
)
),
top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs, tokenizer
),
)
)
return ChatCompletionLogProbs(content=logprobs_content)

View File

@ -0,0 +1 @@
datamodel_code_generator

386
demo/function_call/utils.py Normal file
View File

@ -0,0 +1,386 @@
import ast
import json
import keyword
import traceback
import uuid
from collections import deque
from copy import deepcopy
from logging import getLogger
from typing import Any, Dict, List, Optional, Union
from datamodel_code_generator import DataModelType
from datamodel_code_generator.format import PythonVersion
from datamodel_code_generator.model import get_data_model_types
from datamodel_code_generator.parser.jsonschema import JsonSchemaParser
from jsonschema import Draft202012Validator, exceptions, validate
from transformers import LlamaTokenizer
from transformers.tokenization_utils_base import BatchEncoding
from transformers.utils import TensorType
logger = getLogger(__name__)
def decode_function_call(
sequence: str,
tool_call_start="<|tool_call_start|>",
tool_call_end="<|tool_call_end|>",
thought_start="<|thought_start|>",
thought_end="<|thought_end|>",
):
if thought_end in sequence and thought_start in sequence:
thought_string, sequence = sequence.rsplit(thought_end, 1)
thought_string = thought_string.split(thought_start, 1)[1]
else:
thought_string = ""
if tool_call_start in sequence and tool_call_end in sequence:
tool_call_string, content = sequence.rsplit(tool_call_end, 1)
tool_call_string = tool_call_string.split(tool_call_start, 1)[1]
try:
tool_calls = []
tool_call_string = tool_call_string.strip()
if tool_call_string.startswith("```"):
tool_call_string = tool_call_string.lstrip("```").strip()
if tool_call_string.startswith("python"):
tool_call_string = tool_call_string.lstrip("python").strip()
if tool_call_string.endswith("```"):
tool_call_string = tool_call_string.rstrip("```").strip()
for kw in keyword.kwlist:
tool_call_string = tool_call_string.replace(
"," + kw + "=", "," + kw + "_="
)
tool_call_string = tool_call_string.replace(
" " + kw + "=", " " + kw + "_="
)
tool_call_string = tool_call_string.replace(
"(" + kw + "=", "(" + kw + "_="
)
parsed = ast.parse(tool_call_string)
for elem in parsed.body:
assert isinstance(elem.value, ast.Call)
calls = resolve_ast_call(elem.value)
for func_name, func_args in calls.items():
new_args = {}
for k, v in func_args.items():
for kw in keyword.kwlist:
if k == kw + "_":
k = kw
new_args[k] = v
this_one = {"name": func_name, "arguments": new_args}
tool_calls.append(this_one)
return {
"content": content.strip(),
"tool_calls": [
{
"type": "function",
"function": tool_call,
"id": "call_" + uuid.uuid4().hex,
}
for tool_call in tool_calls
],
"role": "assistant",
}
except:
logger.error(traceback.format_exc())
return {
"content": content.strip(),
"role": "assistant",
"thought": thought_string,
}
else:
return {
"content": sequence.strip(),
"role": "assistant",
"thought": thought_string,
}
def check_messages(conversation: List[Dict[str, str]], tools: List[Dict]):
if tools is not None:
for tool in tools:
if "type" not in tool or tool["type"] != "function":
raise ValueError(f"Tool {tool} is not valid")
if "name" not in tool["function"]:
raise ValueError(f"Tool {tool} is not valid")
if "parameters" not in tool["function"] or not check_tool(
tool["function"]["parameters"]["properties"]
):
raise ValueError(f"Tool {tool} is not valid")
for message in conversation:
if (
message["role"] == "assistant"
and "tool_calls" in message
and len(message["tool_calls"]) > 0
):
for tool_call in message["tool_calls"]:
if "id" not in tool_call:
raise ValueError(f"Tool call {tool_call} is not valid")
if tool_call["type"] != "function":
raise ValueError(f"Tool call {tool_call} is not valid")
if "function" not in tool_call:
raise ValueError(f"Tool call {tool_call} is not valid")
if not check_tool(tool_call["function"]):
raise ValueError(
f"Tool call function {tool_call['function']} is not valid"
)
elif message["role"] == "tool":
if "tool_call_id" not in message:
raise ValueError(f"Tool message {message['content']} is not valid")
def check_tool(tool_schema):
try:
Draft202012Validator.check_schema(tool_schema)
return True
except exceptions.SchemaError as e:
print(f"SchemaError: {e}")
return False
def check_args(args, tool_schema):
try:
validate(instance=args, schema=tool_schema)
return True
except exceptions.ValidationError as e:
print(f"Data failed validation: {e}")
return False
def message_format(msg, system_suffix="", user_prefix=""):
if "thought" in msg and msg["thought"] is not None and len(msg["thought"]) > 0:
thought_prefix = f"<|thought_start|>\n{msg['thought']}\n<|thought_end|>\n"
else:
thought_prefix = ""
if msg["role"] == "assistant":
content = msg.get("content", "")
if content is None:
content = ""
if (
"tool_calls" in msg
and msg["tool_calls"] is not None
and len(msg["tool_calls"]) > 0
):
def add_quotes(variable):
if isinstance(variable, str):
return repr(variable)
else:
return str(variable)
tool_calls = []
for _tool_call in msg["tool_calls"]:
if _tool_call is None:
continue
tool_call = _tool_call["function"]
tool_name = tool_call["name"]
if "arguments" not in tool_call or tool_call["arguments"] is None:
continue
if isinstance(tool_call["arguments"], str):
try:
tool_call["arguments"] = json.loads(tool_call["arguments"])
except:
continue
args = ",".join(
[k + "=" + add_quotes(v) for k, v in tool_call["arguments"].items()]
)
tool_calls.append(f"{tool_name}({args})")
content = (
thought_prefix
+ "<|tool_call_start|>\n```python\n"
+ "\n".join(tool_calls).strip()
+ "\n```\n<|tool_call_end|>\n"
+ content
)
# msg["tool_call_string"] = "\n".join(tool_calls).strip()
msg["content"] = content
else:
content = thought_prefix + content
msg["content"] = content
elif msg["role"] == "user":
msg["content"] = user_prefix + "\n" + msg["content"]
elif msg["role"] == "system":
msg["content"] = msg["content"] + "\n" + system_suffix
msg["content"] = msg["content"].strip()
return msg
def jsonschema_to_code(jsonschema: dict) -> str:
input_text = json.dumps(jsonschema)
data_model_types = get_data_model_types(
DataModelType.PydanticBaseModel,
PythonVersion.PY_310,
)
parser = JsonSchemaParser(
source=input_text,
data_model_type=data_model_types.data_model,
data_model_root_type=data_model_types.root_model,
data_model_field_type=data_model_types.field_model,
data_type_manager_type=data_model_types.data_type_manager,
target_python_version=PythonVersion.PY_310,
dump_resolve_reference_action=data_model_types.dump_resolve_reference_action,
field_constraints=True,
)
results = parser.parse()
return results
def transform_function(function: dict):
"""turn json format of function into signature"""
params, default_params = [], []
for prop_name, prop in function["parameters"]["properties"].items():
if "default" in prop:
default_params.append(f'{prop_name}={repr(prop["default"])}')
elif prop_name not in function["parameters"].get("required", []):
default_params.append(f"{prop_name}={repr(None)}")
else:
params.append(prop_name)
ps = ", ".join(params + default_params)
res = "def {f_name}({ps}):\n".format(f_name=function["name"], ps=ps)
f_des = function.get("description", "")
content = jsonschema_to_code(function["parameters"])
if "class" in content:
i = content.index("class")
# print(content[:i])
content = content[i:]
classes, args = content.split("class Model(BaseModel):", 1)
lint_msg = f' """\n {f_des}\n Args:\n{args}\n """\n'
res += lint_msg
if len(classes) > 0:
res = classes + res
return res
def input_format(messages: List[Dict], tools: List[Dict], add_to_system=True):
"""
Process the input messages, global_arguments, tools, tool_choice,
and convert it into a input string.
The global arguments and tools can not be both empty.
parameters:
messages: List[Dict]
the input messages
For example:
tools: List[Dict]
the tools list you can use
For example:
"""
messages = deepcopy(messages)
tools = deepcopy(tools)
if tools is not None and len(tools) > 0:
header = "from enum import Enum\nfrom typing import List, Dict, Optional\nfrom pydantic import BaseModel, Field\n\n"
tools_string = header
for tool in tools:
try:
tools_string += "\n\n" + transform_function(tool)
except:
pass
tools_template = """# Functions
Here is a list of functions that you can invoke:
```python
{tools}
```
# Function Call Rule and Output Format
- If the user's question can be answered without calling any function, please answer the user's question directly. In this situation, you should return your thought and answer the user's question directly.
- If the user cannot be answered without calling any function, and the user does not provide enough information to call functions, please ask the user for more information. In this situation, you should return your thought and ask the user for more information.
- If the user's question cannot be answered without calling any function, and the user has provided enough information to call functions to solve it, you should call the functions. In this situation, the assistant should return your thought and call the functions.
- Use default parameters unless the user has specified otherwise.
- You should answer in the following format:
<|thought_start|>
{{explain why the user's question can be answered without calling a function or why you should ask the user for more information or why you should call one or more functions and your plan to solve the user's question.}}
<|thought_end|>
<|tool_call_start|>
```python
func1(params_name=params_value, params_name2=params_value2...)
func2(params)
```
<|tool_call_end|>
{{answer the user's question directly or ask the user for more information}}
"""
tools_string = tools_template.format(tools=tools_string).strip()
else:
tools_string = ""
if add_to_system:
return [
message_format(msg, system_suffix=tools_string, user_prefix="")
for msg in messages
]
else:
return [
message_format(msg, system_suffix="", user_prefix=tools_string)
for msg in messages
]
# This is a modified version of
# https://github.com/ShishirPatil/gorilla/blob/main/berkeley-function-call-leaderboard/bfcl/model_handler/utils.py
# Thanks to the gorilla team for the original implementation
def resolve_ast_call(elem):
# Handle nested attributes for deeply nested module paths
func_parts = []
func_part = elem.func
while isinstance(func_part, ast.Attribute):
func_parts.append(func_part.attr)
func_part = func_part.value
if isinstance(func_part, ast.Name):
func_parts.append(func_part.id)
func_name = ".".join(reversed(func_parts))
args_dict = {}
for arg in elem.keywords:
output = resolve_ast_by_type(arg.value)
args_dict[arg.arg] = output
return {func_name: args_dict}
def resolve_ast_by_type(value):
if isinstance(value, ast.Constant):
if value.value is Ellipsis:
output = "..."
else:
output = value.value
elif isinstance(value, ast.UnaryOp):
output = -value.operand.value
elif isinstance(value, ast.List):
output = [resolve_ast_by_type(v) for v in value.elts]
elif isinstance(value, ast.Dict):
output = {
resolve_ast_by_type(k): resolve_ast_by_type(v)
for k, v in zip(value.keys, value.values)
}
elif isinstance(
value, ast.NameConstant
): # Added this condition to handle boolean values
output = value.value
elif isinstance(
value, ast.BinOp
): # Added this condition to handle function calls as arguments
output = eval(ast.unparse(value))
elif isinstance(value, ast.Name):
output = value.id
elif isinstance(value, ast.Call):
if len(value.keywords) == 0:
output = ast.unparse(value)
else:
output = resolve_ast_call(value)
elif isinstance(value, ast.Tuple):
output = tuple(resolve_ast_by_type(v) for v in value.elts)
elif isinstance(value, ast.Lambda):
output = eval(ast.unparse(value.body[0].value))
elif isinstance(value, ast.Ellipsis):
output = "..."
elif isinstance(value, ast.Subscript):
try:
output = ast.unparse(value.body[0].value)
except:
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
else:
raise Exception(f"Unsupported AST type: {type(value)}")
return output