mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-22 14:30:05 +08:00
commit
cf9a5be5be
@ -1,7 +1,4 @@
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import argparse
|
||||
import gradio as gr
|
||||
import torch
|
||||
@ -16,7 +13,7 @@ import warnings
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, default="")
|
||||
parser.add_argument("--model_path", type=str, default="openbmb/MiniCPM-2B-dpo-fp16")
|
||||
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"])
|
||||
parser.add_argument("--server_name", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--server_port", type=int, default=7860)
|
||||
@ -55,7 +52,7 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f
|
||||
str: real-time generation results of hf model
|
||||
"""
|
||||
inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
|
||||
enc = tokenizer(inputs, return_tensors="pt").to("cuda")
|
||||
enc = tokenizer(inputs, return_tensors="pt").to(next(model.parameters()).device)
|
||||
streamer = TextIteratorStreamer(tokenizer)
|
||||
generation_kwargs = dict(
|
||||
enc,
|
||||
|
||||
55
demo/openai_api_demo/openai_api_request_demo.py
Normal file
55
demo/openai_api_demo/openai_api_request_demo.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""
|
||||
这是一个简单的OpenAI接口代码,由于 MiniCPM-2B的限制,该脚本:
|
||||
1. 没有工具调用功能
|
||||
2. 没有System Prompt
|
||||
3. 最大支持文本 4096 长度
|
||||
|
||||
运行本代码需要:
|
||||
1. 启动本地服务,本方案使用的是 AutoModelForCausalLM.from_pretrained 读入模型,没有进行优化,可以根据需要自行修改。
|
||||
2. 通过此代码进行请求。
|
||||
"""
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
base_url = "http://127.0.0.1:8000/v1/"
|
||||
client = OpenAI(api_key="MiniCPM-2B", base_url=base_url)
|
||||
|
||||
def chat(use_stream=True):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "tell me a story"
|
||||
}
|
||||
]
|
||||
response = client.chat.completions.create(
|
||||
model="MiniCPM-2B",
|
||||
messages=messages,
|
||||
stream=use_stream,
|
||||
max_tokens=4096, # need less than 4096 tokens
|
||||
temperature=0.8,
|
||||
top_p=0.8
|
||||
)
|
||||
if response:
|
||||
if use_stream:
|
||||
for chunk in response:
|
||||
print(chunk.choices[0].delta.content)
|
||||
else:
|
||||
content = response.choices[0].message.content
|
||||
print(content)
|
||||
else:
|
||||
print("Error:", response.status_code)
|
||||
|
||||
|
||||
def embedding():
|
||||
response = client.embeddings.create(
|
||||
model="bge-m3",
|
||||
input=["hello, I am MiniCPM-2B"],
|
||||
)
|
||||
embeddings = response.data[0].embedding
|
||||
print("Embedding_Success:", len(embeddings))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
chat(use_stream=True)
|
||||
|
||||
|
||||
296
demo/openai_api_demo/openai_api_server_demo.py
Normal file
296
demo/openai_api_demo/openai_api_server_demo.py
Normal file
@ -0,0 +1,296 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
import tiktoken
|
||||
import torch
|
||||
import uvicorn
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from loguru import logger
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
||||
|
||||
MODEL_PATH = os.environ.get('MODEL_PATH', 'openbmb/MiniCPM-2B-dpo-fp16')
|
||||
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
|
||||
|
||||
EMBEDDING_PATH = os.environ.get('EMBEDDING_PATH', 'BAAI/bge-m3')
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
yield
|
||||
# clean cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "owner"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: Optional[list] = None
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = []
|
||||
|
||||
|
||||
class FunctionCallResponse(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["user", "assistant", "system", "function"]
|
||||
content: str = None
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: List[str]
|
||||
model: str
|
||||
|
||||
|
||||
class CompletionUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
data: list
|
||||
model: str
|
||||
object: str
|
||||
usage: CompletionUsage
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = 0.8
|
||||
top_p: Optional[float] = 0.8
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[Union[dict, List[dict]]] = None
|
||||
repetition_penalty: Optional[float] = 1.1
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Literal["stop", "length"]
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length"]]
|
||||
index: int
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
model: str
|
||||
id: str
|
||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
async def list_models():
|
||||
model_card = ModelCard(
|
||||
id="MiniCPM-2B"
|
||||
)
|
||||
return ModelList(
|
||||
data=[model_card]
|
||||
)
|
||||
|
||||
|
||||
def generate_minicpm(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, params: dict):
|
||||
messages = params["messages"]
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
max_new_tokens = int(params.get("max_tokens", 256))
|
||||
inputs = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
||||
enc = tokenizer(inputs, return_tensors="pt").to(model.device)
|
||||
input_echo_len = len(enc["input_ids"][0])
|
||||
|
||||
if input_echo_len >= model.config.max_length:
|
||||
logger.error(f"Input length larger than {model.config.max_length}")
|
||||
return
|
||||
streamer = TextIteratorStreamer(tokenizer)
|
||||
generation_kwargs = {
|
||||
**enc,
|
||||
"do_sample": True if temperature > 1e-5 else False,
|
||||
"top_k": 0,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"streamer": streamer,
|
||||
}
|
||||
eos_token = tokenizer.eos_token
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
response = ""
|
||||
for new_text in streamer:
|
||||
new_text = new_text.split(eos_token)[0] if eos_token in new_text else new_text
|
||||
response += new_text
|
||||
current_length = len(new_text)
|
||||
yield {
|
||||
"text": response[5 + len(inputs):],
|
||||
"usage": {
|
||||
"prompt_tokens": input_echo_len,
|
||||
"completion_tokens": current_length - input_echo_len,
|
||||
"total_tokens": len(response),
|
||||
},
|
||||
"finish_reason": "",
|
||||
}
|
||||
thread.join()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
|
||||
async def get_embeddings(request: EmbeddingRequest):
|
||||
embeddings = [embedding_model.encode(text) for text in request.input]
|
||||
embeddings = [embedding.tolist() for embedding in embeddings]
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
num_tokens = len(encoding.encode(string))
|
||||
return num_tokens
|
||||
|
||||
response = {
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": embedding,
|
||||
"index": index
|
||||
}
|
||||
for index, embedding in enumerate(embeddings)
|
||||
],
|
||||
"model": request.model,
|
||||
"object": "list",
|
||||
"usage": CompletionUsage(
|
||||
prompt_tokens=sum(len(text.split()) for text in request.input),
|
||||
completion_tokens=0,
|
||||
total_tokens=sum(num_tokens_from_string(text) for text in request.input),
|
||||
)
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
global model, tokenizer
|
||||
|
||||
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
|
||||
gen_params = dict(
|
||||
messages=request.messages,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_tokens=request.max_tokens or 2048,
|
||||
echo=False,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
tools=request.tools,
|
||||
)
|
||||
logger.debug(f"==== request ====\n{gen_params}")
|
||||
input_tokens = sum(len(tokenizer.encode(msg.content)) for msg in request.messages)
|
||||
if request.stream:
|
||||
async def stream_response():
|
||||
previous_text = ""
|
||||
for new_response in generate_minicpm(model, tokenizer, gen_params):
|
||||
delta_text = new_response["text"][len(previous_text):]
|
||||
previous_text = new_response["text"]
|
||||
delta = DeltaMessage(content=delta_text, role="assistant")
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=delta,
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = {
|
||||
"model": request.model,
|
||||
"id": "",
|
||||
"choices": [choice_data.dict(exclude_none=True)],
|
||||
"object": "chat.completion.chunk"
|
||||
}
|
||||
yield json.dumps(chunk) + "\n"
|
||||
|
||||
return EventSourceResponse(stream_response(), media_type="text/event-stream")
|
||||
|
||||
else:
|
||||
generated_text = ""
|
||||
for response in generate_minicpm(model, tokenizer, gen_params):
|
||||
generated_text = response["text"]
|
||||
generated_text = generated_text.strip()
|
||||
output_tokens = len(tokenizer.encode(generated_text))
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
total_tokens=output_tokens + input_tokens
|
||||
)
|
||||
message = ChatMessage(role="assistant", content=generated_text)
|
||||
logger.debug(f"==== message ====\n{message}")
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason="stop",
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
model=request.model,
|
||||
id="",
|
||||
choices=[choice_data],
|
||||
object="chat.completion",
|
||||
usage=usage
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto",
|
||||
trust_remote_code=True)
|
||||
embedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda")
|
||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
||||
@ -1,36 +1,46 @@
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import argparse
|
||||
import gradio as gr
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, default="")
|
||||
parser.add_argument("--model_path", type=str, default="openbmb/MiniCPM-1B-sft-bf16")
|
||||
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])
|
||||
parser.add_argument("--server_name", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--server_port", type=int, default=7860)
|
||||
args = parser.parse_args()
|
||||
parser.add_argument("--max_tokens", type=int, default=2048)
|
||||
# for MiniCPM-1B and MiniCPM-2B model, max_tokens should be set to 2048
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# init model torch dtype
|
||||
torch_dtype = args.torch_dtype
|
||||
if torch_dtype =="" or torch_dtype == "bfloat16":
|
||||
torch_dtype = "bfloat16"
|
||||
if torch_dtype == "" or torch_dtype == "bfloat16":
|
||||
torch_dtype = torch.bfloat16
|
||||
elif torch_dtype == "float32":
|
||||
torch_dtype = "float32"
|
||||
torch_dtype = torch.float32
|
||||
elif torch_dtype == "float16":
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
|
||||
|
||||
# init model and tokenizer
|
||||
path = args.model_path
|
||||
llm = LLM(model=path, tensor_parallel_size=1, dtype=torch_dtype)
|
||||
llm = LLM(
|
||||
model=path,
|
||||
tensor_parallel_size=1,
|
||||
dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=args.max_tokens
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
|
||||
|
||||
server_name = args.server_name
|
||||
server_port = args.server_port
|
||||
|
||||
# init gradio demo host and port
|
||||
server_name=args.server_name
|
||||
server_port=args.server_port
|
||||
|
||||
def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
|
||||
"""generate model output with huggingface api
|
||||
@ -43,19 +53,14 @@ def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
|
||||
|
||||
Yields:
|
||||
str: real-time generation results of hf model
|
||||
"""
|
||||
prompt = ""
|
||||
"""
|
||||
assert len(dialog) % 2 == 1
|
||||
for info in dialog:
|
||||
if info["role"] == "user":
|
||||
prompt += "<用户>" + info["content"]
|
||||
else:
|
||||
prompt += "<AI>" + info["content"]
|
||||
prompt += "<AI>"
|
||||
prompt = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
|
||||
token_ids = tokenizer.convert_tokens_to_ids(["<|im_end|>"])
|
||||
params_dict = {
|
||||
"n": 1,
|
||||
"best_of": 1,
|
||||
"presence_penalty": 1.0,
|
||||
"presence_penalty": 1.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
@ -63,8 +68,8 @@ def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
|
||||
"use_beam_search": False,
|
||||
"length_penalty": 1,
|
||||
"early_stopping": False,
|
||||
"stop": None,
|
||||
"stop_token_ids": None,
|
||||
"stop": "<|im_end|>",
|
||||
"stop_token_ids": token_ids,
|
||||
"ignore_eos": False,
|
||||
"max_tokens": max_dec_len,
|
||||
"logprobs": None,
|
||||
@ -89,7 +94,7 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, m
|
||||
|
||||
Yields:
|
||||
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
|
||||
"""
|
||||
"""
|
||||
assert query != "", "Input must not be empty!!!"
|
||||
# apply chat template
|
||||
model_input = []
|
||||
@ -114,7 +119,7 @@ def regenerate(chat_history: List, top_p: float, temperature: float, max_dec_len
|
||||
|
||||
Yields:
|
||||
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
|
||||
"""
|
||||
"""
|
||||
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
|
||||
# apply chat template
|
||||
model_input = []
|
||||
@ -133,7 +138,7 @@ def clear_history():
|
||||
|
||||
Returns:
|
||||
List: empty chat history
|
||||
"""
|
||||
"""
|
||||
return []
|
||||
|
||||
|
||||
@ -145,7 +150,7 @@ def reverse_last_round(chat_history):
|
||||
|
||||
Returns:
|
||||
List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round.
|
||||
"""
|
||||
"""
|
||||
assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!"
|
||||
return chat_history[:-1]
|
||||
|
||||
@ -158,7 +163,7 @@ with gr.Blocks(theme="soft") as demo:
|
||||
with gr.Column(scale=1):
|
||||
top_p = gr.Slider(0, 1, value=0.8, step=0.1, label="top_p")
|
||||
temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature")
|
||||
max_dec_len = gr.Slider(1, 1024, value=1024, step=1, label="max_dec_len")
|
||||
max_dec_len = gr.Slider(1, args.max_tokens, value=args.max_tokens, step=1, label="max_tokens")
|
||||
with gr.Column(scale=5):
|
||||
chatbot = gr.Chatbot(bubble_full_width=False, height=400)
|
||||
user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8)
|
||||
|
||||
19
requirements.txt
Normal file
19
requirements.txt
Normal file
@ -0,0 +1,19 @@
|
||||
# for MiniCPM-2B hf inference
|
||||
torch>=2.0.0
|
||||
transformers>=4.36.2
|
||||
gradio>=4.26.0
|
||||
|
||||
# for vllm inference
|
||||
# vllm>=0.4.0.post1
|
||||
|
||||
# for openai api inference
|
||||
openai>=1.17.1
|
||||
tiktoken>=0.6.0
|
||||
loguru>=0.7.2
|
||||
sentence_transformers>=2.6.1
|
||||
sse_starlette>=2.1.0
|
||||
|
||||
# for MiniCPM-V hf inference
|
||||
Pillow>=10.3.0
|
||||
timm>=0.9.16
|
||||
sentencepiece>=0.2.0
|
||||
Loading…
x
Reference in New Issue
Block a user