mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-02-02 12:39:27 +08:00
fix some bug in compile in linux
This commit is contained in:
parent
0a2fd52cea
commit
1d9d397525
@ -205,7 +205,12 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/pybind11 ${CMAKE_
|
|||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/llama.cpp)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/llama.cpp)
|
||||||
|
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
|
||||||
include_directories("D:/CUDA/v12.5/include")
|
if (WIN32)
|
||||||
|
include_directories("$ENV{CUDA_PATH}/include")
|
||||||
|
elseif (UNIX)
|
||||||
|
find_package(CUDA REQUIRED)
|
||||||
|
include_directories("${CUDA_INCLUDE_DIRS}")
|
||||||
|
endif()
|
||||||
|
|
||||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
|
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
|
||||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
|
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
|
||||||
@ -216,4 +221,8 @@ message(STATUS "ALL_SOURCES: ${ALL_SOURCES}")
|
|||||||
|
|
||||||
pybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES})
|
pybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES})
|
||||||
target_link_libraries(${PROJECT_NAME} PRIVATE llama)
|
target_link_libraries(${PROJECT_NAME} PRIVATE llama)
|
||||||
target_link_libraries(${PROJECT_NAME} PRIVATE "D:/CUDA/v12.5/lib/x64/cudart.lib")#CUDA::cudart
|
if(WIN32)
|
||||||
|
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
|
||||||
|
elseif(UNIX)
|
||||||
|
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
|
||||||
|
endif()
|
||||||
@ -3,8 +3,8 @@
|
|||||||
* @Author : chenht2022
|
* @Author : chenht2022
|
||||||
* @Date : 2024-07-16 10:43:18
|
* @Date : 2024-07-16 10:43:18
|
||||||
* @Version : 1.0.0
|
* @Version : 1.0.0
|
||||||
* @LastEditors : chenht2022
|
* @LastEditors : chenxl
|
||||||
* @LastEditTime : 2024-07-25 10:33:47
|
* @LastEditTime : 2024-08-08 04:23:51
|
||||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||||
**/
|
**/
|
||||||
#ifndef CPUINFER_TASKQUEUE_H
|
#ifndef CPUINFER_TASKQUEUE_H
|
||||||
@ -25,7 +25,7 @@ class custom_mutex {
|
|||||||
private:
|
private:
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
HANDLE global_mutex;
|
HANDLE global_mutex;
|
||||||
#elif
|
#else
|
||||||
std::mutex global_mutex;
|
std::mutex global_mutex;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ public:
|
|||||||
{
|
{
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
WaitForSingleObject(global_mutex, INFINITE);
|
WaitForSingleObject(global_mutex, INFINITE);
|
||||||
#elif
|
#else
|
||||||
global_mutex.lock();
|
global_mutex.lock();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@ -50,7 +50,7 @@ public:
|
|||||||
{
|
{
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
ReleaseMutex(global_mutex);
|
ReleaseMutex(global_mutex);
|
||||||
#elif
|
#else
|
||||||
global_mutex.lock();
|
global_mutex.lock();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,7 +3,8 @@ requires = [
|
|||||||
"setuptools",
|
"setuptools",
|
||||||
"torch >= 2.3.0",
|
"torch >= 2.3.0",
|
||||||
"ninja",
|
"ninja",
|
||||||
"packaging"
|
"packaging",
|
||||||
|
"cpufeature"
|
||||||
]
|
]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
|||||||
15
setup.py
15
setup.py
@ -6,7 +6,7 @@ Author : chenxl
|
|||||||
Date : 2024-07-27 16:15:27
|
Date : 2024-07-27 16:15:27
|
||||||
Version : 1.0.0
|
Version : 1.0.0
|
||||||
LastEditors : chenxl
|
LastEditors : chenxl
|
||||||
LastEditTime : 2024-07-31 09:44:46
|
LastEditTime : 2024-08-08 02:45:15
|
||||||
Adapted from:
|
Adapted from:
|
||||||
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
|
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
|
||||||
Copyright (c) 2023, Tri Dao.
|
Copyright (c) 2023, Tri Dao.
|
||||||
@ -19,6 +19,7 @@ import re
|
|||||||
import ast
|
import ast
|
||||||
import subprocess
|
import subprocess
|
||||||
import platform
|
import platform
|
||||||
|
import shutil
|
||||||
import http.client
|
import http.client
|
||||||
import urllib.request
|
import urllib.request
|
||||||
import urllib.error
|
import urllib.error
|
||||||
@ -27,6 +28,7 @@ from packaging.version import parse
|
|||||||
import torch.version
|
import torch.version
|
||||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||||
from setuptools import setup, Extension
|
from setuptools import setup, Extension
|
||||||
|
from cpufeature.extension import CPUFeature
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
||||||
|
|
||||||
class CpuInstructInfo:
|
class CpuInstructInfo:
|
||||||
@ -100,7 +102,14 @@ class VersionInfo:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported cpu Instructions: {}".format(flags_line))
|
"Unsupported cpu Instructions: {}".format(flags_line))
|
||||||
elif sys.platform == "win32":
|
elif sys.platform == "win32":
|
||||||
return 'native'
|
if CPUFeature.get("AVX512bw", False):
|
||||||
|
return 'fancy'
|
||||||
|
if CPUFeature.get("AVX512f", False):
|
||||||
|
return 'avx512'
|
||||||
|
if CPUFeature.get("AVX2", False):
|
||||||
|
return 'avx2'
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported cpu Instructions: {}".format(str(CPUFeature)))
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported platform: {}".format(sys.platform))
|
raise ValueError("Unsupported platform: {}".format(sys.platform))
|
||||||
|
|
||||||
@ -158,7 +167,7 @@ class BuildWheelsCommand(_bdist_wheel):
|
|||||||
|
|
||||||
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
|
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
|
||||||
print("Raw wheel path", wheel_path)
|
print("Raw wheel path", wheel_path)
|
||||||
os.rename(wheel_filename, wheel_path)
|
shutil.move(wheel_filename, wheel_path)
|
||||||
except (urllib.error.HTTPError, urllib.error.URLError, http.client.RemoteDisconnected):
|
except (urllib.error.HTTPError, urllib.error.URLError, http.client.RemoteDisconnected):
|
||||||
print("Precompiled wheel not found. Building from source...")
|
print("Precompiled wheel not found. Building from source...")
|
||||||
# If the wheel could not be downloaded, build from source
|
# If the wheel could not be downloaded, build from source
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user