diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index b2bc9d6..5db643a 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -18,6 +18,8 @@ from ktransformers.models.custom_cache import StaticCache from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.textstream import TextStreamer +warm_uped = False + def set_module(model, submodule_key, module): tokens = submodule_key.split('.') sub_tokens = tokens[:-1] @@ -99,6 +101,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud tokens = [] def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, use_cuda_graph: bool = True): + if cuda_graph_runner is None: + use_cuda_graph = False if use_cuda_graph: logits = cuda_graph_runner(cur_token, position_ids, cache_position) else: @@ -182,14 +186,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud position_ids = cache_position.unsqueeze(0) seq_length += 1 - if use_cuda_graph: - cuda_graph_runner = CUDAGraphRunner() - cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True) - else: - cuda_graph_runner = None + cuda_graph_runner = None start_time = time.time() - for _ in range(1, max_new_tokens): + for i in range(1, max_new_tokens): + global warm_uped + if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ): + warm_uped = True + cuda_graph_runner = CUDAGraphRunner() + cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True) + next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) generated_ids[:, cache_position] = next_token.int()