修复了两个bug,一个是代码中存在两个generate函数,另外一个是<用户>问题<AI>这种格式没有用到该代码中去的bug

This commit is contained in:
刘丹 2024-06-27 21:17:26 +08:00
parent 6a48f35950
commit 80e506289a

View File

@ -7,7 +7,8 @@ Using Model with https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-
Use this Code with command: Use this Code with command:
train: train:
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/AdvertiseGen --train --seed 2024 --iters 500 首先处理数据运行data_processing.ipynb
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/mlx_AdvertiseGen --train --seed 2024 --iters 500
输出结果如下 输出结果如下
@ -19,7 +20,7 @@ Iter 2: Val loss 4.001, Val took 1061.649s
训练结束之后文件夹下会有 adapters.npz 文件用于后续的测试接着运行测试命令 训练结束之后文件夹下会有 adapters.npz 文件用于后续的测试接着运行测试命令
test: test:
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/AdvertiseGen --test --seed 2024 python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/mlx_AdvertiseGen --test --seed 2024
输出结果如下 输出结果如下
@ -318,7 +319,7 @@ def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument( parser.add_argument(
"--model", "--model",
default="mlx_model", default="/Users/liudan/Downloads/模型/llamaformat_minicpm",
help="The path to the local model directory or Hugging Face repo.", help="The path to the local model directory or Hugging Face repo.",
) )
# Generation args # Generation args
@ -336,8 +337,7 @@ def build_parser():
"--prompt", "--prompt",
"-p", "-p",
type=str, type=str,
help="The prompt for generation", help="The prompt for generation"
default=None,
) )
# Training args # Training args
@ -349,7 +349,7 @@ def build_parser():
parser.add_argument( parser.add_argument(
"--data", "--data",
type=str, type=str,
default="data/", default="data/mlx_AdvertiseGen",
help="Directory with {train, valid, test}.json files", help="Directory with {train, valid, test}.json files",
) )
parser.add_argument( parser.add_argument(
@ -424,9 +424,10 @@ class ConversationDataset:
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
entry = self._data[idx] entry = self._data[idx]
content = entry.get("content", "") content = entry.get("input", "")
summary = entry.get("summary", "") summary = entry.get("output", "")
return content, summary prompt = entry.get("prompt", "")
return prompt, content, summary
def __len__(self): def __len__(self):
return len(self._data) return len(self._data)
@ -479,7 +480,9 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
# Collect batches from dataset # Collect batches from dataset
for i in range(0, len(indices) - batch_size + 1, batch_size): for i in range(0, len(indices) - batch_size + 1, batch_size):
# Encode batch # Encode batch
batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)] batch_samples=[dset[indices[i + j]] for j in range(batch_size)]
batch_format_text=['<用户>{}<AI>{}'.format(i[1]+i[0],i[2]) for i in batch_samples]
batch = [tokenizer.encode(i)+[tokenizer.eos_token_id] for i in batch_format_text]
lengths = [len(x) for x in batch] lengths = [len(x) for x in batch]
# Check if any sequence is longer than 2048 tokens # Check if any sequence is longer than 2048 tokens
if max(lengths) > 2048: if max(lengths) > 2048:
@ -645,7 +648,7 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
print(f"Iter {it + 1}: Saved adapter weights to {args.adapter_file}.") print(f"Iter {it + 1}: Saved adapter weights to {args.adapter_file}.")
def generate(model, prompt, tokenizer, args): def generate_string(model, prompt, tokenizer, args):
print(prompt, end="", flush=True) print(prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(prompt)) prompt = mx.array(tokenizer.encode(prompt))
@ -736,4 +739,4 @@ if __name__ == "__main__":
if args.prompt is not None: if args.prompt is not None:
print("Generating") print("Generating")
generate(model, args.prompt, tokenizer, args) generate_string(model, args.prompt, tokenizer, args)