mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-02-02 04:28:01 +08:00
Merge pull request #301 from kvcache-ai/fix-cuda-graph-bug
warm_up before capture
This commit is contained in:
commit
cc8d627e32
@ -18,6 +18,8 @@ from ktransformers.models.custom_cache import StaticCache
|
|||||||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||||
from ktransformers.util.textstream import TextStreamer
|
from ktransformers.util.textstream import TextStreamer
|
||||||
|
|
||||||
|
warm_uped = False
|
||||||
|
|
||||||
def set_module(model, submodule_key, module):
|
def set_module(model, submodule_key, module):
|
||||||
tokens = submodule_key.split('.')
|
tokens = submodule_key.split('.')
|
||||||
sub_tokens = tokens[:-1]
|
sub_tokens = tokens[:-1]
|
||||||
@ -99,6 +101,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||||||
tokens = []
|
tokens = []
|
||||||
|
|
||||||
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, use_cuda_graph: bool = True):
|
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, use_cuda_graph: bool = True):
|
||||||
|
if cuda_graph_runner is None:
|
||||||
|
use_cuda_graph = False
|
||||||
if use_cuda_graph:
|
if use_cuda_graph:
|
||||||
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
|
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
|
||||||
else:
|
else:
|
||||||
@ -182,14 +186,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
seq_length += 1
|
seq_length += 1
|
||||||
|
|
||||||
if use_cuda_graph:
|
cuda_graph_runner = None
|
||||||
cuda_graph_runner = CUDAGraphRunner()
|
|
||||||
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
|
||||||
else:
|
|
||||||
cuda_graph_runner = None
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for _ in range(1, max_new_tokens):
|
for i in range(1, max_new_tokens):
|
||||||
|
global warm_uped
|
||||||
|
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
||||||
|
warm_uped = True
|
||||||
|
cuda_graph_runner = CUDAGraphRunner()
|
||||||
|
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||||
|
|
||||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
||||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||||
generated_ids[:, cache_position] = next_token.int()
|
generated_ids[:, cache_position] = next_token.int()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user