From 5977a63783814ae3478cbbf53d52071ac2837522 Mon Sep 17 00:00:00 2001 From: huangyuxiang03 Date: Sat, 3 Feb 2024 22:13:15 +0800 Subject: [PATCH] Fix: flash_attn and cpu inference --- model/configuration_minicpm.py | 5 +++++ model/modeling_minicpm.py | 8 ++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/model/configuration_minicpm.py b/model/configuration_minicpm.py index 692883c..21eb158 100644 --- a/model/configuration_minicpm.py +++ b/model/configuration_minicpm.py @@ -174,6 +174,11 @@ class MiniCPMConfig(PretrainedConfig): tie_word_embeddings=tie_word_embeddings, **kwargs, ) + try: + import flash_attn + self._attn_implementation = "flash_attention_2" + except: + pass def _rope_scaling_validation(self): """ diff --git a/model/modeling_minicpm.py b/model/modeling_minicpm.py index 23420d0..5a8455b 100644 --- a/model/modeling_minicpm.py +++ b/model/modeling_minicpm.py @@ -51,10 +51,11 @@ from transformers.utils.import_utils import is_torch_fx_available from .configuration_minicpm import MiniCPMConfig import re - -if is_flash_attn_2_available(): +try: from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +except: + pass # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. @@ -125,7 +126,7 @@ ALL_LAYERNORM_LAYERS.append(MiniCPMRMSNorm) class MiniCPMRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device="cuda"): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim @@ -763,7 +764,6 @@ class MiniCPMDecoderLayer(nn.Module): def __init__(self, config: MiniCPMConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = MiniCPMMLP(config)