From 262840a805e2f8b234d4f8ebd05e7c4cc318b53e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E4=B8=B9?= Date: Tue, 25 Jun 2024 10:29:25 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86finetune=E4=B8=AD?= =?UTF-8?q?=E7=9A=84=E9=BB=98=E8=AE=A4=E6=A8=A1=E5=9E=8B=E7=9A=84=E9=94=99?= =?UTF-8?q?=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- finetune/finetune.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/finetune/finetune.py b/finetune/finetune.py index e99e0e2..5fabdcd 100644 --- a/finetune/finetune.py +++ b/finetune/finetune.py @@ -6,12 +6,13 @@ from typing import Dict, Optional import torch import transformers from torch.utils.data import Dataset -from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments +from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer, + TrainingArguments) @dataclass class ModelArguments: - model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base") + model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-2B-sft-bf16") @dataclass @@ -41,25 +42,21 @@ class TrainingArguments(transformers.TrainingArguments): class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" - + def __init__( self, data_path, tokenizer, model_max_length=4096, - user_tokens="<用户>", - assistant_tokens="", + user_tokens='<用户>', + assistant_tokens='', ): super(SupervisedDataset, self).__init__() self.data = json.load(open(data_path)) self.tokenizer = tokenizer self.model_max_length = model_max_length - self.user_tokens = self.tokenizer.encode( - user_tokens - ) # 针对不同模型,都可以对应到<用户>的id - self.assistant_tokens = self.tokenizer.encode( - assistant_tokens - ) # 针对不同模型,都可以对应到的id + self.user_tokens = self.tokenizer.encode(user_tokens) #针对不同模型,都可以对应到<用户>的id + self.assistant_tokens = self.tokenizer.encode(assistant_tokens) #针对不同模型,都可以对应到的id self.ignore_index = -100 item = self.preprocessing(self.data[0]) print("input:", self.tokenizer.decode(item["input_ids"])) @@ -89,9 +86,10 @@ class SupervisedDataset(Dataset): ] * len(content_ids) else: input_ids += self.assistant_tokens + content_ids - label_ids += [self.ignore_index] * len( - self.assistant_tokens - ) + content_ids + label_ids += ( + [self.ignore_index] * len(self.assistant_tokens) + + content_ids + ) input_ids.append(self.tokenizer.eos_token_id) label_ids.append(self.tokenizer.eos_token_id) @@ -173,7 +171,7 @@ if __name__ == "__main__": max_length=training_args.model_max_length, use_lora=training_args.use_lora, bf16=training_args.bf16, - fp16=training_args.fp16, + fp16=training_args.fp16 ) train_dataset = SupervisedDataset( @@ -197,4 +195,4 @@ if __name__ == "__main__": trainer.train() # save the incremental PEFT weights, more details can be found in https://huggingface.co/blog/peft - # model.save_pretrained("output_dir") + # model.save_pretrained("output_dir") \ No newline at end of file