mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-01-19 12:43:16 +08:00
musa: support bf16
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
This commit is contained in:
parent
94ab2de3b9
commit
18b1d18367
@ -1,7 +1,9 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <musa_runtime.h>
|
#include <musa_runtime.h>
|
||||||
|
#include <musa_bf16.h>
|
||||||
|
|
||||||
#define cudaLaunchHostFunc musaLaunchHostFunc
|
#define cudaLaunchHostFunc musaLaunchHostFunc
|
||||||
#define cudaStream_t musaStream_t
|
#define cudaStream_t musaStream_t
|
||||||
#define cudaHostFn_t musaHostFn_t
|
#define cudaHostFn_t musaHostFn_t
|
||||||
|
#define nv_bfloat16 mt_bfloat16
|
||||||
1
setup.py
1
setup.py
@ -350,6 +350,7 @@ elif MUSA_HOME is not None:
|
|||||||
"at::cuda": "at::musa",
|
"at::cuda": "at::musa",
|
||||||
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
|
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
|
||||||
"#include <c10/cuda/CUDAGuard.h>": "#include \"torch_musa/csrc/core/MUSAGuard.h\"",
|
"#include <c10/cuda/CUDAGuard.h>": "#include \"torch_musa/csrc/core/MUSAGuard.h\"",
|
||||||
|
"nv_bfloat16": "mt_bfloat16",
|
||||||
}).run()
|
}).run()
|
||||||
ops_module = MUSAExtension('KTransformersOps', [
|
ops_module = MUSAExtension('KTransformersOps', [
|
||||||
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
|
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user