From 7b7c6a657d8f4bb7a0e330dd91c03a3e81e802f4 Mon Sep 17 00:00:00 2001
From: Azure
Date: Sat, 22 Feb 2025 13:05:08 +0000
Subject: [PATCH 01/14] =?UTF-8?q?Add=20fp8=20linear=20kernel;\n=20Add=20em?=
=?UTF-8?q?pty=20cache=20to=20fit=20in=2016G=20VRAM;=20By=20'wkGCaSS=20-?=
=?UTF-8?q?=20=E7=9F=A5=E4=B9=8E=20https://zhuanlan.zhihu.com/p/2549161122?=
=?UTF-8?q?5'?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../ktransformers_ext/triton/fp8gemm.py | 191 ++++++++++++++++++
ktransformers/operators/linear.py | 61 +++++-
ktransformers/tests/triton_fp8gemm_test.py | 73 +++++++
ktransformers/util/custom_gguf.py | 6 +
ktransformers/util/utils.py | 2 +-
5 files changed, 331 insertions(+), 2 deletions(-)
create mode 100644 ktransformers/ktransformers_ext/triton/fp8gemm.py
create mode 100644 ktransformers/tests/triton_fp8gemm_test.py
diff --git a/ktransformers/ktransformers_ext/triton/fp8gemm.py b/ktransformers/ktransformers_ext/triton/fp8gemm.py
new file mode 100644
index 0000000..4da4cfe
--- /dev/null
+++ b/ktransformers/ktransformers_ext/triton/fp8gemm.py
@@ -0,0 +1,191 @@
+from typing import Tuple
+
+import torch
+import triton
+import triton.language as tl
+from triton import Config
+
+
+@triton.jit
+def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
+ """
+ Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
+
+ Args:
+ x_ptr (triton.Pointer): Pointer to the input tensor.
+ y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
+ s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
+ BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
+
+ Returns:
+ None
+ """
+ pid = tl.program_id(axis=0)
+ offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ x = tl.load(x_ptr + offs).to(tl.float32)
+ s = tl.max(tl.abs(x)) / 448.
+ y = x / s
+ y = y.to(y_ptr.dtype.element_ty)
+ tl.store(y_ptr + offs, y)
+ tl.store(s_ptr + pid, s)
+
+
+def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Quantizes the input tensor `x` using block-wise quantization.
+
+ Args:
+ x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
+ block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+ - The quantized tensor with dtype `torch.float8_e4m3fn`.
+ - A tensor of scaling factors with dtype `torch.float32`.
+ """
+ assert x.is_contiguous(), 'Input tensor must be contiguous'
+ assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
+ y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
+ s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
+ grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
+ act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
+ return y, s
+
+
+@triton.jit
+def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
+ """
+ Dequantizes weights using the provided scaling factors and stores the result.
+
+ Args:
+ x_ptr (tl.pointer): Pointer to the quantized weights.
+ s_ptr (tl.pointer): Pointer to the scaling factors.
+ y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
+ M (int): Number of rows in the weight matrix.
+ N (int): Number of columns in the weight matrix.
+ BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
+
+ Returns:
+ None
+ """
+ pid_m = tl.program_id(axis=0)
+ pid_n = tl.program_id(axis=1)
+ n = tl.cdiv(N, BLOCK_SIZE)
+ offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ offs = offs_m[:, None] * N + offs_n[None, :]
+ mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
+ x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
+ s = tl.load(s_ptr + pid_m * n + pid_n)
+ y = x * s
+ tl.store(y_ptr + offs, y, mask=mask)
+
+
+def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
+ """
+ Dequantizes the given weight tensor using the provided scale tensor.
+
+ Args:
+ x (torch.Tensor): The quantized weight tensor of shape (M, N).
+ s (torch.Tensor): The scale tensor of shape (M, N).
+ block_size (int, optional): The block size to use for dequantization. Defaults to 128.
+
+ Returns:
+ torch.Tensor: The dequantized weight tensor of the same shape as `x`.
+
+ Raises:
+ AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
+ """
+ assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
+ assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
+ M, N = x.size()
+ y = torch.empty_like(x, dtype=torch.get_default_dtype())
+ grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
+ weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
+ return y
+
+
+fp8_gemm_configs = [
+ Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
+ for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
+]
+
+@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
+@triton.jit
+def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
+ a_s_ptr, b_s_ptr,
+ M, N: tl.constexpr, K: tl.constexpr,
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr):
+ """
+ Performs a matrix multiplication operation on FP8 matrices with scaling factors.
+
+ Args:
+ a_ptr (tl.tensor): Pointer to the first input matrix A.
+ b_ptr (tl.tensor): Pointer to the second input matrix B.
+ c_ptr (tl.tensor): Pointer to the output matrix C.
+ a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
+ b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
+ M (int): Number of rows in matrix A and C.
+ N (tl.constexpr): Number of columns in matrix B and C.
+ K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
+ BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
+ BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
+ BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
+
+ Returns:
+ None
+ """
+ pid_m = tl.program_id(axis=0)
+ pid_n = tl.program_id(axis=1)
+ k = tl.cdiv(K, BLOCK_SIZE_K)
+ offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
+ offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
+ b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
+ a_s_ptrs = a_s_ptr + offs_m * k
+ b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
+
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ for i in range(k):
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
+ a_s = tl.load(a_s_ptrs)
+ b_s = tl.load(b_s_ptrs)
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
+ a_ptrs += BLOCK_SIZE_K
+ b_ptrs += BLOCK_SIZE_K
+ a_s_ptrs += 1
+ b_s_ptrs += 1
+ c = accumulator.to(c_ptr.dtype.element_ty)
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
+ mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
+ tl.store(c_ptrs, c, mask=mask)
+
+
+def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
+ """
+ Perform a matrix multiplication using FP8 precision.
+
+ Args:
+ a (torch.Tensor): The first input matrix, must be contiguous.
+ a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
+ b (torch.Tensor): The second input matrix, must be contiguous.
+ b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
+
+ Returns:
+ torch.Tensor: The result of the matrix multiplication.
+ """
+ assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
+ assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
+ K = a.size(-1)
+ M = a.numel() // K
+ N = b.size(0)
+ c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
+ fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
+ return c
\ No newline at end of file
diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py
index 08a2cca..5aff964 100644
--- a/ktransformers/operators/linear.py
+++ b/ktransformers/operators/linear.py
@@ -25,6 +25,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
)
from ktransformers.operators.base_operator import BaseInjectedModule
from transformers.configuration_utils import PretrainedConfig
+from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
from abc import ABC, abstractmethod
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
@@ -164,7 +165,65 @@ class KLinearTorch(KLinearBase):
if self.has_bias:
self.bias = None
-
+class KLinearFP8(KLinearBase):
+ marlin_q_w: torch.Tensor
+ marlin_s: torch.Tensor
+ g_idx: torch.Tensor
+ sort_indices: torch.Tensor
+ has_bias: bool
+ weight: torch.Tensor
+ scale_w: torch.Tensor
+ bias: torch.Tensor
+ def __init__(
+ self,
+ key: str,
+ gguf_loader: GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module = None,
+ device: str = "cuda",
+ block_size: int = 128,
+ **kwargs,
+ ):
+ super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
+ self.has_bias = False
+ self.dtype = torch.get_default_dtype()
+ self.block_size = block_size
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = x.to(self.device)
+ orig_shape = list(x.shape)
+ orig_dtype = x.dtype
+ x = x.reshape(-1, orig_shape[-1])
+ x_quantized, scale_x = act_quant(x, self.block_size)
+ y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight.scale)
+ if self.bias is not None:
+ y += self.bias
+ return y.to(orig_dtype).reshape(orig_shape)
+
+ def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
+ if device is None: device = self.device
+ if w is None:
+ w = self.load_weight(device=device)
+ if isinstance(w, nn.Parameter):
+ self.weight = w.to(device)
+ self.has_bias = False
+ elif isinstance(w, tuple):
+ self.weight = w[0].to(device)
+ self.bias = w[1].to(device)
+ self.has_bias = True
+ else:
+ raise ValueError("Invalid weight type")
+ self.weight = self.weight.to(device)
+ if self.has_bias:
+ self.bias = self.bias.to(device)
+
+ def unload(self):
+ if self.weight is not None:
+ self.weight = None
+ if self.has_bias:
+ self.bias = None
+
+
class KLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
diff --git a/ktransformers/tests/triton_fp8gemm_test.py b/ktransformers/tests/triton_fp8gemm_test.py
new file mode 100644
index 0000000..bb3801c
--- /dev/null
+++ b/ktransformers/tests/triton_fp8gemm_test.py
@@ -0,0 +1,73 @@
+import torch
+import torch.nn.functional as F
+from typing import Optional
+import pytest
+from typing import Tuple, Optional, Literal
+
+# use dir path
+import os
+import sys
+sys.path.insert(0, "/home/azure/ktransformers")
+print(sys.path)
+from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
+from safetensors import safe_open
+
+world_size = 1
+rank = 0
+block_size = 128
+gemm_impl: Literal["bf16", "fp8"] = "bf16"
+# Assuming `fp8_gemm`, `act_quant`, `weight_dequant` and other relevant functions are already defined
+
+def test_fp8_gemm_vs_torch_matmul():
+ # Test case 1: Create random matrices of size (M, K) and (K, N)
+ M, K, N = 64, 128, 256 # Matrix dimensions
+ x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
+ weight = torch.randn(N, K, dtype=torch.bfloat16, device='cuda')
+
+ # Apply act_quant to both matrices
+ x_quantized, scale_x = act_quant(x, block_size)
+ weight_quantized, scale_w = act_quant(weight, block_size)
+
+ # mk continous
+ x_quantized = x_quantized.contiguous()
+ weight_quantized = weight_quantized.contiguous()
+ scale_x = scale_x.contiguous()
+ scale_w = scale_w.contiguous()
+
+ # Perform fp8_gemm using the quantized tensors
+ result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight_quantized, scale_w)
+
+ # Perform torch.matmul using the original floating point tensors
+ result_torch_matmul = torch.matmul(x, weight.T)
+ print(f'result_torch_matmul: {result_torch_matmul.shape}')
+ print(f'result_fp8_gemm: {result_fp8_gemm.shape}')
+
+ print(f"result_fp8_gemm:\n {result_fp8_gemm}")
+ print(f"result_torch_matmul:\n {result_torch_matmul}")
+
+def test_fp8_gemm_vs_torch_matmul_load():
+ file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors"
+ with safe_open(file_path, framework="pt", device=0) as f:
+ weight = f.get_tensor("model.layers.0.mlp.down_proj.weight")
+ scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv")
+
+ # weight_dequant
+ weight_dequantized = weight_dequant(weight, scale)
+ print(f"weight_dequantized: {weight_dequantized.shape}")
+ N, K = weight_dequantized.shape
+ M = 64
+ x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
+ x_quantized, scale_x = act_quant(x, block_size)
+
+ # Test case 1: quantized x matmal with undequantized weight
+ result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
+ print(f"result_fp8_gemm:\n {result_fp8_gemm}")
+
+ # Perform torch.matmul using the original floating point tensors
+ result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T)
+ print(f"result_torch_matmul:\n {result_torch_matmul}")
+
+if __name__ == "__main__":
+ test_fp8_gemm_vs_torch_matmul()
+ test_fp8_gemm_vs_torch_matmul_load()
+
\ No newline at end of file
diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py
index eaa1a7d..26afd39 100644
--- a/ktransformers/util/custom_gguf.py
+++ b/ktransformers/util/custom_gguf.py
@@ -127,6 +127,7 @@ GGML_BLOCK_SIZES = {
"Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
"Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2,
"IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64,
+ "FP8": 1,
}
GGML_ELEMENTS_PER_BLOCK = {
@@ -142,6 +143,7 @@ GGML_ELEMENTS_PER_BLOCK = {
"Q5_K": 256,
"Q6_K": 256,
"IQ4_XS": 256,
+ "FP8": 1,
}
DATA_TYPES = {
@@ -158,6 +160,7 @@ DATA_TYPES = {
"uint64": 10,
"int64": 11,
"float64": 12,
+ "FP8": 13,
}
class GGUFLoader:
@@ -393,6 +396,9 @@ def read_value(f, data_type):
elem_type, count = struct.unpack("
Date: Sun, 23 Feb 2025 07:40:47 +0000
Subject: [PATCH 02/14] remove causal mask
---
ktransformers/operators/models.py | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py
index 5d2e911..3877dbc 100644
--- a/ktransformers/operators/models.py
+++ b/ktransformers/operators/models.py
@@ -649,9 +649,12 @@ class KDeepseekV2Model(BaseInjectedModule):
if per_layer_prefill_flag:
causal_mask = None
else:
- causal_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
- )
+ if os.name == 'nt':
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+ else:
+ causal_mask = None
# embed positions
hidden_states = inputs_embeds
From 581a524f65db422011d9a8db99439d658223db6f Mon Sep 17 00:00:00 2001
From: Azure
Date: Mon, 24 Feb 2025 11:16:23 +0000
Subject: [PATCH 03/14] Add data loader to read special weights for fp8; Add
special weight process script
---
.../ktransformers_ext/triton/fp8gemm.py | 1 +
ktransformers/operators/experts.py | 11 +-
ktransformers/operators/gate.py | 13 +-
ktransformers/operators/linear.py | 38 ++--
...pSeek-V3-Chat-fp8-linear-ggml-experts.yaml | 63 ++++++
ktransformers/tests/triton_fp8gemm_test.py | 47 +++-
ktransformers/util/custom_gguf.py | 19 +-
ktransformers/util/custom_loader.py | 86 +++++++
ktransformers/util/utils.py | 15 +-
merge_tensors/merge_safetensor_gguf.py | 214 ++++++++++++++++++
10 files changed, 481 insertions(+), 26 deletions(-)
create mode 100644 ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
create mode 100644 ktransformers/util/custom_loader.py
create mode 100644 merge_tensors/merge_safetensor_gguf.py
diff --git a/ktransformers/ktransformers_ext/triton/fp8gemm.py b/ktransformers/ktransformers_ext/triton/fp8gemm.py
index 4da4cfe..7d5b72e 100644
--- a/ktransformers/ktransformers_ext/triton/fp8gemm.py
+++ b/ktransformers/ktransformers_ext/triton/fp8gemm.py
@@ -1,3 +1,4 @@
+# Adopted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
from typing import Tuple
import torch
diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py
index 21b4830..1ea244a 100644
--- a/ktransformers/operators/experts.py
+++ b/ktransformers/operators/experts.py
@@ -245,7 +245,16 @@ class KExpertsCPU(KExpertsBase):
down_type = None
for key in keys:
- if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
+ if self.gguf_loader.safetensor_loader is not None:
+ # using a temp ugly way to temprary load the tensor
+ gate = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.weight").numpy()
+ up = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.weight").numpy()
+ down = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.weight").numpy()
+ gate_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.ggml_type").item()
+ up_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.ggml_type").item()
+ down_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.ggml_type").item()
+
+ elif key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py
index 52bb33a..d908093 100644
--- a/ktransformers/operators/gate.py
+++ b/ktransformers/operators/gate.py
@@ -67,7 +67,14 @@ class KMoEGateBase(ABC):
for key in keys:
key = ".".join(key.split(".")[:-1])
- if key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
+ if self.gguf_loader.safetensor_loader is not None:
+ targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
+ weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight")
+ e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias")
+ weight_type = weight.dtype
+ e_score_correction_bias_type = e_score_correction_bias.dtype
+ res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
+ elif key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
tensors = self.load_multi(key, targets, device=device)
weight = tensors[".ffn_gate_inp.weight"]
@@ -116,8 +123,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
raise ValueError("Invalid weight type")
- self.orig_module.weight = self.orig_module.weight.to(device)
- self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
+ self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
+ self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
def unload(self):
if self.weight is not None:
diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py
index 5aff964..e778102 100644
--- a/ktransformers/operators/linear.py
+++ b/ktransformers/operators/linear.py
@@ -76,7 +76,13 @@ class KLinearBase(ABC):
keys = [self.key]
for key in keys:
- if key + ".weight" in self.gguf_loader.tensor_file_map:
+ if self.gguf_loader.safetensor_loader is not None:
+ # using safetensor_loader
+ tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
+ weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
+ return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
+
+ elif key + ".weight" in self.gguf_loader.tensor_file_map:
if key + ".bias" in self.gguf_loader.tensor_file_map:
tensors = self.load_multi(key, ["weight", "bias"], device=device)
tensor = tensors["weight"]
@@ -166,6 +172,8 @@ class KLinearTorch(KLinearBase):
self.bias = None
class KLinearFP8(KLinearBase):
+ # this kernel requires special handling for weight
+ # Please load the weight file downloaded from KVCache.AI
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
g_idx: torch.Tensor
@@ -191,26 +199,20 @@ class KLinearFP8(KLinearBase):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(self.device)
- orig_shape = list(x.shape)
- orig_dtype = x.dtype
- x = x.reshape(-1, orig_shape[-1])
+ orig_dtype = x.dtype
x_quantized, scale_x = act_quant(x, self.block_size)
- y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight.scale)
- if self.bias is not None:
- y += self.bias
- return y.to(orig_dtype).reshape(orig_shape)
+ y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight_scale_inv)
+ return y.to(dtype=orig_dtype)
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device
if w is None:
w = self.load_weight(device=device)
- if isinstance(w, nn.Parameter):
- self.weight = w.to(device)
- self.has_bias = False
- elif isinstance(w, tuple):
+ ### TODO fit weight_inv format
+ if isinstance(w, tuple):
self.weight = w[0].to(device)
- self.bias = w[1].to(device)
- self.has_bias = True
+ self.weight_scale_inv = w[1].to(device)
+ self.has_bias = False
else:
raise ValueError("Invalid weight type")
self.weight = self.weight.to(device)
@@ -425,7 +427,8 @@ class KLinearCPUInfer(KLinearBase):
LINEAR_MAP = {
"KLinearMarlin": KLinearMarlin,
"KLinearTorch": KLinearTorch,
- "KLinearCPUInfer": KLinearCPUInfer
+ "KLinearCPUInfer": KLinearCPUInfer,
+ "KLinearFP8": KLinearFP8,
}
class KTransformersLinear(BaseInjectedModule, KLinearBase):
@@ -472,10 +475,11 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
def forward(self, x):
if self.mode == InferenceState.PREFILL:
assert self.prefill_linear is not None, "cpu linear is not initialized"
- return self.prefill_linear.forward(x)
+ y = self.prefill_linear.forward(x)
else:
assert self.generate_linear is not None, "gpu linear is not initialized"
- return self.generate_linear.forward(x)
+ y = self.generate_linear.forward(x)
+ return y
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
if not mode:
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
new file mode 100644
index 0000000..25f021e
--- /dev/null
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
@@ -0,0 +1,63 @@
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
+ replace:
+ class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+- match:
+ name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+ generate_op: "KLinearFP8"
+ prefill_op: "KLinearTorch"
+- match:
+ name: "^model\\.layers\\..*\\.mlp$"
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
+ replace:
+ class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.MoEGate
+ replace:
+ class: ktransformers.operators.gate.KMoEGate
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+- match:
+ name: "^model\\.layers\\..*\\.mlp\\.experts$"
+ replace:
+ class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
+ kwargs:
+ prefill_device: "cuda"
+ prefill_op: "KExpertsTorch"
+ generate_device: "cpu"
+ generate_op: "KExpertsCPU"
+ out_device: "cuda"
+ recursive: False # don't recursively inject submodules of this module
+- match:
+ name: "^model\\.layers\\..*\\.self_attn$"
+ replace:
+ class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+- match:
+ name: "^model$"
+ replace:
+ class: "ktransformers.operators.models.KDeepseekV2Model"
+ kwargs:
+ per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
+- match:
+ name: "^model.embed_tokens"
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
\ No newline at end of file
diff --git a/ktransformers/tests/triton_fp8gemm_test.py b/ktransformers/tests/triton_fp8gemm_test.py
index bb3801c..58888d6 100644
--- a/ktransformers/tests/triton_fp8gemm_test.py
+++ b/ktransformers/tests/triton_fp8gemm_test.py
@@ -3,7 +3,7 @@ import torch.nn.functional as F
from typing import Optional
import pytest
from typing import Tuple, Optional, Literal
-
+import time
# use dir path
import os
import sys
@@ -56,18 +56,61 @@ def test_fp8_gemm_vs_torch_matmul_load():
print(f"weight_dequantized: {weight_dequantized.shape}")
N, K = weight_dequantized.shape
M = 64
- x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
+ x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')
x_quantized, scale_x = act_quant(x, block_size)
# Test case 1: quantized x matmal with undequantized weight
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
print(f"result_fp8_gemm:\n {result_fp8_gemm}")
+ print(f"dtype {result_fp8_gemm.dtype}")
# Perform torch.matmul using the original floating point tensors
result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T)
print(f"result_torch_matmul:\n {result_torch_matmul}")
+def test_fp8_gemm_tplops():
+ file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors"
+ with safe_open(file_path, framework="pt", device=0) as f:
+ weight = f.get_tensor("model.layers.0.mlp.down_proj.weight")
+ scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv")
+
+ # weight_dequant
+ weight_dequantized = weight_dequant(weight, scale)
+ print(f"weight_dequantized: {weight_dequantized.shape}")
+ N, K = weight_dequantized.shape
+ M = 6400
+ x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')
+ # x_quantized, scale_x = act_quant(x, block_size)
+
+ # Calculate time for 1000 fp8_gemm
+ i = 10
+ flops_per_gemm = 2 * M * N * K
+ total_flops = i * flops_per_gemm
+
+ x_quantized, scale_x = act_quant(x, block_size)
+ result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
+ x_quantized, scale_x = act_quant(x, block_size)
+ result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
+
+
+ t0 = time.time()
+ torch.cuda.synchronize()
+ for i in range(i):
+ x_quantized, scale_x = act_quant(x, block_size)
+ result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
+ torch.cuda.synchronize()
+ t1 = time.time()
+
+ total_time = t1 - t0
+ tflops = total_flops / total_time / 1e12
+ print(f"total_time: {total_time}")
+ print(f"tflops: {tflops}")
+
+
+
+
if __name__ == "__main__":
test_fp8_gemm_vs_torch_matmul()
test_fp8_gemm_vs_torch_matmul_load()
+ test_fp8_gemm_tplops()
\ No newline at end of file
diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py
index 26afd39..d054ad3 100644
--- a/ktransformers/util/custom_gguf.py
+++ b/ktransformers/util/custom_gguf.py
@@ -25,6 +25,7 @@ import os
from enum import IntEnum
import torch
import KTransformersOps
+from .custom_loader import SafeTensorLoader
class GGMLQuantizationType(IntEnum):
F32 = 0
@@ -168,12 +169,15 @@ class GGUFLoader:
gguf_path: str
tensor_file_map: dict # {tensor_name: tensor_file_path}
gguf_file_meta: dict
+ safetensor_loader: SafeTensorLoader
def __init__(self, gguf_path: str):
# Check dir exist
if not os.path.exists(gguf_path):
raise FileNotFoundError(f"GGUF dir not found: {gguf_path}")
if os.path.isfile(gguf_path):
gguf_path = os.path.dirname(gguf_path)
+
+ self.safetensor_loader = None
self.tensor_info = {}
self.gguf_path = gguf_path
@@ -181,7 +185,13 @@ class GGUFLoader:
self.file_data_map = {}
self.gguf_file_meta = {}
self.tensor_device_map = {}
-
+
+ # I know this is ugly, but I don't want to change the original code too much
+ # TODO: merge gguf load and other loads.
+ safetensor_loader = SafeTensorLoader(gguf_path)
+ if safetensor_loader.tensor_file_map:
+ self.safetensor_loader = safetensor_loader
+ return
# Walk through all the .gguf files in the directory
found_gguf = False
for root, dirs, files in os.walk(gguf_path):
@@ -288,6 +298,13 @@ class GGUFLoader:
itemsize = int(np.empty([], dtype = item_type).itemsize)
return mmap_data[offset : offset + itemsize * item_count]
+ def get_undequanted_tensor_and_ggml_type(self, name):
+ t = self.tensor_info[name]
+ data = self.get_mmap_tensor(name)
+ ggml_type = t["ggml_type"]
+ data = torch.from_numpy(data)
+ return data, ggml_type
+
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor:
t = self.tensor_info[name]
if device.lower() == "cpu":
diff --git a/ktransformers/util/custom_loader.py b/ktransformers/util/custom_loader.py
new file mode 100644
index 0000000..ecc09a0
--- /dev/null
+++ b/ktransformers/util/custom_loader.py
@@ -0,0 +1,86 @@
+import struct
+import warnings
+import numpy as np
+import re
+import numpy.typing as npt
+from typing import Sequence
+import os
+from enum import IntEnum
+import torch
+import KTransformersOps
+from safetensors import safe_open
+from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
+from safetensors.torch import save_file
+
+class SafeTensorLoader:
+ tensor_file_map = {}
+ tensor_type_map = {}
+ file_handle_map = {}
+
+ def __init__(self, file_path: str):
+ self.__load_tensor_file_map(file_path)
+
+ def __load_tensor_file_map(self, file_path: str):
+ # 处理传入路径,确保是文件夹路径
+ if not os.path.exists(file_path):
+ raise FileNotFoundError(f"Path not found: {file_path}")
+ if os.path.isfile(file_path):
+ folder_path = os.path.dirname(file_path)
+ else:
+ folder_path = file_path
+
+ found_safetensor = False
+ for root, _, files in os.walk(folder_path):
+ files = sorted(files)
+ for file in files:
+ if file.endswith(".safetensors"):
+ found_safetensor = True
+ file_path = os.path.join(root, file)
+ if file not in self.file_handle_map:
+ try:
+ handle = safe_open(file_path, framework="pt")
+ self.file_handle_map[file] = handle
+ except Exception as e:
+ print(f"Error opening Safetensor file {file_path}: {e}")
+ continue
+
+ f = self.file_handle_map.get(file)
+ if f is None:
+ continue
+ try:
+ for key in f.keys():
+ self.tensor_file_map[key] = file
+ except Exception as e:
+ print(f"Error reading Safetensor file {file_path}: {e}")
+
+ # if not found_safetensor:
+ # raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
+
+ def load_tensor(self, key: str, device: str="cpu"):
+ if key not in self.tensor_file_map:
+ raise KeyError(f"Key {key} not found in Safetensor files")
+ file = self.tensor_file_map[key]
+ f = self.file_handle_map.get(file)
+ if f is None:
+ raise FileNotFoundError(f"File {file} not found in Safetensor files")
+ tensor = f.get_tensor(key)
+ return tensor.to(device)
+
+ def close_all_handles(self):
+ for handle in self.file_handle_map.values():
+ handle.close()
+ self.file_handle_map.clear()
+
+ def load_dequantized_tensor(self, key:str, device: str="cpu"):
+ if key not in self.tensor_file_map:
+ raise KeyError(f"Key {key} not found in Safetensor files")
+ file = self.tensor_file_map[key]
+ f = self.file_handle_map.get(file)
+ if f is None:
+ raise FileNotFoundError(f"File {file} not found in Safetensor files")
+ tensor = f.get_tensor(key).to(device)
+ if key.endswith(".weight"):
+ if key[:-7] + ".weight_scale_inv" in self.tensor_file_map:
+ weight_scale_inv = f.get_tensor(key[:-7] + ".weight_scale_inv").to(device)
+ tensor = weight_dequant(tensor, weight_scale_inv)
+ return tensor.to(device)
\ No newline at end of file
diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py
index 81d007c..1c21135 100644
--- a/ktransformers/util/utils.py
+++ b/ktransformers/util/utils.py
@@ -66,12 +66,23 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
for name, param in local_state.items():
key = prefix + name
translated_key = translate_name_to_gguf(key)
- if translated_key in gguf_loader.tensor_file_map:
+
+ # TODO: Merge all loader.
+ # I know this is ugly but lets do it for now.
+ if gguf_loader.safetensor_loader is not None:
+ load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor
+ tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
+ else:
+ load_dequantized_tensor = gguf_loader.load_gguf_tensor
+ tensor_file_map = gguf_loader.tensor_file_map
+
+ if translated_key in tensor_file_map:
target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}")
torch.cuda.empty_cache() # To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225"
- weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
+ # weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
+ weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
set_param(module, name, weights)
del weights
else:
diff --git a/merge_tensors/merge_safetensor_gguf.py b/merge_tensors/merge_safetensor_gguf.py
new file mode 100644
index 0000000..7aeb62d
--- /dev/null
+++ b/merge_tensors/merge_safetensor_gguf.py
@@ -0,0 +1,214 @@
+# this script targets to merge the fp8 safe tensor and the gguf quantized tensors.
+
+import os
+# insert the path of the project
+import sys
+sys.path.insert(0, "/home/azure/ktransformers")
+import argparse
+import torch
+from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf
+from safetensors import safe_open
+from safetensors.torch import save_file
+import re
+from collections import defaultdict
+
+def read_safetensor_keys_from_folder(folder_path)->dict:
+ """
+ :param folder_path: folder path
+ :return: key_to_file_map
+ """
+ # check if the folder path is exist
+ if not os.path.exists(folder_path):
+ raise FileNotFoundError(f"GGUF dir not found: {folder_path}")
+ if os.path.isfile(folder_path):
+ folder_path = os.path.dirname(folder_path)
+
+ key_to_file_map = {}
+
+ found_safetensor = False
+ for root, dirs, files in os.walk(folder_path):
+ # sort files
+ files = sorted(files)
+ for file in files:
+ if file.endswith(".safetensors"):
+ found_safetensor = True
+ file_path = os.path.join(root, file)
+ try:
+ with safe_open(file_path, framework="pt") as f:
+ for key in f.keys():
+ if "model.layers.61" in key:
+ # skip MTP layer
+ continue
+ # try:
+ # if int(key.split('.')[2]) > 4:
+ # continue
+ # except:
+ # pass
+ key_to_file_map[key] = file_path
+ except Exception as e:
+ print(f"Error reading Safetensor file {file_path}: {e}")
+
+ if not found_safetensor:
+ raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
+
+ return key_to_file_map
+
+tensor_from_gguf = [] # todo: add keys in gguf that should be used in the final tensor
+
+def translate_name(name:str)->str:
+ """
+ :param name: name of the tensor
+ :return: translated name
+ """
+ name = translate_name_to_gguf(name)
+ name = name.replace(".up_proj.", ".ffn_up_exps.")
+ name = name.replace(".down_proj.", ".ffn_down_exps.")
+ name = name.replace(".gate_proj.", ".ffn_gate_exps.")
+ name = name.replace(".ffn_gate_inp.e_score_correction_bias", ".exp_probs_b.bias")
+ return name
+
+
+def combine_tensor_sources(safetensor_path:str, gguf_path:str):
+ gguf_loader = GGUFLoader(gguf_path)
+ gguf_tensor_file_map = gguf_loader.tensor_file_map
+ safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path)
+
+ # build a map for the key to the tensor
+ # according to the key, we can get the tensor from the file
+
+ target_tensor_map = {}
+ for key in safetensor_tensor_file_map.keys():
+ # for all experts, we use the gguf tensor
+ if ".mlp.experts." in key:
+ if '.weight_scale_inv' in key:
+ continue
+ key = '.'.join(key.split('.')[:5]+key.split('.')[-2:])
+ translated_key = translate_name(key)
+ target_tensor_map[key] = gguf_tensor_file_map[translated_key]
+ continue
+
+ if any(target_key in key for target_key in tensor_from_gguf):
+ target_tensor_map[key] = gguf_tensor_file_map[translate_name(key)]
+ else:
+ target_tensor_map[key] = safetensor_tensor_file_map[key]
+
+ return target_tensor_map, gguf_loader
+
+def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader):
+ # Ensure output directory exists
+ os.makedirs(output_path, exist_ok=True)
+
+ # Cache for safetensor file handles and GGUF loaders
+ safetensors_cache = {}
+ gguf_cache = {}
+
+ # Group tensors by layer
+ layer_groups = defaultdict(list)
+ non_layer_keys = []
+ layer_pattern = re.compile(r'\.layers\.(\d+)\.')
+
+ for key in target_tensor_map:
+ match = layer_pattern.search(key)
+ if match:
+ layer_num = int(match.group(1))
+ layer_groups[layer_num].append(key)
+ else:
+ non_layer_keys.append(key)
+
+ # Calculate total shards
+ total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1
+ if total_shards == 0:
+ raise ValueError("No tensors to save")
+
+ shard_idx = 0
+
+ # Save non-layer tensors to the first shard if they exist
+ if non_layer_keys:
+ tensors = {}
+ for key in non_layer_keys:
+ file_path = target_tensor_map[key]
+ tensor = None
+ ggml_type = None
+ if file_path.endswith('.safetensors'):
+ if file_path not in safetensors_cache:
+ safetensors_cache[file_path] = safe_open(file_path, framework='pt')
+ f = safetensors_cache[file_path]
+ tensor = f.get_tensor(key)
+ elif file_path.endswith('.gguf'):
+ gguf_name = translate_name(key)
+ tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
+ else:
+ raise ValueError(f"Unsupported file format: {file_path}")
+ tensors[translate_name(key)] = tensor
+ if ggml_type:
+ ggml_type = torch.tensor(ggml_type)
+ ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type"
+ tensors[ggml_key] = ggml_type
+
+ output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors")
+ print(f"Saving non-layer tensors to {output_file}")
+ save_file(tensors, output_file)
+ print(tensors.keys())
+
+ shard_idx += 1
+
+ # Save each layer's tensors to subsequent shards
+ for layer_num in sorted(layer_groups.keys()):
+ layer_keys = layer_groups[layer_num]
+ tensors = {}
+ for key in layer_keys:
+ file_path = target_tensor_map[key]
+ tensor = None
+ ggml_type = None
+ if file_path.endswith('.safetensors'):
+ if file_path not in safetensors_cache:
+ safetensors_cache[file_path] = safe_open(file_path, framework='pt')
+ f = safetensors_cache[file_path]
+ tensor = f.get_tensor(key)
+ tensor_info = tensor.shape
+ elif file_path.endswith('.gguf'):
+ gguf_name = translate_name(key)
+ tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
+ # tensor_info = gguf_loader.tensor_info[gguf_name]
+ # ggml_type = gguf_loader.tensor_info[gguf_name]['ggml_type']
+ else:
+ raise ValueError(f"Unsupported file format: {file_path}")
+ tensors[translate_name(key)] = tensor
+ if ggml_type:
+ ggml_type = torch.tensor(ggml_type)
+ ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type"
+ tensors[ggml_key] = ggml_type
+
+ output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors")
+ print(f"Saving layer {layer_num} to {output_file}")
+ print(tensors.keys())
+ save_file(tensors, output_file)
+ shard_idx += 1
+
+ return
+
+def main():
+ # 创建命令行参数解析器
+ parser = argparse.ArgumentParser(description="Read parameters from Safetensor and GGUF files")
+ parser.add_argument("--safetensor_path", type=str, help="Path to the Safetensor file", default="/mnt/data/model/DeepSeek-V3")
+ parser.add_argument("--gguf_path", type=str, help="Path to the GGUF file", default="/mnt/data/model/DeepseekV3-q4km-gguf")
+ parser.add_argument("--output_path", type=str, help="Path to the output file", default="/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8")
+
+ # print all the arguments
+ print("All the arguments:")
+ print(parser.parse_args())
+
+ # 解析命令行参数
+ args = parser.parse_args()
+
+ safetensor_path = args.safetensor_path
+ gguf_path = args.gguf_path
+ output_path = args.output_path
+
+ target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)
+ write_combined_tensor(target_tensor_map, output_path, gguf_loader)
+
+ return
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
From 4dc5518e4d2ae89a687709bcbe05d2f3f80e00ad Mon Sep 17 00:00:00 2001
From: Azure
Date: Mon, 24 Feb 2025 15:37:01 +0000
Subject: [PATCH 04/14] update fp8 kernel tutorial
---
doc/SUMMARY.md | 3 +-
doc/en/FAQ.md | 5 ++-
doc/en/injection_tutorial.md | 1 +
.../server/backend/interfaces/transformers.py | 2 +-
merge_tensors/README.md | 36 +++++++++++++++++++
merge_tensors/merge_safetensor_gguf.py | 1 -
requirements-local_chat.txt | 3 +-
7 files changed, 46 insertions(+), 5 deletions(-)
create mode 100644 merge_tensors/README.md
diff --git a/doc/SUMMARY.md b/doc/SUMMARY.md
index 4efb7ea..47b4a02 100644
--- a/doc/SUMMARY.md
+++ b/doc/SUMMARY.md
@@ -5,10 +5,11 @@
- [Installation Guide](en/install.md)
# Tutorial
-- [Deepseek-R1/V3 Show Case](en/DeepseekR1_V3_tutorial.md)
+- [Deepseek-R1/V3 Show Case/Tutorial](en/DeepseekR1_V3_tutorial.md)
- [Why KTransformers So Fast](en/deepseek-v2-injection.md)
- [Injection Tutorial](en/injection_tutorial.md)
- [Multi-GPU Tutorial](en/multi-gpu-tutorial.md)
+- [Using FP8 GPU Kernel](../merge_tensors/README.md)
# Server
- [Server](en/api/server/server.md)
- [Website](en/api/server/website.md)
diff --git a/doc/en/FAQ.md b/doc/en/FAQ.md
index 51c6271..93977cf 100644
--- a/doc/en/FAQ.md
+++ b/doc/en/FAQ.md
@@ -55,7 +55,7 @@ You have to set `--cpu_infer` to the number of cores you want to use. The more c
### Q: My DeepSeek-R1 model is not thinking.
-According to DeepSeek, you need to enforce the model to initiate its response with "\\n" at the beginning of every output by passing the arg `--force_think true `.
+According to DeepSeek, you need to enforce the model to initiate its response with "\\n" at the beginning of every output by passing the arg `--force_think True `.
### Q: Loading gguf error
@@ -63,9 +63,12 @@ Make sure you:
1. Have the `gguf` file in the `--gguf_path` directory.
2. The directory only contains gguf files from one model. If you have multiple models, you need to separate them into different directories.
3. The folder name it self should not end with `.gguf`, eg. `Deep-gguf` is correct, `Deep.gguf` is wrong.
+4. The file itself is not corrupted; you can verify this by checking that the sha256sum matches the one from huggingface, modelscope, or hf-mirror.
### Q: Version `GLIBCXX_3.4.30' not found
The detailed error:
>ImportError: /mnt/data/miniconda3/envs/xxx/bin/../lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /home/xxx/xxx/ktransformers/./cpuinfer_ext.cpython-312-x86_64-linux-gnu.so)
Running `conda install -c conda-forge libstdcxx-ng` can solve the problem.
+
+
diff --git a/doc/en/injection_tutorial.md b/doc/en/injection_tutorial.md
index 5ebb327..4518836 100644
--- a/doc/en/injection_tutorial.md
+++ b/doc/en/injection_tutorial.md
@@ -59,6 +59,7 @@ Supported operators and their corresponding classes are as follows:
| Linear | KTransformersLinear | KLinearMarlin | Marlin as backend |
| | | KLinearTorch | pytorch as backend |
| | | KLinearCPUInfer | llamafile as backend |
+| | | KLinearFP8 | Triton fp8_gemm kernel. Requires GPU be able to caluculate fp8 data |
| experts | KTransformersExperts | KExpertsTorch | pytorch as backend |
| | | KExpertsMarlin | Marlin as backend |
| | | KExpertsCPU | llamafile as backend |
diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py
index 8211933..c6f8e3a 100644
--- a/ktransformers/server/backend/interfaces/transformers.py
+++ b/ktransformers/server/backend/interfaces/transformers.py
@@ -340,7 +340,7 @@ class TransformersInterface(BackendInterfaceBase):
sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
next_token = self.decode_one_tokens()
self.profiler.inc("decode")
- if next_token == self.tokenizer.eos_token_id:
+ if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
assert self.args.batch_size == 1
break
yield self.append_new_tokens(next_token)
diff --git a/merge_tensors/README.md b/merge_tensors/README.md
new file mode 100644
index 0000000..c786197
--- /dev/null
+++ b/merge_tensors/README.md
@@ -0,0 +1,36 @@
+# FP8 Linear Kernel.
+For DeepSeek-R1/V3, the DeepSeek-AI team provides fp8 safetensors. We have integrated the FP8 GPU kernel into the KTransformers. But to keep the experts still in CPU to save GPU memory, we still use ggml(GGUF tensors) quantization for experts. In this way, we can increase the precision in calculating attention, which may improve the model's performance.
+
+Therefore, to use fp8 linear kernel, we need to merge fp8 weights and gguf files. We have provides prepared weights in huggingface so that you can use them directly.
+
+[KVCache-ai/DeepSeek-V3](https://huggingface.co/KVCache-ai/DeepSeek-V3/upload/main)
+
+
+If you want to use other formats of ggml quantization, you can use the following script to merge them.
+
+## Example
+To use fp8 linear kernal and q4km experts.
+```shell
+bash
+python convert_model.py \
+ --safetensor_path \
+ --gguf_path \
+ --output_path
-在仅 24GB VRAM 的桌面上进行 1M 上下文本地推理
-
-
-https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12
+
+
+
更多高级功能即将推出,敬请期待!
@@ -116,7 +117,7 @@ KTransformers 的核心是一个用户友好的、基于模板的注入框架。
```python
with torch.device("meta"):
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
-optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
+optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
...
generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens=1000)
```
@@ -151,7 +152,7 @@ YAML 文件中的每个规则都有两部分:`match` 和 `replace`。`match`
致谢和贡献者
-KTransformer 的开发基于 Transformers 提供的灵活和多功能框架。我们还受益于 GGUF/GGML、Llamafile 和 Marlin 等高级内核。我们计划通过向上游贡献我们的修改来回馈社区。
+KTransformer 的开发基于 Transformers 提供的灵活和多功能框架。我们还受益于 GGUF/GGML、Llamafile 、 Marlin、sglang和flashinfer 等高级内核。我们计划通过向上游贡献我们的修改来回馈社区。
KTransformer 由清华大学 MADSys group 小组的成员以及 Approaching.AI 的成员积极维护和开发。我们欢迎新的贡献者加入我们,使 KTransformer 更快、更易于使用。
diff --git a/doc/SUMMARY.md b/doc/SUMMARY.md
index 47b4a02..94172ce 100644
--- a/doc/SUMMARY.md
+++ b/doc/SUMMARY.md
@@ -9,7 +9,7 @@
- [Why KTransformers So Fast](en/deepseek-v2-injection.md)
- [Injection Tutorial](en/injection_tutorial.md)
- [Multi-GPU Tutorial](en/multi-gpu-tutorial.md)
-- [Using FP8 GPU Kernel](../merge_tensors/README.md)
+- [Use FP8 GPU Kernel](en/fp8_kernel.md)
# Server
- [Server](en/api/server/server.md)
- [Website](en/api/server/website.md)
diff --git a/doc/en/FAQ.md b/doc/en/FAQ.md
index 93977cf..c8f58a2 100644
--- a/doc/en/FAQ.md
+++ b/doc/en/FAQ.md
@@ -45,7 +45,7 @@ from-https://github.com/kvcache-ai/ktransformers/issues/129#issue-2842799552
### Q: If I don't have enough VRAM, but I have multiple GPUs, how can I utilize them?
-Use the `--optimize_rule_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml` to load the two optimized rule yaml file. You may also use it as an example to write your own 4/8 gpu optimized rule yaml file.
+Use the `--optimize_config_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml` to load the two optimized rule yaml file. You may also use it as an example to write your own 4/8 gpu optimized rule yaml file.
> Note: The ktransformers' multi-gpu stratigy is pipline, which is not able to speed up the model's inference. It's only for the model's weight distribution.
diff --git a/doc/en/fp8_kernel.md b/doc/en/fp8_kernel.md
new file mode 100644
index 0000000..f1d86c3
--- /dev/null
+++ b/doc/en/fp8_kernel.md
@@ -0,0 +1,74 @@
+# FP8 Linear Kernel for DeepSeek-V3
+
+## Overview
+The DeepSeek-AI team provides FP8 safetensors for DeepSeek-R1/V3 models. We achieve performance optimization through the following works:
+- **FP8 GPU Kernel Integration**: FP8 linear layer acceleration kernels integrated in KTransformers
+- **Hybrid Quantization Architecture**:
+ - Attention and Shared-Expert modules use FP8 precision (enhances computational accuracy)
+ - Experts modules retain GGML quantization (GGUF format, reside in CPU to save GPU memory)
+
+So those who are persuing the best performance can use the FP8 linear kernel for DeepSeek-V3/R1.
+
+## Key Features
+✅ Hybrid Precision Architecture (FP8 + GGML)
+✅ Memory Optimization (~19GB VRAM usage)
+
+## Quick Start
+### Using Pre-Merged Weights
+
+Pre-merged weights are available on Hugging Face:
+[KVCache-ai/DeepSeek-V3](https://huggingface.co/KVCache-ai/DeepSeek-V3)
+[KVCache-ai/DeepSeek-R1](https://huggingface.co/KVCache-ai/DeepSeek-R1)
+> Please confirm the weights are fully uploaded before downloading. The large file size may extend Hugging Face upload time.
+
+
+Download Pre-Merged Weights
+```shell
+pip install -U huggingface_hub
+
+# Optional: Use HF Mirror for faster downloads in special area.
+# export HF_ENDPOINT=https://hf-mirror.com
+
+huggingface-cli download --resume-download KVCache-ai/DeepSeek-V3 --local-dir
+```
+### Using merge scripts
+If you got local DeepSeek-R1/V3 fp8 safetensors and q4km gguf weights, you can merge them using the following scripts.
+
+```shell
+python convert_model.py \
+ --safetensor_path \
+ --gguf_path \
+ --output_path
+```
+
+* `--safetensor_path`: input path of safetensor file([Download](https://huggingface.co/deepseek-ai/DeepSeek-V3/tree/main)).
+* `--gguf_path`: input path of gguf folder ([Download](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)).
+* `--output_path`: output path of merged file.
+
+
+### Execution Notes
+
+Launch local_chat.py with custom quantized experts
+```shell
+python ktransformers/local_chat.py \
+ --model_path deepseek-ai/DeepSeek-V3 \
+ --gguf_path \
+ --optimize_config_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml \
+ --cpu_infer
+```
+
+
+## Notes
+
+⚠️ Hardware Requirements
+* Recommended minimum 19GB available VRAM for FP8 kernel.
+* Requires GPU with FP8 support (e.g., 4090)
+
+⏳ First-Run Optimization
+JIT compilation causes longer initial execution (subsequent runs retain optimized speed).
+
+🔄 Temporary Interface
+Current weight loading implementation is provisional - will be refined in future versions
+
+📁 Path Specification
+Despite hybrid quantization, merged weights are stored as .safetensors - pass the containing folder path to `--gguf_path`
\ No newline at end of file
diff --git a/doc/en/install.md b/doc/en/install.md
index 269e8fb..2a4a6af 100644
--- a/doc/en/install.md
+++ b/doc/en/install.md
@@ -141,7 +141,7 @@ It features the following arguments:
- `--gguf_path` (required): Path of a directory containing GGUF files which could that can be downloaded from [Hugging Face](https://huggingface.co/mzwing/DeepSeek-V2-Lite-Chat-GGUF/tree/main). Note that the directory should only contains GGUF of current model, which means you need one separate directory for each model.
-- `--optimize_rule_path` (required except for Qwen2Moe and DeepSeek-V2): Path of YAML file containing optimize rules. There are two rule files pre-written in the [ktransformers/optimize/optimize_rules](ktransformers/optimize/optimize_rules) directory for optimizing DeepSeek-V2 and Qwen2-57B-A14, two SOTA MoE models.
+- `--optimize_config_path` (required except for Qwen2Moe and DeepSeek-V2): Path of YAML file containing optimize rules. There are two rule files pre-written in the [ktransformers/optimize/optimize_rules](ktransformers/optimize/optimize_rules) directory for optimizing DeepSeek-V2 and Qwen2-57B-A14, two SOTA MoE models.
- `--max_new_tokens`: Int (default=1000). Maximum number of new tokens to generate.
diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py
index d5e74de..920a3f6 100644
--- a/ktransformers/local_chat.py
+++ b/ktransformers/local_chat.py
@@ -54,7 +54,7 @@ default_optimize_rules = {
def local_chat(
model_path: str | None = None,
- optimize_rule_path: str = None,
+ optimize_config_path: str = None,
gguf_path: str | None = None,
max_new_tokens: int = 300,
cpu_infer: int = Config().cpu_infer,
@@ -95,12 +95,12 @@ def local_chat(
config, trust_remote_code=True, attn_implementation="flash_attention_2"
)
- if optimize_rule_path is None:
+ if optimize_config_path is None:
if config.architectures[0] in default_optimize_rules:
print("using default_optimize_rule for", config.architectures[0])
- optimize_rule_path = default_optimize_rules[config.architectures[0]]
+ optimize_config_path = default_optimize_rules[config.architectures[0]]
else:
- optimize_rule_path = input(
+ optimize_config_path = input(
"please input the path of your rule file(yaml file containing optimize rules):"
)
@@ -108,7 +108,7 @@ def local_chat(
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
- optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
+ optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
try:
model.generation_config = GenerationConfig.from_pretrained(model_path)
diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py
index 49a3f16..55e32de 100644
--- a/ktransformers/server/backend/interfaces/ktransformers.py
+++ b/ktransformers/server/backend/interfaces/ktransformers.py
@@ -35,9 +35,9 @@ class KTransformersInterface(TransformersInterface):
with torch.device("meta"):
self.model = custom_models[config.architectures[0]](config)
if default_args.optimize_config_path is None:
- optimize_rule_path = default_optimize_rules[config.architectures[0]]
+ optimize_config_path = default_optimize_rules[config.architectures[0]]
else:
- optimize_rule_path = args.optimize_config_path
+ optimize_config_path = args.optimize_config_path
# print(optimize_config)
@@ -47,7 +47,7 @@ class KTransformersInterface(TransformersInterface):
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
)
- optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
+ optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
self.device_map = self.model.gguf_loader.tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
diff --git a/merge_tensors/README.md b/merge_tensors/README.md
deleted file mode 100644
index c786197..0000000
--- a/merge_tensors/README.md
+++ /dev/null
@@ -1,36 +0,0 @@
-# FP8 Linear Kernel.
-For DeepSeek-R1/V3, the DeepSeek-AI team provides fp8 safetensors. We have integrated the FP8 GPU kernel into the KTransformers. But to keep the experts still in CPU to save GPU memory, we still use ggml(GGUF tensors) quantization for experts. In this way, we can increase the precision in calculating attention, which may improve the model's performance.
-
-Therefore, to use fp8 linear kernel, we need to merge fp8 weights and gguf files. We have provides prepared weights in huggingface so that you can use them directly.
-
-[KVCache-ai/DeepSeek-V3](https://huggingface.co/KVCache-ai/DeepSeek-V3/upload/main)
-
-
-If you want to use other formats of ggml quantization, you can use the following script to merge them.
-
-## Example
-To use fp8 linear kernal and q4km experts.
-```shell
-bash
-python convert_model.py \
- --safetensor_path \
- --gguf_path \
- --output_path