diff --git a/finetune/finetune.py b/finetune/finetune.py index 3d050db..2cbd6cc 100644 --- a/finetune/finetune.py +++ b/finetune/finetune.py @@ -73,7 +73,7 @@ class SupervisedDataset(Dataset): def preprocessing(self, example): input_ids = [self.tokenizer.bos_token_id] - label_ids = [self.tokenizer.bos_token_id] + label_ids = [] for message in example["messages"]: role = message["role"] @@ -82,18 +82,17 @@ class SupervisedDataset(Dataset): if role == "user": input_ids += self.user_tokens + content_ids - label_ids += ( - [self.ignore_index] * len(self.user_tokens) - + [self.tokenizer.eos_token_id] - + [self.ignore_index] * len(content_ids) - ) + label_ids += [self.ignore_index] * len(self.user_tokens) + [ + self.ignore_index + ] * len(content_ids) else: input_ids += self.assistant_tokens + content_ids - label_ids += [self.ignore_index] * len( - self.assistant_tokens - ) + content_ids - input_ids.append(self.tokenizer.eos_token_id) - label_ids.append(self.tokenizer.eos_token_id) + label_ids += ( + [self.ignore_index] * len(self.assistant_tokens) + + content_ids + + [self.tokenizer.eos_token_id] + ) + input_ids = input_ids[: self.model_max_length] label_ids = label_ids[: self.model_max_length] # input_ids += [self.tokenizer.eos_token_id] * (len(label_ids) - len(input_ids))