diff --git a/README.md b/README.md
index f62528b..7b539b4 100644
--- a/README.md
+++ b/README.md
@@ -23,7 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
🔥 Updates
-* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; Longer Context (from 8K to 128K for 24GB VRAM).
+* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).
* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. For detailed show case and reproduction tutorial, see [here](./doc/en/DeepseekR1_V3_tutorial.md).
* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.
diff --git a/README_ZH.md b/README_ZH.md
index 3696a08..6c82805 100644
--- a/README_ZH.md
+++ b/README_ZH.md
@@ -21,7 +21,7 @@ KTransformers 是一个以 Python 为中心的灵活框架,其核心是可扩
🔥 更新
-* **2025 年 2 月 15 日**:为DeepSeek-V3/R1支持[FP8 GPU内核](./doc/en/fp8_kernel.md); 支持更长的上下文 (从8K到128K仅用24GB VRAM).
+* **2025 年 2 月 15 日**:为DeepSeek-V3/R1支持[FP8 GPU内核](./doc/en/fp8_kernel.md); 支持更长的上下文([教程](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context)).
* **2025 年 2 月 15 日**:长上下文(从4K到8K,24GB VRAM) & 稍快的速度(+15%)(最快 16 Tokens/s),文档请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md) 和 [在线指南](https://kvcache-ai.github.io/ktransformers/) 。
* **2025 年 2 月 10 日**:支持 Deepseek-R1 和 V3 在单个(24GB VRAM)/多 GPU 和 382G DRAM 上运行,速度提升高达 3~28 倍。详细教程请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md)。
* **2024 年 8 月 28 日**:支持 InternLM2.5-7B-Chat-1M 模型下的 1M 上下文,使用 24GB 的 VRAM 和 150GB 的 DRAM。详细教程请参见 [这里](./doc/en/long_context_tutorial.md)。
diff --git a/doc/README.md b/doc/README.md
index 9c955c2..f50e6b7 100644
--- a/doc/README.md
+++ b/doc/README.md
@@ -22,6 +22,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
🔥 Updates
+* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. The detailed tutorial is [here](./en/DeepseekR1_V3_tutorial.md).
* **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./en/long_context_tutorial.md).
* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
new file mode 100644
index 0000000..fa8c03d
--- /dev/null
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
@@ -0,0 +1,157 @@
+- match:
+ name: "^model.embed_tokens"
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
+
+- match:
+ name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
+ replace:
+ class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+- match:
+ name: "^model\\.layers\\.([3456][0-9])\\."
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
+ replace:
+ class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
+ kwargs:
+ generate_device: "cuda:1"
+ prefill_device: "cuda:1"
+
+- match:
+ name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+ generate_op: "KLinearFP8"
+ prefill_op: "KLinearTorch"
+
+- match:
+ name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "cuda:1"
+ prefill_device: "cuda:1"
+ generate_op: "KLinearFP8"
+ prefill_op: "KLinearTorch"
+
+- match:
+ name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
+ replace:
+ class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+- match:
+ name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
+ replace:
+ class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
+ kwargs:
+ generate_device: "cuda:1"
+ prefill_device: "cuda:1"
+
+- match:
+ name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
+ class: ktransformers.models.modeling_deepseek_v3.MoEGate
+ replace:
+ class: ktransformers.operators.gate.KMoEGate
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+- match:
+ name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
+ class: ktransformers.models.modeling_deepseek_v3.MoEGate
+ replace:
+ class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
+ kwargs:
+ generate_device: "cuda:1"
+ prefill_device: "cuda:1"
+
+- match:
+ name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
+ replace:
+ class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
+ kwargs:
+ prefill_device: "cuda:0"
+ prefill_op: "KExpertsTorch"
+ generate_device: "cpu"
+ generate_op: "KExpertsCPU"
+ out_device: "cuda:0"
+ recursive: False # don't recursively inject submodules of this module
+
+- match:
+ name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$"
+ replace:
+ class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
+ kwargs:
+ prefill_device: "cuda:1"
+ prefill_op: "KExpertsTorch"
+ generate_device: "cpu"
+ generate_op: "KExpertsCPU"
+ out_device: "cuda:1"
+ recursive: False # don't recursively inject submodules of this module
+
+- match:
+ name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
+ replace:
+ class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+ absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
+
+- match:
+ name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
+ replace:
+ class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
+ kwargs:
+ generate_device: "cuda:1"
+ prefill_device: "cuda:1"
+ absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
+
+- match:
+ name: "^model$"
+ replace:
+ class: "ktransformers.operators.models.KDeepseekV2Model"
+ kwargs:
+ per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
+ transfer_map:
+ 30: "cuda:1"
+
+- match:
+ name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+
+- match:
+ name: "^lm_head"
+ class: torch.nn.Linear
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cuda:1"
+ prefill_device: "cuda:1"
+
+
+- match:
+ name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cuda:1"
+ prefill_device: "cuda:1"