diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml
new file mode 100644
index 0000000..60df01f
--- /dev/null
+++ b/.github/workflows/docker-image.yml
@@ -0,0 +1,90 @@
+name: DockerHub CI
+
+on:
+ release:
+ types: [published]
+ # push:
+ # branches:
+ # - main
+env:
+ DOCKERHUB_REPO: ${{ secrets.DOCKERHUB_USERNAME }}/ktransformers
+jobs:
+ test:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - name: Run tests
+ run: |
+ if [ -f docker-compose.test.yml ]; then
+ docker-compose --file docker-compose.test.yml build
+ docker-compose --file docker-compose.test.yml run sut
+ else
+ docker build . --file Dockerfile
+ fi
+
+ docker_task:
+ needs: test
+ name: ${{ matrix.instruct}}
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ # for amd64
+ - {instruct: "FANCY", platform: "linux/amd64"}
+ - {instruct: "AVX512", platform: "linux/amd64"}
+ - {instruct: "AVX2", platform: "linux/amd64"}
+ - {instruct: "NATIVE", platform: "linux/amd64"}
+ # for arm64
+ - {instruct: "NATIVE", platform: "linux/arm64"}
+
+ steps:
+ - name: Move Docker data directory
+ run: |
+ sudo systemctl stop docker
+ sudo mkdir -p /mnt/docker
+ sudo rsync -avz /var/lib/docker/ /mnt/docker
+ sudo rm -rf /var/lib/docker
+ sudo ln -s /mnt/docker /var/lib/docker
+ sudo systemctl start docker
+
+ -
+ name: Set up QEMU
+ uses: docker/setup-qemu-action@v3
+
+ -
+ name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ -
+ name: Login to Docker Hub
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
+ -
+ name: Build and push for amd64
+ if: matrix.platform == 'linux/amd64'
+ uses: docker/build-push-action@v6
+ with:
+ push: true
+ platforms: |
+ linux/amd64
+ tags: |
+ ${{ env.DOCKERHUB_REPO }}:latest-${{ matrix.instruct }}
+ ${{ env.DOCKERHUB_REPO }}:${{ github.event.release.tag_name }}-${{ matrix.instruct }}
+ build-args: |
+ CPU_INSTRUCT=${{ matrix.instruct }}
+ -
+ name: Build and push for arm64
+ if: matrix.platform == 'linux/arm64'
+ uses: docker/build-push-action@v6
+ with:
+ push: true
+ platforms: |
+ linux/arm64
+ tags: |
+ ${{ env.DOCKERHUB_REPO }}:latest-${{ matrix.instruct }}
+ ${{ env.DOCKERHUB_REPO }}:${{ github.event.release.tag_name }}-${{ matrix.instruct }}
+ build-args: |
+ CPU_INSTRUCT=${{ matrix.instruct }}
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index f8da261..58250d1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -28,3 +28,4 @@ ktransformers/tests/chat_txt.txt
mmlu_result_q4km.json
mmlu_result_q4km.log
ktransformers/tests/mmlu_result_silicon.log
+ktransformers/ktransformers_ext/cuda_musa/
diff --git a/Dockerfile b/Dockerfile
index 6d4b214..1807150 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -11,6 +11,7 @@ EOF
FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel as compile_server
+ARG CPU_INSTRUCT=NATIVE
WORKDIR /workspace
ENV CUDA_HOME /usr/local/cuda
COPY --from=web_compile /home/ktransformers /workspace/ktransformers
@@ -28,8 +29,9 @@ git submodule init &&
git submodule update &&
pip install ninja pyproject numpy cpufeature &&
pip install flash-attn &&
-CPU_INSTRUCT=NATIVE KTRANSFORMERS_FORCE_BUILD=TRUE TORCH_CUDA_ARCH_LIST="8.0;8.6;8.7;8.9;9.0+PTX" pip install . --no-build-isolation --verbose &&
-pip cache purge
+CPU_INSTRUCT=${CPU_INSTRUCT} KTRANSFORMERS_FORCE_BUILD=TRUE TORCH_CUDA_ARCH_LIST="8.0;8.6;8.7;8.9;9.0+PTX" pip install . --no-build-isolation --verbose &&
+pip cache purge &&
+cp /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /opt/conda/lib/
EOF
ENTRYPOINT ["tail", "-f", "/dev/null"]
\ No newline at end of file
diff --git a/doc/en/DeepseekR1_V3_tutorial.md b/doc/en/DeepseekR1_V3_tutorial.md
index fe5cb7a..52e9d32 100644
--- a/doc/en/DeepseekR1_V3_tutorial.md
+++ b/doc/en/DeepseekR1_V3_tutorial.md
@@ -226,6 +226,7 @@ Intel is currently the only CPU vendor that supports AMX-like instructions, whic
### Easier
* Official Docker images to simplify installation
* Fix the server integration for web API access
+* Fix the local chat only accepting a single line prompt (currently \n begins generating prompt)
* Support for more quantization types, including the highly requested dynamic quantization from unsloth
Stay tuned for more updates!
diff --git a/doc/en/api/server/website.md b/doc/en/api/server/website.md
index bd380cd..a057898 100644
--- a/doc/en/api/server/website.md
+++ b/doc/en/api/server/website.md
@@ -8,6 +8,20 @@ This document provides the necessary steps to set up and run the web service for
Before you can compile the web code, make sure you have installed [Node.js](https://nodejs.org) version 18.3 or higher
+Note: The version of Node.js in the Ubuntu or Debian GNU/Linux software repository is too low, causing compilation errors. Users can also install Node.js through the Nodesource repository, provided they uninstall the outdated version first.
+
+```bash
+
+ # sudo apt-get remove nodejs npm -y && sudo apt-get autoremove -y
+ sudo apt-get update -y && sudo apt-get install -y apt-transport-https ca-certificates curl gnupg
+ curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | sudo gpg --dearmor -o /usr/share/keyrings/nodesource.gpg
+ sudo chmod 644 /usr/share/keyrings/nodesource.gpg
+ echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/nodesource.gpg] https://deb.nodesource.com/node_23.x nodistro main" | sudo tee /etc/apt/sources.list.d/nodesource.list
+ sudo apt-get update -y
+ sudo apt-get install nodejs -y
+
+```
+
Once npm is installed, navigate to the `ktransformers/website` directory:
```bash
diff --git a/doc/en/install.md b/doc/en/install.md
index 7abb6c2..b51c54f 100644
--- a/doc/en/install.md
+++ b/doc/en/install.md
@@ -27,11 +27,11 @@ Some preparation:
fi
```
-- Linux-x86_64 with gcc, g++ and cmake
+- Linux-x86_64 with gcc, g++ and cmake (using Ubuntu as an example)
```sh
sudo apt-get update
- sudo apt-get install gcc g++ cmake ninja-build
+ sudo apt-get install build-essential cmake ninja-build
```
- We recommend using [Miniconda3](https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh) or [Anaconda3](https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh) to create a virtual environment with Python=3.11 to run our program. Assuming your Anaconda installation directory is `~/anaconda3`, you should ensure that the version identifier of the GNU C++standard library used by Anaconda includes `GLIBCXX-3.4.32`
diff --git a/doc/zh/DeepseekR1_V3_tutorial_zh.md b/doc/zh/DeepseekR1_V3_tutorial_zh.md
index b8ef85e..ba9d7e8 100644
--- a/doc/zh/DeepseekR1_V3_tutorial_zh.md
+++ b/doc/zh/DeepseekR1_V3_tutorial_zh.md
@@ -160,9 +160,14 @@ DeepSeek 的 MLA 操作符计算密集。虽然全部在 CPU 上运行是可行
5. 为什么选择英特尔 CPU?
英特尔目前是唯一支持 AMX 类似指令的 CPU 供应商,与仅支持 AVX 的替代方案相比,性能显著更好。
+
## 常见问题解答
### R1 不返回思考过程
注意!如果测试 R1 可能会跳过思考。因此,可以添加参数:`--force_think true`。详细信息在 [常见问题解答](./FAQ.md) 部分中。
+## 问题
+* 修复服务器集成功能以实现网络API访问支持
+* 修复本地聊天功能仅支持单行提示输入的问题(目前输入换行符(\n)即开始生成提示)
+
### 更多常见问题解答
[详见](./FAQ.md)
diff --git a/ktransformers/ktransformers_ext/CMakeLists.txt b/ktransformers/ktransformers_ext/CMakeLists.txt
index d9ecd7a..ecce9b7 100644
--- a/ktransformers/ktransformers_ext/CMakeLists.txt
+++ b/ktransformers/ktransformers_ext/CMakeLists.txt
@@ -30,6 +30,8 @@ if (NOT MSVC)
option(LLAMA_F16C "llama: enable F16C" OFF)
endif()
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
+option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF)
+option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
@@ -208,8 +210,31 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
if (WIN32)
include_directories("$ENV{CUDA_PATH}/include")
elseif (UNIX)
- find_package(CUDA REQUIRED)
- include_directories("${CUDA_INCLUDE_DIRS}")
+ if (KTRANSFORMERS_USE_CUDA)
+ find_package(CUDA REQUIRED)
+ include_directories("${CUDA_INCLUDE_DIRS}")
+ add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
+ endif()
+
+ if (KTRANSFORMERS_USE_MUSA)
+ if (NOT EXISTS $ENV{MUSA_PATH})
+ if (NOT EXISTS /opt/musa)
+ set(MUSA_PATH /usr/local/musa)
+ else()
+ set(MUSA_PATH /opt/musa)
+ endif()
+ else()
+ set(MUSA_PATH $ENV{MUSA_PATH})
+ endif()
+
+ list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
+
+ find_package(MUSAToolkit)
+ if (MUSAToolkit_FOUND)
+ message(STATUS "MUSA Toolkit found")
+ add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
+ endif()
+ endif()
endif()
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
@@ -225,10 +250,15 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama)
if(WIN32)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
elseif(UNIX)
- if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "")
- set(ENV{CUDA_HOME} "/usr/local/cuda")
+ if(KTRANSFORMERS_USE_CUDA)
+ if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "")
+ set(ENV{CUDA_HOME} "/usr/local/cuda")
+ endif()
+ target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
+ endif()
+ if(KTRANSFORMERS_USE_MUSA)
+ target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
endif()
- target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
endif()
# Define the USE_NUMA option
diff --git a/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h b/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
index 9618e6b..d0f0c60 100644
--- a/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
+++ b/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
@@ -17,7 +17,11 @@
#include
#include
#include
-#include "cuda_runtime.h"
+#ifdef KTRANSFORMERS_USE_CUDA
+#include "vendors/cuda.h"
+#elif KTRANSFORMERS_USE_MUSA
+#include "vendors/musa.h"
+#endif
#include "backend.h"
#include "task_queue.h"
diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/README.md b/ktransformers/ktransformers_ext/cpu_backend/vendors/README.md
new file mode 100644
index 0000000..d179f66
--- /dev/null
+++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/README.md
@@ -0,0 +1,3 @@
+## TODO
+
+This directory can be removed after updating the version of `llama.cpp`.
\ No newline at end of file
diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h b/ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
new file mode 100644
index 0000000..082ad2c
--- /dev/null
+++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
@@ -0,0 +1,3 @@
+#pragma once
+
+#include
\ No newline at end of file
diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h b/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
new file mode 100644
index 0000000..7c94102
--- /dev/null
+++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
@@ -0,0 +1,7 @@
+#pragma once
+
+#include
+
+#define cudaLaunchHostFunc musaLaunchHostFunc
+#define cudaStream_t musaStream_t
+#define cudaHostFn_t musaHostFn_t
\ No newline at end of file
diff --git a/ktransformers/ktransformers_ext/cuda/binding.cpp b/ktransformers/ktransformers_ext/cuda/binding.cpp
index 96ee9d8..75cfcdb 100644
--- a/ktransformers/ktransformers_ext/cuda/binding.cpp
+++ b/ktransformers/ktransformers_ext/cuda/binding.cpp
@@ -7,7 +7,9 @@
**/
#include "custom_gguf/ops.h"
+#ifdef KTRANSFORMERS_USE_CUDA
#include "gptq_marlin/ops.h"
+#endif
// Python bindings
#include
#include
@@ -52,7 +54,8 @@ PYBIND11_MODULE(KTransformersOps, m) {
return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
}, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
-
+
+#ifdef KTRANSFORMERS_USE_CUDA
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),
diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py
index 7915654..85378ee 100644
--- a/ktransformers/operators/attention.py
+++ b/ktransformers/operators/attention.py
@@ -58,18 +58,10 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
- q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
- out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
- self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
- bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
- self.q_absorb.weight.data = q_absorb
- self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
- bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
- self.out_absorb.weight.data = out_absorb
- #del self.orig_module.kv_b_proj
- q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
- out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
- return q_absorb, out_absorb
+ self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
+ self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
+
+ return self.q_absorb, self.out_absorb
def forward_chunck(
self,
@@ -105,7 +97,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
@@ -129,8 +121,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
q_absorb, out_absorb = self.get_absorbed()
- # if hasattr(self.orig_module, 'kv_b_proj'):
- # del self.orig_module.kv_b_proj
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
@@ -227,7 +217,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
@@ -379,7 +369,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ f"The cache structure has changed since version transformer verision v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
diff --git a/ktransformers/operators/flashinfer_wrapper.py b/ktransformers/operators/flashinfer_wrapper.py
index b7d0938..b3b9dd1 100644
--- a/ktransformers/operators/flashinfer_wrapper.py
+++ b/ktransformers/operators/flashinfer_wrapper.py
@@ -9,7 +9,7 @@ flashinfer_enabled = False
try:
import flashinfer
- flashinfer_enabled = False
+ flashinfer_enabled = False # disabled now, TODO:use new version of flashinfer and enable
print("found flashinfer")
except ImportError:
diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py
index deb6cfa..8211933 100644
--- a/ktransformers/server/backend/interfaces/transformers.py
+++ b/ktransformers/server/backend/interfaces/transformers.py
@@ -381,13 +381,13 @@ class TransformersInterface(BackendInterfaceBase):
self.profiler.create_and_start_timer("prefill")
-
+ if Config().user_force_think:
+ think = '\n'
+ print(think, end="",flush=True)
+ yield think
+
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
# output think token after prefill done
- if Config().user_force_think:
- think = '\n'
- print(think, end="",flush=True)
- yield think
if t is not None:
print(t, end="",flush=True)
yield t
diff --git a/ktransformers/tests/mmlu_pro_test.py b/ktransformers/tests/mmlu_pro_test.py
index d44be2a..27eb9b2 100644
--- a/ktransformers/tests/mmlu_pro_test.py
+++ b/ktransformers/tests/mmlu_pro_test.py
@@ -176,7 +176,7 @@ if __name__ == "__main__":
parser.add_argument("--result", type=str, default="./mmlu_pro.json", help="Path to save the result JSON file")
parser.add_argument("--log", type=str, default="./mmlu_pro.log", help="Path to save the log file")
parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path")
- parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
+ parser.add_argument("--api_url", type=str, default="http://localhost:15488/v1/chat/completions", help="API URL")
# parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
args = parser.parse_args()
diff --git a/setup.py b/setup.py
index ddd4835..8061713 100644
--- a/setup.py
+++ b/setup.py
@@ -1,16 +1,16 @@
#!/usr/bin/env python
# coding=utf-8
'''
-Description :
+Description :
Author : chenxl
Date : 2024-07-27 16:15:27
Version : 1.0.0
-LastEditors : chenxl
+LastEditors : chenxl
LastEditTime : 2024-08-14 16:36:19
Adapted from:
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
Copyright (c) 2023, Tri Dao.
-Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
+Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import os
@@ -30,6 +30,11 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
from setuptools import setup, Extension
from cpufeature.extension import CPUFeature
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
+try:
+ from torch_musa.utils.simple_porting import SimplePorting
+ from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
+except ImportError:
+ MUSA_HOME=None
class CpuInstructInfo:
CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE")
@@ -40,7 +45,7 @@ class CpuInstructInfo:
CMAKE_FANCY = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON"
CMAKE_AVX512 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON"
CMAKE_AVX2 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON"
-
+
class VersionInfo:
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = "ktransformers"
@@ -49,6 +54,16 @@ class VersionInfo:
)
FORCE_BUILD = os.getenv("KTRANSFORMERS_FORCE_BUILD", "FALSE") == "TRUE"
+ def get_musa_bare_metal_version(self, musa_dir):
+ raw_output = subprocess.run(
+ [musa_dir + "/bin/mcc", "-v"], check=True,
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.decode("utf-8")
+ output = raw_output.split()
+ release_idx = output.index("version") + 1
+ bare_metal_version = parse(output[release_idx].split(",")[0])
+ musa_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
+ return musa_version
+
def get_cuda_bare_metal_version(self, cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
@@ -58,7 +73,7 @@ class VersionInfo:
cuda_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
return cuda_version
- def get_cuda_version_of_torch(self,):
+ def get_cuda_version_of_torch(self):
torch_cuda_version = parse(torch.version.cuda)
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
return cuda_version
@@ -117,7 +132,7 @@ class VersionInfo:
torch_version_raw = parse(torch.__version__)
torch_version = f"{torch_version_raw.major}{torch_version_raw.minor}"
return torch_version
-
+
def get_flash_version(self,):
version_file = os.path.join(
Path(VersionInfo.THIS_DIR), VersionInfo.PACKAGE_NAME, "__init__.py")
@@ -128,12 +143,21 @@ class VersionInfo:
return flash_version
def get_package_version(self, full_version=False):
- flash_version = self.get_flash_version()
- package_version = f"{str(flash_version)}+cu{self.get_cuda_bare_metal_version(CUDA_HOME)}torch{self.get_torch_version()}{self.get_cpu_instruct()}"
+ flash_version = str(self.get_flash_version())
+ torch_version = self.get_torch_version()
+ cpu_instruct = self.get_cpu_instruct()
+ backend_version = ""
+ if CUDA_HOME is not None:
+ backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}"
+ elif MUSA_HOME is not None:
+ backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
+ else:
+ raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
+ package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
if full_version:
return package_version
if not VersionInfo.FORCE_BUILD:
- return str(flash_version)
+ return flash_version
return package_version
@@ -218,11 +242,19 @@ class CMakeBuild(BuildExtension):
f"-DPYTHON_EXECUTABLE={sys.executable}",
f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
]
+
+ if CUDA_HOME is not None:
+ cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"]
+ elif MUSA_HOME is not None:
+ cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
+ else:
+ raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
+
build_args = []
if "CMAKE_ARGS" in os.environ:
cmake_args += [
item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
-
+
if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY:
cpu_args = CpuInstructInfo.CMAKE_FANCY
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512:
@@ -231,7 +263,7 @@ class CMakeBuild(BuildExtension):
cpu_args = CpuInstructInfo.CMAKE_AVX2
else:
cpu_args = CpuInstructInfo.CMAKE_NATIVE
-
+
cmake_args += [
item for item in cpu_args.split(" ") if item
]
@@ -288,28 +320,55 @@ class CMakeBuild(BuildExtension):
print("Standard output:", result.stdout)
print("Standard error:", result.stderr)
subprocess.run(
- ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
+ ["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True
)
+if CUDA_HOME is not None:
+ ops_module = CUDAExtension('KTransformersOps', [
+ 'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
+ 'ktransformers/ktransformers_ext/cuda/binding.cpp',
+ 'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
+ ],
+ extra_compile_args={
+ 'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
+ 'nvcc': [
+ '-O3',
+ '--use_fast_math',
+ '-Xcompiler', '-fPIC',
+ '-DKTRANSFORMERS_USE_CUDA',
+ ]
+ }
+ )
+elif MUSA_HOME is not None:
+ SimplePorting(cuda_dir_path="ktransformers/ktransformers_ext/cuda", mapping_rule={
+ # Common rules
+ "at::cuda": "at::musa",
+ "#include ": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
+ "#include ": "#include \"torch_musa/csrc/core/MUSAGuard.h\"",
+ }).run()
+ ops_module = MUSAExtension('KTransformersOps', [
+ 'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
+ 'ktransformers/ktransformers_ext/cuda_musa/binding.cpp',
+ # TODO: Add Marlin support for MUSA.
+ # 'ktransformers/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
+ ],
+ extra_compile_args={
+ 'cxx': ['force_mcc'],
+ 'mcc': [
+ '-O3',
+ '-DKTRANSFORMERS_USE_MUSA',
+ '-DTHRUST_IGNORE_CUB_VERSION_CHECK',
+ ]
+ }
+ )
+else:
+ raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
setup(
version=VersionInfo().get_package_version(),
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=[
CMakeExtension("cpuinfer_ext"),
- CUDAExtension('KTransformersOps', [
- 'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
- 'ktransformers/ktransformers_ext/cuda/binding.cpp',
- 'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
- ],
- extra_compile_args={
- 'cxx': ['-O3'],
- 'nvcc': [
- '-O3',
- '--use_fast_math',
- '-Xcompiler', '-fPIC',
- ]
- }
- )
+ ops_module,
]
)