mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-31 20:03:13 +08:00
Fix fine tune label offset bug
This commit is contained in:
parent
6f63f1d978
commit
24a71e964f
@ -73,7 +73,7 @@ class SupervisedDataset(Dataset):
|
|||||||
|
|
||||||
def preprocessing(self, example):
|
def preprocessing(self, example):
|
||||||
input_ids = [self.tokenizer.bos_token_id]
|
input_ids = [self.tokenizer.bos_token_id]
|
||||||
label_ids = [self.tokenizer.bos_token_id]
|
label_ids = []
|
||||||
|
|
||||||
for message in example["messages"]:
|
for message in example["messages"]:
|
||||||
role = message["role"]
|
role = message["role"]
|
||||||
@ -82,18 +82,17 @@ class SupervisedDataset(Dataset):
|
|||||||
|
|
||||||
if role == "user":
|
if role == "user":
|
||||||
input_ids += self.user_tokens + content_ids
|
input_ids += self.user_tokens + content_ids
|
||||||
label_ids += (
|
label_ids += [self.ignore_index] * len(self.user_tokens) + [
|
||||||
[self.ignore_index] * len(self.user_tokens)
|
self.ignore_index
|
||||||
+ [self.tokenizer.eos_token_id]
|
] * len(content_ids)
|
||||||
+ [self.ignore_index] * len(content_ids)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
input_ids += self.assistant_tokens + content_ids
|
input_ids += self.assistant_tokens + content_ids
|
||||||
label_ids += [self.ignore_index] * len(
|
label_ids += (
|
||||||
self.assistant_tokens
|
[self.ignore_index] * len(self.assistant_tokens)
|
||||||
) + content_ids
|
+ content_ids
|
||||||
input_ids.append(self.tokenizer.eos_token_id)
|
+ [self.tokenizer.eos_token_id]
|
||||||
label_ids.append(self.tokenizer.eos_token_id)
|
)
|
||||||
|
|
||||||
input_ids = input_ids[: self.model_max_length]
|
input_ids = input_ids[: self.model_max_length]
|
||||||
label_ids = label_ids[: self.model_max_length]
|
label_ids = label_ids[: self.model_max_length]
|
||||||
# input_ids += [self.tokenizer.eos_token_id] * (len(label_ids) - len(input_ids))
|
# input_ids += [self.tokenizer.eos_token_id] * (len(label_ids) - len(input_ids))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user