mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-02-03 21:53:45 +08:00
修改了finetune中的默认模型的错误
This commit is contained in:
parent
3347021a0c
commit
262840a805
@ -6,12 +6,13 @@ from typing import Dict, Optional
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
|
||||||
|
TrainingArguments)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
|
model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-2B-sft-bf16")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -41,25 +42,21 @@ class TrainingArguments(transformers.TrainingArguments):
|
|||||||
|
|
||||||
class SupervisedDataset(Dataset):
|
class SupervisedDataset(Dataset):
|
||||||
"""Dataset for supervised fine-tuning."""
|
"""Dataset for supervised fine-tuning."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_path,
|
data_path,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
model_max_length=4096,
|
model_max_length=4096,
|
||||||
user_tokens="<用户>",
|
user_tokens='<用户>',
|
||||||
assistant_tokens="<AI>",
|
assistant_tokens='<AI>',
|
||||||
):
|
):
|
||||||
super(SupervisedDataset, self).__init__()
|
super(SupervisedDataset, self).__init__()
|
||||||
self.data = json.load(open(data_path))
|
self.data = json.load(open(data_path))
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.model_max_length = model_max_length
|
self.model_max_length = model_max_length
|
||||||
self.user_tokens = self.tokenizer.encode(
|
self.user_tokens = self.tokenizer.encode(user_tokens) #针对不同模型,都可以对应到<用户>的id
|
||||||
user_tokens
|
self.assistant_tokens = self.tokenizer.encode(assistant_tokens) #针对不同模型,都可以对应到<AI>的id
|
||||||
) # 针对不同模型,都可以对应到<用户>的id
|
|
||||||
self.assistant_tokens = self.tokenizer.encode(
|
|
||||||
assistant_tokens
|
|
||||||
) # 针对不同模型,都可以对应到<AI>的id
|
|
||||||
self.ignore_index = -100
|
self.ignore_index = -100
|
||||||
item = self.preprocessing(self.data[0])
|
item = self.preprocessing(self.data[0])
|
||||||
print("input:", self.tokenizer.decode(item["input_ids"]))
|
print("input:", self.tokenizer.decode(item["input_ids"]))
|
||||||
@ -89,9 +86,10 @@ class SupervisedDataset(Dataset):
|
|||||||
] * len(content_ids)
|
] * len(content_ids)
|
||||||
else:
|
else:
|
||||||
input_ids += self.assistant_tokens + content_ids
|
input_ids += self.assistant_tokens + content_ids
|
||||||
label_ids += [self.ignore_index] * len(
|
label_ids += (
|
||||||
self.assistant_tokens
|
[self.ignore_index] * len(self.assistant_tokens)
|
||||||
) + content_ids
|
+ content_ids
|
||||||
|
)
|
||||||
|
|
||||||
input_ids.append(self.tokenizer.eos_token_id)
|
input_ids.append(self.tokenizer.eos_token_id)
|
||||||
label_ids.append(self.tokenizer.eos_token_id)
|
label_ids.append(self.tokenizer.eos_token_id)
|
||||||
@ -173,7 +171,7 @@ if __name__ == "__main__":
|
|||||||
max_length=training_args.model_max_length,
|
max_length=training_args.model_max_length,
|
||||||
use_lora=training_args.use_lora,
|
use_lora=training_args.use_lora,
|
||||||
bf16=training_args.bf16,
|
bf16=training_args.bf16,
|
||||||
fp16=training_args.fp16,
|
fp16=training_args.fp16
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = SupervisedDataset(
|
train_dataset = SupervisedDataset(
|
||||||
@ -197,4 +195,4 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
# save the incremental PEFT weights, more details can be found in https://huggingface.co/blog/peft
|
# save the incremental PEFT weights, more details can be found in https://huggingface.co/blog/peft
|
||||||
# model.save_pretrained("output_dir")
|
# model.save_pretrained("output_dir")
|
||||||
Loading…
x
Reference in New Issue
Block a user