Fix fine tune label offset bug

This commit is contained in:
Xiang Long 2024-02-01 12:51:30 +08:00
parent 6f63f1d978
commit 24a71e964f

View File

@ -73,7 +73,7 @@ class SupervisedDataset(Dataset):
def preprocessing(self, example):
input_ids = [self.tokenizer.bos_token_id]
label_ids = [self.tokenizer.bos_token_id]
label_ids = []
for message in example["messages"]:
role = message["role"]
@ -82,18 +82,17 @@ class SupervisedDataset(Dataset):
if role == "user":
input_ids += self.user_tokens + content_ids
label_ids += (
[self.ignore_index] * len(self.user_tokens)
+ [self.tokenizer.eos_token_id]
+ [self.ignore_index] * len(content_ids)
)
label_ids += [self.ignore_index] * len(self.user_tokens) + [
self.ignore_index
] * len(content_ids)
else:
input_ids += self.assistant_tokens + content_ids
label_ids += [self.ignore_index] * len(
self.assistant_tokens
) + content_ids
input_ids.append(self.tokenizer.eos_token_id)
label_ids.append(self.tokenizer.eos_token_id)
label_ids += (
[self.ignore_index] * len(self.assistant_tokens)
+ content_ids
+ [self.tokenizer.eos_token_id]
)
input_ids = input_ids[: self.model_max_length]
label_ids = label_ids[: self.model_max_length]
# input_ids += [self.tokenizer.eos_token_id] * (len(label_ids) - len(input_ids))