mirror of
https://github.com/RYDE-WORK/MiniCPM.git
synced 2026-01-30 03:03:24 +08:00
更新训练集类型和推理fp16部分
This commit is contained in:
parent
111657c02c
commit
fdaab94f1e
@ -7,26 +7,29 @@ import gradio as gr
|
|||||||
import torch
|
import torch
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
TextIteratorStreamer
|
TextIteratorStreamer
|
||||||
)
|
)
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--model_path", type=str, default="")
|
parser.add_argument("--model_path", type=str, default="")
|
||||||
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])
|
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"])
|
||||||
parser.add_argument("--server_name", type=str, default="127.0.0.1")
|
parser.add_argument("--server_name", type=str, default="127.0.0.1")
|
||||||
parser.add_argument("--server_port", type=int, default=7860)
|
parser.add_argument("--server_port", type=int, default=7860)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# init model torch dtype
|
# init model torch dtype
|
||||||
torch_dtype = args.torch_dtype
|
torch_dtype = args.torch_dtype
|
||||||
if torch_dtype =="" or torch_dtype == "bfloat16":
|
if torch_dtype == "" or torch_dtype == "bfloat16":
|
||||||
torch_dtype = torch.bfloat16
|
torch_dtype = torch.bfloat16
|
||||||
elif torch_dtype == "float32":
|
elif torch_dtype == "float32":
|
||||||
torch_dtype = torch.float32
|
torch_dtype = torch.float32
|
||||||
|
elif torch_dtype == "float16":
|
||||||
|
torch_dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
|
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
|
||||||
|
|
||||||
@ -36,8 +39,8 @@ tokenizer = AutoTokenizer.from_pretrained(path)
|
|||||||
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True)
|
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True)
|
||||||
|
|
||||||
# init gradio demo host and port
|
# init gradio demo host and port
|
||||||
server_name=args.server_name
|
server_name = args.server_name
|
||||||
server_port=args.server_port
|
server_port = args.server_port
|
||||||
|
|
||||||
def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
|
def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
|
||||||
"""generate model output with huggingface api
|
"""generate model output with huggingface api
|
||||||
@ -50,7 +53,7 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f
|
|||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
str: real-time generation results of hf model
|
str: real-time generation results of hf model
|
||||||
"""
|
"""
|
||||||
inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
|
inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
|
||||||
enc = tokenizer(inputs, return_tensors="pt").to("cuda")
|
enc = tokenizer(inputs, return_tensors="pt").to("cuda")
|
||||||
streamer = TextIteratorStreamer(tokenizer)
|
streamer = TextIteratorStreamer(tokenizer)
|
||||||
@ -73,7 +76,8 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f
|
|||||||
yield answer[4 + len(inputs):]
|
yield answer[4 + len(inputs):]
|
||||||
|
|
||||||
|
|
||||||
def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
|
def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float,
|
||||||
|
max_dec_len: int):
|
||||||
"""generate after hitting "submit" button
|
"""generate after hitting "submit" button
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -85,7 +89,7 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, r
|
|||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
|
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
|
||||||
"""
|
"""
|
||||||
assert query != "", "Input must not be empty!!!"
|
assert query != "", "Input must not be empty!!!"
|
||||||
# apply chat template
|
# apply chat template
|
||||||
model_input = []
|
model_input = []
|
||||||
@ -111,7 +115,7 @@ def regenerate(chat_history: List, top_p: float, temperature: float, repetition_
|
|||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
|
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
|
||||||
"""
|
"""
|
||||||
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
|
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
|
||||||
# apply chat template
|
# apply chat template
|
||||||
model_input = []
|
model_input = []
|
||||||
@ -130,7 +134,7 @@ def clear_history():
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List: empty chat history
|
List: empty chat history
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@ -142,7 +146,7 @@ def reverse_last_round(chat_history):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round.
|
List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round.
|
||||||
"""
|
"""
|
||||||
assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!"
|
assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!"
|
||||||
return chat_history[:-1]
|
return chat_history[:-1]
|
||||||
|
|
||||||
@ -166,8 +170,10 @@ with gr.Blocks(theme="soft") as demo:
|
|||||||
regen = gr.Button("Regenerate")
|
regen = gr.Button("Regenerate")
|
||||||
reverse = gr.Button("Reverse")
|
reverse = gr.Button("Reverse")
|
||||||
|
|
||||||
submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot])
|
submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, max_dec_len],
|
||||||
regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot])
|
outputs=[user_input, chatbot])
|
||||||
|
regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, max_dec_len],
|
||||||
|
outputs=[user_input, chatbot])
|
||||||
clear.click(clear_history, inputs=[], outputs=[chatbot])
|
clear.click(clear_history, inputs=[], outputs=[chatbot])
|
||||||
reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot])
|
reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot])
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ Using Code is modified from https://github.com/ml-explore/mlx-examples.
|
|||||||
Using Model with https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx
|
Using Model with https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx
|
||||||
|
|
||||||
Use this Code with command:
|
Use this Code with command:
|
||||||
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data AdvertiseGen --train --seed 2024 --iters 1000
|
python mlx_finetune.py --model MiniCPM-2B-sft-bf16-llama-format-mlx --data finetune/data/AdvertiseGen --train --seed 2024 --iters 1000
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
@ -329,7 +329,7 @@ def build_parser():
|
|||||||
"--data",
|
"--data",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/",
|
default="data/",
|
||||||
help="Directory with {train, valid, test}.jsonl files",
|
help="Directory with {train, valid, test}.json files",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-layers",
|
"--lora-layers",
|
||||||
@ -395,35 +395,26 @@ def build_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationDataset:
|
class ConversationDataset:
|
||||||
"""
|
|
||||||
Light-weight wrapper to handle conversation data from a jsonl file.
|
|
||||||
Each data entry is expected to have a "conversations" list, with each item
|
|
||||||
containing "role" and "content".
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, path: Path):
|
def __init__(self, path: Path):
|
||||||
with open(path, "r") as fid:
|
with open(path, "r") as fid:
|
||||||
self._data = [json.loads(l) for l in fid]
|
self._data = [json.loads(l) for l in fid]
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
conversation = self._data[idx]["conversations"]
|
entry = self._data[idx]
|
||||||
user_texts = []
|
content = entry.get("content", "")
|
||||||
assistant_texts = []
|
summary = entry.get("summary", "")
|
||||||
for turn in conversation:
|
return content, summary
|
||||||
if turn["role"] == "user":
|
|
||||||
user_texts.append(turn["content"])
|
|
||||||
elif turn["role"] == "assistant":
|
|
||||||
assistant_texts.append(turn["content"])
|
|
||||||
return " ".join(user_texts), " ".join(assistant_texts)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._data)
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
def load(args):
|
def load(args):
|
||||||
def load_and_check(name):
|
def load_and_check(name):
|
||||||
dataset_path = Path(args.data) / f"{name}.jsonl"
|
dataset_path = Path(args.data) / f"{name}.json"
|
||||||
try:
|
try:
|
||||||
return ConversationDataset(dataset_path)
|
return ConversationDataset(dataset_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user