From 5ec33d046d11665dbb1c13fb557beef88753ac35 Mon Sep 17 00:00:00 2001 From: Atream Date: Sat, 22 Feb 2025 06:13:01 +0000 Subject: [PATCH] optimize gguf dequant, save mem, support Q2_K use marlin for lm_head, lm_head only calc last token for prefill extend context window to 19K for DeepSeek-V3/R1 within 24GB VRAM --- .../ktransformers_ext/cuda/binding.cpp | 64 +++-- .../cuda/custom_gguf/binding.cpp | 35 --- .../cuda/custom_gguf/dequant.cu | 258 +++++++++--------- .../ktransformers_ext/cuda/custom_gguf/ops.h | 14 +- .../ktransformers_ext/cuda/test_dequant.py | 16 ++ .../quantize/utils/marlin_utils.py | 4 +- .../quantize/utils/quant_utils.py | 12 +- ktransformers/models/modeling_deepseek.py | 3 +- ktransformers/models/modeling_deepseek_v3.py | 2 +- ktransformers/operators/flashinfer_wrapper.py | 2 +- ktransformers/operators/linear.py | 59 ++-- ktransformers/optimize/optimize.py | 2 + .../DeepSeek-V2-Chat-multi-gpu-4.yaml | 14 +- .../DeepSeek-V2-Chat-multi-gpu.yaml | 13 +- .../optimize_rules/DeepSeek-V2-Chat.yaml | 12 + .../DeepSeek-V2-Lite-Chat-multi-gpu.yaml | 13 +- .../optimize_rules/DeepSeek-V2-Lite-Chat.yaml | 12 + .../DeepSeek-V3-Chat-multi-gpu-4.yaml | 19 +- .../DeepSeek-V3-Chat-multi-gpu-8.yaml | 17 +- .../DeepSeek-V3-Chat-multi-gpu-marlin.yaml | 13 +- .../DeepSeek-V3-Chat-multi-gpu.yaml | 13 +- .../optimize/optimize_rules/Mixtral.yaml | 10 + .../Qwen2-57B-A14B-Instruct-multi-gpu.yaml | 12 +- .../Qwen2-57B-A14B-Instruct.yaml | 10 + ktransformers/util/custom_gguf.py | 52 +++- ktransformers/util/utils.py | 2 +- test_prompt.txt | 11 + 27 files changed, 435 insertions(+), 259 deletions(-) delete mode 100644 ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp create mode 100644 ktransformers/ktransformers_ext/cuda/test_dequant.py diff --git a/ktransformers/ktransformers_ext/cuda/binding.cpp b/ktransformers/ktransformers_ext/cuda/binding.cpp index 1f89b31..96ee9d8 100644 --- a/ktransformers/ktransformers_ext/cuda/binding.cpp +++ b/ktransformers/ktransformers_ext/cuda/binding.cpp @@ -1,10 +1,8 @@ /** * @Description : - * @Author : Azure-Tang + * @Author : Azure-Tang, Boxin Zhang * @Date : 2024-07-25 13:38:30 - * @Version : 1.0.0 - * @LastEditors : kkk1nak0 - * @LastEditTime : 2024-08-12 03:05:04 + * @Version : 0.2.2 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ @@ -19,22 +17,44 @@ // namespace py = pybind11; PYBIND11_MODULE(KTransformersOps, m) { - m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.", - py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", - py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", - py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", - py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", - py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", - py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", - py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); - m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", - py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), - py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), - py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full")); + + m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { + return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + }, "Function to dequantize q8_0 data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); + + m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { + return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + }, "Function to dequantize q6_k data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); + + m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { + return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + }, "Function to dequantize q5_k data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); + + m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { + return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + }, "Function to dequantize q4_k data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); + + m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { + return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + }, "Function to dequantize q3_k data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); + + m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { + return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + }, "Function to dequantize q2_k data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); + + m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { + return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + }, "Function to dequantize iq4_xs data.", + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); + + m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", + py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), + py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), + py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full")); } diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp b/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp deleted file mode 100644 index 2011247..0000000 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include "ops.h" -// Python bindings -#include -#include -#include -#include -#include -// namespace py = pybind11; - -int test(){ - return 5; -} - -torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device); -torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device); - -PYBIND11_MODULE(cudaops, m) { - m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.", - py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", - py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", - py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", - py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", - py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", - py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", - py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); - m.def("test", &test, "Function to test."); - -} diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu index d5184ce..e80efc4 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu @@ -2,9 +2,7 @@ * @Description : * @Author : Azure-Tang, Boxin Zhang * @Date : 2024-07-25 13:38:30 - * @Version : 1.0.0 - * @LastEditors : kkk1nak0 - * @LastEditTime : 2024-08-12 04:18:04 + * @Version : 0.2.2 * Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c * Copyright (c) 2023-2024 The ggml authors * Copyright (c) 2024 by KVCache.AI, All Rights Reserved. @@ -18,45 +16,42 @@ #include #include -__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ - float* __restrict__ output_blk = (float*)(output + block_id * 256); + float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk); const int8_t* cur_block = data + block_id * blk_size; float scale = __half2float(*((half*)cur_block)); cur_block += 2; - for (int i = 0; i < 32; i++){ + for (int i = 0; i < ele_per_blk; i++){ output_blk[i] = scale * cur_block[i]; } - output_blk += 32; } } -__global__ void dequantize_q8_0_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_q8_0_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) { - __half* __restrict__ output_blk = (__half*)(output + block_id * 256); + __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk); const int8_t* cur_block = data + block_id * blk_size; float scale = __half2float(*((half*)cur_block)); cur_block += 2; - for (int i = 0; i < 32; i++) { + for (int i = 0; i < ele_per_blk; i++) { output_blk[i] = __float2half(scale * cur_block[i]); } - output_blk += 32; } } -__global__ void dequantize_q8_0_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_q8_0_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) { - nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256); + nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk); const int8_t* cur_block = data + block_id * blk_size; float scale = __half2float(*((half*)cur_block)); cur_block += 2; - for (int i = 0; i < 32; i++) { + for (int i = 0; i < ele_per_blk; i++) { output_blk[i] = __float2bfloat16(scale * cur_block[i]); } - output_blk += 32; } } @@ -70,10 +65,10 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_ } } -__global__ void dequantize_q2_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_q2_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 80))); const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 82))); @@ -104,10 +99,10 @@ __global__ void dequantize_q2_k_fp32_kernel(const int8_t* data, float* output, c } } -__global__ void dequantize_q2_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_q2_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 80))); const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 82))); @@ -138,10 +133,10 @@ __global__ void dequantize_q2_k_fp16_kernel(const int8_t* data, __half* output, } } -__global__ void dequantize_q2_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_q2_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 80))); const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 82))); @@ -172,13 +167,13 @@ __global__ void dequantize_q2_k_bf16_kernel(const int8_t* data, nv_bfloat16* out } } -__global__ void dequantize_q3_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_q3_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; const uint32_t kmask1 = 0x03030303; const uint32_t kmask2 = 0x0f0f0f0f; for (long long block_id=global_idx; block_id(data + block_id * 144 + 2))); int is = 0; uint8_t sc, m; - for (int j = 0; j < blk_size; j += 64) { + for (int j = 0; j < ele_per_blk; j += 64) { uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4); get_scale_min_k4(is + 0, scales, &sc, &m); const float d1 = d * sc; const float m1 = min * m; @@ -365,10 +360,10 @@ __global__ void dequantize_q4_k_fp32_kernel(const int8_t* data, float* output, c } } -__global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * 144 + 2))); int is = 0; uint8_t sc, m; - for (int j = 0; j < blk_size; j += 64) { + for (int j = 0; j < ele_per_blk; j += 64) { uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4); get_scale_min_k4(is + 0, scales, &sc, &m); const float d1 = d * sc; const float m1 = min * m; @@ -389,10 +384,10 @@ __global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output, } } -__global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * 144 + 2))); int is = 0; uint8_t sc, m; - for (int j = 0; j < blk_size; j += 64) { + for (int j = 0; j < ele_per_blk; j += 64) { uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4); get_scale_min_k4(is + 0, scales, &sc, &m); const float d1 = d * sc; const float m1 = min * m; @@ -413,10 +408,10 @@ __global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* out } } -__global__ void dequantize_q5_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_q5_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ - float* __restrict__ output_blk = (float*)(output + block_id * 256); + float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk); const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 0))); const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 2))); @@ -442,10 +437,10 @@ __global__ void dequantize_q5_k_fp32_kernel(const int8_t* data, float* output, c } } -__global__ void dequantize_q5_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_q5_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ - __half* __restrict__ output_blk = (__half*)(output + block_id * 256); + __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk); const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 0))); const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 2))); @@ -471,10 +466,10 @@ __global__ void dequantize_q5_k_fp16_kernel(const int8_t* data, __half* output, } } -__global__ void dequantize_q5_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_q5_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ - nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256); + nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk); const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 0))); const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 2))); @@ -500,10 +495,10 @@ __global__ void dequantize_q5_k_bf16_kernel(const int8_t* data, nv_bfloat16* out } } -__global__ void dequantize_q6_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_q6_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 208))); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); @@ -511,31 +506,30 @@ __global__ void dequantize_q6_k_fp32_kernel(const int8_t* data, float* output, c const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192); - //if (blk_size == 256){ - for (int n = 0; n < blk_size; n += 128) { - for (int l = 0; l < 32; ++l) { - int is = l/16; - const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - output_blk[l + 0] = d * sc[is + 0] * q1; - output_blk[l + 32] = d * sc[is + 2] * q2; - output_blk[l + 64] = d * sc[is + 4] * q3; - output_blk[l + 96] = d * sc[is + 6] * q4; - } - output_blk += 128; - ql += 64; - qh += 32; - sc += 8; + for (int n = 0; n < ele_per_blk; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + output_blk[l + 0] = d * sc[is + 0] * q1; + output_blk[l + 32] = d * sc[is + 2] * q2; + output_blk[l + 64] = d * sc[is + 4] * q3; + output_blk[l + 96] = d * sc[is + 6] * q4; } + output_blk += 128; + ql += 64; + qh += 32; + sc += 8; + } } } -__global__ void dequantize_q6_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_q6_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 208))); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); @@ -543,31 +537,30 @@ __global__ void dequantize_q6_k_fp16_kernel(const int8_t* data, __half* output, const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192); - //if (blk_size == 256){ - for (int n = 0; n < blk_size; n += 128) { - for (int l = 0; l < 32; ++l) { - int is = l/16; - const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - output_blk[l + 0] = __float2half(d * sc[is + 0] * q1); - output_blk[l + 32] = __float2half(d * sc[is + 2] * q2); - output_blk[l + 64] = __float2half(d * sc[is + 4] * q3); - output_blk[l + 96] = __float2half(d * sc[is + 6] * q4); - } - output_blk += 128; - ql += 64; - qh += 32; - sc += 8; + for (int n = 0; n < ele_per_blk; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + output_blk[l + 0] = __float2half(d * sc[is + 0] * q1); + output_blk[l + 32] = __float2half(d * sc[is + 2] * q2); + output_blk[l + 64] = __float2half(d * sc[is + 4] * q3); + output_blk[l + 96] = __float2half(d * sc[is + 6] * q4); } + output_blk += 128; + ql += 64; + qh += 32; + sc += 8; + } } } -__global__ void dequantize_q6_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_q6_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 208))); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); @@ -575,33 +568,32 @@ __global__ void dequantize_q6_k_bf16_kernel(const int8_t* data, nv_bfloat16* out const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192); - //if (blk_size == 256){ - for (int n = 0; n < blk_size; n += 128) { - for (int l = 0; l < 32; ++l) { - int is = l/16; - const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - output_blk[l + 0] = __float2bfloat16(d * sc[is + 0] * q1); - output_blk[l + 32] = __float2bfloat16(d * sc[is + 2] * q2); - output_blk[l + 64] = __float2bfloat16(d * sc[is + 4] * q3); - output_blk[l + 96] = __float2bfloat16(d * sc[is + 6] * q4); - } - output_blk += 128; - ql += 64; - qh += 32; - sc += 8; + for (int n = 0; n < ele_per_blk; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + output_blk[l + 0] = __float2bfloat16(d * sc[is + 0] * q1); + output_blk[l + 32] = __float2bfloat16(d * sc[is + 2] * q2); + output_blk[l + 64] = __float2bfloat16(d * sc[is + 4] * q3); + output_blk[l + 96] = __float2bfloat16(d * sc[is + 6] * q4); } + output_blk += 128; + ql += 64; + qh += 32; + sc += 8; + } } } static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; -__global__ void dequantize_iq4_xs_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_iq4_xs_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size))); const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); @@ -620,10 +612,10 @@ __global__ void dequantize_iq4_xs_fp32_kernel(const int8_t* data, float* output, } } -__global__ void dequantize_iq4_xs_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_iq4_xs_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size))); const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); @@ -642,10 +634,10 @@ __global__ void dequantize_iq4_xs_fp16_kernel(const int8_t* data, __half* output } } -__global__ void dequantize_iq4_xs_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { +__global__ void dequantize_iq4_xs_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size))); const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); @@ -664,7 +656,7 @@ __global__ void dequantize_iq4_xs_bf16_kernel(const int8_t* data, nv_bfloat16* o } } -torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { +torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); @@ -679,13 +671,13 @@ torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int switch (target_dtype) { case torch::kFloat16: - dequantize_q8_0_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, num_blocks); + dequantize_q8_0_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kBFloat16: - dequantize_q8_0_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + dequantize_q8_0_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kFloat32: - dequantize_q8_0_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + dequantize_q8_0_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; default: printf("target type not support\n"); @@ -697,7 +689,7 @@ torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int } -torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { +torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { // data.numel%blk_size should be 0, else raise err int num_blocks = num_bytes / blk_size; @@ -713,13 +705,13 @@ torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int switch (target_dtype) { case torch::kFloat16: - dequantize_q6_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, num_blocks); + dequantize_q6_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kBFloat16: - dequantize_q6_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + dequantize_q6_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kFloat32: - dequantize_q6_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + dequantize_q6_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; default: printf("target type not support\n"); @@ -729,7 +721,7 @@ torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int return output; } -torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { +torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); @@ -744,13 +736,13 @@ torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int switch (target_dtype) { case torch::kFloat16: - dequantize_q5_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, num_blocks); + dequantize_q5_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kBFloat16: - dequantize_q5_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + dequantize_q5_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kFloat32: - dequantize_q5_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + dequantize_q5_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; default: printf("target type not support\n"); @@ -760,7 +752,7 @@ torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int return output; } -torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { +torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { // data.numel%blk_size should be 0, else raise err int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); @@ -776,13 +768,13 @@ torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int switch (target_dtype) { case torch::kFloat16: - dequantize_q4_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, num_blocks); + dequantize_q4_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kBFloat16: - dequantize_q4_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + dequantize_q4_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kFloat32: - dequantize_q4_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + dequantize_q4_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; default: printf("target type not support\n"); @@ -792,7 +784,7 @@ torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int return output; } -torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { +torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); @@ -807,13 +799,13 @@ torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int switch (target_dtype) { case torch::kFloat16: - dequantize_q3_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, num_blocks); + dequantize_q3_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kBFloat16: - dequantize_q3_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + dequantize_q3_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kFloat32: - dequantize_q3_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + dequantize_q3_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; default: printf("target type not support\n"); @@ -823,7 +815,7 @@ torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int return output; } -torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { +torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); @@ -838,13 +830,13 @@ torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int switch (target_dtype) { case torch::kFloat16: - dequantize_q2_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, num_blocks); + dequantize_q2_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kBFloat16: - dequantize_q2_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + dequantize_q2_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kFloat32: - dequantize_q2_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + dequantize_q2_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; default: printf("target type not support\n"); @@ -854,7 +846,7 @@ torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int return output; } -torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { +torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); @@ -869,13 +861,13 @@ torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const i switch (target_dtype) { case torch::kFloat16: - dequantize_iq4_xs_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, num_blocks); + dequantize_iq4_xs_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kBFloat16: - dequantize_iq4_xs_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + dequantize_iq4_xs_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kFloat32: - dequantize_iq4_xs_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, num_blocks); + dequantize_iq4_xs_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; default: printf("target type not support\n"); diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h index b18c799..a52db2d 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h @@ -13,10 +13,10 @@ #include #include -torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); diff --git a/ktransformers/ktransformers_ext/cuda/test_dequant.py b/ktransformers/ktransformers_ext/cuda/test_dequant.py new file mode 100644 index 0000000..abca745 --- /dev/null +++ b/ktransformers/ktransformers_ext/cuda/test_dequant.py @@ -0,0 +1,16 @@ +import os +import sys +sys.path.insert(0,"/home/zbx/ktransformers") +from ktransformers.util.custom_gguf import GGUFLoader +import torch + +gguf_loader_1 = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf") +gguf_loader_2 = GGUFLoader("/mnt/data/chenht/model/gguf_for_ktransformers/DeepSeek-V3-bf16/") + +torch.set_default_dtype(torch.bfloat16) + +tensor_1 = gguf_loader_1.load_gguf_tensor("blk.0.attn_kv_a_mqa.weight", "cuda") +tensor_2 = gguf_loader_2.load_gguf_tensor("blk.0.attn_kv_a_mqa.weight", "cuda") + +print(tensor_1[0, -64:]) +print(tensor_2[0, -64:]) \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py index accbc00..fadfb11 100644 --- a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py +++ b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py @@ -90,7 +90,7 @@ def marlin_quantize( assert group_size <= size_k # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, + q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, act_order) # For act_order, sort the "weights" and "g_idx" so that group ids are @@ -107,7 +107,7 @@ def marlin_quantize( marlin_scale_perm_single[num_bits]) # Create result - res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + res_list = [marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] for i in range(len(res_list)): res_list[i] = res_list[i].to(w.device) diff --git a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py index b3a0ba5..de73667 100644 --- a/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py +++ b/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py @@ -11,8 +11,7 @@ def get_pack_factor(num_bits): return 32 // num_bits -def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): - assert q_w.shape == w_ref.shape +def permute_rows(q_w: torch.Tensor, group_size: int): orig_device = q_w.device k_size, _ = q_w.shape @@ -26,10 +25,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): g_idx = g_idx[rand_perm].contiguous() q_w = q_w[rand_perm, :].contiguous() - w_ref = w_ref[rand_perm, :].contiguous() return ( - w_ref.to(device=orig_device), q_w.to(device=orig_device), g_idx.to(device=orig_device), rand_perm.to(device=orig_device), @@ -69,9 +66,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, q_w += half_q_val q_w = torch.clamp(q_w, 0, max_q_val) - # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s - # Restore original shapes if group_size < size_k: @@ -82,7 +76,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, return w q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) s = s.reshape((-1, size_n)).contiguous() @@ -95,10 +88,9 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, ), "For act_order, groupsize = {} must be less than size_k = {}".format( group_size, size_k) - w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size) + q_w, g_idx, rand_perm = permute_rows(q_w, group_size) return ( - w_ref.to(device=orig_device), q_w.to(device=orig_device), s.to(device=orig_device), g_idx.to(device=orig_device), diff --git a/ktransformers/models/modeling_deepseek.py b/ktransformers/models/modeling_deepseek.py index 692020d..e14a521 100644 --- a/ktransformers/models/modeling_deepseek.py +++ b/ktransformers/models/modeling_deepseek.py @@ -1742,8 +1742,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits[:,-1,:].unsqueeze(0).float() + logits = self.lm_head(hidden_states[:,-1:,:]).float() loss = None if labels is not None: diff --git a/ktransformers/models/modeling_deepseek_v3.py b/ktransformers/models/modeling_deepseek_v3.py index 277258a..952eed7 100644 --- a/ktransformers/models/modeling_deepseek_v3.py +++ b/ktransformers/models/modeling_deepseek_v3.py @@ -1699,7 +1699,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states.to(self.lm_head.weight.device)) + logits = self.lm_head(hidden_states[:,-1:,:]) logits = logits.float() loss = None diff --git a/ktransformers/operators/flashinfer_wrapper.py b/ktransformers/operators/flashinfer_wrapper.py index 8d49187..b7d0938 100644 --- a/ktransformers/operators/flashinfer_wrapper.py +++ b/ktransformers/operators/flashinfer_wrapper.py @@ -9,7 +9,7 @@ flashinfer_enabled = False try: import flashinfer - flashinfer_enabled = True + flashinfer_enabled = False print("found flashinfer") except ImportError: diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 08a2cca..394aa03 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -21,6 +21,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl MarlinWorkspace, marlin_quantize, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MIN_THREAD_K, GPTQ_MARLIN_MAX_PARALLEL, ) from ktransformers.operators.base_operator import BaseInjectedModule @@ -64,6 +65,8 @@ class KLinearBase(ABC): self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0] self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1] + self.loaded = False # for lm_head pre-load, TODO: use new way to do lm_head pre-load when layer wise prefill. + @abstractmethod def forward(self, x: torch.Tensor) -> torch.Tensor: pass @@ -134,6 +137,7 @@ class KLinearTorch(KLinearBase): return x def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): + if self.loaded: return if device is None: device = self.device if w is None: w = self.load_weight(device=device) # else: self.out_features = w.shape[0], self.in_features = w.shape[1] @@ -157,6 +161,7 @@ class KLinearTorch(KLinearBase): self.weight = self.weight.to(device) if self.has_bias: self.bias = self.bias.to(device) + self.loaded = True def unload(self): if self.weight is not None: @@ -190,20 +195,36 @@ class KLinearMarlin(KLinearBase): self.group_size = group_size self.act_order = act_order self.is_k_full = is_k_full + self.padding = False + self.orin_in_features = self.in_features + self.orin_out_features = self.out_features + if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0: + #print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding") + self.padding = True + self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K + self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N + #print(f"After padding: in_features={in_features}, out_features={out_features}") + + self.k = self.in_features + self.n = self.out_features def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): + if self.loaded: return if device is None: device = self.device assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" + + #if self.in_features * self.out_features: if w is None: w = self.load_weight(device=device) if isinstance(w, nn.Parameter): # pad weight - weight = w.view(self.out_features, self.in_features).T + weight = w.view(self.orin_out_features, self.orin_in_features).T self.has_bias = False elif isinstance(w, tuple): w = list(w) - weight = w[0].view(self.out_features, self.in_features).T + weight = w[0].view(self.orin_out_features, self.orin_in_features).T + self.bias = w[1].view(self.orin_out_features) self.bias = w[1] self.has_bias = True else: @@ -211,8 +232,14 @@ class KLinearMarlin(KLinearBase): weight = weight.to(device) if self.has_bias: self.bias = self.bias.to(device) + + if self.padding: + padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device) + padded_weight[:self.orin_in_features, :self.orin_out_features] = weight + weight = padded_weight + # Pack Marlin linear - w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( weight, self.num_bits, self.group_size, self.act_order ) self.workspace = MarlinWorkspace( @@ -225,6 +252,7 @@ class KLinearMarlin(KLinearBase): self.sort_indices = sort_indices self.k = weight.shape[0] self.n = weight.shape[1] + self.loaded = True def forward(self, x: torch.Tensor) -> torch.Tensor: # Only support input x as BF16 and FP16 @@ -232,6 +260,11 @@ class KLinearMarlin(KLinearBase): orig_shape = list(x.shape) orig_dtype = x.dtype x = x.reshape(-1, orig_shape[-1]) + x = x.reshape(-1, x.shape[-1]) + if self.padding: + padding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype) + padding_input[:,:self.orin_in_features] = x + x = padding_input marlin_s = self.marlin_s.to(x.dtype) x = KTransformersOps.gptq_marlin_gemm( x, @@ -246,6 +279,11 @@ class KLinearMarlin(KLinearBase): x.shape[-1], self.is_k_full, ) + if self.padding: + x = x[:,:self.orin_out_features] + orig_shape[-1] = self.orin_out_features + else: + orig_shape[-1] = self.out_features if self.has_bias: x = x + self.bias orig_shape[-1] = self.n @@ -388,24 +426,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): # build all the linear operators if prefill_op is not None: assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported" - if prefill_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0): - print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.") - print(f"module info: key:{key} orig_module:{orig_module}") - self.prefill_linear = KLinearTorch(key, gguf_loader, config, orig_module, prefill_device, **kwargs) - else: - self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs) else: self.prefill_linear = None if generate_op is not None: assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported" - if generate_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0): - print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.") - print(f"module info: key:{key} orig_module:{orig_module}") - self.generate_op = "KLinearTorch" - self.generate_linear = KLinearTorch(key, gguf_loader, config, orig_module, generate_device, **kwargs) - else: - self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs) + self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs) else: self.generate_linear = None self.mode = InferenceState.UNLOAD diff --git a/ktransformers/optimize/optimize.py b/ktransformers/optimize/optimize.py index 32eab01..331e6cf 100644 --- a/ktransformers/optimize/optimize.py +++ b/ktransformers/optimize/optimize.py @@ -126,6 +126,8 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo gguf_loader=GGUFLoader(gguf_path) with torch.device("meta"): inject(module, optimize_config, model_config, gguf_loader) + # pre load lm_head because its big inter result + load_weights(module.lm_head, gguf_loader, "lm_head.") load_weights(module, gguf_loader) module.gguf_loader = gguf_loader del_meta(module) diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml index a87a30c..66a420a 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml @@ -219,8 +219,20 @@ kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" + - match: - name: "(^model\\.layers\\.([5][0-9]|[4][5-9])\\.)|(^model.norm)|(^lm_head)" + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:3" + prefill_device: "cuda:3" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "(^model\\.layers\\.([5][0-9]|[4][5-9])\\.)|(^model.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml index 269257e..f409376 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml @@ -118,7 +118,18 @@ prefill_device: "cuda:0" - match: - name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)" + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml index b115aba..7f3e44e 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml @@ -15,6 +15,18 @@ prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" + +- match: + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml index 99d01c0..158892d 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml @@ -118,7 +118,18 @@ prefill_device: "cuda:0" - match: - name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)" + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml index b115aba..7f3e44e 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml @@ -15,6 +15,18 @@ prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" + +- match: + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml index 84ab801..03c85a0 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml @@ -188,7 +188,7 @@ # !!!Do remember 'close' cuda graph if you are using marlin expert.!!! # !!!KExpertsTorch is untested, we don't have enough VRAM.!!! -# # GPU 0: layers 3–4 +# GPU 0: layers 3–4 # - match: # name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$" # replace: @@ -363,11 +363,20 @@ generate_device: "cuda:2" prefill_device: "cuda:2" -# don't inject lm_head if already inject marlin experts - -# For final modules (model.norm and lm_head), ensure they are on GPU 3 (as in your original config) - match: - name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:3" + prefill_device: "cuda:3" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +# For final modules (model.norm), ensure they are on GPU 3 (as in your original config) +- match: + name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml index a10b57f..b00d2b4 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml @@ -713,11 +713,20 @@ generate_device: "cuda:7" prefill_device: "cuda:7" -# don't inject lm_head if already inject marlin experts - -# For final modules (model.norm and lm_head), ensure they are on GPU 7 (as in your original config) - match: - name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:7" + prefill_device: "cuda:7" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +# For final modules (model.norm), ensure they are on GPU 7 (as in your original config) +- match: + name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml index 92571b5..6b39121 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml @@ -153,7 +153,18 @@ prefill_device: "cuda:0" - match: - name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml index 06ab4db..50e282d 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml @@ -135,7 +135,18 @@ prefill_device: "cuda:0" - match: - name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/Mixtral.yaml b/ktransformers/optimize/optimize_rules/Mixtral.yaml index 7d48812..80a346a 100644 --- a/ktransformers/optimize/optimize_rules/Mixtral.yaml +++ b/ktransformers/optimize/optimize_rules/Mixtral.yaml @@ -15,6 +15,16 @@ prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" +- match: + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.block_sparse_moe$" class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock diff --git a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml index da4fb4a..da01c82 100644 --- a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml @@ -77,9 +77,19 @@ kwargs: generate_device: "cpu" prefill_device: "cpu" +- match: + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" - match: - name: "(^model.norm)|(^lm_head)" + name: "(^model.norm)" replace: class: "default" kwargs: diff --git a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml index 0cc2edf..38e9e73 100644 --- a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml +++ b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml @@ -15,6 +15,16 @@ prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" +- match: + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 7ad13c7..0c42c9c 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -25,6 +25,7 @@ import os from enum import IntEnum import torch import KTransformersOps +import ctypes class GGMLQuantizationType(IntEnum): F32 = 0 @@ -307,7 +308,7 @@ class GGUFLoader: values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype) else: values = GGML_DEQUANTIZE[ggml_name](data) - values = torch.from_numpy(values) + values = torch.from_numpy(values.copy()) values = values.view(shape[-2::-1]) @@ -343,7 +344,7 @@ class GGUFLoader: cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype) else: cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size]) - cur_values = torch.from_numpy(cur_values) + cur_values = torch.from_numpy(cur_values.copy()) cur_values = cur_values.view(-1, elements_per_block) values[blocks_begin : blocks_end] = cur_values @@ -455,11 +456,13 @@ def dequantize_q2_k(data): def dequantize_q2_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q2_K"] + ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q2_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. - return KTransformersOps.dequantize_q2_k(data.data, data.size, block_size, device, target_dtype) + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q2_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q3_k(data): # C implementation @@ -505,11 +508,13 @@ def dequantize_q3_k(data): def dequantize_q3_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q3_K"] + ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q3_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. - return KTransformersOps.dequantize_q3_k(data.data, data.size, block_size, device, target_dtype) + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q3_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q4_k(data): # C implementation @@ -534,11 +539,14 @@ def dequantize_q4_k(data): return factors * qs2 - offsets def dequantize_q4_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): + block_size = GGML_BLOCK_SIZES["Q4_K"] + ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q4_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. - return KTransformersOps.dequantize_q4_k(data.data, data.size, 144, device, target_dtype) + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q4_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q5_k(data): # C implementation @@ -598,11 +606,13 @@ def dequantize_q5_k(data): def dequantize_q5_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q5_K"] + ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q5_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. - return KTransformersOps.dequantize_q5_k(data.data, data.size, block_size, device, target_dtype) + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q5_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q6_k(data): # C implementation @@ -655,10 +665,12 @@ def dequantize_q6_k(data): # @torch.jit.script def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q6_K"] + ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q6_K"] device = torch.device(device) num_blocks = len(data) // block_size data = np.frombuffer(data, dtype=data.dtype) - return KTransformersOps.dequantize_q6_k(data.data, data.size, block_size, device, target_dtype) + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q6_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8) @@ -694,10 +706,12 @@ def dequantize_iq4_xs(data): def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["IQ4_XS"] + ele_per_blk = GGML_ELEMENTS_PER_BLOCK["IQ4_XS"] device = torch.device(device) num_blocks = len(data) // block_size data = np.frombuffer(data, dtype=data.dtype) - return KTransformersOps.dequantize_iq4_xs(data.data, data.size, block_size, device, target_dtype) + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_iq4_xs(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q4_0(data): # C implementation @@ -753,10 +767,13 @@ def dequantize_q8_0(data): def dequantize_q8_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()): # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43 - num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"] + + block_size = GGML_BLOCK_SIZES["Q8_0"] + ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q8_0"] device = torch.device(device) data = np.frombuffer(data, dtype=data.dtype) - return KTransformersOps.dequantize_q8_0(data.data, data.size, 34, device, target_dtype) + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q8_0(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_f32(data): @@ -764,8 +781,8 @@ def dequantize_f32(data): def dequantize_f32_gpu(data, device, target_dtype = torch.get_default_dtype()): data = np.frombuffer(data, dtype=np.float32) - res = torch.from_numpy(data) - res_gpu = torch.empty_like(res, device=device) + res = torch.from_numpy(data.copy()) + res_gpu = torch.empty_like(res, device=device, dtype=target_dtype) res_gpu.copy_(res) return res_gpu @@ -774,7 +791,14 @@ def dequantize_f16(data): def dequantize_f16_gpu(data, device, target_dtype = torch.get_default_dtype()): data = np.frombuffer(data, dtype=np.float16) - res = torch.from_numpy(data) + res = torch.from_numpy(data.copy()) + res_gpu = torch.empty_like(res, device=device, dtype=target_dtype) + res_gpu.copy_(res) + return res_gpu + +def dequantize_bf16_gpu(data, device, target_dtype = torch.get_default_dtype()): + data = np.frombuffer(data, dtype=np.float16) + res = torch.from_numpy(data.copy()) res_gpu = torch.empty_like(res, device=device) res_gpu.copy_(res) return res_gpu @@ -797,7 +821,7 @@ GGML_DEQUANTIZE = { GGML_DEQUANTIZE_GPU = { "F32": dequantize_f32_gpu, "F16": dequantize_f16_gpu, - "BF16": dequantize_f16_gpu, + "BF16": dequantize_bf16_gpu, "Q4_0": dequantize_q4_0_gpu, "Q5_0": dequantize_q5_0_gpu, "Q8_0": dequantize_q8_0_gpu, diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 7034ac9..cc4a323 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -79,7 +79,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str raise Exception(f"can't find {translated_key} in GGUF file!") def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''): - # print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}") + #print(f"recursively loading weights {prefix}") if not isinstance(module, base_operator.BaseInjectedModule): load_cur_state_dict(module, gguf_loader, prefix) for name, child in module._modules.items(): diff --git a/test_prompt.txt b/test_prompt.txt index 69fd23b..c749c0e 100644 --- a/test_prompt.txt +++ b/test_prompt.txt @@ -6,4 +6,15 @@ Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. +Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair. None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. 阅读以上文字,并概括大意