更新训练集类型和推理fp16部分

This commit is contained in:
zR 2024-04-03 14:47:45 +08:00
parent 111657c02c
commit fdaab94f1e
2 changed files with 28 additions and 31 deletions

View File

@ -7,26 +7,29 @@ import gradio as gr
import torch import torch
from threading import Thread from threading import Thread
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
TextIteratorStreamer TextIteratorStreamer
) )
import warnings import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="") parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"]) parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"])
parser.add_argument("--server_name", type=str, default="127.0.0.1") parser.add_argument("--server_name", type=str, default="127.0.0.1")
parser.add_argument("--server_port", type=int, default=7860) parser.add_argument("--server_port", type=int, default=7860)
args = parser.parse_args() args = parser.parse_args()
# init model torch dtype # init model torch dtype
torch_dtype = args.torch_dtype torch_dtype = args.torch_dtype
if torch_dtype =="" or torch_dtype == "bfloat16": if torch_dtype == "" or torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
elif torch_dtype == "float32": elif torch_dtype == "float32":
torch_dtype = torch.float32 torch_dtype = torch.float32
elif torch_dtype == "float16":
torch_dtype = torch.float16
else: else:
raise ValueError(f"Invalid torch dtype: {torch_dtype}") raise ValueError(f"Invalid torch dtype: {torch_dtype}")
@ -36,8 +39,8 @@ tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True)
# init gradio demo host and port # init gradio demo host and port
server_name=args.server_name server_name = args.server_name
server_port=args.server_port server_port = args.server_port
def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int): def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
"""generate model output with huggingface api """generate model output with huggingface api
@ -50,7 +53,7 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f
Yields: Yields:
str: real-time generation results of hf model str: real-time generation results of hf model
""" """
inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False) inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
enc = tokenizer(inputs, return_tensors="pt").to("cuda") enc = tokenizer(inputs, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer) streamer = TextIteratorStreamer(tokenizer)
@ -73,7 +76,8 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f
yield answer[4 + len(inputs):] yield answer[4 + len(inputs):]
def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int): def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float,
max_dec_len: int):
"""generate after hitting "submit" button """generate after hitting "submit" button
Args: Args:
@ -85,7 +89,7 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, r
Yields: Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round. List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
""" """
assert query != "", "Input must not be empty!!!" assert query != "", "Input must not be empty!!!"
# apply chat template # apply chat template
model_input = [] model_input = []
@ -111,7 +115,7 @@ def regenerate(chat_history: List, top_p: float, temperature: float, repetition_
Yields: Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
""" """
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!" assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
# apply chat template # apply chat template
model_input = [] model_input = []
@ -130,7 +134,7 @@ def clear_history():
Returns: Returns:
List: empty chat history List: empty chat history
""" """
return [] return []
@ -142,7 +146,7 @@ def reverse_last_round(chat_history):
Returns: Returns:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round. List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round.
""" """
assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!" assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!"
return chat_history[:-1] return chat_history[:-1]
@ -166,8 +170,10 @@ with gr.Blocks(theme="soft") as demo:
regen = gr.Button("Regenerate") regen = gr.Button("Regenerate")
reverse = gr.Button("Reverse") reverse = gr.Button("Reverse")
submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot]) submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, max_dec_len],
regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot]) outputs=[user_input, chatbot])
regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, max_dec_len],
outputs=[user_input, chatbot])
clear.click(clear_history, inputs=[], outputs=[chatbot]) clear.click(clear_history, inputs=[], outputs=[chatbot])
reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot]) reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot])

View File

@ -5,7 +5,7 @@ Using Code is modified from https://github.com/ml-explore/mlx-examples.
Using Model with https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx Using Model with https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx
Use this Code with command: Use this Code with command:
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data AdvertiseGen --train --seed 2024 --iters 1000 python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data finetune/data/AdvertiseGen --train --seed 2024 --iters 1000
""" """
import argparse import argparse
import json import json
@ -329,7 +329,7 @@ def build_parser():
"--data", "--data",
type=str, type=str,
default="data/", default="data/",
help="Directory with {train, valid, test}.jsonl files", help="Directory with {train, valid, test}.json files",
) )
parser.add_argument( parser.add_argument(
"--lora-layers", "--lora-layers",
@ -395,35 +395,26 @@ def build_parser():
return parser return parser
class ConversationDataset: class ConversationDataset:
"""
Light-weight wrapper to handle conversation data from a jsonl file.
Each data entry is expected to have a "conversations" list, with each item
containing "role" and "content".
"""
def __init__(self, path: Path): def __init__(self, path: Path):
with open(path, "r") as fid: with open(path, "r") as fid:
self._data = [json.loads(l) for l in fid] self._data = [json.loads(l) for l in fid]
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
conversation = self._data[idx]["conversations"] entry = self._data[idx]
user_texts = [] content = entry.get("content", "")
assistant_texts = [] summary = entry.get("summary", "")
for turn in conversation: return content, summary
if turn["role"] == "user":
user_texts.append(turn["content"])
elif turn["role"] == "assistant":
assistant_texts.append(turn["content"])
return " ".join(user_texts), " ".join(assistant_texts)
def __len__(self): def __len__(self):
return len(self._data) return len(self._data)
def load(args): def load(args):
def load_and_check(name): def load_and_check(name):
dataset_path = Path(args.data) / f"{name}.jsonl" dataset_path = Path(args.data) / f"{name}.json"
try: try:
return ConversationDataset(dataset_path) return ConversationDataset(dataset_path)
except Exception as e: except Exception as e: