更新训练集类型和推理fp16部分

This commit is contained in:
zR 2024-04-03 14:47:45 +08:00
parent 111657c02c
commit fdaab94f1e
2 changed files with 28 additions and 31 deletions

View File

@ -7,26 +7,29 @@ import gradio as gr
import torch
from threading import Thread
from transformers import (
AutoModelForCausalLM,
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer
)
import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"])
parser.add_argument("--server_name", type=str, default="127.0.0.1")
parser.add_argument("--server_port", type=int, default=7860)
args = parser.parse_args()
# init model torch dtype
torch_dtype = args.torch_dtype
if torch_dtype =="" or torch_dtype == "bfloat16":
if torch_dtype == "" or torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
elif torch_dtype == "float16":
torch_dtype = torch.float16
else:
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
@ -36,8 +39,8 @@ tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True)
# init gradio demo host and port
server_name=args.server_name
server_port=args.server_port
server_name = args.server_name
server_port = args.server_port
def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
"""generate model output with huggingface api
@ -50,7 +53,7 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f
Yields:
str: real-time generation results of hf model
"""
"""
inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
enc = tokenizer(inputs, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer)
@ -73,7 +76,8 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f
yield answer[4 + len(inputs):]
def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float,
max_dec_len: int):
"""generate after hitting "submit" button
Args:
@ -85,7 +89,7 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, r
Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
"""
"""
assert query != "", "Input must not be empty!!!"
# apply chat template
model_input = []
@ -111,7 +115,7 @@ def regenerate(chat_history: List, top_p: float, temperature: float, repetition_
Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
"""
"""
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
# apply chat template
model_input = []
@ -130,7 +134,7 @@ def clear_history():
Returns:
List: empty chat history
"""
"""
return []
@ -142,7 +146,7 @@ def reverse_last_round(chat_history):
Returns:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round.
"""
"""
assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!"
return chat_history[:-1]
@ -166,8 +170,10 @@ with gr.Blocks(theme="soft") as demo:
regen = gr.Button("Regenerate")
reverse = gr.Button("Reverse")
submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot])
regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot])
submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, max_dec_len],
outputs=[user_input, chatbot])
regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, max_dec_len],
outputs=[user_input, chatbot])
clear.click(clear_history, inputs=[], outputs=[chatbot])
reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot])

View File

@ -5,7 +5,7 @@ Using Code is modified from https://github.com/ml-explore/mlx-examples.
Using Model with https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx
Use this Code with command:
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data AdvertiseGen --train --seed 2024 --iters 1000
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data finetune/data/AdvertiseGen --train --seed 2024 --iters 1000
"""
import argparse
import json
@ -329,7 +329,7 @@ def build_parser():
"--data",
type=str,
default="data/",
help="Directory with {train, valid, test}.jsonl files",
help="Directory with {train, valid, test}.json files",
)
parser.add_argument(
"--lora-layers",
@ -395,35 +395,26 @@ def build_parser():
return parser
class ConversationDataset:
"""
Light-weight wrapper to handle conversation data from a jsonl file.
Each data entry is expected to have a "conversations" list, with each item
containing "role" and "content".
"""
def __init__(self, path: Path):
with open(path, "r") as fid:
self._data = [json.loads(l) for l in fid]
def __getitem__(self, idx: int):
conversation = self._data[idx]["conversations"]
user_texts = []
assistant_texts = []
for turn in conversation:
if turn["role"] == "user":
user_texts.append(turn["content"])
elif turn["role"] == "assistant":
assistant_texts.append(turn["content"])
return " ".join(user_texts), " ".join(assistant_texts)
entry = self._data[idx]
content = entry.get("content", "")
summary = entry.get("summary", "")
return content, summary
def __len__(self):
return len(self._data)
def load(args):
def load_and_check(name):
dataset_path = Path(args.data) / f"{name}.jsonl"
dataset_path = Path(args.data) / f"{name}.json"
try:
return ConversationDataset(dataset_path)
except Exception as e: