Fix fine tune label offset bug

This commit is contained in:
Xiang Long 2024-02-01 12:51:30 +08:00
parent 6f63f1d978
commit 24a71e964f

View File

@ -73,7 +73,7 @@ class SupervisedDataset(Dataset):
def preprocessing(self, example): def preprocessing(self, example):
input_ids = [self.tokenizer.bos_token_id] input_ids = [self.tokenizer.bos_token_id]
label_ids = [self.tokenizer.bos_token_id] label_ids = []
for message in example["messages"]: for message in example["messages"]:
role = message["role"] role = message["role"]
@ -82,18 +82,17 @@ class SupervisedDataset(Dataset):
if role == "user": if role == "user":
input_ids += self.user_tokens + content_ids input_ids += self.user_tokens + content_ids
label_ids += ( label_ids += [self.ignore_index] * len(self.user_tokens) + [
[self.ignore_index] * len(self.user_tokens) self.ignore_index
+ [self.tokenizer.eos_token_id] ] * len(content_ids)
+ [self.ignore_index] * len(content_ids)
)
else: else:
input_ids += self.assistant_tokens + content_ids input_ids += self.assistant_tokens + content_ids
label_ids += [self.ignore_index] * len( label_ids += (
self.assistant_tokens [self.ignore_index] * len(self.assistant_tokens)
) + content_ids + content_ids
input_ids.append(self.tokenizer.eos_token_id) + [self.tokenizer.eos_token_id]
label_ids.append(self.tokenizer.eos_token_id) )
input_ids = input_ids[: self.model_max_length] input_ids = input_ids[: self.model_max_length]
label_ids = label_ids[: self.model_max_length] label_ids = label_ids[: self.model_max_length]
# input_ids += [self.tokenizer.eos_token_id] * (len(label_ids) - len(input_ids)) # input_ids += [self.tokenizer.eos_token_id] * (len(label_ids) - len(input_ids))