mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-02-05 22:13:16 +08:00
Update attention.py
This commit is contained in:
parent
d90749d35d
commit
92399283b6
@ -262,7 +262,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# flash attn doesn't support head_dim bigger than 256
|
# flash attn doesn't support head_dim bigger than 256
|
||||||
# use vLLM triton attention kernel for MQA
|
# use triton attention kernel adapted from vLLM and SGLang for MQA
|
||||||
decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
|
decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
|
||||||
page_table,
|
page_table,
|
||||||
position_ids.squeeze(0).to(torch.int32), attn_logits,
|
position_ids.squeeze(0).to(torch.int32), attn_logits,
|
||||||
@ -551,4 +551,4 @@ class KLlamaAttention(BaseInjectedModule):
|
|||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user