增加了qlora的finetune

This commit is contained in:
root 2024-07-26 14:57:46 +08:00
parent c95a1f1cb7
commit 65e0e2570a

View File

@ -7,7 +7,7 @@ import torch
import transformers
from torch.utils.data import Dataset
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
TrainingArguments)
TrainingArguments,BitsAndBytesConfig)
@dataclass
@ -38,6 +38,7 @@ class TrainingArguments(transformers.TrainingArguments):
},
)
use_lora: bool = field(default=False)
qlora: bool = field(default=False)
class SupervisedDataset(Dataset):
@ -121,6 +122,7 @@ def load_model_and_tokenizer(
model_path: str,
max_length: int = 4096,
use_lora: bool = True,
qlora: bool = False,
bf16: bool = False,
fp16: bool = False,
):
@ -135,11 +137,33 @@ def load_model_and_tokenizer(
dtype = torch.float16
else:
dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
trust_remote_code=True,
)
if qlora:
assert use_lora, "use_lora must be True when use_qlora is True"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # 是否进行4bit量化
load_in_8bit=False, # 是否进行8bit量化
bnb_4bit_compute_dtype=torch.float16, # 计算精度设置
bnb_4bit_quant_storage=torch.uint8, # 量化权重的储存格式
bnb_4bit_quant_type="nf4", # 量化格式这里用的是正太分布的int4
bnb_4bit_use_double_quant=True, # 是否采用双量化即对zeropoint和scaling参数进行量化
llm_int8_enable_fp32_cpu_offload=False, # 是否llm使用int8cpu上保存的参数使用fp32
llm_int8_has_fp16_weight=False, # 是否启用混合精度
#llm_int8_skip_modules=["out_proj", "kv_proj", "lm_head"], # 不进行量化的模块
llm_int8_threshold=6.0, # llm.int8()算法中的离群值,根据这个值区分是否进行量化
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
trust_remote_code=True,
quantization_config=quantization_config,
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
trust_remote_code=True,
)
if use_lora:
from peft import LoraConfig, TaskType, get_peft_model
@ -170,6 +194,7 @@ if __name__ == "__main__":
model_path=model_args.model_name_or_path,
max_length=training_args.model_max_length,
use_lora=training_args.use_lora,
qlora=training_args.qlora,
bf16=training_args.bf16,
fp16=training_args.fp16
)