From 476b1d8dc69da6c8823cc7a7d004e574d1cabaf1 Mon Sep 17 00:00:00 2001 From: Azure Date: Fri, 31 Jan 2025 08:27:24 +0000 Subject: [PATCH] support deepseekv3; runable but have precition problem --- ktransformers/local_chat.py | 8 +- .../models/configuration_deepseekv3.py | 231 ++++ ktransformers/models/custom_cache.py | 11 +- ktransformers/models/modeling_deepseekv3.py | 1216 +++++++++++++++++ ktransformers/operators/attention.py | 201 +++ ktransformers/operators/experts.py | 101 ++ ktransformers/operators/gate.py | 128 ++ ktransformers/operators/linear.py | 18 +- ktransformers/operators/models.py | 6 +- .../DeepSeek-V3-Chat-multi-gpu.yaml | 143 ++ .../optimize_rules/DeepSeek-V3-Chat.yaml | 56 + .../backend/interfaces/ktransformers.py | 79 +- .../server/backend/interfaces/transformers.py | 4 +- 13 files changed, 2178 insertions(+), 24 deletions(-) create mode 100644 ktransformers/models/configuration_deepseekv3.py create mode 100644 ktransformers/models/modeling_deepseekv3.py create mode 100644 ktransformers/operators/gate.py create mode 100644 ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml create mode 100644 ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 41f98a1..cec7e5d 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -15,6 +15,7 @@ from ktransformers.server.args import ArgumentParser from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM +from ktransformers.models.modeling_deepseekv3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM @@ -22,6 +23,7 @@ from ktransformers.server.config.config import Config custom_models = { "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, + "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM, "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, "MixtralForCausalLM": MixtralForCausalLM, @@ -30,6 +32,8 @@ custom_models = { ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" default_optimize_rules = { "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", + # "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", + "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-multi-gpu.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml", "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml", @@ -74,8 +78,8 @@ def local_chat(): else: content += line + "\n" if content == "": - if config.prompt_file == None or config.prompt_file == "": - content = "Please write a piece of quicksort code in C++." + if True: # config.prompt_file == None or config.prompt_file == "": + content = "hi" else: content = open(config.prompt_file, "r").read() elif os.path.isfile(content): diff --git a/ktransformers/models/configuration_deepseekv3.py b/ktransformers/models/configuration_deepseekv3.py new file mode 100644 index 0000000..5c599b3 --- /dev/null +++ b/ktransformers/models/configuration_deepseekv3.py @@ -0,0 +1,231 @@ +# coding=utf-8 +# Copyright 2025 bzantium and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on the DeepSeekV3 implementations from the DeepSeek AI team. (https://huggingface.co/deepseek-ai/DeepSeek-V3) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" DeepSeekV3 model configuration """ + +from transformers.configuration_utils import PretrainedConfig + + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +class DeepseekV3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 7168): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 18432): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 2048): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 61): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 128): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 128): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + n_shared_experts (`int`, *optional*, defaults to 1): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to 256): + Number of routed experts, None means dense model. + ep_size (``, *optional*, defaults to 1): + routed_scaling_factor (`float`, *optional*, defaults to 2.5): + Scaling factor or routed experts. + kv_lora_rank (``, *optional*, defaults to 512): + q_lora_rank (``, *optional*, defaults to 1536): + qk_rope_head_dim (``, *optional*, defaults to 64): + v_head_dim (``, *optional*, defaults to 128): + qk_nope_head_dim (``, *optional*, defaults to 128): + topk_method (`str`, *optional*, defaults to `"noaux_tc"`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to 8): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to 4): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to 8): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 3): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to `"sigmoid"`): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + Whether to compute the auxiliary loss for each individual sample. + seq_aux (``, *optional*, defaults to `True`): + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ```python + >>> from transformers import DeepseekV3Model, DeepseekV3Config + >>> # Initializing a Deepseek-V3 style configuration + >>> configuration = DeepseekV3Config() + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=129280, + hidden_size=7168, + intermediate_size=18432, + moe_intermediate_size = 2048, + num_hidden_layers=61, + num_nextn_predict_layers=1, + num_attention_heads=128, + num_key_value_heads=128, + n_shared_experts = 1, + n_routed_experts = 256, + ep_size = 1, + routed_scaling_factor = 2.5, + kv_lora_rank = 512, + q_lora_rank = 1536, + qk_rope_head_dim = 64, + v_head_dim = 128, + qk_nope_head_dim = 128, + topk_method = 'noaux_tc', + n_group = 8, + topk_group = 4, + num_experts_per_tok = 8, + moe_layer_freq = 1, + first_k_dense_replace = 3, + norm_topk_prob = True, + scoring_func = 'sigmoid', + aux_loss_alpha = 0.001, + seq_aux = True, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=0, + eos_token_id=1, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_nextn_predict_layers = num_nextn_predict_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["DeepseekV3Config"] \ No newline at end of file diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index dbaea57..c85c7bb 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -34,9 +34,12 @@ class StaticCache(transformers.StaticCache): self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) + if config.architectures[0] == "DeepseekV3ForCausalLM": + self.head_dim = config.qk_rope_head_dim + else: + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) self.dtype = dtype if dtype is not None else torch.float32 self.num_key_value_heads = ( @@ -46,7 +49,7 @@ class StaticCache(transformers.StaticCache): self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - if config.architectures[0] == "DeepseekV2ForCausalLM": + if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM": # TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically # key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim) # value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim) diff --git a/ktransformers/models/modeling_deepseekv3.py b/ktransformers/models/modeling_deepseekv3.py new file mode 100644 index 0000000..5ab042c --- /dev/null +++ b/ktransformers/models/modeling_deepseekv3.py @@ -0,0 +1,1216 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deepseekv3/modular_deepseekv3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deepseekv3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +# from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel # ALL_ATTENTION_FUNCTIONS, PreTrainedModel +# from transformers.processing_utils import Unpack +from transformers.utils import ( + # LossKwargs, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg +from .configuration_deepseekv3 import DeepseekV3Config + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "DeepseekV3Config" + + +class DeepseekV3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class DeepseekV3RotaryEmbedding(nn.Module): + def __init__(self, config: DeepseekV3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class DeepseekV3MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.moe_intermediate_size + # TODO rm hard coding + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)# config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) + if self.topk_method == "noaux_tc": + self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None) + if self.scoring_func == "sigmoid": + scores = logits.sigmoid() + else: + raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") + + ### select top-k experts + if self.topk_method == "noaux_tc": + # assert not self.training + scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] + _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) + topk_weight = scores.gather(1, topk_idx) + else: + raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}") + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor + + return topk_idx, topk_weight + + +class DeepseekV3MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + if hasattr(config, "ep_size") and config.ep_size > 1: + assert config.ep_size == dist.get_world_size() + self.ep_size = config.ep_size + self.experts_per_rank = config.n_routed_experts // config.ep_size + self.ep_rank = dist.get_rank() + self.experts = nn.ModuleList( + [ + ( + DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) + if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank + else None + ) + for i in range(config.n_routed_experts) + ] + ) + else: + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = nn.ModuleList( + [ + DeepseekV3MLP(config) + for i in range(config.n_routed_experts) + ] + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV3MLP(config=config) + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + if not self.training: + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + sorted_tokens_shape = sorted_tokens.shape + if self.ep_size > 1: + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0]) + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) + output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist() + gathered_tokens = sorted_tokens.new_empty( + tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] + ) + input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() + dist.all_to_all( + list(gathered_tokens.split(output_splits)), + list(sorted_tokens.split(input_split_sizes)), + ) + tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum( + dim=0 + ) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % self.experts_per_rank + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + if self.ep_size > 1: + new_x = torch.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = new_x.new_empty(*sorted_tokens_shape) + dist.all_to_all( + list(gathered_tokens.split(input_split_sizes)), + list(new_x.split(output_splits)), + ) + outs = gathered_tokens + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class DeepseekV3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + + self.rotary_emb = DeepseekV3RotaryEmbedding( + config=self.config, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs# : Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + pass + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class DeepseekV3DecoderLayer(nn.Module): + def __init__(self, config: DeepseekV3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx) + + self.mlp = ( + DeepseekV3MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV3MLP(config) + ) + self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs# : Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +DEEPSEEKV3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DeepseekV3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", + DEEPSEEKV3_START_DOCSTRING, +) +class DeepseekV3PreTrainedModel(PreTrainedModel): + config_class = DeepseekV3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +DEEPSEEKV3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", + DEEPSEEKV3_START_DOCSTRING, +) +class DeepseekV3Model(DeepseekV3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`] + + Args: + config: DeepseekV3Config + """ + + def __init__(self, config: DeepseekV3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs# : Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs# : Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM + + >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The DeepseekV3 Model transformer with a sequence classification head on top (linear layer). + + [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DEEPSEEKV3_START_DOCSTRING, +) +class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeepseekV3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) \ No newline at end of file diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index ff2d644..b3b1802 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -13,6 +13,7 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.modeling_llama import LlamaRotaryEmbedding from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb +from ktransformers.models.modeling_deepseekv3 import DeepseekV3Attention, apply_rotary_pos_emb from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_gguf import GGUFLoader @@ -20,6 +21,206 @@ import logging from transformers.configuration_utils import PretrainedConfig from transformers.cache_utils import Cache logger = logging.getLogger("attention") + +class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + attn_mask: Optional[torch.Tensor] = None + + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + device: str = "cuda", + chunck_size: int = 1000, + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) + self.orig_module.__init__(orig_module.config, + orig_module.layer_idx) + self.chunck_size = chunck_size # TODO, generate chunck_size automatically. + self.softmax_scale = self.q_head_dim ** (-0.5) + + def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: + if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): + kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) + q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank) + out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank) + self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, + bias=False, dtype=q_absorb.dtype, device=q_absorb.device) + self.q_absorb.weight.data = q_absorb + self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, + bias=False, dtype=out_absorb.dtype, device=out_absorb.device) + self.out_absorb.weight.data = out_absorb + del self.orig_module.kv_b_proj + q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) + out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank) + return q_absorb, out_absorb + + def forward_chunck( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + compressed_kv = self.kv_a_layernorm(compressed_kv) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + + kv_seq_len = k_pe.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(q_pe, position_ids) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + compressed_kv = compressed_kv.unsqueeze(1) + k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) + compressed_kv = compressed_kv.squeeze(1) + #if cache_position is not None: + # compressed_kv = compressed_kv[:,: cache_position[-1] + 1,:] + # k_pe = k_pe[:,:,: cache_position[-1] + 1,:] + q_absorb, out_absorb = self.get_absorbed() + + q_nope = torch.matmul(q_nope, q_absorb) + attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale + """ + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + assert attention_mask is not None + """ + if attention_mask is not None: + """ + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + """ + #causal_mask = attention_mask[:, :, :, : kv_seq_len] + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(q_pe.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv) + + attn_output = torch.matmul(attn_output, out_absorb.mT) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if q_len <= self.chunck_size: + return self.forward_chunck( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs + ) + + assert output_attentions == False, "output_attentions is not supported when using chunked attention" + attn_output = None + attn_weight = None + cur_idx = 0 + while cur_idx < q_len: + if attention_mask is not None: + chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...] + else: + # generate chunk_mask automatically. + self.attn_mask = \ + torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \ + if self.attn_mask is None \ + else self.attn_mask + self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \ + -1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\ + [:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))] + self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38 + self.attn_mask[:, :, :, :cur_idx] = 0 + chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx)) + + cur_output, cur_attn_weight = self.forward_chunck( + hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...], + chunk_mask, + position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)], + past_key_value, + output_attentions, + use_cache, + cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)], + **kwargs + ) + cur_idx += self.chunck_size + if attn_output is None: + attn_output = cur_output + attn_weight = cur_attn_weight + else: + attn_output = torch.cat((attn_output, cur_output), dim=-2) + attn_weight = torch.cat((attn_weight, cur_attn_weight), dim=-2) + + return attn_output, attn_weight + class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" attn_mask: Optional[torch.Tensor] = None diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 81135ea..ddfcda9 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -519,6 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase): from ktransformers.models.modeling_deepseek import DeepseekV2MoE +from ktransformers.models.modeling_deepseekv3 import DeepseekV3MoE from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock @@ -727,6 +728,106 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE): ) return final_out +class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + sequence_length = orig_shape[1] + topk_idx, topk_weight= self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): + self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) + if self.config.n_shared_experts is not None: + y_ = self.shared_experts(identity).squeeze(0) + y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) + y += y_ + y.resize_(*orig_shape) + return y + + if self.config.n_shared_experts is not None: + y_ = self.shared_experts(identity).squeeze(0) + + if isinstance(self.experts, KExpertsBase): + y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device) + elif hidden_states.size(0) > 10: + # TODO may bugs here + y = ( + self.moe_infer(hidden_states, topk_idx, topk_weight) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + else: + # TODO may bugs here + y = ( + self.moe_infer_simple(hidden_states, topk_idx, topk_weight) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + if self.config.n_shared_experts is not None: + y += y_ + return y + + @torch.no_grad() + def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: + outs = torch.empty_like(x) + outs = self.experts(x, topk_ids, topk_weight) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer_simple( + self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor + ) -> torch.Tensor: + """ + x: [num_tokens, hidden_size] + topk_ids, topk_weight: [num_tokens, num_selected_experts] + """ + outs = torch.zeros_like(x) + for token_idx in range(topk_ids.size(0)): + for expert_idx in range(topk_ids.size(1)): + expert = self.experts[topk_ids[token_idx, expert_idx]] + outs[token_idx] += ( + expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] + ) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert.forward(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py new file mode 100644 index 0000000..91a3872 --- /dev/null +++ b/ktransformers/operators/gate.py @@ -0,0 +1,128 @@ + +from typing import Any, Union +import numpy as np +import numpy.typing as npt +from torch import Tensor, nn +import torch.nn.functional as F +import torch +import sys, os +from ktransformers.operators.base_operator import BaseInjectedModule + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build")) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release")) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug")) +import cpuinfer_ext +from cpuinfer_ext.moe import MOEConfig, MOE +import ctypes +from ktransformers.operators.base_operator import BaseInjectedModule +from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.models.modeling_deepseekv3 import MoEGate +from ktransformers.util.utils import InferenceState +from ktransformers.server.config.config import Config +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from abc import ABC, abstractmethod +import time + + +# class Base(BaseInjectedModule, ABC): +class KMoEGateBase(ABC): + def __init__(self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + device: str = "cuda", + **kwargs): + # super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) + super().__init__() + self.key = key + self.gguf_loader = gguf_loader + self.config = config + self.device = device + self.orig_module = orig_module + + @abstractmethod + def forward(self, input_tensor, expert_ids, weights): + pass + + @abstractmethod + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu", warmup: bool = False): + pass + + @abstractmethod + def unload(): + pass + + def load_weights(self, override_key: str | None = None, device: str = "cpu"): + res = {} + if override_key is not None: + keys = override_key + else: + keys = [self.key] + + gate = None + up = None + down = None + gate_type = None + up_type = None + down_type = None + + for key in keys: + key = ".".join(key.split(".")[:-1]) + if key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info: + targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"] + tensors = self.load_multi(key, targets, device=device) + weight = tensors[".ffn_gate_inp.weight"] + e_score_correction_bias = tensors[".exp_probs_b.bias"] + weight_type = self.gguf_loader.tensor_info[key + ".ffn_gate_inp.weight"]["ggml_type"] + e_score_correction_bias_type = self.gguf_loader.tensor_info[key + ".exp_probs_b.bias"]["ggml_type"] + else: + raise ValueError(f"Experts {key} not found in gguf_loader") + res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type} + return res + + def load_multi(self, key: str, keys: list[str], device: str = "cpu"): + tensors = {} + for k in keys: + tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device) + return tensors + + +class KMoEGate(BaseInjectedModule, KMoEGateBase): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + generate_device: str = "cuda", + prefill_device: str = "cuda", + **kwargs, + ): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + + def forward(self, hidden_states) -> torch.Tensor: + return self.orig_module.forward(hidden_states) + + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): + if device is None: device = self.device + if w is None: w = self.load_weights(device=device) + + if isinstance(w, dict): + self.weight_type = w["weight_type"] + self.e_score_correction_bias_type = w["e_score_correction_bias_type"] + self.orig_module.weight = nn.Parameter(w["weight"]) + self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"]) + else: + raise ValueError("Invalid weight type") + self.orig_module.weight = self.orig_module.weight.to(device) + if self.topk_method == "noaux_tc": + self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device) + + def unload(self): + if self.weight is not None: + self.weight = None + if self.topk_method == "noaux_tc": + self.e_score_correction_bias = None diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 7cdb204..7510f82 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -54,15 +54,15 @@ class KLinearBase(ABC): self.has_bias = False self.dtype = torch.get_default_dtype() - if orig_module is not None: - self.in_features = orig_module.in_features - self.out_features = orig_module.out_features - else: - shape = self.gguf_loader.tensor_info[key + ".weight"]["shape"] - if len(shape) == 1: - print("Warning: orig_module is not set, but has in_features or out_features equals to 1, can't get in_features and out_features from GGUF") - self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0] - self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1] + # if orig_module is not None: + # self.in_features = orig_module.in_features + # self.out_features = orig_module.out_features + # else: + shape = self.gguf_loader.tensor_info[key + ".weight"]["shape"] + if len(shape) == 1: + print("Warning: orig_module is not set, but has in_features or out_features equals to 1, can't get in_features and out_features from GGUF") + self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0] + self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1] @abstractmethod def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index f6e85c0..9fa1a19 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -641,6 +641,7 @@ class KDeepseekV2Model(BaseInjectedModule): if inputs_embeds is None: org_device = input_ids.device + # TODO move to embed_tokens's device, not hard code to cpu input_ids = input_ids.to("cpu") inputs_embeds = self.embed_tokens(input_ids) input_ids = input_ids.to(org_device) @@ -737,8 +738,9 @@ class KDeepseekV2Model(BaseInjectedModule): hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + # @@@@@@@ TODO open this notes, tmp close to fit deepseekv3 + # if use_cache: + # next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml new file mode 100644 index 0000000..3fd86d9 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml @@ -0,0 +1,143 @@ +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([3456][0-9])\\." + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([3456][0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" + class: ktransformers.models.modeling_deepseekv3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" + class: ktransformers.models.modeling_deepseekv3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:0" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda:0" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:1" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda:1" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV3Attention # optimized MLA implementation + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([3456][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV3Attention # optimized MLA implementation + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill + transfer_map: + 30: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." + replace: + class: "default" + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + +- match: + name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" + replace: + class: "default" + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml new file mode 100644 index 0000000..6fb87b7 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml @@ -0,0 +1,56 @@ +- match: + class: ktransformers.models.modeling_deepseek.DeepseekV3YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 420f37e..c34f17f 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -46,17 +46,26 @@ class KTransformersInterface(TransformersInterface): ) optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config) - device_map = self.model.gguf_loader.tensor_device_map - logger.info(f"{args.model_name} loaded from {args.model_dir} to {device_map}") + self.device_map = self.model.gguf_loader.tensor_device_map + # logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}") self.cache = StaticCache( config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, - device=device_map, + device=self.device_map, dtype=self.model.dtype, ) - logger.info(f"StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}") - self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir) + # logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}") + try: + self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir) + except: + gen_config = GenerationConfig( + max_length=128, + temperature=0.7, + top_p=0.9, + do_sample=True + ) + self.model.generation_config = gen_config if self.model.generation_config.pad_token_id is None: self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.streamer = TextStreamer(self.tokenizer) @@ -102,3 +111,63 @@ class KTransformersInterface(TransformersInterface): logits = logits[0, -1, :] return self.logits_to_token(logits) + + + + @torch.no_grad + def prefill(self, input_ids: torch.Tensor, is_new: bool): + input_ids_length = input_ids.shape[-1] + self.profiler.set_counter("prefill", input_ids_length) + logger.debug(f"input_ids: {input_ids.shape}") + + device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") + + if is_new: + self.cache.reset() + self.ever_generated_ids.clear() + former_seq_length = 0 + self.seq_length = input_ids_length + self.generated_ids = torch.zeros( + self.args.batch_size, + self.seq_length + self.args.max_new_tokens + 1, + dtype=torch.int, + device=self.args.device, + ) + else: + logger.debug(f"generate_ids: {self.generated_ids.shape}") + former_seq_length = self.seq_length + self.seq_length += input_ids_length + expected_length = self.seq_length + self.args.max_new_tokens + 1 + delta_length = expected_length - self.generated_ids.shape[-1] + if delta_length > 0: + new_generate_ids = torch.zeros( + self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device + ) + self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) + logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") + cache_position = torch.arange(former_seq_length, self.seq_length, device=device) + self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) + + mask = torch.ones((1, self.seq_length)).to(device) + if not (type(self) is TransformersInterface): + input_ids = input_ids.to("cpu") + inputs_embeds = self.model.model.embed_tokens(input_ids).to(device) + if self.use_static_cache: + logits = self.model( + inputs_embeds=inputs_embeds, + cache_position=cache_position, + past_key_values=self.cache, + return_dict=False, + use_cache=True, + attention_mask=mask, + )[0] + else: + logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] + + next_token = self.logits_to_token(logits[0, -1, :]) + yield self.append_new_tokens(next_token) + + @property + def active_cache_position(self): + device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") + return torch.tensor([self.seq_length - 1], device=device) \ No newline at end of file diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index f205ac5..517045b 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -134,7 +134,7 @@ class TransformersInterface(BackendInterfaceBase): self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir) self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True) - logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}") + # logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}") self.cache = StaticCache( config=self.model.config, @@ -143,7 +143,7 @@ class TransformersInterface(BackendInterfaceBase): device=args.device, dtype=self.model.dtype, ) - logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}") + # logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}") self.streamer = TextStreamer(self.tokenizer)