warm_up before capture

This commit is contained in:
Atream 2025-02-14 15:52:21 +00:00
parent cadd55078f
commit 1946493f2d

View File

@ -18,6 +18,8 @@ from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.util.textstream import TextStreamer
warm_uped = False
def set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
sub_tokens = tokens[:-1]
@ -99,6 +101,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
tokens = []
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, use_cuda_graph: bool = True):
if cuda_graph_runner is None:
use_cuda_graph = False
if use_cuda_graph:
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
else:
@ -182,14 +186,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
position_ids = cache_position.unsqueeze(0)
seq_length += 1
if use_cuda_graph:
cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
else:
cuda_graph_runner = None
cuda_graph_runner = None
start_time = time.time()
for _ in range(1, max_new_tokens):
for i in range(1, max_new_tokens):
global warm_uped
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
warm_uped = True
cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()