From 95d937c51d91630ec6d764500c4e668038ef8f22 Mon Sep 17 00:00:00 2001 From: DDong Jianwei <1913953267@qq.com> Date: Sun, 23 Feb 2025 18:51:42 +0800 Subject: [PATCH 1/3] tmp --- ktransformers/local_chat.py | 9 +++++++-- ktransformers/operators/attention.py | 6 +++--- ktransformers/operators/experts.py | 4 ++-- .../optimize/optimize_rules/DeepSeek-V3-Chat.yaml | 2 +- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index d5e74de..5b40455 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -58,7 +58,7 @@ def local_chat( gguf_path: str | None = None, max_new_tokens: int = 300, cpu_infer: int = Config().cpu_infer, - use_cuda_graph: bool = True, + use_cuda_graph: bool = False, prompt_file : str | None = None, mode: str = "normal", force_think: bool = False, @@ -160,6 +160,9 @@ def local_chat( input_tensor = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ) + + # input_tensor = torch.tensor([[0, 6657, 84646]], device=input_tensor.device) + if force_think: token_thinks = torch.tensor([tokenizer.encode("\\n",add_special_tokens=False)],device=input_tensor.device) input_tensor = torch.cat( @@ -181,4 +184,6 @@ def local_chat( if __name__ == "__main__": - fire.Fire(local_chat) \ No newline at end of file + # fire.Fire(local_chat) + # local_chat(model_path="/mnt/data/model/DeepSeek-R1", gguf_path="/mnt/data/model/DeepseekV3-q4km-gguf", cpu_infer=33, force_think=False) + local_chat(model_path="/mnt/data/model/Moonlight-16B-A3B-Instruct", gguf_path="/mnt/data/model/Moonlight-16B-A3B-Instruct-GGUF", cpu_infer=33, force_think=False) \ No newline at end of file diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 85378ee..b4c5402 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -441,10 +441,10 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): # mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank] # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank] # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] - attn_output = attn_output.transpose(1, 2) - attn_output = torch.matmul(attn_output, out_absorb.mT) + attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank] + attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim] - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim] attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 21b4830..04c04c5 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -450,9 +450,9 @@ class KExpertsTorch(KExpertsBase): self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype) self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype) - self.up = torch.cat(self.gate, dim=0) + self.up = torch.cat(self.up, dim=0) self.gate = torch.cat(self.gate, dim=0) - self.down = torch.cat(self.gate, dim=0) + self.down = torch.cat(self.down, dim=0) return def unload(self): diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml index 6fb6586..4c8eca2 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml @@ -1,7 +1,7 @@ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 + class: ktransformers.operators.RoPE.RotaryEmbeddingV3 kwargs: generate_device: "cuda" prefill_device: "cuda" From e8e02e5ccc9227055617247fad60e1a973885109 Mon Sep 17 00:00:00 2001 From: Atream Date: Sun, 23 Feb 2025 14:21:18 +0000 Subject: [PATCH 2/3] support Moonlight --- ktransformers/local_chat.py | 10 ++-------- ktransformers/operators/experts.py | 2 +- ktransformers/util/utils.py | 2 +- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 5b40455..d087752 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -58,13 +58,12 @@ def local_chat( gguf_path: str | None = None, max_new_tokens: int = 300, cpu_infer: int = Config().cpu_infer, - use_cuda_graph: bool = False, + use_cuda_graph: bool = True, prompt_file : str | None = None, mode: str = "normal", force_think: bool = False, ): - torch.set_grad_enabled(False) Config().cpu_infer = cpu_infer @@ -160,9 +159,6 @@ def local_chat( input_tensor = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ) - - # input_tensor = torch.tensor([[0, 6657, 84646]], device=input_tensor.device) - if force_think: token_thinks = torch.tensor([tokenizer.encode("\\n",add_special_tokens=False)],device=input_tensor.device) input_tensor = torch.cat( @@ -184,6 +180,4 @@ def local_chat( if __name__ == "__main__": - # fire.Fire(local_chat) - # local_chat(model_path="/mnt/data/model/DeepSeek-R1", gguf_path="/mnt/data/model/DeepseekV3-q4km-gguf", cpu_infer=33, force_think=False) - local_chat(model_path="/mnt/data/model/Moonlight-16B-A3B-Instruct", gguf_path="/mnt/data/model/Moonlight-16B-A3B-Instruct-GGUF", cpu_infer=33, force_think=False) \ No newline at end of file + fire.Fire(local_chat) \ No newline at end of file diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 04c04c5..035bac4 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -159,7 +159,7 @@ class KExpertsCPU(KExpertsBase): down_ptr = ctypes.addressof( ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents ) - # print(self.gate_qtype, self.up_qtype, self.down_qtype) + #print(self.gate_type, self.up_type, self.down_type) n_routed_experts = self.n_routed_experts # n_routed_experts = len(self.orig_module) moe_config = MOEConfig( diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index cc4a323..5c608b1 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -207,7 +207,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud tokens.append(int(next_token)) seq_length += 1 - if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token) == '<|im_end|>': + if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>': print(stream.end(), end="", flush=True) break else: From f5f6c6b95d935e65fbc37d3245c2be064389cfa5 Mon Sep 17 00:00:00 2001 From: Atream Date: Sun, 23 Feb 2025 14:33:58 +0000 Subject: [PATCH 3/3] update yaml --- .../optimize_rules/DeepSeek-V3-Chat.yaml | 2 +- .../optimize_rules/Moonlight-16B-A3B.yaml | 75 +++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml index 4c8eca2..6fb6586 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml @@ -1,7 +1,7 @@ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.RotaryEmbeddingV3 + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda" prefill_device: "cuda" diff --git a/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml b/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml new file mode 100644 index 0000000..4c8eca2 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml @@ -0,0 +1,75 @@ +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.RotaryEmbeddingV3 + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file