From ff6b265e536b93d498244d7e1551793732a64b84 Mon Sep 17 00:00:00 2001 From: Azure Date: Sun, 16 Feb 2025 06:03:12 +0000 Subject: [PATCH] Mock triton mla due to precision issue --- ktransformers/operators/attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 650f0ae..2971cc7 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -43,11 +43,13 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): orig_module: nn.Module, device: str = "cuda", chunck_size: int = 1000, + use_triton: bool = False, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) self.chunck_size = chunck_size # TODO, generate chunck_size automatically. + self.use_triton = use_triton def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): @@ -401,7 +403,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if os.name == 'nt': + if not self.use_triton: # os.name == 'nt' return self.forward_windows( hidden_states, attention_mask,