mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-01-19 12:43:16 +08:00
Merge pull request #643 from Azure-Tang/support-fp8
[feat] Support fp8 linear kernel;
This commit is contained in:
commit
050b745a6e
@ -23,7 +23,8 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
|
||||
|
||||
<h2 id="Updates">🔥 Updates</h2>
|
||||
|
||||
* **Feb 15, 2025**: KTransformers V0.2.1: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%) (Up to 16 Tokens/s), update docs [here](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).
|
||||
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; Longer Context (from 8K to 128K for 24GB VRAM).
|
||||
* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).
|
||||
* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. For detailed show case and reproduction tutorial, see [here](./doc/en/DeepseekR1_V3_tutorial.md).
|
||||
* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.
|
||||
* **Aug 15, 2024**: Update detailed [tutorial](doc/en/injection_tutorial.md) for injection and multi-GPU.
|
||||
@ -125,7 +126,7 @@ To utilize the provided kernels, users only need to create a YAML-based injectio
|
||||
```python
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||
optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
|
||||
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
|
||||
...
|
||||
generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens=1000)
|
||||
```
|
||||
|
||||
17
README_ZH.md
17
README_ZH.md
@ -21,7 +21,8 @@ KTransformers 是一个以 Python 为中心的灵活框架,其核心是可扩
|
||||
|
||||
<h2 id="Updates">🔥 更新</h2>
|
||||
|
||||
* **2025 年 2 月 15 日**:KTransformers V0.2.1: 长上下文(从4K到8K,24GB VRAM) & 稍快的速度(+15%)(最快 16 Tokens/s),文档请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md) 和 [在线指南](https://kvcache-ai.github.io/ktransformers/) 。
|
||||
* **2025 年 2 月 15 日**:为DeepSeek-V3/R1支持[FP8 GPU内核](./doc/en/fp8_kernel.md); 支持更长的上下文 (从8K到128K仅用24GB VRAM).
|
||||
* **2025 年 2 月 15 日**:长上下文(从4K到8K,24GB VRAM) & 稍快的速度(+15%)(最快 16 Tokens/s),文档请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md) 和 [在线指南](https://kvcache-ai.github.io/ktransformers/) 。
|
||||
* **2025 年 2 月 10 日**:支持 Deepseek-R1 和 V3 在单个(24GB VRAM)/多 GPU 和 382G DRAM 上运行,速度提升高达 3~28 倍。详细教程请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md)。
|
||||
* **2024 年 8 月 28 日**:支持 InternLM2.5-7B-Chat-1M 模型下的 1M 上下文,使用 24GB 的 VRAM 和 150GB 的 DRAM。详细教程请参见 [这里](./doc/en/long_context_tutorial.md)。
|
||||
* **2024 年 8 月 28 日**:将 DeepseekV2 所需的 VRAM 从 21G 降低到 11G。
|
||||
@ -68,11 +69,11 @@ https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c
|
||||
|
||||
</p>
|
||||
|
||||
<h3>在仅 24GB VRAM 的桌面上进行 1M 上下文本地推理</h3>
|
||||
<p align="center">
|
||||
|
||||
https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12
|
||||
<!-- <h3>在仅 24GB VRAM 的桌面上进行 1M 上下文本地推理</h3>
|
||||
<p align="center"> -->
|
||||
|
||||
<!-- https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12 -->
|
||||
<!--
|
||||
* **1M 上下文 InternLM 2.5 7B**:以全 bf16 精度运行,使用 24GB VRAM 和 150GB DRAM,可在本地桌面设置中实现。在 1M "针在干草堆中" 测试中达到 92.88% 的成功率,在 128K NIAH 测试中达到 100%。
|
||||
|
||||
<p align="center">
|
||||
@ -89,7 +90,7 @@ https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12
|
||||
|
||||
* **增强的速度**:使用稀疏注意力,通过 llamafile 内核实现 1M 上下文生成 16.91 tokens/s 的速度。这种方法比 llama.cpp 的全注意力方法快 10 倍以上。
|
||||
|
||||
* **灵活的稀疏注意力框架**:提供了一个灵活的块稀疏注意力框架,用于 CPU 卸载解码。与 SnapKV、Quest 和 InfLLm 兼容。更多信息请参见 [这里](./doc/en/long_context_introduction.md)。
|
||||
* **灵活的稀疏注意力框架**:提供了一个灵活的块稀疏注意力框架,用于 CPU 卸载解码。与 SnapKV、Quest 和 InfLLm 兼容。更多信息请参见 [这里](./doc/en/long_context_introduction.md)。 -->
|
||||
|
||||
<strong>更多高级功能即将推出,敬请期待!</strong>
|
||||
|
||||
@ -116,7 +117,7 @@ KTransformers 的核心是一个用户友好的、基于模板的注入框架。
|
||||
```python
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||
optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
|
||||
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
|
||||
...
|
||||
generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens=1000)
|
||||
```
|
||||
@ -151,7 +152,7 @@ YAML 文件中的每个规则都有两部分:`match` 和 `replace`。`match`
|
||||
|
||||
<h2 id="ack">致谢和贡献者</h2>
|
||||
|
||||
KTransformer 的开发基于 Transformers 提供的灵活和多功能框架。我们还受益于 GGUF/GGML、Llamafile 和 Marlin 等高级内核。我们计划通过向上游贡献我们的修改来回馈社区。
|
||||
KTransformer 的开发基于 Transformers 提供的灵活和多功能框架。我们还受益于 GGUF/GGML、Llamafile 、 Marlin、sglang和flashinfer 等高级内核。我们计划通过向上游贡献我们的修改来回馈社区。
|
||||
|
||||
KTransformer 由清华大学 <a href="https://madsys.cs.tsinghua.edu.cn/">MADSys group</a> 小组的成员以及 <a href="http://approaching.ai/">Approaching.AI</a> 的成员积极维护和开发。我们欢迎新的贡献者加入我们,使 KTransformer 更快、更易于使用。
|
||||
|
||||
|
||||
@ -5,10 +5,11 @@
|
||||
- [Installation Guide](en/install.md)
|
||||
|
||||
# Tutorial
|
||||
- [Deepseek-R1/V3 Show Case](en/DeepseekR1_V3_tutorial.md)
|
||||
- [Deepseek-R1/V3 Show Case/Tutorial](en/DeepseekR1_V3_tutorial.md)
|
||||
- [Why KTransformers So Fast](en/deepseek-v2-injection.md)
|
||||
- [Injection Tutorial](en/injection_tutorial.md)
|
||||
- [Multi-GPU Tutorial](en/multi-gpu-tutorial.md)
|
||||
- [Use FP8 GPU Kernel](en/fp8_kernel.md)
|
||||
# Server
|
||||
- [Server](en/api/server/server.md)
|
||||
- [Website](en/api/server/website.md)
|
||||
|
||||
@ -45,7 +45,7 @@ from-https://github.com/kvcache-ai/ktransformers/issues/129#issue-2842799552
|
||||
|
||||
### Q: If I don't have enough VRAM, but I have multiple GPUs, how can I utilize them?
|
||||
|
||||
Use the `--optimize_rule_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml` to load the two optimized rule yaml file. You may also use it as an example to write your own 4/8 gpu optimized rule yaml file.
|
||||
Use the `--optimize_config_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml` to load the two optimized rule yaml file. You may also use it as an example to write your own 4/8 gpu optimized rule yaml file.
|
||||
|
||||
> Note: The ktransformers' multi-gpu stratigy is pipline, which is not able to speed up the model's inference. It's only for the model's weight distribution.
|
||||
|
||||
@ -55,7 +55,7 @@ You have to set `--cpu_infer` to the number of cores you want to use. The more c
|
||||
|
||||
### Q: My DeepSeek-R1 model is not thinking.
|
||||
|
||||
According to DeepSeek, you need to enforce the model to initiate its response with "\<think>\n" at the beginning of every output by passing the arg `--force_think true `.
|
||||
According to DeepSeek, you need to enforce the model to initiate its response with "\<think>\n" at the beginning of every output by passing the arg `--force_think True `.
|
||||
|
||||
### Q: Loading gguf error
|
||||
|
||||
@ -63,9 +63,12 @@ Make sure you:
|
||||
1. Have the `gguf` file in the `--gguf_path` directory.
|
||||
2. The directory only contains gguf files from one model. If you have multiple models, you need to separate them into different directories.
|
||||
3. The folder name it self should not end with `.gguf`, eg. `Deep-gguf` is correct, `Deep.gguf` is wrong.
|
||||
4. The file itself is not corrupted; you can verify this by checking that the sha256sum matches the one from huggingface, modelscope, or hf-mirror.
|
||||
|
||||
### Q: Version `GLIBCXX_3.4.30' not found
|
||||
The detailed error:
|
||||
>ImportError: /mnt/data/miniconda3/envs/xxx/bin/../lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /home/xxx/xxx/ktransformers/./cpuinfer_ext.cpython-312-x86_64-linux-gnu.so)
|
||||
|
||||
Running `conda install -c conda-forge libstdcxx-ng` can solve the problem.
|
||||
|
||||
|
||||
|
||||
74
doc/en/fp8_kernel.md
Normal file
74
doc/en/fp8_kernel.md
Normal file
@ -0,0 +1,74 @@
|
||||
# FP8 Linear Kernel for DeepSeek-V3
|
||||
|
||||
## Overview
|
||||
The DeepSeek-AI team provides FP8 safetensors for DeepSeek-R1/V3 models. We achieve performance optimization through the following works:
|
||||
- **FP8 GPU Kernel Integration**: FP8 linear layer acceleration kernels integrated in KTransformers
|
||||
- **Hybrid Quantization Architecture**:
|
||||
- Attention and Shared-Expert modules use FP8 precision (enhances computational accuracy)
|
||||
- Experts modules retain GGML quantization (GGUF format, reside in CPU to save GPU memory)
|
||||
|
||||
So those who are persuing the best performance can use the FP8 linear kernel for DeepSeek-V3/R1.
|
||||
|
||||
## Key Features
|
||||
✅ Hybrid Precision Architecture (FP8 + GGML)
|
||||
✅ Memory Optimization (~19GB VRAM usage)
|
||||
|
||||
## Quick Start
|
||||
### Using Pre-Merged Weights
|
||||
|
||||
Pre-merged weights are available on Hugging Face:
|
||||
[KVCache-ai/DeepSeek-V3](https://huggingface.co/KVCache-ai/DeepSeek-V3)
|
||||
[KVCache-ai/DeepSeek-R1](https://huggingface.co/KVCache-ai/DeepSeek-R1)
|
||||
> Please confirm the weights are fully uploaded before downloading. The large file size may extend Hugging Face upload time.
|
||||
|
||||
|
||||
Download Pre-Merged Weights
|
||||
```shell
|
||||
pip install -U huggingface_hub
|
||||
|
||||
# Optional: Use HF Mirror for faster downloads in special area.
|
||||
# export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
huggingface-cli download --resume-download KVCache-ai/DeepSeek-V3 --local-dir <local_dir>
|
||||
```
|
||||
### Using merge scripts
|
||||
If you got local DeepSeek-R1/V3 fp8 safetensors and q4km gguf weights, you can merge them using the following scripts.
|
||||
|
||||
```shell
|
||||
python convert_model.py \
|
||||
--safetensor_path <fp8_safetensor_path> \
|
||||
--gguf_path <q4km_gguf_folder_path> \
|
||||
--output_path <merged_output_path>
|
||||
```
|
||||
|
||||
* `--safetensor_path`: input path of safetensor file([Download](https://huggingface.co/deepseek-ai/DeepSeek-V3/tree/main)).
|
||||
* `--gguf_path`: input path of gguf folder ([Download](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)).
|
||||
* `--output_path`: output path of merged file.
|
||||
|
||||
|
||||
### Execution Notes
|
||||
|
||||
Launch local_chat.py with custom quantized experts
|
||||
```shell
|
||||
python ktransformers/local_chat.py \
|
||||
--model_path deepseek-ai/DeepSeek-V3 \
|
||||
--gguf_path <merged_weights_folder> \
|
||||
--optimize_config_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml \
|
||||
--cpu_infer <cpu_cores + 1>
|
||||
```
|
||||
|
||||
|
||||
## Notes
|
||||
|
||||
⚠️ Hardware Requirements
|
||||
* Recommended minimum 19GB available VRAM for FP8 kernel.
|
||||
* Requires GPU with FP8 support (e.g., 4090)
|
||||
|
||||
⏳ First-Run Optimization
|
||||
JIT compilation causes longer initial execution (subsequent runs retain optimized speed).
|
||||
|
||||
🔄 Temporary Interface
|
||||
Current weight loading implementation is provisional - will be refined in future versions
|
||||
|
||||
📁 Path Specification
|
||||
Despite hybrid quantization, merged weights are stored as .safetensors - pass the containing folder path to `--gguf_path`
|
||||
@ -59,6 +59,7 @@ Supported operators and their corresponding classes are as follows:
|
||||
| Linear | KTransformersLinear | KLinearMarlin | Marlin as backend |
|
||||
| | | KLinearTorch | pytorch as backend |
|
||||
| | | KLinearCPUInfer | llamafile as backend |
|
||||
| | | KLinearFP8 | Triton fp8_gemm kernel. Requires GPU be able to caluculate fp8 data |
|
||||
| experts | KTransformersExperts | KExpertsTorch | pytorch as backend |
|
||||
| | | KExpertsMarlin | Marlin as backend |
|
||||
| | | KExpertsCPU | llamafile as backend |
|
||||
|
||||
@ -141,7 +141,7 @@ It features the following arguments:
|
||||
|
||||
- `--gguf_path` (required): Path of a directory containing GGUF files which could that can be downloaded from [Hugging Face](https://huggingface.co/mzwing/DeepSeek-V2-Lite-Chat-GGUF/tree/main). Note that the directory should only contains GGUF of current model, which means you need one separate directory for each model.
|
||||
|
||||
- `--optimize_rule_path` (required except for Qwen2Moe and DeepSeek-V2): Path of YAML file containing optimize rules. There are two rule files pre-written in the [ktransformers/optimize/optimize_rules](ktransformers/optimize/optimize_rules) directory for optimizing DeepSeek-V2 and Qwen2-57B-A14, two SOTA MoE models.
|
||||
- `--optimize_config_path` (required except for Qwen2Moe and DeepSeek-V2): Path of YAML file containing optimize rules. There are two rule files pre-written in the [ktransformers/optimize/optimize_rules](ktransformers/optimize/optimize_rules) directory for optimizing DeepSeek-V2 and Qwen2-57B-A14, two SOTA MoE models.
|
||||
|
||||
- `--max_new_tokens`: Int (default=1000). Maximum number of new tokens to generate.
|
||||
|
||||
|
||||
192
ktransformers/ktransformers_ext/triton/fp8gemm.py
Normal file
192
ktransformers/ktransformers_ext/triton/fp8gemm.py
Normal file
@ -0,0 +1,192 @@
|
||||
# Adopted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton import Config
|
||||
|
||||
|
||||
@triton.jit
|
||||
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
"""
|
||||
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
|
||||
|
||||
Args:
|
||||
x_ptr (triton.Pointer): Pointer to the input tensor.
|
||||
y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
|
||||
s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
|
||||
BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(x_ptr + offs).to(tl.float32)
|
||||
s = tl.max(tl.abs(x)) / 448.
|
||||
y = x / s
|
||||
y = y.to(y_ptr.dtype.element_ty)
|
||||
tl.store(y_ptr + offs, y)
|
||||
tl.store(s_ptr + pid, s)
|
||||
|
||||
|
||||
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantizes the input tensor `x` using block-wise quantization.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
|
||||
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
||||
- A tensor of scaling factors with dtype `torch.float32`.
|
||||
"""
|
||||
assert x.is_contiguous(), 'Input tensor must be contiguous'
|
||||
assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
|
||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
||||
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
|
||||
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
||||
return y, s
|
||||
|
||||
|
||||
@triton.jit
|
||||
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
||||
"""
|
||||
Dequantizes weights using the provided scaling factors and stores the result.
|
||||
|
||||
Args:
|
||||
x_ptr (tl.pointer): Pointer to the quantized weights.
|
||||
s_ptr (tl.pointer): Pointer to the scaling factors.
|
||||
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
|
||||
M (int): Number of rows in the weight matrix.
|
||||
N (int): Number of columns in the weight matrix.
|
||||
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_n = tl.program_id(axis=1)
|
||||
n = tl.cdiv(N, BLOCK_SIZE)
|
||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
offs = offs_m[:, None] * N + offs_n[None, :]
|
||||
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
||||
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
|
||||
s = tl.load(s_ptr + pid_m * n + pid_n)
|
||||
y = x * s
|
||||
tl.store(y_ptr + offs, y, mask=mask)
|
||||
|
||||
|
||||
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||
"""
|
||||
Dequantizes the given weight tensor using the provided scale tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The quantized weight tensor of shape (M, N).
|
||||
s (torch.Tensor): The scale tensor of shape (M, N).
|
||||
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
|
||||
"""
|
||||
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
|
||||
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
|
||||
M, N = x.size()
|
||||
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
||||
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
|
||||
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
||||
return y
|
||||
|
||||
|
||||
fp8_gemm_configs = [
|
||||
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
|
||||
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
|
||||
]
|
||||
|
||||
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
|
||||
@triton.jit
|
||||
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
||||
a_s_ptr, b_s_ptr,
|
||||
M, N: tl.constexpr, K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr):
|
||||
"""
|
||||
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
||||
|
||||
Args:
|
||||
a_ptr (tl.tensor): Pointer to the first input matrix A.
|
||||
b_ptr (tl.tensor): Pointer to the second input matrix B.
|
||||
c_ptr (tl.tensor): Pointer to the output matrix C.
|
||||
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
|
||||
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
|
||||
M (int): Number of rows in matrix A and C.
|
||||
N (tl.constexpr): Number of columns in matrix B and C.
|
||||
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
|
||||
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
|
||||
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
|
||||
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_n = tl.program_id(axis=1)
|
||||
k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
|
||||
a_s_ptrs = a_s_ptr + offs_m * k
|
||||
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for i in range(k):
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
|
||||
a_s = tl.load(a_s_ptrs)
|
||||
b_s = tl.load(b_s_ptrs)
|
||||
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b_ptrs += BLOCK_SIZE_K
|
||||
a_s_ptrs += 1
|
||||
b_s_ptrs += 1
|
||||
c = accumulator.to(c_ptr.dtype.element_ty)
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=mask)
|
||||
|
||||
|
||||
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
|
||||
"""
|
||||
Perform a matrix multiplication using FP8 precision.
|
||||
|
||||
Args:
|
||||
a (torch.Tensor): The first input matrix, must be contiguous.
|
||||
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
|
||||
b (torch.Tensor): The second input matrix, must be contiguous.
|
||||
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The result of the matrix multiplication.
|
||||
"""
|
||||
assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
|
||||
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
|
||||
K = a.size(-1)
|
||||
M = a.numel() // K
|
||||
N = b.size(0)
|
||||
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
|
||||
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
||||
return c
|
||||
@ -54,7 +54,7 @@ default_optimize_rules = {
|
||||
|
||||
def local_chat(
|
||||
model_path: str | None = None,
|
||||
optimize_rule_path: str = None,
|
||||
optimize_config_path: str = None,
|
||||
gguf_path: str | None = None,
|
||||
max_new_tokens: int = 300,
|
||||
cpu_infer: int = Config().cpu_infer,
|
||||
@ -95,12 +95,12 @@ def local_chat(
|
||||
config, trust_remote_code=True, attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
if optimize_rule_path is None:
|
||||
if optimize_config_path is None:
|
||||
if config.architectures[0] in default_optimize_rules:
|
||||
print("using default_optimize_rule for", config.architectures[0])
|
||||
optimize_rule_path = default_optimize_rules[config.architectures[0]]
|
||||
optimize_config_path = default_optimize_rules[config.architectures[0]]
|
||||
else:
|
||||
optimize_rule_path = input(
|
||||
optimize_config_path = input(
|
||||
"please input the path of your rule file(yaml file containing optimize rules):"
|
||||
)
|
||||
|
||||
@ -108,7 +108,7 @@ def local_chat(
|
||||
gguf_path = input(
|
||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
|
||||
)
|
||||
optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
|
||||
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
|
||||
|
||||
try:
|
||||
model.generation_config = GenerationConfig.from_pretrained(model_path)
|
||||
|
||||
@ -245,7 +245,16 @@ class KExpertsCPU(KExpertsBase):
|
||||
down_type = None
|
||||
|
||||
for key in keys:
|
||||
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
|
||||
if self.gguf_loader.safetensor_loader is not None:
|
||||
# using a temp ugly way to temprary load the tensor
|
||||
gate = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.weight").numpy()
|
||||
up = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.weight").numpy()
|
||||
down = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.weight").numpy()
|
||||
gate_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.ggml_type").item()
|
||||
up_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.ggml_type").item()
|
||||
down_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.ggml_type").item()
|
||||
|
||||
elif key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
|
||||
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
|
||||
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
|
||||
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
|
||||
|
||||
@ -67,7 +67,14 @@ class KMoEGateBase(ABC):
|
||||
|
||||
for key in keys:
|
||||
key = ".".join(key.split(".")[:-1])
|
||||
if key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
|
||||
if self.gguf_loader.safetensor_loader is not None:
|
||||
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
|
||||
weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight")
|
||||
e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias")
|
||||
weight_type = weight.dtype
|
||||
e_score_correction_bias_type = e_score_correction_bias.dtype
|
||||
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
|
||||
elif key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
|
||||
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
|
||||
tensors = self.load_multi(key, targets, device=device)
|
||||
weight = tensors[".ffn_gate_inp.weight"]
|
||||
@ -116,8 +123,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
|
||||
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
self.orig_module.weight = self.orig_module.weight.to(device)
|
||||
self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
|
||||
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
|
||||
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
|
||||
|
||||
def unload(self):
|
||||
if self.weight is not None:
|
||||
|
||||
@ -26,6 +26,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
|
||||
)
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
|
||||
from abc import ABC, abstractmethod
|
||||
import sys, os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
|
||||
@ -78,7 +79,13 @@ class KLinearBase(ABC):
|
||||
keys = [self.key]
|
||||
|
||||
for key in keys:
|
||||
if key + ".weight" in self.gguf_loader.tensor_file_map:
|
||||
if self.gguf_loader.safetensor_loader is not None:
|
||||
# using safetensor_loader
|
||||
tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
|
||||
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
|
||||
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
|
||||
|
||||
elif key + ".weight" in self.gguf_loader.tensor_file_map:
|
||||
if key + ".bias" in self.gguf_loader.tensor_file_map:
|
||||
tensors = self.load_multi(key, ["weight", "bias"], device=device)
|
||||
tensor = tensors["weight"]
|
||||
@ -169,7 +176,61 @@ class KLinearTorch(KLinearBase):
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
|
||||
class KLinearFP8(KLinearBase):
|
||||
# this kernel requires special handling for weight
|
||||
# Please load the weight file downloaded from KVCache.AI
|
||||
marlin_q_w: torch.Tensor
|
||||
marlin_s: torch.Tensor
|
||||
g_idx: torch.Tensor
|
||||
sort_indices: torch.Tensor
|
||||
has_bias: bool
|
||||
weight: torch.Tensor
|
||||
scale_w: torch.Tensor
|
||||
bias: torch.Tensor
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
device: str = "cuda",
|
||||
block_size: int = 128,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.has_bias = False
|
||||
self.dtype = torch.get_default_dtype()
|
||||
self.block_size = block_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.to(self.device)
|
||||
orig_dtype = x.dtype
|
||||
x_quantized, scale_x = act_quant(x, self.block_size)
|
||||
y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight_scale_inv)
|
||||
return y.to(dtype=orig_dtype)
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if device is None: device = self.device
|
||||
if w is None:
|
||||
w = self.load_weight(device=device)
|
||||
### TODO fit weight_inv format
|
||||
if isinstance(w, tuple):
|
||||
self.weight = w[0].to(device)
|
||||
self.weight_scale_inv = w[1].to(device)
|
||||
self.has_bias = False
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
self.weight = self.weight.to(device)
|
||||
if self.has_bias:
|
||||
self.bias = self.bias.to(device)
|
||||
|
||||
def unload(self):
|
||||
if self.weight is not None:
|
||||
self.weight = None
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
|
||||
class KLinearMarlin(KLinearBase):
|
||||
marlin_q_w: torch.Tensor
|
||||
marlin_s: torch.Tensor
|
||||
@ -404,7 +465,8 @@ class KLinearCPUInfer(KLinearBase):
|
||||
LINEAR_MAP = {
|
||||
"KLinearMarlin": KLinearMarlin,
|
||||
"KLinearTorch": KLinearTorch,
|
||||
"KLinearCPUInfer": KLinearCPUInfer
|
||||
"KLinearCPUInfer": KLinearCPUInfer,
|
||||
"KLinearFP8": KLinearFP8,
|
||||
}
|
||||
|
||||
class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
||||
@ -440,10 +502,11 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
||||
def forward(self, x):
|
||||
if self.mode == InferenceState.PREFILL:
|
||||
assert self.prefill_linear is not None, "cpu linear is not initialized"
|
||||
return self.prefill_linear.forward(x)
|
||||
y = self.prefill_linear.forward(x)
|
||||
else:
|
||||
assert self.generate_linear is not None, "gpu linear is not initialized"
|
||||
return self.generate_linear.forward(x)
|
||||
y = self.generate_linear.forward(x)
|
||||
return y
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
|
||||
if not mode:
|
||||
|
||||
@ -0,0 +1,63 @@
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearFP8"
|
||||
prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
@ -35,9 +35,9 @@ class KTransformersInterface(TransformersInterface):
|
||||
with torch.device("meta"):
|
||||
self.model = custom_models[config.architectures[0]](config)
|
||||
if default_args.optimize_config_path is None:
|
||||
optimize_rule_path = default_optimize_rules[config.architectures[0]]
|
||||
optimize_config_path = default_optimize_rules[config.architectures[0]]
|
||||
else:
|
||||
optimize_rule_path = args.optimize_config_path
|
||||
optimize_config_path = args.optimize_config_path
|
||||
|
||||
# print(optimize_config)
|
||||
|
||||
@ -47,7 +47,7 @@ class KTransformersInterface(TransformersInterface):
|
||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
|
||||
" belong to current model):"
|
||||
)
|
||||
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
|
||||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
||||
|
||||
self.device_map = self.model.gguf_loader.tensor_device_map
|
||||
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
|
||||
|
||||
@ -340,7 +340,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||
sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
next_token = self.decode_one_tokens()
|
||||
self.profiler.inc("decode")
|
||||
if next_token == self.tokenizer.eos_token_id:
|
||||
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
|
||||
assert self.args.batch_size == 1
|
||||
break
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
116
ktransformers/tests/triton_fp8gemm_test.py
Normal file
116
ktransformers/tests/triton_fp8gemm_test.py
Normal file
@ -0,0 +1,116 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
import pytest
|
||||
from typing import Tuple, Optional, Literal
|
||||
import time
|
||||
# use dir path
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, "/home/azure/ktransformers")
|
||||
print(sys.path)
|
||||
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
|
||||
from safetensors import safe_open
|
||||
|
||||
world_size = 1
|
||||
rank = 0
|
||||
block_size = 128
|
||||
gemm_impl: Literal["bf16", "fp8"] = "bf16"
|
||||
# Assuming `fp8_gemm`, `act_quant`, `weight_dequant` and other relevant functions are already defined
|
||||
|
||||
def test_fp8_gemm_vs_torch_matmul():
|
||||
# Test case 1: Create random matrices of size (M, K) and (K, N)
|
||||
M, K, N = 64, 128, 256 # Matrix dimensions
|
||||
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
|
||||
weight = torch.randn(N, K, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# Apply act_quant to both matrices
|
||||
x_quantized, scale_x = act_quant(x, block_size)
|
||||
weight_quantized, scale_w = act_quant(weight, block_size)
|
||||
|
||||
# mk continous
|
||||
x_quantized = x_quantized.contiguous()
|
||||
weight_quantized = weight_quantized.contiguous()
|
||||
scale_x = scale_x.contiguous()
|
||||
scale_w = scale_w.contiguous()
|
||||
|
||||
# Perform fp8_gemm using the quantized tensors
|
||||
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight_quantized, scale_w)
|
||||
|
||||
# Perform torch.matmul using the original floating point tensors
|
||||
result_torch_matmul = torch.matmul(x, weight.T)
|
||||
print(f'result_torch_matmul: {result_torch_matmul.shape}')
|
||||
print(f'result_fp8_gemm: {result_fp8_gemm.shape}')
|
||||
|
||||
print(f"result_fp8_gemm:\n {result_fp8_gemm}")
|
||||
print(f"result_torch_matmul:\n {result_torch_matmul}")
|
||||
|
||||
def test_fp8_gemm_vs_torch_matmul_load():
|
||||
file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors"
|
||||
with safe_open(file_path, framework="pt", device=0) as f:
|
||||
weight = f.get_tensor("model.layers.0.mlp.down_proj.weight")
|
||||
scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv")
|
||||
|
||||
# weight_dequant
|
||||
weight_dequantized = weight_dequant(weight, scale)
|
||||
print(f"weight_dequantized: {weight_dequantized.shape}")
|
||||
N, K = weight_dequantized.shape
|
||||
M = 64
|
||||
x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')
|
||||
x_quantized, scale_x = act_quant(x, block_size)
|
||||
|
||||
# Test case 1: quantized x matmal with undequantized weight
|
||||
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
||||
print(f"result_fp8_gemm:\n {result_fp8_gemm}")
|
||||
print(f"dtype {result_fp8_gemm.dtype}")
|
||||
|
||||
# Perform torch.matmul using the original floating point tensors
|
||||
result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T)
|
||||
print(f"result_torch_matmul:\n {result_torch_matmul}")
|
||||
|
||||
def test_fp8_gemm_tplops():
|
||||
file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors"
|
||||
with safe_open(file_path, framework="pt", device=0) as f:
|
||||
weight = f.get_tensor("model.layers.0.mlp.down_proj.weight")
|
||||
scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv")
|
||||
|
||||
# weight_dequant
|
||||
weight_dequantized = weight_dequant(weight, scale)
|
||||
print(f"weight_dequantized: {weight_dequantized.shape}")
|
||||
N, K = weight_dequantized.shape
|
||||
M = 6400
|
||||
x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')
|
||||
# x_quantized, scale_x = act_quant(x, block_size)
|
||||
|
||||
# Calculate time for 1000 fp8_gemm
|
||||
i = 10
|
||||
flops_per_gemm = 2 * M * N * K
|
||||
total_flops = i * flops_per_gemm
|
||||
|
||||
x_quantized, scale_x = act_quant(x, block_size)
|
||||
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
||||
x_quantized, scale_x = act_quant(x, block_size)
|
||||
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
||||
|
||||
|
||||
t0 = time.time()
|
||||
torch.cuda.synchronize()
|
||||
for i in range(i):
|
||||
x_quantized, scale_x = act_quant(x, block_size)
|
||||
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
|
||||
total_time = t1 - t0
|
||||
tflops = total_flops / total_time / 1e12
|
||||
print(f"total_time: {total_time}")
|
||||
print(f"tflops: {tflops}")
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fp8_gemm_vs_torch_matmul()
|
||||
test_fp8_gemm_vs_torch_matmul_load()
|
||||
test_fp8_gemm_tplops()
|
||||
|
||||
@ -25,6 +25,7 @@ import os
|
||||
from enum import IntEnum
|
||||
import torch
|
||||
import KTransformersOps
|
||||
from .custom_loader import SafeTensorLoader
|
||||
import ctypes
|
||||
|
||||
class GGMLQuantizationType(IntEnum):
|
||||
@ -128,6 +129,7 @@ GGML_BLOCK_SIZES = {
|
||||
"Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
|
||||
"Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2,
|
||||
"IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64,
|
||||
"FP8": 1,
|
||||
}
|
||||
|
||||
GGML_ELEMENTS_PER_BLOCK = {
|
||||
@ -143,6 +145,7 @@ GGML_ELEMENTS_PER_BLOCK = {
|
||||
"Q5_K": 256,
|
||||
"Q6_K": 256,
|
||||
"IQ4_XS": 256,
|
||||
"FP8": 1,
|
||||
}
|
||||
|
||||
DATA_TYPES = {
|
||||
@ -159,6 +162,7 @@ DATA_TYPES = {
|
||||
"uint64": 10,
|
||||
"int64": 11,
|
||||
"float64": 12,
|
||||
"FP8": 13,
|
||||
}
|
||||
|
||||
class GGUFLoader:
|
||||
@ -166,12 +170,15 @@ class GGUFLoader:
|
||||
gguf_path: str
|
||||
tensor_file_map: dict # {tensor_name: tensor_file_path}
|
||||
gguf_file_meta: dict
|
||||
safetensor_loader: SafeTensorLoader
|
||||
def __init__(self, gguf_path: str):
|
||||
# Check dir exist
|
||||
if not os.path.exists(gguf_path):
|
||||
raise FileNotFoundError(f"GGUF dir not found: {gguf_path}")
|
||||
if os.path.isfile(gguf_path):
|
||||
gguf_path = os.path.dirname(gguf_path)
|
||||
|
||||
self.safetensor_loader = None
|
||||
|
||||
self.tensor_info = {}
|
||||
self.gguf_path = gguf_path
|
||||
@ -179,7 +186,13 @@ class GGUFLoader:
|
||||
self.file_data_map = {}
|
||||
self.gguf_file_meta = {}
|
||||
self.tensor_device_map = {}
|
||||
|
||||
|
||||
# I know this is ugly, but I don't want to change the original code too much
|
||||
# TODO: merge gguf load and other loads.
|
||||
safetensor_loader = SafeTensorLoader(gguf_path)
|
||||
if safetensor_loader.tensor_file_map:
|
||||
self.safetensor_loader = safetensor_loader
|
||||
return
|
||||
# Walk through all the .gguf files in the directory
|
||||
found_gguf = False
|
||||
for root, dirs, files in os.walk(gguf_path):
|
||||
@ -286,6 +299,13 @@ class GGUFLoader:
|
||||
itemsize = int(np.empty([], dtype = item_type).itemsize)
|
||||
return mmap_data[offset : offset + itemsize * item_count]
|
||||
|
||||
def get_undequanted_tensor_and_ggml_type(self, name):
|
||||
t = self.tensor_info[name]
|
||||
data = self.get_mmap_tensor(name)
|
||||
ggml_type = t["ggml_type"]
|
||||
data = torch.from_numpy(data)
|
||||
return data, ggml_type
|
||||
|
||||
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "cuda", target_dtype = torch.get_default_dtype())->torch.Tensor:
|
||||
t = self.tensor_info[name]
|
||||
if device.lower() == "cpu":
|
||||
@ -418,6 +438,9 @@ def read_value(f, data_type):
|
||||
elem_type, count = struct.unpack("<IQ", f.read(4 + 8))
|
||||
return [read_value(f, elem_type) for _ in range(count)]
|
||||
|
||||
elif data_type == DATA_TYPES["FP8"]:
|
||||
return struct.unpack("<B", f.read(1))[0]
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Data type {data_type} not implemented")
|
||||
|
||||
|
||||
86
ktransformers/util/custom_loader.py
Normal file
86
ktransformers/util/custom_loader.py
Normal file
@ -0,0 +1,86 @@
|
||||
import struct
|
||||
import warnings
|
||||
import numpy as np
|
||||
import re
|
||||
import numpy.typing as npt
|
||||
from typing import Sequence
|
||||
import os
|
||||
from enum import IntEnum
|
||||
import torch
|
||||
import KTransformersOps
|
||||
from safetensors import safe_open
|
||||
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
|
||||
from safetensors.torch import save_file
|
||||
|
||||
class SafeTensorLoader:
|
||||
tensor_file_map = {}
|
||||
tensor_type_map = {}
|
||||
file_handle_map = {}
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
self.__load_tensor_file_map(file_path)
|
||||
|
||||
def __load_tensor_file_map(self, file_path: str):
|
||||
# 处理传入路径,确保是文件夹路径
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Path not found: {file_path}")
|
||||
if os.path.isfile(file_path):
|
||||
folder_path = os.path.dirname(file_path)
|
||||
else:
|
||||
folder_path = file_path
|
||||
|
||||
found_safetensor = False
|
||||
for root, _, files in os.walk(folder_path):
|
||||
files = sorted(files)
|
||||
for file in files:
|
||||
if file.endswith(".safetensors"):
|
||||
found_safetensor = True
|
||||
file_path = os.path.join(root, file)
|
||||
if file not in self.file_handle_map:
|
||||
try:
|
||||
handle = safe_open(file_path, framework="pt")
|
||||
self.file_handle_map[file] = handle
|
||||
except Exception as e:
|
||||
print(f"Error opening Safetensor file {file_path}: {e}")
|
||||
continue
|
||||
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
continue
|
||||
try:
|
||||
for key in f.keys():
|
||||
self.tensor_file_map[key] = file
|
||||
except Exception as e:
|
||||
print(f"Error reading Safetensor file {file_path}: {e}")
|
||||
|
||||
# if not found_safetensor:
|
||||
# raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
|
||||
|
||||
def load_tensor(self, key: str, device: str="cpu"):
|
||||
if key not in self.tensor_file_map:
|
||||
raise KeyError(f"Key {key} not found in Safetensor files")
|
||||
file = self.tensor_file_map[key]
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||
tensor = f.get_tensor(key)
|
||||
return tensor.to(device)
|
||||
|
||||
def close_all_handles(self):
|
||||
for handle in self.file_handle_map.values():
|
||||
handle.close()
|
||||
self.file_handle_map.clear()
|
||||
|
||||
def load_dequantized_tensor(self, key:str, device: str="cpu"):
|
||||
if key not in self.tensor_file_map:
|
||||
raise KeyError(f"Key {key} not found in Safetensor files")
|
||||
file = self.tensor_file_map[key]
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||
tensor = f.get_tensor(key).to(device)
|
||||
if key.endswith(".weight"):
|
||||
if key[:-7] + ".weight_scale_inv" in self.tensor_file_map:
|
||||
weight_scale_inv = f.get_tensor(key[:-7] + ".weight_scale_inv").to(device)
|
||||
tensor = weight_dequant(tensor, weight_scale_inv)
|
||||
return tensor.to(device)
|
||||
@ -66,12 +66,23 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
|
||||
for name, param in local_state.items():
|
||||
key = prefix + name
|
||||
translated_key = translate_name_to_gguf(key)
|
||||
if translated_key in gguf_loader.tensor_file_map:
|
||||
|
||||
# TODO: Merge all loader.
|
||||
# I know this is ugly but lets do it for now.
|
||||
if gguf_loader.safetensor_loader is not None:
|
||||
load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor
|
||||
tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
|
||||
else:
|
||||
load_dequantized_tensor = gguf_loader.load_gguf_tensor
|
||||
tensor_file_map = gguf_loader.tensor_file_map
|
||||
|
||||
if translated_key in tensor_file_map:
|
||||
target_dtype = torch.get_default_dtype()
|
||||
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
||||
print(f"loading {translated_key} to {device}")
|
||||
# device = "cpu" if "embd" in translated_key else "cuda"
|
||||
weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
|
||||
torch.cuda.empty_cache() # To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225"
|
||||
# weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
|
||||
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
|
||||
set_param(module, name, weights)
|
||||
del weights
|
||||
else:
|
||||
|
||||
214
merge_tensors/merge_safetensor_gguf.py
Normal file
214
merge_tensors/merge_safetensor_gguf.py
Normal file
@ -0,0 +1,214 @@
|
||||
# this script targets to merge the fp8 safe tensor and the gguf quantized tensors.
|
||||
|
||||
import os
|
||||
# insert the path of the project
|
||||
import sys
|
||||
sys.path.insert(0, "/home/azure/ktransformers")
|
||||
import argparse
|
||||
import torch
|
||||
from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
def read_safetensor_keys_from_folder(folder_path)->dict:
|
||||
"""
|
||||
:param folder_path: folder path
|
||||
:return: key_to_file_map
|
||||
"""
|
||||
# check if the folder path is exist
|
||||
if not os.path.exists(folder_path):
|
||||
raise FileNotFoundError(f"GGUF dir not found: {folder_path}")
|
||||
if os.path.isfile(folder_path):
|
||||
folder_path = os.path.dirname(folder_path)
|
||||
|
||||
key_to_file_map = {}
|
||||
|
||||
found_safetensor = False
|
||||
for root, dirs, files in os.walk(folder_path):
|
||||
# sort files
|
||||
files = sorted(files)
|
||||
for file in files:
|
||||
if file.endswith(".safetensors"):
|
||||
found_safetensor = True
|
||||
file_path = os.path.join(root, file)
|
||||
try:
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
for key in f.keys():
|
||||
if "model.layers.61" in key:
|
||||
# skip MTP layer
|
||||
continue
|
||||
# try:
|
||||
# if int(key.split('.')[2]) > 4:
|
||||
# continue
|
||||
# except:
|
||||
# pass
|
||||
key_to_file_map[key] = file_path
|
||||
except Exception as e:
|
||||
print(f"Error reading Safetensor file {file_path}: {e}")
|
||||
|
||||
if not found_safetensor:
|
||||
raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
|
||||
|
||||
return key_to_file_map
|
||||
|
||||
tensor_from_gguf = [] # todo: add keys in gguf that should be used in the final tensor
|
||||
|
||||
def translate_name(name:str)->str:
|
||||
"""
|
||||
:param name: name of the tensor
|
||||
:return: translated name
|
||||
"""
|
||||
name = translate_name_to_gguf(name)
|
||||
name = name.replace(".up_proj.", ".ffn_up_exps.")
|
||||
name = name.replace(".down_proj.", ".ffn_down_exps.")
|
||||
name = name.replace(".gate_proj.", ".ffn_gate_exps.")
|
||||
name = name.replace(".ffn_gate_inp.e_score_correction_bias", ".exp_probs_b.bias")
|
||||
return name
|
||||
|
||||
|
||||
def combine_tensor_sources(safetensor_path:str, gguf_path:str):
|
||||
gguf_loader = GGUFLoader(gguf_path)
|
||||
gguf_tensor_file_map = gguf_loader.tensor_file_map
|
||||
safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path)
|
||||
|
||||
# build a map for the key to the tensor
|
||||
# according to the key, we can get the tensor from the file
|
||||
|
||||
target_tensor_map = {}
|
||||
for key in safetensor_tensor_file_map.keys():
|
||||
# for all experts, we use the gguf tensor
|
||||
if ".mlp.experts." in key:
|
||||
if '.weight_scale_inv' in key:
|
||||
continue
|
||||
key = '.'.join(key.split('.')[:5]+key.split('.')[-2:])
|
||||
translated_key = translate_name(key)
|
||||
target_tensor_map[key] = gguf_tensor_file_map[translated_key]
|
||||
continue
|
||||
|
||||
if any(target_key in key for target_key in tensor_from_gguf):
|
||||
target_tensor_map[key] = gguf_tensor_file_map[translate_name(key)]
|
||||
else:
|
||||
target_tensor_map[key] = safetensor_tensor_file_map[key]
|
||||
|
||||
return target_tensor_map, gguf_loader
|
||||
|
||||
def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader):
|
||||
# Ensure output directory exists
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
# Cache for safetensor file handles and GGUF loaders
|
||||
safetensors_cache = {}
|
||||
gguf_cache = {}
|
||||
|
||||
# Group tensors by layer
|
||||
layer_groups = defaultdict(list)
|
||||
non_layer_keys = []
|
||||
layer_pattern = re.compile(r'\.layers\.(\d+)\.')
|
||||
|
||||
for key in target_tensor_map:
|
||||
match = layer_pattern.search(key)
|
||||
if match:
|
||||
layer_num = int(match.group(1))
|
||||
layer_groups[layer_num].append(key)
|
||||
else:
|
||||
non_layer_keys.append(key)
|
||||
|
||||
# Calculate total shards
|
||||
total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1
|
||||
if total_shards == 0:
|
||||
raise ValueError("No tensors to save")
|
||||
|
||||
shard_idx = 0
|
||||
|
||||
# Save non-layer tensors to the first shard if they exist
|
||||
if non_layer_keys:
|
||||
tensors = {}
|
||||
for key in non_layer_keys:
|
||||
file_path = target_tensor_map[key]
|
||||
tensor = None
|
||||
ggml_type = None
|
||||
if file_path.endswith('.safetensors'):
|
||||
if file_path not in safetensors_cache:
|
||||
safetensors_cache[file_path] = safe_open(file_path, framework='pt')
|
||||
f = safetensors_cache[file_path]
|
||||
tensor = f.get_tensor(key)
|
||||
elif file_path.endswith('.gguf'):
|
||||
gguf_name = translate_name(key)
|
||||
tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {file_path}")
|
||||
tensors[translate_name(key)] = tensor
|
||||
if ggml_type:
|
||||
ggml_type = torch.tensor(ggml_type)
|
||||
ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type"
|
||||
tensors[ggml_key] = ggml_type
|
||||
|
||||
output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors")
|
||||
print(f"Saving non-layer tensors to {output_file}")
|
||||
save_file(tensors, output_file)
|
||||
print(tensors.keys())
|
||||
|
||||
shard_idx += 1
|
||||
|
||||
# Save each layer's tensors to subsequent shards
|
||||
for layer_num in sorted(layer_groups.keys()):
|
||||
layer_keys = layer_groups[layer_num]
|
||||
tensors = {}
|
||||
for key in layer_keys:
|
||||
file_path = target_tensor_map[key]
|
||||
tensor = None
|
||||
ggml_type = None
|
||||
if file_path.endswith('.safetensors'):
|
||||
if file_path not in safetensors_cache:
|
||||
safetensors_cache[file_path] = safe_open(file_path, framework='pt')
|
||||
f = safetensors_cache[file_path]
|
||||
tensor = f.get_tensor(key)
|
||||
tensor_info = tensor.shape
|
||||
elif file_path.endswith('.gguf'):
|
||||
gguf_name = translate_name(key)
|
||||
tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)
|
||||
# tensor_info = gguf_loader.tensor_info[gguf_name]
|
||||
# ggml_type = gguf_loader.tensor_info[gguf_name]['ggml_type']
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {file_path}")
|
||||
tensors[translate_name(key)] = tensor
|
||||
if ggml_type:
|
||||
ggml_type = torch.tensor(ggml_type)
|
||||
ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type"
|
||||
tensors[ggml_key] = ggml_type
|
||||
|
||||
output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors")
|
||||
print(f"Saving layer {layer_num} to {output_file}")
|
||||
# print(tensors.keys())
|
||||
save_file(tensors, output_file)
|
||||
shard_idx += 1
|
||||
|
||||
return
|
||||
|
||||
def main():
|
||||
# 创建命令行参数解析器
|
||||
parser = argparse.ArgumentParser(description="Read parameters from Safetensor and GGUF files")
|
||||
parser.add_argument("--safetensor_path", type=str, help="Path to the Safetensor file", default="/mnt/data/model/DeepSeek-V3")
|
||||
parser.add_argument("--gguf_path", type=str, help="Path to the GGUF file", default="/mnt/data/model/DeepseekV3-q4km-gguf")
|
||||
parser.add_argument("--output_path", type=str, help="Path to the output file", default="/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8")
|
||||
|
||||
# print all the arguments
|
||||
print("All the arguments:")
|
||||
print(parser.parse_args())
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
safetensor_path = args.safetensor_path
|
||||
gguf_path = args.gguf_path
|
||||
output_path = args.output_path
|
||||
|
||||
target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)
|
||||
write_combined_tensor(target_tensor_map, output_path, gguf_loader)
|
||||
|
||||
return
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -4,4 +4,6 @@ numpy
|
||||
torch>=2.3.0
|
||||
packaging
|
||||
cpufeature
|
||||
protobuf
|
||||
protobuf
|
||||
tiktoken
|
||||
blobfile
|
||||
Loading…
x
Reference in New Issue
Block a user