1*da0073e9SAndroid Build Coastguard Workerimport os 2*da0073e9SAndroid Build Coastguard Workerimport sys 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerfrom setuptools import setup 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch.cuda 7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import IS_WINDOWS 8*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.cpp_extension import ( 9*da0073e9SAndroid Build Coastguard Worker BuildExtension, 10*da0073e9SAndroid Build Coastguard Worker CppExtension, 11*da0073e9SAndroid Build Coastguard Worker CUDA_HOME, 12*da0073e9SAndroid Build Coastguard Worker CUDAExtension, 13*da0073e9SAndroid Build Coastguard Worker ROCM_HOME, 14*da0073e9SAndroid Build Coastguard Worker) 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerif sys.platform == "win32": 18*da0073e9SAndroid Build Coastguard Worker vc_version = os.getenv("VCToolsVersion", "") 19*da0073e9SAndroid Build Coastguard Worker if vc_version.startswith("14.16."): 20*da0073e9SAndroid Build Coastguard Worker CXX_FLAGS = ["/sdl"] 21*da0073e9SAndroid Build Coastguard Worker else: 22*da0073e9SAndroid Build Coastguard Worker CXX_FLAGS = ["/sdl", "/permissive-"] 23*da0073e9SAndroid Build Coastguard Workerelse: 24*da0073e9SAndroid Build Coastguard Worker CXX_FLAGS = ["-g"] 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard WorkerUSE_NINJA = os.getenv("USE_NINJA") == "1" 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Workerext_modules = [ 29*da0073e9SAndroid Build Coastguard Worker CppExtension( 30*da0073e9SAndroid Build Coastguard Worker "torch_test_cpp_extension.cpp", ["extension.cpp"], extra_compile_args=CXX_FLAGS 31*da0073e9SAndroid Build Coastguard Worker ), 32*da0073e9SAndroid Build Coastguard Worker CppExtension( 33*da0073e9SAndroid Build Coastguard Worker "torch_test_cpp_extension.maia", 34*da0073e9SAndroid Build Coastguard Worker ["maia_extension.cpp"], 35*da0073e9SAndroid Build Coastguard Worker extra_compile_args=CXX_FLAGS, 36*da0073e9SAndroid Build Coastguard Worker ), 37*da0073e9SAndroid Build Coastguard Worker CppExtension( 38*da0073e9SAndroid Build Coastguard Worker "torch_test_cpp_extension.rng", 39*da0073e9SAndroid Build Coastguard Worker ["rng_extension.cpp"], 40*da0073e9SAndroid Build Coastguard Worker extra_compile_args=CXX_FLAGS, 41*da0073e9SAndroid Build Coastguard Worker ), 42*da0073e9SAndroid Build Coastguard Worker] 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Workerif torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None): 45*da0073e9SAndroid Build Coastguard Worker extension = CUDAExtension( 46*da0073e9SAndroid Build Coastguard Worker "torch_test_cpp_extension.cuda", 47*da0073e9SAndroid Build Coastguard Worker [ 48*da0073e9SAndroid Build Coastguard Worker "cuda_extension.cpp", 49*da0073e9SAndroid Build Coastguard Worker "cuda_extension_kernel.cu", 50*da0073e9SAndroid Build Coastguard Worker "cuda_extension_kernel2.cu", 51*da0073e9SAndroid Build Coastguard Worker ], 52*da0073e9SAndroid Build Coastguard Worker extra_compile_args={"cxx": CXX_FLAGS, "nvcc": ["-O2"]}, 53*da0073e9SAndroid Build Coastguard Worker ) 54*da0073e9SAndroid Build Coastguard Worker ext_modules.append(extension) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Workerif torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None): 57*da0073e9SAndroid Build Coastguard Worker extension = CUDAExtension( 58*da0073e9SAndroid Build Coastguard Worker "torch_test_cpp_extension.torch_library", 59*da0073e9SAndroid Build Coastguard Worker ["torch_library.cu"], 60*da0073e9SAndroid Build Coastguard Worker extra_compile_args={"cxx": CXX_FLAGS, "nvcc": ["-O2"]}, 61*da0073e9SAndroid Build Coastguard Worker ) 62*da0073e9SAndroid Build Coastguard Worker ext_modules.append(extension) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Workerif torch.backends.mps.is_available(): 65*da0073e9SAndroid Build Coastguard Worker extension = CppExtension( 66*da0073e9SAndroid Build Coastguard Worker "torch_test_cpp_extension.mps", 67*da0073e9SAndroid Build Coastguard Worker ["mps_extension.mm"], 68*da0073e9SAndroid Build Coastguard Worker extra_compile_args=CXX_FLAGS, 69*da0073e9SAndroid Build Coastguard Worker ) 70*da0073e9SAndroid Build Coastguard Worker ext_modules.append(extension) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker# todo(mkozuki): Figure out the root cause 73*da0073e9SAndroid Build Coastguard Workerif (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None: 74*da0073e9SAndroid Build Coastguard Worker # malfet: One should not assume that PyTorch re-exports CUDA dependencies 75*da0073e9SAndroid Build Coastguard Worker cublas_extension = CUDAExtension( 76*da0073e9SAndroid Build Coastguard Worker name="torch_test_cpp_extension.cublas_extension", 77*da0073e9SAndroid Build Coastguard Worker sources=["cublas_extension.cpp"], 78*da0073e9SAndroid Build Coastguard Worker libraries=["cublas"] if torch.version.hip is None else [], 79*da0073e9SAndroid Build Coastguard Worker ) 80*da0073e9SAndroid Build Coastguard Worker ext_modules.append(cublas_extension) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker cusolver_extension = CUDAExtension( 83*da0073e9SAndroid Build Coastguard Worker name="torch_test_cpp_extension.cusolver_extension", 84*da0073e9SAndroid Build Coastguard Worker sources=["cusolver_extension.cpp"], 85*da0073e9SAndroid Build Coastguard Worker libraries=["cusolver"] if torch.version.hip is None else [], 86*da0073e9SAndroid Build Coastguard Worker ) 87*da0073e9SAndroid Build Coastguard Worker ext_modules.append(cusolver_extension) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Workerif ( 90*da0073e9SAndroid Build Coastguard Worker USE_NINJA 91*da0073e9SAndroid Build Coastguard Worker and (not IS_WINDOWS) 92*da0073e9SAndroid Build Coastguard Worker and torch.cuda.is_available() 93*da0073e9SAndroid Build Coastguard Worker and CUDA_HOME is not None 94*da0073e9SAndroid Build Coastguard Worker): 95*da0073e9SAndroid Build Coastguard Worker extension = CUDAExtension( 96*da0073e9SAndroid Build Coastguard Worker name="torch_test_cpp_extension.cuda_dlink", 97*da0073e9SAndroid Build Coastguard Worker sources=[ 98*da0073e9SAndroid Build Coastguard Worker "cuda_dlink_extension.cpp", 99*da0073e9SAndroid Build Coastguard Worker "cuda_dlink_extension_kernel.cu", 100*da0073e9SAndroid Build Coastguard Worker "cuda_dlink_extension_add.cu", 101*da0073e9SAndroid Build Coastguard Worker ], 102*da0073e9SAndroid Build Coastguard Worker dlink=True, 103*da0073e9SAndroid Build Coastguard Worker extra_compile_args={"cxx": CXX_FLAGS, "nvcc": ["-O2", "-dc"]}, 104*da0073e9SAndroid Build Coastguard Worker ) 105*da0073e9SAndroid Build Coastguard Worker ext_modules.append(extension) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Workersetup( 108*da0073e9SAndroid Build Coastguard Worker name="torch_test_cpp_extension", 109*da0073e9SAndroid Build Coastguard Worker packages=["torch_test_cpp_extension"], 110*da0073e9SAndroid Build Coastguard Worker ext_modules=ext_modules, 111*da0073e9SAndroid Build Coastguard Worker include_dirs="self_compiler_include_dirs_test", 112*da0073e9SAndroid Build Coastguard Worker cmdclass={"build_ext": BuildExtension.with_options(use_ninja=USE_NINJA)}, 113*da0073e9SAndroid Build Coastguard Worker entry_points={ 114*da0073e9SAndroid Build Coastguard Worker "torch.backends": [ 115*da0073e9SAndroid Build Coastguard Worker "device_backend = torch_test_cpp_extension:_autoload", 116*da0073e9SAndroid Build Coastguard Worker ], 117*da0073e9SAndroid Build Coastguard Worker }, 118*da0073e9SAndroid Build Coastguard Worker) 119