更新mlx 微调说明

This commit is contained in:
zR 2024-04-09 23:42:12 +08:00
parent fdaab94f1e
commit 6618dd93be
2 changed files with 31 additions and 10 deletions

View File

@ -5,7 +5,28 @@ Using Code is modified from https://github.com/ml-explore/mlx-examples.
Using Model with https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx
Use this Code with command:
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data finetune/data/AdvertiseGen --train --seed 2024 --iters 1000
train:
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/AdvertiseGen --train --seed 2024 --iters 500
输出结果如下
Training
Iter 1: Val loss 4.015, Val took 1067.669s
Iter 2: Val loss 4.001, Val took 1061.649s
...
训练结束之后文件夹下会有 adapters.npz 文件用于后续的测试接着运行测试命令
test:
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/AdvertiseGen --test --seed 2024
输出结果如下
Testing
Test loss 3.977, Test ppl 53.350.
"""
import argparse
import json
@ -395,8 +416,6 @@ def build_parser():
return parser
class ConversationDataset:
def __init__(self, path: Path):
@ -412,6 +431,7 @@ class ConversationDataset:
def __len__(self):
return len(self._data)
def load(args):
def load_and_check(name):
dataset_path = Path(args.data) / f"{name}.json"

View File

@ -1,11 +1,12 @@
transformers>=4.38.2
torch>=2.0.0
transformers>=4.39.1
torch>=2.2.0
triton>=2.2.0
httpx>=0.27.0
gradio>=4.21.0
gradio>=4.26.0
flash_attn>=2.4.1
accelerate>=0.28.0
sentence_transformers>=2.6.0
sse_starlette>=2.0.0
accelerate>=0.29.2
sentence_transformers>=2.6.1
sse_starlette>=2.1.0
tiktoken>=0.6.0
mlx_lm>=0.5.0
mlx_lm>=0.8.0
openai>=0.16.2