From 476b1d8dc69da6c8823cc7a7d004e574d1cabaf1 Mon Sep 17 00:00:00 2001 From: Azure Date: Fri, 31 Jan 2025 08:27:24 +0000 Subject: [PATCH 01/26] support deepseekv3; runable but have precition problem --- ktransformers/local_chat.py | 8 +- .../models/configuration_deepseekv3.py | 231 ++++ ktransformers/models/custom_cache.py | 11 +- ktransformers/models/modeling_deepseekv3.py | 1216 +++++++++++++++++ ktransformers/operators/attention.py | 201 +++ ktransformers/operators/experts.py | 101 ++ ktransformers/operators/gate.py | 128 ++ ktransformers/operators/linear.py | 18 +- ktransformers/operators/models.py | 6 +- .../DeepSeek-V3-Chat-multi-gpu.yaml | 143 ++ .../optimize_rules/DeepSeek-V3-Chat.yaml | 56 + .../backend/interfaces/ktransformers.py | 79 +- .../server/backend/interfaces/transformers.py | 4 +- 13 files changed, 2178 insertions(+), 24 deletions(-) create mode 100644 ktransformers/models/configuration_deepseekv3.py create mode 100644 ktransformers/models/modeling_deepseekv3.py create mode 100644 ktransformers/operators/gate.py create mode 100644 ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml create mode 100644 ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 41f98a1..cec7e5d 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -15,6 +15,7 @@ from ktransformers.server.args import ArgumentParser from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM +from ktransformers.models.modeling_deepseekv3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM @@ -22,6 +23,7 @@ from ktransformers.server.config.config import Config custom_models = { "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, + "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM, "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, "MixtralForCausalLM": MixtralForCausalLM, @@ -30,6 +32,8 @@ custom_models = { ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" default_optimize_rules = { "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", + # "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", + "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-multi-gpu.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml", "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml", @@ -74,8 +78,8 @@ def local_chat(): else: content += line + "\n" if content == "": - if config.prompt_file == None or config.prompt_file == "": - content = "Please write a piece of quicksort code in C++." + if True: # config.prompt_file == None or config.prompt_file == "": + content = "hi" else: content = open(config.prompt_file, "r").read() elif os.path.isfile(content): diff --git a/ktransformers/models/configuration_deepseekv3.py b/ktransformers/models/configuration_deepseekv3.py new file mode 100644 index 0000000..5c599b3 --- /dev/null +++ b/ktransformers/models/configuration_deepseekv3.py @@ -0,0 +1,231 @@ +# coding=utf-8 +# Copyright 2025 bzantium and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on the DeepSeekV3 implementations from the DeepSeek AI team. (https://huggingface.co/deepseek-ai/DeepSeek-V3) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" DeepSeekV3 model configuration """ + +from transformers.configuration_utils import PretrainedConfig + + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +class DeepseekV3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 7168): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 18432): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 2048): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 61): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 128): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 128): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + n_shared_experts (`int`, *optional*, defaults to 1): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to 256): + Number of routed experts, None means dense model. + ep_size (``, *optional*, defaults to 1): + routed_scaling_factor (`float`, *optional*, defaults to 2.5): + Scaling factor or routed experts. + kv_lora_rank (``, *optional*, defaults to 512): + q_lora_rank (``, *optional*, defaults to 1536): + qk_rope_head_dim (``, *optional*, defaults to 64): + v_head_dim (``, *optional*, defaults to 128): + qk_nope_head_dim (``, *optional*, defaults to 128): + topk_method (`str`, *optional*, defaults to `"noaux_tc"`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to 8): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to 4): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to 8): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 3): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to `"sigmoid"`): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + Whether to compute the auxiliary loss for each individual sample. + seq_aux (``, *optional*, defaults to `True`): + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ```python + >>> from transformers import DeepseekV3Model, DeepseekV3Config + >>> # Initializing a Deepseek-V3 style configuration + >>> configuration = DeepseekV3Config() + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=129280, + hidden_size=7168, + intermediate_size=18432, + moe_intermediate_size = 2048, + num_hidden_layers=61, + num_nextn_predict_layers=1, + num_attention_heads=128, + num_key_value_heads=128, + n_shared_experts = 1, + n_routed_experts = 256, + ep_size = 1, + routed_scaling_factor = 2.5, + kv_lora_rank = 512, + q_lora_rank = 1536, + qk_rope_head_dim = 64, + v_head_dim = 128, + qk_nope_head_dim = 128, + topk_method = 'noaux_tc', + n_group = 8, + topk_group = 4, + num_experts_per_tok = 8, + moe_layer_freq = 1, + first_k_dense_replace = 3, + norm_topk_prob = True, + scoring_func = 'sigmoid', + aux_loss_alpha = 0.001, + seq_aux = True, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=0, + eos_token_id=1, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_nextn_predict_layers = num_nextn_predict_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["DeepseekV3Config"] \ No newline at end of file diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index dbaea57..c85c7bb 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -34,9 +34,12 @@ class StaticCache(transformers.StaticCache): self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) + if config.architectures[0] == "DeepseekV3ForCausalLM": + self.head_dim = config.qk_rope_head_dim + else: + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) self.dtype = dtype if dtype is not None else torch.float32 self.num_key_value_heads = ( @@ -46,7 +49,7 @@ class StaticCache(transformers.StaticCache): self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - if config.architectures[0] == "DeepseekV2ForCausalLM": + if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM": # TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically # key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim) # value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim) diff --git a/ktransformers/models/modeling_deepseekv3.py b/ktransformers/models/modeling_deepseekv3.py new file mode 100644 index 0000000..5ab042c --- /dev/null +++ b/ktransformers/models/modeling_deepseekv3.py @@ -0,0 +1,1216 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deepseekv3/modular_deepseekv3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deepseekv3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +# from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel # ALL_ATTENTION_FUNCTIONS, PreTrainedModel +# from transformers.processing_utils import Unpack +from transformers.utils import ( + # LossKwargs, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg +from .configuration_deepseekv3 import DeepseekV3Config + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "DeepseekV3Config" + + +class DeepseekV3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class DeepseekV3RotaryEmbedding(nn.Module): + def __init__(self, config: DeepseekV3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class DeepseekV3MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.moe_intermediate_size + # TODO rm hard coding + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)# config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) + if self.topk_method == "noaux_tc": + self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None) + if self.scoring_func == "sigmoid": + scores = logits.sigmoid() + else: + raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") + + ### select top-k experts + if self.topk_method == "noaux_tc": + # assert not self.training + scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] + _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) + topk_weight = scores.gather(1, topk_idx) + else: + raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}") + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor + + return topk_idx, topk_weight + + +class DeepseekV3MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + if hasattr(config, "ep_size") and config.ep_size > 1: + assert config.ep_size == dist.get_world_size() + self.ep_size = config.ep_size + self.experts_per_rank = config.n_routed_experts // config.ep_size + self.ep_rank = dist.get_rank() + self.experts = nn.ModuleList( + [ + ( + DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) + if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank + else None + ) + for i in range(config.n_routed_experts) + ] + ) + else: + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = nn.ModuleList( + [ + DeepseekV3MLP(config) + for i in range(config.n_routed_experts) + ] + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV3MLP(config=config) + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + if not self.training: + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + sorted_tokens_shape = sorted_tokens.shape + if self.ep_size > 1: + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0]) + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) + output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist() + gathered_tokens = sorted_tokens.new_empty( + tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] + ) + input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() + dist.all_to_all( + list(gathered_tokens.split(output_splits)), + list(sorted_tokens.split(input_split_sizes)), + ) + tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum( + dim=0 + ) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % self.experts_per_rank + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + if self.ep_size > 1: + new_x = torch.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = new_x.new_empty(*sorted_tokens_shape) + dist.all_to_all( + list(gathered_tokens.split(input_split_sizes)), + list(new_x.split(output_splits)), + ) + outs = gathered_tokens + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class DeepseekV3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + + self.rotary_emb = DeepseekV3RotaryEmbedding( + config=self.config, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs# : Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + pass + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class DeepseekV3DecoderLayer(nn.Module): + def __init__(self, config: DeepseekV3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx) + + self.mlp = ( + DeepseekV3MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV3MLP(config) + ) + self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs# : Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +DEEPSEEKV3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DeepseekV3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", + DEEPSEEKV3_START_DOCSTRING, +) +class DeepseekV3PreTrainedModel(PreTrainedModel): + config_class = DeepseekV3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +DEEPSEEKV3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", + DEEPSEEKV3_START_DOCSTRING, +) +class DeepseekV3Model(DeepseekV3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`] + + Args: + config: DeepseekV3Config + """ + + def __init__(self, config: DeepseekV3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs# : Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs# : Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM + + >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The DeepseekV3 Model transformer with a sequence classification head on top (linear layer). + + [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DEEPSEEKV3_START_DOCSTRING, +) +class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeepseekV3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) \ No newline at end of file diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index ff2d644..b3b1802 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -13,6 +13,7 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.modeling_llama import LlamaRotaryEmbedding from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb +from ktransformers.models.modeling_deepseekv3 import DeepseekV3Attention, apply_rotary_pos_emb from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_gguf import GGUFLoader @@ -20,6 +21,206 @@ import logging from transformers.configuration_utils import PretrainedConfig from transformers.cache_utils import Cache logger = logging.getLogger("attention") + +class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + attn_mask: Optional[torch.Tensor] = None + + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + device: str = "cuda", + chunck_size: int = 1000, + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) + self.orig_module.__init__(orig_module.config, + orig_module.layer_idx) + self.chunck_size = chunck_size # TODO, generate chunck_size automatically. + self.softmax_scale = self.q_head_dim ** (-0.5) + + def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: + if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): + kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) + q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank) + out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank) + self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, + bias=False, dtype=q_absorb.dtype, device=q_absorb.device) + self.q_absorb.weight.data = q_absorb + self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, + bias=False, dtype=out_absorb.dtype, device=out_absorb.device) + self.out_absorb.weight.data = out_absorb + del self.orig_module.kv_b_proj + q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) + out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank) + return q_absorb, out_absorb + + def forward_chunck( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + compressed_kv = self.kv_a_layernorm(compressed_kv) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + + kv_seq_len = k_pe.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(q_pe, position_ids) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + compressed_kv = compressed_kv.unsqueeze(1) + k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) + compressed_kv = compressed_kv.squeeze(1) + #if cache_position is not None: + # compressed_kv = compressed_kv[:,: cache_position[-1] + 1,:] + # k_pe = k_pe[:,:,: cache_position[-1] + 1,:] + q_absorb, out_absorb = self.get_absorbed() + + q_nope = torch.matmul(q_nope, q_absorb) + attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale + """ + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + assert attention_mask is not None + """ + if attention_mask is not None: + """ + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + """ + #causal_mask = attention_mask[:, :, :, : kv_seq_len] + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(q_pe.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv) + + attn_output = torch.matmul(attn_output, out_absorb.mT) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if q_len <= self.chunck_size: + return self.forward_chunck( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs + ) + + assert output_attentions == False, "output_attentions is not supported when using chunked attention" + attn_output = None + attn_weight = None + cur_idx = 0 + while cur_idx < q_len: + if attention_mask is not None: + chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...] + else: + # generate chunk_mask automatically. + self.attn_mask = \ + torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \ + if self.attn_mask is None \ + else self.attn_mask + self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \ + -1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\ + [:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))] + self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38 + self.attn_mask[:, :, :, :cur_idx] = 0 + chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx)) + + cur_output, cur_attn_weight = self.forward_chunck( + hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...], + chunk_mask, + position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)], + past_key_value, + output_attentions, + use_cache, + cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)], + **kwargs + ) + cur_idx += self.chunck_size + if attn_output is None: + attn_output = cur_output + attn_weight = cur_attn_weight + else: + attn_output = torch.cat((attn_output, cur_output), dim=-2) + attn_weight = torch.cat((attn_weight, cur_attn_weight), dim=-2) + + return attn_output, attn_weight + class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" attn_mask: Optional[torch.Tensor] = None diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 81135ea..ddfcda9 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -519,6 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase): from ktransformers.models.modeling_deepseek import DeepseekV2MoE +from ktransformers.models.modeling_deepseekv3 import DeepseekV3MoE from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock @@ -727,6 +728,106 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE): ) return final_out +class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + sequence_length = orig_shape[1] + topk_idx, topk_weight= self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): + self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) + if self.config.n_shared_experts is not None: + y_ = self.shared_experts(identity).squeeze(0) + y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) + y += y_ + y.resize_(*orig_shape) + return y + + if self.config.n_shared_experts is not None: + y_ = self.shared_experts(identity).squeeze(0) + + if isinstance(self.experts, KExpertsBase): + y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device) + elif hidden_states.size(0) > 10: + # TODO may bugs here + y = ( + self.moe_infer(hidden_states, topk_idx, topk_weight) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + else: + # TODO may bugs here + y = ( + self.moe_infer_simple(hidden_states, topk_idx, topk_weight) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + if self.config.n_shared_experts is not None: + y += y_ + return y + + @torch.no_grad() + def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: + outs = torch.empty_like(x) + outs = self.experts(x, topk_ids, topk_weight) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer_simple( + self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor + ) -> torch.Tensor: + """ + x: [num_tokens, hidden_size] + topk_ids, topk_weight: [num_tokens, num_selected_experts] + """ + outs = torch.zeros_like(x) + for token_idx in range(topk_ids.size(0)): + for expert_idx in range(topk_ids.size(1)): + expert = self.experts[topk_ids[token_idx, expert_idx]] + outs[token_idx] += ( + expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] + ) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert.forward(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py new file mode 100644 index 0000000..91a3872 --- /dev/null +++ b/ktransformers/operators/gate.py @@ -0,0 +1,128 @@ + +from typing import Any, Union +import numpy as np +import numpy.typing as npt +from torch import Tensor, nn +import torch.nn.functional as F +import torch +import sys, os +from ktransformers.operators.base_operator import BaseInjectedModule + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build")) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release")) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug")) +import cpuinfer_ext +from cpuinfer_ext.moe import MOEConfig, MOE +import ctypes +from ktransformers.operators.base_operator import BaseInjectedModule +from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.models.modeling_deepseekv3 import MoEGate +from ktransformers.util.utils import InferenceState +from ktransformers.server.config.config import Config +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from abc import ABC, abstractmethod +import time + + +# class Base(BaseInjectedModule, ABC): +class KMoEGateBase(ABC): + def __init__(self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + device: str = "cuda", + **kwargs): + # super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) + super().__init__() + self.key = key + self.gguf_loader = gguf_loader + self.config = config + self.device = device + self.orig_module = orig_module + + @abstractmethod + def forward(self, input_tensor, expert_ids, weights): + pass + + @abstractmethod + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu", warmup: bool = False): + pass + + @abstractmethod + def unload(): + pass + + def load_weights(self, override_key: str | None = None, device: str = "cpu"): + res = {} + if override_key is not None: + keys = override_key + else: + keys = [self.key] + + gate = None + up = None + down = None + gate_type = None + up_type = None + down_type = None + + for key in keys: + key = ".".join(key.split(".")[:-1]) + if key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info: + targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"] + tensors = self.load_multi(key, targets, device=device) + weight = tensors[".ffn_gate_inp.weight"] + e_score_correction_bias = tensors[".exp_probs_b.bias"] + weight_type = self.gguf_loader.tensor_info[key + ".ffn_gate_inp.weight"]["ggml_type"] + e_score_correction_bias_type = self.gguf_loader.tensor_info[key + ".exp_probs_b.bias"]["ggml_type"] + else: + raise ValueError(f"Experts {key} not found in gguf_loader") + res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type} + return res + + def load_multi(self, key: str, keys: list[str], device: str = "cpu"): + tensors = {} + for k in keys: + tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device) + return tensors + + +class KMoEGate(BaseInjectedModule, KMoEGateBase): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + generate_device: str = "cuda", + prefill_device: str = "cuda", + **kwargs, + ): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + + def forward(self, hidden_states) -> torch.Tensor: + return self.orig_module.forward(hidden_states) + + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): + if device is None: device = self.device + if w is None: w = self.load_weights(device=device) + + if isinstance(w, dict): + self.weight_type = w["weight_type"] + self.e_score_correction_bias_type = w["e_score_correction_bias_type"] + self.orig_module.weight = nn.Parameter(w["weight"]) + self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"]) + else: + raise ValueError("Invalid weight type") + self.orig_module.weight = self.orig_module.weight.to(device) + if self.topk_method == "noaux_tc": + self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device) + + def unload(self): + if self.weight is not None: + self.weight = None + if self.topk_method == "noaux_tc": + self.e_score_correction_bias = None diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 7cdb204..7510f82 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -54,15 +54,15 @@ class KLinearBase(ABC): self.has_bias = False self.dtype = torch.get_default_dtype() - if orig_module is not None: - self.in_features = orig_module.in_features - self.out_features = orig_module.out_features - else: - shape = self.gguf_loader.tensor_info[key + ".weight"]["shape"] - if len(shape) == 1: - print("Warning: orig_module is not set, but has in_features or out_features equals to 1, can't get in_features and out_features from GGUF") - self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0] - self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1] + # if orig_module is not None: + # self.in_features = orig_module.in_features + # self.out_features = orig_module.out_features + # else: + shape = self.gguf_loader.tensor_info[key + ".weight"]["shape"] + if len(shape) == 1: + print("Warning: orig_module is not set, but has in_features or out_features equals to 1, can't get in_features and out_features from GGUF") + self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0] + self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1] @abstractmethod def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index f6e85c0..9fa1a19 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -641,6 +641,7 @@ class KDeepseekV2Model(BaseInjectedModule): if inputs_embeds is None: org_device = input_ids.device + # TODO move to embed_tokens's device, not hard code to cpu input_ids = input_ids.to("cpu") inputs_embeds = self.embed_tokens(input_ids) input_ids = input_ids.to(org_device) @@ -737,8 +738,9 @@ class KDeepseekV2Model(BaseInjectedModule): hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + # @@@@@@@ TODO open this notes, tmp close to fit deepseekv3 + # if use_cache: + # next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml new file mode 100644 index 0000000..3fd86d9 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml @@ -0,0 +1,143 @@ +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([3456][0-9])\\." + class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([3456][0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" + class: ktransformers.models.modeling_deepseekv3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" + class: ktransformers.models.modeling_deepseekv3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:0" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda:0" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:1" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda:1" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV3Attention # optimized MLA implementation + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([3456][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV3Attention # optimized MLA implementation + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill + transfer_map: + 30: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." + replace: + class: "default" + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + +- match: + name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" + replace: + class: "default" + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml new file mode 100644 index 0000000..6fb87b7 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml @@ -0,0 +1,56 @@ +- match: + class: ktransformers.models.modeling_deepseek.DeepseekV3YarnRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + replace: + class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 420f37e..c34f17f 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -46,17 +46,26 @@ class KTransformersInterface(TransformersInterface): ) optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config) - device_map = self.model.gguf_loader.tensor_device_map - logger.info(f"{args.model_name} loaded from {args.model_dir} to {device_map}") + self.device_map = self.model.gguf_loader.tensor_device_map + # logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}") self.cache = StaticCache( config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, - device=device_map, + device=self.device_map, dtype=self.model.dtype, ) - logger.info(f"StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}") - self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir) + # logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}") + try: + self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir) + except: + gen_config = GenerationConfig( + max_length=128, + temperature=0.7, + top_p=0.9, + do_sample=True + ) + self.model.generation_config = gen_config if self.model.generation_config.pad_token_id is None: self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.streamer = TextStreamer(self.tokenizer) @@ -102,3 +111,63 @@ class KTransformersInterface(TransformersInterface): logits = logits[0, -1, :] return self.logits_to_token(logits) + + + + @torch.no_grad + def prefill(self, input_ids: torch.Tensor, is_new: bool): + input_ids_length = input_ids.shape[-1] + self.profiler.set_counter("prefill", input_ids_length) + logger.debug(f"input_ids: {input_ids.shape}") + + device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") + + if is_new: + self.cache.reset() + self.ever_generated_ids.clear() + former_seq_length = 0 + self.seq_length = input_ids_length + self.generated_ids = torch.zeros( + self.args.batch_size, + self.seq_length + self.args.max_new_tokens + 1, + dtype=torch.int, + device=self.args.device, + ) + else: + logger.debug(f"generate_ids: {self.generated_ids.shape}") + former_seq_length = self.seq_length + self.seq_length += input_ids_length + expected_length = self.seq_length + self.args.max_new_tokens + 1 + delta_length = expected_length - self.generated_ids.shape[-1] + if delta_length > 0: + new_generate_ids = torch.zeros( + self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device + ) + self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) + logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") + cache_position = torch.arange(former_seq_length, self.seq_length, device=device) + self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) + + mask = torch.ones((1, self.seq_length)).to(device) + if not (type(self) is TransformersInterface): + input_ids = input_ids.to("cpu") + inputs_embeds = self.model.model.embed_tokens(input_ids).to(device) + if self.use_static_cache: + logits = self.model( + inputs_embeds=inputs_embeds, + cache_position=cache_position, + past_key_values=self.cache, + return_dict=False, + use_cache=True, + attention_mask=mask, + )[0] + else: + logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] + + next_token = self.logits_to_token(logits[0, -1, :]) + yield self.append_new_tokens(next_token) + + @property + def active_cache_position(self): + device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") + return torch.tensor([self.seq_length - 1], device=device) \ No newline at end of file diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index f205ac5..517045b 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -134,7 +134,7 @@ class TransformersInterface(BackendInterfaceBase): self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir) self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True) - logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}") + # logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}") self.cache = StaticCache( config=self.model.config, @@ -143,7 +143,7 @@ class TransformersInterface(BackendInterfaceBase): device=args.device, dtype=self.model.dtype, ) - logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}") + # logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}") self.streamer = TextStreamer(self.tokenizer) From 5a50b346271b93a828c3002320be6f14e3426433 Mon Sep 17 00:00:00 2001 From: Azure Date: Fri, 31 Jan 2025 15:25:50 +0000 Subject: [PATCH 02/26] fix hard coding caused by rope dim calculation, load from config now --- ktransformers/local_chat.py | 2 +- ktransformers/models/modeling_deepseekv3.py | 2 +- ktransformers/util/modeling_rope_utils.py | 551 ++++++++++++++++++++ 3 files changed, 553 insertions(+), 2 deletions(-) create mode 100644 ktransformers/util/modeling_rope_utils.py diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index cec7e5d..a924a1d 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -78,7 +78,7 @@ def local_chat(): else: content += line + "\n" if content == "": - if True: # config.prompt_file == None or config.prompt_file == "": + if config.prompt_file == None or config.prompt_file == "": content = "hi" else: content = open(config.prompt_file, "r").read() diff --git a/ktransformers/models/modeling_deepseekv3.py b/ktransformers/models/modeling_deepseekv3.py index 5ab042c..d8a888c 100644 --- a/ktransformers/models/modeling_deepseekv3.py +++ b/ktransformers/models/modeling_deepseekv3.py @@ -19,7 +19,7 @@ from transformers.generation import GenerationMixin from transformers.modeling_attn_mask_utils import AttentionMaskConverter # from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ktransformers.util.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import PreTrainedModel # ALL_ATTENTION_FUNCTIONS, PreTrainedModel # from transformers.processing_utils import Unpack from transformers.utils import ( diff --git a/ktransformers/util/modeling_rope_utils.py b/ktransformers/util/modeling_rope_utils.py new file mode 100644 index 0000000..2598a52 --- /dev/null +++ b/ktransformers/util/modeling_rope_utils.py @@ -0,0 +1,551 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import is_torch_available, logging + + +logger = logging.get_logger(__name__) + + +if is_torch_available(): + import torch + + +def _compute_default_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, attention_factor + + +def _compute_linear_scaling_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + factor = rope_kwargs["factor"] + elif config is not None: + factor = config.rope_scaling["factor"] + + # Gets the default RoPE parameters + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) + + # Then applies linear scaling to the frequencies. + # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so + # applying scaling to the inverse frequencies is equivalent. + inv_freq /= factor + return inv_freq, attention_factor + + +def _compute_dynamic_ntk_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length, used to update the dynamic RoPE at inference time. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + max_position_embeddings = rope_kwargs["max_position_embeddings"] + factor = rope_kwargs["factor"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + max_position_embeddings = config.max_position_embeddings + factor = config.rope_scaling["factor"] + + attention_factor = 1.0 # Unused in this type of RoPE + + # seq_len: default to max_position_embeddings, e.g. at init time + seq_len = seq_len if seq_len is not None else max_position_embeddings + + # Compute the inverse frequencies + base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, attention_factor + + +def _compute_yarn_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://arxiv.org/abs/2309.00071) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # No need to keep BC with yarn, unreleased when this new pattern was created. + if len(rope_kwargs) > 0: + raise ValueError( + f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}" + ) + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + dim = config.qk_rope_head_dim + + max_position_embeddings = config.max_position_embeddings + factor = config.rope_scaling["factor"] + + # Sets the attention factor as suggested in the paper + attention_factor = config.rope_scaling.get("attention_factor") + if attention_factor is None: + attention_factor = 0.1 * math.log(factor) + 1.0 + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) + beta_fast = config.rope_scaling.get("beta_fast") or 32 + beta_slow = config.rope_scaling.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): + """Find dimension range bounds based on rotations""" + low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) + + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_mask = 1 - linear_ramp_mask(low, high, dim // 2).float().to(device) + inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + + return inv_freq, attention_factor + + +def _compute_longrope_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with LongRoPE scaling. Please refer to the + [original implementation](https://github.com/microsoft/LongRoPE) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling + # No need to keep BC with longrope, unreleased when this new pattern was created. + if len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got " + f"{rope_kwargs}" + ) + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + long_factor = config.rope_scaling["long_factor"] + short_factor = config.rope_scaling["short_factor"] + factor = config.rope_scaling.get("factor") + attention_factor = config.rope_scaling.get("attention_factor") + + # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a + # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two + # values to compute the default attention scaling factor, instead of using `factor`. + if hasattr(config, "original_max_position_embeddings"): + max_position_embeddings = config.original_max_position_embeddings + expanded_max_position_embeddings = config.max_position_embeddings + factor = expanded_max_position_embeddings / max_position_embeddings + else: + max_position_embeddings = config.max_position_embeddings + expanded_max_position_embeddings = max_position_embeddings * factor + + # Sets the attention factor as suggested in the paper + if attention_factor is None: + if factor <= 1.0: + attention_factor = 1.0 + else: + attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings)) + + # Compute the inverse frequencies -- scaled based on the target sequence length + if expanded_max_position_embeddings > max_position_embeddings: + ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) + else: + ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device) + inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim + inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) + + return inv_freq, attention_factor + + +def _compute_llama3_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies for llama 3.1. + + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # Gets the default RoPE parameters + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) + + factor = config.rope_scaling["factor"] # `8` in the original implementation + low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation + high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation + old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in inv_freq: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + new_freqs.append((1 - smooth) * freq / factor + smooth * freq) + inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device) + return inv_freq, attention_factor + + +# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters +# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE +# parameterizations, as long as the callable has the same signature. +ROPE_INIT_FUNCTIONS = { + "default": _compute_default_rope_parameters, + "linear": _compute_linear_scaling_rope_parameters, + "dynamic": _compute_dynamic_ntk_parameters, + "yarn": _compute_yarn_parameters, + "longrope": _compute_longrope_parameters, + "llama3": _compute_llama3_parameters, +} + + +def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None): + """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" + # BC: "rope_type" was originally "type" -- let's gracefully handle it + if "rope_type" not in received_keys and "type" in received_keys: + received_keys -= {"type"} + received_keys.add("rope_type") + + missing_keys = required_keys - received_keys + if missing_keys: + raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") + + if optional_keys is not None: + unused_keys = received_keys - required_keys - optional_keys + else: + unused_keys = received_keys - required_keys + if unused_keys: + logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") + + +def _validate_default_rope_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys) + + +def _validate_linear_scaling_rope_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + +def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor"} + # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` + optional_keys = {"original_max_position_embeddings"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + +def _validate_yarn_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor"} + optional_keys = {"attention_factor", "beta_fast", "beta_slow"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + beta_fast = rope_scaling.get("beta_fast") + if beta_fast is not None and not isinstance(beta_fast, float): + logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + beta_slow = rope_scaling.get("beta_slow") + if beta_slow is not None and not isinstance(beta_slow, float): + logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + + if (beta_fast or 32) < (beta_slow or 1): + logger.warning( + f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " + f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" + ) + + +def _validate_longrope_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "short_factor", "long_factor"} + # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` + optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + + short_factor = rope_scaling.get("short_factor") + if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): + logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") + if not len(short_factor) == dim // 2: + logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}") + + long_factor = rope_scaling.get("long_factor") + if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): + logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") + if not len(long_factor) == dim // 2: + logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") + + # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over + # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is + # unique to longrope (= undesirable) + if hasattr(config, "original_max_position_embeddings"): + logger.warning_once( + "This model has set a `original_max_position_embeddings` field, to be used together with " + "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`" + "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, " + "as it is compatible with most model architectures." + ) + else: + factor = rope_scaling.get("factor") + if factor is None: + logger.warning("Missing required keys in `rope_scaling`: 'factor'") + elif not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + + +def _validate_llama3_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + if low_freq_factor is None or not isinstance(low_freq_factor, float): + logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") + if high_freq_factor is None or not isinstance(high_freq_factor, float): + logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") + if high_freq_factor < low_freq_factor: + logger.warning( + "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" + f"{high_freq_factor} and low_freq_factor={low_freq_factor}" + ) + + original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] + if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): + logger.warning( + "`rope_scaling`'s original_max_position_embeddings field must be an integer, got " + f"{original_max_position_embeddings}" + ) + if original_max_position_embeddings >= config.max_position_embeddings: + logger.warning( + "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got " + f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}" + ) + + +# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. +ROPE_VALIDATION_FUNCTIONS = { + "default": _validate_default_rope_parameters, + "linear": _validate_linear_scaling_rope_parameters, + "dynamic": _validate_dynamic_scaling_rope_parameters, + "yarn": _validate_yarn_parameters, + "longrope": _validate_longrope_parameters, + "llama3": _validate_llama3_parameters, +} + + +def rope_config_validation(config: PretrainedConfig): + """ + Validate the RoPE config arguments, given a `PretrainedConfig` object + """ + rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig` + if rope_scaling is None: + return + + # BC: "rope_type" was originally "type" + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) + validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) + if validation_fn is not None: + validation_fn(config) + else: + logger.warning( + f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" + ) From f873558a893d611444ad73ecabae4e0ea21f1713 Mon Sep 17 00:00:00 2001 From: Azure Date: Sat, 1 Feb 2025 07:32:21 +0000 Subject: [PATCH 03/26] update rope calculation; update modeling.py; update gate for moe --- ktransformers/configs/config.yaml | 2 +- ktransformers/local_chat.py | 4 +- ...seekv3.py => configuration_deepseek_v3.py} | 98 ++-- ktransformers/models/custom_cache.py | 4 + ..._deepseekv3.py => modeling_deepseek_v3.py} | 512 ++++++++---------- ktransformers/operators/attention.py | 5 +- ktransformers/operators/experts.py | 9 +- ktransformers/operators/gate.py | 7 +- .../DeepSeek-V3-Chat-multi-gpu.yaml | 8 +- ktransformers/server/config/config.py | 2 +- ktransformers/util/modeling_rope_utils.py | 163 +++--- 11 files changed, 402 insertions(+), 412 deletions(-) rename ktransformers/models/{configuration_deepseekv3.py => configuration_deepseek_v3.py} (81%) rename ktransformers/models/{modeling_deepseekv3.py => modeling_deepseek_v3.py} (77%) diff --git a/ktransformers/configs/config.yaml b/ktransformers/configs/config.yaml index 7bde376..80de09a 100644 --- a/ktransformers/configs/config.yaml +++ b/ktransformers/configs/config.yaml @@ -54,4 +54,4 @@ long_context: token_step: local_chat: - prompt_file: "./ktransformers/p.txt" \ No newline at end of file + prompt_file: "" \ No newline at end of file diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index a924a1d..f16ee7f 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -15,7 +15,7 @@ from ktransformers.server.args import ArgumentParser from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM -from ktransformers.models.modeling_deepseekv3 import DeepseekV3ForCausalLM +from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM @@ -78,7 +78,7 @@ def local_chat(): else: content += line + "\n" if content == "": - if config.prompt_file == None or config.prompt_file == "": + if not config.prompt_file: content = "hi" else: content = open(config.prompt_file, "r").read() diff --git a/ktransformers/models/configuration_deepseekv3.py b/ktransformers/models/configuration_deepseek_v3.py similarity index 81% rename from ktransformers/models/configuration_deepseekv3.py rename to ktransformers/models/configuration_deepseek_v3.py index 5c599b3..6227092 100644 --- a/ktransformers/models/configuration_deepseekv3.py +++ b/ktransformers/models/configuration_deepseek_v3.py @@ -14,19 +14,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" DeepSeekV3 model configuration """ +"""DeepSeekV3 model configuration""" from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + class DeepseekV3Config(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the DeepSeek-V3. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. + + Args: vocab_size (`int`, *optional*, defaults to 129280): Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the @@ -39,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig): Dimension of the MoE representations. num_hidden_layers (`int`, *optional*, defaults to 61): Number of hidden layers in the Transformer decoder. - num_nextn_predict_layers (`int`, *optional*, defaults to 1): - Number of nextn predict layers in the DeepSeekV3 Model. num_attention_heads (`int`, *optional*, defaults to 128): Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*, defaults to 128): @@ -52,38 +56,35 @@ class DeepseekV3Config(PretrainedConfig): paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. n_shared_experts (`int`, *optional*, defaults to 1): - Number of shared experts, None means dense model. + Number of shared experts. n_routed_experts (`int`, *optional*, defaults to 256): - Number of routed experts, None means dense model. - ep_size (``, *optional*, defaults to 1): + Number of routed experts. routed_scaling_factor (`float`, *optional*, defaults to 2.5): Scaling factor or routed experts. - kv_lora_rank (``, *optional*, defaults to 512): - q_lora_rank (``, *optional*, defaults to 1536): - qk_rope_head_dim (``, *optional*, defaults to 64): - v_head_dim (``, *optional*, defaults to 128): - qk_nope_head_dim (``, *optional*, defaults to 128): - topk_method (`str`, *optional*, defaults to `"noaux_tc"`): - Topk method used in routed gate. + kv_lora_rank (`int`, *optional*, defaults to 512): + Rank of the LoRA matrices for key and value projections. + q_lora_rank (`int`, *optional*, defaults to 1536): + Rank of the LoRA matrices for query projections. + qk_rope_head_dim (`int`, *optional*, defaults to 64): + Dimension of the query/key heads that use rotary position embeddings. + v_head_dim (`int`, *optional*, defaults to 128): + Dimension of the value heads. + qk_nope_head_dim (`int`, *optional*, defaults to 128): + Dimension of the query/key heads that don't use rotary position embeddings. n_group (`int`, *optional*, defaults to 8): Number of groups for routed experts. topk_group (`int`, *optional*, defaults to 4): Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). num_experts_per_tok (`int`, *optional*, defaults to 8): Number of selected experts, None means dense model. - moe_layer_freq (`int`, *optional*, defaults to 1): - The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. first_k_dense_replace (`int`, *optional*, defaults to 3): Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). \--k dense layers--/ norm_topk_prob (`bool`, *optional*, defaults to `True`): Whether to normalize the weights of the routed experts. - scoring_func (`str`, *optional*, defaults to `"sigmoid"`): - Method of computing expert weights. aux_loss_alpha (`float`, *optional*, defaults to 0.001): Auxiliary loss weight coefficient. Whether to compute the auxiliary loss for each individual sample. - seq_aux (``, *optional*, defaults to `True`): hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 4096): @@ -119,46 +120,49 @@ class DeepseekV3Config(PretrainedConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. + ```python >>> from transformers import DeepseekV3Model, DeepseekV3Config + >>> # Initializing a Deepseek-V3 style configuration >>> configuration = DeepseekV3Config() + >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "deepseek_v3" keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `DeepseekV3Model` + base_model_tp_plan = { + "layers.*.gate_proj": "colwise", + "layers.*.up_proj": "colwise", + "layers.*.down_proj": "rowwise", + } def __init__( self, vocab_size=129280, hidden_size=7168, intermediate_size=18432, - moe_intermediate_size = 2048, + moe_intermediate_size=2048, num_hidden_layers=61, - num_nextn_predict_layers=1, num_attention_heads=128, num_key_value_heads=128, - n_shared_experts = 1, - n_routed_experts = 256, - ep_size = 1, - routed_scaling_factor = 2.5, - kv_lora_rank = 512, - q_lora_rank = 1536, - qk_rope_head_dim = 64, - v_head_dim = 128, - qk_nope_head_dim = 128, - topk_method = 'noaux_tc', - n_group = 8, - topk_group = 4, - num_experts_per_tok = 8, - moe_layer_freq = 1, - first_k_dense_replace = 3, - norm_topk_prob = True, - scoring_func = 'sigmoid', - aux_loss_alpha = 0.001, - seq_aux = True, + n_shared_experts=1, + n_routed_experts=256, + routed_scaling_factor=2.5, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + n_group=8, + topk_group=4, + num_experts_per_tok=8, + first_k_dense_replace=3, + norm_topk_prob=True, + aux_loss_alpha=0.001, hidden_act="silu", max_position_embeddings=4096, initializer_range=0.02, @@ -173,7 +177,6 @@ class DeepseekV3Config(PretrainedConfig): rope_scaling=None, attention_bias=False, attention_dropout=0.0, - mlp_bias=False, **kwargs, ): self.vocab_size = vocab_size @@ -182,27 +185,24 @@ class DeepseekV3Config(PretrainedConfig): self.intermediate_size = intermediate_size self.moe_intermediate_size = moe_intermediate_size self.num_hidden_layers = num_hidden_layers - self.num_nextn_predict_layers = num_nextn_predict_layers self.num_attention_heads = num_attention_heads self.n_shared_experts = n_shared_experts self.n_routed_experts = n_routed_experts - self.ep_size = ep_size self.routed_scaling_factor = routed_scaling_factor self.kv_lora_rank = kv_lora_rank self.q_lora_rank = q_lora_rank self.qk_rope_head_dim = qk_rope_head_dim self.v_head_dim = v_head_dim self.qk_nope_head_dim = qk_nope_head_dim - self.topk_method = topk_method + self.q_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.head_dim = qk_rope_head_dim self.n_group = n_group self.topk_group = topk_group self.num_experts_per_tok = num_experts_per_tok - self.moe_layer_freq = moe_layer_freq self.first_k_dense_replace = first_k_dense_replace self.norm_topk_prob = norm_topk_prob - self.scoring_func = scoring_func self.aux_loss_alpha = aux_loss_alpha - self.seq_aux = seq_aux + # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads @@ -217,7 +217,11 @@ class DeepseekV3Config(PretrainedConfig): self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout - self.mlp_bias = mlp_bias + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) super().__init__( pad_token_id=pad_token_id, diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index c85c7bb..e402506 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -135,3 +135,7 @@ class StaticCache(transformers.StaticCache): # In-place ops prevent breaking the static address self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() + + def get_max_cache_shape(self) -> Tuple[int, int, int, int]: + """Returns the maximum shape of the cache.""" + return self.max_cache_len \ No newline at end of file diff --git a/ktransformers/models/modeling_deepseekv3.py b/ktransformers/models/modeling_deepseek_v3.py similarity index 77% rename from ktransformers/models/modeling_deepseekv3.py rename to ktransformers/models/modeling_deepseek_v3.py index d8a888c..8eb9b9c 100644 --- a/ktransformers/models/modeling_deepseekv3.py +++ b/ktransformers/models/modeling_deepseek_v3.py @@ -1,15 +1,13 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/deepseekv3/modular_deepseekv3.py. +# This file was automatically generated from src/transformers/models/deepseek_v3/modular_deepseek_v3.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the -# modular_deepseekv3.py file directly. One of our CI enforces this. +# modular_deepseek_v3.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math from typing import Callable, List, Optional, Tuple, Union -import numpy as np import torch -import torch.distributed as dist import torch.nn.functional as F from torch import nn @@ -30,7 +28,7 @@ from transformers.utils import ( replace_return_docstrings, ) from transformers.utils.deprecation import deprecate_kwarg -from .configuration_deepseekv3 import DeepseekV3Config +from .configuration_deepseek_v3 import DeepseekV3Config logger = logging.get_logger(__name__) @@ -119,15 +117,15 @@ class DeepseekV3RotaryEmbedding(nn.Module): class DeepseekV3MLP(nn.Module): - def __init__(self, config): + def __init__(self, config, hidden_size=None, intermediate_size=None): super().__init__() self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.moe_intermediate_size - # TODO rm hard coding - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)# config.mlp_bias) + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): @@ -135,70 +133,46 @@ class DeepseekV3MLP(nn.Module): return down_proj -class MoEGate(nn.Module): +class DeepseekV3TopkRouter(nn.Module): def __init__(self, config): super().__init__() self.config = config self.top_k = config.num_experts_per_tok self.n_routed_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor - self.scoring_func = config.scoring_func - self.seq_aux = config.seq_aux - self.topk_method = config.topk_method self.n_group = config.n_group self.topk_group = config.topk_group - # topk selection algorithm - self.norm_topk_prob = config.norm_topk_prob - self.gating_dim = config.hidden_size - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) - if self.topk_method == "noaux_tc": - self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) - self.reset_parameters() - - def reset_parameters(self) -> None: - import torch.nn.init as init - - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) def forward(self, hidden_states): - bsz, seq_len, h = hidden_states.shape - ### compute gating score - hidden_states = hidden_states.view(-1, h) - logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None) - if self.scoring_func == "sigmoid": - scores = logits.sigmoid() - else: - raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") + batch_size, seq_length = hidden_states.shape[:-1] + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - ### select top-k experts - if self.topk_method == "noaux_tc": - # assert not self.training - scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) - group_scores = ( - scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group) - .reshape(bsz * seq_len, -1) - ) # [n, e] - tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] - _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) - topk_weight = scores.gather(1, topk_idx) - else: - raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}") - - ### norm gate to sum 1 - if self.top_k > 1 and self.norm_topk_prob: - denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 - topk_weight = topk_weight / denominator - topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor - - return topk_idx, topk_weight + scores = router_logits.sigmoid() + scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) # [n, e] + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] + _, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False) + topk_weights = scores.gather(1, topk_indices) + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor + return topk_indices, topk_weights, router_logits class DeepseekV3MoE(nn.Module): @@ -209,116 +183,75 @@ class DeepseekV3MoE(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.num_experts_per_tok = config.num_experts_per_tok - - if hasattr(config, "ep_size") and config.ep_size > 1: - assert config.ep_size == dist.get_world_size() - self.ep_size = config.ep_size - self.experts_per_rank = config.n_routed_experts // config.ep_size - self.ep_rank = dist.get_rank() - self.experts = nn.ModuleList( - [ - ( - DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) - if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank - else None - ) - for i in range(config.n_routed_experts) - ] - ) - else: - self.ep_size = 1 - self.experts_per_rank = config.n_routed_experts - self.ep_rank = 0 - self.experts = nn.ModuleList( - [ - DeepseekV3MLP(config) - for i in range(config.n_routed_experts) - ] - ) - self.gate = MoEGate(config) - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV3MLP(config=config) + self.experts = nn.ModuleList( + [ + DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(config.n_routed_experts) + ] + ) + self.gate = DeepseekV3TopkRouter(config) + self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=config.moe_intermediate_size) def forward(self, hidden_states): - identity = hidden_states + residuals = hidden_states orig_shape = hidden_states.shape - topk_idx, topk_weight = self.gate(hidden_states) + topk_indices, topk_weights, router_logits = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - if not self.training: - y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) - if self.config.n_shared_experts is not None: - y = y + self.shared_experts(identity) - return y + hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states, router_logits - @torch.no_grad() - def moe_infer(self, x, topk_ids, topk_weight): - cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) - cnts.scatter_(1, topk_ids, 1) - tokens_per_expert = cnts.sum(dim=0) - idxs = topk_ids.view(-1).argsort() - sorted_tokens = x[idxs // topk_ids.shape[1]] - sorted_tokens_shape = sorted_tokens.shape - if self.ep_size > 1: - tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) - tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0]) - dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) - output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist() - gathered_tokens = sorted_tokens.new_empty( - tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] - ) - input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() - dist.all_to_all( - list(gathered_tokens.split(output_splits)), - list(sorted_tokens.split(input_split_sizes)), - ) - tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum( - dim=0 - ) - gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) - s = 0 - for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): - gatherd_idxs[s : s + k] = i % self.experts_per_rank - s += k - gatherd_idxs = gatherd_idxs.argsort() - sorted_tokens = gathered_tokens[gatherd_idxs] - tokens_per_expert = tokens_per_expert_post_gather - tokens_per_expert = tokens_per_expert.cpu().numpy() + def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) + expert_mask = expert_mask.permute(2, 0, 1) - outputs = [] - start_idx = 0 - for i, num_tokens in enumerate(tokens_per_expert): - end_idx = start_idx + num_tokens - if num_tokens == 0: - continue - expert = self.experts[i + self.ep_rank * self.experts_per_rank] - tokens_for_this_expert = sorted_tokens[start_idx:end_idx] - expert_out = expert(tokens_for_this_expert) - outputs.append(expert_out) - start_idx = end_idx + for expert_idx in range(len(self.experts)): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx] + token_indices, weight_indices = torch.where(mask) - outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) - if self.ep_size > 1: - new_x = torch.empty_like(outs) - new_x[gatherd_idxs] = outs - gathered_tokens = new_x.new_empty(*sorted_tokens_shape) - dist.all_to_all( - list(gathered_tokens.split(input_split_sizes)), - list(new_x.split(output_splits)), - ) - outs = gathered_tokens + if token_indices.numel() > 0: + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, weighted_output) + return final_hidden_states.type(hidden_states.dtype) - new_x = torch.empty_like(outs) - new_x[idxs] = outs - final_out = ( - new_x.view(*topk_ids.shape, -1) - .type(topk_weight.dtype) - .mul_(topk_weight.unsqueeze(dim=-1)) - .sum(dim=1) - .type(new_x.dtype) - ) - return final_out + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -359,150 +292,94 @@ def eager_attention_forward( return attn_output, attn_weights -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 class DeepseekV3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): + def __init__(self, config: DeepseekV3Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - - self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.q_lora_rank = config.q_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim self.kv_lora_rank = config.kv_lora_rank self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim - self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + self.q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim self.is_causal = True - - if self.q_lora_rank is None: - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) - else: - self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias) - self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) - self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) self.kv_a_proj_with_mqa = nn.Linear( - self.hidden_size, - config.kv_lora_rank + config.qk_rope_head_dim, + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, bias=config.attention_bias, ) - self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) + self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank) self.kv_b_proj = nn.Linear( - config.kv_lora_rank, + self.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias=False, ) self.o_proj = nn.Linear( self.num_heads * self.v_head_dim, - self.hidden_size, + config.hidden_size, bias=config.attention_bias, ) - self.rotary_emb = DeepseekV3RotaryEmbedding( - config=self.config, - ) + self.scaling = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scaling = self.scaling * mscale * mscale + # TODO apply in DeepSeekV3Model to share accrose layers + self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs# : Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, self.num_heads, -1) - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - .transpose(1, 2) - ) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - kv_seq_len = value_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(hidden_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + k_rot = k_rot.view(*input_shape, 1, self.qk_rope_head_dim).transpose(1, 2) - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope - query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + cos, sin = position_embeddings + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(-1, self.num_heads, -1, -1) - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) - if self.q_head_dim != self.v_head_dim: + if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) if past_key_value is not None: @@ -518,8 +395,11 @@ class DeepseekV3Attention(nn.Module): 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: - pass - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + raise NotImplementedError( + f"Attention implementation {self.config._attn_implementation} is not supported. " + "Please use 'eager' or 'sdpa'." + ) + # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -531,9 +411,12 @@ class DeepseekV3Attention(nn.Module): scaling=self.scaling, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) - attn_output = self.o_proj(attn_output) + if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) return attn_output, attn_weights @@ -544,15 +427,11 @@ class DeepseekV3DecoderLayer(nn.Module): self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx) - self.mlp = ( - DeepseekV3MoE(config) - if ( - config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0 - ) - else DeepseekV3MLP(config) - ) + if layer_idx >= config.first_k_dense_replace: + self.mlp = DeepseekV3MoE(config) + else: + self.mlp = DeepseekV3MLP(config) + self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -563,6 +442,7 @@ class DeepseekV3DecoderLayer(nn.Module): position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -590,16 +470,24 @@ class DeepseekV3DecoderLayer(nn.Module): residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) + + if isinstance(hidden_states, tuple): + hidden_states, router_logits = hidden_states + else: + router_logits = (torch.zeros((1,), device=hidden_states.device, dtype=torch.int64),) + hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) + if output_router_logits: + outputs += (router_logits,) return outputs -DEEPSEEKV3_START_DOCSTRING = r""" +DEEPSEEK_V3_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) @@ -618,7 +506,7 @@ DEEPSEEKV3_START_DOCSTRING = r""" @add_start_docstrings( "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", - DEEPSEEKV3_START_DOCSTRING, + DEEPSEEK_V3_START_DOCSTRING, ) class DeepseekV3PreTrainedModel(PreTrainedModel): config_class = DeepseekV3Config @@ -646,7 +534,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): module.weight.data[module.padding_idx].zero_() -DEEPSEEKV3_INPUTS_DOCSTRING = r""" +DEEPSEEK_V3_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide @@ -723,7 +611,7 @@ DEEPSEEKV3_INPUTS_DOCSTRING = r""" @add_start_docstrings( "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", - DEEPSEEKV3_START_DOCSTRING, + DEEPSEEK_V3_START_DOCSTRING, ) class DeepseekV3Model(DeepseekV3PreTrainedModel): """ @@ -733,7 +621,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): config: DeepseekV3Config """ - def __init__(self, config: DeepseekV3Config): + def __init__(self, config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -745,6 +633,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) self.gradient_checkpointing = False + self._register_load_state_dict_pre_hook(self.load_hook) # Initialize weights and apply final processing self.post_init() @@ -755,7 +644,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - @add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, @@ -983,6 +872,49 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): return causal_mask + def load_hook(self, state_dict, prefix, *args): + """ + Weights have to be permuted for correct rope formulation. We can't do this in the weights + as every other framework already uses the `Llama` original function (which is copyrighted btw). + And I am not even sure it's better.... anyways end of my rant + """ + + def permute_for_rope(input_tensor): + """ + When you go from the complex ROPE formulation to sin and cos one, you need + to permute the query and key weights (to avoid doing it on the fly) + """ + n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2] + input_tensor = input_tensor.reshape(n_heads * dim1, dim2) + input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2) + input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2) + return input_tensor + + def permute_layer_for_rope(key, num_heads, head_dim, rope_dim): + weight = state_dict[key] + weight = weight.view(num_heads, head_dim, -1) + weight_rot = weight[:, -rope_dim:] + weight_rot = permute_for_rope(weight_rot) + weight[:, -rope_dim:] = weight_rot + weight = weight.view(-1, weight.shape[-1]) + state_dict[key] = weight + + for k in state_dict: + if "q_b_proj." in k: + permute_layer_for_rope( + k, + num_heads=self.config.num_attention_heads, + head_dim=self.config.q_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) + if "kv_a_proj_with_mqa." in k: + permute_layer_for_rope( + k, + num_heads=1, + head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) + # class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -1019,7 +951,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): return self.model @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") - @add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, @@ -1058,8 +990,8 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): ```python >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM - >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf") + >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -1125,7 +1057,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, - DEEPSEEKV3_START_DOCSTRING, + DEEPSEEK_V3_START_DOCSTRING, ) class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): def __init__(self, config): @@ -1143,7 +1075,7 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value - @add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1213,4 +1145,12 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, - ) \ No newline at end of file + ) + + +__all__ = [ + "DeepseekV3PreTrainedModel", + "DeepseekV3Model", + "DeepseekV3ForCausalLM", + "DeepseekV3ForSequenceClassification", +] \ No newline at end of file diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index b3b1802..f98bfff 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -13,7 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.modeling_llama import LlamaRotaryEmbedding from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb -from ktransformers.models.modeling_deepseekv3 import DeepseekV3Attention, apply_rotary_pos_emb +from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Attention +from ktransformers.models.modeling_deepseek_v3 import apply_rotary_pos_emb as apply_rotary_pos_emb_v3 from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_gguf import GGUFLoader @@ -95,7 +96,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention): kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(q_pe, position_ids) - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) + q_pe, k_pe = apply_rotary_pos_emb_v3(q_pe, k_pe, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index ddfcda9..03a1488 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -519,7 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase): from ktransformers.models.modeling_deepseek import DeepseekV2MoE -from ktransformers.models.modeling_deepseekv3 import DeepseekV3MoE +from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock @@ -734,9 +734,10 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): identity = hidden_states orig_shape = hidden_states.shape sequence_length = orig_shape[1] - topk_idx, topk_weight= self.gate(hidden_states) + topk_idx, topk_weight, router_logits= self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + # only for generate phase if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) if self.config.n_shared_experts is not None: @@ -744,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) y += y_ y.resize_(*orig_shape) - return y + return y, router_logits if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) @@ -767,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): ) if self.config.n_shared_experts is not None: y += y_ - return y + return y, router_logits @torch.no_grad() def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py index 91a3872..dcf45cb 100644 --- a/ktransformers/operators/gate.py +++ b/ktransformers/operators/gate.py @@ -16,7 +16,7 @@ from cpuinfer_ext.moe import MOEConfig, MOE import ctypes from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_gguf import GGUFLoader -from ktransformers.models.modeling_deepseekv3 import MoEGate +from ktransformers.models.modeling_deepseek_v3 import DeepseekV3TopkRouter from ktransformers.util.utils import InferenceState from ktransformers.server.config.config import Config from transformers.activations import ACT2FN @@ -118,11 +118,10 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): else: raise ValueError("Invalid weight type") self.orig_module.weight = self.orig_module.weight.to(device) - if self.topk_method == "noaux_tc": - self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device) + self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device) def unload(self): if self.weight is not None: self.weight = None - if self.topk_method == "noaux_tc": + if self.e_score_correction_bias is not None: self.e_score_correction_bias = None diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml index 3fd86d9..7135933 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml @@ -47,7 +47,7 @@ - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$" - class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: @@ -55,7 +55,7 @@ prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\.mlp$" - class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: @@ -64,7 +64,7 @@ - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" - class: ktransformers.models.modeling_deepseekv3.MoEGate + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter replace: class: ktransformers.operators.gate.KMoEGate kwargs: @@ -72,7 +72,7 @@ prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" - class: ktransformers.models.modeling_deepseekv3.MoEGate + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter replace: class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function kwargs: diff --git a/ktransformers/server/config/config.py b/ktransformers/server/config/config.py index 27b788f..cf5c0ef 100644 --- a/ktransformers/server/config/config.py +++ b/ktransformers/server/config/config.py @@ -102,7 +102,7 @@ class Config(metaclass=Singleton): self.total_context = self.model.get("total_context", 2**18) self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1) self.max_chunk_size = self.model.get("max_chunk_size", 2048) - self.max_new_tokens = self.model.get("max_new_tokens", 500) + self.max_new_tokens = self.model.get("max_new_tokens", 2000) self.json_mode = self.model.get("json_mode", False) self.healing = self.model.get("healing", False) self.ban_strings: Optional[list] = self.model.get("ban_strings", None) diff --git a/ktransformers/util/modeling_rope_utils.py b/ktransformers/util/modeling_rope_utils.py index 2598a52..4fec4bc 100644 --- a/ktransformers/util/modeling_rope_utils.py +++ b/ktransformers/util/modeling_rope_utils.py @@ -58,7 +58,8 @@ def _compute_default_rope_parameters( elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) attention_factor = 1.0 # Unused in this type of RoPE @@ -143,14 +144,15 @@ def _compute_dynamic_ntk_parameters( elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) max_position_embeddings = config.max_position_embeddings factor = config.rope_scaling["factor"] attention_factor = 1.0 # Unused in this type of RoPE # seq_len: default to max_position_embeddings, e.g. at init time - seq_len = seq_len if seq_len is not None else max_position_embeddings + seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings # Compute the inverse frequencies base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) @@ -185,15 +187,33 @@ def _compute_yarn_parameters( base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = config.qk_rope_head_dim - - max_position_embeddings = config.max_position_embeddings + head_dim = getattr(config, "qk_rope_head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) factor = config.rope_scaling["factor"] + attention_factor = config.rope_scaling.get("attention_factor") + mscale = config.rope_scaling.get("mscale") + mscale_all_dim = config.rope_scaling.get("mscale_all_dim") + + # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a + # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two + # values to compute the default attention scaling factor, instead of using `factor`. + if "original_max_position_embeddings" in config.rope_scaling: + original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"] + factor = config.max_position_embeddings / original_max_position_embeddings + else: + original_max_position_embeddings = config.max_position_embeddings + + def get_mscale(scale, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 # Sets the attention factor as suggested in the paper - attention_factor = config.rope_scaling.get("attention_factor") if attention_factor is None: - attention_factor = 0.1 * math.log(factor) + 1.0 + if mscale and mscale_all_dim: + attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)) + else: + attention_factor = get_mscale(factor) # Optional config options # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) @@ -211,7 +231,7 @@ def _compute_yarn_parameters( high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) return max(low, 0), min(high, dim - 1) - def linear_ramp_mask(min, max, dim): + def linear_ramp_factor(min, max, dim): if min == max: max += 0.001 # Prevent singularity @@ -219,16 +239,20 @@ def _compute_yarn_parameters( ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func + # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # to expand the possible context length. In other words, interpolation = apply scaling factor. pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (factor * pos_freqs) - low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) + low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings) # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_mask = 1 - linear_ramp_mask(low, high, dim // 2).float().to(device) - inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask - + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) return inv_freq, attention_factor @@ -244,7 +268,7 @@ def _compute_longrope_parameters( device (`torch.device`): The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. + The current sequence length. rope_kwargs (`Dict`, *optional*): BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: @@ -261,7 +285,8 @@ def _compute_longrope_parameters( base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) long_factor = config.rope_scaling["long_factor"] short_factor = config.rope_scaling["short_factor"] factor = config.rope_scaling.get("factor") @@ -271,22 +296,20 @@ def _compute_longrope_parameters( # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two # values to compute the default attention scaling factor, instead of using `factor`. if hasattr(config, "original_max_position_embeddings"): - max_position_embeddings = config.original_max_position_embeddings - expanded_max_position_embeddings = config.max_position_embeddings - factor = expanded_max_position_embeddings / max_position_embeddings + original_max_position_embeddings = config.original_max_position_embeddings + factor = config.max_position_embeddings / config.original_max_position_embeddings else: - max_position_embeddings = config.max_position_embeddings - expanded_max_position_embeddings = max_position_embeddings * factor + original_max_position_embeddings = config.max_position_embeddings # Sets the attention factor as suggested in the paper if attention_factor is None: if factor <= 1.0: attention_factor = 1.0 else: - attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings)) + attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings)) # Compute the inverse frequencies -- scaled based on the target sequence length - if expanded_max_position_embeddings > max_position_embeddings: + if seq_len and seq_len > original_max_position_embeddings: ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) else: ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device) @@ -325,19 +348,18 @@ def _compute_llama3_parameters( low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = [] - for freq in inv_freq: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) - new_freqs.append((1 - smooth) * freq / factor + smooth * freq) - inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device) - return inv_freq, attention_factor + + wavelen = 2 * math.pi / inv_freq + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + return inv_freq_llama, attention_factor # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters @@ -353,12 +375,22 @@ ROPE_INIT_FUNCTIONS = { } -def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None): +def _check_received_keys( + rope_type: str, + received_keys: set, + required_keys: set, + optional_keys: Optional[set] = None, + ignore_keys: Optional[set] = None, +): """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" - # BC: "rope_type" was originally "type" -- let's gracefully handle it - if "rope_type" not in received_keys and "type" in received_keys: + # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present + if "type" in received_keys: received_keys -= {"type"} - received_keys.add("rope_type") + required_keys.add("rope_type") + + # Some models need to store model-specific keys, and we don't want to throw warning at them + if ignore_keys is not None: + received_keys -= ignore_keys missing_keys = required_keys - received_keys if missing_keys: @@ -372,47 +404,54 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") -def _validate_default_rope_parameters(config: PretrainedConfig): +def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) -def _validate_linear_scaling_rope_parameters(config: PretrainedConfig): +def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") -def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): +def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` optional_keys = {"original_max_position_embeddings"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") -def _validate_yarn_parameters(config: PretrainedConfig): +def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} - optional_keys = {"attention_factor", "beta_fast", "beta_slow"} + optional_keys = { + "attention_factor", + "beta_fast", + "beta_slow", + "original_max_position_embeddings", + "mscale", + "mscale_all_dim", + } received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: @@ -437,17 +476,18 @@ def _validate_yarn_parameters(config: PretrainedConfig): ) -def _validate_longrope_parameters(config: PretrainedConfig): +def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "short_factor", "long_factor"} # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) short_factor = rope_scaling.get("short_factor") if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): @@ -479,18 +519,19 @@ def _validate_longrope_parameters(config: PretrainedConfig): logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") attention_factor = rope_scaling.get("attention_factor") - if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: - logger.warning( - f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) + if attention_factor is not None: + if not isinstance(attention_factor, float) or attention_factor < 0.0: + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) -def _validate_llama3_parameters(config: PretrainedConfig): +def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: @@ -502,7 +543,7 @@ def _validate_llama3_parameters(config: PretrainedConfig): logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") if high_freq_factor is None or not isinstance(high_freq_factor, float): logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") - if high_freq_factor < low_freq_factor: + if high_freq_factor <= low_freq_factor: logger.warning( "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" f"{high_freq_factor} and low_freq_factor={low_freq_factor}" @@ -532,7 +573,7 @@ ROPE_VALIDATION_FUNCTIONS = { } -def rope_config_validation(config: PretrainedConfig): +def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None): """ Validate the RoPE config arguments, given a `PretrainedConfig` object """ @@ -544,8 +585,8 @@ def rope_config_validation(config: PretrainedConfig): rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) if validation_fn is not None: - validation_fn(config) + validation_fn(config, ignore_keys=ignore_keys) else: logger.warning( f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" - ) + ) \ No newline at end of file From f748cd29f0c63bbe39ab617b701eabe58b6753db Mon Sep 17 00:00:00 2001 From: Azure Date: Sat, 1 Feb 2025 18:05:45 +0000 Subject: [PATCH 04/26] fix rope; update moegate --- ktransformers/models/modeling_deepseek_v3.py | 35 +++++++++++-------- ktransformers/operators/RoPE.py | 28 +++++++++++++++ ktransformers/operators/linear.py | 2 +- ktransformers/operators/models.py | 2 +- .../DeepSeek-V3-Chat-multi-gpu.yaml | 8 ++--- 5 files changed, 54 insertions(+), 21 deletions(-) diff --git a/ktransformers/models/modeling_deepseek_v3.py b/ktransformers/models/modeling_deepseek_v3.py index 8eb9b9c..1a197c7 100644 --- a/ktransformers/models/modeling_deepseek_v3.py +++ b/ktransformers/models/modeling_deepseek_v3.py @@ -142,37 +142,42 @@ class DeepseekV3TopkRouter(nn.Module): self.routed_scaling_factor = config.routed_scaling_factor self.n_group = config.n_group self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) def forward(self, hidden_states): - batch_size, seq_length = hidden_states.shape[:-1] hidden_states = hidden_states.view(-1, self.config.hidden_size) router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - scores = router_logits.sigmoid() + topk_indices = self.get_topk_indices(scores) + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights, router_logits + + @torch.no_grad() + def get_topk_indices(self, scores): scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) group_scores = ( scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) .topk(2, dim=-1)[0] .sum(dim=-1) - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) score_mask = ( group_mask.unsqueeze(-1) - .expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) .reshape(-1, self.n_routed_experts) - ) # [n, e] - scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] - _, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False) - topk_weights = scores.gather(1, topk_indices) - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor - return topk_indices, topk_weights, router_logits + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + return topk_indices class DeepseekV3MoE(nn.Module): diff --git a/ktransformers/operators/RoPE.py b/ktransformers/operators/RoPE.py index dca441d..06b0ab4 100644 --- a/ktransformers/operators/RoPE.py +++ b/ktransformers/operators/RoPE.py @@ -12,6 +12,9 @@ from ktransformers.models.modeling_llama import ( LlamaLinearScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding, ) +from ktransformers.models.modeling_deepseek_v3 import ( + DeepseekV3RotaryEmbedding +) from ktransformers.models.modeling_deepseek import ( DeepseekV2YarnRotaryEmbedding, DeepseekV2RotaryEmbedding, @@ -134,6 +137,31 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): self.orig_module.mscale_all_dim, ) +class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + # device: str = "cuda", + generate_device: str = "cuda", + prefill_device: str = "cuda", + **kwargs, + ): + BaseInjectedModule.__init__( + self, key, gguf_loader, config, orig_module, generate_device, **kwargs + ) + self.generate_device = generate_device + self.prefill_device = prefill_device + + def load(self): + # TODO support perlayer prefill + self.orig_module.__init__( + self.config, + device=self.generate_device + ) + return class DynamicNTKScalingRotaryEmbedding( BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 7510f82..a79778a 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -222,7 +222,7 @@ class KLinearMarlin(KLinearBase): x = x.to(self.device) orig_shape = list(x.shape) orig_dtype = x.dtype - x = x.reshape(-1, x.shape[-1]) + x = x.reshape(-1, orig_shape[-1]) marlin_s = self.marlin_s.to(x.dtype) x = KTransformersOps.gptq_marlin_gemm( x, diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index 9fa1a19..fd8fee9 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -643,7 +643,7 @@ class KDeepseekV2Model(BaseInjectedModule): org_device = input_ids.device # TODO move to embed_tokens's device, not hard code to cpu input_ids = input_ids.to("cpu") - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input_ids).to(org_device) input_ids = input_ids.to(org_device) if per_layer_prefill_flag: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml index 7135933..4dfff61 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml @@ -8,17 +8,17 @@ - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." - class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.YarnRotaryEmbedding + class: ktransformers.operators.RoPE.DeepSeekV3YarnRotaryEmbedding kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\." - class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.YarnRotaryEmbedding + class: ktransformers.operators.RoPE.DeepSeekV3YarnRotaryEmbedding kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" From 907251c7432d3f4d61db3c3579443379b5633356 Mon Sep 17 00:00:00 2001 From: Azure Date: Tue, 4 Feb 2025 15:53:38 +0000 Subject: [PATCH 05/26] done support deepseekv3 --- ktransformers/models/modeling_deepseek_v3.py | 1881 ++++++++++++----- ktransformers/operators/RoPE.py | 53 +- ktransformers/operators/attention.py | 4 +- ktransformers/operators/experts.py | 6 +- ktransformers/operators/gate.py | 5 +- ktransformers/operators/models.py | 24 +- .../DeepSeek-V3-Chat-multi-gpu.yaml | 12 +- .../backend/interfaces/ktransformers.py | 4 +- .../server/backend/interfaces/transformers.py | 4 +- 9 files changed, 1413 insertions(+), 580 deletions(-) diff --git a/ktransformers/models/modeling_deepseek_v3.py b/ktransformers/models/modeling_deepseek_v3.py index 1a197c7..10b8766 100644 --- a/ktransformers/models/modeling_deepseek_v3.py +++ b/ktransformers/models/modeling_deepseek_v3.py @@ -1,40 +1,96 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/deepseek_v3/modular_deepseek_v3.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_deepseek_v3.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeepSeek model.""" import math -from typing import Callable, List, Optional, Tuple, Union +import warnings +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F +import torch.utils.checkpoint from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache -from transformers.generation import GenerationMixin -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -# from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from ktransformers.util.modeling_rope_utils import ROPE_INIT_FUNCTIONS -from transformers.modeling_utils import PreTrainedModel # ALL_ATTENTION_FUNCTIONS, PreTrainedModel -# from transformers.processing_utils import Unpack +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_1_13, +) from transformers.utils import ( - # LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) -from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.import_utils import is_torch_fx_available from .configuration_deepseek_v3 import DeepseekV3Config +import torch.distributed as dist +import numpy as np + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) logger = logging.get_logger(__name__) + _CONFIG_FOR_DOC = "DeepseekV3Config" +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + class DeepseekV3RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -51,69 +107,268 @@ class DeepseekV3RMSNorm(nn.Module): hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + +ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm) class DeepseekV3RotaryEmbedding(nn.Module): - def __init__(self, config: DeepseekV3Config, device=None): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + self.max_seq_len_cached = None - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3 +class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): + """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3 +class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): + """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed class DeepseekV3MLP(nn.Module): @@ -121,7 +376,9 @@ class DeepseekV3MLP(nn.Module): super().__init__() self.config = config self.hidden_size = config.hidden_size if hidden_size is None else hidden_size - self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -133,52 +390,87 @@ class DeepseekV3MLP(nn.Module): return down_proj -class DeepseekV3TopkRouter(nn.Module): +class MoEGate(nn.Module): def __init__(self, config): super().__init__() self.config = config self.top_k = config.num_experts_per_tok self.n_routed_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method self.n_group = config.n_group self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) - self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim)) + ) + if self.topk_method == "noaux_tc": + self.e_score_correction_bias = nn.Parameter( + torch.empty((self.n_routed_experts)) + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) - router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - scores = router_logits.sigmoid() - topk_indices = self.get_topk_indices(scores) - topk_weights = scores.gather(1, topk_indices) - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights, router_logits - - @torch.no_grad() - def get_topk_indices(self, scores): - scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) - group_scores = ( - scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) - .topk(2, dim=-1)[0] - .sum(dim=-1) + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear( + hidden_states.type(torch.float32), self.weight.type(torch.float32), None ) - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) - ) - scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) - topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - return topk_indices + if self.scoring_func == "sigmoid": + scores = logits.sigmoid() + else: + raise NotImplementedError( + f"insupportable scoring function for MoE gating: {self.scoring_func}" + ) + ### select top-k experts + if self.topk_method == "noaux_tc": + scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1) + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] + _, topk_idx = torch.topk( + tmp_scores, k=self.top_k, dim=-1, sorted=False + ) + topk_weight = scores.gather(1, topk_idx) + else: + raise NotImplementedError( + f"insupportable TopK function for MoE gating: {self.topk_method}" + ) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor + + return topk_idx, topk_weight class DeepseekV3MoE(nn.Module): """ @@ -188,77 +480,135 @@ class DeepseekV3MoE(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.experts = nn.ModuleList( - [ - DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) - for _ in range(config.n_routed_experts) - ] - ) - self.gate = DeepseekV3TopkRouter(config) - self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=config.moe_intermediate_size) + self.num_experts_per_tok = config.num_experts_per_tok + + if hasattr(config, "ep_size") and config.ep_size > 1: + assert config.ep_size == dist.get_world_size() + self.ep_size = config.ep_size + self.experts_per_rank = config.n_routed_experts // config.ep_size + self.ep_rank = dist.get_rank() + self.experts = nn.ModuleList( + [ + ( + DeepseekV3MLP( + config, intermediate_size=config.moe_intermediate_size + ) + if i >= self.ep_rank * self.experts_per_rank + and i < (self.ep_rank + 1) * self.experts_per_rank + else None + ) + for i in range(config.n_routed_experts) + ] + ) + else: + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = nn.ModuleList( + [ + DeepseekV3MLP( + config, intermediate_size=config.moe_intermediate_size + ) + for i in range(config.n_routed_experts) + ] + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV3MLP( + config=config, intermediate_size=intermediate_size + ) def forward(self, hidden_states): - residuals = hidden_states + identity = hidden_states orig_shape = hidden_states.shape - topk_indices, topk_weights, router_logits = self.gate(hidden_states) + topk_idx, topk_weight = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) - hidden_states = hidden_states + self.shared_experts(residuals) - return hidden_states, router_logits + flat_topk_idx = topk_idx.view(-1) + if not self.training: + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y - def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): - final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) - expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) - expert_mask = expert_mask.permute(2, 0, 1) + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + sorted_tokens_shape = sorted_tokens.shape + if self.ep_size > 1: + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0] + ) + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) + output_splits = ( + tokens_per_expert_group.view(self.ep_size, -1) + .sum(1) + .cpu() + .numpy() + .tolist() + ) + gathered_tokens = sorted_tokens.new_empty( + tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] + ) + input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() + dist.all_to_all( + list(gathered_tokens.split(output_splits)), + list(sorted_tokens.split(input_split_sizes)), + ) + tokens_per_expert_post_gather = tokens_per_expert_group.view( + self.ep_size, self.experts_per_rank + ).sum(dim=0) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % self.experts_per_rank + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + tokens_per_expert = tokens_per_expert.cpu().numpy() - for expert_idx in range(len(self.experts)): - expert = self.experts[expert_idx] - mask = expert_mask[expert_idx] - token_indices, weight_indices = torch.where(mask) + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx - if token_indices.numel() > 0: - expert_weights = topk_weights[token_indices, weight_indices] - expert_input = hidden_states[token_indices] - expert_output = expert(expert_input) - weighted_output = expert_output * expert_weights.unsqueeze(-1) - final_hidden_states.index_add_(0, token_indices, weighted_output) - return final_hidden_states.type(hidden_states.dtype) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + if self.ep_size > 1: + new_x = torch.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = new_x.new_empty(*sorted_tokens_shape) + dist.all_to_all( + list(gathered_tokens.split(input_split_sizes)), + list(new_x.split(output_splits)), + ) + outs = gathered_tokens + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out +# Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -267,198 +617,592 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -def yarn_get_mscale(scale=1, mscale=1): - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3 class DeepseekV3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: DeepseekV3Config, layer_idx: int): + def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.q_lora_rank = config.q_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim self.kv_lora_rank = config.kv_lora_rank self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim - self.q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim self.is_causal = True - self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) - self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) - self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) self.kv_a_proj_with_mqa = nn.Linear( - config.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, bias=config.attention_bias, ) - self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank) + self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) self.kv_b_proj = nn.Linear( - self.kv_lora_rank, - self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + config.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias=False, ) self.o_proj = nn.Linear( self.num_heads * self.v_head_dim, - config.hidden_size, + self.hidden_size, bias=config.attention_bias, ) + self._init_rope() - self.scaling = self.q_head_dim ** (-0.5) + self.softmax_scale = self.q_head_dim ** (-0.5) if self.config.rope_scaling is not None: mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) scaling_factor = self.config.rope_scaling["factor"] if mscale_all_dim: mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) - self.scaling = self.scaling * mscale * mscale - # TODO apply in DeepSeekV3Model to share accrose layers - self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) + self.softmax_scale = self.softmax_scale * mscale * mscale - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs# : Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, self.num_heads, -1) - - q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2) - q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - - k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(hidden_shape).transpose(1, 2) - k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k_rot = k_rot.view(*input_shape, 1, self.qk_rope_head_dim).transpose(1, 2) - - cos, sin = position_embeddings - q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) - k_rot = k_rot.expand(-1, self.num_heads, -1, -1) - - query_states = torch.cat((q_pass, q_rot), dim=-1) - key_states = torch.cat((k_pass, k_rot), dim=-1) - - if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: - value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV3RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV3YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, ) else: - raise NotImplementedError( - f"Attention implementation {self.config._attn_implementation} is not supported. " - "Please use 'eager' or 'sdpa'." - ) - # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) + .transpose(1, 2) + .contiguous() ) - if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: - attn_output = attn_output[:, :, :, : self.v_head_dim] - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -class DeepseekV3DecoderLayer(nn.Module): - def __init__(self, config: DeepseekV3Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx) - - if layer_idx >= config.first_k_dense_replace: - self.mlp = DeepseekV3MoE(config) - else: - self.mlp = DeepseekV3MLP(config) - - self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3 +class DeepseekV3FlashAttention2(DeepseekV3Attention): + """ + DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # DeepseekV3FlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeepseekV3RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = ( + self.q_proj.weight.dtype + if self.q_lora_rank is None + else self.q_a_proj.weight.dtype + ) + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.softmax_scale, + ) + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +ATTENTION_CLASSES = { + "eager": DeepseekV3Attention, + "flash_attention_2": DeepseekV3FlashAttention2, +} + + +class DeepseekV3DecoderLayer(nn.Module): + def __init__(self, config: DeepseekV3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = ( + DeepseekV3MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV3MLP(config) + ) + self.input_layernorm = DeepseekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = DeepseekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, - output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs# : Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights = self.self_attn( + hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -466,7 +1210,6 @@ class DeepseekV3DecoderLayer(nn.Module): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -475,24 +1218,20 @@ class DeepseekV3DecoderLayer(nn.Module): residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - - if isinstance(hidden_states, tuple): - hidden_states, router_logits = hidden_states - else: - router_logits = (torch.zeros((1,), device=hidden_states.device, dtype=torch.int64),) - hidden_states = residual + hidden_states outputs = (hidden_states,) + if output_attentions: outputs += (self_attn_weights,) - if output_router_logits: - outputs += (router_logits,) + + if use_cache: + outputs += (present_key_value,) return outputs -DEEPSEEK_V3_START_DOCSTRING = r""" +DeepseekV3_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) @@ -511,21 +1250,16 @@ DEEPSEEK_V3_START_DOCSTRING = r""" @add_start_docstrings( "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", - DEEPSEEK_V3_START_DOCSTRING, + DeepseekV3_START_DOCSTRING, ) class DeepseekV3PreTrainedModel(PreTrainedModel): config_class = DeepseekV3Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["DeepseekV3DecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] + _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -539,7 +1273,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): module.weight.data[module.padding_idx].zero_() -DEEPSEEK_V3_INPUTS_DOCSTRING = r""" +DeepseekV3_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide @@ -580,8 +1314,7 @@ DEEPSEEK_V3_INPUTS_DOCSTRING = r""" returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - a [`~cache_utils.Cache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. @@ -607,16 +1340,12 @@ DEEPSEEK_V3_INPUTS_DOCSTRING = r""" more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. """ @add_start_docstrings( "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", - DEEPSEEK_V3_START_DOCSTRING, + DeepseekV3_START_DOCSTRING, ) class DeepseekV3Model(DeepseekV3PreTrainedModel): """ @@ -626,20 +1355,24 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): config: DeepseekV3Config """ - def __init__(self, config): + def __init__(self, config: DeepseekV3Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx ) + self.layers = nn.ModuleList( + [ + DeepseekV3DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) - self.gradient_checkpointing = False - self._register_load_state_dict_pre_hook(self.load_hook) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -649,96 +1382,111 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **flash_attn_kwargs# : Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" ) - use_cache = False + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) - + # embed positions hidden_states = inputs_embeds - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + next_decoder_cache = None - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + if output_attentions: all_self_attns += (layer_outputs[1],) @@ -748,14 +1496,27 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - output = BaseModelOutputWithPast( + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -764,8 +1525,13 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): past_key_values: Cache, output_attentions: bool, ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): + if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None @@ -786,6 +1552,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): return None dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() @@ -796,17 +1563,27 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): else past_seen_tokens + sequence_length + 1 ) - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None @@ -816,117 +1593,12 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - def load_hook(self, state_dict, prefix, *args): - """ - Weights have to be permuted for correct rope formulation. We can't do this in the weights - as every other framework already uses the `Llama` original function (which is copyrighted btw). - And I am not even sure it's better.... anyways end of my rant - """ - - def permute_for_rope(input_tensor): - """ - When you go from the complex ROPE formulation to sin and cos one, you need - to permute the query and key weights (to avoid doing it on the fly) - """ - n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2] - input_tensor = input_tensor.reshape(n_heads * dim1, dim2) - input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2) - input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2) - return input_tensor - - def permute_layer_for_rope(key, num_heads, head_dim, rope_dim): - weight = state_dict[key] - weight = weight.view(num_heads, head_dim, -1) - weight_rot = weight[:, -rope_dim:] - weight_rot = permute_for_rope(weight_rot) - weight[:, -rope_dim:] = weight_rot - weight = weight.view(-1, weight.shape[-1]) - state_dict[key] = weight - - for k in state_dict: - if "q_b_proj." in k: - permute_layer_for_rope( - k, - num_heads=self.config.num_attention_heads, - head_dim=self.config.q_head_dim, - rope_dim=self.config.qk_rope_head_dim, - ) - if "kv_a_proj_with_mqa." in k: - permute_layer_for_rope( - k, - num_heads=1, - head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim, - rope_dim=self.config.qk_rope_head_dim, - ) - - -# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - -class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): +class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config): super().__init__(config) @@ -955,15 +1627,16 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") - @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -971,22 +1644,13 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs# : Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. Returns: @@ -995,8 +1659,8 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): ```python >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM - >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") + >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -1006,11 +1670,19 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -1024,17 +1696,24 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - **kwargs, ) hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) + logits = self.lm_head(hidden_states) + logits = logits.float() loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] @@ -1048,6 +1727,82 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): attentions=outputs.attentions, ) + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + @add_start_docstrings( """ @@ -1062,7 +1817,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, - DEEPSEEK_V3_START_DOCSTRING, + DeepseekV3_START_DOCSTRING, ) class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): def __init__(self, config): @@ -1080,13 +1835,13 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value - @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) def forward( self, - input_ids: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1096,11 +1851,13 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) transformer_outputs = self.model( input_ids, @@ -1122,24 +1879,50 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ).to(logits.device) else: sequence_lengths = -1 - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1150,12 +1933,4 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, - ) - - -__all__ = [ - "DeepseekV3PreTrainedModel", - "DeepseekV3Model", - "DeepseekV3ForCausalLM", - "DeepseekV3ForSequenceClassification", -] \ No newline at end of file + ) \ No newline at end of file diff --git a/ktransformers/operators/RoPE.py b/ktransformers/operators/RoPE.py index 06b0ab4..9e2eb44 100644 --- a/ktransformers/operators/RoPE.py +++ b/ktransformers/operators/RoPE.py @@ -23,7 +23,7 @@ from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.utils import InferenceState from transformers.configuration_utils import PretrainedConfig - +import torch # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): @@ -56,6 +56,57 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): ) +class RotaryEmbeddingV3(BaseInjectedModule): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + # device: str = "cuda", + generate_device: str = "cuda", + prefill_device: str = "cuda", + **kwargs, + ): + BaseInjectedModule.__init__( + self, key, gguf_loader, config, orig_module, generate_device, **kwargs + ) + self.generate_device = generate_device + self.prefill_device = prefill_device + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def load(self): + self._init( + dim=self.config.qk_rope_head_dim, + max_position_embeddings=self.config.max_position_embeddings, + base=self.config.rope_theta, + device=self.device, + ) + def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0): + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + # self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding): def __init__( self, diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index f98bfff..9b47b89 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -151,7 +151,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention): attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + return attn_output, attn_weights, past_key_value def forward( self, @@ -220,7 +220,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention): attn_output = torch.cat((attn_output, cur_output), dim=-2) attn_weight = torch.cat((attn_weight, cur_attn_weight), dim=-2) - return attn_output, attn_weight + return attn_output, attn_weight, past_key_value class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 03a1488..f3fd515 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -734,7 +734,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): identity = hidden_states orig_shape = hidden_states.shape sequence_length = orig_shape[1] - topk_idx, topk_weight, router_logits= self.gate(hidden_states) + topk_idx, topk_weight = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # only for generate phase @@ -745,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) y += y_ y.resize_(*orig_shape) - return y, router_logits + return y if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) @@ -768,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): ) if self.config.n_shared_experts is not None: y += y_ - return y, router_logits + return y @torch.no_grad() def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py index dcf45cb..ab7d0b2 100644 --- a/ktransformers/operators/gate.py +++ b/ktransformers/operators/gate.py @@ -16,9 +16,6 @@ from cpuinfer_ext.moe import MOEConfig, MOE import ctypes from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_gguf import GGUFLoader -from ktransformers.models.modeling_deepseek_v3 import DeepseekV3TopkRouter -from ktransformers.util.utils import InferenceState -from ktransformers.server.config.config import Config from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from abc import ABC, abstractmethod @@ -102,6 +99,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): ): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + self.generate_device = generate_device + self.prefill_device = prefill_device def forward(self, hidden_states) -> torch.Tensor: return self.orig_module.forward(hidden_states) diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index fd8fee9..5d2e911 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -625,6 +625,13 @@ class KDeepseekV2Model(BaseInjectedModule): if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) + + if inputs_embeds is None: + org_device = input_ids.device + # TODO move to embed_tokens's device, not hard code to cpu + input_ids = input_ids.to("cpu") + inputs_embeds = self.embed_tokens(input_ids).to(org_device) + input_ids = input_ids.to(org_device) if cache_position is None: past_seen_tokens = ( @@ -639,13 +646,6 @@ class KDeepseekV2Model(BaseInjectedModule): if position_ids is None: position_ids = cache_position.unsqueeze(0) - if inputs_embeds is None: - org_device = input_ids.device - # TODO move to embed_tokens's device, not hard code to cpu - input_ids = input_ids.to("cpu") - inputs_embeds = self.embed_tokens(input_ids).to(org_device) - input_ids = input_ids.to(org_device) - if per_layer_prefill_flag: causal_mask = None else: @@ -717,6 +717,8 @@ class KDeepseekV2Model(BaseInjectedModule): self.load_layer_to(decoder_layer, InferenceState.PREFILL) torch.cuda.empty_cache() t4 = time.time() + # with open("log.txt", "a") as f: + # f.write(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n") layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, @@ -739,13 +741,17 @@ class KDeepseekV2Model(BaseInjectedModule): hidden_states = layer_outputs[0] # @@@@@@@ TODO open this notes, tmp close to fit deepseekv3 - # if use_cache: - # next_decoder_cache = layer_outputs[2 if output_attentions else 1] + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) + # with open("log.txt", "a") as f: + # f.write(f"@@@After layers\n") + # f.write(f"hidden_states={hidden_states}\n") + # f.write(f"hidden_states.shape={hidden_states.shape}\n") if per_layer_prefill_flag: t6 = time.time() diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml index 4dfff61..22be22e 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml @@ -10,7 +10,7 @@ name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.DeepSeekV3YarnRotaryEmbedding + class: ktransformers.operators.RoPE.RotaryEmbeddingV3 kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" @@ -18,7 +18,7 @@ name: "^model\\.layers\\.([3456][0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.DeepSeekV3YarnRotaryEmbedding + class: ktransformers.operators.RoPE.RotaryEmbeddingV3 kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" @@ -64,7 +64,7 @@ - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" - class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter + class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: @@ -72,7 +72,7 @@ prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" - class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter + class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function kwargs: @@ -106,14 +106,14 @@ - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$" replace: - class: ktransformers.operators.attention.KDeepseekV3Attention # optimized MLA implementation + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\.self_attn$" replace: - class: ktransformers.operators.attention.KDeepseekV3Attention # optimized MLA implementation + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index c34f17f..fd8f10b 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -24,7 +24,7 @@ class KTransformersInterface(TransformersInterface): self.args = args torch.set_default_dtype(torch.bfloat16) torch.set_grad_enabled(False) - self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device) + self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=True) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) if config.architectures[0] == "Qwen2MoeForCausalLM": config._attn_implementation = "flash_attention_2" @@ -99,7 +99,7 @@ class KTransformersInterface(TransformersInterface): if self.use_static_cache: mask = torch.ones((1, self.seq_length)).to(torch_device) logits = self.model( - self.current_ids, + self.current_ids.to(torch_device), cache_position=self.active_cache_position, past_key_values=self.cache, attention_mask=mask, diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 517045b..ad24dbf 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -198,7 +198,7 @@ class TransformersInterface(BackendInterfaceBase): return self.streamer.put(new_tokens) def logits_to_token(self, logits: torch.Tensor): - logits = logits / self.args.temperature + logits = logits / self.args.temperature if self.args.temperature!=0 else logits for token_idx in self.ever_generated_ids: if logits[token_idx] < 0: @@ -318,7 +318,9 @@ class TransformersInterface(BackendInterfaceBase): if isinstance(local_messages, List): input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) elif isinstance(local_messages, str): + #local_messages = local_messages[0]['content'] input_ids = self.tokenize_prompt(local_messages) + #input_ids = torch.tensor([[6366]], device=input_ids.device) else: raise ValueError("local_messages should be List or str") From ee24a27001d8eadd798dfea5983b16a4f49fdca8 Mon Sep 17 00:00:00 2001 From: Azure Date: Tue, 4 Feb 2025 16:14:35 +0000 Subject: [PATCH 06/26] update v3 single gpu rule yaml; --- ktransformers/local_chat.py | 3 +-- .../optimize/optimize_rules/DeepSeek-V3-Chat.yaml | 15 +++++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index f16ee7f..513f480 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -32,8 +32,7 @@ custom_models = { ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" default_optimize_rules = { "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", - # "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", - "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-multi-gpu.yaml", + "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml", "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml", diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml index 6fb87b7..4a306be 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml @@ -1,7 +1,7 @@ - match: - class: ktransformers.models.modeling_deepseek.DeepseekV3YarnRotaryEmbedding + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.YarnRotaryEmbedding + class: ktransformers.operators.RoPE.RotaryEmbeddingV3 kwargs: generate_device: "cuda" prefill_device: "cuda" @@ -17,12 +17,19 @@ prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" - class: ktransformers.models.modeling_deepseek.DeepseekV2MoE + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: - class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" +- match: + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: From 027b11266c9bb59ca4853fa854aeacbb2370e7c0 Mon Sep 17 00:00:00 2001 From: Azure Date: Thu, 6 Feb 2025 14:07:38 +0000 Subject: [PATCH 07/26] modify moeinfer param --- ktransformers/local_chat.py | 2 ++ ktransformers/operators/experts.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 513f480..827d88f 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -81,8 +81,10 @@ def local_chat(): content = "hi" else: content = open(config.prompt_file, "r").read() + print("User: ", content) elif os.path.isfile(content): content = open(content, "r").read() + print("User: ", content) messages = his_content + [{"role": "user", "content": content}] async def async_inference(messages): diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index f3fd515..dcca038 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -163,7 +163,7 @@ class KExpertsCPU(KExpertsBase): self.config.hidden_size, self.config.moe_intermediate_size, 64, - 10, + 1024, 1024, gate_ptr, up_ptr, From 3dca28d23b1fac0c38c746fd3e15851572a54460 Mon Sep 17 00:00:00 2001 From: liam Date: Thu, 6 Feb 2025 22:39:16 +0800 Subject: [PATCH 08/26] :zap: fix moe.cpp int overflow problem --- ktransformers/ktransformers_ext/operators/llamafile/moe.cpp | 4 ++-- ktransformers/local_chat.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp b/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp index 0fcf9df..a131b1f 100644 --- a/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp +++ b/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp @@ -224,7 +224,7 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* int stride = QK_K; int nth = config_.intermediate_size / stride; backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) { - int expert_idx = task_id / nth; + uint64_t expert_idx = task_id / nth; int ith = task_id % nth; void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx]; void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); @@ -246,7 +246,7 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* stride = QK_K; nth = config_.hidden_size / stride; backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) { - int expert_idx = task_id / nth; + uint64_t expert_idx = task_id / nth; int ith = task_id % nth; void* down_input_ptr = m_local_down_input_ptr_[expert_idx]; void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 513f480..93f0ab8 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -32,7 +32,7 @@ custom_models = { ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" default_optimize_rules = { "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", - "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml", + "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-multi-gpu.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml", "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml", From c4d9bc6670c927abb82f2436762da3fc8efb2bf3 Mon Sep 17 00:00:00 2001 From: Azure Date: Fri, 7 Feb 2025 05:57:40 +0000 Subject: [PATCH 09/26] support KExpertsMarlin backend --- ktransformers/operators/experts.py | 60 +++++--- ktransformers/operators/linear.py | 4 +- .../DeepSeek-V3-Chat-multi-gpu-marlin.yaml | 143 ++++++++++++++++++ .../backend/interfaces/ktransformers.py | 51 ++++--- ktransformers/server/config/config.py | 2 + 5 files changed, 214 insertions(+), 46 deletions(-) create mode 100644 ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index f3fd515..274a3ca 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -302,13 +302,13 @@ class KExpertsMarlin(KExpertsBase): if w is None: w = self.load_weights()[self.key] if isinstance(w, dict): - self.gate = nn.Parameter(torch.from_numpy(w["gate"])) - self.up = nn.Parameter(torch.from_numpy(w["up"])) - self.down = nn.Parameter(torch.from_numpy(w["down"])) + self.gate = w["gate"] + self.up = (w["up"]) + self.down = (w["down"]) for i in range(self.expert_num): - self.up_projs[i].load(self.up[i,...], device=device) - self.gate_projs[i].load(self.gate[i,...], device=device) - self.down_projs[i].load(self.down[i,...], device=device) + self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device) + self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device) + self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device) self.loaded_experts_idx.append(i) return @@ -342,23 +342,45 @@ class KExpertsMarlin(KExpertsBase): up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"] down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"] # tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"]) - res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}} + res = {key:{"gate": nn.Parameter(gate), "up": nn.Parameter(up), "down": nn.Parameter(down), "gate_type": gate_type, "up_type": up_type, "down_type": down_type}} return res - def forward(self, input_tensor:torch.Tensor, expert_ids, weights): - # forward - device = input_tensor.device - input_tensor = input_tensor.to("cuda") - outs = torch.zeros_like(input_tensor) - for expert_idx in range(expert_ids.size(0)): - down_proj = self.down_projs[expert_idx] - gate_proj = self.gate_projs[expert_idx] - up_proj = self.up_projs[expert_idx] + def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: + org_dtype = hidden_states_cpu.dtype + org_device = hidden_states_cpu.device + hidden_states_cpu = hidden_states_cpu.to(self.device) + selected_experts_cpu = selected_experts_cpu.to(self.device) + routing_weights_cpu = routing_weights_cpu.to(self.device).to(org_dtype) + + batch_sequence_length, hidden_dim = hidden_states_cpu.size() - outs += down_proj(self.act_fn(gate_proj(input_tensor)) * up_proj(input_tensor)) * weights[expert_idx] - outs = outs.to(device) - return outs + final_hidden_states = torch.zeros( + (batch_sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device + ) + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.expert_num).permute(2, 1, 0) + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.expert_num): + if not expert_mask[expert_idx].any(): + continue + idx, top_x = torch.where(expert_mask[expert_idx]) + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim) + G = self.gate_projs[expert_idx].forward(current_state) + A = self.act_fn(G) + U = self.up_projs[expert_idx].forward(current_state) + H = A * U # Element-wise multiplication + current_hidden_states = self.down_projs[expert_idx].forward(H) * routing_weights_cpu[top_x, idx, None] + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states) + + return final_hidden_states.to(dtype=org_dtype, device=org_device) + class KExpertsTorch(KExpertsBase): expert_num: int loaded_experts_idx: list[int] diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index a79778a..9e35e8d 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -138,10 +138,10 @@ class KLinearTorch(KLinearBase): if w is None: w = self.load_weight(device=device) if isinstance(w, nn.Parameter): - self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T + self.w = w.to(dtype=self.dtype).T self.has_bias = False elif isinstance(w, tuple): - self.w = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T + self.w = w[0].to(dtype=self.dtype).T self.bias = w[1].to(dtype=self.dtype) self.has_bias = True else: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml new file mode 100644 index 0000000..22be22e --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml @@ -0,0 +1,143 @@ +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.RotaryEmbeddingV3 + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([3456][0-9])\\." + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.RotaryEmbeddingV3 + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([3456][0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:0" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda:0" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:1" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda:1" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([3456][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill + transfer_map: + 30: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." + replace: + class: "default" + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + +- match: + name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" + replace: + class: "default" + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index fd8f10b..d228b64 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -24,8 +24,8 @@ class KTransformersInterface(TransformersInterface): self.args = args torch.set_default_dtype(torch.bfloat16) torch.set_grad_enabled(False) - self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=True) - config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code) + config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) if config.architectures[0] == "Qwen2MoeForCausalLM": config._attn_implementation = "flash_attention_2" @@ -71,30 +71,31 @@ class KTransformersInterface(TransformersInterface): self.streamer = TextStreamer(self.tokenizer) def decode_one_tokens(self): - if not hasattr(self, "cuda_graph_runner"): - device_map = self.model.gguf_loader.tensor_device_map - torch_device = get_device("blk.0.self_attn", device_map) - torch_device = "cuda:0" if torch_device == "cuda" else torch_device - self.cuda_graph_runner = CUDAGraphRunner() - self.cuda_graph_runner.capture( - self.model, - self.current_ids, - self.active_cache_position.unsqueeze(0), - self.active_cache_position, - self.cache, - main_device=torch_device, - return_dict=False, - use_cache=True, - ) + device_map = self.model.gguf_loader.tensor_device_map + torch_device = get_device("blk.0.self_attn", device_map) + torch_device = "cuda:0" if torch_device == "cuda" else torch_device + if self.args.use_cuda_graph: + if not hasattr(self, "cuda_graph_runner"): + self.cuda_graph_runner = CUDAGraphRunner() + self.cuda_graph_runner.capture( + self.model, + self.current_ids, + self.active_cache_position.unsqueeze(0), + self.active_cache_position, + self.cache, + main_device=torch_device, + return_dict=False, + use_cache=True, + ) - if hasattr(self, "cuda_graph_runner"): - logits = self.cuda_graph_runner( - self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position - ) - self.cache.change_seq_length(1) - torch.cuda.synchronize() - logits = logits[0, -1, :] - return self.logits_to_token(logits) + if hasattr(self, "cuda_graph_runner"): + logits = self.cuda_graph_runner( + self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position + ) + self.cache.change_seq_length(1) + torch.cuda.synchronize() + logits = logits[0, -1, :] + return self.logits_to_token(logits) if self.use_static_cache: mask = torch.ones((1, self.seq_length)).to(torch_device) diff --git a/ktransformers/server/config/config.py b/ktransformers/server/config/config.py index cf5c0ef..7ce616b 100644 --- a/ktransformers/server/config/config.py +++ b/ktransformers/server/config/config.py @@ -93,6 +93,8 @@ class Config(metaclass=Singleton): self.model_name: str = self.model.get("name", "") self.model_device: str = self.model.get("device", "cuda:0") self.gguf_path: Optional[str] = self.model.get("gguf_path", None) + self.use_cuda_graph = self.model.get("use_cuda_graph", True) + self.trust_remote_code = self.model.get("trust_remote_code", True) # self.model_cache_lens = self.model.get("cache_lens") self.optimize_config_path: Optional[str] = self.model.get( "optimize_config_path", None From c18ecd7b7f4c68318d147db296652bd635372141 Mon Sep 17 00:00:00 2001 From: liam Date: Sat, 8 Feb 2025 13:15:52 +0800 Subject: [PATCH 10/26] :zap: add flush print in local_chat output and change default optimize yaml of deepseekv3 to single gpu --- ktransformers/local_chat.py | 2 +- ktransformers/server/backend/interfaces/transformers.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index bc26dda..827d88f 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -32,7 +32,7 @@ custom_models = { ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" default_optimize_rules = { "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", - "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-multi-gpu.yaml", + "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml", "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml", diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index ad24dbf..81fa6e5 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -329,14 +329,14 @@ class TransformersInterface(BackendInterfaceBase): self.profiler.create_and_start_timer("prefill") for t in self.prefill(input_ids, self.check_is_new(thread_id)): if t is not None: - print(t, end="") + print(t, end="",flush=True) yield t self.profiler.pause_timer("prefill") self.profiler.create_and_start_timer("decode") for t in self.generate(): if t is not None: - print(t, end="") + print(t, end="",flush=True) yield t print("") self.profiler.pause_timer("decode") From 098602b08fdc92badf8331eab3deb6f56c5166f1 Mon Sep 17 00:00:00 2001 From: liam Date: Sun, 9 Feb 2025 22:39:01 +0800 Subject: [PATCH 11/26] :zap: v0.2 ongoing --- .gitignore | 5 +- Makefile | 2 +- README.md | 61 +++-- doc/en/DeepseekR1_V3_tutorial.md | 64 +++++ .../ktransformers_ext/CMakeLists.txt | 21 ++ .../ktransformers_ext/cpu_backend/backend.cpp | 17 ++ .../ktransformers_ext/cpu_backend/backend.h | 3 + .../operators/llamafile/moe.cpp | 75 +++++ .../operators/llamafile/moe.h | 6 + ktransformers/local_chat.py | 258 +++++++++++++++--- setup.py | 8 +- 11 files changed, 450 insertions(+), 70 deletions(-) create mode 100644 doc/en/DeepseekR1_V3_tutorial.md diff --git a/.gitignore b/.gitignore index c33a95d..d45e956 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,7 @@ compile_commands.json ktransformers/server/local_store/ ktransformers/server_test1.db *.patch -img/ \ No newline at end of file +img/ +tmp1.txt +test_65_300_1536.txt +test.txt diff --git a/Makefile b/Makefile index dbf771d..f8633a9 100644 --- a/Makefile +++ b/Makefile @@ -17,5 +17,5 @@ dev_install: pip install -r requirements-local_chat.txt echo "Installing ktransformers" - KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . --no-build-isolation + KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . -v --no-build-isolation echo "Installation completed successfully" \ No newline at end of file diff --git a/README.md b/README.md index eb23bf8..8d92cb7 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

πŸ”₯ Updates

+* **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to XXX speedup. The Detailed tutorial is [here](./doc/en/DeepseekR1_V3_tutorial.md) * **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./doc/en/long_context_tutorial.md). * **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G. * **Aug 15, 2024**: Update detailed [TUTORIAL](doc/en/injection_tutorial.md) for injection and multi-GPU. @@ -31,6 +32,43 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin * **Aug 9, 2024**: Support windows native.

πŸ”₯ Show Cases

+ +
+

GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM

+
+ +https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285 + +

+ +- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 12GB VRAM and 382GB DRAM. + - Prefill Speed: + - KTransfermor: 54.21 (32 cores) β†’ 74.362 (dual-socket, 2Γ—32 cores) β†’ xxx (optimized AMX-based MoE kernel, v3 only) β†’ XXX (selectively using 6 experts, v3 only) + - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **XXXΓ— speedup**. + - Decode Speed(tokens/s): + - KTransfermor: 8.73 (32 cores) β†’ 11.26 (dual-socket, 2Γ—32 cores) β†’ 13.69 (selectively using 6 experts, v3 only) + - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **3.03Γ— speedup**. + - Upcoming Open Source Release: + - AMX optimizations and selective expert activation will be open-sourced in v0.3. + - Currently available only in preview binary distribution, which can be found here. + +- **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench). + +

+ + DeepSeek-Coder-V2 Score + +

+ +- **Faster Speed:** Achieving 126 tokens/s for 2K prompt prefill and 13.6 tokens/s for generation through MoE offloading and injecting advanced kernels from [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) and [Marlin](https://github.com/IST-DASLab/marlin). +- **VSCode Integration:** Wrapped into an OpenAI and Ollama compatible API for seamless integration as a backend for [Tabby](https://github.com/TabbyML/tabby) and various other frontends. + +

+ +https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c + +

+

1M Context Local Inference on a Desktop with Only 24GB VRAM

@@ -54,30 +92,7 @@ https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12 * **Flexible Sparse Attention Framework**: Offers a flexible block sparse attention framework for CPU offloaded decoding. Compatible with SnapKV, Quest, and InfLLm. Further information is available [here](./doc/en/long_context_introduction.md). -

-

GPT-4-level Local VSCode Copilot on a Desktop with only 24GB VRAM

-
-https://github.com/user-attachments/assets/0b9fa2da-66f0-48eb-b4b9-f0e1f06f8927 - -

- -- **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench). - -

- - DeepSeek-Coder-V2 Score - -

- -- **Faster Speed:** Achieving 126 tokens/s for 2K prompt prefill and 13.6 tokens/s for generation through MoE offloading and injecting advanced kernels from [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) and [Marlin](https://github.com/IST-DASLab/marlin). -- **VSCode Integration:** Wrapped into an OpenAI and Ollama compatible API for seamless integration as a backend for [Tabby](https://github.com/TabbyML/tabby) and various other frontends. - -

- -https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c - -

More advanced features will coming soon, so stay tuned! diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md new file mode 100644 index 0000000..1b1e6c7 --- /dev/null +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -0,0 +1,64 @@ +## prerequisites +We run our best performance tests on
+cpu: Intel(R) Xeon(R) Gold 6454S 1T DRAM(2 NUMA nodes)
+gpu: 4090D 24G VRAM
+## bench result +### V0.2 +#### settings +- model: DeepseekV3-q4km(int4οΌ‰
+- CPU: cpu_model_name:Intel(R) Xeon(R) Gold 6454S, 32 cores per socket, 2 socket, 2numa nodes +- GPU: 4090D 24GVRAM +- we test after enough warm up! +#### memory consumption: + - single socket: 382G DRAM, 12G VRAM + - dual socket: 1T DRAM, 12G VRAM + +#### Benchmark Results + +"6 experts" case is part of v0.3's preview + +| Prompt
(500 tokens) | Dual socket Ktrans (6 experts) | Dual socket Ktrans (8 experts) | Single socket Ktrans (6 experts) | Single socket Ktrans (8 experts)| Llama (8 experts) | +| --- | --- | --- | --- | --- | --- | +| Prefill token/s | 97.32 | 82.94 | 65.14 | 54.21 | 10.31 | +| Decode token/s | 13.69 | 12.208 | 10.303 | 8.73 |4.51 | + +**The highest speedup reaches up to x3.03 in decoding and x9.44 in prefill.** + +## how to run +### v0.2 showcase +#### single socket version(32 cores) +our local_chat test command is: +``` shell +git clone https://github.com/kvcache-ai/ktransformers.git +cd ktransformers +numactl -N 1 -m 1 python ./ktransformers/local_chat.py --model_path --gguf_path --prompt_file --cpu_infer 33 --cache_lens 1536 + +``` +\ can be local or set from onlie hugging face like deepseek-ai/DeepSeek-V3. If onlie encounters connection problem, try use mirror(hf-mirror.com)
+\ can also be onlie, but as its large we recommend you download it and quantize the model to what you want.
+the command numactl -N 1 -m 1 aims to adoid data transfer between numa nodes. +### dual socket version(64 cores) +make suer before you install(use install.sh or `make dev_install`), setting the env var `USE_NUMA=1` by `export USE_NUMA=1`(if already installed, reinstall it with this env var set)
+our local_chat test command is: +``` shell +git clone https://github.com/kvcache-ai/ktransformers.git +cd ktransformers +export USE_NUMA=1 +make dev_install # or sh ./install.sh +python ./ktransformers/local_chat.py --model_path --gguf_path --prompt_file --cpu_infer 65 --cache_lens 1536 + +``` +The parameters meaning is the same. But As we use dual socket, so we set cpu_infer to 65. +## some explanations +1. From our perspective on DeepSeekV2, DeepSeekV3 and DeepSeekR1, +when we slightly decrease the activation experts num in inference, +the output quality doesn't change(within 1% accuracy drop),But the speed of decoding and prefill +is speed up about 30% which is inspiring. So our showcase makes use of this finding, +changing the activation experts of DeepSeekV3/R1 from 8 to 6.
+2. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu. +To avoid the cost of data transfer between nodes, we "copy" the critical matrix on +both nodes which takes more memory consumption but accelerates the prefill and decoding process. +But this method takes huge memory and slow when loading weights, So be patient when loading +and monitor the memory usage.(we are considering to make this method as an option)
+3. the command args `--cpu_infer 65` specifies how many cores to use(it's ok that it exceeds the physical number, +but it's not the more the better. Adjust it slight lower to your actual number of cores)
diff --git a/ktransformers/ktransformers_ext/CMakeLists.txt b/ktransformers/ktransformers_ext/CMakeLists.txt index 1ef9823..d9ecd7a 100644 --- a/ktransformers/ktransformers_ext/CMakeLists.txt +++ b/ktransformers/ktransformers_ext/CMakeLists.txt @@ -230,3 +230,24 @@ elseif(UNIX) endif() target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so") endif() + +# Define the USE_NUMA option +option(USE_NUMA "Disable NUMA support" OFF) +# Check if the USE_NUMA environment variable is set +if(DEFINED ENV{USE_NUMA}) + set(USE_NUMA ON) +endif() +if (USE_NUMA) + message(STATUS "NUMA support is enabled") +else() + message(STATUS "NUMA support is disabled") +endif() + +find_library(NUMA_LIBRARY NAMES numa) +if (NUMA_LIBRARY AND USE_NUMA) + message(STATUS "NUMA library found: ${NUMA_LIBRARY} - enabling NUMA support") + target_link_libraries(${PROJECT_NAME} PRIVATE ${NUMA_LIBRARY}) + target_compile_definitions(${PROJECT_NAME} PRIVATE USE_NUMA) +else() + message(STATUS "NUMA library not found or user not set USE_NUMA - disabling NUMA support") +endif() diff --git a/ktransformers/ktransformers_ext/cpu_backend/backend.cpp b/ktransformers/ktransformers_ext/cpu_backend/backend.cpp index 16693f0..5980ba3 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/backend.cpp +++ b/ktransformers/ktransformers_ext/cpu_backend/backend.cpp @@ -10,6 +10,13 @@ #include "backend.h" +#ifdef USE_NUMA +#include +#include + +thread_local int Backend::numa_node = -1; +#endif + thread_local int Backend::thread_local_id = -1; Backend::Backend(int max_thread_num) { @@ -74,6 +81,16 @@ void Backend::do_work_stealing_job(int task_num, } void Backend::process_tasks(int thread_id) { + + #ifdef USE_NUMA + if(numa_node == -1){ + numa_node = thread_id * numa_num_configured_nodes() / thread_num_; + struct bitmask* mask = numa_bitmask_alloc(numa_num_configured_nodes()); + numa_bitmask_setbit(mask, numa_node); + numa_bind(mask); + } + #endif + if (init_func_ != nullptr) { init_func_(thread_id); } diff --git a/ktransformers/ktransformers_ext/cpu_backend/backend.h b/ktransformers/ktransformers_ext/cpu_backend/backend.h index 80ff7f9..7a95f27 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/backend.h +++ b/ktransformers/ktransformers_ext/cpu_backend/backend.h @@ -38,6 +38,9 @@ class Backend { void do_work_stealing_job(int, std::function, std::function, std::function); + #ifdef USE_NUMA + static thread_local int numa_node; + #endif static thread_local int thread_local_id; private: diff --git a/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp b/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp index a131b1f..35c144f 100644 --- a/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp +++ b/ktransformers/ktransformers_ext/operators/llamafile/moe.cpp @@ -11,11 +11,41 @@ #include #include +#ifdef USE_NUMA +#include +#include +#endif + MOE::MOE(MOEConfig config) { config_ = config; gate_proj_ = config_.gate_proj; up_proj_ = config_.up_proj; down_proj_ = config_.down_proj; + + #ifdef USE_NUMA + int numa_nodes = numa_num_configured_nodes(); + gate_proj_numa_.resize(numa_nodes); + up_proj_numa_.resize(numa_nodes); + down_proj_numa_.resize(numa_nodes); + size_t exp_inter_hidden_mul_ = (size_t)config.expert_num * config.intermediate_size * config.hidden_size; + for (int i = 0; i < numa_nodes; i++) { + gate_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type), i); + up_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type), i); + down_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type), i); + if (!gate_proj_numa_[i]) { + std::cout << "Memory allocation failed for gate_proj_numa_ on node " << i << std::endl; + } + if (!up_proj_numa_[i]) { + std::cout << "Memory allocation failed for up_proj_numa_ on node " << i << std::endl; + } + if (!down_proj_numa_[i]) { + std::cout << "Memory allocation failed for down_proj_numa_ on node " << i << std::endl; + } + memcpy(gate_proj_numa_[i], gate_proj_, exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type)); + memcpy(up_proj_numa_[i], up_proj_, exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type)); + memcpy(down_proj_numa_[i], down_proj_, exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type)); + } + #endif std::vector> s_mem_requests; s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size}); @@ -74,6 +104,15 @@ MOE::MOE(MOEConfig config) { MOE::~MOE() { shared_mem_buffer.dealloc(this); + + #ifdef USE_NUMA + int numa_nodes = numa_num_configured_nodes(); + for (int i = 0; i < numa_nodes; i++) { + numa_free(gate_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type)); + numa_free(up_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type)); + numa_free(down_proj_numa_[i], config_.expert_num * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type)); + } + #endif } void MOE::warm_up(Backend* backend) { @@ -125,10 +164,22 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c int expert_idx = task_id / nth; uint64_t expert_id = expert_ids[expert_idx]; int ith = task_id % nth; + + #ifdef USE_NUMA + void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); + #else void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); + #endif + float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); + + #ifdef USE_NUMA + void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); + #else void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); + #endif + float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { @@ -153,7 +204,13 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c } for (int expert_idx = 0; expert_idx < k; expert_idx++) { uint64_t expert_id = expert_ids[expert_idx]; + + #ifdef USE_NUMA + void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); + #else void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); + #endif + float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { @@ -227,11 +284,23 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* uint64_t expert_idx = task_id / nth; int ith = task_id % nth; void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx]; + + #ifdef USE_NUMA + void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); + #else void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); + #endif + float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); void* up_input_ptr = m_local_up_input_ptr_[expert_idx]; + + #ifdef USE_NUMA + void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); + #else void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); + #endif + float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = 0; i < m_local_num_[expert_idx]; i++) { @@ -249,7 +318,13 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* uint64_t expert_idx = task_id / nth; int ith = task_id % nth; void* down_input_ptr = m_local_down_input_ptr_[expert_idx]; + + #ifdef USE_NUMA + void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); + #else void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); + #endif + float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); }, nullptr); diff --git a/ktransformers/ktransformers_ext/operators/llamafile/moe.h b/ktransformers/ktransformers_ext/operators/llamafile/moe.h index a1470aa..a39e21d 100644 --- a/ktransformers/ktransformers_ext/operators/llamafile/moe.h +++ b/ktransformers/ktransformers_ext/operators/llamafile/moe.h @@ -61,6 +61,12 @@ class MOE { void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)] + #ifdef USE_NUMA + std::vector gate_proj_numa_; // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)] + std::vector up_proj_numa_; // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)] + std::vector down_proj_numa_; // [numa_num, expert_num * hidden_size * intermediate_size ( /32 if quantized)] + #endif + float* s_input_fp32_; // [hidden_size] uint8_t* s_gate_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)] uint8_t* s_up_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)] diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 827d88f..5f17c21 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -1,24 +1,140 @@ +# """ +# Description : +# Author : Boxin Zhang, Azure-Tang +# Version : 0.1.0 +# Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +# """ + +# import asyncio +# import os +# import platform +# import sys +# project_dir = os.path.dirname(os.path.dirname(__file__)) +# sys.path.insert(0, project_dir) +# from ktransformers.server.args import ArgumentParser + + +# from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM +# from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM +# from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM +# from ktransformers.models.modeling_llama import LlamaForCausalLM +# from ktransformers.models.modeling_mixtral import MixtralForCausalLM +# from ktransformers.server.config.config import Config + +# custom_models = { +# "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, +# "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM, +# "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM, +# "LlamaForCausalLM": LlamaForCausalLM, +# "MixtralForCausalLM": MixtralForCausalLM, +# } + +# ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" +# default_optimize_rules = { +# "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", +# "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml", +# "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", +# "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml", +# "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml", +# } + + +# def local_chat(): +# config = Config() +# arg_parser = ArgumentParser(config) +# # εˆε§‹εŒ–ζΆˆζ― +# arg_parser.parse_args() +# if config.backend_type == "transformers": +# from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface +# elif config.backend_type == "exllamav2": +# from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface +# elif config.backend_type == "ktransformers": +# from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface +# else: +# raise NotImplementedError(f"{config.backend_type} not implemented") +# interface = BackendInterface(config) + +# system = platform.system() +# if system == "Windows": +# os.system("cls") +# else: +# os.system("clear") +# # add a history chat content +# his_content = [] +# while True: +# content = input("Chat: ") +# if content.startswith('"""'): # prefix """ +# # multi lines input +# content = content[3:] + "\n" +# while True: +# line = input("") +# if line.endswith('"""'): +# # end multi lines input +# line = line[:-3] # suffix """ +# if line: +# content += line + "\n" +# break +# else: +# content += line + "\n" +# if content == "": +# if not config.prompt_file: +# content = "hi" +# else: +# content = open(config.prompt_file, "r").read() +# print("User: ", content) +# elif os.path.isfile(content): +# content = open(content, "r").read() +# print("User: ", content) +# messages = his_content + [{"role": "user", "content": content}] + +# async def async_inference(messages): +# generated = "" +# async for token in interface.inference(messages, "local_chat"): +# generated += token +# return generated + +# generated = asyncio.run(async_inference(messages)) +# his_content += [ +# {"role": "user", "content": content}, +# {"role": "assistant", "content": generated}, +# ] + + +# if __name__ == "__main__": +# local_chat() + + """ -Description : +Description : Author : Boxin Zhang, Azure-Tang Version : 0.1.0 -Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ -import asyncio import os import platform import sys + project_dir = os.path.dirname(os.path.dirname(__file__)) sys.path.insert(0, project_dir) -from ktransformers.server.args import ArgumentParser - - +import torch +import logging +from transformers import ( + AutoTokenizer, + AutoConfig, + AutoModelForCausalLM, + GenerationConfig, + TextStreamer, +) +import json +import fire +from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM -from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM +from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM +from ktransformers.util.utils import prefill_and_generate from ktransformers.server.config.config import Config custom_models = { @@ -29,7 +145,9 @@ custom_models = { "MixtralForCausalLM": MixtralForCausalLM, } -ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" +ktransformer_rules_dir = ( + os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" +) default_optimize_rules = { "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml", @@ -39,28 +157,85 @@ default_optimize_rules = { } -def local_chat(): - config = Config() - arg_parser = ArgumentParser(config) - # εˆε§‹εŒ–ζΆˆζ― - arg_parser.parse_args() - if config.backend_type == "transformers": - from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface - elif config.backend_type == "exllamav2": - from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface - elif config.backend_type == "ktransformers": - from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface +def local_chat( + model_path: str | None = None, + optimize_rule_path: str = None, + gguf_path: str | None = None, + max_new_tokens: int = 1000, + cpu_infer: int = Config().cpu_infer, + use_cuda_graph: bool = True, + prompt_file : str | None = None, + mode: str = "normal", +): + + + torch.set_grad_enabled(False) + + Config().cpu_infer = cpu_infer + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + if mode == 'long_context': + assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode" + torch.set_default_dtype(torch.float16) else: - raise NotImplementedError(f"{config.backend_type} not implemented") - interface = BackendInterface(config) + torch.set_default_dtype(config.torch_dtype) + + with torch.device("meta"): + if config.architectures[0] in custom_models: + print("using custom modeling_xxx.py.") + if ( + "Qwen2Moe" in config.architectures[0] + ): # Qwen2Moe must use flash_attention_2 to avoid overflow. + config._attn_implementation = "flash_attention_2" + if "Llama" in config.architectures[0]: + config._attn_implementation = "eager" + if "Mixtral" in config.architectures[0]: + config._attn_implementation = "flash_attention_2" + + model = custom_models[config.architectures[0]](config) + else: + model = AutoModelForCausalLM.from_config( + config, trust_remote_code=True, attn_implementation="flash_attention_2" + ) + + if optimize_rule_path is None: + if config.architectures[0] in default_optimize_rules: + print("using default_optimize_rule for", config.architectures[0]) + optimize_rule_path = default_optimize_rules[config.architectures[0]] + else: + optimize_rule_path = input( + "please input the path of your rule file(yaml file containing optimize rules):" + ) + + if gguf_path is None: + gguf_path = input( + "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" + ) + optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config) + + try: + model.generation_config = GenerationConfig.from_pretrained(model_path) + except: + gen_config = GenerationConfig( + max_length=128, + temperature=0.7, + top_p=0.9, + do_sample=True + ) + model.generation_config = gen_config + # model.generation_config = GenerationConfig.from_pretrained(model_path) + if model.generation_config.pad_token_id is None: + model.generation_config.pad_token_id = model.generation_config.eos_token_id + model.eval() + logging.basicConfig(level=logging.INFO) system = platform.system() if system == "Windows": os.system("cls") else: os.system("clear") - # add a history chat content - his_content = [] + while True: content = input("Chat: ") if content.startswith('"""'): # prefix """ @@ -76,29 +251,28 @@ def local_chat(): break else: content += line + "\n" + if content == "": - if not config.prompt_file: - content = "hi" + if prompt_file != None: + content = open(prompt_file, "r").read() else: - content = open(config.prompt_file, "r").read() - print("User: ", content) + content = "Please write a piece of quicksort code in C++." elif os.path.isfile(content): content = open(content, "r").read() - print("User: ", content) - messages = his_content + [{"role": "user", "content": content}] - - async def async_inference(messages): - generated = "" - async for token in interface.inference(messages, "local_chat"): - generated += token - return generated - - generated = asyncio.run(async_inference(messages)) - his_content += [ - {"role": "user", "content": content}, - {"role": "assistant", "content": generated}, - ] + messages = [{"role": "user", "content": content}] + input_tensor = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, return_tensors="pt" + ) + if mode == 'long_context': + assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ + "please change max_seq_len in ~/.ktransformers/config.yaml" + torch.set_default_dtype( + torch.bfloat16 + ) # TODO: Remove this, replace dtype using config + generated = prefill_and_generate( + model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode + ) if __name__ == "__main__": - local_chat() + fire.Fire(local_chat) \ No newline at end of file diff --git a/setup.py b/setup.py index 2a09b48..d24db14 100644 --- a/setup.py +++ b/setup.py @@ -278,13 +278,15 @@ class CMakeBuild(BuildExtension): if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: if hasattr(self, "parallel") and self.parallel: build_args += [f"-j{self.parallel}"] - + print("CMake args:", cmake_args) build_temp = Path(ext.sourcedir) / "build" if not build_temp.exists(): build_temp.mkdir(parents=True) - subprocess.run( - ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True + result = subprocess.run( + ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True ) + print("Standard output:", result.stdout) + print("Standard error:", result.stderr) subprocess.run( ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True ) From 6b33f41de45f4911a1e4008a3d39ff9d6e8f32eb Mon Sep 17 00:00:00 2001 From: chenht2022 Date: Sun, 9 Feb 2025 16:08:16 +0000 Subject: [PATCH 12/26] Add V0.3-preview doc --- README.md | 4 ++-- doc/en/DeepseekR1_V3_tutorial.md | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8d92cb7..c7411b5 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,8 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285 - **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 12GB VRAM and 382GB DRAM. - Prefill Speed: - - KTransfermor: 54.21 (32 cores) β†’ 74.362 (dual-socket, 2Γ—32 cores) β†’ xxx (optimized AMX-based MoE kernel, v3 only) β†’ XXX (selectively using 6 experts, v3 only) - - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **XXXΓ— speedup**. + - KTransfermor: 54.21 (32 cores) β†’ 74.362 (dual-socket, 2Γ—32 cores) β†’ 255.26 (optimized AMX-based MoE kernel, v3 only) β†’ 286.55 (selectively using 6 experts, v3 only) + - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **63.53Γ— speedup**. - Decode Speed(tokens/s): - KTransfermor: 8.73 (32 cores) β†’ 11.26 (dual-socket, 2Γ—32 cores) β†’ 13.69 (selectively using 6 experts, v3 only) - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **3.03Γ— speedup**. diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index 1b1e6c7..1bc1adf 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -24,6 +24,27 @@ gpu: 4090D 24G VRAM
**The highest speedup reaches up to x3.03 in decoding and x9.44 in prefill.** +### V0.3-Preview +#### settings +- model: DeepseekV3-BF16 (online quant into int8 for CPU and int4 for GPU) +- CPU: cpu_model_name:Intel(R) Xeon(R) Gold 6454S, 32 cores per socket, 2 socket, 2numa nodes +- GPU: (1~4)x 4090D 24GVRAM (requires more VRAM for longer prompt) + +#### memory consumptions: +- 644GB DRAM, at least 12GB VRAM + +#### Benchmark Results +| Prompt length | 1K | 2K | 4K | 8K | +|---------------|-----|-----|-----|-----| +| KTrans (8 experts) Prefill token/s | 185.96 | 255.26 | 252.58 | 195.62 | +| KTrans (6 experts) Prefill token/s | 203.70 | 286.55 | 271.08 | 207.20 | + +**The prefill of KTrans V0.3 is up to x3.45 times faster than KTrans V0.2. The decoding speed is the same as KTrans V0.2 (6 experts version) so it is omitted.** + +The main acceleration comes from +- Intel AMX instruction set and our specially designed cache friendly memory layout +- Expert selection strategy that selects fewer experts based on offline profile results of out of domain data + ## how to run ### v0.2 showcase #### single socket version(32 cores) From 2d684ee96a1a5ad57f9edbbb20297c1322b683ca Mon Sep 17 00:00:00 2001 From: chenht2022 Date: Sun, 9 Feb 2025 16:25:43 +0000 Subject: [PATCH 13/26] Small fix --- doc/en/DeepseekR1_V3_tutorial.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index 1bc1adf..4a4a27f 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -27,7 +27,7 @@ gpu: 4090D 24G VRAM
### V0.3-Preview #### settings - model: DeepseekV3-BF16 (online quant into int8 for CPU and int4 for GPU) -- CPU: cpu_model_name:Intel(R) Xeon(R) Gold 6454S, 32 cores per socket, 2 socket, 2numa nodes +- CPU: cpu_model_name:Intel(R) Xeon(R) Gold 6454S, 32 cores per socket, 2 socket, 2 numa nodes - GPU: (1~4)x 4090D 24GVRAM (requires more VRAM for longer prompt) #### memory consumptions: @@ -39,7 +39,8 @@ gpu: 4090D 24G VRAM
| KTrans (8 experts) Prefill token/s | 185.96 | 255.26 | 252.58 | 195.62 | | KTrans (6 experts) Prefill token/s | 203.70 | 286.55 | 271.08 | 207.20 | -**The prefill of KTrans V0.3 is up to x3.45 times faster than KTrans V0.2. The decoding speed is the same as KTrans V0.2 (6 experts version) so it is omitted.** +**The prefill of KTrans V0.3 is up to x3.45 times faster than KTrans V0.2, and is up to x63.53 times faster than Llama.** +**The decoding speed is the same as KTrans V0.2 (6 experts version) so it is omitted.** The main acceleration comes from - Intel AMX instruction set and our specially designed cache friendly memory layout From c7e6d09068a88e752b43eed0f2c4e56ace6b7005 Mon Sep 17 00:00:00 2001 From: unicornchan Date: Mon, 10 Feb 2025 01:00:57 +0000 Subject: [PATCH 14/26] [feature] update version and github action jobs for package --- .github/workflows/package_wheel_release.yml | 4 ++-- .github/workflows/package_wheel_test.yml | 4 ++-- ktransformers/__init__.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/package_wheel_release.yml b/.github/workflows/package_wheel_release.yml index 8028d59..dfbfde4 100644 --- a/.github/workflows/package_wheel_release.yml +++ b/.github/workflows/package_wheel_release.yml @@ -142,11 +142,11 @@ jobs: - name: Setup Mamba if: matrix.cuda != '' - uses: conda-incubator/setup-miniconda@v2.3.0 + uses: conda-incubator/setup-miniconda@v3 with: activate-environment: "ktransformers" python-version: ${{ matrix.pyver }} - miniforge-variant: Mambaforge + miniforge-variant: Miniforge3 miniforge-version: latest use-mamba: true add-pip-as-python-dependency: true diff --git a/.github/workflows/package_wheel_test.yml b/.github/workflows/package_wheel_test.yml index 35636db..cd8db62 100644 --- a/.github/workflows/package_wheel_test.yml +++ b/.github/workflows/package_wheel_test.yml @@ -54,11 +54,11 @@ jobs: - name: Setup Mamba if: matrix.cuda != '' - uses: conda-incubator/setup-miniconda@v2.3.0 + uses: conda-incubator/setup-miniconda@v3 with: activate-environment: "ktransformers" python-version: ${{ matrix.pyver }} - miniforge-variant: Mambaforge + miniforge-variant: Miniforge3 miniforge-version: latest use-mamba: true add-pip-as-python-dependency: true diff --git a/ktransformers/__init__.py b/ktransformers/__init__.py index 2c7b4dc..8c5108b 100644 --- a/ktransformers/__init__.py +++ b/ktransformers/__init__.py @@ -5,7 +5,7 @@ Description : Author : kkk1nak0 Date : 2024-08-15 07:34:46 Version : 1.0.0 -LastEditors : Azure-Tang -LastEditTime : 2024-08-29 22:35:51 +LastEditors : unicornchan +LastEditTime : 2025-02-10 00:59:53 ''' -__version__ = "0.1.4" \ No newline at end of file +__version__ = "0.2.0" \ No newline at end of file From 6dd4fa0e87de0e4252e2dad6b0cec9fde44c7c7a Mon Sep 17 00:00:00 2001 From: liam Date: Mon, 10 Feb 2025 09:38:26 +0800 Subject: [PATCH 15/26] :zap: improve readme --- README.md | 4 ++-- doc/en/DeepseekR1_V3_tutorial.md | 24 ++++++++++-------------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index c7411b5..b8e83cc 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

πŸ”₯ Updates

-* **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to XXX speedup. The Detailed tutorial is [here](./doc/en/DeepseekR1_V3_tutorial.md) +* **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~64X speedup. The Detailed tutorial is [here](./doc/en/DeepseekR1_V3_tutorial.md) * **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./doc/en/long_context_tutorial.md). * **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G. * **Aug 15, 2024**: Update detailed [TUTORIAL](doc/en/injection_tutorial.md) for injection and multi-GPU. @@ -50,7 +50,7 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285 - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **3.03Γ— speedup**. - Upcoming Open Source Release: - AMX optimizations and selective expert activation will be open-sourced in v0.3. - - Currently available only in preview binary distribution, which can be found here. + - Currently available only in preview binary distribution, which can be found [here](xxx). - **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench). diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index 4a4a27f..a56b689 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -1,8 +1,9 @@ -## prerequisites +# Report +## Prerequisites We run our best performance tests on
cpu: Intel(R) Xeon(R) Gold 6454S 1T DRAM(2 NUMA nodes)
gpu: 4090D 24G VRAM
-## bench result +## Bench result ### V0.2 #### settings - model: DeepseekV3-q4km(int4οΌ‰
@@ -17,12 +18,12 @@ gpu: 4090D 24G VRAM
"6 experts" case is part of v0.3's preview -| Prompt
(500 tokens) | Dual socket Ktrans (6 experts) | Dual socket Ktrans (8 experts) | Single socket Ktrans (6 experts) | Single socket Ktrans (8 experts)| Llama (8 experts) | +| Prompt
(500 tokens) | Dual socket Ktrans (6 experts) | Dual socket Ktrans (8 experts) | Single socket Ktrans (6 experts) | Single socket Ktrans (8 experts)| llama.cpp (8 experts) | | --- | --- | --- | --- | --- | --- | | Prefill token/s | 97.32 | 82.94 | 65.14 | 54.21 | 10.31 | | Decode token/s | 13.69 | 12.208 | 10.303 | 8.73 |4.51 | -**The highest speedup reaches up to x3.03 in decoding and x9.44 in prefill.** +**The highest speedup reaches up to 3.03x in decoding and 9.44x in prefill.** ### V0.3-Preview #### settings @@ -39,7 +40,7 @@ gpu: 4090D 24G VRAM
| KTrans (8 experts) Prefill token/s | 185.96 | 255.26 | 252.58 | 195.62 | | KTrans (6 experts) Prefill token/s | 203.70 | 286.55 | 271.08 | 207.20 | -**The prefill of KTrans V0.3 is up to x3.45 times faster than KTrans V0.2, and is up to x63.53 times faster than Llama.** +**The prefill of KTrans V0.3 is up to 3.45x times faster than KTrans V0.2, and is up to 63.53x times faster than llama.cpp.** **The decoding speed is the same as KTrans V0.2 (6 experts version) so it is omitted.** The main acceleration comes from @@ -72,15 +73,10 @@ python ./ktransformers/local_chat.py --model_path --gguf_path ``` The parameters meaning is the same. But As we use dual socket, so we set cpu_infer to 65. ## some explanations -1. From our perspective on DeepSeekV2, DeepSeekV3 and DeepSeekR1, -when we slightly decrease the activation experts num in inference, -the output quality doesn't change(within 1% accuracy drop),But the speed of decoding and prefill -is speed up about 30% which is inspiring. So our showcase makes use of this finding, -changing the activation experts of DeepSeekV3/R1 from 8 to 6.
-2. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu. +1. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu. To avoid the cost of data transfer between nodes, we "copy" the critical matrix on both nodes which takes more memory consumption but accelerates the prefill and decoding process. But this method takes huge memory and slow when loading weights, So be patient when loading -and monitor the memory usage.(we are considering to make this method as an option)
-3. the command args `--cpu_infer 65` specifies how many cores to use(it's ok that it exceeds the physical number, -but it's not the more the better. Adjust it slight lower to your actual number of cores)
+and monitor the memory usage.(we are considering to make this method as an option). We are going to optimize this huge memory overhead. Stay tuned~
+2. the command args `--cpu_infer 65` specifies how many cores to use(it's ok that it exceeds the physical number, +but it's not the more the better. Adjust it slightly lower to your actual number of cores)
From fd481af193b8baa14cafdcc5f887e8a795d8625c Mon Sep 17 00:00:00 2001 From: liam Date: Mon, 10 Feb 2025 09:48:14 +0800 Subject: [PATCH 16/26] :zap: update v0.3 preview --- README.md | 2 +- doc/en/DeepseekR1_V3_tutorial.md | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b8e83cc..b4ad06a 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

πŸ”₯ Updates

-* **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~64X speedup. The Detailed tutorial is [here](./doc/en/DeepseekR1_V3_tutorial.md) +* **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~64x speedup. The Detailed tutorial is [here](./doc/en/DeepseekR1_V3_tutorial.md) * **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./doc/en/long_context_tutorial.md). * **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G. * **Aug 15, 2024**: Update detailed [TUTORIAL](doc/en/injection_tutorial.md) for injection and multi-GPU. diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index a56b689..45a5aab 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -47,6 +47,12 @@ The main acceleration comes from - Intel AMX instruction set and our specially designed cache friendly memory layout - Expert selection strategy that selects fewer experts based on offline profile results of out of domain data + +*From our research on DeepSeekV2, DeepSeekV3 and DeepSeekR1, +when we slightly decrease the activation experts num in inference, +the output quality doesn't change,But the speed of decoding and prefill +is speed up which is inspiring. So our showcase makes use of this finding* + ## how to run ### v0.2 showcase #### single socket version(32 cores) From e968fa8d72db5ec1177f848a27b7fd625fec488f Mon Sep 17 00:00:00 2001 From: unicornchan Date: Mon, 10 Feb 2025 01:52:39 +0000 Subject: [PATCH 17/26] [feature] add flash_attn to requirements --- pyproject.toml | 1 + requirements-local_chat.txt | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 028c6a3..69c1e37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "blessed >= 1.20.0", "accelerate >= 0.31.0", "sentencepiece >= 0.1.97", + "flash_attn == 2.7.4.post1" "setuptools", "ninja", "wheel", diff --git a/requirements-local_chat.txt b/requirements-local_chat.txt index 50b1f65..d221e0e 100644 --- a/requirements-local_chat.txt +++ b/requirements-local_chat.txt @@ -1,5 +1,6 @@ fire -transformers +transformers==4.43.2 +flash_attn==2.7.4.post1 numpy torch>=2.3.0 packaging From cff68532cee0df4040bf3812a840951a56d28920 Mon Sep 17 00:00:00 2001 From: liam Date: Mon, 10 Feb 2025 09:52:48 +0800 Subject: [PATCH 18/26] :zap: fix typo --- doc/en/DeepseekR1_V3_tutorial.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index 45a5aab..1bc1be8 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -1,6 +1,6 @@ # Report ## Prerequisites -We run our best performance tests on
+We run our best performance tests(V0.2) on
cpu: Intel(R) Xeon(R) Gold 6454S 1T DRAM(2 NUMA nodes)
gpu: 4090D 24G VRAM
## Bench result @@ -50,7 +50,7 @@ The main acceleration comes from *From our research on DeepSeekV2, DeepSeekV3 and DeepSeekR1, when we slightly decrease the activation experts num in inference, -the output quality doesn't change,But the speed of decoding and prefill +the output quality doesn't change. But the speed of decoding and prefill is speed up which is inspiring. So our showcase makes use of this finding* ## how to run From 107e4be41791c3f888051bb64e9ff91f0dab77f0 Mon Sep 17 00:00:00 2001 From: liam Date: Mon, 10 Feb 2025 10:50:40 +0800 Subject: [PATCH 19/26] :zap: fix typo --- doc/en/DeepseekR1_V3_tutorial.md | 52 ++++++++++++++++---------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index 1bc1be8..c837a6a 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -1,18 +1,18 @@ # Report ## Prerequisites -We run our best performance tests(V0.2) on
-cpu: Intel(R) Xeon(R) Gold 6454S 1T DRAM(2 NUMA nodes)
-gpu: 4090D 24G VRAM
-## Bench result +We run our best performance tests (V0.2) on
+CPU: Intel(R) Xeon(R) Gold 6454S 1T DRAM (2 NUMA nodes)
+GPU: 4090D 24G VRAM
+## Bench Result ### V0.2 -#### settings -- model: DeepseekV3-q4km(int4οΌ‰
-- CPU: cpu_model_name:Intel(R) Xeon(R) Gold 6454S, 32 cores per socket, 2 socket, 2numa nodes +#### Settings +- Model: DeepseekV3-q4km(int4οΌ‰
+- CPU: cpu_model_name:Intel (R) Xeon (R) Gold 6454S, 32 cores per socket, 2 socket, 2 numa nodes - GPU: 4090D 24GVRAM -- we test after enough warm up! -#### memory consumption: - - single socket: 382G DRAM, 12G VRAM - - dual socket: 1T DRAM, 12G VRAM +- We test after enough warm up +#### Memory consumption: + - Single socket: 382G DRAM, at least 12G VRAM + - Dual socket: 1T DRAM, at least 12G VRAM #### Benchmark Results @@ -26,22 +26,22 @@ gpu: 4090D 24G VRAM
**The highest speedup reaches up to 3.03x in decoding and 9.44x in prefill.** ### V0.3-Preview -#### settings -- model: DeepseekV3-BF16 (online quant into int8 for CPU and int4 for GPU) +#### Settings +- Model: DeepseekV3-BF16 (online quant into int8 for CPU and int4 for GPU) - CPU: cpu_model_name:Intel(R) Xeon(R) Gold 6454S, 32 cores per socket, 2 socket, 2 numa nodes - GPU: (1~4)x 4090D 24GVRAM (requires more VRAM for longer prompt) -#### memory consumptions: +#### Memory consumptions: - 644GB DRAM, at least 12GB VRAM -#### Benchmark Results +#### Benchmark results | Prompt length | 1K | 2K | 4K | 8K | |---------------|-----|-----|-----|-----| | KTrans (8 experts) Prefill token/s | 185.96 | 255.26 | 252.58 | 195.62 | | KTrans (6 experts) Prefill token/s | 203.70 | 286.55 | 271.08 | 207.20 | **The prefill of KTrans V0.3 is up to 3.45x times faster than KTrans V0.2, and is up to 63.53x times faster than llama.cpp.** -**The decoding speed is the same as KTrans V0.2 (6 experts version) so it is omitted.** +**The decoding speed is the same as KTrans V0.2 (6 experts version) so it is omitted** The main acceleration comes from - Intel AMX instruction set and our specially designed cache friendly memory layout @@ -53,9 +53,9 @@ when we slightly decrease the activation experts num in inference, the output quality doesn't change. But the speed of decoding and prefill is speed up which is inspiring. So our showcase makes use of this finding* -## how to run -### v0.2 showcase -#### single socket version(32 cores) +## How to Run +### V0.2 Showcase +#### Single socket version(32 cores) our local_chat test command is: ``` shell git clone https://github.com/kvcache-ai/ktransformers.git @@ -64,10 +64,10 @@ numactl -N 1 -m 1 python ./ktransformers/local_chat.py --model_path ``` \ can be local or set from onlie hugging face like deepseek-ai/DeepSeek-V3. If onlie encounters connection problem, try use mirror(hf-mirror.com)
-\ can also be onlie, but as its large we recommend you download it and quantize the model to what you want.
-the command numactl -N 1 -m 1 aims to adoid data transfer between numa nodes. -### dual socket version(64 cores) -make suer before you install(use install.sh or `make dev_install`), setting the env var `USE_NUMA=1` by `export USE_NUMA=1`(if already installed, reinstall it with this env var set)
+\ can also be onlie, but as its large we recommend you download it and quantize the model to what you want
+The command numactl -N 1 -m 1 aims to adoid data transfer between numa nodes +#### Dual socket version(64 cores) +Make suer before you install (use install.sh or `make dev_install`), setting the env var `USE_NUMA=1` by `export USE_NUMA=1` (if already installed, reinstall it with this env var set)
our local_chat test command is: ``` shell git clone https://github.com/kvcache-ai/ktransformers.git @@ -77,12 +77,12 @@ make dev_install # or sh ./install.sh python ./ktransformers/local_chat.py --model_path --gguf_path --prompt_file --cpu_infer 65 --cache_lens 1536 ``` -The parameters meaning is the same. But As we use dual socket, so we set cpu_infer to 65. -## some explanations +The parameters meaning is the same. But As we use dual socket, so we set cpu_infer to 65 +## Some Explanations 1. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu. To avoid the cost of data transfer between nodes, we "copy" the critical matrix on both nodes which takes more memory consumption but accelerates the prefill and decoding process. But this method takes huge memory and slow when loading weights, So be patient when loading and monitor the memory usage.(we are considering to make this method as an option). We are going to optimize this huge memory overhead. Stay tuned~
-2. the command args `--cpu_infer 65` specifies how many cores to use(it's ok that it exceeds the physical number, +2. The command args `--cpu_infer 65` specifies how many cores to use(it's ok that it exceeds the physical number, but it's not the more the better. Adjust it slightly lower to your actual number of cores)
From 402b71446b90e517b4cac0227b75777db5ca3bb9 Mon Sep 17 00:00:00 2001 From: unicornchan Date: Mon, 10 Feb 2025 03:15:26 +0000 Subject: [PATCH 20/26] [fix] fix pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 69c1e37..3c3700d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "blessed >= 1.20.0", "accelerate >= 0.31.0", "sentencepiece >= 0.1.97", - "flash_attn == 2.7.4.post1" + "flash_attn == 2.7.4.post1", "setuptools", "ninja", "wheel", From 3d7dfd61510db0cd9a8fb078cda84f2691f474ee Mon Sep 17 00:00:00 2001 From: liam Date: Mon, 10 Feb 2025 11:12:52 +0800 Subject: [PATCH 21/26] :zap: fix typo --- doc/en/DeepseekR1_V3_tutorial.md | 38 ++++++++++++++++---------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index c837a6a..376ffa1 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -1,14 +1,14 @@ # Report ## Prerequisites We run our best performance tests (V0.2) on
-CPU: Intel(R) Xeon(R) Gold 6454S 1T DRAM (2 NUMA nodes)
+CPU: Intel (R) Xeon (R) Gold 6454S 1T DRAM (2 NUMA nodes)
GPU: 4090D 24G VRAM
## Bench Result ### V0.2 #### Settings -- Model: DeepseekV3-q4km(int4οΌ‰
-- CPU: cpu_model_name:Intel (R) Xeon (R) Gold 6454S, 32 cores per socket, 2 socket, 2 numa nodes -- GPU: 4090D 24GVRAM +- Model: DeepseekV3-q4km (int4)
+- CPU: cpu_model_name: Intel (R) Xeon (R) Gold 6454S, 32 cores per socket, 2 sockets, 2 numa nodes +- GPU: 4090D 24G VRAM - We test after enough warm up #### Memory consumption: - Single socket: 382G DRAM, at least 12G VRAM @@ -16,7 +16,7 @@ GPU: 4090D 24G VRAM
#### Benchmark Results -"6 experts" case is part of v0.3's preview +"6 experts" case is part of V0.3's preview | Prompt
(500 tokens) | Dual socket Ktrans (6 experts) | Dual socket Ktrans (8 experts) | Single socket Ktrans (6 experts) | Single socket Ktrans (8 experts)| llama.cpp (8 experts) | | --- | --- | --- | --- | --- | --- | @@ -28,7 +28,7 @@ GPU: 4090D 24G VRAM
### V0.3-Preview #### Settings - Model: DeepseekV3-BF16 (online quant into int8 for CPU and int4 for GPU) -- CPU: cpu_model_name:Intel(R) Xeon(R) Gold 6454S, 32 cores per socket, 2 socket, 2 numa nodes +- CPU: cpu_model_name: Intel (R) Xeon (R) Gold 6454S, 32 cores per socket, 2 socket, 2 numa nodes - GPU: (1~4)x 4090D 24GVRAM (requires more VRAM for longer prompt) #### Memory consumptions: @@ -55,34 +55,34 @@ is speed up which is inspiring. So our showcase makes use of this finding* ## How to Run ### V0.2 Showcase -#### Single socket version(32 cores) -our local_chat test command is: +#### Single socket version (32 cores) +Our local_chat test command is: ``` shell git clone https://github.com/kvcache-ai/ktransformers.git cd ktransformers -numactl -N 1 -m 1 python ./ktransformers/local_chat.py --model_path --gguf_path --prompt_file --cpu_infer 33 --cache_lens 1536 +numactl -N 1 -m 1 python ./ktransformers/local_chat.py --model_path --gguf_path --prompt_file --cpu_infer 33 --cache_lens 1536 ``` -\ can be local or set from onlie hugging face like deepseek-ai/DeepSeek-V3. If onlie encounters connection problem, try use mirror(hf-mirror.com)
-\ can also be onlie, but as its large we recommend you download it and quantize the model to what you want
-The command numactl -N 1 -m 1 aims to adoid data transfer between numa nodes -#### Dual socket version(64 cores) +\ can be local or set from online hugging face like deepseek-ai/DeepSeek-V3. If online encounters connection problem, try use mirror (hf-mirror.com)
+\ can also be online, but as its large we recommend you download it and quantize the model to what you want
+The command numactl -N 1 -m 1 aims to advoid data transfer between numa nodes +#### Dual socket version (64 cores) Make suer before you install (use install.sh or `make dev_install`), setting the env var `USE_NUMA=1` by `export USE_NUMA=1` (if already installed, reinstall it with this env var set)
-our local_chat test command is: +Our local_chat test command is: ``` shell git clone https://github.com/kvcache-ai/ktransformers.git cd ktransformers export USE_NUMA=1 make dev_install # or sh ./install.sh -python ./ktransformers/local_chat.py --model_path --gguf_path --prompt_file --cpu_infer 65 --cache_lens 1536 +python ./ktransformers/local_chat.py --model_path --gguf_path --prompt_file --cpu_infer 65 --cache_lens 1536 ``` -The parameters meaning is the same. But As we use dual socket, so we set cpu_infer to 65 +The parameters' meaning is the same. But As we use dual socket, we set cpu_infer to 65 ## Some Explanations 1. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu. To avoid the cost of data transfer between nodes, we "copy" the critical matrix on both nodes which takes more memory consumption but accelerates the prefill and decoding process. But this method takes huge memory and slow when loading weights, So be patient when loading -and monitor the memory usage.(we are considering to make this method as an option). We are going to optimize this huge memory overhead. Stay tuned~
-2. The command args `--cpu_infer 65` specifies how many cores to use(it's ok that it exceeds the physical number, -but it's not the more the better. Adjust it slightly lower to your actual number of cores)
+and monitor the memory usage. (we are considering to make this method as an option). We are going to optimize this huge memory overhead. Stay tuned~
+2. The command args `--cpu_infer 65` specifies how many cores to use (it's ok that it exceeds the physical number, +but it's not the more the better. Adjust it slightly lower to your actual number of cores)
\ No newline at end of file From 0f73f40da0dfc0c022217d667a67f7044ae6a28a Mon Sep 17 00:00:00 2001 From: liam Date: Mon, 10 Feb 2025 11:31:58 +0800 Subject: [PATCH 22/26] :zap: add Summary part --- README.md | 2 +- doc/en/DeepseekR1_V3_tutorial.md | 26 ++++++++++++++++++++++---- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index b4ad06a..2dc358a 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285

-- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 12GB VRAM and 382GB DRAM. +- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM. - Prefill Speed: - KTransfermor: 54.21 (32 cores) β†’ 74.362 (dual-socket, 2Γ—32 cores) β†’ 255.26 (optimized AMX-based MoE kernel, v3 only) β†’ 286.55 (selectively using 6 experts, v3 only) - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **63.53Γ— speedup**. diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index 376ffa1..4192125 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -1,4 +1,22 @@ -# Report +# GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM +# SUMMARY + +https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285 + +

+ +- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM. + - Prefill Speed: + - KTransfermor: 54.21 (32 cores) β†’ 74.362 (dual-socket, 2Γ—32 cores) β†’ 255.26 (optimized AMX-based MoE kernel, v3 only) β†’ 286.55 (selectively using 6 experts, v3 only) + - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **63.53Γ— speedup**. + - Decode Speed(tokens/s): + - KTransfermor: 8.73 (32 cores) β†’ 11.26 (dual-socket, 2Γ—32 cores) β†’ 13.69 (selectively using 6 experts, v3 only) + - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **3.03Γ— speedup**. + - Upcoming Open Source Release: + - AMX optimizations and selective expert activation will be open-sourced in v0.3. + - Currently available only in preview binary distribution, which can be found [here](xxx). + + ## Prerequisites We run our best performance tests (V0.2) on
CPU: Intel (R) Xeon (R) Gold 6454S 1T DRAM (2 NUMA nodes)
@@ -11,8 +29,8 @@ GPU: 4090D 24G VRAM
- GPU: 4090D 24G VRAM - We test after enough warm up #### Memory consumption: - - Single socket: 382G DRAM, at least 12G VRAM - - Dual socket: 1T DRAM, at least 12G VRAM + - Single socket: 382G DRAM, at least 14GB VRAM + - Dual socket: 1T DRAM, at least 14GB VRAM #### Benchmark Results @@ -32,7 +50,7 @@ GPU: 4090D 24G VRAM
- GPU: (1~4)x 4090D 24GVRAM (requires more VRAM for longer prompt) #### Memory consumptions: -- 644GB DRAM, at least 12GB VRAM +- 644GB DRAM, at least 14GB VRAM #### Benchmark results | Prompt length | 1K | 2K | 4K | 8K | From aecb50f0d11ae4396c05e6f535392b6ca47c2344 Mon Sep 17 00:00:00 2001 From: liam Date: Mon, 10 Feb 2025 11:36:46 +0800 Subject: [PATCH 23/26] :zap: fix typo readme --- README.md | 6 +++--- doc/en/DeepseekR1_V3_tutorial.md | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 2dc358a..92c7766 100644 --- a/README.md +++ b/README.md @@ -43,13 +43,13 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285 - **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM. - Prefill Speed: - - KTransfermor: 54.21 (32 cores) β†’ 74.362 (dual-socket, 2Γ—32 cores) β†’ 255.26 (optimized AMX-based MoE kernel, v3 only) β†’ 286.55 (selectively using 6 experts, v3 only) + - KTransfermor: 54.21 (32 cores) β†’ 74.362 (dual-socket, 2Γ—32 cores) β†’ 255.26 (optimized AMX-based MoE kernel, V0.3 only) β†’ 286.55 (selectively using 6 experts, V0.3 only) - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **63.53Γ— speedup**. - Decode Speed(tokens/s): - - KTransfermor: 8.73 (32 cores) β†’ 11.26 (dual-socket, 2Γ—32 cores) β†’ 13.69 (selectively using 6 experts, v3 only) + - KTransfermor: 8.73 (32 cores) β†’ 11.26 (dual-socket, 2Γ—32 cores) β†’ 13.69 (selectively using 6 experts, V0.3 only) - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **3.03Γ— speedup**. - Upcoming Open Source Release: - - AMX optimizations and selective expert activation will be open-sourced in v0.3. + - AMX optimizations and selective expert activation will be open-sourced in V0.3. - Currently available only in preview binary distribution, which can be found [here](xxx). - **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench). diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index 4192125..e65e1d1 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -7,13 +7,13 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285 - **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM. - Prefill Speed: - - KTransfermor: 54.21 (32 cores) β†’ 74.362 (dual-socket, 2Γ—32 cores) β†’ 255.26 (optimized AMX-based MoE kernel, v3 only) β†’ 286.55 (selectively using 6 experts, v3 only) + - KTransfermor: 54.21 (32 cores) β†’ 74.362 (dual-socket, 2Γ—32 cores) β†’ 255.26 (optimized AMX-based MoE kernel, V0.3 only) β†’ 286.55 (selectively using 6 experts, V0.3 only) - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **63.53Γ— speedup**. - Decode Speed(tokens/s): - - KTransfermor: 8.73 (32 cores) β†’ 11.26 (dual-socket, 2Γ—32 cores) β†’ 13.69 (selectively using 6 experts, v3 only) + - KTransfermor: 8.73 (32 cores) β†’ 11.26 (dual-socket, 2Γ—32 cores) β†’ 13.69 (selectively using 6 experts, V0.3 only) - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **3.03Γ— speedup**. - Upcoming Open Source Release: - - AMX optimizations and selective expert activation will be open-sourced in v0.3. + - AMX optimizations and selective expert activation will be open-sourced in V0.3. - Currently available only in preview binary distribution, which can be found [here](xxx). From f892d22849b45eda2084aa450d89ff265e3f8a45 Mon Sep 17 00:00:00 2001 From: liam Date: Mon, 10 Feb 2025 11:45:46 +0800 Subject: [PATCH 24/26] :zap: update v3 --- README.md | 4 ++-- doc/en/DeepseekR1_V3_tutorial.md | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 92c7766..4c85e1f 100644 --- a/README.md +++ b/README.md @@ -42,10 +42,10 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285

- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM. - - Prefill Speed: + - Prefill Speed (tokens/s): - KTransfermor: 54.21 (32 cores) β†’ 74.362 (dual-socket, 2Γ—32 cores) β†’ 255.26 (optimized AMX-based MoE kernel, V0.3 only) β†’ 286.55 (selectively using 6 experts, V0.3 only) - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **63.53Γ— speedup**. - - Decode Speed(tokens/s): + - Decode Speed (tokens/s): - KTransfermor: 8.73 (32 cores) β†’ 11.26 (dual-socket, 2Γ—32 cores) β†’ 13.69 (selectively using 6 experts, V0.3 only) - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **3.03Γ— speedup**. - Upcoming Open Source Release: diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index e65e1d1..b69b304 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -1,15 +1,17 @@ # GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM # SUMMARY +- **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~64x speedup.
+ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285

- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM. - - Prefill Speed: + - Prefill Speed (tokens/s): - KTransfermor: 54.21 (32 cores) β†’ 74.362 (dual-socket, 2Γ—32 cores) β†’ 255.26 (optimized AMX-based MoE kernel, V0.3 only) β†’ 286.55 (selectively using 6 experts, V0.3 only) - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **63.53Γ— speedup**. - - Decode Speed(tokens/s): + - Decode Speed (tokens/s): - KTransfermor: 8.73 (32 cores) β†’ 11.26 (dual-socket, 2Γ—32 cores) β†’ 13.69 (selectively using 6 experts, V0.3 only) - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **3.03Γ— speedup**. - Upcoming Open Source Release: From 83401dbb3becdf428229c69e899d94e6c9bbf385 Mon Sep 17 00:00:00 2001 From: liam Date: Mon, 10 Feb 2025 12:29:23 +0800 Subject: [PATCH 25/26] :zap: ready to publish --- README.md | 2 +- doc/en/DeepseekR1_V3_tutorial.md | 41 +++++- ktransformers/operators/RoPE.py | 123 +++++++++++++++++- .../DeepSeek-V3-Chat-multi-gpu-marlin.yaml | 4 +- .../DeepSeek-V3-Chat-multi-gpu.yaml | 4 +- .../optimize_rules/DeepSeek-V3-Chat.yaml | 2 +- 6 files changed, 157 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 4c85e1f..6735da9 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin * **Aug 12, 2024**: Support multiple GPU; Support new model: mixtral 8\*7B and 8\*22B; Support q2k, q3k, q5k dequant on gpu. * **Aug 9, 2024**: Support windows native. -

πŸ”₯ Show Cases

+

🌟 Show Cases

GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM

diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index b69b304..24c7a87 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -1,7 +1,14 @@ # GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM # SUMMARY -- **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~64x speedup.
+> **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~64x speedup.
+ +Hi, we're the KTransformers team (formerly known for our local CPU/GPU hybrid inference open source project with DeepSeek-V2). + +We've heard your requests for DeepSeek-R1/V3 supportβ€”and we're excited to finally deliver! +Apologies for the wait, but we've been cooking up something truly amazing! + +Today, we're proud to announce that we not only support DeepSeek-R1/V3, as showcased in the video below: https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285 @@ -14,9 +21,10 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285 - Decode Speed (tokens/s): - KTransfermor: 8.73 (32 cores) β†’ 11.26 (dual-socket, 2Γ—32 cores) β†’ 13.69 (selectively using 6 experts, V0.3 only) - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **3.03Γ— speedup**. - - Upcoming Open Source Release: - - AMX optimizations and selective expert activation will be open-sourced in V0.3. - - Currently available only in preview binary distribution, which can be found [here](xxx). + + +But we're also previewing our upcoming optimizations, including an Intel AMX-accelerated kernel and a selective expert activation method, which will significantly enhance performance. With V0.3-preview, we achieve up to 286 tokens/s for prefill, making it up to **64Γ— faster than llama.cpp** for local inference. +The binary distribution is available now and the source code will come ASAP! Check out the details [here](xxx) ## Prerequisites @@ -98,11 +106,32 @@ python ./ktransformers/local_chat.py --model_path --gguf_path ``` The parameters' meaning is the same. But As we use dual socket, we set cpu_infer to 65 + +### V0.3 Showcase +#### Dual socket version (64 cores) +Our local_chat test command is: +``` shell +python -m ktransformers.local_chat --model_path --gguf_path --prompt_file --cpu_infer 65 --cache_lens 1536 + +``` +The parameters' meaning is the same with V0.2. But As we use dual socket, we set cpu_infer to 65 + ## Some Explanations 1. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu. To avoid the cost of data transfer between nodes, we "copy" the critical matrix on both nodes which takes more memory consumption but accelerates the prefill and decoding process. But this method takes huge memory and slow when loading weights, So be patient when loading -and monitor the memory usage. (we are considering to make this method as an option). We are going to optimize this huge memory overhead. Stay tuned~
+and monitor the memory usage. We are going to optimize this huge memory overhead. Stay tuned~
2. The command args `--cpu_infer 65` specifies how many cores to use (it's ok that it exceeds the physical number, -but it's not the more the better. Adjust it slightly lower to your actual number of cores)
\ No newline at end of file +but it's not the more the better. Adjust it slightly lower to your actual number of cores)
+ +3. Why CPU/GPU Hybrid Inference? +DeepSeek's MLA operators are highly computationally intensive. While running everything on CPU is possible, offloading the heavy computations to the GPU results in a massive performance boost. + +4. Where Does the Speedup Come From? + + - Expert Offload: Unlike traditional layer-based or KVCache offloading (as seen in llama.cpp), we offload the expert computation to the CPU and MLA/KVCache to GPU, aligning perfectly with DeepSeek’s architecture for optimal efficiency. + - Intel AMX Optimization – Our AMX-accelerated kernel is meticulously tuned, running several times faster than existing llama.cpp implementations. We plan to open-source this kernel after cleansing and are considering upstream contributions to llama.cpp. + +5. Why Intel CPUs? +Intel is currently the only CPU vendor that supports AMX-like instructions, which delivers significantly better performance compared to AVX-only alternatives. \ No newline at end of file diff --git a/ktransformers/operators/RoPE.py b/ktransformers/operators/RoPE.py index 9e2eb44..dc5902c 100644 --- a/ktransformers/operators/RoPE.py +++ b/ktransformers/operators/RoPE.py @@ -18,6 +18,9 @@ from ktransformers.models.modeling_deepseek_v3 import ( from ktransformers.models.modeling_deepseek import ( DeepseekV2YarnRotaryEmbedding, DeepseekV2RotaryEmbedding, + yarn_get_mscale, + yarn_linear_ramp_mask, + yarn_find_correction_range ) from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_gguf import GGUFLoader @@ -188,7 +191,33 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): self.orig_module.mscale_all_dim, ) -class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding): +# class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding): +# def __init__( +# self, +# key: str, +# gguf_loader: GGUFLoader, +# config: PretrainedConfig, +# orig_module: nn.Module, +# # device: str = "cuda", +# generate_device: str = "cuda", +# prefill_device: str = "cuda", +# **kwargs, +# ): +# BaseInjectedModule.__init__( +# self, key, gguf_loader, config, orig_module, generate_device, **kwargs +# ) +# self.generate_device = generate_device +# self.prefill_device = prefill_device + +# def load(self): +# # TODO support perlayer prefill +# self.orig_module.__init__( +# self.config, +# device=self.generate_device +# ) +# return + +class YarnRotaryEmbeddingV3(BaseInjectedModule): def __init__( self, key: str, @@ -205,14 +234,94 @@ class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbeddin ) self.generate_device = generate_device self.prefill_device = prefill_device - + def load(self): - # TODO support perlayer prefill - self.orig_module.__init__( - self.config, - device=self.generate_device + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self._init( + dim=self.config.qk_rope_head_dim, + max_position_embeddings=self.config.max_position_embeddings, + base=self.config.rope_theta, + device=self.device, + scaling_factor=self.config.rope_scaling["factor"], + **kwargs, ) - return + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos()* self._mscale + sin = emb.sin()* self._mscale + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def _init( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self._mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings class DynamicNTKScalingRotaryEmbedding( BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml index 22be22e..06ab4db 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml @@ -10,7 +10,7 @@ name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.RotaryEmbeddingV3 + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" @@ -18,7 +18,7 @@ name: "^model\\.layers\\.([3456][0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.RotaryEmbeddingV3 + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml index 22be22e..06ab4db 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml @@ -10,7 +10,7 @@ name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.RotaryEmbeddingV3 + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" @@ -18,7 +18,7 @@ name: "^model\\.layers\\.([3456][0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.RotaryEmbeddingV3 + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml index 4a306be..7a44c5d 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml @@ -1,7 +1,7 @@ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.RotaryEmbeddingV3 + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda" prefill_device: "cuda" From 6f0fe953e1f3d494239971fc895f5032853de9a9 Mon Sep 17 00:00:00 2001 From: liam Date: Mon, 10 Feb 2025 13:52:24 +0800 Subject: [PATCH 26/26] :zap: release v0.2.0 --- README.md | 4 ++-- doc/en/DeepseekR1_V3_tutorial.md | 6 ++++-- pyproject.toml | 1 - requirements-local_chat.txt | 1 - 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 6735da9..d06163d 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

πŸ”₯ Updates

-* **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~64x speedup. The Detailed tutorial is [here](./doc/en/DeepseekR1_V3_tutorial.md) +* **Fed 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~64x speedup. The detailed tutorial is [here](./doc/en/DeepseekR1_V3_tutorial.md) * **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./doc/en/long_context_tutorial.md). * **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G. * **Aug 15, 2024**: Update detailed [TUTORIAL](doc/en/injection_tutorial.md) for injection and multi-GPU. @@ -50,7 +50,7 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285 - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **3.03Γ— speedup**. - Upcoming Open Source Release: - AMX optimizations and selective expert activation will be open-sourced in V0.3. - - Currently available only in preview binary distribution, which can be found [here](xxx). + - Currently available only in preview binary distribution, which can be downloaded [here](https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl). - **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench). diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index 24c7a87..0282ba1 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -23,8 +23,8 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285 - Compared to 4.51 tokens/s in llama.cpp with 2Γ—32 cores, achieving up to **3.03Γ— speedup**. -But we're also previewing our upcoming optimizations, including an Intel AMX-accelerated kernel and a selective expert activation method, which will significantly enhance performance. With V0.3-preview, we achieve up to 286 tokens/s for prefill, making it up to **64Γ— faster than llama.cpp** for local inference. -The binary distribution is available now and the source code will come ASAP! Check out the details [here](xxx) +We also give our upcoming optimizations previews, including an Intel AMX-accelerated kernel and a selective expert activation method, which will significantly enhance performance. With V0.3-preview, we achieve up to 286 tokens/s for prefill, making it up to **64Γ— faster than llama.cpp** for local inference. +The binary distribution is available now and the source code will come ASAP! Check out the wheel package [here](https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl) ## Prerequisites @@ -111,6 +111,8 @@ The parameters' meaning is the same. But As we use dual socket, we set cpu_infe #### Dual socket version (64 cores) Our local_chat test command is: ``` shell +wget https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl +pip install ./ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl python -m ktransformers.local_chat --model_path --gguf_path --prompt_file --cpu_infer 65 --cache_lens 1536 ``` diff --git a/pyproject.toml b/pyproject.toml index 3c3700d..028c6a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,6 @@ dependencies = [ "blessed >= 1.20.0", "accelerate >= 0.31.0", "sentencepiece >= 0.1.97", - "flash_attn == 2.7.4.post1", "setuptools", "ninja", "wheel", diff --git a/requirements-local_chat.txt b/requirements-local_chat.txt index d221e0e..0479d36 100644 --- a/requirements-local_chat.txt +++ b/requirements-local_chat.txt @@ -1,6 +1,5 @@ fire transformers==4.43.2 -flash_attn==2.7.4.post1 numpy torch>=2.3.0 packaging