mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-02-02 04:53:19 +08:00
原始代码的usertoken是针对2b的,其他模型会有问题,现在根据不同模型都会调整
This commit is contained in:
parent
cf9a5be5be
commit
8ae10c60ff
@ -42,21 +42,21 @@ class TrainingArguments(transformers.TrainingArguments):
|
|||||||
|
|
||||||
class SupervisedDataset(Dataset):
|
class SupervisedDataset(Dataset):
|
||||||
"""Dataset for supervised fine-tuning."""
|
"""Dataset for supervised fine-tuning."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_path,
|
data_path,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
model_max_length=4096,
|
model_max_length=4096,
|
||||||
user_tokens=[1786, 4194, 95388],
|
user_tokens='<用户>',
|
||||||
assistant_tokens=[1786, 10850, 95388],
|
assistant_tokens='<AI>',
|
||||||
):
|
):
|
||||||
super(SupervisedDataset, self).__init__()
|
super(SupervisedDataset, self).__init__()
|
||||||
self.data = json.load(open(data_path))
|
self.data = json.load(open(data_path))
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.model_max_length = model_max_length
|
self.model_max_length = model_max_length
|
||||||
self.user_tokens = user_tokens
|
self.user_tokens = self.tokenizer(user_tokens)['input_ids']#针对不同模型,都可以对应到<用户>的id
|
||||||
self.assistant_tokens = assistant_tokens
|
self.assistant_tokens = self.tokenizer(assistant_tokens)['input_ids']#针对不同模型,都可以对应到<AI>的id
|
||||||
self.ignore_index = -100
|
self.ignore_index = -100
|
||||||
item = self.preprocessing(self.data[0])
|
item = self.preprocessing(self.data[0])
|
||||||
print("input:", self.tokenizer.decode(item["input_ids"]))
|
print("input:", self.tokenizer.decode(item["input_ids"]))
|
||||||
@ -64,7 +64,6 @@ class SupervisedDataset(Dataset):
|
|||||||
for id_ in item["label_ids"]:
|
for id_ in item["label_ids"]:
|
||||||
if id_ == -100:
|
if id_ == -100:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
labels.append(id_)
|
labels.append(id_)
|
||||||
print("label:", self.tokenizer.decode(labels))
|
print("label:", self.tokenizer.decode(labels))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user