diff --git a/quantize/awq_quantize.py b/quantize/awq_quantize.py index 38afbc9..0a1543b 100644 --- a/quantize/awq_quantize.py +++ b/quantize/awq_quantize.py @@ -7,10 +7,11 @@ import os model_path = '/root/ld/ld_model_pretrained/MiniCPM-1B-sft-bf16' # model_path or model_id quant_path = '/root/ld/ld_project/pull_request/MiniCPM/quantize/awq_cpm_1b_4bit' # quant_save_path -quant_data_path='/root/ld/ld_project/pull_request/MiniCPM/quantize/quantize_data/wikitext' +quant_data_path='/root/ld/ld_project/pull_request/MiniCPM/quantize/quantize_data/wikitext'# 写入自带 quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } # "w_bit":4 or 8 quant_samples=512 # how many samples to use for calibration - +custom_data=[{'question':'你叫什么名字。','answer':'我是openmbmb开源的小钢炮minicpm。'}, + {'question':'你有什么特色。','answer':'我很小,但是我很强。'}] # Load model model = AutoAWQForCausalLM.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True,device_map={"": "cuda:0"}) @@ -21,7 +22,7 @@ def load_alpaca(quant_data_path): # concatenate data def concatenate_data(x): - return {"text": '<用户>'+x['instruction'] + '' + x['input'] + '\n' + x['output']} + return {"text": '<用户>'+x['instruction'] + x['input'] + '' + '\n' + x['output']} concatenated = data.map(concatenate_data)[:quant_samples] return [text for text in concatenated["text"]] @@ -30,6 +31,9 @@ def load_wikitext(quant_data_path): data = load_dataset(quant_data_path, split="train") return [text for text in data["text"] if text.strip() != '' and len(text.split(' ')) > 20][:quant_samples] +def load_cust_data(custom_data): + quant_data=['<用户>'+i['question'] + '' + i['answer'] + '' for i in custom_data] + return quant_data[:quant_samples] # Quantize model.quantize(tokenizer, quant_config=quant_config, calib_data=load_wikitext(quant_data_path=quant_data_path))