mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-19 12:53:36 +08:00
Fix sft_dataset issue and naming error
This commit is contained in:
parent
35804d3464
commit
74ecbcce5e
@ -1,13 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
from typing import Dict, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
|
||||
TrainingArguments)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -90,7 +90,6 @@ class SupervisedDataset(Dataset):
|
||||
label_ids += (
|
||||
[self.ignore_index] * len(self.assistant_tokens)
|
||||
+ content_ids
|
||||
+ [self.tokenizer.eos_token_id]
|
||||
)
|
||||
|
||||
input_ids = input_ids[: self.model_max_length]
|
||||
|
||||
@ -101,7 +101,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!bash lora_finetune_ds.sh"
|
||||
"!bash lora_finetune.sh"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user