From 7e5962af3d16af570f3108acd3238d9dae330a45 Mon Sep 17 00:00:00 2001 From: Azure Date: Tue, 25 Feb 2025 10:52:29 +0000 Subject: [PATCH] fix fp8 multi gpu; update FQA --- doc/en/FAQ.md | 6 +++++- ktransformers/ktransformers_ext/triton/fp8gemm.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/doc/en/FAQ.md b/doc/en/FAQ.md index 3269f6b..a8399bd 100644 --- a/doc/en/FAQ.md +++ b/doc/en/FAQ.md @@ -92,4 +92,8 @@ Traceback (most recent call last): next_token = torch.multinomial(probs, num_samples=1).squeeze(1) RuntimeError: probability tensor contains either `inf`, `nan` or element < 0 ``` -**SOLUTION**: The issue of running ktransformers on Ubuntu 22.04 is caused by the current system's g++ version being too old, and the pre-defined macros do not include avx_bf16. We have tested and confirmed that it works on g++ 11.4 in Ubuntu 22.04. \ No newline at end of file +**SOLUTION**: The issue of running ktransformers on Ubuntu 22.04 is caused by the current system's g++ version being too old, and the pre-defined macros do not include avx_bf16. We have tested and confirmed that it works on g++ 11.4 in Ubuntu 22.04. + +### Q: Using fp8 prefill very slow. + +The FP8 kernel is build by JIT, so the first run will be slow. The subsequent runs will be faster. \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/triton/fp8gemm.py b/ktransformers/ktransformers_ext/triton/fp8gemm.py index 7d5b72e..d5c913d 100644 --- a/ktransformers/ktransformers_ext/triton/fp8gemm.py +++ b/ktransformers/ktransformers_ext/triton/fp8gemm.py @@ -102,7 +102,8 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t M, N = x.size() y = torch.empty_like(x, dtype=torch.get_default_dtype()) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) - weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + with torch.cuda.device(x.device): + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) return y