原始代码的usertoken是针对2b的,其他模型会有问题,现在根据不同模型都会调整

This commit is contained in:
root 2024-06-21 15:31:24 +08:00
parent cf9a5be5be
commit 8ae10c60ff

View File

@ -48,15 +48,15 @@ class SupervisedDataset(Dataset):
data_path, data_path,
tokenizer, tokenizer,
model_max_length=4096, model_max_length=4096,
user_tokens=[1786, 4194, 95388], user_tokens='<用户>',
assistant_tokens=[1786, 10850, 95388], assistant_tokens='<AI>',
): ):
super(SupervisedDataset, self).__init__() super(SupervisedDataset, self).__init__()
self.data = json.load(open(data_path)) self.data = json.load(open(data_path))
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.model_max_length = model_max_length self.model_max_length = model_max_length
self.user_tokens = user_tokens self.user_tokens = self.tokenizer(user_tokens)['input_ids']#针对不同模型,都可以对应到<用户>的id
self.assistant_tokens = assistant_tokens self.assistant_tokens = self.tokenizer(assistant_tokens)['input_ids']#针对不同模型,都可以对应到<AI>的id
self.ignore_index = -100 self.ignore_index = -100
item = self.preprocessing(self.data[0]) item = self.preprocessing(self.data[0])
print("input:", self.tokenizer.decode(item["input_ids"])) print("input:", self.tokenizer.decode(item["input_ids"]))
@ -64,7 +64,6 @@ class SupervisedDataset(Dataset):
for id_ in item["label_ids"]: for id_ in item["label_ids"]:
if id_ == -100: if id_ == -100:
continue continue
labels.append(id_) labels.append(id_)
print("label:", self.tokenizer.decode(labels)) print("label:", self.tokenizer.decode(labels))