Merge pull request #608 from makllama/fix_musa_ext

musa: support bf16
This commit is contained in:
Atream 2025-02-24 23:12:54 +08:00 committed by GitHub
commit 7b2a6690ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 1 deletions

View File

@ -1,7 +1,9 @@
#pragma once
#include <musa_runtime.h>
#include <musa_bf16.h>
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaStream_t musaStream_t
#define cudaHostFn_t musaHostFn_t
#define cudaHostFn_t musaHostFn_t
#define nv_bfloat16 mt_bfloat16

View File

@ -350,6 +350,7 @@ elif MUSA_HOME is not None:
"at::cuda": "at::musa",
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
"#include <c10/cuda/CUDAGuard.h>": "#include \"torch_musa/csrc/core/MUSAGuard.h\"",
"nv_bfloat16": "mt_bfloat16",
}).run()
ops_module = MUSAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',