mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-19 21:03:39 +08:00
fix bug on quantize_eval.py
This commit is contained in:
parent
2a86c6a287
commit
e53c77c7e8
@ -3,27 +3,27 @@ import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
#import GPUtil
|
||||
import GPUtil
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="========量化困惑度测试========")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default='/root/ld/ld_model_pretrained/miniCPM-bf16',
|
||||
default='',
|
||||
help="未量化前的模型路径。"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--awq_path",
|
||||
type=str,
|
||||
default='/root/ld/ld_project/pull_request/MiniCPM/quantize/awq_cpm_2b_4bit',
|
||||
default='',
|
||||
help="awq量化后的模型保存路径。"
|
||||
)
|
||||
#we will support gptq later
|
||||
parser.add_argument(
|
||||
"--gptq_path",
|
||||
type=str,
|
||||
default='/root/ld/ld_project/AutoGPTQ/examples/quantization/minicpm_2b_4bit',
|
||||
default='',
|
||||
help="gptq量化后的模型保存路径。"
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -82,7 +82,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_path:
|
||||
if args.model_path != "":
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16, device_map='cuda', trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
print("pretrained model:",args.model_path.split('/')[-1])
|
||||
@ -91,7 +91,7 @@ if __name__ == "__main__":
|
||||
evaluate_perplexity(model, tokenizer, args.data_path)
|
||||
del model
|
||||
|
||||
if args.awq_path:
|
||||
if args.awq_path != "":
|
||||
from awq import AutoAWQForCausalLM
|
||||
|
||||
model = AutoAWQForCausalLM.from_quantized(args.awq_path, fuse_layers=True,device_map={"":'cuda:0'})
|
||||
@ -103,7 +103,7 @@ if __name__ == "__main__":
|
||||
del model
|
||||
|
||||
#we will support the autogptq later
|
||||
if args.gptq_path:
|
||||
if args.gptq_path != "":
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.gptq_path, use_fast=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user