From 6618dd93be7a857e054f7f9115f3052feb1b05ae Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Tue, 9 Apr 2024 23:42:12 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0mlx=20=E5=BE=AE=E8=B0=83?= =?UTF-8?q?=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- finetune/mlx_finetune.py | 26 +++++++++++++++++++++++--- requirements.txt | 15 ++++++++------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/finetune/mlx_finetune.py b/finetune/mlx_finetune.py index db1efa2..651df31 100644 --- a/finetune/mlx_finetune.py +++ b/finetune/mlx_finetune.py @@ -5,7 +5,28 @@ Using Code is modified from https://github.com/ml-explore/mlx-examples. Using Model with https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx Use this Code with command: -python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data finetune/data/AdvertiseGen --train --seed 2024 --iters 1000 + +train: +python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/AdvertiseGen --train --seed 2024 --iters 500 + +输出结果如下: + +Training +Iter 1: Val loss 4.015, Val took 1067.669s +Iter 2: Val loss 4.001, Val took 1061.649s +... + +训练结束之后,文件夹下会有 adapters.npz 文件,用于后续的测试。接着,运行测试命令 + +test: +python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/AdvertiseGen --test --seed 2024 + +输出结果如下: + +Testing +Test loss 3.977, Test ppl 53.350. + + """ import argparse import json @@ -395,8 +416,6 @@ def build_parser(): return parser - - class ConversationDataset: def __init__(self, path: Path): @@ -412,6 +431,7 @@ class ConversationDataset: def __len__(self): return len(self._data) + def load(args): def load_and_check(name): dataset_path = Path(args.data) / f"{name}.json" diff --git a/requirements.txt b/requirements.txt index ea3b194..8f8baa5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,12 @@ -transformers>=4.38.2 -torch>=2.0.0 +transformers>=4.39.1 +torch>=2.2.0 triton>=2.2.0 httpx>=0.27.0 -gradio>=4.21.0 +gradio>=4.26.0 flash_attn>=2.4.1 -accelerate>=0.28.0 -sentence_transformers>=2.6.0 -sse_starlette>=2.0.0 +accelerate>=0.29.2 +sentence_transformers>=2.6.1 +sse_starlette>=2.1.0 tiktoken>=0.6.0 -mlx_lm>=0.5.0 \ No newline at end of file +mlx_lm>=0.8.0 +openai>=0.16.2 \ No newline at end of file