mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-19 12:53:36 +08:00
修改了finetune中的默认模型的错误
This commit is contained in:
parent
3347021a0c
commit
262840a805
@ -6,12 +6,13 @@ from typing import Dict, Optional
|
||||
import torch
|
||||
import transformers
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
|
||||
TrainingArguments)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
|
||||
model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-2B-sft-bf16")
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -41,25 +42,21 @@ class TrainingArguments(transformers.TrainingArguments):
|
||||
|
||||
class SupervisedDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path,
|
||||
tokenizer,
|
||||
model_max_length=4096,
|
||||
user_tokens="<用户>",
|
||||
assistant_tokens="<AI>",
|
||||
user_tokens='<用户>',
|
||||
assistant_tokens='<AI>',
|
||||
):
|
||||
super(SupervisedDataset, self).__init__()
|
||||
self.data = json.load(open(data_path))
|
||||
self.tokenizer = tokenizer
|
||||
self.model_max_length = model_max_length
|
||||
self.user_tokens = self.tokenizer.encode(
|
||||
user_tokens
|
||||
) # 针对不同模型,都可以对应到<用户>的id
|
||||
self.assistant_tokens = self.tokenizer.encode(
|
||||
assistant_tokens
|
||||
) # 针对不同模型,都可以对应到<AI>的id
|
||||
self.user_tokens = self.tokenizer.encode(user_tokens) #针对不同模型,都可以对应到<用户>的id
|
||||
self.assistant_tokens = self.tokenizer.encode(assistant_tokens) #针对不同模型,都可以对应到<AI>的id
|
||||
self.ignore_index = -100
|
||||
item = self.preprocessing(self.data[0])
|
||||
print("input:", self.tokenizer.decode(item["input_ids"]))
|
||||
@ -89,9 +86,10 @@ class SupervisedDataset(Dataset):
|
||||
] * len(content_ids)
|
||||
else:
|
||||
input_ids += self.assistant_tokens + content_ids
|
||||
label_ids += [self.ignore_index] * len(
|
||||
self.assistant_tokens
|
||||
) + content_ids
|
||||
label_ids += (
|
||||
[self.ignore_index] * len(self.assistant_tokens)
|
||||
+ content_ids
|
||||
)
|
||||
|
||||
input_ids.append(self.tokenizer.eos_token_id)
|
||||
label_ids.append(self.tokenizer.eos_token_id)
|
||||
@ -173,7 +171,7 @@ if __name__ == "__main__":
|
||||
max_length=training_args.model_max_length,
|
||||
use_lora=training_args.use_lora,
|
||||
bf16=training_args.bf16,
|
||||
fp16=training_args.fp16,
|
||||
fp16=training_args.fp16
|
||||
)
|
||||
|
||||
train_dataset = SupervisedDataset(
|
||||
@ -197,4 +195,4 @@ if __name__ == "__main__":
|
||||
|
||||
trainer.train()
|
||||
# save the incremental PEFT weights, more details can be found in https://huggingface.co/blog/peft
|
||||
# model.save_pretrained("output_dir")
|
||||
# model.save_pretrained("output_dir")
|
||||
Loading…
x
Reference in New Issue
Block a user