mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-02-05 14:03:17 +08:00
Merge pull request #315 from kvcache-ai/Atream-add-adapted
Atream add adapted
This commit is contained in:
commit
f9f9f746c0
@ -262,7 +262,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# flash attn doesn't support head_dim bigger than 256
|
# flash attn doesn't support head_dim bigger than 256
|
||||||
# use vLLM triton attention kernel for MQA
|
# use triton attention kernel adapted from vLLM and SGLang for MQA
|
||||||
decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
|
decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
|
||||||
page_table,
|
page_table,
|
||||||
position_ids.squeeze(0).to(torch.int32), attn_logits,
|
position_ids.squeeze(0).to(torch.int32), attn_logits,
|
||||||
|
|||||||
@ -1,3 +1,9 @@
|
|||||||
|
# Adapted from
|
||||||
|
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
|
||||||
|
# which was originally adapted from
|
||||||
|
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
|
||||||
|
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user