mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-19 12:53:36 +08:00
更新训练集类型和推理fp16部分
This commit is contained in:
parent
111657c02c
commit
fdaab94f1e
@ -7,26 +7,29 @@ import gradio as gr
|
||||
import torch
|
||||
from threading import Thread
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
TextIteratorStreamer
|
||||
)
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, default="")
|
||||
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])
|
||||
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"])
|
||||
parser.add_argument("--server_name", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--server_port", type=int, default=7860)
|
||||
args = parser.parse_args()
|
||||
|
||||
# init model torch dtype
|
||||
torch_dtype = args.torch_dtype
|
||||
if torch_dtype =="" or torch_dtype == "bfloat16":
|
||||
if torch_dtype == "" or torch_dtype == "bfloat16":
|
||||
torch_dtype = torch.bfloat16
|
||||
elif torch_dtype == "float32":
|
||||
torch_dtype = torch.float32
|
||||
elif torch_dtype == "float16":
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
|
||||
|
||||
@ -36,8 +39,8 @@ tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True)
|
||||
|
||||
# init gradio demo host and port
|
||||
server_name=args.server_name
|
||||
server_port=args.server_port
|
||||
server_name = args.server_name
|
||||
server_port = args.server_port
|
||||
|
||||
def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
|
||||
"""generate model output with huggingface api
|
||||
@ -50,7 +53,7 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f
|
||||
|
||||
Yields:
|
||||
str: real-time generation results of hf model
|
||||
"""
|
||||
"""
|
||||
inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
|
||||
enc = tokenizer(inputs, return_tensors="pt").to("cuda")
|
||||
streamer = TextIteratorStreamer(tokenizer)
|
||||
@ -73,7 +76,8 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f
|
||||
yield answer[4 + len(inputs):]
|
||||
|
||||
|
||||
def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
|
||||
def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float,
|
||||
max_dec_len: int):
|
||||
"""generate after hitting "submit" button
|
||||
|
||||
Args:
|
||||
@ -85,7 +89,7 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, r
|
||||
|
||||
Yields:
|
||||
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
|
||||
"""
|
||||
"""
|
||||
assert query != "", "Input must not be empty!!!"
|
||||
# apply chat template
|
||||
model_input = []
|
||||
@ -111,7 +115,7 @@ def regenerate(chat_history: List, top_p: float, temperature: float, repetition_
|
||||
|
||||
Yields:
|
||||
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
|
||||
"""
|
||||
"""
|
||||
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
|
||||
# apply chat template
|
||||
model_input = []
|
||||
@ -130,7 +134,7 @@ def clear_history():
|
||||
|
||||
Returns:
|
||||
List: empty chat history
|
||||
"""
|
||||
"""
|
||||
return []
|
||||
|
||||
|
||||
@ -142,7 +146,7 @@ def reverse_last_round(chat_history):
|
||||
|
||||
Returns:
|
||||
List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round.
|
||||
"""
|
||||
"""
|
||||
assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!"
|
||||
return chat_history[:-1]
|
||||
|
||||
@ -166,8 +170,10 @@ with gr.Blocks(theme="soft") as demo:
|
||||
regen = gr.Button("Regenerate")
|
||||
reverse = gr.Button("Reverse")
|
||||
|
||||
submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot])
|
||||
regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot])
|
||||
submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, max_dec_len],
|
||||
outputs=[user_input, chatbot])
|
||||
regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, max_dec_len],
|
||||
outputs=[user_input, chatbot])
|
||||
clear.click(clear_history, inputs=[], outputs=[chatbot])
|
||||
reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot])
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ Using Code is modified from https://github.com/ml-explore/mlx-examples.
|
||||
Using Model with https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx
|
||||
|
||||
Use this Code with command:
|
||||
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data AdvertiseGen --train --seed 2024 --iters 1000
|
||||
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data finetune/data/AdvertiseGen --train --seed 2024 --iters 1000
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
@ -329,7 +329,7 @@ def build_parser():
|
||||
"--data",
|
||||
type=str,
|
||||
default="data/",
|
||||
help="Directory with {train, valid, test}.jsonl files",
|
||||
help="Directory with {train, valid, test}.json files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-layers",
|
||||
@ -395,35 +395,26 @@ def build_parser():
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
|
||||
class ConversationDataset:
|
||||
"""
|
||||
Light-weight wrapper to handle conversation data from a jsonl file.
|
||||
Each data entry is expected to have a "conversations" list, with each item
|
||||
containing "role" and "content".
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path):
|
||||
with open(path, "r") as fid:
|
||||
self._data = [json.loads(l) for l in fid]
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
conversation = self._data[idx]["conversations"]
|
||||
user_texts = []
|
||||
assistant_texts = []
|
||||
for turn in conversation:
|
||||
if turn["role"] == "user":
|
||||
user_texts.append(turn["content"])
|
||||
elif turn["role"] == "assistant":
|
||||
assistant_texts.append(turn["content"])
|
||||
return " ".join(user_texts), " ".join(assistant_texts)
|
||||
entry = self._data[idx]
|
||||
content = entry.get("content", "")
|
||||
summary = entry.get("summary", "")
|
||||
return content, summary
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
|
||||
def load(args):
|
||||
def load_and_check(name):
|
||||
dataset_path = Path(args.data) / f"{name}.jsonl"
|
||||
dataset_path = Path(args.data) / f"{name}.json"
|
||||
try:
|
||||
return ConversationDataset(dataset_path)
|
||||
except Exception as e:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user