diff --git a/demo/hf_based_demo.py b/demo/hf_based_demo.py index 78a5c9b..3b04cf4 100644 --- a/demo/hf_based_demo.py +++ b/demo/hf_based_demo.py @@ -1,7 +1,4 @@ -from typing import Dict from typing import List -from typing import Tuple - import argparse import gradio as gr import torch @@ -16,7 +13,7 @@ import warnings warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') parser = argparse.ArgumentParser() -parser.add_argument("--model_path", type=str, default="") +parser.add_argument("--model_path", type=str, default="openbmb/MiniCPM-2B-dpo-fp16") parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"]) parser.add_argument("--server_name", type=str, default="127.0.0.1") parser.add_argument("--server_port", type=int, default=7860) @@ -55,7 +52,7 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f str: real-time generation results of hf model """ inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False) - enc = tokenizer(inputs, return_tensors="pt").to("cuda") + enc = tokenizer(inputs, return_tensors="pt").to(next(model.parameters()).device) streamer = TextIteratorStreamer(tokenizer) generation_kwargs = dict( enc, diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ffed218 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +# for MiniCPM-2B hf inference +torch>=2.0.0 +transformers>=4.36.2 +gradio>=4.26.0 + +# for openai api inference +openai>=1.17.1 +tiktoken>=0.6.0 +loguru>=0.7.2 +sentence_transformers>=2.6.1 +sse_starlette>=2.1.0 + +# for MiniCPM-V hf inference +Pillow>=10.3.0 +timm>=0.9.16 +sentencepiece>=0.2.0 \ No newline at end of file