From d90749d35d9282adfa722039dd568f4896ff04b4 Mon Sep 17 00:00:00 2001 From: Atream <80757050+Atream@users.noreply.github.com> Date: Sat, 15 Feb 2025 15:41:01 +0800 Subject: [PATCH 1/2] Update triton_attention.py --- ktransformers/operators/triton_attention.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ktransformers/operators/triton_attention.py b/ktransformers/operators/triton_attention.py index d622be1..4437520 100644 --- a/ktransformers/operators/triton_attention.py +++ b/ktransformers/operators/triton_attention.py @@ -1,3 +1,9 @@ +# Adapted from +# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +# which was originally adapted from +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py + import triton import triton.language as tl @@ -376,4 +382,4 @@ def decode_attention_fwd_grouped( ) _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, - num_kv_splits) \ No newline at end of file + num_kv_splits) From 92399283b6c2c20168b95ac70ca282075400f74f Mon Sep 17 00:00:00 2001 From: Atream <80757050+Atream@users.noreply.github.com> Date: Sat, 15 Feb 2025 15:43:35 +0800 Subject: [PATCH 2/2] Update attention.py --- ktransformers/operators/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 9f73d48..650f0ae 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -262,7 +262,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): """ # flash attn doesn't support head_dim bigger than 256 - # use vLLM triton attention kernel for MQA + # use triton attention kernel adapted from vLLM and SGLang for MQA decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output, page_table, position_ids.squeeze(0).to(torch.int32), attn_logits, @@ -551,4 +551,4 @@ class KLlamaAttention(BaseInjectedModule): if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value \ No newline at end of file + return attn_output, attn_weights, past_key_value