mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-19 12:53:36 +08:00
增加了qlora的finetune
This commit is contained in:
parent
c95a1f1cb7
commit
65e0e2570a
@ -7,7 +7,7 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
|
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
|
||||||
TrainingArguments)
|
TrainingArguments,BitsAndBytesConfig)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -38,6 +38,7 @@ class TrainingArguments(transformers.TrainingArguments):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
use_lora: bool = field(default=False)
|
use_lora: bool = field(default=False)
|
||||||
|
qlora: bool = field(default=False)
|
||||||
|
|
||||||
|
|
||||||
class SupervisedDataset(Dataset):
|
class SupervisedDataset(Dataset):
|
||||||
@ -121,6 +122,7 @@ def load_model_and_tokenizer(
|
|||||||
model_path: str,
|
model_path: str,
|
||||||
max_length: int = 4096,
|
max_length: int = 4096,
|
||||||
use_lora: bool = True,
|
use_lora: bool = True,
|
||||||
|
qlora: bool = False,
|
||||||
bf16: bool = False,
|
bf16: bool = False,
|
||||||
fp16: bool = False,
|
fp16: bool = False,
|
||||||
):
|
):
|
||||||
@ -135,11 +137,33 @@ def load_model_and_tokenizer(
|
|||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
if qlora:
|
||||||
model_path,
|
assert use_lora, "use_lora must be True when use_qlora is True"
|
||||||
torch_dtype=dtype,
|
quantization_config = BitsAndBytesConfig(
|
||||||
trust_remote_code=True,
|
load_in_4bit=True, # 是否进行4bit量化
|
||||||
)
|
load_in_8bit=False, # 是否进行8bit量化
|
||||||
|
bnb_4bit_compute_dtype=torch.float16, # 计算精度设置
|
||||||
|
bnb_4bit_quant_storage=torch.uint8, # 量化权重的储存格式
|
||||||
|
bnb_4bit_quant_type="nf4", # 量化格式,这里用的是正太分布的int4
|
||||||
|
bnb_4bit_use_double_quant=True, # 是否采用双量化,即对zeropoint和scaling参数进行量化
|
||||||
|
llm_int8_enable_fp32_cpu_offload=False, # 是否llm使用int8,cpu上保存的参数使用fp32
|
||||||
|
llm_int8_has_fp16_weight=False, # 是否启用混合精度
|
||||||
|
#llm_int8_skip_modules=["out_proj", "kv_proj", "lm_head"], # 不进行量化的模块
|
||||||
|
llm_int8_threshold=6.0, # llm.int8()算法中的离群值,根据这个值区分是否进行量化
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
trust_remote_code=True,
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
if use_lora:
|
if use_lora:
|
||||||
from peft import LoraConfig, TaskType, get_peft_model
|
from peft import LoraConfig, TaskType, get_peft_model
|
||||||
|
|
||||||
@ -170,6 +194,7 @@ if __name__ == "__main__":
|
|||||||
model_path=model_args.model_name_or_path,
|
model_path=model_args.model_name_or_path,
|
||||||
max_length=training_args.model_max_length,
|
max_length=training_args.model_max_length,
|
||||||
use_lora=training_args.use_lora,
|
use_lora=training_args.use_lora,
|
||||||
|
qlora=training_args.qlora,
|
||||||
bf16=training_args.bf16,
|
bf16=training_args.bf16,
|
||||||
fp16=training_args.fp16
|
fp16=training_args.fp16
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user