增加了qlora的finetune

This commit is contained in:
root 2024-07-26 14:57:46 +08:00
parent c95a1f1cb7
commit 65e0e2570a

View File

@ -7,7 +7,7 @@ import torch
import transformers import transformers
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer, from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
TrainingArguments) TrainingArguments,BitsAndBytesConfig)
@dataclass @dataclass
@ -38,6 +38,7 @@ class TrainingArguments(transformers.TrainingArguments):
}, },
) )
use_lora: bool = field(default=False) use_lora: bool = field(default=False)
qlora: bool = field(default=False)
class SupervisedDataset(Dataset): class SupervisedDataset(Dataset):
@ -121,6 +122,7 @@ def load_model_and_tokenizer(
model_path: str, model_path: str,
max_length: int = 4096, max_length: int = 4096,
use_lora: bool = True, use_lora: bool = True,
qlora: bool = False,
bf16: bool = False, bf16: bool = False,
fp16: bool = False, fp16: bool = False,
): ):
@ -135,11 +137,33 @@ def load_model_and_tokenizer(
dtype = torch.float16 dtype = torch.float16
else: else:
dtype = torch.float32 dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained( if qlora:
model_path, assert use_lora, "use_lora must be True when use_qlora is True"
torch_dtype=dtype, quantization_config = BitsAndBytesConfig(
trust_remote_code=True, load_in_4bit=True, # 是否进行4bit量化
) load_in_8bit=False, # 是否进行8bit量化
bnb_4bit_compute_dtype=torch.float16, # 计算精度设置
bnb_4bit_quant_storage=torch.uint8, # 量化权重的储存格式
bnb_4bit_quant_type="nf4", # 量化格式这里用的是正太分布的int4
bnb_4bit_use_double_quant=True, # 是否采用双量化即对zeropoint和scaling参数进行量化
llm_int8_enable_fp32_cpu_offload=False, # 是否llm使用int8cpu上保存的参数使用fp32
llm_int8_has_fp16_weight=False, # 是否启用混合精度
#llm_int8_skip_modules=["out_proj", "kv_proj", "lm_head"], # 不进行量化的模块
llm_int8_threshold=6.0, # llm.int8()算法中的离群值,根据这个值区分是否进行量化
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
trust_remote_code=True,
quantization_config=quantization_config,
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
trust_remote_code=True,
)
if use_lora: if use_lora:
from peft import LoraConfig, TaskType, get_peft_model from peft import LoraConfig, TaskType, get_peft_model
@ -170,6 +194,7 @@ if __name__ == "__main__":
model_path=model_args.model_name_or_path, model_path=model_args.model_name_or_path,
max_length=training_args.model_max_length, max_length=training_args.model_max_length,
use_lora=training_args.use_lora, use_lora=training_args.use_lora,
qlora=training_args.qlora,
bf16=training_args.bf16, bf16=training_args.bf16,
fp16=training_args.fp16 fp16=training_args.fp16
) )