mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-01-19 12:43:16 +08:00
fix some bug in compile in linux
This commit is contained in:
parent
0a2fd52cea
commit
1d9d397525
@ -205,7 +205,12 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/pybind11 ${CMAKE_
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/llama.cpp)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
|
||||
include_directories("D:/CUDA/v12.5/include")
|
||||
if (WIN32)
|
||||
include_directories("$ENV{CUDA_PATH}/include")
|
||||
elseif (UNIX)
|
||||
find_package(CUDA REQUIRED)
|
||||
include_directories("${CUDA_INCLUDE_DIRS}")
|
||||
endif()
|
||||
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
|
||||
@ -216,4 +221,8 @@ message(STATUS "ALL_SOURCES: ${ALL_SOURCES}")
|
||||
|
||||
pybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES})
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE llama)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE "D:/CUDA/v12.5/lib/x64/cudart.lib")#CUDA::cudart
|
||||
if(WIN32)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
|
||||
elseif(UNIX)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
|
||||
endif()
|
||||
@ -3,8 +3,8 @@
|
||||
* @Author : chenht2022
|
||||
* @Date : 2024-07-16 10:43:18
|
||||
* @Version : 1.0.0
|
||||
* @LastEditors : chenht2022
|
||||
* @LastEditTime : 2024-07-25 10:33:47
|
||||
* @LastEditors : chenxl
|
||||
* @LastEditTime : 2024-08-08 04:23:51
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
**/
|
||||
#ifndef CPUINFER_TASKQUEUE_H
|
||||
@ -25,7 +25,7 @@ class custom_mutex {
|
||||
private:
|
||||
#ifdef _WIN32
|
||||
HANDLE global_mutex;
|
||||
#elif
|
||||
#else
|
||||
std::mutex global_mutex;
|
||||
#endif
|
||||
|
||||
@ -41,7 +41,7 @@ public:
|
||||
{
|
||||
#ifdef _WIN32
|
||||
WaitForSingleObject(global_mutex, INFINITE);
|
||||
#elif
|
||||
#else
|
||||
global_mutex.lock();
|
||||
#endif
|
||||
}
|
||||
@ -50,7 +50,7 @@ public:
|
||||
{
|
||||
#ifdef _WIN32
|
||||
ReleaseMutex(global_mutex);
|
||||
#elif
|
||||
#else
|
||||
global_mutex.lock();
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -3,7 +3,8 @@ requires = [
|
||||
"setuptools",
|
||||
"torch >= 2.3.0",
|
||||
"ninja",
|
||||
"packaging"
|
||||
"packaging",
|
||||
"cpufeature"
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
|
||||
15
setup.py
15
setup.py
@ -6,7 +6,7 @@ Author : chenxl
|
||||
Date : 2024-07-27 16:15:27
|
||||
Version : 1.0.0
|
||||
LastEditors : chenxl
|
||||
LastEditTime : 2024-07-31 09:44:46
|
||||
LastEditTime : 2024-08-08 02:45:15
|
||||
Adapted from:
|
||||
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
|
||||
Copyright (c) 2023, Tri Dao.
|
||||
@ -19,6 +19,7 @@ import re
|
||||
import ast
|
||||
import subprocess
|
||||
import platform
|
||||
import shutil
|
||||
import http.client
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
@ -27,6 +28,7 @@ from packaging.version import parse
|
||||
import torch.version
|
||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||
from setuptools import setup, Extension
|
||||
from cpufeature.extension import CPUFeature
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
||||
|
||||
class CpuInstructInfo:
|
||||
@ -100,7 +102,14 @@ class VersionInfo:
|
||||
raise ValueError(
|
||||
"Unsupported cpu Instructions: {}".format(flags_line))
|
||||
elif sys.platform == "win32":
|
||||
return 'native'
|
||||
if CPUFeature.get("AVX512bw", False):
|
||||
return 'fancy'
|
||||
if CPUFeature.get("AVX512f", False):
|
||||
return 'avx512'
|
||||
if CPUFeature.get("AVX2", False):
|
||||
return 'avx2'
|
||||
raise ValueError(
|
||||
"Unsupported cpu Instructions: {}".format(str(CPUFeature)))
|
||||
else:
|
||||
raise ValueError("Unsupported platform: {}".format(sys.platform))
|
||||
|
||||
@ -158,7 +167,7 @@ class BuildWheelsCommand(_bdist_wheel):
|
||||
|
||||
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
|
||||
print("Raw wheel path", wheel_path)
|
||||
os.rename(wheel_filename, wheel_path)
|
||||
shutil.move(wheel_filename, wheel_path)
|
||||
except (urllib.error.HTTPError, urllib.error.URLError, http.client.RemoteDisconnected):
|
||||
print("Precompiled wheel not found. Building from source...")
|
||||
# If the wheel could not be downloaded, build from source
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user