mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-03-21 10:01:39 +08:00
done support deepseekv3
This commit is contained in:
parent
f748cd29f0
commit
907251c743
File diff suppressed because it is too large
Load Diff
@ -23,7 +23,7 @@ from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.utils import InferenceState
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
import torch
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
|
||||
class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
|
||||
@ -56,6 +56,57 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
|
||||
)
|
||||
|
||||
|
||||
class RotaryEmbeddingV3(BaseInjectedModule):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
def load(self):
|
||||
self._init(
|
||||
dim=self.config.qk_rope_head_dim,
|
||||
max_position_embeddings=self.config.max_position_embeddings,
|
||||
base=self.config.rope_theta,
|
||||
device=self.device,
|
||||
)
|
||||
def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
# self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
# For BC we register cos and sin cached
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
|
||||
class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -151,7 +151,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -220,7 +220,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
|
||||
attn_output = torch.cat((attn_output, cur_output), dim=-2)
|
||||
attn_weight = torch.cat((attn_weight, cur_attn_weight), dim=-2)
|
||||
|
||||
return attn_output, attn_weight
|
||||
return attn_output, attn_weight, past_key_value
|
||||
|
||||
class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@ -734,7 +734,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
topk_idx, topk_weight, router_logits= self.gate(hidden_states)
|
||||
topk_idx, topk_weight = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
# only for generate phase
|
||||
@ -745,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
||||
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
|
||||
y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y, router_logits
|
||||
return y
|
||||
|
||||
if self.config.n_shared_experts is not None:
|
||||
y_ = self.shared_experts(identity).squeeze(0)
|
||||
@ -768,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
||||
)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y += y_
|
||||
return y, router_logits
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -16,9 +16,6 @@ from cpuinfer_ext.moe import MOEConfig, MOE
|
||||
import ctypes
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3TopkRouter
|
||||
from ktransformers.util.utils import InferenceState
|
||||
from ktransformers.server.config.config import Config
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from abc import ABC, abstractmethod
|
||||
@ -102,6 +99,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
|
||||
):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
return self.orig_module.forward(hidden_states)
|
||||
|
||||
@ -625,6 +625,13 @@ class KDeepseekV2Model(BaseInjectedModule):
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
|
||||
if inputs_embeds is None:
|
||||
org_device = input_ids.device
|
||||
# TODO move to embed_tokens's device, not hard code to cpu
|
||||
input_ids = input_ids.to("cpu")
|
||||
inputs_embeds = self.embed_tokens(input_ids).to(org_device)
|
||||
input_ids = input_ids.to(org_device)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = (
|
||||
@ -639,13 +646,6 @@ class KDeepseekV2Model(BaseInjectedModule):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
if inputs_embeds is None:
|
||||
org_device = input_ids.device
|
||||
# TODO move to embed_tokens's device, not hard code to cpu
|
||||
input_ids = input_ids.to("cpu")
|
||||
inputs_embeds = self.embed_tokens(input_ids).to(org_device)
|
||||
input_ids = input_ids.to(org_device)
|
||||
|
||||
if per_layer_prefill_flag:
|
||||
causal_mask = None
|
||||
else:
|
||||
@ -717,6 +717,8 @@ class KDeepseekV2Model(BaseInjectedModule):
|
||||
self.load_layer_to(decoder_layer, InferenceState.PREFILL)
|
||||
torch.cuda.empty_cache()
|
||||
t4 = time.time()
|
||||
# with open("log.txt", "a") as f:
|
||||
# f.write(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n")
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
@ -739,13 +741,17 @@ class KDeepseekV2Model(BaseInjectedModule):
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
# @@@@@@@ TODO open this notes, tmp close to fit deepseekv3
|
||||
# if use_cache:
|
||||
# next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
# with open("log.txt", "a") as f:
|
||||
# f.write(f"@@@After layers\n")
|
||||
# f.write(f"hidden_states={hidden_states}\n")
|
||||
# f.write(f"hidden_states.shape={hidden_states.shape}\n")
|
||||
|
||||
if per_layer_prefill_flag:
|
||||
t6 = time.time()
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.DeepSeekV3YarnRotaryEmbedding
|
||||
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
@ -18,7 +18,7 @@
|
||||
name: "^model\\.layers\\.([3456][0-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.DeepSeekV3YarnRotaryEmbedding
|
||||
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
@ -64,7 +64,7 @@
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
@ -72,7 +72,7 @@
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
|
||||
kwargs:
|
||||
@ -106,14 +106,14 @@
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV3Attention # optimized MLA implementation
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV3Attention # optimized MLA implementation
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
||||
@ -24,7 +24,7 @@ class KTransformersInterface(TransformersInterface):
|
||||
self.args = args
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_grad_enabled(False)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
@ -99,7 +99,7 @@ class KTransformersInterface(TransformersInterface):
|
||||
if self.use_static_cache:
|
||||
mask = torch.ones((1, self.seq_length)).to(torch_device)
|
||||
logits = self.model(
|
||||
self.current_ids,
|
||||
self.current_ids.to(torch_device),
|
||||
cache_position=self.active_cache_position,
|
||||
past_key_values=self.cache,
|
||||
attention_mask=mask,
|
||||
|
||||
@ -198,7 +198,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||
return self.streamer.put(new_tokens)
|
||||
|
||||
def logits_to_token(self, logits: torch.Tensor):
|
||||
logits = logits / self.args.temperature
|
||||
logits = logits / self.args.temperature if self.args.temperature!=0 else logits
|
||||
|
||||
for token_idx in self.ever_generated_ids:
|
||||
if logits[token_idx] < 0:
|
||||
@ -318,7 +318,9 @@ class TransformersInterface(BackendInterfaceBase):
|
||||
if isinstance(local_messages, List):
|
||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||
elif isinstance(local_messages, str):
|
||||
#local_messages = local_messages[0]['content']
|
||||
input_ids = self.tokenize_prompt(local_messages)
|
||||
#input_ids = torch.tensor([[6366]], device=input_ids.device)
|
||||
else:
|
||||
raise ValueError("local_messages should be List or str")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user