mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-01-19 12:43:16 +08:00
Merge branch 'develop-0.2.2' into support-fp8
Update README.md
This commit is contained in:
commit
91c1619296
19
.devcontainer/Dockerfile
Normal file
19
.devcontainer/Dockerfile
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel as compile_server
|
||||||
|
WORKDIR /workspace
|
||||||
|
ENV CUDA_HOME /usr/local/cuda
|
||||||
|
RUN <<EOF
|
||||||
|
apt update -y && apt install -y --no-install-recommends \
|
||||||
|
git \
|
||||||
|
wget \
|
||||||
|
vim \
|
||||||
|
gcc \
|
||||||
|
g++ \
|
||||||
|
cmake &&
|
||||||
|
rm -rf /var/lib/apt/lists/* &&
|
||||||
|
cd ktransformers &&
|
||||||
|
pip install ninja pyproject numpy cpufeature &&
|
||||||
|
pip install flash-attn &&
|
||||||
|
cp /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /opt/conda/lib/
|
||||||
|
EOF
|
||||||
|
# Set the default shell to bash
|
||||||
|
CMD ["/bin/bash"]
|
||||||
34
.devcontainer/devcontainer.json
Normal file
34
.devcontainer/devcontainer.json
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
{
|
||||||
|
"name": "Ktrans Dev Container",
|
||||||
|
"privileged": true,
|
||||||
|
"build": {
|
||||||
|
"dockerfile": "Dockerfile",
|
||||||
|
"context": "..",
|
||||||
|
"args": {
|
||||||
|
"http_proxy": "${env:http_proxy}",
|
||||||
|
"https_proxy": "${env:https_proxy}",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"runArgs": [
|
||||||
|
"--network=host",
|
||||||
|
"--gpus",
|
||||||
|
"all"
|
||||||
|
// "--gpu all"
|
||||||
|
],
|
||||||
|
"workspaceFolder": "/workspace",
|
||||||
|
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=cached",
|
||||||
|
"mounts": [
|
||||||
|
"source=/mnt/data,target=/mnt/incontainer,type=bind,consistency=cached"
|
||||||
|
],
|
||||||
|
"customizations": {
|
||||||
|
"vscode": {
|
||||||
|
"extensions": [
|
||||||
|
],
|
||||||
|
"settings": {
|
||||||
|
"terminal.integrated.shell.linux": "/bin/bash",
|
||||||
|
"cmake.configureOnOpen": true,
|
||||||
|
"cmake.generator": "Ninja"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
8
.gitignore
vendored
8
.gitignore
vendored
@ -19,13 +19,9 @@ ktransformers/server/local_store/
|
|||||||
ktransformers/server_test1.db
|
ktransformers/server_test1.db
|
||||||
*.patch
|
*.patch
|
||||||
img/
|
img/
|
||||||
tmp1.txt
|
tmp*.txt
|
||||||
test_65_300_1536.txt
|
|
||||||
test.txt
|
test.txt
|
||||||
book
|
book
|
||||||
ktransformers/tests/mmlu_result_silicon.json
|
|
||||||
ktransformers/tests/chat_txt.txt
|
ktransformers/tests/chat_txt.txt
|
||||||
mmlu_result_q4km.json
|
mmlu_result*
|
||||||
mmlu_result_q4km.log
|
|
||||||
ktransformers/tests/mmlu_result_silicon.log
|
|
||||||
ktransformers/ktransformers_ext/cuda_musa/
|
ktransformers/ktransformers_ext/cuda_musa/
|
||||||
|
|||||||
@ -23,7 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
|
|||||||
|
|
||||||
<h2 id="Updates">🔥 Updates</h2>
|
<h2 id="Updates">🔥 Updates</h2>
|
||||||
|
|
||||||
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; Longer Context (from 8K to 128K for 24GB VRAM).
|
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
|
||||||
* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).
|
* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).
|
||||||
* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. For detailed show case and reproduction tutorial, see [here](./doc/en/DeepseekR1_V3_tutorial.md).
|
* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. For detailed show case and reproduction tutorial, see [here](./doc/en/DeepseekR1_V3_tutorial.md).
|
||||||
* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.
|
* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.
|
||||||
|
|||||||
@ -21,7 +21,7 @@ KTransformers 是一个以 Python 为中心的灵活框架,其核心是可扩
|
|||||||
|
|
||||||
<h2 id="Updates">🔥 更新</h2>
|
<h2 id="Updates">🔥 更新</h2>
|
||||||
|
|
||||||
* **2025 年 2 月 15 日**:为DeepSeek-V3/R1支持[FP8 GPU内核](./doc/en/fp8_kernel.md); 支持更长的上下文 (从8K到128K仅用24GB VRAM).
|
* **2025 年 2 月 15 日**:为DeepSeek-V3/R1支持[FP8 GPU内核](./doc/en/fp8_kernel.md); 支持更长的上下文([教程](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context)).
|
||||||
* **2025 年 2 月 15 日**:长上下文(从4K到8K,24GB VRAM) & 稍快的速度(+15%)(最快 16 Tokens/s),文档请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md) 和 [在线指南](https://kvcache-ai.github.io/ktransformers/) 。
|
* **2025 年 2 月 15 日**:长上下文(从4K到8K,24GB VRAM) & 稍快的速度(+15%)(最快 16 Tokens/s),文档请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md) 和 [在线指南](https://kvcache-ai.github.io/ktransformers/) 。
|
||||||
* **2025 年 2 月 10 日**:支持 Deepseek-R1 和 V3 在单个(24GB VRAM)/多 GPU 和 382G DRAM 上运行,速度提升高达 3~28 倍。详细教程请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md)。
|
* **2025 年 2 月 10 日**:支持 Deepseek-R1 和 V3 在单个(24GB VRAM)/多 GPU 和 382G DRAM 上运行,速度提升高达 3~28 倍。详细教程请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md)。
|
||||||
* **2024 年 8 月 28 日**:支持 InternLM2.5-7B-Chat-1M 模型下的 1M 上下文,使用 24GB 的 VRAM 和 150GB 的 DRAM。详细教程请参见 [这里](./doc/en/long_context_tutorial.md)。
|
* **2024 年 8 月 28 日**:支持 InternLM2.5-7B-Chat-1M 模型下的 1M 上下文,使用 24GB 的 VRAM 和 150GB 的 DRAM。详细教程请参见 [这里](./doc/en/long_context_tutorial.md)。
|
||||||
|
|||||||
@ -22,6 +22,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
|
|||||||
|
|
||||||
<h2 id="Updates">🔥 Updates</h2>
|
<h2 id="Updates">🔥 Updates</h2>
|
||||||
|
|
||||||
|
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
|
||||||
* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. The detailed tutorial is [here](./en/DeepseekR1_V3_tutorial.md).
|
* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. The detailed tutorial is [here](./en/DeepseekR1_V3_tutorial.md).
|
||||||
* **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./en/long_context_tutorial.md).
|
* **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./en/long_context_tutorial.md).
|
||||||
* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.
|
* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.
|
||||||
|
|||||||
@ -72,3 +72,28 @@ The detailed error:
|
|||||||
Running `conda install -c conda-forge libstdcxx-ng` can solve the problem.
|
Running `conda install -c conda-forge libstdcxx-ng` can solve the problem.
|
||||||
|
|
||||||
|
|
||||||
|
### Q: When running the bfloat16 moe model, the data shows NaN
|
||||||
|
The detailed error:
|
||||||
|
```shell
|
||||||
|
Traceback (most recent call last):
|
||||||
|
File "/root/ktransformers/ktransformers/local_chat.py", line 183, in <module>
|
||||||
|
fire.Fire(local_chat)
|
||||||
|
File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 135, in Fire
|
||||||
|
component_trace = _Fire(component, args, parsed_flag_args, context, name)
|
||||||
|
File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 468, in _Fire
|
||||||
|
component, remaining_args = _CallAndUpdateTrace(
|
||||||
|
File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 684, in _CallAndUpdateTrace
|
||||||
|
component = fn(*varargs, **kwargs)
|
||||||
|
File "/root/ktransformers/ktransformers/local_chat.py", line 177, in local_chat
|
||||||
|
generated = prefill_and_generate(
|
||||||
|
File "/root/ktransformers/ktransformers/util/utils.py", line 204, in prefill_and_generate
|
||||||
|
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
||||||
|
File "/root/ktransformers/ktransformers/util/utils.py", line 128, in decode_one_tokens
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
|
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0
|
||||||
|
```
|
||||||
|
**SOLUTION**: The issue of running ktransformers on Ubuntu 22.04 is caused by the current system's g++ version being too old, and the pre-defined macros do not include avx_bf16. We have tested and confirmed that it works on g++ 11.4 in Ubuntu 22.04.
|
||||||
|
|
||||||
|
### Q: Using fp8 prefill very slow.
|
||||||
|
|
||||||
|
The FP8 kernel is build by JIT, so the first run will be slow. The subsequent runs will be faster.
|
||||||
43
doc/en/benchmark.md
Normal file
43
doc/en/benchmark.md
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
## Benchmark
|
||||||
|
|
||||||
|
To conduct a quick and convenient check, we have employed a simple Python script available [here](https://github.com/kvcache-ai/ktransformers/tree/main/ktransformers/tests) to assess the precision of our **[ktransformers](https://github.com/kvcache-ai/ktransformers)** project. For this evaluation, we utilized the same dataset, which was shuffled in a consistent manner and limited to the first 1,000 data points, to test our implementation across a variety of CPU kernels, MLA kernels, and quantization formats.
|
||||||
|
|
||||||
|
We selected the DeepSeek-V3 model in its bf16, int8, and q4km versions for this test. The MMLU dataset, which can be found [here](https://huggingface.co/datasets/cais/mmlu), was used (we selected all datasets and shuffled them with a fixed random seed).
|
||||||
|
|
||||||
|
**!!! However, we skipped the few-shot part and only chose the first 1,000 data points for a quick check.** Please note that this approach may result in results that are not consistent with the technical report of DeepSeek-V3. And the test of R1 and further more tests are on going.
|
||||||
|
|
||||||
|
To verify our results, we chose [cloud service platform](https://cloud.siliconflow.cn/models) as baseline. All tests were conducted using the same script and datasets, allowing us to make a preliminary assessment of our project's precision.
|
||||||
|
|
||||||
|
We set the argument `temperature=0.6`, and to simplify the test process, we skipped the few-shot part and used the following prompt: `There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter. \nQuestion: {question}\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\nAnswer: '`. For more details, please refer to the [script](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/tests/mmlu_test.py).
|
||||||
|
|
||||||
|
Given that we have only tested 1,000 cases, which provides only a preliminary judgment, some fluctuations in the results are reasonable. We selected all datasets and shuffled them with a fixed random seed to ensure consistency.
|
||||||
|
|
||||||
|
## Some Detail
|
||||||
|
|
||||||
|
- The bf16 model of DeepSeek-V3 is available [here](https://huggingface.co/opensourcerelease/DeepSeek-V3-bf16/tree/main) (you may convert it to gguf by llama.cpp). The q4km model can be found [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M).
|
||||||
|
|
||||||
|
- The optimization YAML file is located [here](https://github.com/kvcache-ai/ktransformers/tree/main/ktransformers/optimize/optimize_rules). For the Matrix MUL Kernel, you can change `KLinearMarlin` to `KLinearTorch`.
|
||||||
|
|
||||||
|
- To switch the MLA Kernel from Triton to Torch, you can check and modify [this file](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/attention.py), specifically by using the `forward_windows` method.
|
||||||
|
|
||||||
|
- When attempting to conduct the bf16 test (both CPU Weight and GPU Weight), you may encounter issues stemming from older versions of g++ and as, particularly when using Ubuntu 20 or earlier versions. To facilitate a smoother experience and enable you to reproduce our results, we have provided a development container. This container offers a pre-configured environment tailored for this purpose. However, please note that the container does not have the ktrans package installed. Therefore, you may still need to manually install certain packages to ensure everything runs smoothly.
|
||||||
|
|
||||||
|
- You may config the model mount dir in `devcontainer/devcontainer.json`, check the `"mouts":` config.
|
||||||
|
|
||||||
|
|
||||||
|
## The Result Table
|
||||||
|
|
||||||
|
| | | | | | | | |
|
||||||
|
| ------------------------ | ----------------- | ---------- | ----------------- | ------- | ---------- | ------------------------------------------------------ | ------------ |
|
||||||
|
| DataSet | CPU Weight Format | CPU Kernel | GPU Weight Format | GEMM | MLA Kernel | [Siliconflow](https://cloud.siliconflow.cn/models)<br> | Ktrans Point |
|
||||||
|
| MMLU<br><br>(shuffle 1k) | bf16 | cpuinfer | bf16 | torch | torch | 81.6 | 81.9 |
|
||||||
|
| | int8 | cpuinfer | bf16 | torch | torch | 81.6 | 83.1 |
|
||||||
|
| | q4km | cpuinfer | bf16 | torch | torch | 81.6 | 82.8 |
|
||||||
|
| | q4km | cpuinfer | bf16 | torch | triton | 81.6 | 81.4 |
|
||||||
|
| | q4km | cpuinfer | q4km->marlin 8 | marlin | triton | 81.6 | 81.1 |
|
||||||
|
| | q4km | cpuinfer | q4km->marlin 4 | marlin | triton | 81.6 | 81 |
|
||||||
|
| | q4km | cpuinfer | fp8 | marlin | triton | 81.6 | 81.5 |
|
||||||
|
| MMLU-pro | q4km | cpuinfer | fp8 | fp8gemm | triton | 57.7 | 57.6 |
|
||||||
|
| MMLU-pro | q4km | cpuinfer | q4km->marlin 4 | marlin | triton | 57.7 | 57.5 |
|
||||||
|
| HumanEval | tbd | tbd | tbd | tbd | tbd | tbd | tbd |
|
||||||
|
| GSM8K | tbd | tbd | tbd | tbd | tbd | tbd | tbd |
|
||||||
@ -1,7 +1,9 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <musa_runtime.h>
|
#include <musa_runtime.h>
|
||||||
|
#include <musa_bf16.h>
|
||||||
|
|
||||||
#define cudaLaunchHostFunc musaLaunchHostFunc
|
#define cudaLaunchHostFunc musaLaunchHostFunc
|
||||||
#define cudaStream_t musaStream_t
|
#define cudaStream_t musaStream_t
|
||||||
#define cudaHostFn_t musaHostFn_t
|
#define cudaHostFn_t musaHostFn_t
|
||||||
|
#define nv_bfloat16 mt_bfloat16
|
||||||
@ -1,9 +1,9 @@
|
|||||||
/**
|
/**
|
||||||
* @Description :
|
* @Description :
|
||||||
* @Author : Azure-Tang, Boxin Zhang
|
* @Author : Azure-Tang, Boxin Zhang
|
||||||
* @Date : 2024-07-25 13:38:30
|
* @Date : 2024-07-25 13:38:30
|
||||||
* @Version : 0.2.2
|
* @Version : 0.2.2
|
||||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||||
**/
|
**/
|
||||||
|
|
||||||
#include "custom_gguf/ops.h"
|
#include "custom_gguf/ops.h"
|
||||||
@ -20,38 +20,45 @@
|
|||||||
|
|
||||||
PYBIND11_MODULE(KTransformersOps, m) {
|
PYBIND11_MODULE(KTransformersOps, m) {
|
||||||
|
|
||||||
m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
|
m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
|
||||||
return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
|
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
|
||||||
|
return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
|
||||||
}, "Function to dequantize q8_0 data.",
|
}, "Function to dequantize q8_0 data.",
|
||||||
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
||||||
|
|
||||||
m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
|
m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
|
||||||
return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
|
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
|
||||||
|
return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
|
||||||
}, "Function to dequantize q6_k data.",
|
}, "Function to dequantize q6_k data.",
|
||||||
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
||||||
|
|
||||||
m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
|
m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
|
||||||
return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
|
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
|
||||||
|
return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
|
||||||
}, "Function to dequantize q5_k data.",
|
}, "Function to dequantize q5_k data.",
|
||||||
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
||||||
|
|
||||||
m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
|
m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
|
||||||
return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
|
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
|
||||||
|
return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
|
||||||
}, "Function to dequantize q4_k data.",
|
}, "Function to dequantize q4_k data.",
|
||||||
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
||||||
|
|
||||||
m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
|
m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
|
||||||
return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
|
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
|
||||||
|
return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
|
||||||
}, "Function to dequantize q3_k data.",
|
}, "Function to dequantize q3_k data.",
|
||||||
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
||||||
|
|
||||||
m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
|
m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
|
||||||
return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
|
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
|
||||||
|
return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
|
||||||
}, "Function to dequantize q2_k data.",
|
}, "Function to dequantize q2_k data.",
|
||||||
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
||||||
|
|
||||||
m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
|
m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
|
||||||
return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
|
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
|
||||||
|
return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
|
||||||
}, "Function to dequantize iq4_xs data.",
|
}, "Function to dequantize iq4_xs data.",
|
||||||
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
/**
|
/**
|
||||||
* @Description :
|
* @Description :
|
||||||
* @Author : Azure-Tang
|
* @Author : Azure-Tang
|
||||||
* @Date : 2024-07-22 09:27:55
|
* @Date : 2024-07-22 09:27:55
|
||||||
* @Version : 1.0.0
|
* @Version : 1.0.0
|
||||||
* @LastEditors : kkk1nak0
|
* @LastEditors : kkk1nak0
|
||||||
* @LastEditTime : 2024-08-12 03:48:46
|
* @LastEditTime : 2024-08-12 03:48:46
|
||||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||||
**/
|
**/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@ -13,10 +13,10 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
|
|
||||||
torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
|
torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
|
||||||
torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
|
torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
|
||||||
torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
|
torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
|
||||||
torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
|
torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
|
||||||
torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
|
torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
|
||||||
torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
|
torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
|
||||||
torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
|
torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
|
||||||
|
|||||||
@ -102,7 +102,8 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
|
|||||||
M, N = x.size()
|
M, N = x.size()
|
||||||
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
||||||
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
|
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
|
||||||
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
with torch.cuda.device(x.device):
|
||||||
|
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
|||||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
||||||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||||
from ktransformers.util.utils import prefill_and_generate
|
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
|
||||||
from ktransformers.server.config.config import Config
|
from ktransformers.server.config.config import Config
|
||||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||||
|
|
||||||
@ -64,7 +64,6 @@ def local_chat(
|
|||||||
force_think: bool = False,
|
force_think: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
Config().cpu_infer = cpu_infer
|
Config().cpu_infer = cpu_infer
|
||||||
@ -169,7 +168,7 @@ def local_chat(
|
|||||||
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
|
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
|
||||||
"please change max_seq_len in ~/.ktransformers/config.yaml"
|
"please change max_seq_len in ~/.ktransformers/config.yaml"
|
||||||
|
|
||||||
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled:
|
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8:
|
||||||
generated = prefill_and_generate(
|
generated = prefill_and_generate(
|
||||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
|
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
|
||||||
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
|
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_ro
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||||
from ktransformers.util.custom_gguf import GGUFLoader
|
from ktransformers.util.custom_gguf import GGUFLoader
|
||||||
|
from ktransformers.util.utils import get_compute_capability
|
||||||
import logging
|
import logging
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
@ -48,12 +49,14 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||||||
prefill_device: str = "cuda",
|
prefill_device: str = "cuda",
|
||||||
generate_device: str = "cuda",
|
generate_device: str = "cuda",
|
||||||
chunck_size: int = 1000,
|
chunck_size: int = 1000,
|
||||||
|
absorb_for_prefill: bool = False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||||
self.orig_module.__init__(orig_module.config,
|
self.orig_module.__init__(orig_module.config,
|
||||||
orig_module.layer_idx)
|
orig_module.layer_idx)
|
||||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||||
self.mla_wrapper = None
|
self.mla_wrapper = None
|
||||||
|
self.absorb_for_prefill = absorb_for_prefill
|
||||||
|
|
||||||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||||
@ -242,7 +245,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||||||
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
||||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||||
q_nope = q_nope.transpose(1, 2)
|
q_nope = q_nope.transpose(1, 2)
|
||||||
assert q_nope.is_contiguous()
|
#assert q_nope.is_contiguous()
|
||||||
|
|
||||||
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
||||||
@ -282,6 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||||||
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
@ -380,7 +384,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
||||||
|
|
||||||
# decode
|
# decode
|
||||||
if q_len == 1:
|
if q_len == 1 or self.absorb_for_prefill:
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||||
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||||
@ -395,29 +399,42 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||||||
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
||||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||||
q_nope = q_nope.transpose(1, 2)
|
q_nope = q_nope.transpose(1, 2)
|
||||||
assert q_nope.is_contiguous()
|
q_nope = q_nope.contiguous()
|
||||||
|
#assert q_nope.is_contiguous()
|
||||||
|
|
||||||
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
||||||
q_nope.squeeze_(1)
|
q_nope.squeeze_(0)
|
||||||
q_pe.squeeze_(1)
|
q_pe.squeeze_(0)
|
||||||
|
|
||||||
# flash attn doesn't support head_dim bigger than 256, use flashinfer
|
# flash attn doesn't support head_dim bigger than 256, use flashinfer
|
||||||
if self.mla_wrapper is None:
|
if self.mla_wrapper is None:
|
||||||
self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True)
|
self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True)
|
||||||
if self.mla_wrapper.need_plan:
|
if self.mla_wrapper.need_plan:
|
||||||
self.mla_wrapper.need_plan = False
|
self.mla_wrapper.need_plan = False
|
||||||
|
if q_len == 1:
|
||||||
self.mla_wrapper.plan(None,None,None,
|
self.mla_wrapper.plan(None,None,None,
|
||||||
position_ids.squeeze(1)+1,
|
position_ids.squeeze(1)+1,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.kv_lora_rank,
|
self.kv_lora_rank,
|
||||||
self.qk_rope_head_dim,
|
self.qk_rope_head_dim,
|
||||||
past_key_value.page_size,
|
past_key_value.page_size,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
q_nope.dtype,
|
q_nope.dtype,
|
||||||
compressed_kv.dtype)
|
compressed_kv.dtype)
|
||||||
|
else:
|
||||||
|
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=self.device)
|
||||||
|
kv_len_arr = torch.tensor([position_ids[0, -1].item()+1], dtype=torch.int32, device=self.device)
|
||||||
|
self.mla_wrapper.plan(qo_indptr,None,None,
|
||||||
|
kv_len_arr,
|
||||||
|
self.num_heads,
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.qk_rope_head_dim,
|
||||||
|
past_key_value.page_size,
|
||||||
|
self.softmax_scale,
|
||||||
|
q_nope.dtype,
|
||||||
|
compressed_kv.dtype)
|
||||||
attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)
|
attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
k = (
|
k = (
|
||||||
torch.cat([compressed_kv, k_pe], dim=-1)
|
torch.cat([compressed_kv, k_pe], dim=-1)
|
||||||
@ -441,12 +458,13 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||||||
# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
|
# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
|
||||||
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||||
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank]
|
||||||
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim]
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim]
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
else:
|
else:
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -571,7 +589,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if os.name == 'nt':
|
if os.name == 'nt' or get_compute_capability()<8:
|
||||||
|
print("for Windows or GPU before ampere, use forward_windows")
|
||||||
return self.forward_windows(
|
return self.forward_windows(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
|||||||
@ -159,7 +159,7 @@ class KExpertsCPU(KExpertsBase):
|
|||||||
down_ptr = ctypes.addressof(
|
down_ptr = ctypes.addressof(
|
||||||
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||||
)
|
)
|
||||||
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
|
#print(self.gate_type, self.up_type, self.down_type)
|
||||||
n_routed_experts = self.n_routed_experts
|
n_routed_experts = self.n_routed_experts
|
||||||
# n_routed_experts = len(self.orig_module)
|
# n_routed_experts = len(self.orig_module)
|
||||||
moe_config = MOEConfig(
|
moe_config = MOEConfig(
|
||||||
@ -459,9 +459,9 @@ class KExpertsTorch(KExpertsBase):
|
|||||||
self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype)
|
self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype)
|
||||||
self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype)
|
self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype)
|
||||||
|
|
||||||
self.up = torch.cat(self.gate, dim=0)
|
self.up = torch.cat(self.up, dim=0)
|
||||||
self.gate = torch.cat(self.gate, dim=0)
|
self.gate = torch.cat(self.gate, dim=0)
|
||||||
self.down = torch.cat(self.gate, dim=0)
|
self.down = torch.cat(self.down, dim=0)
|
||||||
return
|
return
|
||||||
|
|
||||||
def unload(self):
|
def unload(self):
|
||||||
|
|||||||
@ -9,7 +9,7 @@ flashinfer_enabled = False
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import flashinfer
|
import flashinfer
|
||||||
flashinfer_enabled = False # disabled now, TODO:use new version of flashinfer and enable
|
flashinfer_enabled = True
|
||||||
print("found flashinfer")
|
print("found flashinfer")
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -122,7 +122,7 @@ class MLAWrapper():
|
|||||||
if kv_indices is None:
|
if kv_indices is None:
|
||||||
assert self.max_batch_size == 1
|
assert self.max_batch_size == 1
|
||||||
kv_indices = self.kv_indices_buf
|
kv_indices = self.kv_indices_buf
|
||||||
|
|
||||||
self.wrapper.plan(
|
self.wrapper.plan(
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
@ -132,14 +132,14 @@ class MLAWrapper():
|
|||||||
head_dim_ckv,
|
head_dim_ckv,
|
||||||
head_dim_kpe,
|
head_dim_kpe,
|
||||||
page_size,
|
page_size,
|
||||||
False, # causal is False for decoding
|
True, # causal
|
||||||
sm_scale,
|
sm_scale,
|
||||||
q_data_type,
|
q_data_type,
|
||||||
kv_data_type,
|
kv_data_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
||||||
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse)
|
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
|
||||||
|
|
||||||
class MLAWrapperSingleton():
|
class MLAWrapperSingleton():
|
||||||
wrappers:dict = {}
|
wrappers:dict = {}
|
||||||
@ -179,6 +179,24 @@ class MLAWrapperSingleton():
|
|||||||
sm_scale,
|
sm_scale,
|
||||||
q_data_type,
|
q_data_type,
|
||||||
kv_data_type,)
|
kv_data_type,)
|
||||||
|
wrapper.need_plan = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def need_plan_all(cls):
|
||||||
|
for device, wrapper in cls.wrappers.items():
|
||||||
|
wrapper.need_plan = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset_buffer(cls):
|
||||||
|
for device, wrapper in cls.wrappers.items():
|
||||||
|
wrapper.qo_indptr_buf[1] = 1 # assert max_batch_size=1 here.
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_buffer(cls, max_pages):
|
||||||
|
for device, wrapper in cls.wrappers.items():
|
||||||
|
wrapper.kv_indptr_buf[1] = max_pages # assert max_batch_size=1 here.
|
||||||
|
wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)
|
||||||
|
wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -187,8 +205,9 @@ if __name__ == "__main__":
|
|||||||
page_size = 64
|
page_size = 64
|
||||||
num_heads = 128
|
num_heads = 128
|
||||||
|
|
||||||
q_nope = torch.randn((1, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
q_len = 10
|
||||||
q_pe = torch.randn((1, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||||
|
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||||
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
||||||
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
|
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
|
||||||
|
|
||||||
@ -199,10 +218,10 @@ if __name__ == "__main__":
|
|||||||
max_pages,
|
max_pages,
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_len_arr = torch.tensor([10], dtype=torch.int32, device="cuda")
|
kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda")
|
||||||
|
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
||||||
wrapper.plan(
|
wrapper.plan(
|
||||||
None,
|
qo_indptr,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
kv_len_arr,
|
kv_len_arr,
|
||||||
@ -216,6 +235,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
||||||
|
print(attn_output.shape)
|
||||||
|
|
||||||
k = (
|
k = (
|
||||||
torch.cat([ckv, k_pe], dim=-1)
|
torch.cat([ckv, k_pe], dim=-1)
|
||||||
@ -235,6 +255,7 @@ if __name__ == "__main__":
|
|||||||
False,
|
False,
|
||||||
192 ** (-0.5)
|
192 ** (-0.5)
|
||||||
)
|
)
|
||||||
|
print(attn_ref.shape)
|
||||||
|
|
||||||
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
|
||||||
print("test past")
|
print("test past")
|
||||||
@ -56,7 +56,7 @@ from ktransformers.models.modeling_deepseek import (
|
|||||||
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
|
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
|
||||||
from ktransformers.models.configuration_llama import LlamaConfig
|
from ktransformers.models.configuration_llama import LlamaConfig
|
||||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||||
from ktransformers.util.utils import InferenceState
|
from ktransformers.util.utils import InferenceState, get_compute_capability
|
||||||
from ktransformers.util.custom_gguf import GGUFLoader
|
from ktransformers.util.custom_gguf import GGUFLoader
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from ktransformers.models.modeling_llama import (
|
from ktransformers.models.modeling_llama import (
|
||||||
@ -649,9 +649,14 @@ class KDeepseekV2Model(BaseInjectedModule):
|
|||||||
if per_layer_prefill_flag:
|
if per_layer_prefill_flag:
|
||||||
causal_mask = None
|
causal_mask = None
|
||||||
else:
|
else:
|
||||||
causal_mask = self._update_causal_mask(
|
if os.name == 'nt' or get_compute_capability()<8:
|
||||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
print("for Windows or GPU before ampere, use forward_windows")
|
||||||
)
|
# only use mask in forward windows or can't flash attn
|
||||||
|
causal_mask = self._update_causal_mask(
|
||||||
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
causal_mask = None
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|||||||
@ -293,6 +293,7 @@
|
|||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cuda:0"
|
generate_device: "cuda:0"
|
||||||
prefill_device: "cuda:0"
|
prefill_device: "cuda:0"
|
||||||
|
absorb_for_prefill: False
|
||||||
|
|
||||||
# GPU 1: layers 15–29
|
# GPU 1: layers 15–29
|
||||||
- match:
|
- match:
|
||||||
@ -302,6 +303,7 @@
|
|||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cuda:1"
|
generate_device: "cuda:1"
|
||||||
prefill_device: "cuda:1"
|
prefill_device: "cuda:1"
|
||||||
|
absorb_for_prefill: False
|
||||||
|
|
||||||
# GPU 2: layers 30–44
|
# GPU 2: layers 30–44
|
||||||
- match:
|
- match:
|
||||||
@ -311,6 +313,7 @@
|
|||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cuda:2"
|
generate_device: "cuda:2"
|
||||||
prefill_device: "cuda:2"
|
prefill_device: "cuda:2"
|
||||||
|
absorb_for_prefill: False
|
||||||
|
|
||||||
# GPU 3: layers 45–60
|
# GPU 3: layers 45–60
|
||||||
- match:
|
- match:
|
||||||
@ -320,6 +323,7 @@
|
|||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cuda:3"
|
generate_device: "cuda:3"
|
||||||
prefill_device: "cuda:3"
|
prefill_device: "cuda:3"
|
||||||
|
absorb_for_prefill: False
|
||||||
|
|
||||||
# === Overall Model Replacement with Transfer Map ===
|
# === Overall Model Replacement with Transfer Map ===
|
||||||
|
|
||||||
|
|||||||
@ -60,6 +60,7 @@
|
|||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cuda"
|
generate_device: "cuda"
|
||||||
prefill_device: "cuda"
|
prefill_device: "cuda"
|
||||||
|
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
|
||||||
- match:
|
- match:
|
||||||
name: "^model$"
|
name: "^model$"
|
||||||
replace:
|
replace:
|
||||||
|
|||||||
86
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
Normal file
86
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^lm_head$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearMarlin"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearMarlin"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.gate.KMoEGate
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KExpertsCPU"
|
||||||
|
out_device: "cuda"
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
# if want to use more VRAM, use experts Marlin and disable CUDA Graph(disable CUDA Graph may cause low performance)
|
||||||
|
#- match:
|
||||||
|
# name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
|
# replace:
|
||||||
|
# class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
# kwargs:
|
||||||
|
# prefill_device: "cuda"
|
||||||
|
# prefill_op: "KExpertsTorch"
|
||||||
|
# generate_device: "cuda"
|
||||||
|
# generate_op: "KExpertsMarlin"
|
||||||
|
# recursive: False # don't recursively inject submodules of this module
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
||||||
@ -14,6 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
|
|||||||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||||
from ktransformers.local_chat import custom_models, default_optimize_rules
|
from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||||
from ktransformers.util.utils import get_device
|
from ktransformers.util.utils import get_device
|
||||||
|
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||||
|
|
||||||
|
|
||||||
warm_uped = False
|
warm_uped = False
|
||||||
@ -186,6 +187,8 @@ class KTransformersInterface(TransformersInterface):
|
|||||||
input_ids = input_ids.to("cpu")
|
input_ids = input_ids.to("cpu")
|
||||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
|
if flashinfer_enabled:
|
||||||
|
MLAWrapperSingleton.need_plan_all()
|
||||||
if self.use_static_cache:
|
if self.use_static_cache:
|
||||||
logits = self.model(
|
logits = self.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@ -198,6 +201,8 @@ class KTransformersInterface(TransformersInterface):
|
|||||||
else:
|
else:
|
||||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||||
|
|
||||||
|
if flashinfer_enabled:
|
||||||
|
MLAWrapperSingleton.reset_buffer()
|
||||||
self.prepare_logits_wrapper(input_ids, device)
|
self.prepare_logits_wrapper(input_ids, device)
|
||||||
next_token = self.logits_to_token(logits[0, -1, :])
|
next_token = self.logits_to_token(logits[0, -1, :])
|
||||||
yield self.append_new_tokens(next_token)
|
yield self.append_new_tokens(next_token)
|
||||||
|
|||||||
@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||||||
for i in range(1, self.args.max_new_tokens):
|
for i in range(1, self.args.max_new_tokens):
|
||||||
|
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
||||||
if i > 1 and flashinfer_enabled:
|
if flashinfer_enabled:
|
||||||
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
|
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
|
||||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
||||||
|
|||||||
@ -173,8 +173,8 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser(description="API Generate Tester")
|
parser = argparse.ArgumentParser(description="API Generate Tester")
|
||||||
parser.add_argument("--concurrent", type=int, default=1000, help="Number of concurrent evaluations")
|
parser.add_argument("--concurrent", type=int, default=1000, help="Number of concurrent evaluations")
|
||||||
parser.add_argument("--file", type=str, default="TIGER-Lab/MMLU-Pro", help="Path to the mmlu.jsonl file")
|
parser.add_argument("--file", type=str, default="TIGER-Lab/MMLU-Pro", help="Path to the mmlu.jsonl file")
|
||||||
parser.add_argument("--result", type=str, default="./mmlu_pro.json", help="Path to save the result JSON file")
|
parser.add_argument("--result", type=str, default="./mmlu_result_pro.json", help="Path to save the result JSON file")
|
||||||
parser.add_argument("--log", type=str, default="./mmlu_pro.log", help="Path to save the log file")
|
parser.add_argument("--log", type=str, default="./mmlu_result_pro.log", help="Path to save the log file")
|
||||||
parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path")
|
parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path")
|
||||||
parser.add_argument("--api_url", type=str, default="http://localhost:15488/v1/chat/completions", help="API URL")
|
parser.add_argument("--api_url", type=str, default="http://localhost:15488/v1/chat/completions", help="API URL")
|
||||||
# parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
|
# parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
|
||||||
|
|||||||
@ -330,6 +330,8 @@ class GGUFLoader:
|
|||||||
values = GGML_DEQUANTIZE[ggml_name](data)
|
values = GGML_DEQUANTIZE[ggml_name](data)
|
||||||
values = torch.from_numpy(values.copy())
|
values = torch.from_numpy(values.copy())
|
||||||
|
|
||||||
|
if ggml_name == "BF16":
|
||||||
|
values = values.view(torch.bfloat16)
|
||||||
values = values.view(shape[-2::-1])
|
values = values.view(shape[-2::-1])
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|||||||
@ -21,6 +21,18 @@ from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
|||||||
|
|
||||||
warm_uped = False
|
warm_uped = False
|
||||||
|
|
||||||
|
def get_compute_capability(device:torch.device = None):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
if device is None:
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
min_compute_capability_major = 100
|
||||||
|
for gpu_id in range(num_gpus):
|
||||||
|
gpu_props = torch.cuda.get_device_properties(gpu_id)
|
||||||
|
min_compute_capability_major = min(min_compute_capability_major, gpu_props.major)
|
||||||
|
return min_compute_capability_major
|
||||||
|
else:
|
||||||
|
return torch.cuda.get_device_properties(device)
|
||||||
|
|
||||||
def set_module(model, submodule_key, module):
|
def set_module(model, submodule_key, module):
|
||||||
tokens = submodule_key.split('.')
|
tokens = submodule_key.split('.')
|
||||||
sub_tokens = tokens[:-1]
|
sub_tokens = tokens[:-1]
|
||||||
@ -164,6 +176,10 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
|
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
|
||||||
else:
|
else:
|
||||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
||||||
|
if use_flashinfer_mla:
|
||||||
|
MLAWrapperSingleton.update_buffer(past_key_values.max_pages)
|
||||||
|
MLAWrapperSingleton.need_plan_all()
|
||||||
|
|
||||||
logits = model(
|
logits = model(
|
||||||
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
|
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
|
||||||
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
||||||
@ -186,6 +202,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||||||
else:
|
else:
|
||||||
next_token = torch.argmax(next_token_scores, dim=-1)
|
next_token = torch.argmax(next_token_scores, dim=-1)
|
||||||
first_token_time = time.time() - start_time
|
first_token_time = time.time() - start_time
|
||||||
|
|
||||||
|
if use_flashinfer_mla:
|
||||||
|
MLAWrapperSingleton.reset_buffer()
|
||||||
|
|
||||||
prefill_count = seq_length
|
prefill_count = seq_length
|
||||||
prefill_time = first_token_time
|
prefill_time = first_token_time
|
||||||
@ -203,22 +222,22 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i in range(1, max_new_tokens):
|
for i in range(1, max_new_tokens):
|
||||||
|
if use_flashinfer_mla:
|
||||||
|
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
|
||||||
|
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||||
|
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
|
||||||
global warm_uped
|
global warm_uped
|
||||||
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
||||||
warm_uped = True
|
warm_uped = True
|
||||||
cuda_graph_runner = CUDAGraphRunner()
|
cuda_graph_runner = CUDAGraphRunner()
|
||||||
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||||
if i > 1 and use_flashinfer_mla:
|
|
||||||
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
|
|
||||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
|
||||||
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
|
|
||||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
||||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||||
generated_ids[:, cache_position] = next_token.int()
|
generated_ids[:, cache_position] = next_token.int()
|
||||||
tokens.append(int(next_token))
|
tokens.append(int(next_token))
|
||||||
seq_length += 1
|
seq_length += 1
|
||||||
|
|
||||||
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token) == '<|im_end|>':
|
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
|
||||||
print(stream.end(), end="", flush=True)
|
print(stream.end(), end="", flush=True)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|||||||
1
setup.py
1
setup.py
@ -350,6 +350,7 @@ elif MUSA_HOME is not None:
|
|||||||
"at::cuda": "at::musa",
|
"at::cuda": "at::musa",
|
||||||
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
|
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
|
||||||
"#include <c10/cuda/CUDAGuard.h>": "#include \"torch_musa/csrc/core/MUSAGuard.h\"",
|
"#include <c10/cuda/CUDAGuard.h>": "#include \"torch_musa/csrc/core/MUSAGuard.h\"",
|
||||||
|
"nv_bfloat16": "mt_bfloat16",
|
||||||
}).run()
|
}).run()
|
||||||
ops_module = MUSAExtension('KTransformersOps', [
|
ops_module = MUSAExtension('KTransformersOps', [
|
||||||
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
|
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user