mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-02-05 22:53:34 +08:00
修复了两个bug,一个是代码中存在两个generate函数,另外一个是<用户>问题<AI>这种格式没有用到该代码中去的bug
This commit is contained in:
parent
6a48f35950
commit
80e506289a
@ -7,7 +7,8 @@ Using Model with https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-
|
|||||||
Use this Code with command:
|
Use this Code with command:
|
||||||
|
|
||||||
train:
|
train:
|
||||||
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/AdvertiseGen --train --seed 2024 --iters 500
|
首先处理数据,运行data_processing.ipynb
|
||||||
|
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/mlx_AdvertiseGen --train --seed 2024 --iters 500
|
||||||
|
|
||||||
输出结果如下:
|
输出结果如下:
|
||||||
|
|
||||||
@ -19,7 +20,7 @@ Iter 2: Val loss 4.001, Val took 1061.649s
|
|||||||
训练结束之后,文件夹下会有 adapters.npz 文件,用于后续的测试。接着,运行测试命令
|
训练结束之后,文件夹下会有 adapters.npz 文件,用于后续的测试。接着,运行测试命令
|
||||||
|
|
||||||
test:
|
test:
|
||||||
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/AdvertiseGen --test --seed 2024
|
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data data/mlx_AdvertiseGen --test --seed 2024
|
||||||
|
|
||||||
输出结果如下:
|
输出结果如下:
|
||||||
|
|
||||||
@ -318,7 +319,7 @@ def build_parser():
|
|||||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
default="mlx_model",
|
default="/Users/liudan/Downloads/模型/llamaformat_minicpm",
|
||||||
help="The path to the local model directory or Hugging Face repo.",
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
)
|
)
|
||||||
# Generation args
|
# Generation args
|
||||||
@ -336,8 +337,7 @@ def build_parser():
|
|||||||
"--prompt",
|
"--prompt",
|
||||||
"-p",
|
"-p",
|
||||||
type=str,
|
type=str,
|
||||||
help="The prompt for generation",
|
help="The prompt for generation"
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training args
|
# Training args
|
||||||
@ -349,7 +349,7 @@ def build_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data",
|
"--data",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/",
|
default="data/mlx_AdvertiseGen",
|
||||||
help="Directory with {train, valid, test}.json files",
|
help="Directory with {train, valid, test}.json files",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -424,9 +424,10 @@ class ConversationDataset:
|
|||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
entry = self._data[idx]
|
entry = self._data[idx]
|
||||||
content = entry.get("content", "")
|
content = entry.get("input", "")
|
||||||
summary = entry.get("summary", "")
|
summary = entry.get("output", "")
|
||||||
return content, summary
|
prompt = entry.get("prompt", "")
|
||||||
|
return prompt, content, summary
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._data)
|
return len(self._data)
|
||||||
@ -479,7 +480,9 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
|
|||||||
# Collect batches from dataset
|
# Collect batches from dataset
|
||||||
for i in range(0, len(indices) - batch_size + 1, batch_size):
|
for i in range(0, len(indices) - batch_size + 1, batch_size):
|
||||||
# Encode batch
|
# Encode batch
|
||||||
batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)]
|
batch_samples=[dset[indices[i + j]] for j in range(batch_size)]
|
||||||
|
batch_format_text=['<用户>{}<AI>{}'.format(i[1]+i[0],i[2]) for i in batch_samples]
|
||||||
|
batch = [tokenizer.encode(i)+[tokenizer.eos_token_id] for i in batch_format_text]
|
||||||
lengths = [len(x) for x in batch]
|
lengths = [len(x) for x in batch]
|
||||||
# Check if any sequence is longer than 2048 tokens
|
# Check if any sequence is longer than 2048 tokens
|
||||||
if max(lengths) > 2048:
|
if max(lengths) > 2048:
|
||||||
@ -645,7 +648,7 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
|
|||||||
print(f"Iter {it + 1}: Saved adapter weights to {args.adapter_file}.")
|
print(f"Iter {it + 1}: Saved adapter weights to {args.adapter_file}.")
|
||||||
|
|
||||||
|
|
||||||
def generate(model, prompt, tokenizer, args):
|
def generate_string(model, prompt, tokenizer, args):
|
||||||
print(prompt, end="", flush=True)
|
print(prompt, end="", flush=True)
|
||||||
|
|
||||||
prompt = mx.array(tokenizer.encode(prompt))
|
prompt = mx.array(tokenizer.encode(prompt))
|
||||||
@ -736,4 +739,4 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if args.prompt is not None:
|
if args.prompt is not None:
|
||||||
print("Generating")
|
print("Generating")
|
||||||
generate(model, args.prompt, tokenizer, args)
|
generate_string(model, args.prompt, tokenizer, args)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user