diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index cc57997..7915654 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -374,6 +374,16 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): compressed_kv = self.kv_a_layernorm(compressed_kv) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim) compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank) + + kv_seq_len = q_len + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2) @@ -453,26 +463,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models k_pe.squeeze(0) compressed_kv.squeeze(0) - past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) - k_pe.unsqueeze(0) - compressed_kv.unsqueeze(0) - - k_pe = k_pe[:, :q_len] - compressed_kv = compressed_kv[:, :q_len] + compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) + compressed_kv, k_pe = torch.split( + compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim) + k_pe = k_pe[:, :kv_seq_len] + compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank) + compressed_kv = compressed_kv[:, :kv_seq_len] kv = ( self.kv_b_proj(compressed_kv) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) ) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim) + key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim) key_states[:, :, :, :self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim:] = k_pe + key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1) - value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim) + value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim) value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) attn_output = flash_attn_func(