diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index d5e74de..d087752 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -64,7 +64,6 @@ def local_chat( force_think: bool = False, ): - torch.set_grad_enabled(False) Config().cpu_infer = cpu_infer diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 85378ee..b4c5402 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -441,10 +441,10 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): # mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank] # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank] # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] - attn_output = attn_output.transpose(1, 2) - attn_output = torch.matmul(attn_output, out_absorb.mT) + attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank] + attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim] - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim] attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 21b4830..035bac4 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -159,7 +159,7 @@ class KExpertsCPU(KExpertsBase): down_ptr = ctypes.addressof( ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents ) - # print(self.gate_qtype, self.up_qtype, self.down_qtype) + #print(self.gate_type, self.up_type, self.down_type) n_routed_experts = self.n_routed_experts # n_routed_experts = len(self.orig_module) moe_config = MOEConfig( @@ -450,9 +450,9 @@ class KExpertsTorch(KExpertsBase): self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype) self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype) - self.up = torch.cat(self.gate, dim=0) + self.up = torch.cat(self.up, dim=0) self.gate = torch.cat(self.gate, dim=0) - self.down = torch.cat(self.gate, dim=0) + self.down = torch.cat(self.down, dim=0) return def unload(self): diff --git a/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml b/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml new file mode 100644 index 0000000..4c8eca2 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml @@ -0,0 +1,75 @@ +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.RotaryEmbeddingV3 + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index cc4a323..5c608b1 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -207,7 +207,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud tokens.append(int(next_token)) seq_length += 1 - if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token) == '<|im_end|>': + if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>': print(stream.end(), end="", flush=True) break else: