add wikitext dataset to quant

This commit is contained in:
root 2024-06-24 11:26:07 +08:00
parent f062357093
commit 5f239e3742

View File

@ -7,7 +7,7 @@ import os
model_path = '/root/ld/ld_model_pretrained/MiniCPM-1B-sft-bf16' # model_path or model_id model_path = '/root/ld/ld_model_pretrained/MiniCPM-1B-sft-bf16' # model_path or model_id
quant_path = '/root/ld/ld_project/pull_request/MiniCPM/quantize/awq_cpm_1b_4bit' # quant_save_path quant_path = '/root/ld/ld_project/pull_request/MiniCPM/quantize/awq_cpm_1b_4bit' # quant_save_path
quant_data_path='/root/ld/ld_project/pull_request/MiniCPM/quantize/quantize_data/alpaca' quant_data_path='/root/ld/ld_project/pull_request/MiniCPM/quantize/quantize_data/wikitext'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } # "w_bit":4 or 8 quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } # "w_bit":4 or 8
quant_samples=512 # how many samples to use for calibration quant_samples=512 # how many samples to use for calibration
@ -26,12 +26,12 @@ def load_alpaca(quant_data_path):
concatenated = data.map(concatenate_data)[:quant_samples] concatenated = data.map(concatenate_data)[:quant_samples]
return [text for text in concatenated["text"]] return [text for text in concatenated["text"]]
def load_wikitext(): def load_wikitext(quant_data_path):
data = load_dataset('wikitext', 'wikitext-2-raw-v1', split="train") data = load_dataset(quant_data_path, split="train")
return [text for text in data["text"] if text.strip() != '' and len(text.split(' ')) > 20][:quant_samples] return [text for text in data["text"] if text.strip() != '' and len(text.split(' ')) > 20][:quant_samples]
# Quantize # Quantize
model.quantize(tokenizer, quant_config=quant_config, calib_data=load_alpaca(quant_data_path=quant_data_path)) model.quantize(tokenizer, quant_config=quant_config, calib_data=load_wikitext(quant_data_path=quant_data_path))
# Save quantized model # Save quantized model
model.save_quantized(quant_path) model.save_quantized(quant_path)