mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-30 11:13:35 +08:00
Fix sft_dataset issue and naming error
This commit is contained in:
parent
35804d3464
commit
74ecbcce5e
@ -1,13 +1,13 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import json
|
import json
|
||||||
from typing import Dict, Optional
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
from torch.utils.data import Dataset
|
||||||
|
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
|
||||||
|
TrainingArguments)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -90,7 +90,6 @@ class SupervisedDataset(Dataset):
|
|||||||
label_ids += (
|
label_ids += (
|
||||||
[self.ignore_index] * len(self.assistant_tokens)
|
[self.ignore_index] * len(self.assistant_tokens)
|
||||||
+ content_ids
|
+ content_ids
|
||||||
+ [self.tokenizer.eos_token_id]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
input_ids = input_ids[: self.model_max_length]
|
input_ids = input_ids[: self.model_max_length]
|
||||||
|
|||||||
@ -101,7 +101,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!bash lora_finetune_ds.sh"
|
"!bash lora_finetune.sh"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user