diff --git a/demo/vllm_based_demo.py b/demo/vllm_based_demo.py index efefb37..cf471c4 100644 --- a/demo/vllm_based_demo.py +++ b/demo/vllm_based_demo.py @@ -56,6 +56,7 @@ def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): """ assert len(dialog) % 2 == 1 prompt = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False) + token_ids = tokenizer.convert_tokens_to_ids(["<|im_end|>"]) params_dict = { "n": 1, "best_of": 1, @@ -67,8 +68,8 @@ def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): "use_beam_search": False, "length_penalty": 1, "early_stopping": False, - "stop": None, - "stop_token_ids": None, + "stop": "<|im_end|>", + "stop_token_ids": token_ids, "ignore_eos": False, "max_tokens": max_dec_len, "logprobs": None,