diff --git a/finetune/mlx_finetune.py b/finetune/mlx_finetune.py index db1efa2..651df31 100644 --- a/finetune/mlx_finetune.py +++ b/finetune/mlx_finetune.py @@ -5,7 +5,28 @@ Using Code is modified from https://github.com/ml-explore/mlx-examples. Using Model with https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx Use this Code with command: -python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data finetune/data/AdvertiseGen --train --seed 2024 --iters 1000 + +train: +python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/AdvertiseGen --train --seed 2024 --iters 500 + +输出结果如下: + +Training +Iter 1: Val loss 4.015, Val took 1067.669s +Iter 2: Val loss 4.001, Val took 1061.649s +... + +训练结束之后,文件夹下会有 adapters.npz 文件,用于后续的测试。接着,运行测试命令 + +test: +python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/AdvertiseGen --test --seed 2024 + +输出结果如下: + +Testing +Test loss 3.977, Test ppl 53.350. + + """ import argparse import json @@ -395,8 +416,6 @@ def build_parser(): return parser - - class ConversationDataset: def __init__(self, path: Path): @@ -412,6 +431,7 @@ class ConversationDataset: def __len__(self): return len(self._data) + def load(args): def load_and_check(name): dataset_path = Path(args.data) / f"{name}.json" diff --git a/requirements.txt b/requirements.txt index ea3b194..8f8baa5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,12 @@ -transformers>=4.38.2 -torch>=2.0.0 +transformers>=4.39.1 +torch>=2.2.0 triton>=2.2.0 httpx>=0.27.0 -gradio>=4.21.0 +gradio>=4.26.0 flash_attn>=2.4.1 -accelerate>=0.28.0 -sentence_transformers>=2.6.0 -sse_starlette>=2.0.0 +accelerate>=0.29.2 +sentence_transformers>=2.6.1 +sse_starlette>=2.1.0 tiktoken>=0.6.0 -mlx_lm>=0.5.0 \ No newline at end of file +mlx_lm>=0.8.0 +openai>=0.16.2 \ No newline at end of file