feat: allow user set torch dtype

This commit is contained in:
Su Yang 2024-02-02 14:46:41 +08:00
parent 7333ec793b
commit d64fe362fc
No known key found for this signature in database
GPG Key ID: DBCDD8CBF440F8DE
2 changed files with 23 additions and 2 deletions

View File

@ -14,12 +14,22 @@ from transformers import (
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
args = parser.parse_args()
# init model torch dtype
torch_dtype = args.torch_dtype
if torch_dtype =="" or torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
else:
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
# init model and tokenizer
path = args.model_path
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True)
def hf_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):

View File

@ -9,11 +9,22 @@ from vllm import LLM, SamplingParams
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
args = parser.parse_args()
# init model torch dtype
torch_dtype = args.torch_dtype
if torch_dtype =="" or torch_dtype == "bfloat16":
torch_dtype = "bfloat16"
elif torch_dtype == "float32":
torch_dtype = "float32"
else:
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
# init model and tokenizer
path = args.model_path
llm = LLM(model=path, tensor_parallel_size=1, dtype="bfloat16")
llm = LLM(model=path, tensor_parallel_size=1, dtype=torch_dtype)
def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):