mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-19 12:53:36 +08:00
Fix fine tune label offset bug
This commit is contained in:
parent
6f63f1d978
commit
24a71e964f
@ -73,7 +73,7 @@ class SupervisedDataset(Dataset):
|
||||
|
||||
def preprocessing(self, example):
|
||||
input_ids = [self.tokenizer.bos_token_id]
|
||||
label_ids = [self.tokenizer.bos_token_id]
|
||||
label_ids = []
|
||||
|
||||
for message in example["messages"]:
|
||||
role = message["role"]
|
||||
@ -82,18 +82,17 @@ class SupervisedDataset(Dataset):
|
||||
|
||||
if role == "user":
|
||||
input_ids += self.user_tokens + content_ids
|
||||
label_ids += (
|
||||
[self.ignore_index] * len(self.user_tokens)
|
||||
+ [self.tokenizer.eos_token_id]
|
||||
+ [self.ignore_index] * len(content_ids)
|
||||
)
|
||||
label_ids += [self.ignore_index] * len(self.user_tokens) + [
|
||||
self.ignore_index
|
||||
] * len(content_ids)
|
||||
else:
|
||||
input_ids += self.assistant_tokens + content_ids
|
||||
label_ids += [self.ignore_index] * len(
|
||||
self.assistant_tokens
|
||||
) + content_ids
|
||||
input_ids.append(self.tokenizer.eos_token_id)
|
||||
label_ids.append(self.tokenizer.eos_token_id)
|
||||
label_ids += (
|
||||
[self.ignore_index] * len(self.assistant_tokens)
|
||||
+ content_ids
|
||||
+ [self.tokenizer.eos_token_id]
|
||||
)
|
||||
|
||||
input_ids = input_ids[: self.model_max_length]
|
||||
label_ids = label_ids[: self.model_max_length]
|
||||
# input_ids += [self.tokenizer.eos_token_id] * (len(label_ids) - len(input_ids))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user