Fix sft_dataset issue and naming error

This commit is contained in:
Xiang Long 2024-03-06 17:25:41 +08:00
parent 35804d3464
commit 74ecbcce5e
2 changed files with 5 additions and 6 deletions

View File

@ -1,13 +1,13 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json import json
from typing import Dict, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Optional
import torch import torch
from torch.utils.data import Dataset
import transformers import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer from torch.utils.data import Dataset
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
TrainingArguments)
@dataclass @dataclass
@ -90,7 +90,6 @@ class SupervisedDataset(Dataset):
label_ids += ( label_ids += (
[self.ignore_index] * len(self.assistant_tokens) [self.ignore_index] * len(self.assistant_tokens)
+ content_ids + content_ids
+ [self.tokenizer.eos_token_id]
) )
input_ids = input_ids[: self.model_max_length] input_ids = input_ids[: self.model_max_length]

View File

@ -101,7 +101,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"!bash lora_finetune_ds.sh" "!bash lora_finetune.sh"
] ]
} }
], ],