Merge pull request #21 from soulteary/feat/allow-set-torch-dtype

feat: allow user set torch dtype #20
This commit is contained in:
ywfang 2024-02-02 17:00:45 +08:00 committed by GitHub
commit d614cdcd35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 4 deletions

View File

@ -14,15 +14,24 @@ from transformers import (
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"]))
parser.add_argument("--server_name", type=str, default="127.0.0.1")
parser.add_argument("--server_port", type=int, default=7860)
args = parser.parse_args()
# init model torch dtype
torch_dtype = args.torch_dtype
if torch_dtype =="" or torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
else:
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
# init model and tokenizer
path = args.model_path
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True)
# init gradio demo host and port
server_name=args.server_name

View File

@ -9,14 +9,24 @@ from vllm import LLM, SamplingParams
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"]))
parser.add_argument("--server_name", type=str, default="127.0.0.1")
parser.add_argument("--server_port", type=int, default=7860)
args = parser.parse_args()
# init model torch dtype
torch_dtype = args.torch_dtype
if torch_dtype =="" or torch_dtype == "bfloat16":
torch_dtype = "bfloat16"
elif torch_dtype == "float32":
torch_dtype = "float32"
else:
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
# init model and tokenizer
path = args.model_path
llm = LLM(model=path, tensor_parallel_size=1, dtype="bfloat16")
llm = LLM(model=path, tensor_parallel_size=1, dtype=torch_dtype)
# init gradio demo host and port
server_name=args.server_name