mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-02-02 13:15:44 +08:00
commit
cf9a5be5be
@ -1,7 +1,4 @@
|
|||||||
from typing import Dict
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
@ -16,7 +13,7 @@ import warnings
|
|||||||
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--model_path", type=str, default="")
|
parser.add_argument("--model_path", type=str, default="openbmb/MiniCPM-2B-dpo-fp16")
|
||||||
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"])
|
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"])
|
||||||
parser.add_argument("--server_name", type=str, default="127.0.0.1")
|
parser.add_argument("--server_name", type=str, default="127.0.0.1")
|
||||||
parser.add_argument("--server_port", type=int, default=7860)
|
parser.add_argument("--server_port", type=int, default=7860)
|
||||||
@ -55,7 +52,7 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f
|
|||||||
str: real-time generation results of hf model
|
str: real-time generation results of hf model
|
||||||
"""
|
"""
|
||||||
inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
|
inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
|
||||||
enc = tokenizer(inputs, return_tensors="pt").to("cuda")
|
enc = tokenizer(inputs, return_tensors="pt").to(next(model.parameters()).device)
|
||||||
streamer = TextIteratorStreamer(tokenizer)
|
streamer = TextIteratorStreamer(tokenizer)
|
||||||
generation_kwargs = dict(
|
generation_kwargs = dict(
|
||||||
enc,
|
enc,
|
||||||
|
|||||||
55
demo/openai_api_demo/openai_api_request_demo.py
Normal file
55
demo/openai_api_demo/openai_api_request_demo.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"""
|
||||||
|
这是一个简单的OpenAI接口代码,由于 MiniCPM-2B的限制,该脚本:
|
||||||
|
1. 没有工具调用功能
|
||||||
|
2. 没有System Prompt
|
||||||
|
3. 最大支持文本 4096 长度
|
||||||
|
|
||||||
|
运行本代码需要:
|
||||||
|
1. 启动本地服务,本方案使用的是 AutoModelForCausalLM.from_pretrained 读入模型,没有进行优化,可以根据需要自行修改。
|
||||||
|
2. 通过此代码进行请求。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
base_url = "http://127.0.0.1:8000/v1/"
|
||||||
|
client = OpenAI(api_key="MiniCPM-2B", base_url=base_url)
|
||||||
|
|
||||||
|
def chat(use_stream=True):
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "tell me a story"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="MiniCPM-2B",
|
||||||
|
messages=messages,
|
||||||
|
stream=use_stream,
|
||||||
|
max_tokens=4096, # need less than 4096 tokens
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.8
|
||||||
|
)
|
||||||
|
if response:
|
||||||
|
if use_stream:
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk.choices[0].delta.content)
|
||||||
|
else:
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
print(content)
|
||||||
|
else:
|
||||||
|
print("Error:", response.status_code)
|
||||||
|
|
||||||
|
|
||||||
|
def embedding():
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model="bge-m3",
|
||||||
|
input=["hello, I am MiniCPM-2B"],
|
||||||
|
)
|
||||||
|
embeddings = response.data[0].embedding
|
||||||
|
print("Embedding_Success:", len(embeddings))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
chat(use_stream=True)
|
||||||
|
|
||||||
|
|
||||||
296
demo/openai_api_demo/openai_api_server_demo.py
Normal file
296
demo/openai_api_demo/openai_api_server_demo.py
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
import torch
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import List, Literal, Optional, Union
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from loguru import logger
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
|
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
||||||
|
|
||||||
|
MODEL_PATH = os.environ.get('MODEL_PATH', 'openbmb/MiniCPM-2B-dpo-fp16')
|
||||||
|
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
|
||||||
|
|
||||||
|
EMBEDDING_PATH = os.environ.get('EMBEDDING_PATH', 'BAAI/bge-m3')
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
yield
|
||||||
|
# clean cache
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCard(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "model"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
owned_by: str = "owner"
|
||||||
|
root: Optional[str] = None
|
||||||
|
parent: Optional[str] = None
|
||||||
|
permission: Optional[list] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelList(BaseModel):
|
||||||
|
object: str = "list"
|
||||||
|
data: List[ModelCard] = []
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCallResponse(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
arguments: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: Literal["user", "assistant", "system", "function"]
|
||||||
|
content: str = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DeltaMessage(BaseModel):
|
||||||
|
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||||
|
content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingRequest(BaseModel):
|
||||||
|
input: List[str]
|
||||||
|
model: str
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionUsage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingResponse(BaseModel):
|
||||||
|
data: list
|
||||||
|
model: str
|
||||||
|
object: str
|
||||||
|
usage: CompletionUsage
|
||||||
|
|
||||||
|
|
||||||
|
class UsageInfo(BaseModel):
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
completion_tokens: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages: List[ChatMessage]
|
||||||
|
temperature: Optional[float] = 0.8
|
||||||
|
top_p: Optional[float] = 0.8
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
tools: Optional[Union[dict, List[dict]]] = None
|
||||||
|
repetition_penalty: Optional[float] = 1.1
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
message: ChatMessage
|
||||||
|
finish_reason: Literal["stop", "length"]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
|
delta: DeltaMessage
|
||||||
|
finish_reason: Optional[Literal["stop", "length"]]
|
||||||
|
index: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
model: str
|
||||||
|
id: str
|
||||||
|
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||||
|
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||||
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
|
usage: Optional[UsageInfo] = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/models", response_model=ModelList)
|
||||||
|
async def list_models():
|
||||||
|
model_card = ModelCard(
|
||||||
|
id="MiniCPM-2B"
|
||||||
|
)
|
||||||
|
return ModelList(
|
||||||
|
data=[model_card]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_minicpm(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, params: dict):
|
||||||
|
messages = params["messages"]
|
||||||
|
temperature = float(params.get("temperature", 1.0))
|
||||||
|
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
||||||
|
top_p = float(params.get("top_p", 1.0))
|
||||||
|
max_new_tokens = int(params.get("max_tokens", 256))
|
||||||
|
inputs = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
||||||
|
enc = tokenizer(inputs, return_tensors="pt").to(model.device)
|
||||||
|
input_echo_len = len(enc["input_ids"][0])
|
||||||
|
|
||||||
|
if input_echo_len >= model.config.max_length:
|
||||||
|
logger.error(f"Input length larger than {model.config.max_length}")
|
||||||
|
return
|
||||||
|
streamer = TextIteratorStreamer(tokenizer)
|
||||||
|
generation_kwargs = {
|
||||||
|
**enc,
|
||||||
|
"do_sample": True if temperature > 1e-5 else False,
|
||||||
|
"top_k": 0,
|
||||||
|
"top_p": top_p,
|
||||||
|
"temperature": temperature,
|
||||||
|
"repetition_penalty": repetition_penalty,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"pad_token_id": tokenizer.eos_token_id,
|
||||||
|
"streamer": streamer,
|
||||||
|
}
|
||||||
|
eos_token = tokenizer.eos_token
|
||||||
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||||
|
thread.start()
|
||||||
|
response = ""
|
||||||
|
for new_text in streamer:
|
||||||
|
new_text = new_text.split(eos_token)[0] if eos_token in new_text else new_text
|
||||||
|
response += new_text
|
||||||
|
current_length = len(new_text)
|
||||||
|
yield {
|
||||||
|
"text": response[5 + len(inputs):],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": input_echo_len,
|
||||||
|
"completion_tokens": current_length - input_echo_len,
|
||||||
|
"total_tokens": len(response),
|
||||||
|
},
|
||||||
|
"finish_reason": "",
|
||||||
|
}
|
||||||
|
thread.join()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
|
||||||
|
async def get_embeddings(request: EmbeddingRequest):
|
||||||
|
embeddings = [embedding_model.encode(text) for text in request.input]
|
||||||
|
embeddings = [embedding.tolist() for embedding in embeddings]
|
||||||
|
|
||||||
|
def num_tokens_from_string(string: str) -> int:
|
||||||
|
encoding = tiktoken.get_encoding('cl100k_base')
|
||||||
|
num_tokens = len(encoding.encode(string))
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"embedding": embedding,
|
||||||
|
"index": index
|
||||||
|
}
|
||||||
|
for index, embedding in enumerate(embeddings)
|
||||||
|
],
|
||||||
|
"model": request.model,
|
||||||
|
"object": "list",
|
||||||
|
"usage": CompletionUsage(
|
||||||
|
prompt_tokens=sum(len(text.split()) for text in request.input),
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=sum(num_tokens_from_string(text) for text in request.input),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||||
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
|
global model, tokenizer
|
||||||
|
|
||||||
|
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid request")
|
||||||
|
|
||||||
|
gen_params = dict(
|
||||||
|
messages=request.messages,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_p=request.top_p,
|
||||||
|
max_tokens=request.max_tokens or 2048,
|
||||||
|
echo=False,
|
||||||
|
repetition_penalty=request.repetition_penalty,
|
||||||
|
tools=request.tools,
|
||||||
|
)
|
||||||
|
logger.debug(f"==== request ====\n{gen_params}")
|
||||||
|
input_tokens = sum(len(tokenizer.encode(msg.content)) for msg in request.messages)
|
||||||
|
if request.stream:
|
||||||
|
async def stream_response():
|
||||||
|
previous_text = ""
|
||||||
|
for new_response in generate_minicpm(model, tokenizer, gen_params):
|
||||||
|
delta_text = new_response["text"][len(previous_text):]
|
||||||
|
previous_text = new_response["text"]
|
||||||
|
delta = DeltaMessage(content=delta_text, role="assistant")
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=delta,
|
||||||
|
finish_reason=None
|
||||||
|
)
|
||||||
|
chunk = {
|
||||||
|
"model": request.model,
|
||||||
|
"id": "",
|
||||||
|
"choices": [choice_data.dict(exclude_none=True)],
|
||||||
|
"object": "chat.completion.chunk"
|
||||||
|
}
|
||||||
|
yield json.dumps(chunk) + "\n"
|
||||||
|
|
||||||
|
return EventSourceResponse(stream_response(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
else:
|
||||||
|
generated_text = ""
|
||||||
|
for response in generate_minicpm(model, tokenizer, gen_params):
|
||||||
|
generated_text = response["text"]
|
||||||
|
generated_text = generated_text.strip()
|
||||||
|
output_tokens = len(tokenizer.encode(generated_text))
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=input_tokens,
|
||||||
|
completion_tokens=output_tokens,
|
||||||
|
total_tokens=output_tokens + input_tokens
|
||||||
|
)
|
||||||
|
message = ChatMessage(role="assistant", content=generated_text)
|
||||||
|
logger.debug(f"==== message ====\n{message}")
|
||||||
|
choice_data = ChatCompletionResponseChoice(
|
||||||
|
index=0,
|
||||||
|
message=message,
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
model=request.model,
|
||||||
|
id="",
|
||||||
|
choices=[choice_data],
|
||||||
|
object="chat.completion",
|
||||||
|
usage=usage
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto",
|
||||||
|
trust_remote_code=True)
|
||||||
|
embedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda")
|
||||||
|
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
||||||
@ -1,36 +1,46 @@
|
|||||||
from typing import Dict
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--model_path", type=str, default="")
|
parser.add_argument("--model_path", type=str, default="openbmb/MiniCPM-1B-sft-bf16")
|
||||||
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])
|
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])
|
||||||
parser.add_argument("--server_name", type=str, default="127.0.0.1")
|
parser.add_argument("--server_name", type=str, default="127.0.0.1")
|
||||||
parser.add_argument("--server_port", type=int, default=7860)
|
parser.add_argument("--server_port", type=int, default=7860)
|
||||||
args = parser.parse_args()
|
parser.add_argument("--max_tokens", type=int, default=2048)
|
||||||
|
# for MiniCPM-1B and MiniCPM-2B model, max_tokens should be set to 2048
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
# init model torch dtype
|
# init model torch dtype
|
||||||
torch_dtype = args.torch_dtype
|
torch_dtype = args.torch_dtype
|
||||||
if torch_dtype =="" or torch_dtype == "bfloat16":
|
if torch_dtype == "" or torch_dtype == "bfloat16":
|
||||||
torch_dtype = "bfloat16"
|
torch_dtype = torch.bfloat16
|
||||||
elif torch_dtype == "float32":
|
elif torch_dtype == "float32":
|
||||||
torch_dtype = "float32"
|
torch_dtype = torch.float32
|
||||||
|
elif torch_dtype == "float16":
|
||||||
|
torch_dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
|
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
|
||||||
|
|
||||||
# init model and tokenizer
|
# init model and tokenizer
|
||||||
path = args.model_path
|
path = args.model_path
|
||||||
llm = LLM(model=path, tensor_parallel_size=1, dtype=torch_dtype)
|
llm = LLM(
|
||||||
|
model=path,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
dtype=torch_dtype,
|
||||||
|
trust_remote_code=True,
|
||||||
|
gpu_memory_utilization=0.9,
|
||||||
|
max_model_len=args.max_tokens
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
server_name = args.server_name
|
||||||
|
server_port = args.server_port
|
||||||
|
|
||||||
# init gradio demo host and port
|
|
||||||
server_name=args.server_name
|
|
||||||
server_port=args.server_port
|
|
||||||
|
|
||||||
def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
|
def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
|
||||||
"""generate model output with huggingface api
|
"""generate model output with huggingface api
|
||||||
@ -43,19 +53,14 @@ def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
|
|||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
str: real-time generation results of hf model
|
str: real-time generation results of hf model
|
||||||
"""
|
"""
|
||||||
prompt = ""
|
|
||||||
assert len(dialog) % 2 == 1
|
assert len(dialog) % 2 == 1
|
||||||
for info in dialog:
|
prompt = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
|
||||||
if info["role"] == "user":
|
token_ids = tokenizer.convert_tokens_to_ids(["<|im_end|>"])
|
||||||
prompt += "<用户>" + info["content"]
|
|
||||||
else:
|
|
||||||
prompt += "<AI>" + info["content"]
|
|
||||||
prompt += "<AI>"
|
|
||||||
params_dict = {
|
params_dict = {
|
||||||
"n": 1,
|
"n": 1,
|
||||||
"best_of": 1,
|
"best_of": 1,
|
||||||
"presence_penalty": 1.0,
|
"presence_penalty": 1.0,
|
||||||
"frequency_penalty": 0.0,
|
"frequency_penalty": 0.0,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
@ -63,8 +68,8 @@ def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
|
|||||||
"use_beam_search": False,
|
"use_beam_search": False,
|
||||||
"length_penalty": 1,
|
"length_penalty": 1,
|
||||||
"early_stopping": False,
|
"early_stopping": False,
|
||||||
"stop": None,
|
"stop": "<|im_end|>",
|
||||||
"stop_token_ids": None,
|
"stop_token_ids": token_ids,
|
||||||
"ignore_eos": False,
|
"ignore_eos": False,
|
||||||
"max_tokens": max_dec_len,
|
"max_tokens": max_dec_len,
|
||||||
"logprobs": None,
|
"logprobs": None,
|
||||||
@ -89,7 +94,7 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, m
|
|||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
|
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
|
||||||
"""
|
"""
|
||||||
assert query != "", "Input must not be empty!!!"
|
assert query != "", "Input must not be empty!!!"
|
||||||
# apply chat template
|
# apply chat template
|
||||||
model_input = []
|
model_input = []
|
||||||
@ -114,7 +119,7 @@ def regenerate(chat_history: List, top_p: float, temperature: float, max_dec_len
|
|||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
|
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
|
||||||
"""
|
"""
|
||||||
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
|
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
|
||||||
# apply chat template
|
# apply chat template
|
||||||
model_input = []
|
model_input = []
|
||||||
@ -133,7 +138,7 @@ def clear_history():
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List: empty chat history
|
List: empty chat history
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@ -145,7 +150,7 @@ def reverse_last_round(chat_history):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round.
|
List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round.
|
||||||
"""
|
"""
|
||||||
assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!"
|
assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!"
|
||||||
return chat_history[:-1]
|
return chat_history[:-1]
|
||||||
|
|
||||||
@ -158,7 +163,7 @@ with gr.Blocks(theme="soft") as demo:
|
|||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
top_p = gr.Slider(0, 1, value=0.8, step=0.1, label="top_p")
|
top_p = gr.Slider(0, 1, value=0.8, step=0.1, label="top_p")
|
||||||
temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature")
|
temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature")
|
||||||
max_dec_len = gr.Slider(1, 1024, value=1024, step=1, label="max_dec_len")
|
max_dec_len = gr.Slider(1, args.max_tokens, value=args.max_tokens, step=1, label="max_tokens")
|
||||||
with gr.Column(scale=5):
|
with gr.Column(scale=5):
|
||||||
chatbot = gr.Chatbot(bubble_full_width=False, height=400)
|
chatbot = gr.Chatbot(bubble_full_width=False, height=400)
|
||||||
user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8)
|
user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8)
|
||||||
|
|||||||
19
requirements.txt
Normal file
19
requirements.txt
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# for MiniCPM-2B hf inference
|
||||||
|
torch>=2.0.0
|
||||||
|
transformers>=4.36.2
|
||||||
|
gradio>=4.26.0
|
||||||
|
|
||||||
|
# for vllm inference
|
||||||
|
# vllm>=0.4.0.post1
|
||||||
|
|
||||||
|
# for openai api inference
|
||||||
|
openai>=1.17.1
|
||||||
|
tiktoken>=0.6.0
|
||||||
|
loguru>=0.7.2
|
||||||
|
sentence_transformers>=2.6.1
|
||||||
|
sse_starlette>=2.1.0
|
||||||
|
|
||||||
|
# for MiniCPM-V hf inference
|
||||||
|
Pillow>=10.3.0
|
||||||
|
timm>=0.9.16
|
||||||
|
sentencepiece>=0.2.0
|
||||||
Loading…
x
Reference in New Issue
Block a user