mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-19 12:53:36 +08:00
增加了qlora的finetune
This commit is contained in:
parent
c95a1f1cb7
commit
65e0e2570a
@ -7,7 +7,7 @@ import torch
|
||||
import transformers
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
|
||||
TrainingArguments)
|
||||
TrainingArguments,BitsAndBytesConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -38,6 +38,7 @@ class TrainingArguments(transformers.TrainingArguments):
|
||||
},
|
||||
)
|
||||
use_lora: bool = field(default=False)
|
||||
qlora: bool = field(default=False)
|
||||
|
||||
|
||||
class SupervisedDataset(Dataset):
|
||||
@ -121,6 +122,7 @@ def load_model_and_tokenizer(
|
||||
model_path: str,
|
||||
max_length: int = 4096,
|
||||
use_lora: bool = True,
|
||||
qlora: bool = False,
|
||||
bf16: bool = False,
|
||||
fp16: bool = False,
|
||||
):
|
||||
@ -135,6 +137,28 @@ def load_model_and_tokenizer(
|
||||
dtype = torch.float16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
if qlora:
|
||||
assert use_lora, "use_lora must be True when use_qlora is True"
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True, # 是否进行4bit量化
|
||||
load_in_8bit=False, # 是否进行8bit量化
|
||||
bnb_4bit_compute_dtype=torch.float16, # 计算精度设置
|
||||
bnb_4bit_quant_storage=torch.uint8, # 量化权重的储存格式
|
||||
bnb_4bit_quant_type="nf4", # 量化格式,这里用的是正太分布的int4
|
||||
bnb_4bit_use_double_quant=True, # 是否采用双量化,即对zeropoint和scaling参数进行量化
|
||||
llm_int8_enable_fp32_cpu_offload=False, # 是否llm使用int8,cpu上保存的参数使用fp32
|
||||
llm_int8_has_fp16_weight=False, # 是否启用混合精度
|
||||
#llm_int8_skip_modules=["out_proj", "kv_proj", "lm_head"], # 不进行量化的模块
|
||||
llm_int8_threshold=6.0, # llm.int8()算法中的离群值,根据这个值区分是否进行量化
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=dtype,
|
||||
trust_remote_code=True,
|
||||
quantization_config=quantization_config,
|
||||
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=dtype,
|
||||
@ -170,6 +194,7 @@ if __name__ == "__main__":
|
||||
model_path=model_args.model_name_or_path,
|
||||
max_length=training_args.model_max_length,
|
||||
use_lora=training_args.use_lora,
|
||||
qlora=training_args.qlora,
|
||||
bf16=training_args.bf16,
|
||||
fp16=training_args.fp16
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user