From a52951834671e461a720649cbad8f64bec0684a3 Mon Sep 17 00:00:00 2001 From: Atream Date: Wed, 19 Feb 2025 04:42:47 +0000 Subject: [PATCH] clean PR code and disable flashinfer --- ktransformers/operators/attention.py | 24 ++++++------------- ktransformers/operators/flashinfer_wrapper.py | 2 +- .../server/backend/interfaces/transformers.py | 10 ++++---- 3 files changed, 13 insertions(+), 23 deletions(-) diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 7915654..85378ee 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -58,18 +58,10 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) - q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank) - out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank) - self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, - bias=False, dtype=q_absorb.dtype, device=q_absorb.device) - self.q_absorb.weight.data = q_absorb - self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, - bias=False, dtype=out_absorb.dtype, device=out_absorb.device) - self.out_absorb.weight.data = out_absorb - #del self.orig_module.kv_b_proj - q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) - out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank) - return q_absorb, out_absorb + self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) + self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank) + + return self.q_absorb, self.out_absorb def forward_chunck( self, @@ -105,7 +97,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): if past_key_value is not None: if self.layer_idx is None: raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) @@ -129,8 +121,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): # compressed_kv [pages, page_size, 1, self.kv_lora_rank] q_absorb, out_absorb = self.get_absorbed() - # if hasattr(self.orig_module, 'kv_b_proj'): - # del self.orig_module.kv_b_proj # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim] # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim] @@ -227,7 +217,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): if past_key_value is not None: if self.layer_idx is None: raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) @@ -379,7 +369,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): if past_key_value is not None: if self.layer_idx is None: raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + f"The cache structure has changed since version transformer verision v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) diff --git a/ktransformers/operators/flashinfer_wrapper.py b/ktransformers/operators/flashinfer_wrapper.py index 8d49187..b3b9dd1 100644 --- a/ktransformers/operators/flashinfer_wrapper.py +++ b/ktransformers/operators/flashinfer_wrapper.py @@ -9,7 +9,7 @@ flashinfer_enabled = False try: import flashinfer - flashinfer_enabled = True + flashinfer_enabled = False # disabled now, TODO:use new version of flashinfer and enable print("found flashinfer") except ImportError: diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index deb6cfa..8211933 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -381,13 +381,13 @@ class TransformersInterface(BackendInterfaceBase): self.profiler.create_and_start_timer("prefill") - + if Config().user_force_think: + think = '\n' + print(think, end="",flush=True) + yield think + for t in self.prefill(input_ids, self.check_is_new(thread_id)): # output think token after prefill done - if Config().user_force_think: - think = '\n' - print(think, end="",flush=True) - yield think if t is not None: print(t, end="",flush=True) yield t