From 3758afb526d2dcd02a45843dafb650108f4fd48e Mon Sep 17 00:00:00 2001 From: Azure Date: Fri, 13 Sep 2024 08:34:23 +0000 Subject: [PATCH] fix some dequant function dosen't support multi gpu bug --- ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu index 1583cf7..0c49fa7 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu @@ -292,6 +292,7 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device) { int num_blocks = data.numel() / blk_size; + const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto data_gpu = torch::empty({data.numel()}, options); @@ -330,6 +331,7 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device) { int num_blocks = data.numel() / blk_size; + const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto data_gpu = torch::empty({data.numel()}, options); @@ -348,6 +350,7 @@ torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device de torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device) { int num_blocks = data.numel() / blk_size; + const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto data_gpu = torch::empty({data.numel()}, options); @@ -366,6 +369,7 @@ torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device de torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device) { int num_blocks = data.numel() / blk_size; + const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto data_gpu = torch::empty({data.numel()}, options);