diff --git a/finetune/finetune.py b/finetune/finetune.py index 7008ff2..57ceab4 100644 --- a/finetune/finetune.py +++ b/finetune/finetune.py @@ -42,21 +42,21 @@ class TrainingArguments(transformers.TrainingArguments): class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" - + def __init__( self, data_path, tokenizer, model_max_length=4096, - user_tokens=[1786, 4194, 95388], - assistant_tokens=[1786, 10850, 95388], + user_tokens='<用户>', + assistant_tokens='', ): super(SupervisedDataset, self).__init__() self.data = json.load(open(data_path)) self.tokenizer = tokenizer self.model_max_length = model_max_length - self.user_tokens = user_tokens - self.assistant_tokens = assistant_tokens + self.user_tokens = self.tokenizer.encode(user_tokens) #针对不同模型,都可以对应到<用户>的id + self.assistant_tokens = self.tokenizer.encode(assistant_tokens) #针对不同模型,都可以对应到的id self.ignore_index = -100 item = self.preprocessing(self.data[0]) print("input:", self.tokenizer.decode(item["input_ids"])) @@ -64,7 +64,6 @@ class SupervisedDataset(Dataset): for id_ in item["label_ids"]: if id_ == -100: continue - labels.append(id_) print("label:", self.tokenizer.decode(labels))