mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-19 12:53:36 +08:00
commit
ead8064a52
@ -64,7 +64,7 @@ MiniCPM 是面壁智能与清华大学自然语言处理实验室共同开源的
|
||||
|-------------|------------|-----------|-----------|
|
||||
|[Transformers](#Huggingface模型)|[Transformers](#transformer_finetune)|[MLC部署](#MLC)|[GPTQ](#gptq)|
|
||||
|[vLLM](#vllm-推理)|[mlx_finetune](#mlx)|[llama.cpp](#llama.cpp)|[AWQ](#awq)|
|
||||
|[llama.cpp](#llama.cpp)|[llama_factory](https://github.com/OpenBMB/MiniCPM/tree/main/finetune/llama_factory_example/README.md)||[困惑度测试](#quantize_test)|
|
||||
|[llama.cpp](#llama.cpp)|[llama_factory](./finetune/llama_factory_example/README.md)||[困惑度测试](#quantize_test)|
|
||||
|[ollama](#ollama)||||
|
||||
|[fastllm](#fastllm)||||
|
||||
|[mlx_lm](#mlx_lm)||||
|
||||
|
||||
66
finetune/data_processing.ipynb
Normal file
66
finetune/data_processing.ipynb
Normal file
@ -0,0 +1,66 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. 准备数据集\n",
|
||||
"\n",
|
||||
"将数据集转换为更通用的格式\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 转换为 ChatML 格式\n",
|
||||
"import os\n",
|
||||
"import shutil\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"input_dir = \"data/AdvertiseGen\"\n",
|
||||
"output_dir = \"data/mlx_AdvertiseGen\"\n",
|
||||
"if os.path.exists(output_dir):\n",
|
||||
" shutil.rmtree(output_dir)\n",
|
||||
"os.makedirs(output_dir, exist_ok=True)\n",
|
||||
"\n",
|
||||
"for fn in [\"train.json\", \"dev.json\"]:\n",
|
||||
" data_out_list = []\n",
|
||||
" with open(os.path.join(input_dir, fn), \"r\") as f, open(os.path.join(output_dir, fn), \"w\") as fo:\n",
|
||||
" for line in f:\n",
|
||||
" if len(line.strip()) > 0:\n",
|
||||
" data = json.loads(line)\n",
|
||||
" data_out = {\"input\":data['content'],'prompt':\"/n请为以下关键词生成一条广告语。\",'output':data['summary']}\n",
|
||||
" data_out_list.append(data_out)\n",
|
||||
"\n",
|
||||
" for d in data_out_list:\n",
|
||||
" json_str = json.dumps(d,ensure_ascii=False) # 将字典转换为JSON字符串\n",
|
||||
" fo.write(json_str + '\\n') # 写入字符串并添加换行符\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -7,7 +7,8 @@ Using Model with https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-
|
||||
Use this Code with command:
|
||||
|
||||
train:
|
||||
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/AdvertiseGen --train --seed 2024 --iters 500
|
||||
首先处理数据,运行data_processing.ipynb
|
||||
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/mlx_AdvertiseGen --train --seed 2024 --iters 500
|
||||
|
||||
输出结果如下:
|
||||
|
||||
@ -19,7 +20,7 @@ Iter 2: Val loss 4.001, Val took 1061.649s
|
||||
训练结束之后,文件夹下会有 adapters.npz 文件,用于后续的测试。接着,运行测试命令
|
||||
|
||||
test:
|
||||
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/AdvertiseGen --test --seed 2024
|
||||
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/mlx_AdvertiseGen --test --seed 2024
|
||||
|
||||
输出结果如下:
|
||||
|
||||
@ -318,7 +319,7 @@ def build_parser():
|
||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx_model",
|
||||
default="/Users/liudan/Downloads/模型/llamaformat_minicpm",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
# Generation args
|
||||
@ -336,8 +337,7 @@ def build_parser():
|
||||
"--prompt",
|
||||
"-p",
|
||||
type=str,
|
||||
help="The prompt for generation",
|
||||
default=None,
|
||||
help="The prompt for generation"
|
||||
)
|
||||
|
||||
# Training args
|
||||
@ -349,7 +349,7 @@ def build_parser():
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
type=str,
|
||||
default="data/",
|
||||
default="data/mlx_AdvertiseGen",
|
||||
help="Directory with {train, valid, test}.json files",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -424,9 +424,10 @@ class ConversationDataset:
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
entry = self._data[idx]
|
||||
content = entry.get("content", "")
|
||||
summary = entry.get("summary", "")
|
||||
return content, summary
|
||||
content = entry.get("input", "")
|
||||
summary = entry.get("output", "")
|
||||
prompt = entry.get("prompt", "")
|
||||
return prompt, content, summary
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
@ -479,7 +480,9 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
|
||||
# Collect batches from dataset
|
||||
for i in range(0, len(indices) - batch_size + 1, batch_size):
|
||||
# Encode batch
|
||||
batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)]
|
||||
batch_samples=[dset[indices[i + j]] for j in range(batch_size)]
|
||||
batch_format_text=['<用户>{}<AI>{}'.format(i[1]+i[0],i[2]) for i in batch_samples]
|
||||
batch = [tokenizer.encode(i)+[tokenizer.eos_token_id] for i in batch_format_text]
|
||||
lengths = [len(x) for x in batch]
|
||||
# Check if any sequence is longer than 2048 tokens
|
||||
if max(lengths) > 2048:
|
||||
@ -645,7 +648,7 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
|
||||
print(f"Iter {it + 1}: Saved adapter weights to {args.adapter_file}.")
|
||||
|
||||
|
||||
def generate(model, prompt, tokenizer, args):
|
||||
def generate_string(model, prompt, tokenizer, args):
|
||||
print(prompt, end="", flush=True)
|
||||
|
||||
prompt = mx.array(tokenizer.encode(prompt))
|
||||
@ -736,4 +739,4 @@ if __name__ == "__main__":
|
||||
|
||||
if args.prompt is not None:
|
||||
print("Generating")
|
||||
generate(model, args.prompt, tokenizer, args)
|
||||
generate_string(model, args.prompt, tokenizer, args)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user