From f029588b61b0385c6cffa414438ab47877ae314b Mon Sep 17 00:00:00 2001 From: Xie Weiyu Date: Tue, 18 Feb 2025 11:39:45 +0800 Subject: [PATCH] fix server warmup --- .../backend/interfaces/ktransformers.py | 27 ++++++++++--------- .../server/backend/interfaces/transformers.py | 4 +-- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 6b8c45a..d4f5562 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -73,13 +73,13 @@ class KTransformersInterface(TransformersInterface): self._infer_lock = asyncio.Lock() - def decode_one_tokens(self, i): + def decode_one_tokens(self): device_map = self.model.gguf_loader.tensor_device_map torch_device = get_device("blk.0.self_attn", device_map) torch_device = "cuda:0" if torch_device == "cuda" else torch_device global warm_uped - if self.args.use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ): - warm_uped = True + if self.args.use_cuda_graph and warm_uped == True: + if not hasattr(self, "cuda_graph_runner"): self.cuda_graph_runner = CUDAGraphRunner() self.cuda_graph_runner.capture( @@ -93,15 +93,18 @@ class KTransformersInterface(TransformersInterface): use_cache=True, ) - if hasattr(self, "cuda_graph_runner"): - logits = self.cuda_graph_runner( - self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position - ) - self.cache.change_seq_length(1) - torch.cuda.synchronize() - logits = logits[0, -1, :] - return self.logits_to_token(logits) - + if hasattr(self, "cuda_graph_runner"): + logits = self.cuda_graph_runner( + self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position + ) + self.cache.change_seq_length(1) + torch.cuda.synchronize() + logits = logits[0, -1, :] + return self.logits_to_token(logits) + + if self.args.use_cuda_graph: + warm_uped = True + if self.use_static_cache: mask = torch.ones((1, self.seq_length)).to(torch_device) logits = self.model( diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index d00fc02..4021670 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -219,7 +219,7 @@ class TransformersInterface(BackendInterfaceBase): self.ever_generated_ids.add(last) return last - def decode_one_tokens(self, i): + def decode_one_tokens(self): if self.use_static_cache: mask = torch.ones((1, self.seq_length)).to(self.args.device) logits = self.model( @@ -299,7 +299,7 @@ class TransformersInterface(BackendInterfaceBase): num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size, sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) - next_token = self.decode_one_tokens(i) + next_token = self.decode_one_tokens() self.profiler.inc("decode") if next_token == self.tokenizer.eos_token_id: assert self.args.batch_size == 1