MiniCPM/finetune/finetune.py
2024-03-06 17:25:41 +08:00

194 lines
6.2 KiB
Python

# -*- coding: utf-8 -*-
import json
from dataclasses import dataclass, field
from typing import Dict, Optional
import torch
import transformers
from torch.utils.data import Dataset
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
TrainingArguments)
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
@dataclass
class DataArguments:
train_data_path: str = field(
default="data/AdvertiseGenChatML/train.json",
metadata={"help": "Path to the training data."},
)
eval_data_path: str = field(
default="data/AdvertiseGenChatML/dev.json",
metadata={"help": "Path to the test data."},
)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
use_lora: bool = field(default=False)
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
data_path,
tokenizer,
model_max_length=4096,
user_tokens=[1786, 4194, 95388],
assistant_tokens=[1786, 10850, 95388],
):
super(SupervisedDataset, self).__init__()
self.data = json.load(open(data_path))
self.tokenizer = tokenizer
self.model_max_length = model_max_length
self.user_tokens = user_tokens
self.assistant_tokens = assistant_tokens
self.ignore_index = -100
item = self.preprocessing(self.data[0])
print("input:", self.tokenizer.decode(item["input_ids"]))
labels = []
for id_ in item["label_ids"]:
if id_ == -100:
continue
labels.append(id_)
print("label:", self.tokenizer.decode(labels))
def __len__(self):
return len(self.data)
def preprocessing(self, example):
input_ids = [self.tokenizer.bos_token_id]
label_ids = []
for message in example["messages"]:
role = message["role"]
content = message["content"]
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
if role == "user":
input_ids += self.user_tokens + content_ids
label_ids += [self.ignore_index] * len(self.user_tokens) + [
self.ignore_index
] * len(content_ids)
else:
input_ids += self.assistant_tokens + content_ids
label_ids += (
[self.ignore_index] * len(self.assistant_tokens)
+ content_ids
)
input_ids = input_ids[: self.model_max_length]
label_ids = label_ids[: self.model_max_length]
# input_ids += [self.tokenizer.eos_token_id] * (len(label_ids) - len(input_ids))
input_ids += [self.tokenizer.eos_token_id] * (
self.model_max_length - len(input_ids)
)
label_ids += [self.ignore_index] * (self.model_max_length - len(label_ids))
input_ids = torch.LongTensor(input_ids)
label_ids = torch.LongTensor(label_ids)
# print(f"len input_ids: {len(input_ids)}, len label_ids: {len(label_ids)}")
attention_mask = input_ids.ne(self.tokenizer.eos_token_id)
return {
"input_ids": input_ids,
"label_ids": label_ids,
"attention_mask": attention_mask,
}
def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
return self.preprocessing(self.data[idx])
def load_model_and_tokenizer(
model_path: str,
max_length: int = 4096,
use_lora: bool = True,
bf16: bool = False,
fp16: bool = False,
):
"""load model and tokenizer"""
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
assert not (bf16 and fp16), "bf16 or fp16, not both"
if bf16:
dtype = torch.bfloat16
elif fp16:
dtype = torch.float16
else:
dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
trust_remote_code=True,
)
if use_lora:
from peft import LoraConfig, TaskType, get_peft_model
lora_config = LoraConfig(
init_lora_weights="gaussian",
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "v_proj"],
r=8,
lora_alpha=32,
lora_dropout=0.1,
inference_mode=False,
)
model = get_peft_model(model, lora_config)
# trainable params: 2,949,120 || all params: 3,010,652,928 || trainable%: 0.09795616002669305
model.print_trainable_parameters()
# model.enable_input_require_grads() # need when using adapter
return model, tokenizer
if __name__ == "__main__":
model_path = "/mnt/data/user/tc_agi/yh/models/MiniCPM"
max_length = 512
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model, tokenizer = load_model_and_tokenizer(
model_path=model_args.model_name_or_path,
max_length=training_args.model_max_length,
use_lora=training_args.use_lora,
)
train_dataset = SupervisedDataset(
data_path=data_args.train_data_path,
tokenizer=tokenizer,
model_max_length=training_args.model_max_length,
)
eval_dataset = SupervisedDataset(
data_path=data_args.eval_data_path,
tokenizer=tokenizer,
model_max_length=training_args.model_max_length,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
)
trainer.train()
# save the incremental PEFT weights, more details can be found in https://huggingface.co/blog/peft
# model.save_pretrained("output_dir")