Merge pull request #122 from zRzRzRzRzRzRzR/main

OpenAI API Support
This commit is contained in:
LDLINGLINGLING 2024-06-21 11:03:55 +08:00 committed by GitHub
commit cf9a5be5be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 406 additions and 34 deletions

View File

@ -1,7 +1,4 @@
from typing import Dict
from typing import List
from typing import Tuple
import argparse
import gradio as gr
import torch
@ -16,7 +13,7 @@ import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--model_path", type=str, default="openbmb/MiniCPM-2B-dpo-fp16")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"])
parser.add_argument("--server_name", type=str, default="127.0.0.1")
parser.add_argument("--server_port", type=int, default=7860)
@ -55,7 +52,7 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f
str: real-time generation results of hf model
"""
inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
enc = tokenizer(inputs, return_tensors="pt").to("cuda")
enc = tokenizer(inputs, return_tensors="pt").to(next(model.parameters()).device)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict(
enc,

View File

@ -0,0 +1,55 @@
"""
这是一个简单的OpenAI接口代码,由于 MiniCPM-2B的限制该脚本
1. 没有工具调用功能
2. 没有System Prompt
3. 最大支持文本 4096 长度
运行本代码需要
1. 启动本地服务本方案使用的是 AutoModelForCausalLM.from_pretrained 读入模型没有进行优化可以根据需要自行修改
2. 通过此代码进行请求
"""
from openai import OpenAI
base_url = "http://127.0.0.1:8000/v1/"
client = OpenAI(api_key="MiniCPM-2B", base_url=base_url)
def chat(use_stream=True):
messages = [
{
"role": "user",
"content": "tell me a story"
}
]
response = client.chat.completions.create(
model="MiniCPM-2B",
messages=messages,
stream=use_stream,
max_tokens=4096, # need less than 4096 tokens
temperature=0.8,
top_p=0.8
)
if response:
if use_stream:
for chunk in response:
print(chunk.choices[0].delta.content)
else:
content = response.choices[0].message.content
print(content)
else:
print("Error:", response.status_code)
def embedding():
response = client.embeddings.create(
model="bge-m3",
input=["hello, I am MiniCPM-2B"],
)
embeddings = response.data[0].embedding
print("Embedding_Success", len(embeddings))
if __name__ == "__main__":
chat(use_stream=True)

View File

@ -0,0 +1,296 @@
import gc
import json
import os
import time
from threading import Thread
import tiktoken
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from loguru import logger
from sse_starlette.sse import EventSourceResponse
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
MODEL_PATH = os.environ.get('MODEL_PATH', 'openbmb/MiniCPM-2B-dpo-fp16')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
EMBEDDING_PATH = os.environ.get('EMBEDDING_PATH', 'BAAI/bge-m3')
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
# clean cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class FunctionCallResponse(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system", "function"]
content: str = None
name: Optional[str] = None
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class EmbeddingRequest(BaseModel):
input: List[str]
model: str
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class EmbeddingResponse(BaseModel):
data: list
model: str
object: str
usage: CompletionUsage
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.8
max_tokens: Optional[int] = None
stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None
repetition_penalty: Optional[float] = 1.1
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
index: int
class ChatCompletionResponse(BaseModel):
model: str
id: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None
@app.get("/v1/models", response_model=ModelList)
async def list_models():
model_card = ModelCard(
id="MiniCPM-2B"
)
return ModelList(
data=[model_card]
)
def generate_minicpm(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, params: dict):
messages = params["messages"]
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_tokens", 256))
inputs = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
enc = tokenizer(inputs, return_tensors="pt").to(model.device)
input_echo_len = len(enc["input_ids"][0])
if input_echo_len >= model.config.max_length:
logger.error(f"Input length larger than {model.config.max_length}")
return
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
**enc,
"do_sample": True if temperature > 1e-5 else False,
"top_k": 0,
"top_p": top_p,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"max_new_tokens": max_new_tokens,
"pad_token_id": tokenizer.eos_token_id,
"streamer": streamer,
}
eos_token = tokenizer.eos_token
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
new_text = new_text.split(eos_token)[0] if eos_token in new_text else new_text
response += new_text
current_length = len(new_text)
yield {
"text": response[5 + len(inputs):],
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": current_length - input_echo_len,
"total_tokens": len(response),
},
"finish_reason": "",
}
thread.join()
gc.collect()
torch.cuda.empty_cache()
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest):
embeddings = [embedding_model.encode(text) for text in request.input]
embeddings = [embedding.tolist() for embedding in embeddings]
def num_tokens_from_string(string: str) -> int:
encoding = tiktoken.get_encoding('cl100k_base')
num_tokens = len(encoding.encode(string))
return num_tokens
response = {
"data": [
{
"object": "embedding",
"embedding": embedding,
"index": index
}
for index, embedding in enumerate(embeddings)
],
"model": request.model,
"object": "list",
"usage": CompletionUsage(
prompt_tokens=sum(len(text.split()) for text in request.input),
completion_tokens=0,
total_tokens=sum(num_tokens_from_string(text) for text in request.input),
)
}
return response
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")
gen_params = dict(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens or 2048,
echo=False,
repetition_penalty=request.repetition_penalty,
tools=request.tools,
)
logger.debug(f"==== request ====\n{gen_params}")
input_tokens = sum(len(tokenizer.encode(msg.content)) for msg in request.messages)
if request.stream:
async def stream_response():
previous_text = ""
for new_response in generate_minicpm(model, tokenizer, gen_params):
delta_text = new_response["text"][len(previous_text):]
previous_text = new_response["text"]
delta = DeltaMessage(content=delta_text, role="assistant")
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=delta,
finish_reason=None
)
chunk = {
"model": request.model,
"id": "",
"choices": [choice_data.dict(exclude_none=True)],
"object": "chat.completion.chunk"
}
yield json.dumps(chunk) + "\n"
return EventSourceResponse(stream_response(), media_type="text/event-stream")
else:
generated_text = ""
for response in generate_minicpm(model, tokenizer, gen_params):
generated_text = response["text"]
generated_text = generated_text.strip()
output_tokens = len(tokenizer.encode(generated_text))
usage = UsageInfo(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=output_tokens + input_tokens
)
message = ChatMessage(role="assistant", content=generated_text)
logger.debug(f"==== message ====\n{message}")
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason="stop",
)
return ChatCompletionResponse(
model=request.model,
id="",
choices=[choice_data],
object="chat.completion",
usage=usage
)
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto",
trust_remote_code=True)
embedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda")
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)

View File

@ -1,36 +1,46 @@
from typing import Dict
from typing import List
from typing import Tuple
import argparse
import gradio as gr
from vllm import LLM, SamplingParams
import torch
from transformers import AutoTokenizer
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--model_path", type=str, default="openbmb/MiniCPM-1B-sft-bf16")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])
parser.add_argument("--server_name", type=str, default="127.0.0.1")
parser.add_argument("--server_port", type=int, default=7860)
args = parser.parse_args()
parser.add_argument("--max_tokens", type=int, default=2048)
# for MiniCPM-1B and MiniCPM-2B model, max_tokens should be set to 2048
args = parser.parse_args()
# init model torch dtype
torch_dtype = args.torch_dtype
if torch_dtype =="" or torch_dtype == "bfloat16":
torch_dtype = "bfloat16"
if torch_dtype == "" or torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = "float32"
torch_dtype = torch.float32
elif torch_dtype == "float16":
torch_dtype = torch.float16
else:
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
# init model and tokenizer
path = args.model_path
llm = LLM(model=path, tensor_parallel_size=1, dtype=torch_dtype)
llm = LLM(
model=path,
tensor_parallel_size=1,
dtype=torch_dtype,
trust_remote_code=True,
gpu_memory_utilization=0.9,
max_model_len=args.max_tokens
)
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
server_name = args.server_name
server_port = args.server_port
# init gradio demo host and port
server_name=args.server_name
server_port=args.server_port
def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
"""generate model output with huggingface api
@ -43,19 +53,14 @@ def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
Yields:
str: real-time generation results of hf model
"""
prompt = ""
"""
assert len(dialog) % 2 == 1
for info in dialog:
if info["role"] == "user":
prompt += "<用户>" + info["content"]
else:
prompt += "<AI>" + info["content"]
prompt += "<AI>"
prompt = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
token_ids = tokenizer.convert_tokens_to_ids(["<|im_end|>"])
params_dict = {
"n": 1,
"best_of": 1,
"presence_penalty": 1.0,
"presence_penalty": 1.0,
"frequency_penalty": 0.0,
"temperature": temperature,
"top_p": top_p,
@ -63,8 +68,8 @@ def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
"use_beam_search": False,
"length_penalty": 1,
"early_stopping": False,
"stop": None,
"stop_token_ids": None,
"stop": "<|im_end|>",
"stop_token_ids": token_ids,
"ignore_eos": False,
"max_tokens": max_dec_len,
"logprobs": None,
@ -89,7 +94,7 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, m
Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
"""
"""
assert query != "", "Input must not be empty!!!"
# apply chat template
model_input = []
@ -114,7 +119,7 @@ def regenerate(chat_history: List, top_p: float, temperature: float, max_dec_len
Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
"""
"""
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
# apply chat template
model_input = []
@ -133,7 +138,7 @@ def clear_history():
Returns:
List: empty chat history
"""
"""
return []
@ -145,7 +150,7 @@ def reverse_last_round(chat_history):
Returns:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round.
"""
"""
assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!"
return chat_history[:-1]
@ -158,7 +163,7 @@ with gr.Blocks(theme="soft") as demo:
with gr.Column(scale=1):
top_p = gr.Slider(0, 1, value=0.8, step=0.1, label="top_p")
temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature")
max_dec_len = gr.Slider(1, 1024, value=1024, step=1, label="max_dec_len")
max_dec_len = gr.Slider(1, args.max_tokens, value=args.max_tokens, step=1, label="max_tokens")
with gr.Column(scale=5):
chatbot = gr.Chatbot(bubble_full_width=False, height=400)
user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8)

19
requirements.txt Normal file
View File

@ -0,0 +1,19 @@
# for MiniCPM-2B hf inference
torch>=2.0.0
transformers>=4.36.2
gradio>=4.26.0
# for vllm inference
# vllm>=0.4.0.post1
# for openai api inference
openai>=1.17.1
tiktoken>=0.6.0
loguru>=0.7.2
sentence_transformers>=2.6.1
sse_starlette>=2.1.0
# for MiniCPM-V hf inference
Pillow>=10.3.0
timm>=0.9.16
sentencepiece>=0.2.0