diff --git a/finetune/finetune.py b/finetune/finetune.py index 5fabdcd..2c5fb02 100644 --- a/finetune/finetune.py +++ b/finetune/finetune.py @@ -7,7 +7,7 @@ import torch import transformers from torch.utils.data import Dataset from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer, - TrainingArguments) + TrainingArguments,BitsAndBytesConfig) @dataclass @@ -38,6 +38,7 @@ class TrainingArguments(transformers.TrainingArguments): }, ) use_lora: bool = field(default=False) + qlora: bool = field(default=False) class SupervisedDataset(Dataset): @@ -121,6 +122,7 @@ def load_model_and_tokenizer( model_path: str, max_length: int = 4096, use_lora: bool = True, + qlora: bool = False, bf16: bool = False, fp16: bool = False, ): @@ -135,11 +137,33 @@ def load_model_and_tokenizer( dtype = torch.float16 else: dtype = torch.float32 - model = AutoModelForCausalLM.from_pretrained( - model_path, - torch_dtype=dtype, - trust_remote_code=True, - ) + if qlora: + assert use_lora, "use_lora must be True when use_qlora is True" + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, # 是否进行4bit量化 + load_in_8bit=False, # 是否进行8bit量化 + bnb_4bit_compute_dtype=torch.float16, # 计算精度设置 + bnb_4bit_quant_storage=torch.uint8, # 量化权重的储存格式 + bnb_4bit_quant_type="nf4", # 量化格式,这里用的是正太分布的int4 + bnb_4bit_use_double_quant=True, # 是否采用双量化,即对zeropoint和scaling参数进行量化 + llm_int8_enable_fp32_cpu_offload=False, # 是否llm使用int8,cpu上保存的参数使用fp32 + llm_int8_has_fp16_weight=False, # 是否启用混合精度 + #llm_int8_skip_modules=["out_proj", "kv_proj", "lm_head"], # 不进行量化的模块 + llm_int8_threshold=6.0, # llm.int8()算法中的离群值,根据这个值区分是否进行量化 + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, + trust_remote_code=True, + quantization_config=quantization_config, + + ) + else: + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, + trust_remote_code=True, + ) if use_lora: from peft import LoraConfig, TaskType, get_peft_model @@ -170,6 +194,7 @@ if __name__ == "__main__": model_path=model_args.model_name_or_path, max_length=training_args.model_max_length, use_lora=training_args.use_lora, + qlora=training_args.qlora, bf16=training_args.bf16, fp16=training_args.fp16 )