Merge pull request #176 from LDLINGLINGLING/main

增加了qlora的训练方式
This commit is contained in:
LDLINGLINGLING 2024-07-26 15:27:44 +08:00 committed by GitHub
commit 2b43378e3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 15 deletions

View File

@ -1,4 +1,20 @@
"""
my package: langchain_demo
langchain 0.2.6
langchain-community 0.2.1
langchain-core 0.2.19
langchain-text-splitters 0.2.0
langchainplus-sdk 0.0.20
pypdf 4.3.0
pydantic 2.8.2
pydantic_core 2.20.1
transformers 4.41.1
triton 2.3.0
trl 0.8.6
vllm 0.5.0.post1+cu122
vllm-flash-attn 2.5.9
vllm_nccl_cu12 2.18.1.0.4.0
你只需要最少6g显存(足够)的显卡就能在消费级显卡上体验流畅的rag
使用方法
@ -29,27 +45,35 @@ import re
import gradio as gr
parser = ArgumentParser()
# 大语言模型参数设置
parser.add_argument(
"--cpm_model_path",
type=str,
default="openbmb/MiniCPM-1B-sft-bf16",
help="MiniCPM模型路径或者huggingface id"
)
parser.add_argument(
"--cpm_device", type=str, default="cuda:0", choices=["auto", "cuda:0"]
"--cpm_device", type=str, default="cuda:0", choices=["auto", "cuda:0"],
help="MiniCPM模型所在设备默认为cuda:0"
)
parser.add_argument("--backend", type=str, default="torch", choices=["torch", "vllm"],
help="使用torch还是vllm后端默认为torch"
)
parser.add_argument("--backend", type=str, default="torch", choices=["torch", "vllm"])
# 嵌入模型参数设置
parser.add_argument(
"--encode_model", type=str, default="BAAI/bge-base-zh"
"--encode_model", type=str, default="BAAI/bge-base-zh",
help="用于召回编码的embedding模型默认为BAAI/bge-base-zh,可输入本地地址"
)
parser.add_argument(
"--encode_model_device", type=str, default="cpu", choices=["cpu", "cuda:0"]
"--encode_model_device", type=str, default="cpu", choices=["cpu", "cuda:0"],
help="用于召回编码的embedding模型所在设备默认为cpu"
)
parser.add_argument("--query_instruction", type=str, default="")
parser.add_argument("--query_instruction", type=str, default="",help="召回时增加的前缀")
parser.add_argument(
"--file_path", type=str, default="/root/ld/pull_request/rag/红楼梦.pdf"
"--file_path", type=str, default="/root/ld/pull_request/rag/红楼梦.pdf",
help="需要检索的文本文件路径,gradio运行时无效"
)
# 生成参数
@ -60,9 +84,9 @@ parser.add_argument("--max_new_tokens", type=int, default=4096)
parser.add_argument("--repetition_penalty", type=float, default=1.02)
# retriever参数设置
parser.add_argument("--embed_top_k", type=int, default=5)
parser.add_argument("--chunk_size", type=int, default=256)
parser.add_argument("--chunk_overlap", type=int, default=50)
parser.add_argument("--embed_top_k", type=int, default=5,help="召回几个最相似的文本")
parser.add_argument("--chunk_size", type=int, default=256,help="文本切分时切分的长度")
parser.add_argument("--chunk_overlap", type=int, default=50,help="文本切分的重叠长度")
args = parser.parse_args()

View File

@ -22,6 +22,7 @@ pip install -r requirements.txt
+ SFT 全量微调: 4张显卡平均分配每张显卡占用 `30245MiB` 显存。
+ LORA 微调: 1张显卡占用 `10619MiB` 显存。
+ qlora 微调+cpu+offload: 1张显卡占用 `5500MiB` 显存。
> 请注意,该结果仅供参考,对于不同的参数,显存占用可能会有所不同。请结合你的硬件情况进行调整。

View File

@ -7,7 +7,7 @@ import torch
import transformers
from torch.utils.data import Dataset
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
TrainingArguments)
TrainingArguments,BitsAndBytesConfig)
@dataclass
@ -38,6 +38,7 @@ class TrainingArguments(transformers.TrainingArguments):
},
)
use_lora: bool = field(default=False)
qlora: bool = field(default=False)
class SupervisedDataset(Dataset):
@ -121,6 +122,7 @@ def load_model_and_tokenizer(
model_path: str,
max_length: int = 4096,
use_lora: bool = True,
qlora: bool = False,
bf16: bool = False,
fp16: bool = False,
):
@ -135,11 +137,33 @@ def load_model_and_tokenizer(
dtype = torch.float16
else:
dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
trust_remote_code=True,
)
if qlora:
assert use_lora, "use_lora must be True when use_qlora is True"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # 是否进行4bit量化
load_in_8bit=False, # 是否进行8bit量化
bnb_4bit_compute_dtype=torch.float16, # 计算精度设置
bnb_4bit_quant_storage=torch.uint8, # 量化权重的储存格式
bnb_4bit_quant_type="nf4", # 量化格式这里用的是正太分布的int4
bnb_4bit_use_double_quant=True, # 是否采用双量化即对zeropoint和scaling参数进行量化
llm_int8_enable_fp32_cpu_offload=False, # 是否llm使用int8cpu上保存的参数使用fp32
llm_int8_has_fp16_weight=False, # 是否启用混合精度
#llm_int8_skip_modules=["out_proj", "kv_proj", "lm_head"], # 不进行量化的模块
llm_int8_threshold=6.0, # llm.int8()算法中的离群值,根据这个值区分是否进行量化
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
trust_remote_code=True,
quantization_config=quantization_config,
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
trust_remote_code=True,
)
if use_lora:
from peft import LoraConfig, TaskType, get_peft_model
@ -170,6 +194,7 @@ if __name__ == "__main__":
model_path=model_args.model_name_or_path,
max_length=training_args.model_max_length,
use_lora=training_args.use_lora,
qlora=training_args.qlora,
bf16=training_args.bf16,
fp16=training_args.fp16
)