Fix sft_dataset issue and naming error

This commit is contained in:
Xiang Long 2024-03-06 17:25:41 +08:00
parent 35804d3464
commit 74ecbcce5e
2 changed files with 5 additions and 6 deletions

View File

@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-
import json
from typing import Dict, Optional
from dataclasses import dataclass, field
from typing import Dict, Optional
import torch
from torch.utils.data import Dataset
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from torch.utils.data import Dataset
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
TrainingArguments)
@dataclass
@ -90,7 +90,6 @@ class SupervisedDataset(Dataset):
label_ids += (
[self.ignore_index] * len(self.assistant_tokens)
+ content_ids
+ [self.tokenizer.eos_token_id]
)
input_ids = input_ids[: self.model_max_length]

View File

@ -101,7 +101,7 @@
"metadata": {},
"outputs": [],
"source": [
"!bash lora_finetune_ds.sh"
"!bash lora_finetune.sh"
]
}
],