From 600e00dba3c9c829bf8a852b67698280eb20043b Mon Sep 17 00:00:00 2001 From: "Y.W. Fang" <1157670798@qq.com> Date: Mon, 5 Feb 2024 21:55:22 +0800 Subject: [PATCH] add repetition_penalty and set opk=0 in hf-based demo --- demo/hf_based_demo.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/demo/hf_based_demo.py b/demo/hf_based_demo.py index 8fe8782..6677d9b 100644 --- a/demo/hf_based_demo.py +++ b/demo/hf_based_demo.py @@ -39,7 +39,7 @@ model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, devi server_name=args.server_name server_port=args.server_port -def hf_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): +def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int): """generate model output with huggingface api Args: @@ -57,8 +57,10 @@ def hf_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): generation_kwargs = dict( enc, do_sample=True, + top_k=0, top_p=top_p, temperature=temperature, + repetition_penalty=repetition_penalty, max_new_tokens=max_dec_len, pad_token_id=tokenizer.eos_token_id, streamer=streamer, @@ -71,7 +73,7 @@ def hf_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): yield answer[4 + len(inputs):] -def generate(chat_history: List, query: str, top_p: float, temperature: float, max_dec_len: int): +def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int): """generate after hitting "submit" button Args: @@ -93,12 +95,12 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, m model_input.append({"role": "user", "content": query}) # yield model generation chat_history.append([query, ""]) - for answer in hf_gen(model_input, top_p, temperature, max_dec_len): + for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len): chat_history[-1][1] = answer.strip("") yield gr.update(value=""), chat_history -def regenerate(chat_history: List, top_p: float, temperature: float, max_dec_len: int): +def regenerate(chat_history: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int): """re-generate the answer of last round's query Args: @@ -118,7 +120,7 @@ def regenerate(chat_history: List, top_p: float, temperature: float, max_dec_len model_input.append({"role": "assistant", "content": a}) model_input.append({"role": "user", "content": chat_history[-1][0]}) # yield model generation - for answer in hf_gen(model_input, top_p, temperature, max_dec_len): + for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len): chat_history[-1][1] = answer.strip("") yield gr.update(value=""), chat_history @@ -152,7 +154,8 @@ with gr.Blocks(theme="soft") as demo: with gr.Row(): with gr.Column(scale=1): top_p = gr.Slider(0, 1, value=0.8, step=0.1, label="top_p") - temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.1, label="temperature") + temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature") + repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, step=0.1, label="repetition_penalty") max_dec_len = gr.Slider(1, 1024, value=1024, step=1, label="max_dec_len") with gr.Column(scale=5): chatbot = gr.Chatbot(bubble_full_width=False, height=400) @@ -163,8 +166,8 @@ with gr.Blocks(theme="soft") as demo: regen = gr.Button("Regenerate") reverse = gr.Button("Reverse") - submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, max_dec_len], outputs=[user_input, chatbot]) - regen.click(regenerate, inputs=[chatbot, top_p, temperature, max_dec_len], outputs=[user_input, chatbot]) + submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot]) + regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot]) clear.click(clear_history, inputs=[], outputs=[chatbot]) reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot])