mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-19 12:53:36 +08:00
Add fine tune scripts
This commit is contained in:
parent
2bd210871e
commit
fb341897a6
32
finetune/configs/ds_config_zero2.json
Normal file
32
finetune/configs/ds_config_zero2.json
Normal file
@ -0,0 +1,32 @@
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": 1.0,
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"wall_clock_breakdown": false,
|
||||
"flops_profiler": {
|
||||
"enabled": false,
|
||||
"profile_step": 1,
|
||||
"module_depth": -1,
|
||||
"top_modules": 1,
|
||||
"detailed": true,
|
||||
"output_file": null
|
||||
}
|
||||
}
|
||||
38
finetune/configs/ds_config_zero2_offload.json
Normal file
38
finetune/configs/ds_config_zero2_offload.json
Normal file
@ -0,0 +1,38 @@
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": 1.0,
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"contiguous_gradients": true,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu"
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu"
|
||||
}
|
||||
},
|
||||
"wall_clock_breakdown": false,
|
||||
"flops_profiler": {
|
||||
"enabled": false,
|
||||
"profile_step": 1,
|
||||
"module_depth": -1,
|
||||
"top_modules": 1,
|
||||
"detailed": true,
|
||||
"output_file": null
|
||||
}
|
||||
}
|
||||
22
finetune/configs/ds_config_zero3.json
Normal file
22
finetune/configs/ds_config_zero3.json
Normal file
@ -0,0 +1,22 @@
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": 1.0,
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"overlap_comm": true,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"flops_profiler": {
|
||||
"enabled": false,
|
||||
"profile_step": 1,
|
||||
"module_depth": -1,
|
||||
"top_modules": 1,
|
||||
"detailed": true,
|
||||
"output_file": null
|
||||
}
|
||||
}
|
||||
198
finetune/finetune.py
Normal file
198
finetune/finetune.py
Normal file
@ -0,0 +1,198 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
train_data_path: str = field(
|
||||
default="data/AdvertiseGenChatML/train.json",
|
||||
metadata={"help": "Path to the training data."},
|
||||
)
|
||||
eval_data_path: str = field(
|
||||
default="data/AdvertiseGenChatML/dev.json",
|
||||
metadata={"help": "Path to the test data."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(transformers.TrainingArguments):
|
||||
cache_dir: Optional[str] = field(default=None)
|
||||
optim: str = field(default="adamw_torch")
|
||||
model_max_length: int = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
||||
},
|
||||
)
|
||||
use_lora: bool = field(default=False)
|
||||
|
||||
|
||||
class SupervisedDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path,
|
||||
tokenizer,
|
||||
model_max_length=4096,
|
||||
user_tokens=[1786, 4194, 95388],
|
||||
assistant_tokens=[1786, 10850, 95388],
|
||||
):
|
||||
super(SupervisedDataset, self).__init__()
|
||||
self.data = json.load(open(data_path))
|
||||
self.tokenizer = tokenizer
|
||||
self.model_max_length = model_max_length
|
||||
self.user_tokens = user_tokens
|
||||
self.assistant_tokens = assistant_tokens
|
||||
self.ignore_index = -100
|
||||
item = self.preprocessing(self.data[0])
|
||||
print("input:", self.tokenizer.decode(item["input_ids"]))
|
||||
labels = []
|
||||
for id_ in item["label_ids"]:
|
||||
if id_ == -100:
|
||||
continue
|
||||
|
||||
labels.append(id_)
|
||||
print("label:", self.tokenizer.decode(labels))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def preprocessing(self, example):
|
||||
input_ids = [self.tokenizer.bos_token_id]
|
||||
label_ids = [self.tokenizer.bos_token_id]
|
||||
|
||||
for message in example["messages"]:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
|
||||
|
||||
if role == "user":
|
||||
input_ids += self.user_tokens + content_ids
|
||||
label_ids += (
|
||||
[self.ignore_index] * len(self.user_tokens)
|
||||
+ [self.tokenizer.eos_token_id]
|
||||
+ [self.ignore_index] * len(content_ids)
|
||||
)
|
||||
else:
|
||||
input_ids += self.assistant_tokens + content_ids
|
||||
label_ids += [self.ignore_index] * len(
|
||||
self.assistant_tokens
|
||||
) + content_ids
|
||||
input_ids.append(self.tokenizer.eos_token_id)
|
||||
label_ids.append(self.tokenizer.eos_token_id)
|
||||
input_ids = input_ids[: self.model_max_length]
|
||||
label_ids = label_ids[: self.model_max_length]
|
||||
# input_ids += [self.tokenizer.eos_token_id] * (len(label_ids) - len(input_ids))
|
||||
input_ids += [self.tokenizer.eos_token_id] * (
|
||||
self.model_max_length - len(input_ids)
|
||||
)
|
||||
label_ids += [self.ignore_index] * (self.model_max_length - len(label_ids))
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
label_ids = torch.LongTensor(label_ids)
|
||||
# print(f"len input_ids: {len(input_ids)}, len label_ids: {len(label_ids)}")
|
||||
attention_mask = input_ids.ne(self.tokenizer.eos_token_id)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"label_ids": label_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
|
||||
return self.preprocessing(self.data[idx])
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
model_path: str,
|
||||
max_length: int = 4096,
|
||||
use_lora: bool = True,
|
||||
bf16: bool = False,
|
||||
fp16: bool = False,
|
||||
):
|
||||
"""load model and tokenizer"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
assert not (bf16 and fp16), "bf16 or fp16, not both"
|
||||
if bf16:
|
||||
dtype = torch.bfloat16
|
||||
elif fp16:
|
||||
dtype = torch.float16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=dtype,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if use_lora:
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
|
||||
lora_config = LoraConfig(
|
||||
init_lora_weights="gaussian",
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
r=8,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1,
|
||||
inference_mode=False,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
# trainable params: 2,949,120 || all params: 3,010,652,928 || trainable%: 0.09795616002669305
|
||||
model.print_trainable_parameters()
|
||||
model.enable_input_require_grads()
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_path = "/mnt/data/user/tc_agi/yh/models/MiniCPM"
|
||||
max_length = 512
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
model, tokenizer = load_model_and_tokenizer(
|
||||
model_path=model_args.model_name_or_path,
|
||||
max_length=training_args.model_max_length,
|
||||
use_lora=training_args.use_lora,
|
||||
)
|
||||
|
||||
train_dataset = SupervisedDataset(
|
||||
data_path=data_args.train_data_path,
|
||||
tokenizer=tokenizer,
|
||||
model_max_length=training_args.model_max_length,
|
||||
)
|
||||
eval_dataset = SupervisedDataset(
|
||||
data_path=data_args.eval_data_path,
|
||||
tokenizer=tokenizer,
|
||||
model_max_length=training_args.model_max_length,
|
||||
)
|
||||
|
||||
formatted_time = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
# compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
129
finetune/lora_finetune.ipynb
Normal file
129
finetune/lora_finetune.ipynb
Normal file
@ -0,0 +1,129 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# MiniCPM-2B 参数高效微调(LoRA)消费级单卡示例\n",
|
||||
"\n",
|
||||
"本 notebook 是一个使用 `AdvertiseGen` 数据集对 MiniCPM-2B 进行 LoRA 微调,使其具备专业的广告生成能力的代码示例。\n",
|
||||
"\n",
|
||||
"## 硬件需求\n",
|
||||
"- 显存:24GB\n",
|
||||
"- 显卡架构:安培架构(推荐)\n",
|
||||
"- 内存:16GB"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. 准备数据集\n",
|
||||
"\n",
|
||||
"下载 AdvertiseGen 数据集\n",
|
||||
"- [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing)\n",
|
||||
"- [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1)\n",
|
||||
"\n",
|
||||
"下载后的数据集格式为 `.tar.gz` 的压缩格式,接下来的操作中,假设该压缩包被置于 `finetune/data/`。\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 校验文件完整性\n",
|
||||
"!md5sum data/AdvertiseGen.tar.gz "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 解压数据集\n",
|
||||
"!tar xvf data/AdvertiseGen.tar.gz "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 转换为 ChatML 格式\n",
|
||||
"import os\n",
|
||||
"import shutil\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"input_dir = \"data/AdvertiseGen\"\n",
|
||||
"output_dir = \"data/AdvertiseGenChatML\"\n",
|
||||
"if os.path.exists(output_dir):\n",
|
||||
" shutil.rmtree(output_dir)\n",
|
||||
"os.makedirs(output_dir, exist_ok=True)\n",
|
||||
"\n",
|
||||
"for fn in [\"train.json\", \"dev.json\"]:\n",
|
||||
" data_out_list = []\n",
|
||||
" with open(os.path.join(input_dir, fn), \"r\") as f, open(os.path.join(output_dir, fn), \"w\") as fo:\n",
|
||||
" for line in f:\n",
|
||||
" if len(line.strip()) > 0:\n",
|
||||
" data = json.loads(line)\n",
|
||||
" data_out = {\n",
|
||||
" \"messages\": [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": data[\"content\"],\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"role\": \"assistant\",\n",
|
||||
" \"content\": data[\"summary\"],\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" data_out_list.append(data_out)\n",
|
||||
" json.dump(data_out_list, fo, ensure_ascii=False, indent=4)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. 使用 LoRA 进行微调\n",
|
||||
"\n",
|
||||
"命令行一键运行"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!bash lora_finetune.sh"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
16
finetune/lora_finetune.sh
Normal file
16
finetune/lora_finetune.sh
Normal file
@ -0,0 +1,16 @@
|
||||
formatted_time=$(date +"%Y%m%d%H%M%S")
|
||||
echo $formatted_time
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python finetune.py \
|
||||
--model_name_or_path <your_model_name_or_path> \
|
||||
--output_dir output/AdvertiseGenLoRA/$formatted_time/ \
|
||||
--train_data_path data/AdvertiseGenChatML/train.json \
|
||||
--eval_data_path data/AdvertiseGenChatML/dev.json \
|
||||
--learning_rate 1e-3 --per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 32 --fp16\
|
||||
--gradient_accumulation_steps 8 --warmup_steps 100 \
|
||||
--max_steps 3000 --weight_decay 0.01 \
|
||||
--evaluation_strategy steps --eval_steps 500 \
|
||||
--save_strategy steps --save_steps 500 \
|
||||
--use_lora true --seed 42 \
|
||||
--log_level info --logging_strategy steps --logging_steps 10
|
||||
17
finetune/lora_finetune_ds.sh
Normal file
17
finetune/lora_finetune_ds.sh
Normal file
@ -0,0 +1,17 @@
|
||||
formatted_time=$(date +"%Y%m%d%H%M%S")
|
||||
echo $formatted_time
|
||||
|
||||
|
||||
deepspeed --include localhost:0,1 finetune.py \
|
||||
--model_name_or_path <your_model_name_or_path> \
|
||||
--output_dir output/AdvertiseGenLoRA/$formatted_time/ \
|
||||
--train_data_path data/AdvertiseGenChatML/train.json \
|
||||
--eval_data_path data/AdvertiseGenChatML/dev.json \
|
||||
--learning_rate 1e-3 --per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 --fp16 --use_lora \
|
||||
--gradient_accumulation_steps 8 --warmup_steps 100 \
|
||||
--max_steps 3000 --weight_decay 0.01 \
|
||||
--evaluation_strategy steps --eval_steps 500 \
|
||||
--save_strategy steps --save_steps 500 --seed 42 \
|
||||
--log_level info --logging_strategy steps --logging_steps 10 \
|
||||
--deepspeed configs/ds_config_zero2_offload.json
|
||||
9
finetune/requirements.txt
Normal file
9
finetune/requirements.txt
Normal file
@ -0,0 +1,9 @@
|
||||
# for finetune
|
||||
jieba>=0.42.1
|
||||
ruamel_yaml>=0.18.5
|
||||
rouge_chinese>=1.0.3
|
||||
jupyter>=1.0.0
|
||||
datasets>=2.16.1
|
||||
peft>=0.7.1
|
||||
deepspeed>=0.13.1
|
||||
flash_attn>=2.5.1
|
||||
17
finetune/sft_finetune.sh
Normal file
17
finetune/sft_finetune.sh
Normal file
@ -0,0 +1,17 @@
|
||||
formatted_time=$(date +"%Y%m%d%H%M%S")
|
||||
echo $formatted_time
|
||||
|
||||
|
||||
deepspeed --include localhost:1,2 finetune.py \
|
||||
--model_name_or_path <your_model_name_or_path> \
|
||||
--output_dir output/AdvertiseGenLoRA/$formatted_time/ \
|
||||
--train_data_path data/AdvertiseGenChatML/train.json \
|
||||
--eval_data_path data/AdvertiseGenChatML/dev.json \
|
||||
--learning_rate 1e-3 --per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 32 --fp16 \
|
||||
--gradient_accumulation_steps 8 --warmup_steps 100 \
|
||||
--max_steps 3000 --weight_decay 0.01 \
|
||||
--evaluation_strategy steps --eval_steps 500 \
|
||||
--save_strategy steps --save_steps 500 --seed 42 \
|
||||
--log_level info --logging_strategy steps --logging_steps 10 \
|
||||
--deepspeed configs/ds_config_zero2.json
|
||||
Loading…
x
Reference in New Issue
Block a user