mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-19 12:53:36 +08:00
update readme and add more demos
This commit is contained in:
parent
f541d519a0
commit
9ac6cf3e38
37
README-en.md
37
README-en.md
@ -415,16 +415,47 @@ We have supported fine-tuning MiniCPM3 using [LLaMA-Factory](https://github.com/
|
|||||||
|
|
||||||
### Advanced Features
|
### Advanced Features
|
||||||
|
|
||||||
|
We recommend using [vLLM](#vllm) for the following advanced features.
|
||||||
|
|
||||||
#### Function calling
|
#### Function calling
|
||||||
|
|
||||||
We provide example code for using function calls with MiniCPM3, see [`demo/function_call.py`](./demo/function_calling.py).
|
We provide example code for using function calls with MiniCPM3:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd demo/function_call
|
||||||
|
python function_call.py
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to start a function call service, use the following commands:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd demo/function_call
|
||||||
|
pip install -r requirements.txt
|
||||||
|
python openai_api_server.py \
|
||||||
|
--model openbmb/MiniCPM3-4B \
|
||||||
|
--served-model-name MiniCPM3-4B \
|
||||||
|
--chat-template chatml.jinja \
|
||||||
|
--dtype auto \
|
||||||
|
--api-key token-abc123 \
|
||||||
|
--tensor-parallel-size 1 \
|
||||||
|
--trust-remote-code
|
||||||
|
```
|
||||||
|
|
||||||
|
Below is a demo of using a search engine to answer the question:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
#### Code Interpreter
|
#### Code Interpreter
|
||||||
|
|
||||||
We provide example code for using the code interpreter with MiniCPM3, see [`demo/code_interpreter.py`](./demo/code_interpreter.py).
|
We provide example code for using the code interpreter with MiniCPM3:
|
||||||
|
|
||||||
Below is a demo:
|
```bash
|
||||||
|
cd demo/code_interpreter
|
||||||
|
pip install -r requirements.txt
|
||||||
|
python code_interpreter.py openbmb/MiniCPM3-4B
|
||||||
|
```
|
||||||
|
|
||||||
|
Below is an example of using the code interpreter to generate a QR code:
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
|||||||
39
README.md
39
README.md
@ -415,15 +415,48 @@ print(responds)
|
|||||||
目前模型微调支持 [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory),使用方法参考 [LLaMA-Factory 微调](https://modelbest.feishu.cn/docx/Z7USdW4lloZzkZxQ14icJ3senjb?from=from_copylink)。
|
目前模型微调支持 [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory),使用方法参考 [LLaMA-Factory 微调](https://modelbest.feishu.cn/docx/Z7USdW4lloZzkZxQ14icJ3senjb?from=from_copylink)。
|
||||||
|
|
||||||
### 进阶功能
|
### 进阶功能
|
||||||
|
|
||||||
|
对于以下进阶功能,我们推荐使用 [vLLM](#vllm)。
|
||||||
|
|
||||||
#### 工具调用
|
#### 工具调用
|
||||||
|
|
||||||
我们提供了使用 MiniCPM3 调用工具的示例代码,见[`demo/function_calling.py`](./demo/function_calling.py)。
|
我们提供了使用 MiniCPM3 调用工具的示例代码:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd demo/function_call
|
||||||
|
python function_call.py
|
||||||
|
```
|
||||||
|
|
||||||
|
如果你想启动一个能够调用工具的推理服务,使用以下代码:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd demo/function_call
|
||||||
|
pip install -r requirements.txt
|
||||||
|
python openai_api_server.py \
|
||||||
|
--model openbmb/MiniCPM3-4B \
|
||||||
|
--served-model-name MiniCPM3-4B \
|
||||||
|
--chat-template chatml.jinja \
|
||||||
|
--dtype auto \
|
||||||
|
--api-key token-abc123 \
|
||||||
|
--tensor-parallel-size 1 \
|
||||||
|
--trust-remote-code
|
||||||
|
```
|
||||||
|
|
||||||
|
下面是一个调用搜索工具回答问题的演示:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
#### 代码解释器
|
#### 代码解释器
|
||||||
|
|
||||||
我们提供了一个 MiniCPM3 使用代码解释器的示例代码,见[`demo/code_interpreter.py`](./demo/code_interpreter.py)。
|
我们提供了一个 MiniCPM3 使用代码解释器的示例代码:
|
||||||
|
|
||||||
下面是一个 Demo:
|
```bash
|
||||||
|
cd demo/code_interpreter
|
||||||
|
pip install -r requirements.txt
|
||||||
|
python code_interpreter.py openbmb/MiniCPM3-4B
|
||||||
|
```
|
||||||
|
|
||||||
|
下面是一个使用代码解释器生成二维码的演示:
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
|||||||
BIN
assets/function_call.gif
Normal file
BIN
assets/function_call.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.6 MiB |
1
demo/code_interpreter/requirements.txt
Normal file
1
demo/code_interpreter/requirements.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
fire
|
||||||
2
demo/function_call/chatml.jinja
Normal file
2
demo/function_call/chatml.jinja
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}
|
||||||
|
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}
|
||||||
370
demo/function_call/openai_api_server.py
Normal file
370
demo/function_call/openai_api_server.py
Normal file
@ -0,0 +1,370 @@
|
|||||||
|
import asyncio
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import re
|
||||||
|
from argparse import Namespace
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from http import HTTPStatus
|
||||||
|
from multiprocessing import Process
|
||||||
|
from typing import AsyncIterator, Set
|
||||||
|
|
||||||
|
from fastapi import APIRouter, FastAPI, Request
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
|
from prometheus_client import make_asgi_app
|
||||||
|
from starlette.routing import Mount
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.engine.protocol import AsyncEngineClient
|
||||||
|
from vllm.entrypoints.launcher import serve_http
|
||||||
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
|
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
|
from openai_protocol import (ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
CompletionRequest,
|
||||||
|
DetokenizeRequest,
|
||||||
|
DetokenizeResponse,
|
||||||
|
EmbeddingRequest, ErrorResponse,
|
||||||
|
TokenizeRequest,
|
||||||
|
TokenizeResponse)
|
||||||
|
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
|
||||||
|
from vllm.entrypoints.openai.rpc.server import run_rpc_server
|
||||||
|
# yapf: enable
|
||||||
|
from openai_serving_chat import OpenAIServingChat
|
||||||
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
|
from vllm.entrypoints.openai.serving_tokenization import (
|
||||||
|
OpenAIServingTokenization)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
from vllm.utils import FlexibleArgumentParser, get_open_port
|
||||||
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
|
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||||
|
|
||||||
|
async_engine_client: AsyncEngineClient
|
||||||
|
engine_args: AsyncEngineArgs
|
||||||
|
openai_serving_chat: OpenAIServingChat
|
||||||
|
openai_serving_completion: OpenAIServingCompletion
|
||||||
|
openai_serving_embedding: OpenAIServingEmbedding
|
||||||
|
openai_serving_tokenization: OpenAIServingTokenization
|
||||||
|
|
||||||
|
logger = init_logger('vllm.entrypoints.openai.api_server')
|
||||||
|
|
||||||
|
_running_tasks: Set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
|
||||||
|
def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
|
||||||
|
return ModelConfig(model=model_name,
|
||||||
|
tokenizer=model_name,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
seed=0,
|
||||||
|
dtype="float16").embedding_mode
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
|
||||||
|
async def _force_log():
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
await async_engine_client.do_log_stats()
|
||||||
|
|
||||||
|
if not engine_args.disable_log_stats:
|
||||||
|
task = asyncio.create_task(_force_log())
|
||||||
|
_running_tasks.add(task)
|
||||||
|
task.add_done_callback(_running_tasks.remove)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
|
||||||
|
# Context manager to handle async_engine_client lifecycle
|
||||||
|
# Ensures everything is shutdown and cleaned up on error/exit
|
||||||
|
global engine_args
|
||||||
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
|
|
||||||
|
# Backend itself still global for the silly lil' health handler
|
||||||
|
global async_engine_client
|
||||||
|
|
||||||
|
# If manually triggered or embedding model, use AsyncLLMEngine in process.
|
||||||
|
# TODO: support embedding model via RPC.
|
||||||
|
if (model_is_embedding(args.model, args.trust_remote_code)
|
||||||
|
or args.disable_frontend_multiprocessing):
|
||||||
|
async_engine_client = AsyncLLMEngine.from_engine_args(
|
||||||
|
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
||||||
|
yield async_engine_client
|
||||||
|
return
|
||||||
|
|
||||||
|
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||||
|
else:
|
||||||
|
# Start RPCServer in separate process (holds the AsyncLLMEngine).
|
||||||
|
port = get_open_port(envs.VLLM_RPC_PORT)
|
||||||
|
rpc_server_process = Process(target=run_rpc_server,
|
||||||
|
args=(engine_args,
|
||||||
|
UsageContext.OPENAI_API_SERVER,
|
||||||
|
port))
|
||||||
|
rpc_server_process.start()
|
||||||
|
|
||||||
|
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
|
||||||
|
async_engine_client = AsyncEngineRPCClient(port)
|
||||||
|
await async_engine_client.setup()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield async_engine_client
|
||||||
|
finally:
|
||||||
|
# Ensure rpc server process was terminated
|
||||||
|
rpc_server_process.terminate()
|
||||||
|
|
||||||
|
# Close all open connections to the backend
|
||||||
|
async_engine_client.close()
|
||||||
|
|
||||||
|
# Wait for server process to join
|
||||||
|
rpc_server_process.join()
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def mount_metrics(app: FastAPI):
|
||||||
|
# Add prometheus asgi middleware to route /metrics requests
|
||||||
|
metrics_route = Mount("/metrics", make_asgi_app())
|
||||||
|
# Workaround for 307 Redirect for /metrics
|
||||||
|
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
|
||||||
|
app.routes.append(metrics_route)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health")
|
||||||
|
async def health() -> Response:
|
||||||
|
"""Health check."""
|
||||||
|
await async_engine_client.check_health()
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/tokenize")
|
||||||
|
async def tokenize(request: TokenizeRequest):
|
||||||
|
generator = await openai_serving_tokenization.create_tokenize(request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
else:
|
||||||
|
assert isinstance(generator, TokenizeResponse)
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/detokenize")
|
||||||
|
async def detokenize(request: DetokenizeRequest):
|
||||||
|
generator = await openai_serving_tokenization.create_detokenize(request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
else:
|
||||||
|
assert isinstance(generator, DetokenizeResponse)
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/models")
|
||||||
|
async def show_available_models():
|
||||||
|
models = await openai_serving_completion.show_available_models()
|
||||||
|
return JSONResponse(content=models.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/version")
|
||||||
|
async def show_version():
|
||||||
|
ver = {"version": VLLM_VERSION}
|
||||||
|
return JSONResponse(content=ver)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/chat/completions")
|
||||||
|
async def create_chat_completion(request: ChatCompletionRequest,
|
||||||
|
raw_request: Request):
|
||||||
|
generator = await openai_serving_chat.create_chat_completion(
|
||||||
|
request, raw_request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
if request.stream:
|
||||||
|
return StreamingResponse(content=generator,
|
||||||
|
media_type="text/event-stream")
|
||||||
|
else:
|
||||||
|
assert isinstance(generator, ChatCompletionResponse)
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/completions")
|
||||||
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||||
|
generator = await openai_serving_completion.create_completion(
|
||||||
|
request, raw_request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
if request.stream:
|
||||||
|
return StreamingResponse(content=generator,
|
||||||
|
media_type="text/event-stream")
|
||||||
|
else:
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/embeddings")
|
||||||
|
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||||
|
generator = await openai_serving_embedding.create_embedding(
|
||||||
|
request, raw_request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
else:
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
def build_app(args: Namespace) -> FastAPI:
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.include_router(router)
|
||||||
|
app.root_path = args.root_path
|
||||||
|
|
||||||
|
mount_metrics(app)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=args.allowed_origins,
|
||||||
|
allow_credentials=args.allow_credentials,
|
||||||
|
allow_methods=args.allowed_methods,
|
||||||
|
allow_headers=args.allowed_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(_, exc):
|
||||||
|
err = openai_serving_chat.create_error_response(message=str(exc))
|
||||||
|
return JSONResponse(err.model_dump(),
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
if token := envs.VLLM_API_KEY or args.api_key:
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def authentication(request: Request, call_next):
|
||||||
|
root_path = "" if args.root_path is None else args.root_path
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
return await call_next(request)
|
||||||
|
if not request.url.path.startswith(f"{root_path}/v1"):
|
||||||
|
return await call_next(request)
|
||||||
|
if request.headers.get("Authorization") != "Bearer " + token:
|
||||||
|
return JSONResponse(content={"error": "Unauthorized"},
|
||||||
|
status_code=401)
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
for middleware in args.middleware:
|
||||||
|
module_path, object_name = middleware.rsplit(".", 1)
|
||||||
|
imported = getattr(importlib.import_module(module_path), object_name)
|
||||||
|
if inspect.isclass(imported):
|
||||||
|
app.add_middleware(imported)
|
||||||
|
elif inspect.iscoroutinefunction(imported):
|
||||||
|
app.middleware("http")(imported)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid middleware {middleware}. "
|
||||||
|
f"Must be a function or a class.")
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
async def init_app(
|
||||||
|
async_engine_client: AsyncEngineClient,
|
||||||
|
args: Namespace,
|
||||||
|
) -> FastAPI:
|
||||||
|
app = build_app(args)
|
||||||
|
|
||||||
|
if args.served_model_name is not None:
|
||||||
|
served_model_names = args.served_model_name
|
||||||
|
else:
|
||||||
|
served_model_names = [args.model]
|
||||||
|
|
||||||
|
model_config = await async_engine_client.get_model_config()
|
||||||
|
|
||||||
|
if args.disable_log_requests:
|
||||||
|
request_logger = None
|
||||||
|
else:
|
||||||
|
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||||
|
|
||||||
|
global openai_serving_chat
|
||||||
|
global openai_serving_completion
|
||||||
|
global openai_serving_embedding
|
||||||
|
global openai_serving_tokenization
|
||||||
|
|
||||||
|
openai_serving_chat = OpenAIServingChat(
|
||||||
|
async_engine_client,
|
||||||
|
model_config,
|
||||||
|
served_model_names,
|
||||||
|
args.response_role,
|
||||||
|
lora_modules=args.lora_modules,
|
||||||
|
prompt_adapters=args.prompt_adapters,
|
||||||
|
request_logger=request_logger,
|
||||||
|
chat_template=args.chat_template,
|
||||||
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
|
)
|
||||||
|
openai_serving_completion = OpenAIServingCompletion(
|
||||||
|
async_engine_client,
|
||||||
|
model_config,
|
||||||
|
served_model_names,
|
||||||
|
lora_modules=args.lora_modules,
|
||||||
|
prompt_adapters=args.prompt_adapters,
|
||||||
|
request_logger=request_logger,
|
||||||
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
|
)
|
||||||
|
openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
|
async_engine_client,
|
||||||
|
model_config,
|
||||||
|
served_model_names,
|
||||||
|
request_logger=request_logger,
|
||||||
|
)
|
||||||
|
openai_serving_tokenization = OpenAIServingTokenization(
|
||||||
|
async_engine_client,
|
||||||
|
model_config,
|
||||||
|
served_model_names,
|
||||||
|
lora_modules=args.lora_modules,
|
||||||
|
request_logger=request_logger,
|
||||||
|
chat_template=args.chat_template,
|
||||||
|
)
|
||||||
|
app.root_path = args.root_path
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||||
|
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||||
|
logger.info("args: %s", args)
|
||||||
|
|
||||||
|
async with build_async_engine_client(args) as async_engine_client:
|
||||||
|
app = await init_app(async_engine_client, args)
|
||||||
|
|
||||||
|
shutdown_task = await serve_http(
|
||||||
|
app,
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
log_level=args.uvicorn_log_level,
|
||||||
|
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||||
|
ssl_keyfile=args.ssl_keyfile,
|
||||||
|
ssl_certfile=args.ssl_certfile,
|
||||||
|
ssl_ca_certs=args.ssl_ca_certs,
|
||||||
|
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||||
|
**uvicorn_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NB: Await server shutdown only after the backend context is exited
|
||||||
|
await shutdown_task
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# NOTE(simon):
|
||||||
|
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||||
|
parser = make_arg_parser(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
asyncio.run(run_server(args))
|
||||||
738
demo/function_call/openai_protocol.py
Normal file
738
demo/function_call/openai_protocol.py
Normal file
@ -0,0 +1,738 @@
|
|||||||
|
# Adapted from
|
||||||
|
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||||
|
import time
|
||||||
|
from argparse import Namespace
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
|
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
# torch is mocked during docs generation,
|
||||||
|
# so we have to provide the values as literals
|
||||||
|
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sphinx.ext.autodoc.mock import _MockModule
|
||||||
|
|
||||||
|
if isinstance(torch, _MockModule):
|
||||||
|
_LONG_INFO = _MOCK_LONG_INFO
|
||||||
|
else:
|
||||||
|
_LONG_INFO = torch.iinfo(torch.long)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
_LONG_INFO = torch.iinfo(torch.long)
|
||||||
|
|
||||||
|
assert _LONG_INFO.min == _MOCK_LONG_INFO.min
|
||||||
|
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIBaseModel(BaseModel):
|
||||||
|
# OpenAI API does not allow extra fields
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(OpenAIBaseModel):
|
||||||
|
object: str = "error"
|
||||||
|
message: str
|
||||||
|
type: str
|
||||||
|
param: Optional[str] = None
|
||||||
|
code: int
|
||||||
|
|
||||||
|
|
||||||
|
class ModelPermission(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
||||||
|
object: str = "model_permission"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
allow_create_engine: bool = False
|
||||||
|
allow_sampling: bool = True
|
||||||
|
allow_logprobs: bool = True
|
||||||
|
allow_search_indices: bool = False
|
||||||
|
allow_view: bool = True
|
||||||
|
allow_fine_tuning: bool = False
|
||||||
|
organization: str = "*"
|
||||||
|
group: Optional[str] = None
|
||||||
|
is_blocking: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCard(OpenAIBaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "model"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
owned_by: str = "vllm"
|
||||||
|
root: Optional[str] = None
|
||||||
|
parent: Optional[str] = None
|
||||||
|
max_model_len: Optional[int] = None
|
||||||
|
permission: List[ModelPermission] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelList(OpenAIBaseModel):
|
||||||
|
object: str = "list"
|
||||||
|
data: List[ModelCard] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class UsageInfo(OpenAIBaseModel):
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
completion_tokens: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormat(OpenAIBaseModel):
|
||||||
|
# type must be "json_object" or "text"
|
||||||
|
type: Literal["text", "json_object"]
|
||||||
|
|
||||||
|
|
||||||
|
class StreamOptions(OpenAIBaseModel):
|
||||||
|
include_usage: Optional[bool] = True
|
||||||
|
continuous_usage_stats: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionDefinition(OpenAIBaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
parameters: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionToolsParam(OpenAIBaseModel):
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: FunctionDefinition
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionNamedFunction(OpenAIBaseModel):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
|
||||||
|
function: ChatCompletionNamedFunction
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequest(OpenAIBaseModel):
|
||||||
|
# Ordered by official OpenAI API documentation
|
||||||
|
# https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
messages: List[ChatCompletionMessageParam]
|
||||||
|
model: str
|
||||||
|
frequency_penalty: Optional[float] = 0.0
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
|
logprobs: Optional[bool] = False
|
||||||
|
top_logprobs: Optional[int] = 0
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
n: Optional[int] = 1
|
||||||
|
presence_penalty: Optional[float] = 0.0
|
||||||
|
response_format: Optional[ResponseFormat] = None
|
||||||
|
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||||
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
stream_options: Optional[StreamOptions] = None
|
||||||
|
temperature: Optional[float] = 0.7
|
||||||
|
top_p: Optional[float] = 1.0
|
||||||
|
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||||
|
tool_choice: Optional[Union[Literal["none"],
|
||||||
|
ChatCompletionNamedToolChoiceParam]] = "none"
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
# doc: begin-chat-completion-sampling-params
|
||||||
|
best_of: Optional[int] = None
|
||||||
|
use_beam_search: bool = False
|
||||||
|
top_k: int = -1
|
||||||
|
min_p: float = 0.0
|
||||||
|
repetition_penalty: float = 1.0
|
||||||
|
length_penalty: float = 1.0
|
||||||
|
early_stopping: bool = False
|
||||||
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
|
include_stop_str_in_output: bool = False
|
||||||
|
ignore_eos: bool = False
|
||||||
|
min_tokens: int = 0
|
||||||
|
skip_special_tokens: bool = True
|
||||||
|
spaces_between_special_tokens: bool = True
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||||
|
# doc: end-chat-completion-sampling-params
|
||||||
|
|
||||||
|
# doc: begin-chat-completion-extra-params
|
||||||
|
echo: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"If true, the new message will be prepended with the last message "
|
||||||
|
"if they belong to the same role."),
|
||||||
|
)
|
||||||
|
add_generation_prompt: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=
|
||||||
|
("If true, the generation prompt will be added to the chat template. "
|
||||||
|
"This is a parameter used by chat template in tokenizer config of the "
|
||||||
|
"model."),
|
||||||
|
)
|
||||||
|
add_special_tokens: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||||
|
"on top of what is added by the chat template. "
|
||||||
|
"For most models, the chat template takes care of adding the "
|
||||||
|
"special tokens so this should be set to false (as is the "
|
||||||
|
"default)."),
|
||||||
|
)
|
||||||
|
documents: Optional[List[Dict[str, str]]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=
|
||||||
|
("A list of dicts representing documents that will be accessible to "
|
||||||
|
"the model if it is performing RAG (retrieval-augmented generation)."
|
||||||
|
" If the template does not support RAG, this argument will have no "
|
||||||
|
"effect. We recommend that each document should be a dict containing "
|
||||||
|
"\"title\" and \"text\" keys."),
|
||||||
|
)
|
||||||
|
chat_template: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"A Jinja template to use for this conversion. "
|
||||||
|
"If this is not passed, the model's default chat template will be "
|
||||||
|
"used instead."),
|
||||||
|
)
|
||||||
|
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=("Additional kwargs to pass to the template renderer. "
|
||||||
|
"Will be accessible by the chat template."),
|
||||||
|
)
|
||||||
|
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=("If specified, the output will follow the JSON schema."),
|
||||||
|
)
|
||||||
|
guided_regex: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the output will follow the regex pattern."),
|
||||||
|
)
|
||||||
|
guided_choice: Optional[List[str]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the output will be exactly one of the choices."),
|
||||||
|
)
|
||||||
|
guided_grammar: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the output will follow the context free grammar."),
|
||||||
|
)
|
||||||
|
guided_decoding_backend: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, will override the default guided decoding backend "
|
||||||
|
"of the server for this specific request. If set, must be either "
|
||||||
|
"'outlines' / 'lm-format-enforcer'"))
|
||||||
|
guided_whitespace_pattern: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, will override the default whitespace pattern "
|
||||||
|
"for guided json decoding."))
|
||||||
|
|
||||||
|
# doc: end-chat-completion-extra-params
|
||||||
|
|
||||||
|
def to_sampling_params(
|
||||||
|
self, tokenizer: PreTrainedTokenizer,
|
||||||
|
guided_decode_logits_processor: Optional[LogitsProcessor],
|
||||||
|
default_max_tokens: int) -> SamplingParams:
|
||||||
|
max_tokens = self.max_tokens
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
# We now allow logprobs being true without top_logrobs.
|
||||||
|
logits_processors = get_logits_processors(
|
||||||
|
logit_bias=self.logit_bias,
|
||||||
|
allowed_token_ids=None,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
if guided_decode_logits_processor:
|
||||||
|
logits_processors.append(guided_decode_logits_processor)
|
||||||
|
|
||||||
|
return SamplingParams(
|
||||||
|
n=self.n,
|
||||||
|
best_of=self.best_of,
|
||||||
|
presence_penalty=self.presence_penalty,
|
||||||
|
frequency_penalty=self.frequency_penalty,
|
||||||
|
repetition_penalty=self.repetition_penalty,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
top_k=self.top_k,
|
||||||
|
min_p=self.min_p,
|
||||||
|
seed=self.seed,
|
||||||
|
stop=self.stop,
|
||||||
|
stop_token_ids=self.stop_token_ids,
|
||||||
|
logprobs=self.top_logprobs if self.logprobs else None,
|
||||||
|
prompt_logprobs=self.top_logprobs if self.echo else None,
|
||||||
|
ignore_eos=self.ignore_eos,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
min_tokens=self.min_tokens,
|
||||||
|
use_beam_search=self.use_beam_search,
|
||||||
|
early_stopping=self.early_stopping,
|
||||||
|
skip_special_tokens=self.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||||
|
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||||
|
length_penalty=self.length_penalty,
|
||||||
|
logits_processors=logits_processors,
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode='before')
|
||||||
|
@classmethod
|
||||||
|
def validate_stream_options(cls, values):
|
||||||
|
if (values.get('stream_options') is not None
|
||||||
|
and not values.get('stream')):
|
||||||
|
raise ValueError(
|
||||||
|
"stream_options can only be set if stream is true")
|
||||||
|
return values
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_guided_decoding_count(cls, data):
|
||||||
|
guide_count = sum([
|
||||||
|
"guided_json" in data and data["guided_json"] is not None,
|
||||||
|
"guided_regex" in data and data["guided_regex"] is not None,
|
||||||
|
"guided_choice" in data and data["guided_choice"] is not None
|
||||||
|
])
|
||||||
|
# you can only use one kind of guided decoding
|
||||||
|
if guide_count > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"You can only use one kind of guided decoding "
|
||||||
|
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||||
|
# you can only either use guided decoding or tools, not both
|
||||||
|
if guide_count > 1 and "tool_choice" in data and data[
|
||||||
|
"tool_choice"] != "none":
|
||||||
|
raise ValueError(
|
||||||
|
"You can only either use guided decoding or tools, not both.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_tool_choice(cls, data):
|
||||||
|
if "tool_choice" in data and data["tool_choice"] != "none":
|
||||||
|
if not isinstance(data["tool_choice"], dict):
|
||||||
|
raise ValueError("Currently only named tools are supported.")
|
||||||
|
if "tools" not in data or data["tools"] is None:
|
||||||
|
raise ValueError(
|
||||||
|
"When using `tool_choice`, `tools` must be set.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_logprobs(cls, data):
|
||||||
|
if "top_logprobs" in data and data["top_logprobs"] is not None:
|
||||||
|
if "logprobs" not in data or data["logprobs"] is False:
|
||||||
|
raise ValueError(
|
||||||
|
"when using `top_logprobs`, `logprobs` must be set to true."
|
||||||
|
)
|
||||||
|
elif data["top_logprobs"] < 0:
|
||||||
|
raise ValueError(
|
||||||
|
"`top_logprobs` must be a value a positive value.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionRequest(OpenAIBaseModel):
|
||||||
|
# Ordered by official OpenAI API documentation
|
||||||
|
# https://platform.openai.com/docs/api-reference/completions/create
|
||||||
|
model: str
|
||||||
|
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||||
|
best_of: Optional[int] = None
|
||||||
|
echo: Optional[bool] = False
|
||||||
|
frequency_penalty: Optional[float] = 0.0
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
|
logprobs: Optional[int] = None
|
||||||
|
max_tokens: Optional[int] = 16
|
||||||
|
n: int = 1
|
||||||
|
presence_penalty: Optional[float] = 0.0
|
||||||
|
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||||
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
stream_options: Optional[StreamOptions] = None
|
||||||
|
suffix: Optional[str] = None
|
||||||
|
temperature: Optional[float] = 1.0
|
||||||
|
top_p: Optional[float] = 1.0
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
# doc: begin-completion-sampling-params
|
||||||
|
use_beam_search: bool = False
|
||||||
|
top_k: int = -1
|
||||||
|
min_p: float = 0.0
|
||||||
|
repetition_penalty: float = 1.0
|
||||||
|
length_penalty: float = 1.0
|
||||||
|
early_stopping: bool = False
|
||||||
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
|
include_stop_str_in_output: bool = False
|
||||||
|
ignore_eos: bool = False
|
||||||
|
min_tokens: int = 0
|
||||||
|
skip_special_tokens: bool = True
|
||||||
|
spaces_between_special_tokens: bool = True
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||||
|
allowed_token_ids: Optional[List[int]] = None
|
||||||
|
# doc: end-completion-sampling-params
|
||||||
|
|
||||||
|
# doc: begin-completion-extra-params
|
||||||
|
add_special_tokens: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=(
|
||||||
|
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||||
|
"the prompt."),
|
||||||
|
)
|
||||||
|
response_format: Optional[ResponseFormat] = Field(
|
||||||
|
default=None,
|
||||||
|
description=
|
||||||
|
("Similar to chat completion, this parameter specifies the format of "
|
||||||
|
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
|
||||||
|
"supported."),
|
||||||
|
)
|
||||||
|
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=("If specified, the output will follow the JSON schema."),
|
||||||
|
)
|
||||||
|
guided_regex: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the output will follow the regex pattern."),
|
||||||
|
)
|
||||||
|
guided_choice: Optional[List[str]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the output will be exactly one of the choices."),
|
||||||
|
)
|
||||||
|
guided_grammar: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the output will follow the context free grammar."),
|
||||||
|
)
|
||||||
|
guided_decoding_backend: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, will override the default guided decoding backend "
|
||||||
|
"of the server for this specific request. If set, must be one of "
|
||||||
|
"'outlines' / 'lm-format-enforcer'"))
|
||||||
|
guided_whitespace_pattern: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, will override the default whitespace pattern "
|
||||||
|
"for guided json decoding."))
|
||||||
|
|
||||||
|
# doc: end-completion-extra-params
|
||||||
|
|
||||||
|
def to_sampling_params(
|
||||||
|
self, tokenizer: PreTrainedTokenizer,
|
||||||
|
guided_decode_logits_processor: Optional[LogitsProcessor],
|
||||||
|
default_max_tokens: int) -> SamplingParams:
|
||||||
|
max_tokens = self.max_tokens
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
echo_without_generation = self.echo and self.max_tokens == 0
|
||||||
|
|
||||||
|
logits_processors = get_logits_processors(
|
||||||
|
logit_bias=self.logit_bias,
|
||||||
|
allowed_token_ids=self.allowed_token_ids,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
if guided_decode_logits_processor:
|
||||||
|
logits_processors.append(guided_decode_logits_processor)
|
||||||
|
|
||||||
|
return SamplingParams(
|
||||||
|
n=self.n,
|
||||||
|
best_of=self.best_of,
|
||||||
|
presence_penalty=self.presence_penalty,
|
||||||
|
frequency_penalty=self.frequency_penalty,
|
||||||
|
repetition_penalty=self.repetition_penalty,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
top_k=self.top_k,
|
||||||
|
min_p=self.min_p,
|
||||||
|
seed=self.seed,
|
||||||
|
stop=self.stop,
|
||||||
|
stop_token_ids=self.stop_token_ids,
|
||||||
|
logprobs=self.logprobs,
|
||||||
|
ignore_eos=self.ignore_eos,
|
||||||
|
max_tokens=max_tokens if not echo_without_generation else 1,
|
||||||
|
min_tokens=self.min_tokens,
|
||||||
|
use_beam_search=self.use_beam_search,
|
||||||
|
early_stopping=self.early_stopping,
|
||||||
|
prompt_logprobs=self.logprobs if self.echo else None,
|
||||||
|
skip_special_tokens=self.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||||
|
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||||
|
length_penalty=self.length_penalty,
|
||||||
|
logits_processors=logits_processors,
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_guided_decoding_count(cls, data):
|
||||||
|
guide_count = sum([
|
||||||
|
"guided_json" in data and data["guided_json"] is not None,
|
||||||
|
"guided_regex" in data and data["guided_regex"] is not None,
|
||||||
|
"guided_choice" in data and data["guided_choice"] is not None
|
||||||
|
])
|
||||||
|
if guide_count > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"You can only use one kind of guided decoding "
|
||||||
|
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_logprobs(cls, data):
|
||||||
|
if "logprobs" in data and data[
|
||||||
|
"logprobs"] is not None and not data["logprobs"] >= 0:
|
||||||
|
raise ValueError("if passed, `logprobs` must be a positive value.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_stream_options(cls, data):
|
||||||
|
if data.get("stream_options") and not data.get("stream"):
|
||||||
|
raise ValueError(
|
||||||
|
"Stream options can only be defined when stream is true.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingRequest(OpenAIBaseModel):
|
||||||
|
# Ordered by official OpenAI API documentation
|
||||||
|
# https://platform.openai.com/docs/api-reference/embeddings
|
||||||
|
model: str
|
||||||
|
input: Union[List[int], List[List[int]], str, List[str]]
|
||||||
|
encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
|
||||||
|
dimensions: Optional[int] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
# doc: begin-embedding-pooling-params
|
||||||
|
additional_data: Optional[Any] = None
|
||||||
|
|
||||||
|
# doc: end-embedding-pooling-params
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(additional_data=self.additional_data)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionLogProbs(OpenAIBaseModel):
|
||||||
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||||
|
tokens: List[str] = Field(default_factory=list)
|
||||||
|
top_logprobs: List[Optional[Dict[str,
|
||||||
|
float]]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponseChoice(OpenAIBaseModel):
|
||||||
|
index: int
|
||||||
|
text: str
|
||||||
|
logprobs: Optional[CompletionLogProbs] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
stop_reason: Optional[Union[int, str]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"The stop string or token id that caused the completion "
|
||||||
|
"to stop, None if the completion finished for some other reason "
|
||||||
|
"including encountering the EOS token"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||||
|
object: str = "text_completion"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[CompletionResponseChoice]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponseStreamChoice(OpenAIBaseModel):
|
||||||
|
index: int
|
||||||
|
text: str
|
||||||
|
logprobs: Optional[CompletionLogProbs] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
stop_reason: Optional[Union[int, str]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"The stop string or token id that caused the completion "
|
||||||
|
"to stop, None if the completion finished for some other reason "
|
||||||
|
"including encountering the EOS token"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionStreamResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||||
|
object: str = "text_completion"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[CompletionResponseStreamChoice]
|
||||||
|
usage: Optional[UsageInfo] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingResponseData(OpenAIBaseModel):
|
||||||
|
index: int
|
||||||
|
object: str = "embedding"
|
||||||
|
embedding: Union[List[float], str]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||||
|
object: str = "list"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
data: List[EmbeddingResponseData]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCall(OpenAIBaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: FunctionCall
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(OpenAIBaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
tool_call_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionLogProb(OpenAIBaseModel):
|
||||||
|
token: str
|
||||||
|
logprob: float = -9999.0
|
||||||
|
bytes: Optional[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
|
||||||
|
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionLogProbs(OpenAIBaseModel):
|
||||||
|
content: Optional[List[ChatCompletionLogProbsContent]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseChoice(OpenAIBaseModel):
|
||||||
|
index: int
|
||||||
|
message: ChatMessage
|
||||||
|
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
stop_reason: Optional[Union[int, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||||
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseChoice]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class DeltaMessage(OpenAIBaseModel):
|
||||||
|
role: Optional[str] = None
|
||||||
|
content: Optional[str] = None
|
||||||
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
tool_call_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
||||||
|
index: int
|
||||||
|
delta: DeltaMessage
|
||||||
|
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
stop_reason: Optional[Union[int, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionStreamResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||||
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
usage: Optional[UsageInfo] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchRequestInput(OpenAIBaseModel):
|
||||||
|
"""
|
||||||
|
The per-line object of the batch input file.
|
||||||
|
|
||||||
|
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# A developer-provided per-request id that will be used to match outputs to
|
||||||
|
# inputs. Must be unique for each request in a batch.
|
||||||
|
custom_id: str
|
||||||
|
|
||||||
|
# The HTTP method to be used for the request. Currently only POST is
|
||||||
|
# supported.
|
||||||
|
method: str
|
||||||
|
|
||||||
|
# The OpenAI API relative URL to be used for the request. Currently
|
||||||
|
# /v1/chat/completions is supported.
|
||||||
|
url: str
|
||||||
|
|
||||||
|
# The parameters of the request.
|
||||||
|
body: ChatCompletionRequest
|
||||||
|
|
||||||
|
|
||||||
|
class BatchResponseData(OpenAIBaseModel):
|
||||||
|
# HTTP status code of the response.
|
||||||
|
status_code: int = 200
|
||||||
|
|
||||||
|
# An unique identifier for the API request.
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
# The body of the response.
|
||||||
|
body: Optional[ChatCompletionResponse] = None
|
||||||
|
|
||||||
|
|
||||||
|
class BatchRequestOutput(OpenAIBaseModel):
|
||||||
|
"""
|
||||||
|
The per-line object of the batch output and error files
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
|
||||||
|
# A developer-provided per-request id that will be used to match outputs to
|
||||||
|
# inputs.
|
||||||
|
custom_id: str
|
||||||
|
|
||||||
|
response: Optional[BatchResponseData]
|
||||||
|
|
||||||
|
# For requests that failed with a non-HTTP error, this will contain more
|
||||||
|
# information on the cause of the failure.
|
||||||
|
error: Optional[Any]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||||
|
model: str
|
||||||
|
prompt: str
|
||||||
|
|
||||||
|
add_special_tokens: bool = Field(default=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizeChatRequest(OpenAIBaseModel):
|
||||||
|
model: str
|
||||||
|
messages: List[ChatCompletionMessageParam]
|
||||||
|
|
||||||
|
add_generation_prompt: bool = Field(default=True)
|
||||||
|
add_special_tokens: bool = Field(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizeResponse(OpenAIBaseModel):
|
||||||
|
count: int
|
||||||
|
max_model_len: int
|
||||||
|
tokens: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
class DetokenizeRequest(OpenAIBaseModel):
|
||||||
|
model: str
|
||||||
|
tokens: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
class DetokenizeResponse(OpenAIBaseModel):
|
||||||
|
prompt: str
|
||||||
691
demo/function_call/openai_serving_chat.py
Normal file
691
demo/function_call/openai_serving_chat.py
Normal file
@ -0,0 +1,691 @@
|
|||||||
|
import time
|
||||||
|
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
|
||||||
|
from typing import Sequence as GenericSequence
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.protocol import AsyncEngineClient
|
||||||
|
from vllm.entrypoints.chat_utils import (
|
||||||
|
ConversationMessage,
|
||||||
|
load_chat_template,
|
||||||
|
parse_chat_messages,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
|
from openai_protocol import (
|
||||||
|
ChatCompletionLogProb,
|
||||||
|
ChatCompletionLogProbs,
|
||||||
|
ChatCompletionLogProbsContent,
|
||||||
|
ChatCompletionNamedToolChoiceParam,
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseChoice,
|
||||||
|
ChatCompletionResponseStreamChoice,
|
||||||
|
ChatCompletionStreamResponse,
|
||||||
|
ChatMessage,
|
||||||
|
DeltaMessage,
|
||||||
|
ErrorResponse,
|
||||||
|
FunctionCall,
|
||||||
|
ToolCall,
|
||||||
|
UsageInfo,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.openai.serving_engine import (
|
||||||
|
LoRAModulePath,
|
||||||
|
OpenAIServing,
|
||||||
|
PromptAdapterPath,
|
||||||
|
)
|
||||||
|
from vllm.inputs import PromptInputs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.multimodal import MultiModalDataDict
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.sequence import Logprob
|
||||||
|
from vllm.tracing import (
|
||||||
|
contains_trace_headers,
|
||||||
|
extract_trace_headers,
|
||||||
|
log_tracing_disabled_warning,
|
||||||
|
)
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
from utils import decode_function_call
|
||||||
|
import json
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingChat(OpenAIServing):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
async_engine_client: AsyncEngineClient,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
served_model_names: List[str],
|
||||||
|
response_role: str,
|
||||||
|
*,
|
||||||
|
lora_modules: Optional[List[LoRAModulePath]],
|
||||||
|
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||||
|
request_logger: Optional[RequestLogger],
|
||||||
|
chat_template: Optional[str],
|
||||||
|
return_tokens_as_token_ids: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
async_engine_client=async_engine_client,
|
||||||
|
model_config=model_config,
|
||||||
|
served_model_names=served_model_names,
|
||||||
|
lora_modules=lora_modules,
|
||||||
|
prompt_adapters=prompt_adapters,
|
||||||
|
request_logger=request_logger,
|
||||||
|
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.response_role = response_role
|
||||||
|
|
||||||
|
# If this is None we use the tokenizer's default chat template
|
||||||
|
self.chat_template = load_chat_template(chat_template)
|
||||||
|
|
||||||
|
async def create_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest, raw_request: Optional[Request] = None
|
||||||
|
) -> Union[ErrorResponse, AsyncGenerator[str, None], ChatCompletionResponse]:
|
||||||
|
"""Completion API similar to OpenAI's API.
|
||||||
|
|
||||||
|
See https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
for the API specification. This API mimics the OpenAI
|
||||||
|
ChatCompletion API.
|
||||||
|
|
||||||
|
NOTE: Currently we do not support the following feature:
|
||||||
|
- function_call (Users should implement this by themselves)
|
||||||
|
"""
|
||||||
|
error_check_ret = await self._check_model(request)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
try:
|
||||||
|
(
|
||||||
|
lora_request,
|
||||||
|
prompt_adapter_request,
|
||||||
|
) = self._maybe_get_adapters(request)
|
||||||
|
|
||||||
|
model_config = self.model_config
|
||||||
|
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
|
||||||
|
|
||||||
|
conversation, mm_futures = parse_chat_messages(
|
||||||
|
request.messages, model_config, tokenizer
|
||||||
|
)
|
||||||
|
print('conversation:', conversation)
|
||||||
|
# parse_chat_messages ignores tool_calls and tool_call_id
|
||||||
|
# we have to fix this
|
||||||
|
conversation = request.messages
|
||||||
|
for msg in conversation:
|
||||||
|
if 'tool_calls' in msg and msg['tool_calls'] is not None:
|
||||||
|
msg['tool_calls'] = [tc for tc in msg['tool_calls']]
|
||||||
|
print('fixed conversation:', conversation)
|
||||||
|
|
||||||
|
tool_dicts = (
|
||||||
|
None
|
||||||
|
if request.tools is None
|
||||||
|
else [tool.model_dump() for tool in request.tools]
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
conversation=conversation,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
tools=tool_dicts,
|
||||||
|
documents=request.documents,
|
||||||
|
chat_template=request.chat_template or self.chat_template,
|
||||||
|
**(request.chat_template_kwargs or {}),
|
||||||
|
)
|
||||||
|
assert isinstance(prompt, str)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error in applying chat template from request: %s", e)
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
mm_data: Optional[MultiModalDataDict] = None
|
||||||
|
try:
|
||||||
|
if len(mm_futures):
|
||||||
|
# since we support only single mm data currently
|
||||||
|
assert (
|
||||||
|
len(mm_futures) == 1
|
||||||
|
), "Multiple 'image_url' input is currently not supported."
|
||||||
|
mm_data = await mm_futures[0]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error in loading multi-modal data: %s", e)
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
request_id = f"chat-{random_uuid()}"
|
||||||
|
try:
|
||||||
|
guided_decode_logits_processor = await self._guided_decode_logits_processor(
|
||||||
|
request, tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_inputs = self._tokenize_prompt_input(
|
||||||
|
request,
|
||||||
|
tokenizer,
|
||||||
|
prompt,
|
||||||
|
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||||
|
add_special_tokens=request.add_special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampling_params = request.to_sampling_params(
|
||||||
|
tokenizer,
|
||||||
|
guided_decode_logits_processor,
|
||||||
|
default_max_tokens=self.max_model_len
|
||||||
|
- len(prompt_inputs["prompt_token_ids"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._log_inputs(
|
||||||
|
request_id,
|
||||||
|
prompt_inputs,
|
||||||
|
params=sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
engine_inputs: PromptInputs = {
|
||||||
|
"prompt_token_ids": prompt_inputs["prompt_token_ids"],
|
||||||
|
}
|
||||||
|
if mm_data is not None:
|
||||||
|
engine_inputs["multi_modal_data"] = mm_data
|
||||||
|
|
||||||
|
is_tracing_enabled = await self.async_engine_client.is_tracing_enabled()
|
||||||
|
trace_headers = None
|
||||||
|
if is_tracing_enabled and raw_request:
|
||||||
|
trace_headers = extract_trace_headers(raw_request.headers)
|
||||||
|
if (
|
||||||
|
not is_tracing_enabled
|
||||||
|
and raw_request
|
||||||
|
and contains_trace_headers(raw_request.headers)
|
||||||
|
):
|
||||||
|
log_tracing_disabled_warning()
|
||||||
|
|
||||||
|
result_generator = self.async_engine_client.generate(
|
||||||
|
engine_inputs,
|
||||||
|
sampling_params,
|
||||||
|
request_id,
|
||||||
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
# Streaming response
|
||||||
|
if request.stream:
|
||||||
|
return self.chat_completion_stream_generator(
|
||||||
|
request, result_generator, request_id, conversation, tokenizer
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return await self.chat_completion_full_generator(
|
||||||
|
request,
|
||||||
|
raw_request,
|
||||||
|
result_generator,
|
||||||
|
request_id,
|
||||||
|
conversation,
|
||||||
|
tokenizer,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
||||||
|
if request.add_generation_prompt:
|
||||||
|
return self.response_role
|
||||||
|
else:
|
||||||
|
return request.messages[-1]["role"]
|
||||||
|
|
||||||
|
async def chat_completion_stream_generator(
|
||||||
|
self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
result_generator: AsyncIterator[RequestOutput],
|
||||||
|
request_id: str,
|
||||||
|
conversation: List[ConversationMessage],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
model_name = self.served_model_names[0]
|
||||||
|
created_time = int(time.time())
|
||||||
|
chunk_object_type = "chat.completion.chunk"
|
||||||
|
first_iteration = True
|
||||||
|
|
||||||
|
# Send response for each token for each request.n (index)
|
||||||
|
num_choices = 1 if request.n is None else request.n
|
||||||
|
previous_texts = [""] * num_choices
|
||||||
|
previous_num_tokens = [0] * num_choices
|
||||||
|
finish_reason_sent = [False] * num_choices
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for res in result_generator:
|
||||||
|
# We need to do it here, because if there are exceptions in
|
||||||
|
# the result_generator, it needs to be sent as the FIRST
|
||||||
|
# response (by the try...catch).
|
||||||
|
if first_iteration:
|
||||||
|
# Send first response for each request.n (index) with
|
||||||
|
# the role
|
||||||
|
role = self.get_chat_request_role(request)
|
||||||
|
for i in range(num_choices):
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
delta=DeltaMessage(role=role),
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
request.stream_options
|
||||||
|
and request.stream_options.include_usage
|
||||||
|
):
|
||||||
|
if request.stream_options.continuous_usage_stats:
|
||||||
|
prompt_tokens = len(res.prompt_token_ids)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=prompt_tokens,
|
||||||
|
)
|
||||||
|
chunk.usage = usage
|
||||||
|
else:
|
||||||
|
chunk.usage = None
|
||||||
|
|
||||||
|
data = chunk.model_dump_json(exclude_unset=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
# Send response to echo the input portion of the
|
||||||
|
# last message
|
||||||
|
if request.echo:
|
||||||
|
last_msg_content = ""
|
||||||
|
if (
|
||||||
|
conversation
|
||||||
|
and conversation[-1].get("content")
|
||||||
|
and conversation[-1].get("role") == role
|
||||||
|
):
|
||||||
|
last_msg_content = conversation[-1]["content"]
|
||||||
|
|
||||||
|
if last_msg_content:
|
||||||
|
for i in range(num_choices):
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
delta=DeltaMessage(content=last_msg_content),
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
request.stream_options
|
||||||
|
and request.stream_options.include_usage
|
||||||
|
):
|
||||||
|
if request.stream_options.continuous_usage_stats:
|
||||||
|
prompt_tokens = len(res.prompt_token_ids)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=prompt_tokens,
|
||||||
|
)
|
||||||
|
chunk.usage = usage
|
||||||
|
else:
|
||||||
|
chunk.usage = None
|
||||||
|
|
||||||
|
data = chunk.model_dump_json(exclude_unset=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
first_iteration = False
|
||||||
|
|
||||||
|
for output in res.outputs:
|
||||||
|
i = output.index
|
||||||
|
|
||||||
|
if finish_reason_sent[i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta_token_ids = output.token_ids[previous_num_tokens[i] :]
|
||||||
|
out_logprobs = (
|
||||||
|
output.logprobs[previous_num_tokens[i] :]
|
||||||
|
if output.logprobs
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.logprobs and request.top_logprobs is not None:
|
||||||
|
assert out_logprobs is not None, "Did not output logprobs"
|
||||||
|
logprobs = self._create_chat_logprobs(
|
||||||
|
token_ids=delta_token_ids,
|
||||||
|
top_logprobs=out_logprobs,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_output_top_logprobs=request.top_logprobs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
|
# if have tool
|
||||||
|
if request.tools is not None and len(request.tools) > 0:
|
||||||
|
if output.finish_reason is not None:
|
||||||
|
msg = decode_function_call(output.text)
|
||||||
|
if "tool_calls" in msg and msg["tool_calls"] is not None and len(msg["tool_calls"]) > 0:
|
||||||
|
delta_message = DeltaMessage(
|
||||||
|
content=msg.get("thought", ""),
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
id=f"chatcmpl-tool-{random_uuid()}",
|
||||||
|
function=FunctionCall(
|
||||||
|
name=fc["function"]["name"],
|
||||||
|
arguments=json.dumps(fc["function"]["arguments"], ensure_ascii=False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for fc in msg["tool_calls"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
delta_message = DeltaMessage(content=msg.get("content", ""))
|
||||||
|
else:
|
||||||
|
# only return the last one
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
delta_text = output.text[len(previous_texts[i]) :]
|
||||||
|
previous_texts[i] = output.text
|
||||||
|
previous_num_tokens[i] = len(output.token_ids)
|
||||||
|
|
||||||
|
if (
|
||||||
|
request.tool_choice
|
||||||
|
and type(request.tool_choice)
|
||||||
|
is ChatCompletionNamedToolChoiceParam
|
||||||
|
):
|
||||||
|
delta_message = DeltaMessage(
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(
|
||||||
|
name=request.tool_choice.function.name,
|
||||||
|
arguments=delta_text,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
delta_message = DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
|
if output.finish_reason is None:
|
||||||
|
# Send token-by-token response for each request.n
|
||||||
|
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
delta=delta_message,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
request.stream_options
|
||||||
|
and request.stream_options.include_usage
|
||||||
|
):
|
||||||
|
if request.stream_options.continuous_usage_stats:
|
||||||
|
prompt_tokens = len(res.prompt_token_ids)
|
||||||
|
completion_tokens = len(output.token_ids)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
chunk.usage = usage
|
||||||
|
else:
|
||||||
|
chunk.usage = None
|
||||||
|
|
||||||
|
data = chunk.model_dump_json(exclude_unset=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
else:
|
||||||
|
# Send the finish response for each request.n only once
|
||||||
|
prompt_tokens = len(res.prompt_token_ids)
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
delta=delta_message,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=output.finish_reason,
|
||||||
|
stop_reason=output.stop_reason,
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
request.stream_options
|
||||||
|
and request.stream_options.include_usage
|
||||||
|
):
|
||||||
|
if request.stream_options.continuous_usage_stats:
|
||||||
|
prompt_tokens = len(res.prompt_token_ids)
|
||||||
|
completion_tokens = len(output.token_ids)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
chunk.usage = usage
|
||||||
|
else:
|
||||||
|
chunk.usage = None
|
||||||
|
data = chunk.model_dump_json(exclude_unset=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
finish_reason_sent[i] = True
|
||||||
|
|
||||||
|
if request.stream_options and request.stream_options.include_usage:
|
||||||
|
final_usage = UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=previous_num_tokens[i],
|
||||||
|
total_tokens=prompt_tokens + previous_num_tokens[i],
|
||||||
|
)
|
||||||
|
|
||||||
|
final_usage_chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[],
|
||||||
|
model=model_name,
|
||||||
|
usage=final_usage,
|
||||||
|
)
|
||||||
|
final_usage_data = final_usage_chunk.model_dump_json(
|
||||||
|
exclude_unset=True, exclude_none=True
|
||||||
|
)
|
||||||
|
yield f"data: {final_usage_data}\n\n"
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
data = self.create_streaming_error_response(str(e))
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
# Send the final done message after all response.n are finished
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
async def chat_completion_full_generator(
|
||||||
|
self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
raw_request: Optional[Request],
|
||||||
|
result_generator: AsyncIterator[RequestOutput],
|
||||||
|
request_id: str,
|
||||||
|
conversation: List[ConversationMessage],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||||
|
model_name = self.served_model_names[0]
|
||||||
|
created_time = int(time.time())
|
||||||
|
final_res: Optional[RequestOutput] = None
|
||||||
|
|
||||||
|
async for res in result_generator:
|
||||||
|
if raw_request is not None and await raw_request.is_disconnected():
|
||||||
|
# Abort the request if the client disconnects.
|
||||||
|
await self.async_engine_client.abort(request_id)
|
||||||
|
return self.create_error_response("Client disconnected")
|
||||||
|
final_res = res
|
||||||
|
assert final_res is not None
|
||||||
|
|
||||||
|
choices: List[ChatCompletionResponseChoice] = []
|
||||||
|
|
||||||
|
role = self.get_chat_request_role(request)
|
||||||
|
for output in final_res.outputs:
|
||||||
|
token_ids = output.token_ids
|
||||||
|
out_logprobs = output.logprobs
|
||||||
|
|
||||||
|
if request.logprobs and request.top_logprobs is not None:
|
||||||
|
assert out_logprobs is not None, "Did not output logprobs"
|
||||||
|
logprobs = self._create_chat_logprobs(
|
||||||
|
token_ids=token_ids,
|
||||||
|
top_logprobs=out_logprobs,
|
||||||
|
num_output_top_logprobs=request.top_logprobs,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
|
#if (
|
||||||
|
# request.tool_choice
|
||||||
|
# and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
|
||||||
|
#):
|
||||||
|
# message = ChatMessage(
|
||||||
|
# role=role,
|
||||||
|
# content="",
|
||||||
|
# tool_calls=[
|
||||||
|
# ToolCall(
|
||||||
|
# function=FunctionCall(
|
||||||
|
# name=request.tool_choice.function.name,
|
||||||
|
# arguments=output.text,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# ],
|
||||||
|
# )
|
||||||
|
#elif not request.tool_choice or request.tool_choice == "none":
|
||||||
|
# message = ChatMessage(role=role, content=output.text)
|
||||||
|
msg = decode_function_call(output.text)
|
||||||
|
if "tool_calls" in msg and msg["tool_calls"] is not None and len(msg["tool_calls"]) > 0:
|
||||||
|
message = ChatMessage(
|
||||||
|
role=role,
|
||||||
|
content=msg.get("thought", ""),
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(
|
||||||
|
name=fc["function"]["name"],
|
||||||
|
arguments=json.dumps(fc["function"]["arguments"], ensure_ascii=False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for fc in msg["tool_calls"]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message = ChatMessage(role=role, content=msg.get("content", ""))
|
||||||
|
|
||||||
|
choice_data = ChatCompletionResponseChoice(
|
||||||
|
index=output.index,
|
||||||
|
message=message,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=output.finish_reason,
|
||||||
|
stop_reason=output.stop_reason,
|
||||||
|
)
|
||||||
|
choices.append(choice_data)
|
||||||
|
|
||||||
|
if request.echo:
|
||||||
|
last_msg_content = ""
|
||||||
|
if (
|
||||||
|
conversation
|
||||||
|
and conversation[-1].get("content")
|
||||||
|
and conversation[-1].get("role") == role
|
||||||
|
):
|
||||||
|
last_msg_content = conversation[-1]["content"]
|
||||||
|
|
||||||
|
for choice in choices:
|
||||||
|
full_message = last_msg_content + choice.message.content
|
||||||
|
choice.message.content = full_message
|
||||||
|
|
||||||
|
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||||
|
num_generated_tokens = sum(
|
||||||
|
len(output.token_ids) for output in final_res.outputs
|
||||||
|
)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=num_generated_tokens,
|
||||||
|
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||||
|
)
|
||||||
|
response = ChatCompletionResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created_time,
|
||||||
|
model=model_name,
|
||||||
|
choices=choices,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _get_top_logprobs(
|
||||||
|
self,
|
||||||
|
logprobs: Dict[int, Logprob],
|
||||||
|
top_logprobs: Optional[int],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
) -> List[ChatCompletionLogProb]:
|
||||||
|
return [
|
||||||
|
ChatCompletionLogProb(
|
||||||
|
token=(
|
||||||
|
token := self._get_decoded_token(
|
||||||
|
p[1],
|
||||||
|
p[0],
|
||||||
|
tokenizer,
|
||||||
|
return_as_token_id=self.return_tokens_as_token_ids,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
logprob=max(p[1].logprob, -9999.0),
|
||||||
|
bytes=list(token.encode("utf-8", errors="replace")),
|
||||||
|
)
|
||||||
|
for i, p in enumerate(logprobs.items())
|
||||||
|
if top_logprobs and i < top_logprobs
|
||||||
|
]
|
||||||
|
|
||||||
|
def _create_chat_logprobs(
|
||||||
|
self,
|
||||||
|
token_ids: GenericSequence[int],
|
||||||
|
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
num_output_top_logprobs: Optional[int] = None,
|
||||||
|
) -> ChatCompletionLogProbs:
|
||||||
|
"""Create OpenAI-style logprobs."""
|
||||||
|
|
||||||
|
logprobs_content = []
|
||||||
|
|
||||||
|
for i, token_id in enumerate(token_ids):
|
||||||
|
step_top_logprobs = top_logprobs[i]
|
||||||
|
if step_top_logprobs is None:
|
||||||
|
token = tokenizer.decode(token_id)
|
||||||
|
if self.return_tokens_as_token_ids:
|
||||||
|
token = f"token_id:{token_id}"
|
||||||
|
logprobs_content.append(
|
||||||
|
ChatCompletionLogProbsContent(
|
||||||
|
token=token, bytes=list(token.encode("utf-8", errors="replace"))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprobs_content.append(
|
||||||
|
ChatCompletionLogProbsContent(
|
||||||
|
token=self._get_decoded_token(
|
||||||
|
step_top_logprobs[token_id],
|
||||||
|
token_id,
|
||||||
|
tokenizer,
|
||||||
|
self.return_tokens_as_token_ids,
|
||||||
|
),
|
||||||
|
logprob=max(step_top_logprobs[token_id].logprob, -9999.0),
|
||||||
|
bytes=list(
|
||||||
|
step_top_logprobs[token_id].decoded_token.encode(
|
||||||
|
"utf-8", errors="replace"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
top_logprobs=self._get_top_logprobs(
|
||||||
|
step_top_logprobs, num_output_top_logprobs, tokenizer
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatCompletionLogProbs(content=logprobs_content)
|
||||||
1
demo/function_call/requirements.txt
Normal file
1
demo/function_call/requirements.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
datamodel_code_generator
|
||||||
386
demo/function_call/utils.py
Normal file
386
demo/function_call/utils.py
Normal file
@ -0,0 +1,386 @@
|
|||||||
|
import ast
|
||||||
|
import json
|
||||||
|
import keyword
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from collections import deque
|
||||||
|
from copy import deepcopy
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from datamodel_code_generator import DataModelType
|
||||||
|
from datamodel_code_generator.format import PythonVersion
|
||||||
|
from datamodel_code_generator.model import get_data_model_types
|
||||||
|
from datamodel_code_generator.parser.jsonschema import JsonSchemaParser
|
||||||
|
from jsonschema import Draft202012Validator, exceptions, validate
|
||||||
|
|
||||||
|
from transformers import LlamaTokenizer
|
||||||
|
from transformers.tokenization_utils_base import BatchEncoding
|
||||||
|
from transformers.utils import TensorType
|
||||||
|
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_function_call(
|
||||||
|
sequence: str,
|
||||||
|
tool_call_start="<|tool_call_start|>",
|
||||||
|
tool_call_end="<|tool_call_end|>",
|
||||||
|
thought_start="<|thought_start|>",
|
||||||
|
thought_end="<|thought_end|>",
|
||||||
|
):
|
||||||
|
if thought_end in sequence and thought_start in sequence:
|
||||||
|
thought_string, sequence = sequence.rsplit(thought_end, 1)
|
||||||
|
thought_string = thought_string.split(thought_start, 1)[1]
|
||||||
|
else:
|
||||||
|
thought_string = ""
|
||||||
|
if tool_call_start in sequence and tool_call_end in sequence:
|
||||||
|
tool_call_string, content = sequence.rsplit(tool_call_end, 1)
|
||||||
|
tool_call_string = tool_call_string.split(tool_call_start, 1)[1]
|
||||||
|
try:
|
||||||
|
tool_calls = []
|
||||||
|
tool_call_string = tool_call_string.strip()
|
||||||
|
if tool_call_string.startswith("```"):
|
||||||
|
tool_call_string = tool_call_string.lstrip("```").strip()
|
||||||
|
if tool_call_string.startswith("python"):
|
||||||
|
tool_call_string = tool_call_string.lstrip("python").strip()
|
||||||
|
if tool_call_string.endswith("```"):
|
||||||
|
tool_call_string = tool_call_string.rstrip("```").strip()
|
||||||
|
for kw in keyword.kwlist:
|
||||||
|
tool_call_string = tool_call_string.replace(
|
||||||
|
"," + kw + "=", "," + kw + "_="
|
||||||
|
)
|
||||||
|
tool_call_string = tool_call_string.replace(
|
||||||
|
" " + kw + "=", " " + kw + "_="
|
||||||
|
)
|
||||||
|
tool_call_string = tool_call_string.replace(
|
||||||
|
"(" + kw + "=", "(" + kw + "_="
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed = ast.parse(tool_call_string)
|
||||||
|
|
||||||
|
for elem in parsed.body:
|
||||||
|
assert isinstance(elem.value, ast.Call)
|
||||||
|
calls = resolve_ast_call(elem.value)
|
||||||
|
|
||||||
|
for func_name, func_args in calls.items():
|
||||||
|
new_args = {}
|
||||||
|
for k, v in func_args.items():
|
||||||
|
for kw in keyword.kwlist:
|
||||||
|
if k == kw + "_":
|
||||||
|
k = kw
|
||||||
|
new_args[k] = v
|
||||||
|
|
||||||
|
this_one = {"name": func_name, "arguments": new_args}
|
||||||
|
tool_calls.append(this_one)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": content.strip(),
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": tool_call,
|
||||||
|
"id": "call_" + uuid.uuid4().hex,
|
||||||
|
}
|
||||||
|
for tool_call in tool_calls
|
||||||
|
],
|
||||||
|
"role": "assistant",
|
||||||
|
}
|
||||||
|
except:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return {
|
||||||
|
"content": content.strip(),
|
||||||
|
"role": "assistant",
|
||||||
|
"thought": thought_string,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"content": sequence.strip(),
|
||||||
|
"role": "assistant",
|
||||||
|
"thought": thought_string,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def check_messages(conversation: List[Dict[str, str]], tools: List[Dict]):
|
||||||
|
if tools is not None:
|
||||||
|
for tool in tools:
|
||||||
|
if "type" not in tool or tool["type"] != "function":
|
||||||
|
raise ValueError(f"Tool {tool} is not valid")
|
||||||
|
if "name" not in tool["function"]:
|
||||||
|
raise ValueError(f"Tool {tool} is not valid")
|
||||||
|
if "parameters" not in tool["function"] or not check_tool(
|
||||||
|
tool["function"]["parameters"]["properties"]
|
||||||
|
):
|
||||||
|
raise ValueError(f"Tool {tool} is not valid")
|
||||||
|
for message in conversation:
|
||||||
|
if (
|
||||||
|
message["role"] == "assistant"
|
||||||
|
and "tool_calls" in message
|
||||||
|
and len(message["tool_calls"]) > 0
|
||||||
|
):
|
||||||
|
for tool_call in message["tool_calls"]:
|
||||||
|
if "id" not in tool_call:
|
||||||
|
raise ValueError(f"Tool call {tool_call} is not valid")
|
||||||
|
if tool_call["type"] != "function":
|
||||||
|
raise ValueError(f"Tool call {tool_call} is not valid")
|
||||||
|
if "function" not in tool_call:
|
||||||
|
raise ValueError(f"Tool call {tool_call} is not valid")
|
||||||
|
if not check_tool(tool_call["function"]):
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool call function {tool_call['function']} is not valid"
|
||||||
|
)
|
||||||
|
elif message["role"] == "tool":
|
||||||
|
if "tool_call_id" not in message:
|
||||||
|
raise ValueError(f"Tool message {message['content']} is not valid")
|
||||||
|
|
||||||
|
|
||||||
|
def check_tool(tool_schema):
|
||||||
|
try:
|
||||||
|
Draft202012Validator.check_schema(tool_schema)
|
||||||
|
return True
|
||||||
|
except exceptions.SchemaError as e:
|
||||||
|
print(f"SchemaError: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_args(args, tool_schema):
|
||||||
|
try:
|
||||||
|
validate(instance=args, schema=tool_schema)
|
||||||
|
return True
|
||||||
|
except exceptions.ValidationError as e:
|
||||||
|
print(f"Data failed validation: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def message_format(msg, system_suffix="", user_prefix=""):
|
||||||
|
if "thought" in msg and msg["thought"] is not None and len(msg["thought"]) > 0:
|
||||||
|
thought_prefix = f"<|thought_start|>\n{msg['thought']}\n<|thought_end|>\n"
|
||||||
|
else:
|
||||||
|
thought_prefix = ""
|
||||||
|
if msg["role"] == "assistant":
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if content is None:
|
||||||
|
content = ""
|
||||||
|
if (
|
||||||
|
"tool_calls" in msg
|
||||||
|
and msg["tool_calls"] is not None
|
||||||
|
and len(msg["tool_calls"]) > 0
|
||||||
|
):
|
||||||
|
|
||||||
|
def add_quotes(variable):
|
||||||
|
if isinstance(variable, str):
|
||||||
|
return repr(variable)
|
||||||
|
else:
|
||||||
|
return str(variable)
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
for _tool_call in msg["tool_calls"]:
|
||||||
|
if _tool_call is None:
|
||||||
|
continue
|
||||||
|
tool_call = _tool_call["function"]
|
||||||
|
tool_name = tool_call["name"]
|
||||||
|
if "arguments" not in tool_call or tool_call["arguments"] is None:
|
||||||
|
continue
|
||||||
|
if isinstance(tool_call["arguments"], str):
|
||||||
|
try:
|
||||||
|
tool_call["arguments"] = json.loads(tool_call["arguments"])
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
args = ",".join(
|
||||||
|
[k + "=" + add_quotes(v) for k, v in tool_call["arguments"].items()]
|
||||||
|
)
|
||||||
|
tool_calls.append(f"{tool_name}({args})")
|
||||||
|
|
||||||
|
content = (
|
||||||
|
thought_prefix
|
||||||
|
+ "<|tool_call_start|>\n```python\n"
|
||||||
|
+ "\n".join(tool_calls).strip()
|
||||||
|
+ "\n```\n<|tool_call_end|>\n"
|
||||||
|
+ content
|
||||||
|
)
|
||||||
|
# msg["tool_call_string"] = "\n".join(tool_calls).strip()
|
||||||
|
msg["content"] = content
|
||||||
|
else:
|
||||||
|
content = thought_prefix + content
|
||||||
|
msg["content"] = content
|
||||||
|
elif msg["role"] == "user":
|
||||||
|
msg["content"] = user_prefix + "\n" + msg["content"]
|
||||||
|
elif msg["role"] == "system":
|
||||||
|
msg["content"] = msg["content"] + "\n" + system_suffix
|
||||||
|
msg["content"] = msg["content"].strip()
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def jsonschema_to_code(jsonschema: dict) -> str:
|
||||||
|
input_text = json.dumps(jsonschema)
|
||||||
|
data_model_types = get_data_model_types(
|
||||||
|
DataModelType.PydanticBaseModel,
|
||||||
|
PythonVersion.PY_310,
|
||||||
|
)
|
||||||
|
parser = JsonSchemaParser(
|
||||||
|
source=input_text,
|
||||||
|
data_model_type=data_model_types.data_model,
|
||||||
|
data_model_root_type=data_model_types.root_model,
|
||||||
|
data_model_field_type=data_model_types.field_model,
|
||||||
|
data_type_manager_type=data_model_types.data_type_manager,
|
||||||
|
target_python_version=PythonVersion.PY_310,
|
||||||
|
dump_resolve_reference_action=data_model_types.dump_resolve_reference_action,
|
||||||
|
field_constraints=True,
|
||||||
|
)
|
||||||
|
results = parser.parse()
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def transform_function(function: dict):
|
||||||
|
"""turn json format of function into signature"""
|
||||||
|
params, default_params = [], []
|
||||||
|
for prop_name, prop in function["parameters"]["properties"].items():
|
||||||
|
if "default" in prop:
|
||||||
|
default_params.append(f'{prop_name}={repr(prop["default"])}')
|
||||||
|
elif prop_name not in function["parameters"].get("required", []):
|
||||||
|
default_params.append(f"{prop_name}={repr(None)}")
|
||||||
|
else:
|
||||||
|
params.append(prop_name)
|
||||||
|
ps = ", ".join(params + default_params)
|
||||||
|
res = "def {f_name}({ps}):\n".format(f_name=function["name"], ps=ps)
|
||||||
|
f_des = function.get("description", "")
|
||||||
|
content = jsonschema_to_code(function["parameters"])
|
||||||
|
if "class" in content:
|
||||||
|
i = content.index("class")
|
||||||
|
# print(content[:i])
|
||||||
|
content = content[i:]
|
||||||
|
classes, args = content.split("class Model(BaseModel):", 1)
|
||||||
|
lint_msg = f' """\n {f_des}\n Args:\n{args}\n """\n'
|
||||||
|
res += lint_msg
|
||||||
|
if len(classes) > 0:
|
||||||
|
res = classes + res
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def input_format(messages: List[Dict], tools: List[Dict], add_to_system=True):
|
||||||
|
"""
|
||||||
|
Process the input messages, global_arguments, tools, tool_choice,
|
||||||
|
and convert it into a input string.
|
||||||
|
The global arguments and tools can not be both empty.
|
||||||
|
parameters:
|
||||||
|
messages: List[Dict]
|
||||||
|
the input messages
|
||||||
|
For example:
|
||||||
|
tools: List[Dict]
|
||||||
|
the tools list you can use
|
||||||
|
For example:
|
||||||
|
"""
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
tools = deepcopy(tools)
|
||||||
|
if tools is not None and len(tools) > 0:
|
||||||
|
header = "from enum import Enum\nfrom typing import List, Dict, Optional\nfrom pydantic import BaseModel, Field\n\n"
|
||||||
|
tools_string = header
|
||||||
|
for tool in tools:
|
||||||
|
try:
|
||||||
|
tools_string += "\n\n" + transform_function(tool)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
tools_template = """# Functions
|
||||||
|
Here is a list of functions that you can invoke:
|
||||||
|
```python
|
||||||
|
{tools}
|
||||||
|
```
|
||||||
|
|
||||||
|
# Function Call Rule and Output Format
|
||||||
|
- If the user's question can be answered without calling any function, please answer the user's question directly. In this situation, you should return your thought and answer the user's question directly.
|
||||||
|
- If the user cannot be answered without calling any function, and the user does not provide enough information to call functions, please ask the user for more information. In this situation, you should return your thought and ask the user for more information.
|
||||||
|
- If the user's question cannot be answered without calling any function, and the user has provided enough information to call functions to solve it, you should call the functions. In this situation, the assistant should return your thought and call the functions.
|
||||||
|
- Use default parameters unless the user has specified otherwise.
|
||||||
|
- You should answer in the following format:
|
||||||
|
|
||||||
|
<|thought_start|>
|
||||||
|
{{explain why the user's question can be answered without calling a function or why you should ask the user for more information or why you should call one or more functions and your plan to solve the user's question.}}
|
||||||
|
<|thought_end|>
|
||||||
|
<|tool_call_start|>
|
||||||
|
```python
|
||||||
|
func1(params_name=params_value, params_name2=params_value2...)
|
||||||
|
func2(params)
|
||||||
|
```
|
||||||
|
<|tool_call_end|>
|
||||||
|
{{answer the user's question directly or ask the user for more information}}
|
||||||
|
"""
|
||||||
|
tools_string = tools_template.format(tools=tools_string).strip()
|
||||||
|
else:
|
||||||
|
tools_string = ""
|
||||||
|
|
||||||
|
if add_to_system:
|
||||||
|
return [
|
||||||
|
message_format(msg, system_suffix=tools_string, user_prefix="")
|
||||||
|
for msg in messages
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
message_format(msg, system_suffix="", user_prefix=tools_string)
|
||||||
|
for msg in messages
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# This is a modified version of
|
||||||
|
# https://github.com/ShishirPatil/gorilla/blob/main/berkeley-function-call-leaderboard/bfcl/model_handler/utils.py
|
||||||
|
# Thanks to the gorilla team for the original implementation
|
||||||
|
def resolve_ast_call(elem):
|
||||||
|
# Handle nested attributes for deeply nested module paths
|
||||||
|
func_parts = []
|
||||||
|
func_part = elem.func
|
||||||
|
while isinstance(func_part, ast.Attribute):
|
||||||
|
func_parts.append(func_part.attr)
|
||||||
|
func_part = func_part.value
|
||||||
|
if isinstance(func_part, ast.Name):
|
||||||
|
func_parts.append(func_part.id)
|
||||||
|
func_name = ".".join(reversed(func_parts))
|
||||||
|
args_dict = {}
|
||||||
|
for arg in elem.keywords:
|
||||||
|
output = resolve_ast_by_type(arg.value)
|
||||||
|
args_dict[arg.arg] = output
|
||||||
|
return {func_name: args_dict}
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_ast_by_type(value):
|
||||||
|
if isinstance(value, ast.Constant):
|
||||||
|
if value.value is Ellipsis:
|
||||||
|
output = "..."
|
||||||
|
else:
|
||||||
|
output = value.value
|
||||||
|
elif isinstance(value, ast.UnaryOp):
|
||||||
|
output = -value.operand.value
|
||||||
|
elif isinstance(value, ast.List):
|
||||||
|
output = [resolve_ast_by_type(v) for v in value.elts]
|
||||||
|
elif isinstance(value, ast.Dict):
|
||||||
|
output = {
|
||||||
|
resolve_ast_by_type(k): resolve_ast_by_type(v)
|
||||||
|
for k, v in zip(value.keys, value.values)
|
||||||
|
}
|
||||||
|
elif isinstance(
|
||||||
|
value, ast.NameConstant
|
||||||
|
): # Added this condition to handle boolean values
|
||||||
|
output = value.value
|
||||||
|
elif isinstance(
|
||||||
|
value, ast.BinOp
|
||||||
|
): # Added this condition to handle function calls as arguments
|
||||||
|
output = eval(ast.unparse(value))
|
||||||
|
elif isinstance(value, ast.Name):
|
||||||
|
output = value.id
|
||||||
|
elif isinstance(value, ast.Call):
|
||||||
|
if len(value.keywords) == 0:
|
||||||
|
output = ast.unparse(value)
|
||||||
|
else:
|
||||||
|
output = resolve_ast_call(value)
|
||||||
|
elif isinstance(value, ast.Tuple):
|
||||||
|
output = tuple(resolve_ast_by_type(v) for v in value.elts)
|
||||||
|
elif isinstance(value, ast.Lambda):
|
||||||
|
output = eval(ast.unparse(value.body[0].value))
|
||||||
|
elif isinstance(value, ast.Ellipsis):
|
||||||
|
output = "..."
|
||||||
|
elif isinstance(value, ast.Subscript):
|
||||||
|
try:
|
||||||
|
output = ast.unparse(value.body[0].value)
|
||||||
|
except:
|
||||||
|
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unsupported AST type: {type(value)}")
|
||||||
|
return output
|
||||||
Loading…
x
Reference in New Issue
Block a user