# -*- coding: utf-8 -*- import json from dataclasses import dataclass, field from typing import Dict, Optional import torch import transformers from torch.utils.data import Dataset from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments) @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base") @dataclass class DataArguments: train_data_path: str = field( default="data/AdvertiseGenChatML/train.json", metadata={"help": "Path to the training data."}, ) eval_data_path: str = field( default="data/AdvertiseGenChatML/dev.json", metadata={"help": "Path to the test data."}, ) @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) use_lora: bool = field(default=False) class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__( self, data_path, tokenizer, model_max_length=4096, user_tokens=[1786, 4194, 95388], assistant_tokens=[1786, 10850, 95388], ): super(SupervisedDataset, self).__init__() self.data = json.load(open(data_path)) self.tokenizer = tokenizer self.model_max_length = model_max_length self.user_tokens = user_tokens self.assistant_tokens = assistant_tokens self.ignore_index = -100 item = self.preprocessing(self.data[0]) print("input:", self.tokenizer.decode(item["input_ids"])) labels = [] for id_ in item["label_ids"]: if id_ == -100: continue labels.append(id_) print("label:", self.tokenizer.decode(labels)) def __len__(self): return len(self.data) def preprocessing(self, example): input_ids = [self.tokenizer.bos_token_id] label_ids = [] for message in example["messages"]: role = message["role"] content = message["content"] content_ids = self.tokenizer.encode(content, add_special_tokens=False) if role == "user": input_ids += self.user_tokens + content_ids label_ids += [self.ignore_index] * len(self.user_tokens) + [ self.ignore_index ] * len(content_ids) else: input_ids += self.assistant_tokens + content_ids label_ids += ( [self.ignore_index] * len(self.assistant_tokens) + content_ids ) input_ids = input_ids[: self.model_max_length] label_ids = label_ids[: self.model_max_length] # input_ids += [self.tokenizer.eos_token_id] * (len(label_ids) - len(input_ids)) input_ids += [self.tokenizer.eos_token_id] * ( self.model_max_length - len(input_ids) ) label_ids += [self.ignore_index] * (self.model_max_length - len(label_ids)) input_ids = torch.LongTensor(input_ids) label_ids = torch.LongTensor(label_ids) # print(f"len input_ids: {len(input_ids)}, len label_ids: {len(label_ids)}") attention_mask = input_ids.ne(self.tokenizer.eos_token_id) return { "input_ids": input_ids, "label_ids": label_ids, "attention_mask": attention_mask, } def __getitem__(self, idx) -> Dict[str, torch.Tensor]: return self.preprocessing(self.data[idx]) def load_model_and_tokenizer( model_path: str, max_length: int = 4096, use_lora: bool = True, bf16: bool = False, fp16: bool = False, ): """load model and tokenizer""" tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token assert not (bf16 and fp16), "bf16 or fp16, not both" if bf16: dtype = torch.bfloat16 elif fp16: dtype = torch.float16 else: dtype = torch.float32 model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=dtype, trust_remote_code=True, ) if use_lora: from peft import LoraConfig, TaskType, get_peft_model lora_config = LoraConfig( init_lora_weights="gaussian", task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "v_proj"], r=8, lora_alpha=32, lora_dropout=0.1, inference_mode=False, ) model = get_peft_model(model, lora_config) # trainable params: 2,949,120 || all params: 3,010,652,928 || trainable%: 0.09795616002669305 model.print_trainable_parameters() # model.enable_input_require_grads() # need when using adapter return model, tokenizer if __name__ == "__main__": model_path = "/mnt/data/user/tc_agi/yh/models/MiniCPM" max_length = 512 parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments) ) model_args, data_args, training_args = parser.parse_args_into_dataclasses() model, tokenizer = load_model_and_tokenizer( model_path=model_args.model_name_or_path, max_length=training_args.model_max_length, use_lora=training_args.use_lora, ) train_dataset = SupervisedDataset( data_path=data_args.train_data_path, tokenizer=tokenizer, model_max_length=training_args.model_max_length, ) eval_dataset = SupervisedDataset( data_path=data_args.eval_data_path, tokenizer=tokenizer, model_max_length=training_args.model_max_length, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, ) trainer.train() # save the incremental PEFT weights, more details can be found in https://huggingface.co/blog/peft # model.save_pretrained("output_dir")