Fix finetune supervised dataset issue

This commit is contained in:
Xiang Long 2024-03-16 01:58:08 +08:00
parent 74ecbcce5e
commit 36337f70ea
53 changed files with 170463 additions and 54 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
*__pycache__*
*.pyc
finetune/output/*
wip.*

View File

@ -64,11 +64,6 @@ pip install -r requirements.txt
## 数据集格式示例 ## 数据集格式示例
这里以 AdvertiseGen 数据集为例,
您可以从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing)
或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载 AdvertiseGen 数据集。
将解压后的 AdvertiseGen 目录放到 `data` 目录下并自行转换为如下格式数据集。
> 请注意,现在的微调代码中加入了验证集,因此,对于一组完整的微调数据集,必须包含训练数据集和验证数据集,测试数据集可以不填写。或者直接用验证数据集代替。 > 请注意,现在的微调代码中加入了验证集,因此,对于一组完整的微调数据集,必须包含训练数据集和验证数据集,测试数据集可以不填写。或者直接用验证数据集代替。
``` ```

View File

@ -66,11 +66,6 @@ For the data file, the example uses the following format
## Dataset Format Example ## Dataset Format Example
Here, taking the AdvertiseGen dataset as an example,
you can download the AdvertiseGen dataset from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing)
or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) . After extracting the AdvertiseGen directory, place it in the `data` directory and convert it into the following format dataset.
> Please note, the fine-tuning code now includes a validation set, so for a complete set of fine-tuning datasets, it must contain training and validation datasets, while the test dataset is optional. Or, you can use the validation dataset in place of it. > Please note, the fine-tuning code now includes a validation set, so for a complete set of fine-tuning datasets, it must contain training and validation datasets, while the test dataset is optional. Or, you can use the validation dataset in place of it.
``` ```

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -73,7 +73,7 @@ class SupervisedDataset(Dataset):
def preprocessing(self, example): def preprocessing(self, example):
input_ids = [self.tokenizer.bos_token_id] input_ids = [self.tokenizer.bos_token_id]
label_ids = [] label_ids = [self.ignore_index]
for message in example["messages"]: for message in example["messages"]:
role = message["role"] role = message["role"]
@ -92,17 +92,22 @@ class SupervisedDataset(Dataset):
+ content_ids + content_ids
) )
input_ids.append(self.tokenizer.eos_token_id)
label_ids.append(self.tokenizer.eos_token_id)
# truncate to max len
input_ids = input_ids[: self.model_max_length] input_ids = input_ids[: self.model_max_length]
label_ids = label_ids[: self.model_max_length] label_ids = label_ids[: self.model_max_length]
# input_ids += [self.tokenizer.eos_token_id] * (len(label_ids) - len(input_ids)) attention_mask = [1] * len(input_ids)
# pad to max len
input_ids += [self.tokenizer.eos_token_id] * ( input_ids += [self.tokenizer.eos_token_id] * (
self.model_max_length - len(input_ids) self.model_max_length - len(input_ids)
) )
label_ids += [self.ignore_index] * (self.model_max_length - len(label_ids)) label_ids += [self.ignore_index] * (self.model_max_length - len(label_ids))
attention_mask += [0] * (self.model_max_length - len(attention_mask))
# convert to pt tensor
input_ids = torch.LongTensor(input_ids) input_ids = torch.LongTensor(input_ids)
label_ids = torch.LongTensor(label_ids) label_ids = torch.LongTensor(label_ids)
# print(f"len input_ids: {len(input_ids)}, len label_ids: {len(label_ids)}") attention_mask = torch.LongTensor(attention_mask)
attention_mask = input_ids.ne(self.tokenizer.eos_token_id)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"label_ids": label_ids, "label_ids": label_ids,
@ -158,7 +163,6 @@ def load_model_and_tokenizer(
if __name__ == "__main__": if __name__ == "__main__":
model_path = "/mnt/data/user/tc_agi/yh/models/MiniCPM" model_path = "/mnt/data/user/tc_agi/yh/models/MiniCPM"
max_length = 512
parser = transformers.HfArgumentParser( parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments) (ModelArguments, DataArguments, TrainingArguments)
) )

View File

@ -4,11 +4,13 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# MiniCPM-2B 参数高效微调LoRA消费级单卡示例\n", "# MiniCPM-2B 参数高效微调LoRAA100 80G 单卡示例\n",
"\n",
"显存更小的显卡可用 batch size 和 grad_accum 间时间换空间\n",
"\n", "\n",
"本 notebook 是一个使用 `AdvertiseGen` 数据集对 MiniCPM-2B 进行 LoRA 微调,使其具备专业的广告生成能力的代码示例。\n", "本 notebook 是一个使用 `AdvertiseGen` 数据集对 MiniCPM-2B 进行 LoRA 微调,使其具备专业的广告生成能力的代码示例。\n",
"\n", "\n",
"## 硬件需求\n", "## 最低硬件需求\n",
"- 显存12GB\n", "- 显存12GB\n",
"- 显卡架构:安培架构(推荐)\n", "- 显卡架构:安培架构(推荐)\n",
"- 内存16GB" "- 内存16GB"
@ -20,31 +22,7 @@
"source": [ "source": [
"## 1. 准备数据集\n", "## 1. 准备数据集\n",
"\n", "\n",
"下载 AdvertiseGen 数据集\n", "将数据集转换为更通用的格式\n"
"- [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing)\n",
"- [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1)\n",
"\n",
"下载后的数据集格式为 `.tar.gz` 的压缩格式,接下来的操作中,假设该压缩包被置于 `finetune/data/`。\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 校验文件完整性\n",
"!md5sum data/AdvertiseGen.tar.gz "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 解压数据集\n",
"!tar xvf data/AdvertiseGen.tar.gz "
] ]
}, },
{ {
@ -103,6 +81,47 @@
"source": [ "source": [
"!bash lora_finetune.sh" "!bash lora_finetune.sh"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 推理验证"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from tqdm import tqdm\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = \"output/AdvertiseGenLoRA/20240315224356/checkpoint-3000\"\n",
"tokenizer = AutoTokenizer.from_pretrained(path)\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" path, torch_dtype=torch.bfloat16, device_map=\"cuda\", trust_remote_code=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res, history = model.chat(tokenizer, query=\"<用户>类型#上衣*材质#牛仔布*颜色#白色*风格#简约*图案#刺绣*衣样式#外套*衣款式#破洞<AI>\", max_length=80, top_p=0.5)\n",
"res, history"
]
} }
], ],
"metadata": { "metadata": {
@ -121,7 +140,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.13" "version": "3.10.12"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -2,13 +2,13 @@ formatted_time=$(date +"%Y%m%d%H%M%S")
echo $formatted_time echo $formatted_time
deepspeed --include localhost:0 finetune.py \ deepspeed --include localhost:1 finetune.py \
--model_name_or_path <your_model_name_or_path> \ --model_name_or_path MiniCPM-2B-sft-bf16 \
--output_dir output/AdvertiseGenLoRA/$formatted_time/ \ --output_dir output/AdvertiseGenLoRA/$formatted_time/ \
--train_data_path data/AdvertiseGenChatML/train.json \ --train_data_path data/AdvertiseGenChatML/train.json \
--eval_data_path data/AdvertiseGenChatML/dev.json \ --eval_data_path data/AdvertiseGenChatML/dev.json \
--learning_rate 1e-3 --per_device_train_batch_size 1 \ --learning_rate 5e-5 --per_device_train_batch_size 32 \
--per_device_eval_batch_size 1 --fp16 --use_lora \ --per_device_eval_batch_size 64 --model_max_length 384 --bf16 --use_lora \
--gradient_accumulation_steps 1 --warmup_steps 100 \ --gradient_accumulation_steps 1 --warmup_steps 100 \
--max_steps 3000 --weight_decay 0.01 \ --max_steps 3000 --weight_decay 0.01 \
--evaluation_strategy steps --eval_steps 500 \ --evaluation_strategy steps --eval_steps 500 \

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,17 @@
formatted_time=$(date +"%Y%m%d%H%M%S")
echo $formatted_time
deepspeed --include localhost:1 --master_port 19888 finetune.py \
--model_name_or_path MiniCPM-2B-sft-bf16 \
--output_dir output/OCNLILoRA/$formatted_time/ \
--train_data_path data/ocnli_public_chatml/train.json \
--eval_data_path data/ocnli_public_chatml/dev.json \
--learning_rate 5e-5 --per_device_train_batch_size 80 \
--per_device_eval_batch_size 128 --model_max_length 128 --bf16 --use_lora \
--gradient_accumulation_steps 1 --warmup_steps 100 \
--max_steps 1000 --weight_decay 0.01 \
--evaluation_strategy steps --eval_steps 500 \
--save_strategy steps --save_steps 500 --seed 42 \
--log_level info --logging_strategy steps --logging_steps 10 \
--deepspeed configs/ds_config_zero3_offload.json

View File

@ -2,16 +2,16 @@ formatted_time=$(date +"%Y%m%d%H%M%S")
echo $formatted_time echo $formatted_time
deepspeed --include localhost:0,1,2,3 finetune.py \ deepspeed --include localhost:0,1 finetune.py \
--model_name_or_path <your_model_name_or_path> \ --model_name_or_path MiniCPM-2B-sft-bf16 \
--output_dir output/AdvertiseGenSFT/$formatted_time/ \ --output_dir output/AdvertiseGenSFT/$formatted_time/ \
--train_data_path data/AdvertiseGenChatML/train.json \ --train_data_path data/AdvertiseGenChatML/train.json \
--eval_data_path data/AdvertiseGenChatML/dev.json \ --eval_data_path data/AdvertiseGenChatML/dev.json \
--learning_rate 1e-3 --per_device_train_batch_size 1 \ --learning_rate 5e-5 --per_device_train_batch_size 14 \
--per_device_eval_batch_size 4 --bf16 \ --per_device_eval_batch_size 32 --bf16 \
--gradient_accumulation_steps 8 --warmup_steps 100 \ --gradient_accumulation_steps 2 --warmup_steps 100 \
--max_steps 3000 --weight_decay 0.01 \ --max_steps 3000 --weight_decay 0.01 \
--evaluation_strategy steps --eval_steps 500 \ --evaluation_strategy steps --eval_steps 100 \
--save_strategy steps --save_steps 500 --seed 42 \ --save_strategy steps --save_steps 500 --seed 42 \
--log_level info --logging_strategy steps --logging_steps 10 \ --log_level info --logging_strategy steps --logging_steps 10 \
--deepspeed configs/ds_config_zero3_offload.json --deepspeed configs/ds_config_zero2.json