diff --git a/README.md b/README.md index 30f1425..f62528b 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,8 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

🔥 Updates

-* **Feb 15, 2025**: KTransformers V0.2.1: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%) (Up to 16 Tokens/s), update docs [here](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/). +* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; Longer Context (from 8K to 128K for 24GB VRAM). +* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/). * **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. For detailed show case and reproduction tutorial, see [here](./doc/en/DeepseekR1_V3_tutorial.md). * **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G. * **Aug 15, 2024**: Update detailed [tutorial](doc/en/injection_tutorial.md) for injection and multi-GPU. @@ -125,7 +126,7 @@ To utilize the provided kernels, users only need to create a YAML-based injectio ```python with torch.device("meta"): model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) -optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config) +optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) ... generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens=1000) ``` diff --git a/README_ZH.md b/README_ZH.md index 4cdd3c1..3696a08 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -21,7 +21,8 @@ KTransformers 是一个以 Python 为中心的灵活框架,其核心是可扩

🔥 更新

-* **2025 年 2 月 15 日**:KTransformers V0.2.1: 长上下文(从4K到8K,24GB VRAM) & 稍快的速度(+15%)(最快 16 Tokens/s),文档请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md) 和 [在线指南](https://kvcache-ai.github.io/ktransformers/) 。 +* **2025 年 2 月 15 日**:为DeepSeek-V3/R1支持[FP8 GPU内核](./doc/en/fp8_kernel.md); 支持更长的上下文 (从8K到128K仅用24GB VRAM). +* **2025 年 2 月 15 日**:长上下文(从4K到8K,24GB VRAM) & 稍快的速度(+15%)(最快 16 Tokens/s),文档请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md) 和 [在线指南](https://kvcache-ai.github.io/ktransformers/) 。 * **2025 年 2 月 10 日**:支持 Deepseek-R1 和 V3 在单个(24GB VRAM)/多 GPU 和 382G DRAM 上运行,速度提升高达 3~28 倍。详细教程请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md)。 * **2024 年 8 月 28 日**:支持 InternLM2.5-7B-Chat-1M 模型下的 1M 上下文,使用 24GB 的 VRAM 和 150GB 的 DRAM。详细教程请参见 [这里](./doc/en/long_context_tutorial.md)。 * **2024 年 8 月 28 日**:将 DeepseekV2 所需的 VRAM 从 21G 降低到 11G。 @@ -68,11 +69,11 @@ https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c

-

在仅 24GB VRAM 的桌面上进行 1M 上下文本地推理

-

- -https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12 + + + 更多高级功能即将推出,敬请期待! @@ -116,7 +117,7 @@ KTransformers 的核心是一个用户友好的、基于模板的注入框架。 ```python with torch.device("meta"): model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) -optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config) +optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) ... generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens=1000) ``` @@ -151,7 +152,7 @@ YAML 文件中的每个规则都有两部分:`match` 和 `replace`。`match`

致谢和贡献者

-KTransformer 的开发基于 Transformers 提供的灵活和多功能框架。我们还受益于 GGUF/GGML、Llamafile 和 Marlin 等高级内核。我们计划通过向上游贡献我们的修改来回馈社区。 +KTransformer 的开发基于 Transformers 提供的灵活和多功能框架。我们还受益于 GGUF/GGML、Llamafile 、 Marlin、sglang和flashinfer 等高级内核。我们计划通过向上游贡献我们的修改来回馈社区。 KTransformer 由清华大学 MADSys group 小组的成员以及 Approaching.AI 的成员积极维护和开发。我们欢迎新的贡献者加入我们,使 KTransformer 更快、更易于使用。 diff --git a/doc/SUMMARY.md b/doc/SUMMARY.md index 4efb7ea..94172ce 100644 --- a/doc/SUMMARY.md +++ b/doc/SUMMARY.md @@ -5,10 +5,11 @@ - [Installation Guide](en/install.md) # Tutorial -- [Deepseek-R1/V3 Show Case](en/DeepseekR1_V3_tutorial.md) +- [Deepseek-R1/V3 Show Case/Tutorial](en/DeepseekR1_V3_tutorial.md) - [Why KTransformers So Fast](en/deepseek-v2-injection.md) - [Injection Tutorial](en/injection_tutorial.md) - [Multi-GPU Tutorial](en/multi-gpu-tutorial.md) +- [Use FP8 GPU Kernel](en/fp8_kernel.md) # Server - [Server](en/api/server/server.md) - [Website](en/api/server/website.md) diff --git a/doc/en/FAQ.md b/doc/en/FAQ.md index 51c6271..c8f58a2 100644 --- a/doc/en/FAQ.md +++ b/doc/en/FAQ.md @@ -45,7 +45,7 @@ from-https://github.com/kvcache-ai/ktransformers/issues/129#issue-2842799552 ### Q: If I don't have enough VRAM, but I have multiple GPUs, how can I utilize them? -Use the `--optimize_rule_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml` to load the two optimized rule yaml file. You may also use it as an example to write your own 4/8 gpu optimized rule yaml file. +Use the `--optimize_config_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml` to load the two optimized rule yaml file. You may also use it as an example to write your own 4/8 gpu optimized rule yaml file. > Note: The ktransformers' multi-gpu stratigy is pipline, which is not able to speed up the model's inference. It's only for the model's weight distribution. @@ -55,7 +55,7 @@ You have to set `--cpu_infer` to the number of cores you want to use. The more c ### Q: My DeepSeek-R1 model is not thinking. -According to DeepSeek, you need to enforce the model to initiate its response with "\\n" at the beginning of every output by passing the arg `--force_think true `. +According to DeepSeek, you need to enforce the model to initiate its response with "\\n" at the beginning of every output by passing the arg `--force_think True `. ### Q: Loading gguf error @@ -63,9 +63,12 @@ Make sure you: 1. Have the `gguf` file in the `--gguf_path` directory. 2. The directory only contains gguf files from one model. If you have multiple models, you need to separate them into different directories. 3. The folder name it self should not end with `.gguf`, eg. `Deep-gguf` is correct, `Deep.gguf` is wrong. +4. The file itself is not corrupted; you can verify this by checking that the sha256sum matches the one from huggingface, modelscope, or hf-mirror. ### Q: Version `GLIBCXX_3.4.30' not found The detailed error: >ImportError: /mnt/data/miniconda3/envs/xxx/bin/../lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /home/xxx/xxx/ktransformers/./cpuinfer_ext.cpython-312-x86_64-linux-gnu.so) Running `conda install -c conda-forge libstdcxx-ng` can solve the problem. + + diff --git a/doc/en/fp8_kernel.md b/doc/en/fp8_kernel.md new file mode 100644 index 0000000..f1d86c3 --- /dev/null +++ b/doc/en/fp8_kernel.md @@ -0,0 +1,74 @@ +# FP8 Linear Kernel for DeepSeek-V3 + +## Overview +The DeepSeek-AI team provides FP8 safetensors for DeepSeek-R1/V3 models. We achieve performance optimization through the following works: +- **FP8 GPU Kernel Integration**: FP8 linear layer acceleration kernels integrated in KTransformers +- **Hybrid Quantization Architecture**: + - Attention and Shared-Expert modules use FP8 precision (enhances computational accuracy) + - Experts modules retain GGML quantization (GGUF format, reside in CPU to save GPU memory) + +So those who are persuing the best performance can use the FP8 linear kernel for DeepSeek-V3/R1. + +## Key Features +✅ Hybrid Precision Architecture (FP8 + GGML) +✅ Memory Optimization (~19GB VRAM usage) + +## Quick Start +### Using Pre-Merged Weights + +Pre-merged weights are available on Hugging Face: +[KVCache-ai/DeepSeek-V3](https://huggingface.co/KVCache-ai/DeepSeek-V3) +[KVCache-ai/DeepSeek-R1](https://huggingface.co/KVCache-ai/DeepSeek-R1) +> Please confirm the weights are fully uploaded before downloading. The large file size may extend Hugging Face upload time. + + +Download Pre-Merged Weights +```shell +pip install -U huggingface_hub + +# Optional: Use HF Mirror for faster downloads in special area. +# export HF_ENDPOINT=https://hf-mirror.com + +huggingface-cli download --resume-download KVCache-ai/DeepSeek-V3 --local-dir +``` +### Using merge scripts +If you got local DeepSeek-R1/V3 fp8 safetensors and q4km gguf weights, you can merge them using the following scripts. + +```shell +python convert_model.py \ + --safetensor_path \ + --gguf_path \ + --output_path +``` + +* `--safetensor_path`: input path of safetensor file([Download](https://huggingface.co/deepseek-ai/DeepSeek-V3/tree/main)). +* `--gguf_path`: input path of gguf folder ([Download](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)). +* `--output_path`: output path of merged file. + + +### Execution Notes + +Launch local_chat.py with custom quantized experts +```shell +python ktransformers/local_chat.py \ + --model_path deepseek-ai/DeepSeek-V3 \ + --gguf_path \ + --optimize_config_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml \ + --cpu_infer +``` + + +## Notes + +⚠️ Hardware Requirements +* Recommended minimum 19GB available VRAM for FP8 kernel. +* Requires GPU with FP8 support (e.g., 4090) + +⏳ First-Run Optimization +JIT compilation causes longer initial execution (subsequent runs retain optimized speed). + +🔄 Temporary Interface +Current weight loading implementation is provisional - will be refined in future versions + +📁 Path Specification +Despite hybrid quantization, merged weights are stored as .safetensors - pass the containing folder path to `--gguf_path` \ No newline at end of file diff --git a/doc/en/injection_tutorial.md b/doc/en/injection_tutorial.md index 5ebb327..4518836 100644 --- a/doc/en/injection_tutorial.md +++ b/doc/en/injection_tutorial.md @@ -59,6 +59,7 @@ Supported operators and their corresponding classes are as follows: | Linear | KTransformersLinear | KLinearMarlin | Marlin as backend | | | | KLinearTorch | pytorch as backend | | | | KLinearCPUInfer | llamafile as backend | +| | | KLinearFP8 | Triton fp8_gemm kernel. Requires GPU be able to caluculate fp8 data | | experts | KTransformersExperts | KExpertsTorch | pytorch as backend | | | | KExpertsMarlin | Marlin as backend | | | | KExpertsCPU | llamafile as backend | diff --git a/doc/en/install.md b/doc/en/install.md index 269e8fb..2a4a6af 100644 --- a/doc/en/install.md +++ b/doc/en/install.md @@ -141,7 +141,7 @@ It features the following arguments: - `--gguf_path` (required): Path of a directory containing GGUF files which could that can be downloaded from [Hugging Face](https://huggingface.co/mzwing/DeepSeek-V2-Lite-Chat-GGUF/tree/main). Note that the directory should only contains GGUF of current model, which means you need one separate directory for each model. -- `--optimize_rule_path` (required except for Qwen2Moe and DeepSeek-V2): Path of YAML file containing optimize rules. There are two rule files pre-written in the [ktransformers/optimize/optimize_rules](ktransformers/optimize/optimize_rules) directory for optimizing DeepSeek-V2 and Qwen2-57B-A14, two SOTA MoE models. +- `--optimize_config_path` (required except for Qwen2Moe and DeepSeek-V2): Path of YAML file containing optimize rules. There are two rule files pre-written in the [ktransformers/optimize/optimize_rules](ktransformers/optimize/optimize_rules) directory for optimizing DeepSeek-V2 and Qwen2-57B-A14, two SOTA MoE models. - `--max_new_tokens`: Int (default=1000). Maximum number of new tokens to generate. diff --git a/ktransformers/ktransformers_ext/triton/fp8gemm.py b/ktransformers/ktransformers_ext/triton/fp8gemm.py new file mode 100644 index 0000000..7d5b72e --- /dev/null +++ b/ktransformers/ktransformers_ext/triton/fp8gemm.py @@ -0,0 +1,192 @@ +# Adopted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py +from typing import Tuple + +import torch +import triton +import triton.language as tl +from triton import Config + + +@triton.jit +def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + """ + Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the input tensor. + y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. + s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + + Returns: + None + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.max(tl.abs(x)) / 448. + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) + + +def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), 'Input tensor must be contiguous' + assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})' + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) + act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + return y, s + + +@triton.jit +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """ + Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + +def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: + """ + Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M, N). + block_size (int, optional): The block size to use for dequantization. Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + """ + assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' + assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y + + +fp8_gemm_configs = [ + Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8) + for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6] +] + +@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) +@triton.jit +def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, + a_s_ptr, b_s_ptr, + M, N: tl.constexpr, K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr): + """ + Performs a matrix multiplication operation on FP8 matrices with scaling factors. + + Args: + a_ptr (tl.tensor): Pointer to the first input matrix A. + b_ptr (tl.tensor): Pointer to the second input matrix B. + c_ptr (tl.tensor): Pointer to the output matrix C. + a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. + b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. + M (int): Number of rows in matrix A and C. + N (tl.constexpr): Number of columns in matrix B and C. + K (tl.constexpr): Number of columns in matrix A and rows in matrix B. + BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. + BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. + BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): + """ + Perform a matrix multiplication using FP8 precision. + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. + b (torch.Tensor): The second input matrix, must be contiguous. + b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous' + assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous' + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) + fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) + return c \ No newline at end of file diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index d5e74de..920a3f6 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -54,7 +54,7 @@ default_optimize_rules = { def local_chat( model_path: str | None = None, - optimize_rule_path: str = None, + optimize_config_path: str = None, gguf_path: str | None = None, max_new_tokens: int = 300, cpu_infer: int = Config().cpu_infer, @@ -95,12 +95,12 @@ def local_chat( config, trust_remote_code=True, attn_implementation="flash_attention_2" ) - if optimize_rule_path is None: + if optimize_config_path is None: if config.architectures[0] in default_optimize_rules: print("using default_optimize_rule for", config.architectures[0]) - optimize_rule_path = default_optimize_rules[config.architectures[0]] + optimize_config_path = default_optimize_rules[config.architectures[0]] else: - optimize_rule_path = input( + optimize_config_path = input( "please input the path of your rule file(yaml file containing optimize rules):" ) @@ -108,7 +108,7 @@ def local_chat( gguf_path = input( "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" ) - optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config) + optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) try: model.generation_config = GenerationConfig.from_pretrained(model_path) diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 21b4830..1ea244a 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -245,7 +245,16 @@ class KExpertsCPU(KExpertsBase): down_type = None for key in keys: - if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: + if self.gguf_loader.safetensor_loader is not None: + # using a temp ugly way to temprary load the tensor + gate = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.weight").numpy() + up = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.weight").numpy() + down = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.weight").numpy() + gate_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.ggml_type").item() + up_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.ggml_type").item() + down_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.ggml_type").item() + + elif key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight") up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight") down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight") diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py index 52bb33a..d908093 100644 --- a/ktransformers/operators/gate.py +++ b/ktransformers/operators/gate.py @@ -67,7 +67,14 @@ class KMoEGateBase(ABC): for key in keys: key = ".".join(key.split(".")[:-1]) - if key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info: + if self.gguf_loader.safetensor_loader is not None: + targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"] + weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight") + e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias") + weight_type = weight.dtype + e_score_correction_bias_type = e_score_correction_bias.dtype + res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type} + elif key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info: targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"] tensors = self.load_multi(key, targets, device=device) weight = tensors[".ffn_gate_inp.weight"] @@ -116,8 +123,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"]) else: raise ValueError("Invalid weight type") - self.orig_module.weight = self.orig_module.weight.to(device) - self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device) + self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device)) + self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device)) def unload(self): if self.weight is not None: diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 394aa03..96d3578 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -26,6 +26,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl ) from ktransformers.operators.base_operator import BaseInjectedModule from transformers.configuration_utils import PretrainedConfig +from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant from abc import ABC, abstractmethod import sys, os sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build")) @@ -78,7 +79,13 @@ class KLinearBase(ABC): keys = [self.key] for key in keys: - if key + ".weight" in self.gguf_loader.tensor_file_map: + if self.gguf_loader.safetensor_loader is not None: + # using safetensor_loader + tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight') + weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv') + return nn.Parameter(tensor), nn.Parameter(weight_scale_inv) + + elif key + ".weight" in self.gguf_loader.tensor_file_map: if key + ".bias" in self.gguf_loader.tensor_file_map: tensors = self.load_multi(key, ["weight", "bias"], device=device) tensor = tensors["weight"] @@ -169,7 +176,61 @@ class KLinearTorch(KLinearBase): if self.has_bias: self.bias = None - +class KLinearFP8(KLinearBase): + # this kernel requires special handling for weight + # Please load the weight file downloaded from KVCache.AI + marlin_q_w: torch.Tensor + marlin_s: torch.Tensor + g_idx: torch.Tensor + sort_indices: torch.Tensor + has_bias: bool + weight: torch.Tensor + scale_w: torch.Tensor + bias: torch.Tensor + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + device: str = "cuda", + block_size: int = 128, + **kwargs, + ): + super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) + self.has_bias = False + self.dtype = torch.get_default_dtype() + self.block_size = block_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.to(self.device) + orig_dtype = x.dtype + x_quantized, scale_x = act_quant(x, self.block_size) + y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight_scale_inv) + return y.to(dtype=orig_dtype) + + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): + if device is None: device = self.device + if w is None: + w = self.load_weight(device=device) + ### TODO fit weight_inv format + if isinstance(w, tuple): + self.weight = w[0].to(device) + self.weight_scale_inv = w[1].to(device) + self.has_bias = False + else: + raise ValueError("Invalid weight type") + self.weight = self.weight.to(device) + if self.has_bias: + self.bias = self.bias.to(device) + + def unload(self): + if self.weight is not None: + self.weight = None + if self.has_bias: + self.bias = None + + class KLinearMarlin(KLinearBase): marlin_q_w: torch.Tensor marlin_s: torch.Tensor @@ -404,7 +465,8 @@ class KLinearCPUInfer(KLinearBase): LINEAR_MAP = { "KLinearMarlin": KLinearMarlin, "KLinearTorch": KLinearTorch, - "KLinearCPUInfer": KLinearCPUInfer + "KLinearCPUInfer": KLinearCPUInfer, + "KLinearFP8": KLinearFP8, } class KTransformersLinear(BaseInjectedModule, KLinearBase): @@ -440,10 +502,11 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): def forward(self, x): if self.mode == InferenceState.PREFILL: assert self.prefill_linear is not None, "cpu linear is not initialized" - return self.prefill_linear.forward(x) + y = self.prefill_linear.forward(x) else: assert self.generate_linear is not None, "gpu linear is not initialized" - return self.generate_linear.forward(x) + y = self.generate_linear.forward(x) + return y def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE): if not mode: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml new file mode 100644 index 0000000..25f021e --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml @@ -0,0 +1,63 @@ +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearFP8" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 49a3f16..55e32de 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -35,9 +35,9 @@ class KTransformersInterface(TransformersInterface): with torch.device("meta"): self.model = custom_models[config.architectures[0]](config) if default_args.optimize_config_path is None: - optimize_rule_path = default_optimize_rules[config.architectures[0]] + optimize_config_path = default_optimize_rules[config.architectures[0]] else: - optimize_rule_path = args.optimize_config_path + optimize_config_path = args.optimize_config_path # print(optimize_config) @@ -47,7 +47,7 @@ class KTransformersInterface(TransformersInterface): "please input the path of your gguf file(gguf file in the dir containing input gguf file must all" " belong to current model):" ) - optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config) + optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config) self.device_map = self.model.gguf_loader.tensor_device_map # logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}") diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 8211933..c6f8e3a 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -340,7 +340,7 @@ class TransformersInterface(BackendInterfaceBase): sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) next_token = self.decode_one_tokens() self.profiler.inc("decode") - if next_token == self.tokenizer.eos_token_id: + if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token): assert self.args.batch_size == 1 break yield self.append_new_tokens(next_token) diff --git a/ktransformers/tests/triton_fp8gemm_test.py b/ktransformers/tests/triton_fp8gemm_test.py new file mode 100644 index 0000000..58888d6 --- /dev/null +++ b/ktransformers/tests/triton_fp8gemm_test.py @@ -0,0 +1,116 @@ +import torch +import torch.nn.functional as F +from typing import Optional +import pytest +from typing import Tuple, Optional, Literal +import time +# use dir path +import os +import sys +sys.path.insert(0, "/home/azure/ktransformers") +print(sys.path) +from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant +from safetensors import safe_open + +world_size = 1 +rank = 0 +block_size = 128 +gemm_impl: Literal["bf16", "fp8"] = "bf16" +# Assuming `fp8_gemm`, `act_quant`, `weight_dequant` and other relevant functions are already defined + +def test_fp8_gemm_vs_torch_matmul(): + # Test case 1: Create random matrices of size (M, K) and (K, N) + M, K, N = 64, 128, 256 # Matrix dimensions + x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda') + weight = torch.randn(N, K, dtype=torch.bfloat16, device='cuda') + + # Apply act_quant to both matrices + x_quantized, scale_x = act_quant(x, block_size) + weight_quantized, scale_w = act_quant(weight, block_size) + + # mk continous + x_quantized = x_quantized.contiguous() + weight_quantized = weight_quantized.contiguous() + scale_x = scale_x.contiguous() + scale_w = scale_w.contiguous() + + # Perform fp8_gemm using the quantized tensors + result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight_quantized, scale_w) + + # Perform torch.matmul using the original floating point tensors + result_torch_matmul = torch.matmul(x, weight.T) + print(f'result_torch_matmul: {result_torch_matmul.shape}') + print(f'result_fp8_gemm: {result_fp8_gemm.shape}') + + print(f"result_fp8_gemm:\n {result_fp8_gemm}") + print(f"result_torch_matmul:\n {result_torch_matmul}") + +def test_fp8_gemm_vs_torch_matmul_load(): + file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors" + with safe_open(file_path, framework="pt", device=0) as f: + weight = f.get_tensor("model.layers.0.mlp.down_proj.weight") + scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv") + + # weight_dequant + weight_dequantized = weight_dequant(weight, scale) + print(f"weight_dequantized: {weight_dequantized.shape}") + N, K = weight_dequantized.shape + M = 64 + x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda') + x_quantized, scale_x = act_quant(x, block_size) + + # Test case 1: quantized x matmal with undequantized weight + result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) + print(f"result_fp8_gemm:\n {result_fp8_gemm}") + print(f"dtype {result_fp8_gemm.dtype}") + + # Perform torch.matmul using the original floating point tensors + result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T) + print(f"result_torch_matmul:\n {result_torch_matmul}") + +def test_fp8_gemm_tplops(): + file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors" + with safe_open(file_path, framework="pt", device=0) as f: + weight = f.get_tensor("model.layers.0.mlp.down_proj.weight") + scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv") + + # weight_dequant + weight_dequantized = weight_dequant(weight, scale) + print(f"weight_dequantized: {weight_dequantized.shape}") + N, K = weight_dequantized.shape + M = 6400 + x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda') + # x_quantized, scale_x = act_quant(x, block_size) + + # Calculate time for 1000 fp8_gemm + i = 10 + flops_per_gemm = 2 * M * N * K + total_flops = i * flops_per_gemm + + x_quantized, scale_x = act_quant(x, block_size) + result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) + x_quantized, scale_x = act_quant(x, block_size) + result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) + + + t0 = time.time() + torch.cuda.synchronize() + for i in range(i): + x_quantized, scale_x = act_quant(x, block_size) + result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) + torch.cuda.synchronize() + t1 = time.time() + + total_time = t1 - t0 + tflops = total_flops / total_time / 1e12 + print(f"total_time: {total_time}") + print(f"tflops: {tflops}") + + + + +if __name__ == "__main__": + test_fp8_gemm_vs_torch_matmul() + test_fp8_gemm_vs_torch_matmul_load() + test_fp8_gemm_tplops() + \ No newline at end of file diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 919f432..d26dc26 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -25,6 +25,7 @@ import os from enum import IntEnum import torch import KTransformersOps +from .custom_loader import SafeTensorLoader import ctypes class GGMLQuantizationType(IntEnum): @@ -128,6 +129,7 @@ GGML_BLOCK_SIZES = { "Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2, "Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2, "IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64, + "FP8": 1, } GGML_ELEMENTS_PER_BLOCK = { @@ -143,6 +145,7 @@ GGML_ELEMENTS_PER_BLOCK = { "Q5_K": 256, "Q6_K": 256, "IQ4_XS": 256, + "FP8": 1, } DATA_TYPES = { @@ -159,6 +162,7 @@ DATA_TYPES = { "uint64": 10, "int64": 11, "float64": 12, + "FP8": 13, } class GGUFLoader: @@ -166,12 +170,15 @@ class GGUFLoader: gguf_path: str tensor_file_map: dict # {tensor_name: tensor_file_path} gguf_file_meta: dict + safetensor_loader: SafeTensorLoader def __init__(self, gguf_path: str): # Check dir exist if not os.path.exists(gguf_path): raise FileNotFoundError(f"GGUF dir not found: {gguf_path}") if os.path.isfile(gguf_path): gguf_path = os.path.dirname(gguf_path) + + self.safetensor_loader = None self.tensor_info = {} self.gguf_path = gguf_path @@ -179,7 +186,13 @@ class GGUFLoader: self.file_data_map = {} self.gguf_file_meta = {} self.tensor_device_map = {} - + + # I know this is ugly, but I don't want to change the original code too much + # TODO: merge gguf load and other loads. + safetensor_loader = SafeTensorLoader(gguf_path) + if safetensor_loader.tensor_file_map: + self.safetensor_loader = safetensor_loader + return # Walk through all the .gguf files in the directory found_gguf = False for root, dirs, files in os.walk(gguf_path): @@ -286,6 +299,13 @@ class GGUFLoader: itemsize = int(np.empty([], dtype = item_type).itemsize) return mmap_data[offset : offset + itemsize * item_count] + def get_undequanted_tensor_and_ggml_type(self, name): + t = self.tensor_info[name] + data = self.get_mmap_tensor(name) + ggml_type = t["ggml_type"] + data = torch.from_numpy(data) + return data, ggml_type + def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "cuda", target_dtype = torch.get_default_dtype())->torch.Tensor: t = self.tensor_info[name] if device.lower() == "cpu": @@ -418,6 +438,9 @@ def read_value(f, data_type): elem_type, count = struct.unpack("dict: + """ + :param folder_path: folder path + :return: key_to_file_map + """ + # check if the folder path is exist + if not os.path.exists(folder_path): + raise FileNotFoundError(f"GGUF dir not found: {folder_path}") + if os.path.isfile(folder_path): + folder_path = os.path.dirname(folder_path) + + key_to_file_map = {} + + found_safetensor = False + for root, dirs, files in os.walk(folder_path): + # sort files + files = sorted(files) + for file in files: + if file.endswith(".safetensors"): + found_safetensor = True + file_path = os.path.join(root, file) + try: + with safe_open(file_path, framework="pt") as f: + for key in f.keys(): + if "model.layers.61" in key: + # skip MTP layer + continue + # try: + # if int(key.split('.')[2]) > 4: + # continue + # except: + # pass + key_to_file_map[key] = file_path + except Exception as e: + print(f"Error reading Safetensor file {file_path}: {e}") + + if not found_safetensor: + raise FileNotFoundError(f"No Safetensor files found in {folder_path}") + + return key_to_file_map + +tensor_from_gguf = [] # todo: add keys in gguf that should be used in the final tensor + +def translate_name(name:str)->str: + """ + :param name: name of the tensor + :return: translated name + """ + name = translate_name_to_gguf(name) + name = name.replace(".up_proj.", ".ffn_up_exps.") + name = name.replace(".down_proj.", ".ffn_down_exps.") + name = name.replace(".gate_proj.", ".ffn_gate_exps.") + name = name.replace(".ffn_gate_inp.e_score_correction_bias", ".exp_probs_b.bias") + return name + + +def combine_tensor_sources(safetensor_path:str, gguf_path:str): + gguf_loader = GGUFLoader(gguf_path) + gguf_tensor_file_map = gguf_loader.tensor_file_map + safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path) + + # build a map for the key to the tensor + # according to the key, we can get the tensor from the file + + target_tensor_map = {} + for key in safetensor_tensor_file_map.keys(): + # for all experts, we use the gguf tensor + if ".mlp.experts." in key: + if '.weight_scale_inv' in key: + continue + key = '.'.join(key.split('.')[:5]+key.split('.')[-2:]) + translated_key = translate_name(key) + target_tensor_map[key] = gguf_tensor_file_map[translated_key] + continue + + if any(target_key in key for target_key in tensor_from_gguf): + target_tensor_map[key] = gguf_tensor_file_map[translate_name(key)] + else: + target_tensor_map[key] = safetensor_tensor_file_map[key] + + return target_tensor_map, gguf_loader + +def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader): + # Ensure output directory exists + os.makedirs(output_path, exist_ok=True) + + # Cache for safetensor file handles and GGUF loaders + safetensors_cache = {} + gguf_cache = {} + + # Group tensors by layer + layer_groups = defaultdict(list) + non_layer_keys = [] + layer_pattern = re.compile(r'\.layers\.(\d+)\.') + + for key in target_tensor_map: + match = layer_pattern.search(key) + if match: + layer_num = int(match.group(1)) + layer_groups[layer_num].append(key) + else: + non_layer_keys.append(key) + + # Calculate total shards + total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1 + if total_shards == 0: + raise ValueError("No tensors to save") + + shard_idx = 0 + + # Save non-layer tensors to the first shard if they exist + if non_layer_keys: + tensors = {} + for key in non_layer_keys: + file_path = target_tensor_map[key] + tensor = None + ggml_type = None + if file_path.endswith('.safetensors'): + if file_path not in safetensors_cache: + safetensors_cache[file_path] = safe_open(file_path, framework='pt') + f = safetensors_cache[file_path] + tensor = f.get_tensor(key) + elif file_path.endswith('.gguf'): + gguf_name = translate_name(key) + tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name) + else: + raise ValueError(f"Unsupported file format: {file_path}") + tensors[translate_name(key)] = tensor + if ggml_type: + ggml_type = torch.tensor(ggml_type) + ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type" + tensors[ggml_key] = ggml_type + + output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors") + print(f"Saving non-layer tensors to {output_file}") + save_file(tensors, output_file) + print(tensors.keys()) + + shard_idx += 1 + + # Save each layer's tensors to subsequent shards + for layer_num in sorted(layer_groups.keys()): + layer_keys = layer_groups[layer_num] + tensors = {} + for key in layer_keys: + file_path = target_tensor_map[key] + tensor = None + ggml_type = None + if file_path.endswith('.safetensors'): + if file_path not in safetensors_cache: + safetensors_cache[file_path] = safe_open(file_path, framework='pt') + f = safetensors_cache[file_path] + tensor = f.get_tensor(key) + tensor_info = tensor.shape + elif file_path.endswith('.gguf'): + gguf_name = translate_name(key) + tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name) + # tensor_info = gguf_loader.tensor_info[gguf_name] + # ggml_type = gguf_loader.tensor_info[gguf_name]['ggml_type'] + else: + raise ValueError(f"Unsupported file format: {file_path}") + tensors[translate_name(key)] = tensor + if ggml_type: + ggml_type = torch.tensor(ggml_type) + ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type" + tensors[ggml_key] = ggml_type + + output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors") + print(f"Saving layer {layer_num} to {output_file}") + # print(tensors.keys()) + save_file(tensors, output_file) + shard_idx += 1 + + return + +def main(): + # 创建命令行参数解析器 + parser = argparse.ArgumentParser(description="Read parameters from Safetensor and GGUF files") + parser.add_argument("--safetensor_path", type=str, help="Path to the Safetensor file", default="/mnt/data/model/DeepSeek-V3") + parser.add_argument("--gguf_path", type=str, help="Path to the GGUF file", default="/mnt/data/model/DeepseekV3-q4km-gguf") + parser.add_argument("--output_path", type=str, help="Path to the output file", default="/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8") + + # print all the arguments + print("All the arguments:") + print(parser.parse_args()) + + # 解析命令行参数 + args = parser.parse_args() + + safetensor_path = args.safetensor_path + gguf_path = args.gguf_path + output_path = args.output_path + + target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path) + write_combined_tensor(target_tensor_map, output_path, gguf_loader) + + return + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/requirements-local_chat.txt b/requirements-local_chat.txt index 0479d36..ad280c0 100644 --- a/requirements-local_chat.txt +++ b/requirements-local_chat.txt @@ -4,4 +4,6 @@ numpy torch>=2.3.0 packaging cpufeature -protobuf \ No newline at end of file +protobuf +tiktoken +blobfile \ No newline at end of file