From 7c4cb520bd935f944312604b4e6a79380da6a3ae Mon Sep 17 00:00:00 2001 From: BITcyman <815207911@qq.com> Date: Mon, 12 Aug 2024 12:53:12 +0000 Subject: [PATCH] [feature] support q2_k & q3_k dequantize on gpu --- .../ktransformers_ext/cuda/binding.cpp | 6 +- .../cuda/custom_gguf/binding.cpp | 5 + .../cuda/custom_gguf/dequant.cu | 134 +++++++++++++++++- .../ktransformers_ext/cuda/custom_gguf/ops.h | 6 +- ktransformers/util/custom_gguf.py | 22 ++- 5 files changed, 161 insertions(+), 12 deletions(-) diff --git a/ktransformers/ktransformers_ext/cuda/binding.cpp b/ktransformers/ktransformers_ext/cuda/binding.cpp index f17382d..06ec5f3 100644 --- a/ktransformers/ktransformers_ext/cuda/binding.cpp +++ b/ktransformers/ktransformers_ext/cuda/binding.cpp @@ -4,7 +4,7 @@ * @Date : 2024-07-25 13:38:30 * @Version : 1.0.0 * @LastEditors : kkk1nak0 - * @LastEditTime : 2024-08-09 01:45:02 + * @LastEditTime : 2024-08-12 03:05:04 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ @@ -27,6 +27,10 @@ PYBIND11_MODULE(KTransformersOps, m) { py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); + m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", + py::arg("data"), py::arg("blk_size"), py::arg("device")); + m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", + py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp b/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp index 2cb46fc..70fc606 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp @@ -13,6 +13,7 @@ int test(){ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device); +torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device); PYBIND11_MODULE(cudaops, m) { m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.", @@ -23,6 +24,10 @@ PYBIND11_MODULE(cudaops, m) { py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); + m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", + py::arg("data"), py::arg("blk_size"), py::arg("device")); + m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", + py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("test", &test, "Function to test."); } diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu index aaa6453..cc5552b 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu @@ -4,7 +4,7 @@ * @Date : 2024-07-25 13:38:30 * @Version : 1.0.0 * @LastEditors : kkk1nak0 - * @LastEditTime : 2024-08-09 07:57:06 + * @LastEditTime : 2024-08-12 04:18:04 * Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c * Copyright (c) 2023-2024 The ggml authors * Copyright (c) 2024 by KVCache.AI, All Rights Reserved. @@ -36,6 +36,97 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_ } } +__global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { + int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (auto block_id=global_idx; block_id(data + block_id * blk_size + 80))); + const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 82))); + + const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16); + + int is = 0; + float dl, ml; + + for (int n = 0; n < 256; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++)); + uint8_t sc = *scales; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml; + + scales = (uint8_t*)(data + block_id * blk_size + (is++)); + sc = *scales; + + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml; + + shift += 2; + } + q += 32; + } + } +} + +__global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { + + int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + for (auto block_id=global_idx; block_id(data + block_id * blk_size + 108))); + + const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32); + const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0); + uint8_t m = 1; + + + uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96); + + for (int i = 0; i < 3; i++) { + aux[i] = 0; + for (int j = 0; j < 4; j++) { + aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8); + } + } + + uint32_t tmp = aux[2]; + aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + int is = 0; + float dl; + for (int n = 0; n < 256; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *output_blk++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)); + } + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)); + } + + shift += 2; + m <<= 1; + } + q += 32; + } + } +} + + __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { int global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (auto block_id=global_idx; block_id>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + + cudaDeviceSynchronize(); + return output; +} + torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device) { // data.numel%blk_size should be 0, else raise err int num_blocks = data.numel() / blk_size; @@ -196,8 +305,7 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de return output; } - -torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device) { +torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device) { int num_blocks = data.numel() / blk_size; auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); @@ -209,7 +317,25 @@ torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device de auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); // Launch kernel - dequantize_q5_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + dequantize_q3_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + + cudaDeviceSynchronize(); + return output; +} + +torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device) { + int num_blocks = data.numel() / blk_size; + + auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); + auto data_gpu = torch::empty({data.numel()}, options); + + data_gpu.copy_(data, false); + + // Create output tensor + auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device)); + + // Launch kernel + dequantize_q2_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); cudaDeviceSynchronize(); return output; diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h index f5fde87..5196f88 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h @@ -4,7 +4,7 @@ * @Date : 2024-07-22 09:27:55 * @Version : 1.0.0 * @LastEditors : kkk1nak0 - * @LastEditTime : 2024-08-09 01:44:21 + * @LastEditTime : 2024-08-12 03:48:46 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #pragma once @@ -16,4 +16,6 @@ torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device); \ No newline at end of file +torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device); +torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device); +torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device); \ No newline at end of file diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index fe796a7..bd5c5b0 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022 Date : 2024-07-26 08:48:54 Version : 1.0.0 LastEditors : kkk1nak0 -LastEditTime : 2024-08-09 08:03:44 +LastEditTime : 2024-08-12 07:21:55 Adapted from https://github.com/99991/pygguf/blob/main/gguf.py Copyright (c) 2023-2024 The ggml authors Copyright (c) 2024 Thomas Germer @@ -390,8 +390,14 @@ def dequantize_q2_k(data): return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) -def dequantize_q2_k_gpu(data): - raise NotImplementedError() +def dequantize_q2_k_gpu(data, device:str ="cuda"): + block_size = GGML_BLOCK_SIZES["Q2_K"] + data = np.frombuffer(data, dtype=data.dtype) + device = torch.device(device) + # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, + # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. + data = torch.from_numpy(data) + return KTransformersOps.dequantize_q2_k(data, block_size, device) def dequantize_q3_k(data): # C implementation @@ -435,8 +441,14 @@ def dequantize_q3_k(data): (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]) ], axis=1) -def dequantize_q3_k_gpu(data): - raise NotImplementedError() +def dequantize_q3_k_gpu(data, device:str ="cuda"): + block_size = GGML_BLOCK_SIZES["Q3_K"] + data = np.frombuffer(data, dtype=data.dtype) + device = torch.device(device) + # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, + # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. + data = torch.from_numpy(data) + return KTransformersOps.dequantize_q3_k(data, block_size, device) def dequantize_q4_k(data): # C implementation