From 7b7c6a657d8f4bb7a0e330dd91c03a3e81e802f4 Mon Sep 17 00:00:00 2001 From: Azure Date: Sat, 22 Feb 2025 13:05:08 +0000 Subject: [PATCH 01/29] =?UTF-8?q?Add=20fp8=20linear=20kernel;\n=20Add=20em?= =?UTF-8?q?pty=20cache=20to=20fit=20in=2016G=20VRAM;=20By=20'wkGCaSS=20-?= =?UTF-8?q?=20=E7=9F=A5=E4=B9=8E=20https://zhuanlan.zhihu.com/p/2549161122?= =?UTF-8?q?5'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ktransformers_ext/triton/fp8gemm.py | 191 ++++++++++++++++++ ktransformers/operators/linear.py | 61 +++++- ktransformers/tests/triton_fp8gemm_test.py | 73 +++++++ ktransformers/util/custom_gguf.py | 6 + ktransformers/util/utils.py | 2 +- 5 files changed, 331 insertions(+), 2 deletions(-) create mode 100644 ktransformers/ktransformers_ext/triton/fp8gemm.py create mode 100644 ktransformers/tests/triton_fp8gemm_test.py diff --git a/ktransformers/ktransformers_ext/triton/fp8gemm.py b/ktransformers/ktransformers_ext/triton/fp8gemm.py new file mode 100644 index 0000000..4da4cfe --- /dev/null +++ b/ktransformers/ktransformers_ext/triton/fp8gemm.py @@ -0,0 +1,191 @@ +from typing import Tuple + +import torch +import triton +import triton.language as tl +from triton import Config + + +@triton.jit +def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + """ + Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the input tensor. + y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. + s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + + Returns: + None + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.max(tl.abs(x)) / 448. + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) + + +def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), 'Input tensor must be contiguous' + assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})' + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) + act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + return y, s + + +@triton.jit +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """ + Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + +def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: + """ + Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M, N). + block_size (int, optional): The block size to use for dequantization. Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + """ + assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' + assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y + + +fp8_gemm_configs = [ + Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8) + for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6] +] + +@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) +@triton.jit +def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, + a_s_ptr, b_s_ptr, + M, N: tl.constexpr, K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr): + """ + Performs a matrix multiplication operation on FP8 matrices with scaling factors. + + Args: + a_ptr (tl.tensor): Pointer to the first input matrix A. + b_ptr (tl.tensor): Pointer to the second input matrix B. + c_ptr (tl.tensor): Pointer to the output matrix C. + a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. + b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. + M (int): Number of rows in matrix A and C. + N (tl.constexpr): Number of columns in matrix B and C. + K (tl.constexpr): Number of columns in matrix A and rows in matrix B. + BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. + BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. + BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): + """ + Perform a matrix multiplication using FP8 precision. + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. + b (torch.Tensor): The second input matrix, must be contiguous. + b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous' + assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous' + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) + fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) + return c \ No newline at end of file diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 08a2cca..5aff964 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -25,6 +25,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl ) from ktransformers.operators.base_operator import BaseInjectedModule from transformers.configuration_utils import PretrainedConfig +from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant from abc import ABC, abstractmethod import sys, os sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build")) @@ -164,7 +165,65 @@ class KLinearTorch(KLinearBase): if self.has_bias: self.bias = None - +class KLinearFP8(KLinearBase): + marlin_q_w: torch.Tensor + marlin_s: torch.Tensor + g_idx: torch.Tensor + sort_indices: torch.Tensor + has_bias: bool + weight: torch.Tensor + scale_w: torch.Tensor + bias: torch.Tensor + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + device: str = "cuda", + block_size: int = 128, + **kwargs, + ): + super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) + self.has_bias = False + self.dtype = torch.get_default_dtype() + self.block_size = block_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.to(self.device) + orig_shape = list(x.shape) + orig_dtype = x.dtype + x = x.reshape(-1, orig_shape[-1]) + x_quantized, scale_x = act_quant(x, self.block_size) + y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight.scale) + if self.bias is not None: + y += self.bias + return y.to(orig_dtype).reshape(orig_shape) + + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): + if device is None: device = self.device + if w is None: + w = self.load_weight(device=device) + if isinstance(w, nn.Parameter): + self.weight = w.to(device) + self.has_bias = False + elif isinstance(w, tuple): + self.weight = w[0].to(device) + self.bias = w[1].to(device) + self.has_bias = True + else: + raise ValueError("Invalid weight type") + self.weight = self.weight.to(device) + if self.has_bias: + self.bias = self.bias.to(device) + + def unload(self): + if self.weight is not None: + self.weight = None + if self.has_bias: + self.bias = None + + class KLinearMarlin(KLinearBase): marlin_q_w: torch.Tensor marlin_s: torch.Tensor diff --git a/ktransformers/tests/triton_fp8gemm_test.py b/ktransformers/tests/triton_fp8gemm_test.py new file mode 100644 index 0000000..bb3801c --- /dev/null +++ b/ktransformers/tests/triton_fp8gemm_test.py @@ -0,0 +1,73 @@ +import torch +import torch.nn.functional as F +from typing import Optional +import pytest +from typing import Tuple, Optional, Literal + +# use dir path +import os +import sys +sys.path.insert(0, "/home/azure/ktransformers") +print(sys.path) +from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant +from safetensors import safe_open + +world_size = 1 +rank = 0 +block_size = 128 +gemm_impl: Literal["bf16", "fp8"] = "bf16" +# Assuming `fp8_gemm`, `act_quant`, `weight_dequant` and other relevant functions are already defined + +def test_fp8_gemm_vs_torch_matmul(): + # Test case 1: Create random matrices of size (M, K) and (K, N) + M, K, N = 64, 128, 256 # Matrix dimensions + x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda') + weight = torch.randn(N, K, dtype=torch.bfloat16, device='cuda') + + # Apply act_quant to both matrices + x_quantized, scale_x = act_quant(x, block_size) + weight_quantized, scale_w = act_quant(weight, block_size) + + # mk continous + x_quantized = x_quantized.contiguous() + weight_quantized = weight_quantized.contiguous() + scale_x = scale_x.contiguous() + scale_w = scale_w.contiguous() + + # Perform fp8_gemm using the quantized tensors + result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight_quantized, scale_w) + + # Perform torch.matmul using the original floating point tensors + result_torch_matmul = torch.matmul(x, weight.T) + print(f'result_torch_matmul: {result_torch_matmul.shape}') + print(f'result_fp8_gemm: {result_fp8_gemm.shape}') + + print(f"result_fp8_gemm:\n {result_fp8_gemm}") + print(f"result_torch_matmul:\n {result_torch_matmul}") + +def test_fp8_gemm_vs_torch_matmul_load(): + file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors" + with safe_open(file_path, framework="pt", device=0) as f: + weight = f.get_tensor("model.layers.0.mlp.down_proj.weight") + scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv") + + # weight_dequant + weight_dequantized = weight_dequant(weight, scale) + print(f"weight_dequantized: {weight_dequantized.shape}") + N, K = weight_dequantized.shape + M = 64 + x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda') + x_quantized, scale_x = act_quant(x, block_size) + + # Test case 1: quantized x matmal with undequantized weight + result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) + print(f"result_fp8_gemm:\n {result_fp8_gemm}") + + # Perform torch.matmul using the original floating point tensors + result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T) + print(f"result_torch_matmul:\n {result_torch_matmul}") + +if __name__ == "__main__": + test_fp8_gemm_vs_torch_matmul() + test_fp8_gemm_vs_torch_matmul_load() + \ No newline at end of file diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index eaa1a7d..26afd39 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -127,6 +127,7 @@ GGML_BLOCK_SIZES = { "Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2, "Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2, "IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64, + "FP8": 1, } GGML_ELEMENTS_PER_BLOCK = { @@ -142,6 +143,7 @@ GGML_ELEMENTS_PER_BLOCK = { "Q5_K": 256, "Q6_K": 256, "IQ4_XS": 256, + "FP8": 1, } DATA_TYPES = { @@ -158,6 +160,7 @@ DATA_TYPES = { "uint64": 10, "int64": 11, "float64": 12, + "FP8": 13, } class GGUFLoader: @@ -393,6 +396,9 @@ def read_value(f, data_type): elem_type, count = struct.unpack(" Date: Sun, 23 Feb 2025 10:19:19 +0800 Subject: [PATCH 02/29] musa: support bf16 Signed-off-by: Xiaodong Ye --- ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h | 4 +++- setup.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h b/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h index 7c94102..1892221 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h +++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h @@ -1,7 +1,9 @@ #pragma once #include +#include #define cudaLaunchHostFunc musaLaunchHostFunc #define cudaStream_t musaStream_t -#define cudaHostFn_t musaHostFn_t \ No newline at end of file +#define cudaHostFn_t musaHostFn_t +#define nv_bfloat16 mt_bfloat16 \ No newline at end of file diff --git a/setup.py b/setup.py index 345fdb1..ea15482 100644 --- a/setup.py +++ b/setup.py @@ -350,6 +350,7 @@ elif MUSA_HOME is not None: "at::cuda": "at::musa", "#include ": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"", "#include ": "#include \"torch_musa/csrc/core/MUSAGuard.h\"", + "nv_bfloat16": "mt_bfloat16", }).run() ops_module = MUSAExtension('KTransformersOps', [ 'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu', From 006e8c6abc6503921411db935efd80ad7f16032d Mon Sep 17 00:00:00 2001 From: Atream Date: Sun, 23 Feb 2025 07:40:47 +0000 Subject: [PATCH 03/29] remove causal mask --- ktransformers/operators/models.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index 5d2e911..3877dbc 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -649,9 +649,12 @@ class KDeepseekV2Model(BaseInjectedModule): if per_layer_prefill_flag: causal_mask = None else: - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + if os.name == 'nt': + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + else: + causal_mask = None # embed positions hidden_states = inputs_embeds From f327695079298e703e24eb8a9f2493fe4e2bde80 Mon Sep 17 00:00:00 2001 From: Atream Date: Mon, 24 Feb 2025 09:30:54 +0000 Subject: [PATCH 04/29] fix KExpertsMarlin on GPU with out CUDA Graph --- .../optimize/optimize_rules/Moonlight-16B-A3B.yaml | 11 +++++++++++ ktransformers/util/custom_gguf.py | 2 ++ 2 files changed, 13 insertions(+) diff --git a/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml b/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml index 4c8eca2..6cea246 100644 --- a/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml +++ b/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml @@ -53,6 +53,17 @@ generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module +# if want to use more VRAM, use experts Marlin and disable CUDA Graph(disable CUDA Graph may cause low performance) +#- match: +# name: "^model\\.layers\\..*\\.mlp\\.experts$" +# replace: +# class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism +# kwargs: +# prefill_device: "cuda" +# prefill_op: "KExpertsTorch" +# generate_device: "cuda" +# generate_op: "KExpertsMarlin" +# recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 919f432..72c3efb 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -310,6 +310,8 @@ class GGUFLoader: values = GGML_DEQUANTIZE[ggml_name](data) values = torch.from_numpy(values.copy()) + if ggml_name == "BF16": + values = values.view(torch.bfloat16) values = values.view(shape[-2::-1]) return values From 581a524f65db422011d9a8db99439d658223db6f Mon Sep 17 00:00:00 2001 From: Azure Date: Mon, 24 Feb 2025 11:16:23 +0000 Subject: [PATCH 05/29] Add data loader to read special weights for fp8; Add special weight process script --- .../ktransformers_ext/triton/fp8gemm.py | 1 + ktransformers/operators/experts.py | 11 +- ktransformers/operators/gate.py | 13 +- ktransformers/operators/linear.py | 38 ++-- ...pSeek-V3-Chat-fp8-linear-ggml-experts.yaml | 63 ++++++ ktransformers/tests/triton_fp8gemm_test.py | 47 +++- ktransformers/util/custom_gguf.py | 19 +- ktransformers/util/custom_loader.py | 86 +++++++ ktransformers/util/utils.py | 15 +- merge_tensors/merge_safetensor_gguf.py | 214 ++++++++++++++++++ 10 files changed, 481 insertions(+), 26 deletions(-) create mode 100644 ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml create mode 100644 ktransformers/util/custom_loader.py create mode 100644 merge_tensors/merge_safetensor_gguf.py diff --git a/ktransformers/ktransformers_ext/triton/fp8gemm.py b/ktransformers/ktransformers_ext/triton/fp8gemm.py index 4da4cfe..7d5b72e 100644 --- a/ktransformers/ktransformers_ext/triton/fp8gemm.py +++ b/ktransformers/ktransformers_ext/triton/fp8gemm.py @@ -1,3 +1,4 @@ +# Adopted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py from typing import Tuple import torch diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 21b4830..1ea244a 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -245,7 +245,16 @@ class KExpertsCPU(KExpertsBase): down_type = None for key in keys: - if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: + if self.gguf_loader.safetensor_loader is not None: + # using a temp ugly way to temprary load the tensor + gate = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.weight").numpy() + up = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.weight").numpy() + down = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.weight").numpy() + gate_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.ggml_type").item() + up_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.ggml_type").item() + down_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.ggml_type").item() + + elif key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight") up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight") down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight") diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py index 52bb33a..d908093 100644 --- a/ktransformers/operators/gate.py +++ b/ktransformers/operators/gate.py @@ -67,7 +67,14 @@ class KMoEGateBase(ABC): for key in keys: key = ".".join(key.split(".")[:-1]) - if key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info: + if self.gguf_loader.safetensor_loader is not None: + targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"] + weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight") + e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias") + weight_type = weight.dtype + e_score_correction_bias_type = e_score_correction_bias.dtype + res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type} + elif key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info: targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"] tensors = self.load_multi(key, targets, device=device) weight = tensors[".ffn_gate_inp.weight"] @@ -116,8 +123,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"]) else: raise ValueError("Invalid weight type") - self.orig_module.weight = self.orig_module.weight.to(device) - self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device) + self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device)) + self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device)) def unload(self): if self.weight is not None: diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 5aff964..e778102 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -76,7 +76,13 @@ class KLinearBase(ABC): keys = [self.key] for key in keys: - if key + ".weight" in self.gguf_loader.tensor_file_map: + if self.gguf_loader.safetensor_loader is not None: + # using safetensor_loader + tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight') + weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv') + return nn.Parameter(tensor), nn.Parameter(weight_scale_inv) + + elif key + ".weight" in self.gguf_loader.tensor_file_map: if key + ".bias" in self.gguf_loader.tensor_file_map: tensors = self.load_multi(key, ["weight", "bias"], device=device) tensor = tensors["weight"] @@ -166,6 +172,8 @@ class KLinearTorch(KLinearBase): self.bias = None class KLinearFP8(KLinearBase): + # this kernel requires special handling for weight + # Please load the weight file downloaded from KVCache.AI marlin_q_w: torch.Tensor marlin_s: torch.Tensor g_idx: torch.Tensor @@ -191,26 +199,20 @@ class KLinearFP8(KLinearBase): def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.to(self.device) - orig_shape = list(x.shape) - orig_dtype = x.dtype - x = x.reshape(-1, orig_shape[-1]) + orig_dtype = x.dtype x_quantized, scale_x = act_quant(x, self.block_size) - y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight.scale) - if self.bias is not None: - y += self.bias - return y.to(orig_dtype).reshape(orig_shape) + y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight_scale_inv) + return y.to(dtype=orig_dtype) def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): if device is None: device = self.device if w is None: w = self.load_weight(device=device) - if isinstance(w, nn.Parameter): - self.weight = w.to(device) - self.has_bias = False - elif isinstance(w, tuple): + ### TODO fit weight_inv format + if isinstance(w, tuple): self.weight = w[0].to(device) - self.bias = w[1].to(device) - self.has_bias = True + self.weight_scale_inv = w[1].to(device) + self.has_bias = False else: raise ValueError("Invalid weight type") self.weight = self.weight.to(device) @@ -425,7 +427,8 @@ class KLinearCPUInfer(KLinearBase): LINEAR_MAP = { "KLinearMarlin": KLinearMarlin, "KLinearTorch": KLinearTorch, - "KLinearCPUInfer": KLinearCPUInfer + "KLinearCPUInfer": KLinearCPUInfer, + "KLinearFP8": KLinearFP8, } class KTransformersLinear(BaseInjectedModule, KLinearBase): @@ -472,10 +475,11 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): def forward(self, x): if self.mode == InferenceState.PREFILL: assert self.prefill_linear is not None, "cpu linear is not initialized" - return self.prefill_linear.forward(x) + y = self.prefill_linear.forward(x) else: assert self.generate_linear is not None, "gpu linear is not initialized" - return self.generate_linear.forward(x) + y = self.generate_linear.forward(x) + return y def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE): if not mode: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml new file mode 100644 index 0000000..25f021e --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml @@ -0,0 +1,63 @@ +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearFP8" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/tests/triton_fp8gemm_test.py b/ktransformers/tests/triton_fp8gemm_test.py index bb3801c..58888d6 100644 --- a/ktransformers/tests/triton_fp8gemm_test.py +++ b/ktransformers/tests/triton_fp8gemm_test.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from typing import Optional import pytest from typing import Tuple, Optional, Literal - +import time # use dir path import os import sys @@ -56,18 +56,61 @@ def test_fp8_gemm_vs_torch_matmul_load(): print(f"weight_dequantized: {weight_dequantized.shape}") N, K = weight_dequantized.shape M = 64 - x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda') + x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda') x_quantized, scale_x = act_quant(x, block_size) # Test case 1: quantized x matmal with undequantized weight result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) print(f"result_fp8_gemm:\n {result_fp8_gemm}") + print(f"dtype {result_fp8_gemm.dtype}") # Perform torch.matmul using the original floating point tensors result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T) print(f"result_torch_matmul:\n {result_torch_matmul}") +def test_fp8_gemm_tplops(): + file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors" + with safe_open(file_path, framework="pt", device=0) as f: + weight = f.get_tensor("model.layers.0.mlp.down_proj.weight") + scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv") + + # weight_dequant + weight_dequantized = weight_dequant(weight, scale) + print(f"weight_dequantized: {weight_dequantized.shape}") + N, K = weight_dequantized.shape + M = 6400 + x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda') + # x_quantized, scale_x = act_quant(x, block_size) + + # Calculate time for 1000 fp8_gemm + i = 10 + flops_per_gemm = 2 * M * N * K + total_flops = i * flops_per_gemm + + x_quantized, scale_x = act_quant(x, block_size) + result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) + x_quantized, scale_x = act_quant(x, block_size) + result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) + + + t0 = time.time() + torch.cuda.synchronize() + for i in range(i): + x_quantized, scale_x = act_quant(x, block_size) + result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) + torch.cuda.synchronize() + t1 = time.time() + + total_time = t1 - t0 + tflops = total_flops / total_time / 1e12 + print(f"total_time: {total_time}") + print(f"tflops: {tflops}") + + + + if __name__ == "__main__": test_fp8_gemm_vs_torch_matmul() test_fp8_gemm_vs_torch_matmul_load() + test_fp8_gemm_tplops() \ No newline at end of file diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 26afd39..d054ad3 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -25,6 +25,7 @@ import os from enum import IntEnum import torch import KTransformersOps +from .custom_loader import SafeTensorLoader class GGMLQuantizationType(IntEnum): F32 = 0 @@ -168,12 +169,15 @@ class GGUFLoader: gguf_path: str tensor_file_map: dict # {tensor_name: tensor_file_path} gguf_file_meta: dict + safetensor_loader: SafeTensorLoader def __init__(self, gguf_path: str): # Check dir exist if not os.path.exists(gguf_path): raise FileNotFoundError(f"GGUF dir not found: {gguf_path}") if os.path.isfile(gguf_path): gguf_path = os.path.dirname(gguf_path) + + self.safetensor_loader = None self.tensor_info = {} self.gguf_path = gguf_path @@ -181,7 +185,13 @@ class GGUFLoader: self.file_data_map = {} self.gguf_file_meta = {} self.tensor_device_map = {} - + + # I know this is ugly, but I don't want to change the original code too much + # TODO: merge gguf load and other loads. + safetensor_loader = SafeTensorLoader(gguf_path) + if safetensor_loader.tensor_file_map: + self.safetensor_loader = safetensor_loader + return # Walk through all the .gguf files in the directory found_gguf = False for root, dirs, files in os.walk(gguf_path): @@ -288,6 +298,13 @@ class GGUFLoader: itemsize = int(np.empty([], dtype = item_type).itemsize) return mmap_data[offset : offset + itemsize * item_count] + def get_undequanted_tensor_and_ggml_type(self, name): + t = self.tensor_info[name] + data = self.get_mmap_tensor(name) + ggml_type = t["ggml_type"] + data = torch.from_numpy(data) + return data, ggml_type + def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor: t = self.tensor_info[name] if device.lower() == "cpu": diff --git a/ktransformers/util/custom_loader.py b/ktransformers/util/custom_loader.py new file mode 100644 index 0000000..ecc09a0 --- /dev/null +++ b/ktransformers/util/custom_loader.py @@ -0,0 +1,86 @@ +import struct +import warnings +import numpy as np +import re +import numpy.typing as npt +from typing import Sequence +import os +from enum import IntEnum +import torch +import KTransformersOps +from safetensors import safe_open +from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant +from safetensors.torch import save_file + +class SafeTensorLoader: + tensor_file_map = {} + tensor_type_map = {} + file_handle_map = {} + + def __init__(self, file_path: str): + self.__load_tensor_file_map(file_path) + + def __load_tensor_file_map(self, file_path: str): + # 处理传入路径,确保是文件夹路径 + if not os.path.exists(file_path): + raise FileNotFoundError(f"Path not found: {file_path}") + if os.path.isfile(file_path): + folder_path = os.path.dirname(file_path) + else: + folder_path = file_path + + found_safetensor = False + for root, _, files in os.walk(folder_path): + files = sorted(files) + for file in files: + if file.endswith(".safetensors"): + found_safetensor = True + file_path = os.path.join(root, file) + if file not in self.file_handle_map: + try: + handle = safe_open(file_path, framework="pt") + self.file_handle_map[file] = handle + except Exception as e: + print(f"Error opening Safetensor file {file_path}: {e}") + continue + + f = self.file_handle_map.get(file) + if f is None: + continue + try: + for key in f.keys(): + self.tensor_file_map[key] = file + except Exception as e: + print(f"Error reading Safetensor file {file_path}: {e}") + + # if not found_safetensor: + # raise FileNotFoundError(f"No Safetensor files found in {folder_path}") + + def load_tensor(self, key: str, device: str="cpu"): + if key not in self.tensor_file_map: + raise KeyError(f"Key {key} not found in Safetensor files") + file = self.tensor_file_map[key] + f = self.file_handle_map.get(file) + if f is None: + raise FileNotFoundError(f"File {file} not found in Safetensor files") + tensor = f.get_tensor(key) + return tensor.to(device) + + def close_all_handles(self): + for handle in self.file_handle_map.values(): + handle.close() + self.file_handle_map.clear() + + def load_dequantized_tensor(self, key:str, device: str="cpu"): + if key not in self.tensor_file_map: + raise KeyError(f"Key {key} not found in Safetensor files") + file = self.tensor_file_map[key] + f = self.file_handle_map.get(file) + if f is None: + raise FileNotFoundError(f"File {file} not found in Safetensor files") + tensor = f.get_tensor(key).to(device) + if key.endswith(".weight"): + if key[:-7] + ".weight_scale_inv" in self.tensor_file_map: + weight_scale_inv = f.get_tensor(key[:-7] + ".weight_scale_inv").to(device) + tensor = weight_dequant(tensor, weight_scale_inv) + return tensor.to(device) \ No newline at end of file diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 81d007c..1c21135 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -66,12 +66,23 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str for name, param in local_state.items(): key = prefix + name translated_key = translate_name_to_gguf(key) - if translated_key in gguf_loader.tensor_file_map: + + # TODO: Merge all loader. + # I know this is ugly but lets do it for now. + if gguf_loader.safetensor_loader is not None: + load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor + tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map + else: + load_dequantized_tensor = gguf_loader.load_gguf_tensor + tensor_file_map = gguf_loader.tensor_file_map + + if translated_key in tensor_file_map: target_dtype = torch.get_default_dtype() device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map) print(f"loading {translated_key} to {device}") torch.cuda.empty_cache() # To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225" - weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype) + # weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype) + weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype) set_param(module, name, weights) del weights else: diff --git a/merge_tensors/merge_safetensor_gguf.py b/merge_tensors/merge_safetensor_gguf.py new file mode 100644 index 0000000..7aeb62d --- /dev/null +++ b/merge_tensors/merge_safetensor_gguf.py @@ -0,0 +1,214 @@ +# this script targets to merge the fp8 safe tensor and the gguf quantized tensors. + +import os +# insert the path of the project +import sys +sys.path.insert(0, "/home/azure/ktransformers") +import argparse +import torch +from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf +from safetensors import safe_open +from safetensors.torch import save_file +import re +from collections import defaultdict + +def read_safetensor_keys_from_folder(folder_path)->dict: + """ + :param folder_path: folder path + :return: key_to_file_map + """ + # check if the folder path is exist + if not os.path.exists(folder_path): + raise FileNotFoundError(f"GGUF dir not found: {folder_path}") + if os.path.isfile(folder_path): + folder_path = os.path.dirname(folder_path) + + key_to_file_map = {} + + found_safetensor = False + for root, dirs, files in os.walk(folder_path): + # sort files + files = sorted(files) + for file in files: + if file.endswith(".safetensors"): + found_safetensor = True + file_path = os.path.join(root, file) + try: + with safe_open(file_path, framework="pt") as f: + for key in f.keys(): + if "model.layers.61" in key: + # skip MTP layer + continue + # try: + # if int(key.split('.')[2]) > 4: + # continue + # except: + # pass + key_to_file_map[key] = file_path + except Exception as e: + print(f"Error reading Safetensor file {file_path}: {e}") + + if not found_safetensor: + raise FileNotFoundError(f"No Safetensor files found in {folder_path}") + + return key_to_file_map + +tensor_from_gguf = [] # todo: add keys in gguf that should be used in the final tensor + +def translate_name(name:str)->str: + """ + :param name: name of the tensor + :return: translated name + """ + name = translate_name_to_gguf(name) + name = name.replace(".up_proj.", ".ffn_up_exps.") + name = name.replace(".down_proj.", ".ffn_down_exps.") + name = name.replace(".gate_proj.", ".ffn_gate_exps.") + name = name.replace(".ffn_gate_inp.e_score_correction_bias", ".exp_probs_b.bias") + return name + + +def combine_tensor_sources(safetensor_path:str, gguf_path:str): + gguf_loader = GGUFLoader(gguf_path) + gguf_tensor_file_map = gguf_loader.tensor_file_map + safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path) + + # build a map for the key to the tensor + # according to the key, we can get the tensor from the file + + target_tensor_map = {} + for key in safetensor_tensor_file_map.keys(): + # for all experts, we use the gguf tensor + if ".mlp.experts." in key: + if '.weight_scale_inv' in key: + continue + key = '.'.join(key.split('.')[:5]+key.split('.')[-2:]) + translated_key = translate_name(key) + target_tensor_map[key] = gguf_tensor_file_map[translated_key] + continue + + if any(target_key in key for target_key in tensor_from_gguf): + target_tensor_map[key] = gguf_tensor_file_map[translate_name(key)] + else: + target_tensor_map[key] = safetensor_tensor_file_map[key] + + return target_tensor_map, gguf_loader + +def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader): + # Ensure output directory exists + os.makedirs(output_path, exist_ok=True) + + # Cache for safetensor file handles and GGUF loaders + safetensors_cache = {} + gguf_cache = {} + + # Group tensors by layer + layer_groups = defaultdict(list) + non_layer_keys = [] + layer_pattern = re.compile(r'\.layers\.(\d+)\.') + + for key in target_tensor_map: + match = layer_pattern.search(key) + if match: + layer_num = int(match.group(1)) + layer_groups[layer_num].append(key) + else: + non_layer_keys.append(key) + + # Calculate total shards + total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1 + if total_shards == 0: + raise ValueError("No tensors to save") + + shard_idx = 0 + + # Save non-layer tensors to the first shard if they exist + if non_layer_keys: + tensors = {} + for key in non_layer_keys: + file_path = target_tensor_map[key] + tensor = None + ggml_type = None + if file_path.endswith('.safetensors'): + if file_path not in safetensors_cache: + safetensors_cache[file_path] = safe_open(file_path, framework='pt') + f = safetensors_cache[file_path] + tensor = f.get_tensor(key) + elif file_path.endswith('.gguf'): + gguf_name = translate_name(key) + tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name) + else: + raise ValueError(f"Unsupported file format: {file_path}") + tensors[translate_name(key)] = tensor + if ggml_type: + ggml_type = torch.tensor(ggml_type) + ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type" + tensors[ggml_key] = ggml_type + + output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors") + print(f"Saving non-layer tensors to {output_file}") + save_file(tensors, output_file) + print(tensors.keys()) + + shard_idx += 1 + + # Save each layer's tensors to subsequent shards + for layer_num in sorted(layer_groups.keys()): + layer_keys = layer_groups[layer_num] + tensors = {} + for key in layer_keys: + file_path = target_tensor_map[key] + tensor = None + ggml_type = None + if file_path.endswith('.safetensors'): + if file_path not in safetensors_cache: + safetensors_cache[file_path] = safe_open(file_path, framework='pt') + f = safetensors_cache[file_path] + tensor = f.get_tensor(key) + tensor_info = tensor.shape + elif file_path.endswith('.gguf'): + gguf_name = translate_name(key) + tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name) + # tensor_info = gguf_loader.tensor_info[gguf_name] + # ggml_type = gguf_loader.tensor_info[gguf_name]['ggml_type'] + else: + raise ValueError(f"Unsupported file format: {file_path}") + tensors[translate_name(key)] = tensor + if ggml_type: + ggml_type = torch.tensor(ggml_type) + ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type" + tensors[ggml_key] = ggml_type + + output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors") + print(f"Saving layer {layer_num} to {output_file}") + print(tensors.keys()) + save_file(tensors, output_file) + shard_idx += 1 + + return + +def main(): + # 创建命令行参数解析器 + parser = argparse.ArgumentParser(description="Read parameters from Safetensor and GGUF files") + parser.add_argument("--safetensor_path", type=str, help="Path to the Safetensor file", default="/mnt/data/model/DeepSeek-V3") + parser.add_argument("--gguf_path", type=str, help="Path to the GGUF file", default="/mnt/data/model/DeepseekV3-q4km-gguf") + parser.add_argument("--output_path", type=str, help="Path to the output file", default="/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8") + + # print all the arguments + print("All the arguments:") + print(parser.parse_args()) + + # 解析命令行参数 + args = parser.parse_args() + + safetensor_path = args.safetensor_path + gguf_path = args.gguf_path + output_path = args.output_path + + target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path) + write_combined_tensor(target_tensor_map, output_path, gguf_loader) + + return + +if __name__ == "__main__": + main() \ No newline at end of file From f88c05a6f18c349639aa58823cc8072d52b2c349 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Mon, 24 Feb 2025 21:55:30 +0800 Subject: [PATCH 06/29] Ensure backward compatibility with Torch 2.2 Signed-off-by: Xiaodong Ye --- .../ktransformers_ext/cuda/binding.cpp | 39 +++++++++++-------- .../ktransformers_ext/cuda/custom_gguf/ops.h | 18 ++++----- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/ktransformers/ktransformers_ext/cuda/binding.cpp b/ktransformers/ktransformers_ext/cuda/binding.cpp index 0b1994d..5bba873 100644 --- a/ktransformers/ktransformers_ext/cuda/binding.cpp +++ b/ktransformers/ktransformers_ext/cuda/binding.cpp @@ -1,9 +1,9 @@ /** - * @Description : + * @Description : * @Author : Azure-Tang, Boxin Zhang * @Date : 2024-07-25 13:38:30 * @Version : 0.2.2 - * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "custom_gguf/ops.h" @@ -20,38 +20,45 @@ PYBIND11_MODULE(KTransformersOps, m) { - m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { - return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); + return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q8_0 data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { - return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); + return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q6_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { - return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); + return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q5_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { - return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); + return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q4_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { - return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); + return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q3_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { - return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); + return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q2_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { - return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); + return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize iq4_xs data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h index a52db2d..1740cbf 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h @@ -1,11 +1,11 @@ /** - * @Description : + * @Description : * @Author : Azure-Tang * @Date : 2024-07-22 09:27:55 * @Version : 1.0.0 * @LastEditors : kkk1nak0 * @LastEditTime : 2024-08-12 03:48:46 - * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #pragma once @@ -13,10 +13,10 @@ #include #include -torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); +torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); +torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); +torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); +torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); +torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); +torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); From 4dc5518e4d2ae89a687709bcbe05d2f3f80e00ad Mon Sep 17 00:00:00 2001 From: Azure Date: Mon, 24 Feb 2025 15:37:01 +0000 Subject: [PATCH 07/29] update fp8 kernel tutorial --- doc/SUMMARY.md | 3 +- doc/en/FAQ.md | 5 ++- doc/en/injection_tutorial.md | 1 + .../server/backend/interfaces/transformers.py | 2 +- merge_tensors/README.md | 36 +++++++++++++++++++ merge_tensors/merge_safetensor_gguf.py | 1 - requirements-local_chat.txt | 3 +- 7 files changed, 46 insertions(+), 5 deletions(-) create mode 100644 merge_tensors/README.md diff --git a/doc/SUMMARY.md b/doc/SUMMARY.md index 4efb7ea..47b4a02 100644 --- a/doc/SUMMARY.md +++ b/doc/SUMMARY.md @@ -5,10 +5,11 @@ - [Installation Guide](en/install.md) # Tutorial -- [Deepseek-R1/V3 Show Case](en/DeepseekR1_V3_tutorial.md) +- [Deepseek-R1/V3 Show Case/Tutorial](en/DeepseekR1_V3_tutorial.md) - [Why KTransformers So Fast](en/deepseek-v2-injection.md) - [Injection Tutorial](en/injection_tutorial.md) - [Multi-GPU Tutorial](en/multi-gpu-tutorial.md) +- [Using FP8 GPU Kernel](../merge_tensors/README.md) # Server - [Server](en/api/server/server.md) - [Website](en/api/server/website.md) diff --git a/doc/en/FAQ.md b/doc/en/FAQ.md index 51c6271..93977cf 100644 --- a/doc/en/FAQ.md +++ b/doc/en/FAQ.md @@ -55,7 +55,7 @@ You have to set `--cpu_infer` to the number of cores you want to use. The more c ### Q: My DeepSeek-R1 model is not thinking. -According to DeepSeek, you need to enforce the model to initiate its response with "\\n" at the beginning of every output by passing the arg `--force_think true `. +According to DeepSeek, you need to enforce the model to initiate its response with "\\n" at the beginning of every output by passing the arg `--force_think True `. ### Q: Loading gguf error @@ -63,9 +63,12 @@ Make sure you: 1. Have the `gguf` file in the `--gguf_path` directory. 2. The directory only contains gguf files from one model. If you have multiple models, you need to separate them into different directories. 3. The folder name it self should not end with `.gguf`, eg. `Deep-gguf` is correct, `Deep.gguf` is wrong. +4. The file itself is not corrupted; you can verify this by checking that the sha256sum matches the one from huggingface, modelscope, or hf-mirror. ### Q: Version `GLIBCXX_3.4.30' not found The detailed error: >ImportError: /mnt/data/miniconda3/envs/xxx/bin/../lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /home/xxx/xxx/ktransformers/./cpuinfer_ext.cpython-312-x86_64-linux-gnu.so) Running `conda install -c conda-forge libstdcxx-ng` can solve the problem. + + diff --git a/doc/en/injection_tutorial.md b/doc/en/injection_tutorial.md index 5ebb327..4518836 100644 --- a/doc/en/injection_tutorial.md +++ b/doc/en/injection_tutorial.md @@ -59,6 +59,7 @@ Supported operators and their corresponding classes are as follows: | Linear | KTransformersLinear | KLinearMarlin | Marlin as backend | | | | KLinearTorch | pytorch as backend | | | | KLinearCPUInfer | llamafile as backend | +| | | KLinearFP8 | Triton fp8_gemm kernel. Requires GPU be able to caluculate fp8 data | | experts | KTransformersExperts | KExpertsTorch | pytorch as backend | | | | KExpertsMarlin | Marlin as backend | | | | KExpertsCPU | llamafile as backend | diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 8211933..c6f8e3a 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -340,7 +340,7 @@ class TransformersInterface(BackendInterfaceBase): sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) next_token = self.decode_one_tokens() self.profiler.inc("decode") - if next_token == self.tokenizer.eos_token_id: + if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token): assert self.args.batch_size == 1 break yield self.append_new_tokens(next_token) diff --git a/merge_tensors/README.md b/merge_tensors/README.md new file mode 100644 index 0000000..c786197 --- /dev/null +++ b/merge_tensors/README.md @@ -0,0 +1,36 @@ +# FP8 Linear Kernel. +For DeepSeek-R1/V3, the DeepSeek-AI team provides fp8 safetensors. We have integrated the FP8 GPU kernel into the KTransformers. But to keep the experts still in CPU to save GPU memory, we still use ggml(GGUF tensors) quantization for experts. In this way, we can increase the precision in calculating attention, which may improve the model's performance. + +Therefore, to use fp8 linear kernel, we need to merge fp8 weights and gguf files. We have provides prepared weights in huggingface so that you can use them directly. + +[KVCache-ai/DeepSeek-V3](https://huggingface.co/KVCache-ai/DeepSeek-V3/upload/main) + + +If you want to use other formats of ggml quantization, you can use the following script to merge them. + +## Example +To use fp8 linear kernal and q4km experts. +```shell +bash +python convert_model.py \ + --safetensor_path \ + --gguf_path \ + --output_path +``` +* `--safetensor_path`: input path of safetensor file +* `--gguf_path`: input path of gguf folder +* `--output_path`: output path of merged file + + +## To Run DeepSeek-V3 with fp8 linear kernel and q4km experts + + +```shell +python ktransformers/local_chat.py --model_path deepseek-ai/DeepSeek-V3 --gguf_path --optimize_rule_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml --cpu_infer +``` + + +> NOTES: +> 1. Using fp8 linear kernel and q4km experts will consume approximatly 19GB GPU memory. +> 2. I know the the new way to load module is ugly, we are working on it. +> 3. Though the model is a mixture of fp8 and ggml, they are stored in .safetensor files. Please pass the folder path of the new weights to `--gguf_path`. diff --git a/merge_tensors/merge_safetensor_gguf.py b/merge_tensors/merge_safetensor_gguf.py index 7aeb62d..67e09e5 100644 --- a/merge_tensors/merge_safetensor_gguf.py +++ b/merge_tensors/merge_safetensor_gguf.py @@ -3,7 +3,6 @@ import os # insert the path of the project import sys -sys.path.insert(0, "/home/azure/ktransformers") import argparse import torch from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf diff --git a/requirements-local_chat.txt b/requirements-local_chat.txt index 0479d36..bf4df63 100644 --- a/requirements-local_chat.txt +++ b/requirements-local_chat.txt @@ -4,4 +4,5 @@ numpy torch>=2.3.0 packaging cpufeature -protobuf \ No newline at end of file +protobuf +tiktoken \ No newline at end of file From 36fbeee341e283a93b6befa2a4d9085b7a5dd2b1 Mon Sep 17 00:00:00 2001 From: Azure Date: Tue, 25 Feb 2025 08:21:18 +0000 Subject: [PATCH 08/29] Update doc --- README.md | 5 +- README_ZH.md | 17 +++-- doc/SUMMARY.md | 2 +- doc/en/FAQ.md | 2 +- doc/en/fp8_kernel.md | 74 +++++++++++++++++++ doc/en/install.md | 2 +- ktransformers/local_chat.py | 10 +-- .../backend/interfaces/ktransformers.py | 6 +- merge_tensors/README.md | 36 --------- merge_tensors/merge_safetensor_gguf.py | 3 +- requirements-local_chat.txt | 3 +- 11 files changed, 101 insertions(+), 59 deletions(-) create mode 100644 doc/en/fp8_kernel.md delete mode 100644 merge_tensors/README.md diff --git a/README.md b/README.md index 30f1425..f62528b 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,8 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

🔥 Updates

-* **Feb 15, 2025**: KTransformers V0.2.1: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%) (Up to 16 Tokens/s), update docs [here](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/). +* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; Longer Context (from 8K to 128K for 24GB VRAM). +* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/). * **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. For detailed show case and reproduction tutorial, see [here](./doc/en/DeepseekR1_V3_tutorial.md). * **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G. * **Aug 15, 2024**: Update detailed [tutorial](doc/en/injection_tutorial.md) for injection and multi-GPU. @@ -125,7 +126,7 @@ To utilize the provided kernels, users only need to create a YAML-based injectio ```python with torch.device("meta"): model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) -optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config) +optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) ... generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens=1000) ``` diff --git a/README_ZH.md b/README_ZH.md index 4cdd3c1..3696a08 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -21,7 +21,8 @@ KTransformers 是一个以 Python 为中心的灵活框架,其核心是可扩

🔥 更新

-* **2025 年 2 月 15 日**:KTransformers V0.2.1: 长上下文(从4K到8K,24GB VRAM) & 稍快的速度(+15%)(最快 16 Tokens/s),文档请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md) 和 [在线指南](https://kvcache-ai.github.io/ktransformers/) 。 +* **2025 年 2 月 15 日**:为DeepSeek-V3/R1支持[FP8 GPU内核](./doc/en/fp8_kernel.md); 支持更长的上下文 (从8K到128K仅用24GB VRAM). +* **2025 年 2 月 15 日**:长上下文(从4K到8K,24GB VRAM) & 稍快的速度(+15%)(最快 16 Tokens/s),文档请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md) 和 [在线指南](https://kvcache-ai.github.io/ktransformers/) 。 * **2025 年 2 月 10 日**:支持 Deepseek-R1 和 V3 在单个(24GB VRAM)/多 GPU 和 382G DRAM 上运行,速度提升高达 3~28 倍。详细教程请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md)。 * **2024 年 8 月 28 日**:支持 InternLM2.5-7B-Chat-1M 模型下的 1M 上下文,使用 24GB 的 VRAM 和 150GB 的 DRAM。详细教程请参见 [这里](./doc/en/long_context_tutorial.md)。 * **2024 年 8 月 28 日**:将 DeepseekV2 所需的 VRAM 从 21G 降低到 11G。 @@ -68,11 +69,11 @@ https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c

-

在仅 24GB VRAM 的桌面上进行 1M 上下文本地推理

-

- -https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12 + + + 更多高级功能即将推出,敬请期待! @@ -116,7 +117,7 @@ KTransformers 的核心是一个用户友好的、基于模板的注入框架。 ```python with torch.device("meta"): model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) -optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config) +optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) ... generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens=1000) ``` @@ -151,7 +152,7 @@ YAML 文件中的每个规则都有两部分:`match` 和 `replace`。`match`

致谢和贡献者

-KTransformer 的开发基于 Transformers 提供的灵活和多功能框架。我们还受益于 GGUF/GGML、Llamafile 和 Marlin 等高级内核。我们计划通过向上游贡献我们的修改来回馈社区。 +KTransformer 的开发基于 Transformers 提供的灵活和多功能框架。我们还受益于 GGUF/GGML、Llamafile 、 Marlin、sglang和flashinfer 等高级内核。我们计划通过向上游贡献我们的修改来回馈社区。 KTransformer 由清华大学 MADSys group 小组的成员以及 Approaching.AI 的成员积极维护和开发。我们欢迎新的贡献者加入我们,使 KTransformer 更快、更易于使用。 diff --git a/doc/SUMMARY.md b/doc/SUMMARY.md index 47b4a02..94172ce 100644 --- a/doc/SUMMARY.md +++ b/doc/SUMMARY.md @@ -9,7 +9,7 @@ - [Why KTransformers So Fast](en/deepseek-v2-injection.md) - [Injection Tutorial](en/injection_tutorial.md) - [Multi-GPU Tutorial](en/multi-gpu-tutorial.md) -- [Using FP8 GPU Kernel](../merge_tensors/README.md) +- [Use FP8 GPU Kernel](en/fp8_kernel.md) # Server - [Server](en/api/server/server.md) - [Website](en/api/server/website.md) diff --git a/doc/en/FAQ.md b/doc/en/FAQ.md index 93977cf..c8f58a2 100644 --- a/doc/en/FAQ.md +++ b/doc/en/FAQ.md @@ -45,7 +45,7 @@ from-https://github.com/kvcache-ai/ktransformers/issues/129#issue-2842799552 ### Q: If I don't have enough VRAM, but I have multiple GPUs, how can I utilize them? -Use the `--optimize_rule_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml` to load the two optimized rule yaml file. You may also use it as an example to write your own 4/8 gpu optimized rule yaml file. +Use the `--optimize_config_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml` to load the two optimized rule yaml file. You may also use it as an example to write your own 4/8 gpu optimized rule yaml file. > Note: The ktransformers' multi-gpu stratigy is pipline, which is not able to speed up the model's inference. It's only for the model's weight distribution. diff --git a/doc/en/fp8_kernel.md b/doc/en/fp8_kernel.md new file mode 100644 index 0000000..f1d86c3 --- /dev/null +++ b/doc/en/fp8_kernel.md @@ -0,0 +1,74 @@ +# FP8 Linear Kernel for DeepSeek-V3 + +## Overview +The DeepSeek-AI team provides FP8 safetensors for DeepSeek-R1/V3 models. We achieve performance optimization through the following works: +- **FP8 GPU Kernel Integration**: FP8 linear layer acceleration kernels integrated in KTransformers +- **Hybrid Quantization Architecture**: + - Attention and Shared-Expert modules use FP8 precision (enhances computational accuracy) + - Experts modules retain GGML quantization (GGUF format, reside in CPU to save GPU memory) + +So those who are persuing the best performance can use the FP8 linear kernel for DeepSeek-V3/R1. + +## Key Features +✅ Hybrid Precision Architecture (FP8 + GGML) +✅ Memory Optimization (~19GB VRAM usage) + +## Quick Start +### Using Pre-Merged Weights + +Pre-merged weights are available on Hugging Face: +[KVCache-ai/DeepSeek-V3](https://huggingface.co/KVCache-ai/DeepSeek-V3) +[KVCache-ai/DeepSeek-R1](https://huggingface.co/KVCache-ai/DeepSeek-R1) +> Please confirm the weights are fully uploaded before downloading. The large file size may extend Hugging Face upload time. + + +Download Pre-Merged Weights +```shell +pip install -U huggingface_hub + +# Optional: Use HF Mirror for faster downloads in special area. +# export HF_ENDPOINT=https://hf-mirror.com + +huggingface-cli download --resume-download KVCache-ai/DeepSeek-V3 --local-dir +``` +### Using merge scripts +If you got local DeepSeek-R1/V3 fp8 safetensors and q4km gguf weights, you can merge them using the following scripts. + +```shell +python convert_model.py \ + --safetensor_path \ + --gguf_path \ + --output_path +``` + +* `--safetensor_path`: input path of safetensor file([Download](https://huggingface.co/deepseek-ai/DeepSeek-V3/tree/main)). +* `--gguf_path`: input path of gguf folder ([Download](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)). +* `--output_path`: output path of merged file. + + +### Execution Notes + +Launch local_chat.py with custom quantized experts +```shell +python ktransformers/local_chat.py \ + --model_path deepseek-ai/DeepSeek-V3 \ + --gguf_path \ + --optimize_config_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml \ + --cpu_infer +``` + + +## Notes + +⚠️ Hardware Requirements +* Recommended minimum 19GB available VRAM for FP8 kernel. +* Requires GPU with FP8 support (e.g., 4090) + +⏳ First-Run Optimization +JIT compilation causes longer initial execution (subsequent runs retain optimized speed). + +🔄 Temporary Interface +Current weight loading implementation is provisional - will be refined in future versions + +📁 Path Specification +Despite hybrid quantization, merged weights are stored as .safetensors - pass the containing folder path to `--gguf_path` \ No newline at end of file diff --git a/doc/en/install.md b/doc/en/install.md index 269e8fb..2a4a6af 100644 --- a/doc/en/install.md +++ b/doc/en/install.md @@ -141,7 +141,7 @@ It features the following arguments: - `--gguf_path` (required): Path of a directory containing GGUF files which could that can be downloaded from [Hugging Face](https://huggingface.co/mzwing/DeepSeek-V2-Lite-Chat-GGUF/tree/main). Note that the directory should only contains GGUF of current model, which means you need one separate directory for each model. -- `--optimize_rule_path` (required except for Qwen2Moe and DeepSeek-V2): Path of YAML file containing optimize rules. There are two rule files pre-written in the [ktransformers/optimize/optimize_rules](ktransformers/optimize/optimize_rules) directory for optimizing DeepSeek-V2 and Qwen2-57B-A14, two SOTA MoE models. +- `--optimize_config_path` (required except for Qwen2Moe and DeepSeek-V2): Path of YAML file containing optimize rules. There are two rule files pre-written in the [ktransformers/optimize/optimize_rules](ktransformers/optimize/optimize_rules) directory for optimizing DeepSeek-V2 and Qwen2-57B-A14, two SOTA MoE models. - `--max_new_tokens`: Int (default=1000). Maximum number of new tokens to generate. diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index d5e74de..920a3f6 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -54,7 +54,7 @@ default_optimize_rules = { def local_chat( model_path: str | None = None, - optimize_rule_path: str = None, + optimize_config_path: str = None, gguf_path: str | None = None, max_new_tokens: int = 300, cpu_infer: int = Config().cpu_infer, @@ -95,12 +95,12 @@ def local_chat( config, trust_remote_code=True, attn_implementation="flash_attention_2" ) - if optimize_rule_path is None: + if optimize_config_path is None: if config.architectures[0] in default_optimize_rules: print("using default_optimize_rule for", config.architectures[0]) - optimize_rule_path = default_optimize_rules[config.architectures[0]] + optimize_config_path = default_optimize_rules[config.architectures[0]] else: - optimize_rule_path = input( + optimize_config_path = input( "please input the path of your rule file(yaml file containing optimize rules):" ) @@ -108,7 +108,7 @@ def local_chat( gguf_path = input( "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" ) - optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config) + optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) try: model.generation_config = GenerationConfig.from_pretrained(model_path) diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 49a3f16..55e32de 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -35,9 +35,9 @@ class KTransformersInterface(TransformersInterface): with torch.device("meta"): self.model = custom_models[config.architectures[0]](config) if default_args.optimize_config_path is None: - optimize_rule_path = default_optimize_rules[config.architectures[0]] + optimize_config_path = default_optimize_rules[config.architectures[0]] else: - optimize_rule_path = args.optimize_config_path + optimize_config_path = args.optimize_config_path # print(optimize_config) @@ -47,7 +47,7 @@ class KTransformersInterface(TransformersInterface): "please input the path of your gguf file(gguf file in the dir containing input gguf file must all" " belong to current model):" ) - optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config) + optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config) self.device_map = self.model.gguf_loader.tensor_device_map # logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}") diff --git a/merge_tensors/README.md b/merge_tensors/README.md deleted file mode 100644 index c786197..0000000 --- a/merge_tensors/README.md +++ /dev/null @@ -1,36 +0,0 @@ -# FP8 Linear Kernel. -For DeepSeek-R1/V3, the DeepSeek-AI team provides fp8 safetensors. We have integrated the FP8 GPU kernel into the KTransformers. But to keep the experts still in CPU to save GPU memory, we still use ggml(GGUF tensors) quantization for experts. In this way, we can increase the precision in calculating attention, which may improve the model's performance. - -Therefore, to use fp8 linear kernel, we need to merge fp8 weights and gguf files. We have provides prepared weights in huggingface so that you can use them directly. - -[KVCache-ai/DeepSeek-V3](https://huggingface.co/KVCache-ai/DeepSeek-V3/upload/main) - - -If you want to use other formats of ggml quantization, you can use the following script to merge them. - -## Example -To use fp8 linear kernal and q4km experts. -```shell -bash -python convert_model.py \ - --safetensor_path \ - --gguf_path \ - --output_path -``` -* `--safetensor_path`: input path of safetensor file -* `--gguf_path`: input path of gguf folder -* `--output_path`: output path of merged file - - -## To Run DeepSeek-V3 with fp8 linear kernel and q4km experts - - -```shell -python ktransformers/local_chat.py --model_path deepseek-ai/DeepSeek-V3 --gguf_path --optimize_rule_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml --cpu_infer -``` - - -> NOTES: -> 1. Using fp8 linear kernel and q4km experts will consume approximatly 19GB GPU memory. -> 2. I know the the new way to load module is ugly, we are working on it. -> 3. Though the model is a mixture of fp8 and ggml, they are stored in .safetensor files. Please pass the folder path of the new weights to `--gguf_path`. diff --git a/merge_tensors/merge_safetensor_gguf.py b/merge_tensors/merge_safetensor_gguf.py index 67e09e5..69780fe 100644 --- a/merge_tensors/merge_safetensor_gguf.py +++ b/merge_tensors/merge_safetensor_gguf.py @@ -3,6 +3,7 @@ import os # insert the path of the project import sys +sys.path.insert(0, "/home/azure/ktransformers") import argparse import torch from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf @@ -180,7 +181,7 @@ def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors") print(f"Saving layer {layer_num} to {output_file}") - print(tensors.keys()) + # print(tensors.keys()) save_file(tensors, output_file) shard_idx += 1 diff --git a/requirements-local_chat.txt b/requirements-local_chat.txt index bf4df63..ad280c0 100644 --- a/requirements-local_chat.txt +++ b/requirements-local_chat.txt @@ -5,4 +5,5 @@ torch>=2.3.0 packaging cpufeature protobuf -tiktoken \ No newline at end of file +tiktoken +blobfile \ No newline at end of file From f4c198bd42f037ccc570eb8f1f1b4ab1ea9c7fa2 Mon Sep 17 00:00:00 2001 From: Atream Date: Tue, 25 Feb 2025 08:52:02 +0000 Subject: [PATCH 09/29] support absorb for prefill long context --- ktransformers/local_chat.py | 4 +- ktransformers/operators/attention.py | 52 +++++++++++++------ ktransformers/operators/flashinfer_wrapper.py | 30 ++++++++--- ktransformers/operators/models.py | 6 ++- .../optimize_rules/DeepSeek-V3-Chat.yaml | 1 + .../backend/interfaces/ktransformers.py | 5 ++ .../server/backend/interfaces/transformers.py | 2 +- ktransformers/util/utils.py | 26 ++++++++-- 8 files changed, 93 insertions(+), 33 deletions(-) diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index d087752..5e57a22 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -28,7 +28,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM -from ktransformers.util.utils import prefill_and_generate +from ktransformers.util.utils import prefill_and_generate, get_compute_capability from ktransformers.server.config.config import Config from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled @@ -168,7 +168,7 @@ def local_chat( assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ "please change max_seq_len in ~/.ktransformers/config.yaml" - if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled: + if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8: generated = prefill_and_generate( model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index b4c5402..5e7391f 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -16,6 +16,7 @@ from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_ro from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.util.utils import get_compute_capability import logging from transformers.configuration_utils import PretrainedConfig from transformers.cache_utils import Cache @@ -48,12 +49,14 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): prefill_device: str = "cuda", generate_device: str = "cuda", chunck_size: int = 1000, + absorb_for_prefill: bool = False, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) self.chunck_size = chunck_size # TODO, generate chunck_size automatically. self.mla_wrapper = None + self.absorb_for_prefill = absorb_for_prefill def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): @@ -242,7 +245,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below q_nope = torch.matmul(q_nope, q_absorb) # batched MM q_nope = q_nope.transpose(1, 2) - assert q_nope.is_contiguous() + #assert q_nope.is_contiguous() # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank] # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] @@ -282,6 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] attn_output = attn_output.transpose(1, 2) attn_output = torch.matmul(attn_output, out_absorb.mT) + attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) @@ -380,7 +384,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim] # decode - if q_len == 1: + if q_len == 1 or self.absorb_for_prefill: if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) @@ -395,27 +399,41 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below q_nope = torch.matmul(q_nope, q_absorb) # batched MM q_nope = q_nope.transpose(1, 2) - assert q_nope.is_contiguous() + q_nope = q_nope.contiguous() + #assert q_nope.is_contiguous() # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank] # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] - q_nope.squeeze_(1) - q_pe.squeeze_(1) + q_nope.squeeze_(0) + q_pe.squeeze_(0) # flash attn doesn't support head_dim bigger than 256, use flashinfer if self.mla_wrapper is None: self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True) - if self.mla_wrapper.need_plan: - self.mla_wrapper.need_plan = False + if self.mla_wrapper.need_plan: + self.mla_wrapper.need_plan = False + if q_len == 1: self.mla_wrapper.plan(None,None,None, - position_ids.squeeze(1)+1, - self.num_heads, - self.kv_lora_rank, - self.qk_rope_head_dim, - past_key_value.page_size, - self.softmax_scale, - q_nope.dtype, - compressed_kv.dtype) + position_ids.squeeze(1)+1, + self.num_heads, + self.kv_lora_rank, + self.qk_rope_head_dim, + past_key_value.page_size, + self.softmax_scale, + q_nope.dtype, + compressed_kv.dtype) + else: + qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=self.device) + kv_len_arr = torch.tensor([position_ids[0, -1].item()+1], dtype=torch.int32, device=self.device) + self.mla_wrapper.plan(qo_indptr,None,None, + kv_len_arr, + self.num_heads, + self.kv_lora_rank, + self.qk_rope_head_dim, + past_key_value.page_size, + self.softmax_scale, + q_nope.dtype, + compressed_kv.dtype) attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank) """ @@ -443,6 +461,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank] attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim] + attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank] attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim] attn_output = self.o_proj(attn_output) @@ -571,7 +590,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if os.name == 'nt': + if os.name == 'nt' or get_compute_capability()<8: + print("for Windows or GPU before ampere, use forward_windows") return self.forward_windows( hidden_states, attention_mask, diff --git a/ktransformers/operators/flashinfer_wrapper.py b/ktransformers/operators/flashinfer_wrapper.py index b3b9dd1..864b33e 100644 --- a/ktransformers/operators/flashinfer_wrapper.py +++ b/ktransformers/operators/flashinfer_wrapper.py @@ -9,7 +9,7 @@ flashinfer_enabled = False try: import flashinfer - flashinfer_enabled = False # disabled now, TODO:use new version of flashinfer and enable + flashinfer_enabled = True print("found flashinfer") except ImportError: @@ -132,14 +132,14 @@ class MLAWrapper(): head_dim_ckv, head_dim_kpe, page_size, - False, # causal is False for decoding + True, # causal sm_scale, q_data_type, kv_data_type, ) def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False): - return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse) + return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse) class MLAWrapperSingleton(): wrappers:dict = {} @@ -179,6 +179,17 @@ class MLAWrapperSingleton(): sm_scale, q_data_type, kv_data_type,) + wrapper.need_plan = False + + @classmethod + def need_plan_all(cls): + for device, wrapper in cls.wrappers.items(): + wrapper.need_plan = True + + @classmethod + def reset_buffer(cls): + for device, wrapper in cls.wrappers.items(): + wrapper.qo_indptr_buf[1] = 1 if __name__ == "__main__": @@ -187,8 +198,9 @@ if __name__ == "__main__": page_size = 64 num_heads = 128 - q_nope = torch.randn((1, num_heads, 512), dtype=torch.bfloat16, device="cuda") - q_pe = torch.randn((1, num_heads, 64), dtype=torch.bfloat16, device="cuda") + q_len = 10 + q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda") + q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda") ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda") k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda") @@ -199,10 +211,10 @@ if __name__ == "__main__": max_pages, ) - kv_len_arr = torch.tensor([10], dtype=torch.int32, device="cuda") - + kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda") + qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda") wrapper.plan( - None, + qo_indptr, None, None, kv_len_arr, @@ -216,6 +228,7 @@ if __name__ == "__main__": ) attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe) + print(attn_output.shape) k = ( torch.cat([ckv, k_pe], dim=-1) @@ -235,6 +248,7 @@ if __name__ == "__main__": False, 192 ** (-0.5) ) + print(attn_ref.shape) torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3) print("test past") \ No newline at end of file diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index 3877dbc..57d4bea 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -56,7 +56,7 @@ from ktransformers.models.modeling_deepseek import ( from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.operators.base_operator import BaseInjectedModule -from ktransformers.util.utils import InferenceState +from ktransformers.util.utils import InferenceState, get_compute_capability from ktransformers.util.custom_gguf import GGUFLoader from transformers.configuration_utils import PretrainedConfig from ktransformers.models.modeling_llama import ( @@ -649,7 +649,9 @@ class KDeepseekV2Model(BaseInjectedModule): if per_layer_prefill_flag: causal_mask = None else: - if os.name == 'nt': + if os.name == 'nt' or get_compute_capability()<8: + print("for Windows or GPU before ampere, use forward_windows") + # only use mask in forward windows or can't flash attn causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml index 6fb6586..d28e016 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml @@ -60,6 +60,7 @@ kwargs: generate_device: "cuda" prefill_device: "cuda" + absorb_for_prefill: False # change this to True to enable long context(prefill may slower). - match: name: "^model$" replace: diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 49a3f16..8e6e5f9 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -14,6 +14,7 @@ from ktransformers.models.custom_cache import StaticCache from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.local_chat import custom_models, default_optimize_rules from ktransformers.util.utils import get_device +from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton warm_uped = False @@ -186,6 +187,8 @@ class KTransformersInterface(TransformersInterface): input_ids = input_ids.to("cpu") inputs_embeds = self.model.model.embed_tokens(input_ids).to(device) torch.cuda.set_device(device) + if flashinfer_enabled: + MLAWrapperSingleton.need_plan_all() if self.use_static_cache: logits = self.model( inputs_embeds=inputs_embeds, @@ -198,6 +201,8 @@ class KTransformersInterface(TransformersInterface): else: logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] + if flashinfer_enabled: + MLAWrapperSingleton.reset_buffer() self.prepare_logits_wrapper(input_ids, device) next_token = self.logits_to_token(logits[0, -1, :]) yield self.append_new_tokens(next_token) diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 8211933..7e6bd15 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase): for i in range(1, self.args.max_new_tokens): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - if i > 1 and flashinfer_enabled: + if flashinfer_enabled: MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size, diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 5c608b1..1908373 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -21,6 +21,18 @@ from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton warm_uped = False +def get_compute_capability(device:torch.device = None): + if torch.cuda.is_available(): + if device is None: + num_gpus = torch.cuda.device_count() + min_compute_capability_major = 100 + for gpu_id in range(num_gpus): + gpu_props = torch.cuda.get_device_properties(gpu_id) + min_compute_capability_major = min(min_compute_capability_major, gpu_props.major) + return min_compute_capability_major + else: + return torch.cuda.get_device_properties(device) + def set_module(model, submodule_key, module): tokens = submodule_key.split('.') sub_tokens = tokens[:-1] @@ -153,6 +165,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud inputs_embeds = model.model.embed_tokens(inputs.to("cpu")) else: inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device) + if use_flashinfer_mla: + MLAWrapperSingleton.need_plan_all() + logits = model( inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True )[0][:,-1,:].unsqueeze(0).clone().to(torch_device) @@ -175,6 +190,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud else: next_token = torch.argmax(next_token_scores, dim=-1) first_token_time = time.time() - start_time + + if use_flashinfer_mla: + MLAWrapperSingleton.reset_buffer() prefill_count = seq_length prefill_time = first_token_time @@ -192,15 +210,15 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud start_time = time.time() for i in range(1, max_new_tokens): + if use_flashinfer_mla: + MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1, + num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size, + q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16) global warm_uped if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ): warm_uped = True cuda_graph_runner = CUDAGraphRunner() cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True) - if i > 1 and use_flashinfer_mla: - MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1, - num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size, - q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16) next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) generated_ids[:, cache_position] = next_token.int() From 021822dd01b0ade6690ad358a46a4829de55ec84 Mon Sep 17 00:00:00 2001 From: Azure Date: Tue, 25 Feb 2025 09:02:32 +0000 Subject: [PATCH 10/29] update FAQ --- doc/en/FAQ.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/doc/en/FAQ.md b/doc/en/FAQ.md index c8f58a2..3269f6b 100644 --- a/doc/en/FAQ.md +++ b/doc/en/FAQ.md @@ -72,3 +72,24 @@ The detailed error: Running `conda install -c conda-forge libstdcxx-ng` can solve the problem. +### Q: When running the bfloat16 moe model, the data shows NaN +The detailed error: +```shell +Traceback (most recent call last): + File "/root/ktransformers/ktransformers/local_chat.py", line 183, in + fire.Fire(local_chat) + File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 135, in Fire + component_trace = _Fire(component, args, parsed_flag_args, context, name) + File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 468, in _Fire + component, remaining_args = _CallAndUpdateTrace( + File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 684, in _CallAndUpdateTrace + component = fn(*varargs, **kwargs) + File "/root/ktransformers/ktransformers/local_chat.py", line 177, in local_chat + generated = prefill_and_generate( + File "/root/ktransformers/ktransformers/util/utils.py", line 204, in prefill_and_generate + next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device) + File "/root/ktransformers/ktransformers/util/utils.py", line 128, in decode_one_tokens + next_token = torch.multinomial(probs, num_samples=1).squeeze(1) +RuntimeError: probability tensor contains either `inf`, `nan` or element < 0 +``` +**SOLUTION**: The issue of running ktransformers on Ubuntu 22.04 is caused by the current system's g++ version being too old, and the pre-defined macros do not include avx_bf16. We have tested and confirmed that it works on g++ 11.4 in Ubuntu 22.04. \ No newline at end of file From 0ca0b99fab7a98251ccb3eb9c51f8f877a4d2bc4 Mon Sep 17 00:00:00 2001 From: liam Date: Tue, 25 Feb 2025 17:19:19 +0800 Subject: [PATCH 11/29] :zap: update git ignore add docker dev container --- .devcontainer/Dockerfile | 19 ++++++++++++++++ .devcontainer/devcontainer.json | 34 ++++++++++++++++++++++++++++ .gitignore | 8 ++----- ktransformers/tests/mmlu_pro_test.py | 4 ++-- 4 files changed, 57 insertions(+), 8 deletions(-) create mode 100644 .devcontainer/Dockerfile create mode 100644 .devcontainer/devcontainer.json diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000..5c2606f --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,19 @@ +FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel as compile_server +WORKDIR /workspace +ENV CUDA_HOME /usr/local/cuda +RUN < Date: Tue, 25 Feb 2025 17:45:17 +0800 Subject: [PATCH 12/29] :memo: add benchmark.md --- doc/en/benchmark.md | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 doc/en/benchmark.md diff --git a/doc/en/benchmark.md b/doc/en/benchmark.md new file mode 100644 index 0000000..a07f240 --- /dev/null +++ b/doc/en/benchmark.md @@ -0,0 +1,43 @@ +## Benchmark + +To conduct a quick and convenient check, we have employed a simple Python script available [here](https://github.com/kvcache-ai/ktransformers/tree/main/ktransformers/tests) to assess the precision of our **[ktransformers](https://github.com/kvcache-ai/ktransformers)** project. For this evaluation, we utilized the same dataset, which was shuffled in a consistent manner and limited to the first 1,000 data points, to test our implementation across a variety of CPU kernels, MLA kernels, and quantization formats. + +We selected the DeepSeek-V3 model in its bf16, int8, and q4km versions for this test. The MMLU dataset, which can be found [here](https://huggingface.co/datasets/cais/mmlu), was used (we selected all datasets and shuffled them with a fixed random seed). + +**!!! However, we skipped the few-shot part and only chose the first 1,000 data points for a quick check.** Please note that this approach may result in results that are not consistent with the technical report of DeepSeek-V3. And the test of R1 and further more tests are on going. + +To verify our results, we chose [cloud service platform](https://cloud.siliconflow.cn/models) as baseline. All tests were conducted using the same script and datasets, allowing us to make a preliminary assessment of our project's precision. + +We set the argument `temperature=0.6`, and to simplify the test process, we skipped the few-shot part and used the following prompt: `There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter. \nQuestion: {question}\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\nAnswer: '`. For more details, please refer to the [script](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/tests/mmlu_test.py). + +Given that we have only tested 1,000 cases, which provides only a preliminary judgment, some fluctuations in the results are reasonable. We selected all datasets and shuffled them with a fixed random seed to ensure consistency. + +## Some Detail + +- The bf16 model of DeepSeek-V3 is available [here](https://huggingface.co/opensourcerelease/DeepSeek-V3-bf16/tree/main) (you may convert it to gguf by llama.cpp). The q4km model can be found [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M). + +- The optimization YAML file is located [here](https://github.com/kvcache-ai/ktransformers/tree/main/ktransformers/optimize/optimize_rules). For the Matrix MUL Kernel, you can change `KLinearMarlin` to `KLinearTorch`. + +- To switch the MLA Kernel from Triton to Torch, you can check and modify [this file](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/attention.py), specifically by using the `forward_windows` method. + +- When attempting to conduct the bf16 test (both CPU Weight and GPU Weight), you may encounter issues stemming from older versions of g++ and as, particularly when using Ubuntu 20 or earlier versions. To facilitate a smoother experience and enable you to reproduce our results, we have provided a development container. This container offers a pre-configured environment tailored for this purpose. However, please note that the container does not have the ktrans package installed. Therefore, you may still need to manually install certain packages to ensure everything runs smoothly. + + - You may config the model mount dir in `devcontainer/devcontainer.json`, check the `"mouts":` config. + + +## The Result Table + +| | | | | | | | | +| ------------------------ | ----------------- | ---------- | ----------------- | ------- | ---------- | ------------------------------------------------------ | ------------ | +| DataSet | CPU Weight Format | CPU Kernel | GPU Weight Format | GEMM | MLA Kernel | [Siliconflow](https://cloud.siliconflow.cn/models)
| Ktrans Point | +| MMLU

(shuffle 1k) | bf16 | cpuinfer | bf16 | torch | torch | 81.6 | 81.9 | +| | int8 | cpuinfer | bf16 | torch | torch | 81.6 | 83.1 | +| | q4km | cpuinfer | bf16 | torch | torch | 81.6 | 82.8 | +| | q4km | cpuinfer | bf16 | torch | triton | 81.6 | 81.4 | +| | q4km | cpuinfer | q4km->marlin 8 | marlin | triton | 81.6 | 81.1 | +| | q4km | cpuinfer | q4km->marlin 4 | marlin | triton | 81.6 | 81 | +| | q4km | cpuinfer | fp8 | marlin | triton | 81.6 | 81.5 | +| MMLU-pro | q4km | cpuinfer | fp8 | fp8gemm | triton | 57.7 | 57.6 | +| MMLU-pro | q4km | cpuinfer | q4km->marlin 4 | marlin | triton | 57.7 | 57.5 | +| HumanEval | tbd | tbd | tbd | tbd | tbd | tbd | tbd | +| GSM8K | tbd | tbd | tbd | tbd | tbd | tbd | tbd | From 7e5962af3d16af570f3108acd3238d9dae330a45 Mon Sep 17 00:00:00 2001 From: Azure Date: Tue, 25 Feb 2025 10:52:29 +0000 Subject: [PATCH 13/29] fix fp8 multi gpu; update FQA --- doc/en/FAQ.md | 6 +++++- ktransformers/ktransformers_ext/triton/fp8gemm.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/doc/en/FAQ.md b/doc/en/FAQ.md index 3269f6b..a8399bd 100644 --- a/doc/en/FAQ.md +++ b/doc/en/FAQ.md @@ -92,4 +92,8 @@ Traceback (most recent call last): next_token = torch.multinomial(probs, num_samples=1).squeeze(1) RuntimeError: probability tensor contains either `inf`, `nan` or element < 0 ``` -**SOLUTION**: The issue of running ktransformers on Ubuntu 22.04 is caused by the current system's g++ version being too old, and the pre-defined macros do not include avx_bf16. We have tested and confirmed that it works on g++ 11.4 in Ubuntu 22.04. \ No newline at end of file +**SOLUTION**: The issue of running ktransformers on Ubuntu 22.04 is caused by the current system's g++ version being too old, and the pre-defined macros do not include avx_bf16. We have tested and confirmed that it works on g++ 11.4 in Ubuntu 22.04. + +### Q: Using fp8 prefill very slow. + +The FP8 kernel is build by JIT, so the first run will be slow. The subsequent runs will be faster. \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/triton/fp8gemm.py b/ktransformers/ktransformers_ext/triton/fp8gemm.py index 7d5b72e..d5c913d 100644 --- a/ktransformers/ktransformers_ext/triton/fp8gemm.py +++ b/ktransformers/ktransformers_ext/triton/fp8gemm.py @@ -102,7 +102,8 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t M, N = x.size() y = torch.empty_like(x, dtype=torch.get_default_dtype()) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) - weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + with torch.cuda.device(x.device): + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) return y From 477ac28a9c5002a066a956320297d2ec1c098d52 Mon Sep 17 00:00:00 2001 From: Atream Date: Tue, 25 Feb 2025 12:47:31 +0000 Subject: [PATCH 14/29] fix-update-flashinfer_wrapper_local_chat --- ktransformers/operators/attention.py | 3 +-- ktransformers/operators/flashinfer_wrapper.py | 11 +++++++++-- .../optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml | 4 ++++ ktransformers/util/utils.py | 1 + 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 5e7391f..35c8093 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -435,7 +435,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): q_nope.dtype, compressed_kv.dtype) attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank) - """ k = ( torch.cat([compressed_kv, k_pe], dim=-1) @@ -465,7 +464,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim] attn_output = self.o_proj(attn_output) - + return attn_output, None, past_key_value else: if past_key_value is not None: diff --git a/ktransformers/operators/flashinfer_wrapper.py b/ktransformers/operators/flashinfer_wrapper.py index 864b33e..f8ea3ce 100644 --- a/ktransformers/operators/flashinfer_wrapper.py +++ b/ktransformers/operators/flashinfer_wrapper.py @@ -122,7 +122,7 @@ class MLAWrapper(): if kv_indices is None: assert self.max_batch_size == 1 kv_indices = self.kv_indices_buf - + self.wrapper.plan( qo_indptr, kv_indptr, @@ -189,7 +189,14 @@ class MLAWrapperSingleton(): @classmethod def reset_buffer(cls): for device, wrapper in cls.wrappers.items(): - wrapper.qo_indptr_buf[1] = 1 + wrapper.qo_indptr_buf[1] = 1 # assert max_batch_size=1 here. + + @classmethod + def update_buffer(cls, max_pages): + for device, wrapper in cls.wrappers.items(): + wrapper.kv_indptr_buf[1] = max_pages # assert max_batch_size=1 here. + wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device) + wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf if __name__ == "__main__": diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml index 03c85a0..ea75b30 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml @@ -293,6 +293,7 @@ kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" + absorb_for_prefill: False # GPU 1: layers 15–29 - match: @@ -302,6 +303,7 @@ kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" + absorb_for_prefill: False # GPU 2: layers 30–44 - match: @@ -311,6 +313,7 @@ kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" + absorb_for_prefill: False # GPU 3: layers 45–60 - match: @@ -320,6 +323,7 @@ kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" + absorb_for_prefill: False # === Overall Model Replacement with Transfer Map === diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 64b9131..3f5ad8e 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -177,6 +177,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud else: inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device) if use_flashinfer_mla: + MLAWrapperSingleton.update_buffer(past_key_values.max_pages) MLAWrapperSingleton.need_plan_all() logits = model( From 2c0cce90d0d2bf9e361e514134fa5689f9f46db4 Mon Sep 17 00:00:00 2001 From: Azure Date: Tue, 25 Feb 2025 13:32:09 +0000 Subject: [PATCH 15/29] add fp8 multi gpu yaml example --- ...hat-multi-gpu-fp8-linear-ggml-experts.yaml | 157 ++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml new file mode 100644 index 0000000..fa8c03d --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml @@ -0,0 +1,157 @@ +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([3456][0-9])\\." + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + generate_op: "KLinearFP8" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "KLinearFP8" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([3456][0-9])\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:0" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda:0" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:1" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda:1" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + absorb_for_prefill: False # change this to True to enable long context(prefill may slower). + +- match: + name: "^model\\.layers\\.([3456][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + absorb_for_prefill: False # change this to True to enable long context(prefill may slower). + +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill + transfer_map: + 30: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." + replace: + class: "default" + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + +- match: + name: "^lm_head" + class: torch.nn.Linear + replace: + class: "default" + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + + +- match: + name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)" + replace: + class: "default" + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" From 03f8bc9f79d9b3915bce73ab13174f91e53a79d9 Mon Sep 17 00:00:00 2001 From: Atream <80757050+Atream@users.noreply.github.com> Date: Tue, 25 Feb 2025 21:35:31 +0800 Subject: [PATCH 16/29] Update DeepseekR1_V3_tutorial.md add long context --- doc/en/DeepseekR1_V3_tutorial.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index 52e9d32..22cfab7 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -154,6 +154,18 @@ the output quality doesn't change. But the speed of decoding and prefill is speed up which is inspiring. So our showcase makes use of this finding* ## How to Run +### V0.2.2 longer context +If you want to use long context(longer than 20K) for prefill, enable the matrix absorption MLA during the prefill phase, which will significantly reduce the size of the kv cache. Modify yaml file like this: +``` +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + absorb_for_prefill: True # change this to True to enable long context(prefill may slower). +``` ### V0.2 & V0.2.1 Showcase #### Single socket version (32 cores) Our local_chat test command is: From 13974eb2642156118ae16ba469f7ddc9265b0498 Mon Sep 17 00:00:00 2001 From: Atream <80757050+Atream@users.noreply.github.com> Date: Tue, 25 Feb 2025 21:36:52 +0800 Subject: [PATCH 17/29] Update DeepseekR1_V3_tutorial.md --- doc/en/DeepseekR1_V3_tutorial.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index 22cfab7..84d3418 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -16,6 +16,7 @@ - [Memory consumptions:](#memory-consumptions) - [Benchmark results](#benchmark-results-2) - [How to Run](#how-to-run) + - [V0.2.2 longer context](#v022-longer-context) - [V0.2 \& V0.2.1 Showcase](#v02--v021-showcase) - [Single socket version (32 cores)](#single-socket-version-32-cores) - [Dual socket version (64 cores)](#dual-socket-version-64-cores) From ddf3339339e5c247f960573ba4dfd7db96b697b7 Mon Sep 17 00:00:00 2001 From: liam Date: Tue, 25 Feb 2025 22:06:36 +0800 Subject: [PATCH 18/29] :zap: release v0.2.2rc1 --- ktransformers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ktransformers/__init__.py b/ktransformers/__init__.py index df85afe..5b7ecae 100644 --- a/ktransformers/__init__.py +++ b/ktransformers/__init__.py @@ -8,4 +8,4 @@ Version : 1.0.0 LastEditors : chenxl LastEditTime : 2025-02-15 03:53:02 ''' -__version__ = "0.2.1.post1" \ No newline at end of file +__version__ = "0.2.2rc1" \ No newline at end of file From bb6920ed72241556b87f1fea704180143af2c997 Mon Sep 17 00:00:00 2001 From: Azure Date: Tue, 25 Feb 2025 15:46:15 +0000 Subject: [PATCH 19/29] update doc --- doc/en/DeepseekR1_V3_tutorial.md | 22 ++++++++++++++++++++-- doc/en/fp8_kernel.md | 8 ++++---- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index 84d3418..29bfe3b 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -16,7 +16,9 @@ - [Memory consumptions:](#memory-consumptions) - [Benchmark results](#benchmark-results-2) - [How to Run](#how-to-run) - - [V0.2.2 longer context](#v022-longer-context) + - [V0.2.2 longer context \& FP8 kernel](#v022-longer-context--fp8-kernel) + - [longer context](#longer-context) + - [FP8 kernel](#fp8-kernel) - [V0.2 \& V0.2.1 Showcase](#v02--v021-showcase) - [Single socket version (32 cores)](#single-socket-version-32-cores) - [Dual socket version (64 cores)](#dual-socket-version-64-cores) @@ -155,7 +157,11 @@ the output quality doesn't change. But the speed of decoding and prefill is speed up which is inspiring. So our showcase makes use of this finding* ## How to Run -### V0.2.2 longer context +### V0.2.2 longer context & FP8 kernel +#### longer context +To use this feature, [install flashinfer](https://github.com/flashinfer-ai/flashinfer) first. + + If you want to use long context(longer than 20K) for prefill, enable the matrix absorption MLA during the prefill phase, which will significantly reduce the size of the kv cache. Modify yaml file like this: ``` - match: @@ -167,6 +173,18 @@ If you want to use long context(longer than 20K) for prefill, enable the matrix prefill_device: "cuda" absorb_for_prefill: True # change this to True to enable long context(prefill may slower). ``` +#### FP8 kernel + +The DeepSeek-AI team provides FP8 safetensors for DeepSeek-R1/V3 models. We achieve performance optimization through the following works: +- **FP8 GPU Kernel Integration**: FP8 linear layer acceleration kernels integrated in KTransformers +- **Hybrid Quantization Architecture**: + - Attention and Shared-Expert modules use FP8 precision (enhances computational accuracy) + - Experts modules retain GGML quantization (GGUF format, reside in CPU to save GPU memory) + +So those who are persuing the best performance can use the FP8 linear kernel for DeepSeek-V3/R1. + +The detailed guide is [here](./fp8_kernel.md). + ### V0.2 & V0.2.1 Showcase #### Single socket version (32 cores) Our local_chat test command is: diff --git a/doc/en/fp8_kernel.md b/doc/en/fp8_kernel.md index f1d86c3..5237a5c 100644 --- a/doc/en/fp8_kernel.md +++ b/doc/en/fp8_kernel.md @@ -1,4 +1,4 @@ -# FP8 Linear Kernel for DeepSeek-V3 +# FP8 Linear Kernel for DeepSeek-V3/R1 ## Overview The DeepSeek-AI team provides FP8 safetensors for DeepSeek-R1/V3 models. We achieve performance optimization through the following works: @@ -17,8 +17,8 @@ So those who are persuing the best performance can use the FP8 linear kernel for ### Using Pre-Merged Weights Pre-merged weights are available on Hugging Face: -[KVCache-ai/DeepSeek-V3](https://huggingface.co/KVCache-ai/DeepSeek-V3) -[KVCache-ai/DeepSeek-R1](https://huggingface.co/KVCache-ai/DeepSeek-R1) +[KVCache-ai/DeepSeek-V3-GGML-FP8-Hybrid](https://huggingface.co/KVCache-ai/DeepSeek-V3) +[KVCache-ai/DeepSeek-R1-GGML-FP8-Hybrid](https://huggingface.co/KVCache-ai/DeepSeek-R1) > Please confirm the weights are fully uploaded before downloading. The large file size may extend Hugging Face upload time. @@ -29,7 +29,7 @@ pip install -U huggingface_hub # Optional: Use HF Mirror for faster downloads in special area. # export HF_ENDPOINT=https://hf-mirror.com -huggingface-cli download --resume-download KVCache-ai/DeepSeek-V3 --local-dir +huggingface-cli download --resume-download KVCache-ai/DeepSeek-V3-GGML-FP8-Hybrid --local-dir ``` ### Using merge scripts If you got local DeepSeek-R1/V3 fp8 safetensors and q4km gguf weights, you can merge them using the following scripts. From 05339ad0ef2e2ede71e4ddfd53523fa4fc2e27cc Mon Sep 17 00:00:00 2001 From: liam Date: Tue, 25 Feb 2025 23:56:19 +0800 Subject: [PATCH 20/29] :memo: update benchmark.md --- doc/SUMMARY.md | 2 ++ doc/en/benchmark.md | 35 ++++++++++++++++++++++++----------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/doc/SUMMARY.md b/doc/SUMMARY.md index 94172ce..d9fa9b8 100644 --- a/doc/SUMMARY.md +++ b/doc/SUMMARY.md @@ -21,3 +21,5 @@ - [FAQ](en/FAQ.md) # V3 Reproduction - [Success List](en/V3-success.md) +# Benchmark +- [Benchmark](en/benchmark.md) \ No newline at end of file diff --git a/doc/en/benchmark.md b/doc/en/benchmark.md index a07f240..c9e152f 100644 --- a/doc/en/benchmark.md +++ b/doc/en/benchmark.md @@ -16,7 +16,7 @@ Given that we have only tested 1,000 cases, which provides only a preliminary ju - The bf16 model of DeepSeek-V3 is available [here](https://huggingface.co/opensourcerelease/DeepSeek-V3-bf16/tree/main) (you may convert it to gguf by llama.cpp). The q4km model can be found [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M). -- The optimization YAML file is located [here](https://github.com/kvcache-ai/ktransformers/tree/main/ktransformers/optimize/optimize_rules). For the Matrix MUL Kernel, you can change `KLinearMarlin` to `KLinearTorch`. +- The optimization YAML file is located [here](https://github.com/kvcache-ai/ktransformers/tree/main/ktransformers/optimize/optimize_rules). For the GEMM Kernel, you can change `KLinearMarlin` to `KLinearTorch`. - To switch the MLA Kernel from Triton to Torch, you can check and modify [this file](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/attention.py), specifically by using the `forward_windows` method. @@ -29,15 +29,28 @@ Given that we have only tested 1,000 cases, which provides only a preliminary ju | | | | | | | | | | ------------------------ | ----------------- | ---------- | ----------------- | ------- | ---------- | ------------------------------------------------------ | ------------ | -| DataSet | CPU Weight Format | CPU Kernel | GPU Weight Format | GEMM | MLA Kernel | [Siliconflow](https://cloud.siliconflow.cn/models)
| Ktrans Point | -| MMLU

(shuffle 1k) | bf16 | cpuinfer | bf16 | torch | torch | 81.6 | 81.9 | -| | int8 | cpuinfer | bf16 | torch | torch | 81.6 | 83.1 | -| | q4km | cpuinfer | bf16 | torch | torch | 81.6 | 82.8 | -| | q4km | cpuinfer | bf16 | torch | triton | 81.6 | 81.4 | -| | q4km | cpuinfer | q4km->marlin 8 | marlin | triton | 81.6 | 81.1 | -| | q4km | cpuinfer | q4km->marlin 4 | marlin | triton | 81.6 | 81 | -| | q4km | cpuinfer | fp8 | marlin | triton | 81.6 | 81.5 | -| MMLU-pro | q4km | cpuinfer | fp8 | fp8gemm | triton | 57.7 | 57.6 | -| MMLU-pro | q4km | cpuinfer | q4km->marlin 4 | marlin | triton | 57.7 | 57.5 | +| DataSet | CPU Weight Format | CPU Kernel | GPU Weight Format | GEMM Kernel | MLA Kernel | [Siliconflow](https://cloud.siliconflow.cn/models)
| Ktrans Point | +| MMLU

(shuffle 1k) | | | | | | | | +| 1 | bf16 | cpuinfer | bf16 | torch | torch | 81.6 | 81.9 | +| 2 | q8_0 | cpuinfer | bf16 | torch | torch | 81.6 | 83.1 | +| 3 | q4km | cpuinfer | bf16 | torch | triton | 81.6 | 81.4 | +| 4 | q4km | cpuinfer | q4km->marlin 8 | marlin | triton | 81.6 | 81.1 | +| 5 | q4km | cpuinfer | q4km->marlin 4 | marlin | triton | 81.6 | 81 | +| 6 | q4km | cpuinfer | fp8 | fp8gemm | triton | 81.6 | 81.5 | +| MMLU-pro | | | | | | | | +| 1 | q4km | cpuinfer | fp8 | fp8gemm | triton | 57.7 | 57.6 | +| 2 | q4km | cpuinfer | q4km->marlin 4 | marlin | triton | 57.7 | 57.5 | | HumanEval | tbd | tbd | tbd | tbd | tbd | tbd | tbd | | GSM8K | tbd | tbd | tbd | tbd | tbd | tbd | tbd | + +**the yaml files used for each case are listed below**: +- MMLU test + 1. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml) change all the `KLinearMarlin` to `KLinearTorch` (just find all the usage in this file). The source weight comes from [there](https://huggingface.co/opensourcerelease/DeepSeek-V3-bf16) (you need to use llama.cpp to convert it to gguf) + 2. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). You need to modify the code to seperately load cpu's expert weight. We leave this as comment in these places: [1](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L122), [2](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L136), [3](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L137) (note in 3, change the path to your local weight file path). The weight file for q8_0 is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q8_0) + 3. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). You need to modify the code to seperately load cpu's expert weight. We leave this as comment in these places: [1](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L122), [2](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L136), [3](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L137) (note in 3, change the path to your local weight file path). The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M) + 4. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). You don't need to change the source code as they both use q4km. But note the yaml file [here](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml#L29) and [here](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml#L18), below these lines you need to add `num_bits: 8` (in other words: add this kwargs to all that use `KLinearMarlin`). The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M) + 5. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). No need to change yaml, just use the default. The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M) + 6. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case. +- MMLU-pro test + 1. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case. + 2. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). No need to change yaml, just use the default. The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M) \ No newline at end of file From 3ad12751cf17edb6e2f4fa7d5e3afe29b080a816 Mon Sep 17 00:00:00 2001 From: liam Date: Wed, 26 Feb 2025 00:17:02 +0800 Subject: [PATCH 21/29] :memo: update more detail and fix typo --- doc/en/benchmark.md | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/doc/en/benchmark.md b/doc/en/benchmark.md index c9e152f..599aef8 100644 --- a/doc/en/benchmark.md +++ b/doc/en/benchmark.md @@ -12,7 +12,7 @@ We set the argument `temperature=0.6`, and to simplify the test process, we skip Given that we have only tested 1,000 cases, which provides only a preliminary judgment, some fluctuations in the results are reasonable. We selected all datasets and shuffled them with a fixed random seed to ensure consistency. -## Some Detail +## Some Details - The bf16 model of DeepSeek-V3 is available [here](https://huggingface.co/opensourcerelease/DeepSeek-V3-bf16/tree/main) (you may convert it to gguf by llama.cpp). The q4km model can be found [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M). @@ -43,14 +43,17 @@ Given that we have only tested 1,000 cases, which provides only a preliminary ju | HumanEval | tbd | tbd | tbd | tbd | tbd | tbd | tbd | | GSM8K | tbd | tbd | tbd | tbd | tbd | tbd | tbd | -**the yaml files used for each case are listed below**: +**The details for each case are listed below**: + +By default, The MLA kernel uses triton in linux and torch in windows. But we need to test torch in linux, so we manually modify the [file](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/attention.py#L592). Just get rid of all the if branch and force it to use `self.forward_windows` + - MMLU test 1. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml) change all the `KLinearMarlin` to `KLinearTorch` (just find all the usage in this file). The source weight comes from [there](https://huggingface.co/opensourcerelease/DeepSeek-V3-bf16) (you need to use llama.cpp to convert it to gguf) - 2. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). You need to modify the code to seperately load cpu's expert weight. We leave this as comment in these places: [1](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L122), [2](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L136), [3](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L137) (note in 3, change the path to your local weight file path). The weight file for q8_0 is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q8_0) - 3. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). You need to modify the code to seperately load cpu's expert weight. We leave this as comment in these places: [1](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L122), [2](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L136), [3](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L137) (note in 3, change the path to your local weight file path). The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M) + 2. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). You need to modify the code to separately load cpu's expert weight. We leave this as comment in these places: [1](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L122), [2](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L136), [3](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L137) (note in 3, change the path to your local weight file path). The weight file for q8_0 is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q8_0) + 3. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). You need to modify the code to separately load cpu's expert weight. We leave this as comment in these places: [1](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L122), [2](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L136), [3](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L137) (note in 3, change the path to your local weight file path). The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M) 4. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). You don't need to change the source code as they both use q4km. But note the yaml file [here](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml#L29) and [here](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml#L18), below these lines you need to add `num_bits: 8` (in other words: add this kwargs to all that use `KLinearMarlin`). The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M) 5. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). No need to change yaml, just use the default. The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M) 6. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case. - MMLU-pro test 1. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case. - 2. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). No need to change yaml, just use the default. The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M) \ No newline at end of file + 2. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). No need to change yaml, just use the default. The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M) \ No newline at end of file From b2bff1777502519838b8a5ec112d2566635d42e5 Mon Sep 17 00:00:00 2001 From: wkgcass Date: Wed, 26 Feb 2025 14:48:22 +0800 Subject: [PATCH 22/29] fix numa cpu distribution The numa node location would be calculated based on the total number of worker threads. So we should always use the actual number of threads instead of using a min() op. --- ktransformers/ktransformers_ext/cpu_backend/backend.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ktransformers/ktransformers_ext/cpu_backend/backend.cpp b/ktransformers/ktransformers_ext/cpu_backend/backend.cpp index 5980ba3..a254db9 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/backend.cpp +++ b/ktransformers/ktransformers_ext/cpu_backend/backend.cpp @@ -54,7 +54,12 @@ void Backend::do_work_stealing_job(int task_num, init_func_ = init_func; compute_func_ = compute_func; finalize_func_ = finalize_func; +#ifdef USE_NUMA + // numa node location will be calculated based on the number of threads + thread_num_ = max_thread_num_; +#else thread_num_ = std::min(max_thread_num_, task_num); +#endif int base = task_num / thread_num_; int remain = task_num % thread_num_; thread_state_[0].end = base + (0 < remain); @@ -146,4 +151,4 @@ void Backend::worker_thread(int thread_id) { return; } } -} \ No newline at end of file +} From de082f141c16992bb8584996938396d8ebcd1ac7 Mon Sep 17 00:00:00 2001 From: liam Date: Wed, 26 Feb 2025 14:54:47 +0800 Subject: [PATCH 23/29] :zap: fix cd error --- .devcontainer/Dockerfile | 1 - 1 file changed, 1 deletion(-) diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 5c2606f..03ede05 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -10,7 +10,6 @@ apt update -y && apt install -y --no-install-recommends \ g++ \ cmake && rm -rf /var/lib/apt/lists/* && -cd ktransformers && pip install ninja pyproject numpy cpufeature && pip install flash-attn && cp /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /opt/conda/lib/ From ffb86c66e3ac06a5eb57a3f4ebbc219d959142c4 Mon Sep 17 00:00:00 2001 From: liam Date: Wed, 26 Feb 2025 15:04:25 +0800 Subject: [PATCH 24/29] :zap: fix experts torch --- ktransformers/operators/experts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 10e3a66..88960c7 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -459,9 +459,9 @@ class KExpertsTorch(KExpertsBase): self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype) self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype) - self.up = torch.cat(self.up, dim=0) - self.gate = torch.cat(self.gate, dim=0) - self.down = torch.cat(self.down, dim=0) + self.up = torch.stack(self.up, dim=0) + self.gate = torch.stack(self.gate, dim=0) + self.down = torch.stack(self.down, dim=0) return def unload(self): From 68e7df3a251650995384740a62b678310c7c73c2 Mon Sep 17 00:00:00 2001 From: swu-hyk Date: Wed, 26 Feb 2025 17:05:00 +0800 Subject: [PATCH 25/29] implementation of chat routing for Ollama --- .../server/api/ollama/completions.py | 109 +++++++++++------- 1 file changed, 69 insertions(+), 40 deletions(-) diff --git a/ktransformers/server/api/ollama/completions.py b/ktransformers/server/api/ollama/completions.py index e3a1a51..d0ac17e 100644 --- a/ktransformers/server/api/ollama/completions.py +++ b/ktransformers/server/api/ollama/completions.py @@ -12,8 +12,8 @@ from ktransformers.server.config.config import Config from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.schemas.assistants.streaming import check_link_response from ktransformers.server.backend.base import BackendInterfaceBase -router = APIRouter(prefix='/api') +router = APIRouter(prefix='/api') # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion class OllamaGenerateCompletionRequest(BaseModel): @@ -40,61 +40,95 @@ class OllamaGenerateCompletionRequest(BaseModel): keep_alive: Optional[str] = Field( "5m", description="Controls how long the model will stay loaded into memory following the request.") - class OllamaGenerationStreamResponse(BaseModel): model: str created_at: str response: str done: bool = Field(...) - class OllamaGenerationResponse(BaseModel): pass - @router.post("/generate", tags=['ollama']) async def generate(request: Request, input: OllamaGenerateCompletionRequest): id = str(uuid4()) - interface: BackendInterfaceBase = get_interface() print(f'COMPLETION INPUT:----\n{input.prompt}\n----') - config = Config() if input.stream: async def inner(): - async for token in interface.inference(input.prompt,id): - d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response=token,done=False) - yield d.model_dump_json()+'\n' - # d = {'model':config.model_name,'created_at':"", 'response':token,'done':False} - # yield f"{json.dumps(d)}\n" - # d = {'model':config.model_name,'created_at':"", 'response':'','done':True} - # yield f"{json.dumps(d)}\n" - d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response='',done=True) - yield d.model_dump_json()+'\n' - return check_link_response(request,inner()) + async for token in interface.inference(input.prompt, id): + d = OllamaGenerationStreamResponse( + model=config.model_name, + created_at=str(datetime.now()), + response=token, + done=False + ) + yield d.model_dump_json() + '\n' + d = OllamaGenerationStreamResponse( + model=config.model_name, + created_at=str(datetime.now()), + response='', + done=True + ) + yield d.model_dump_json() + '\n' + return check_link_response(request, inner()) else: raise NotImplementedError # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion - +class OllamaChatCompletionMessage(BaseModel): + role: str + content: str class OllamaChatCompletionRequest(BaseModel): - pass - + model: str = Field(..., description="The model name, which is required.") + messages: List[OllamaChatCompletionMessage] = Field( + ..., description="A list of messages to generate a response for.") + stream: bool = Field(True, description="If true, the response will be streamed.") class OllamaChatCompletionStreamResponse(BaseModel): - pass - + model: str + created_at: str + message: str + done: bool = Field(...) class OllamaChatCompletionResponse(BaseModel): pass - @router.post("/chat", tags=['ollama']) async def chat(request: Request, input: OllamaChatCompletionRequest): - raise NotImplementedError + id = str(uuid4()) + interface: BackendInterfaceBase = get_interface() + config = Config() + # 将消息转换为提示字符串 + prompt = "" + for msg in input.messages: + prompt += f"{msg.role}: {msg.content}\n" + prompt += "assistant:" + + if input.stream: + async def inner(): + async for token in interface.inference(prompt, id): + d = OllamaChatCompletionStreamResponse( + model=config.model_name, + created_at=str(datetime.now()), + message=token, + done=False + ) + yield d.model_dump_json() + '\n' + d = OllamaChatCompletionStreamResponse( + model=config.model_name, + created_at=str(datetime.now()), + message='', + done=True + ) + yield d.model_dump_json() + '\n' + return check_link_response(request, inner()) + else: + raise NotImplementedError("Non-streaming chat is not implemented.") # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models class OllamaModel(BaseModel): @@ -103,9 +137,8 @@ class OllamaModel(BaseModel): size: int # TODO: fill the rest correctly - # mock ollama -@router.get("/tags",tags=['ollama']) +@router.get("/tags", tags=['ollama']) async def tags(): config = Config() # TODO: fill this correctly, although it does not effect Tabby @@ -138,25 +171,21 @@ class OllamaShowResponse(BaseModel): class Config: protected_namespaces = () - - @router.post("/show", tags=['ollama']) async def show(request: Request, input: OllamaShowRequest): config = Config() # TODO: Add more info in config to return, although it does not effect Tabby return OllamaShowResponse( - modelfile = "# Modelfile generated by ...", - parameters = " ", - template = " ", - details = OllamaShowDetial( - parent_model = " ", - format = "gguf", - family = " ", - families = [ - " " - ], - parameter_size = " ", - quantization_level = " " + modelfile="# Modelfile generated by ...", + parameters=" ", + template=" ", + details=OllamaShowDetial( + parent_model=" ", + format="gguf", + family=" ", + families=[" "], + parameter_size=" ", + quantization_level=" " ), - model_info = OllamaModelInfo() + model_info=OllamaModelInfo() ) \ No newline at end of file From ec7e912feed51db8c247e96ea582a9427966134a Mon Sep 17 00:00:00 2001 From: swu-hyk Date: Wed, 26 Feb 2025 19:21:30 +0800 Subject: [PATCH 26/29] modify --- .../server/api/ollama/completions.py | 34 ++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/ktransformers/server/api/ollama/completions.py b/ktransformers/server/api/ollama/completions.py index d0ac17e..0ff6183 100644 --- a/ktransformers/server/api/ollama/completions.py +++ b/ktransformers/server/api/ollama/completions.py @@ -91,8 +91,16 @@ class OllamaChatCompletionRequest(BaseModel): class OllamaChatCompletionStreamResponse(BaseModel): model: str created_at: str - message: str + message: dict done: bool = Field(...) + total_duration: Optional[int] = Field(None, description="Total time spent in nanoseconds") + load_duration: Optional[int] = Field(None, description="Time spent loading model in nanoseconds") + prompt_eval_count: Optional[int] = Field(None, description="Number of tokens in prompt") + prompt_eval_duration: Optional[int] = Field(None, description="Time spent evaluating prompt in nanoseconds") + eval_count: Optional[int] = Field(None, description="Number of tokens generated") + eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds") + + class OllamaChatCompletionResponse(BaseModel): pass @@ -111,19 +119,37 @@ async def chat(request: Request, input: OllamaChatCompletionRequest): if input.stream: async def inner(): + start_time = time() # 记录开始时间(秒) + eval_count = 0 # 统计生成的 token 数量 + tokens = [] + async for token in interface.inference(prompt, id): d = OllamaChatCompletionStreamResponse( model=config.model_name, created_at=str(datetime.now()), - message=token, + message={"role": "assistant", "content": token}, done=False ) yield d.model_dump_json() + '\n' + # 计算性能数据 + end_time = time() + total_duration = int((end_time - start_time) * 1_000_000_000) # 转换为纳秒 + prompt_eval_count = len(prompt.split()) # 简单估算提示词数量 + eval_duration = total_duration # 假设全部时间用于生成(简化) + prompt_eval_duration = 0 # 假设无单独提示评估时间 + load_duration = 0 # 假设加载时间未知 + d = OllamaChatCompletionStreamResponse( model=config.model_name, created_at=str(datetime.now()), - message='', - done=True + message={}, + done=True, + total_duration=total_duration, + load_duration=load_duration, + prompt_eval_count=prompt_eval_count, + prompt_eval_duration=prompt_eval_duration, + eval_count=eval_count, + eval_duration=eval_duration ) yield d.model_dump_json() + '\n' return check_link_response(request, inner()) From 90eb87b3fca3c6c13805d7f80c86833f379d5fe7 Mon Sep 17 00:00:00 2001 From: Atream <80757050+Atream@users.noreply.github.com> Date: Wed, 26 Feb 2025 21:53:50 +0800 Subject: [PATCH 27/29] Update DeepSeek-V3-Chat-multi-gpu-marlin.yaml --- .../optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml index 6b39121..e04c6ce 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml @@ -168,5 +168,5 @@ replace: class: "default" kwargs: - generate_device: "cuda:0" - prefill_device: "cuda:0" + generate_device: "cuda:1" + prefill_device: "cuda:1" From 369f4d917dd911e15b0587cbca703178176af9f8 Mon Sep 17 00:00:00 2001 From: Atream <80757050+Atream@users.noreply.github.com> Date: Wed, 26 Feb 2025 22:04:29 +0800 Subject: [PATCH 28/29] Update DeepseekR1_V3_tutorial.md --- doc/en/DeepseekR1_V3_tutorial.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md index 29bfe3b..02575c9 100644 --- a/doc/en/DeepseekR1_V3_tutorial.md +++ b/doc/en/DeepseekR1_V3_tutorial.md @@ -160,7 +160,7 @@ is speed up which is inspiring. So our showcase makes use of this finding* ### V0.2.2 longer context & FP8 kernel #### longer context To use this feature, [install flashinfer](https://github.com/flashinfer-ai/flashinfer) first. - +Note: The latest MLA kernel in FlashInfer still has a few minor issues. They are continuously fixing them on the main branch. If you are using FlashInfer, please install it from the main source code. If you want to use long context(longer than 20K) for prefill, enable the matrix absorption MLA during the prefill phase, which will significantly reduce the size of the kv cache. Modify yaml file like this: ``` From c05ebb74b1a04376cc4f7863a66efec1457bdede Mon Sep 17 00:00:00 2001 From: Azure Date: Wed, 26 Feb 2025 15:43:08 +0000 Subject: [PATCH 29/29] Update fp8 doc; Update install.md broken link --- doc/en/fp8_kernel.md | 20 +++++++++++--------- doc/en/install.md | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/doc/en/fp8_kernel.md b/doc/en/fp8_kernel.md index 5237a5c..e76bae5 100644 --- a/doc/en/fp8_kernel.md +++ b/doc/en/fp8_kernel.md @@ -10,15 +10,17 @@ The DeepSeek-AI team provides FP8 safetensors for DeepSeek-R1/V3 models. We achi So those who are persuing the best performance can use the FP8 linear kernel for DeepSeek-V3/R1. ## Key Features -✅ Hybrid Precision Architecture (FP8 + GGML) + +✅ Hybrid Precision Architecture (FP8 + GGML)
✅ Memory Optimization (~19GB VRAM usage) ## Quick Start ### Using Pre-Merged Weights -Pre-merged weights are available on Hugging Face: -[KVCache-ai/DeepSeek-V3-GGML-FP8-Hybrid](https://huggingface.co/KVCache-ai/DeepSeek-V3) +Pre-merged weights are available on Hugging Face:
+[KVCache-ai/DeepSeek-V3-GGML-FP8-Hybrid](https://huggingface.co/KVCache-ai/DeepSeek-V3)
[KVCache-ai/DeepSeek-R1-GGML-FP8-Hybrid](https://huggingface.co/KVCache-ai/DeepSeek-R1) + > Please confirm the weights are fully uploaded before downloading. The large file size may extend Hugging Face upload time. @@ -32,12 +34,12 @@ pip install -U huggingface_hub huggingface-cli download --resume-download KVCache-ai/DeepSeek-V3-GGML-FP8-Hybrid --local-dir ``` ### Using merge scripts -If you got local DeepSeek-R1/V3 fp8 safetensors and q4km gguf weights, you can merge them using the following scripts. +If you got local DeepSeek-R1/V3 fp8 safetensors and gguf weights(eg.q4km), you can merge them using the following scripts. ```shell -python convert_model.py \ +python merge_tensors/merge_safetensor_gguf.py \ --safetensor_path \ - --gguf_path \ + --gguf_path \ --output_path ``` @@ -60,15 +62,15 @@ python ktransformers/local_chat.py \ ## Notes -⚠️ Hardware Requirements +⚠️ Hardware Requirements
* Recommended minimum 19GB available VRAM for FP8 kernel. * Requires GPU with FP8 support (e.g., 4090) ⏳ First-Run Optimization JIT compilation causes longer initial execution (subsequent runs retain optimized speed). -🔄 Temporary Interface +🔄 Temporary Interface
Current weight loading implementation is provisional - will be refined in future versions -📁 Path Specification +📁 Path Specification
Despite hybrid quantization, merged weights are stored as .safetensors - pass the containing folder path to `--gguf_path` \ No newline at end of file diff --git a/doc/en/install.md b/doc/en/install.md index 2a4a6af..3f0acf0 100644 --- a/doc/en/install.md +++ b/doc/en/install.md @@ -121,7 +121,7 @@ We provide a simple command-line local chat Python script that you can run for t mkdir DeepSeek-V2-Lite-Chat-GGUF cd DeepSeek-V2-Lite-Chat-GGUF -wget https://huggingface.co/mzwing/DeepSeek-V2-Lite-Chat-GGUF/resolve/main/DeepSeek-V2-Lite-Chat.Q4_K_M.gguf -O DeepSeek-V2-Lite-Chat.Q4_K_M.gguf +wget https://huggingface.co/mradermacher/DeepSeek-V2-Lite-GGUF/resolve/main/DeepSeek-V2-Lite.Q4_K_M.gguf -O DeepSeek-V2-Lite-Chat.Q4_K_M.gguf cd .. # Move to repo's root dir