MiniCPM/finetune/finetune.py
2024-01-31 23:25:47 +08:00

199 lines
6.3 KiB
Python

# -*- coding: utf-8 -*-
import json
import time
from typing import Dict, Optional
from datetime import datetime
from dataclasses import dataclass, field
import torch
from torch.utils.data import Dataset
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
@dataclass
class DataArguments:
train_data_path: str = field(
default="data/AdvertiseGenChatML/train.json",
metadata={"help": "Path to the training data."},
)
eval_data_path: str = field(
default="data/AdvertiseGenChatML/dev.json",
metadata={"help": "Path to the test data."},
)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
use_lora: bool = field(default=False)
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
data_path,
tokenizer,
model_max_length=4096,
user_tokens=[1786, 4194, 95388],
assistant_tokens=[1786, 10850, 95388],
):
super(SupervisedDataset, self).__init__()
self.data = json.load(open(data_path))
self.tokenizer = tokenizer
self.model_max_length = model_max_length
self.user_tokens = user_tokens
self.assistant_tokens = assistant_tokens
self.ignore_index = -100
item = self.preprocessing(self.data[0])
print("input:", self.tokenizer.decode(item["input_ids"]))
labels = []
for id_ in item["label_ids"]:
if id_ == -100:
continue
labels.append(id_)
print("label:", self.tokenizer.decode(labels))
def __len__(self):
return len(self.data)
def preprocessing(self, example):
input_ids = [self.tokenizer.bos_token_id]
label_ids = [self.tokenizer.bos_token_id]
for message in example["messages"]:
role = message["role"]
content = message["content"]
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
if role == "user":
input_ids += self.user_tokens + content_ids
label_ids += (
[self.ignore_index] * len(self.user_tokens)
+ [self.tokenizer.eos_token_id]
+ [self.ignore_index] * len(content_ids)
)
else:
input_ids += self.assistant_tokens + content_ids
label_ids += [self.ignore_index] * len(
self.assistant_tokens
) + content_ids
input_ids.append(self.tokenizer.eos_token_id)
label_ids.append(self.tokenizer.eos_token_id)
input_ids = input_ids[: self.model_max_length]
label_ids = label_ids[: self.model_max_length]
# input_ids += [self.tokenizer.eos_token_id] * (len(label_ids) - len(input_ids))
input_ids += [self.tokenizer.eos_token_id] * (
self.model_max_length - len(input_ids)
)
label_ids += [self.ignore_index] * (self.model_max_length - len(label_ids))
input_ids = torch.LongTensor(input_ids)
label_ids = torch.LongTensor(label_ids)
# print(f"len input_ids: {len(input_ids)}, len label_ids: {len(label_ids)}")
attention_mask = input_ids.ne(self.tokenizer.eos_token_id)
return {
"input_ids": input_ids,
"label_ids": label_ids,
"attention_mask": attention_mask,
}
def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
return self.preprocessing(self.data[idx])
def load_model_and_tokenizer(
model_path: str,
max_length: int = 4096,
use_lora: bool = True,
bf16: bool = False,
fp16: bool = False,
):
"""load model and tokenizer"""
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
assert not (bf16 and fp16), "bf16 or fp16, not both"
if bf16:
dtype = torch.bfloat16
elif fp16:
dtype = torch.float16
else:
dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
trust_remote_code=True,
)
if use_lora:
from peft import LoraConfig, TaskType, get_peft_model
lora_config = LoraConfig(
init_lora_weights="gaussian",
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "v_proj"],
r=8,
lora_alpha=32,
lora_dropout=0.1,
inference_mode=False,
)
model = get_peft_model(model, lora_config)
# trainable params: 2,949,120 || all params: 3,010,652,928 || trainable%: 0.09795616002669305
model.print_trainable_parameters()
model.enable_input_require_grads()
return model, tokenizer
if __name__ == "__main__":
model_path = "/mnt/data/user/tc_agi/yh/models/MiniCPM"
max_length = 512
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model, tokenizer = load_model_and_tokenizer(
model_path=model_args.model_name_or_path,
max_length=training_args.model_max_length,
use_lora=training_args.use_lora,
)
train_dataset = SupervisedDataset(
data_path=data_args.train_data_path,
tokenizer=tokenizer,
model_max_length=training_args.model_max_length,
)
eval_dataset = SupervisedDataset(
data_path=data_args.eval_data_path,
tokenizer=tokenizer,
model_max_length=training_args.model_max_length,
)
formatted_time = datetime.now().strftime("%Y%m%d%H%M%S")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
# compute_metrics=compute_metrics,
)
trainer.train()