MiniCPM/finetune/lora_finetune.ipynb
2024-03-16 01:58:08 +08:00

149 lines
3.9 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MiniCPM-2B 参数高效微调LoRAA100 80G 单卡示例\n",
"\n",
"显存更小的显卡可用 batch size 和 grad_accum 间时间换空间\n",
"\n",
"本 notebook 是一个使用 `AdvertiseGen` 数据集对 MiniCPM-2B 进行 LoRA 微调,使其具备专业的广告生成能力的代码示例。\n",
"\n",
"## 最低硬件需求\n",
"- 显存12GB\n",
"- 显卡架构:安培架构(推荐)\n",
"- 内存16GB"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. 准备数据集\n",
"\n",
"将数据集转换为更通用的格式\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 转换为 ChatML 格式\n",
"import os\n",
"import shutil\n",
"import json\n",
"\n",
"input_dir = \"data/AdvertiseGen\"\n",
"output_dir = \"data/AdvertiseGenChatML\"\n",
"if os.path.exists(output_dir):\n",
" shutil.rmtree(output_dir)\n",
"os.makedirs(output_dir, exist_ok=True)\n",
"\n",
"for fn in [\"train.json\", \"dev.json\"]:\n",
" data_out_list = []\n",
" with open(os.path.join(input_dir, fn), \"r\") as f, open(os.path.join(output_dir, fn), \"w\") as fo:\n",
" for line in f:\n",
" if len(line.strip()) > 0:\n",
" data = json.loads(line)\n",
" data_out = {\n",
" \"messages\": [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": data[\"content\"],\n",
" },\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": data[\"summary\"],\n",
" },\n",
" ]\n",
" }\n",
" data_out_list.append(data_out)\n",
" json.dump(data_out_list, fo, ensure_ascii=False, indent=4)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. 使用 LoRA 进行微调\n",
"\n",
"命令行一键运行"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!bash lora_finetune.sh"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 推理验证"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from tqdm import tqdm\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = \"output/AdvertiseGenLoRA/20240315224356/checkpoint-3000\"\n",
"tokenizer = AutoTokenizer.from_pretrained(path)\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" path, torch_dtype=torch.bfloat16, device_map=\"cuda\", trust_remote_code=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res, history = model.chat(tokenizer, query=\"<用户>类型#上衣*材质#牛仔布*颜色#白色*风格#简约*图案#刺绣*衣样式#外套*衣款式#破洞<AI>\", max_length=80, top_p=0.5)\n",
"res, history"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}