MiniCPM/demo/openai_api_demo/openai_api_server_demo.py
zR b9b53e2e19 OpenAI 推理简单应用
用的是Linux transformers载入推理的,只测试了常规对话,在跑的时候没有看到支持Function 功能就没写
2024-04-11 19:31:08 +08:00

297 lines
9.1 KiB
Python

import gc
import json
import os
import time
from threading import Thread
import tiktoken
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from loguru import logger
from sse_starlette.sse import EventSourceResponse
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
MODEL_PATH = os.environ.get('MODEL_PATH', 'openbmb/MiniCPM-2B-dpo-fp16')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
EMBEDDING_PATH = os.environ.get('EMBEDDING_PATH', 'BAAI/bge-m3')
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
# clean cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class FunctionCallResponse(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system", "function"]
content: str = None
name: Optional[str] = None
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class EmbeddingRequest(BaseModel):
input: List[str]
model: str
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class EmbeddingResponse(BaseModel):
data: list
model: str
object: str
usage: CompletionUsage
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.8
max_tokens: Optional[int] = None
stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None
repetition_penalty: Optional[float] = 1.1
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
index: int
class ChatCompletionResponse(BaseModel):
model: str
id: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None
@app.get("/v1/models", response_model=ModelList)
async def list_models():
model_card = ModelCard(
id="MiniCPM-2B"
)
return ModelList(
data=[model_card]
)
def generate_minicpm(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, params: dict):
messages = params["messages"]
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_tokens", 256))
inputs = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
enc = tokenizer(inputs, return_tensors="pt").to(model.device)
input_echo_len = len(enc["input_ids"][0])
if input_echo_len >= model.config.max_length:
logger.error(f"Input length larger than {model.config.max_length}")
return
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
**enc,
"do_sample": True if temperature > 1e-5 else False,
"top_k": 0,
"top_p": top_p,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"max_new_tokens": max_new_tokens,
"pad_token_id": tokenizer.eos_token_id,
"streamer": streamer,
}
eos_token = tokenizer.eos_token
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
new_text = new_text.split(eos_token)[0] if eos_token in new_text else new_text
response += new_text
current_length = len(new_text)
yield {
"text": response[5 + len(inputs):],
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": current_length - input_echo_len,
"total_tokens": len(response),
},
"finish_reason": "",
}
thread.join()
gc.collect()
torch.cuda.empty_cache()
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest):
embeddings = [embedding_model.encode(text) for text in request.input]
embeddings = [embedding.tolist() for embedding in embeddings]
def num_tokens_from_string(string: str) -> int:
encoding = tiktoken.get_encoding('cl100k_base')
num_tokens = len(encoding.encode(string))
return num_tokens
response = {
"data": [
{
"object": "embedding",
"embedding": embedding,
"index": index
}
for index, embedding in enumerate(embeddings)
],
"model": request.model,
"object": "list",
"usage": CompletionUsage(
prompt_tokens=sum(len(text.split()) for text in request.input),
completion_tokens=0,
total_tokens=sum(num_tokens_from_string(text) for text in request.input),
)
}
return response
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")
gen_params = dict(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens or 2048,
echo=False,
repetition_penalty=request.repetition_penalty,
tools=request.tools,
)
logger.debug(f"==== request ====\n{gen_params}")
input_tokens = sum(len(tokenizer.encode(msg.content)) for msg in request.messages)
if request.stream:
async def stream_response():
previous_text = ""
for new_response in generate_minicpm(model, tokenizer, gen_params):
delta_text = new_response["text"][len(previous_text):]
previous_text = new_response["text"]
delta = DeltaMessage(content=delta_text, role="assistant")
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=delta,
finish_reason=None
)
chunk = {
"model": request.model,
"id": "",
"choices": [choice_data.dict(exclude_none=True)],
"object": "chat.completion.chunk"
}
yield json.dumps(chunk) + "\n"
return EventSourceResponse(stream_response(), media_type="text/event-stream")
else:
generated_text = ""
for response in generate_minicpm(model, tokenizer, gen_params):
generated_text = response["text"]
generated_text = generated_text.strip()
output_tokens = len(tokenizer.encode(generated_text))
usage = UsageInfo(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=output_tokens + input_tokens
)
message = ChatMessage(role="assistant", content=generated_text)
logger.debug(f"==== message ====\n{message}")
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason="stop",
)
return ChatCompletionResponse(
model=request.model,
id="",
choices=[choice_data],
object="chat.completion",
usage=usage
)
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto",
trust_remote_code=True)
embedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda")
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)