diff --git a/quantize/awq_quantize.py b/quantize/awq_quantize.py index dc03506..38afbc9 100644 --- a/quantize/awq_quantize.py +++ b/quantize/awq_quantize.py @@ -7,9 +7,9 @@ import os model_path = '/root/ld/ld_model_pretrained/MiniCPM-1B-sft-bf16' # model_path or model_id quant_path = '/root/ld/ld_project/pull_request/MiniCPM/quantize/awq_cpm_1b_4bit' # quant_save_path -quant_data_path='/root/ld/ld_project/pull_request/MiniCPM/quantize/quantize_data/alpaca' -quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } #"w_bit":4 or 8 -quant_samples=512 #how many samples to use for calibration +quant_data_path='/root/ld/ld_project/pull_request/MiniCPM/quantize/quantize_data/wikitext' +quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } # "w_bit":4 or 8 +quant_samples=512 # how many samples to use for calibration # Load model model = AutoAWQForCausalLM.from_pretrained(model_path) @@ -17,7 +17,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True,dev # Define data loading methods def load_alpaca(quant_data_path): - data = load_dataset(quant_data_path, split="train") #Set the absolute path to alpaca or huggingface id + data = load_dataset(quant_data_path, split="train") # Set the absolute path to alpaca or huggingface id # concatenate data def concatenate_data(x): @@ -26,12 +26,12 @@ def load_alpaca(quant_data_path): concatenated = data.map(concatenate_data)[:quant_samples] return [text for text in concatenated["text"]] -def load_wikitext(): - data = load_dataset('wikitext', 'wikitext-2-raw-v1', split="train") +def load_wikitext(quant_data_path): + data = load_dataset(quant_data_path, split="train") return [text for text in data["text"] if text.strip() != '' and len(text.split(' ')) > 20][:quant_samples] # Quantize -model.quantize(tokenizer, quant_config=quant_config, calib_data=load_alpaca(quant_data_path=quant_data_path)) +model.quantize(tokenizer, quant_config=quant_config, calib_data=load_wikitext(quant_data_path=quant_data_path)) # Save quantized model model.save_quantized(quant_path)