diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 9f73d48..650f0ae 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -262,7 +262,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): """ # flash attn doesn't support head_dim bigger than 256 - # use vLLM triton attention kernel for MQA + # use triton attention kernel adapted from vLLM and SGLang for MQA decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output, page_table, position_ids.squeeze(0).to(torch.int32), attn_logits, @@ -551,4 +551,4 @@ class KLlamaAttention(BaseInjectedModule): if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value \ No newline at end of file + return attn_output, attn_weights, past_key_value