mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-02-05 22:53:34 +08:00
更新mlx 微调说明
This commit is contained in:
parent
fdaab94f1e
commit
6618dd93be
@ -5,7 +5,28 @@ Using Code is modified from https://github.com/ml-explore/mlx-examples.
|
|||||||
Using Model with https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx
|
Using Model with https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx
|
||||||
|
|
||||||
Use this Code with command:
|
Use this Code with command:
|
||||||
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data finetune/data/AdvertiseGen --train --seed 2024 --iters 1000
|
|
||||||
|
train:
|
||||||
|
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/AdvertiseGen --train --seed 2024 --iters 500
|
||||||
|
|
||||||
|
输出结果如下:
|
||||||
|
|
||||||
|
Training
|
||||||
|
Iter 1: Val loss 4.015, Val took 1067.669s
|
||||||
|
Iter 2: Val loss 4.001, Val took 1061.649s
|
||||||
|
...
|
||||||
|
|
||||||
|
训练结束之后,文件夹下会有 adapters.npz 文件,用于后续的测试。接着,运行测试命令
|
||||||
|
|
||||||
|
test:
|
||||||
|
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/AdvertiseGen --test --seed 2024
|
||||||
|
|
||||||
|
输出结果如下:
|
||||||
|
|
||||||
|
Testing
|
||||||
|
Test loss 3.977, Test ppl 53.350.
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
@ -395,8 +416,6 @@ def build_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationDataset:
|
class ConversationDataset:
|
||||||
|
|
||||||
def __init__(self, path: Path):
|
def __init__(self, path: Path):
|
||||||
@ -412,6 +431,7 @@ class ConversationDataset:
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._data)
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
def load(args):
|
def load(args):
|
||||||
def load_and_check(name):
|
def load_and_check(name):
|
||||||
dataset_path = Path(args.data) / f"{name}.json"
|
dataset_path = Path(args.data) / f"{name}.json"
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
transformers>=4.38.2
|
transformers>=4.39.1
|
||||||
torch>=2.0.0
|
torch>=2.2.0
|
||||||
triton>=2.2.0
|
triton>=2.2.0
|
||||||
httpx>=0.27.0
|
httpx>=0.27.0
|
||||||
gradio>=4.21.0
|
gradio>=4.26.0
|
||||||
flash_attn>=2.4.1
|
flash_attn>=2.4.1
|
||||||
accelerate>=0.28.0
|
accelerate>=0.29.2
|
||||||
sentence_transformers>=2.6.0
|
sentence_transformers>=2.6.1
|
||||||
sse_starlette>=2.0.0
|
sse_starlette>=2.1.0
|
||||||
tiktoken>=0.6.0
|
tiktoken>=0.6.0
|
||||||
mlx_lm>=0.5.0
|
mlx_lm>=0.8.0
|
||||||
|
openai>=0.16.2
|
||||||
Loading…
x
Reference in New Issue
Block a user