1import os 2import sys 3 4from setuptools import setup 5 6import torch.cuda 7from torch.testing._internal.common_utils import IS_WINDOWS 8from torch.utils.cpp_extension import ( 9 BuildExtension, 10 CppExtension, 11 CUDA_HOME, 12 CUDAExtension, 13 ROCM_HOME, 14) 15 16 17if sys.platform == "win32": 18 vc_version = os.getenv("VCToolsVersion", "") 19 if vc_version.startswith("14.16."): 20 CXX_FLAGS = ["/sdl"] 21 else: 22 CXX_FLAGS = ["/sdl", "/permissive-"] 23else: 24 CXX_FLAGS = ["-g"] 25 26USE_NINJA = os.getenv("USE_NINJA") == "1" 27 28ext_modules = [ 29 CppExtension( 30 "torch_test_cpp_extension.cpp", ["extension.cpp"], extra_compile_args=CXX_FLAGS 31 ), 32 CppExtension( 33 "torch_test_cpp_extension.maia", 34 ["maia_extension.cpp"], 35 extra_compile_args=CXX_FLAGS, 36 ), 37 CppExtension( 38 "torch_test_cpp_extension.rng", 39 ["rng_extension.cpp"], 40 extra_compile_args=CXX_FLAGS, 41 ), 42] 43 44if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None): 45 extension = CUDAExtension( 46 "torch_test_cpp_extension.cuda", 47 [ 48 "cuda_extension.cpp", 49 "cuda_extension_kernel.cu", 50 "cuda_extension_kernel2.cu", 51 ], 52 extra_compile_args={"cxx": CXX_FLAGS, "nvcc": ["-O2"]}, 53 ) 54 ext_modules.append(extension) 55 56if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None): 57 extension = CUDAExtension( 58 "torch_test_cpp_extension.torch_library", 59 ["torch_library.cu"], 60 extra_compile_args={"cxx": CXX_FLAGS, "nvcc": ["-O2"]}, 61 ) 62 ext_modules.append(extension) 63 64if torch.backends.mps.is_available(): 65 extension = CppExtension( 66 "torch_test_cpp_extension.mps", 67 ["mps_extension.mm"], 68 extra_compile_args=CXX_FLAGS, 69 ) 70 ext_modules.append(extension) 71 72# todo(mkozuki): Figure out the root cause 73if (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None: 74 # malfet: One should not assume that PyTorch re-exports CUDA dependencies 75 cublas_extension = CUDAExtension( 76 name="torch_test_cpp_extension.cublas_extension", 77 sources=["cublas_extension.cpp"], 78 libraries=["cublas"] if torch.version.hip is None else [], 79 ) 80 ext_modules.append(cublas_extension) 81 82 cusolver_extension = CUDAExtension( 83 name="torch_test_cpp_extension.cusolver_extension", 84 sources=["cusolver_extension.cpp"], 85 libraries=["cusolver"] if torch.version.hip is None else [], 86 ) 87 ext_modules.append(cusolver_extension) 88 89if ( 90 USE_NINJA 91 and (not IS_WINDOWS) 92 and torch.cuda.is_available() 93 and CUDA_HOME is not None 94): 95 extension = CUDAExtension( 96 name="torch_test_cpp_extension.cuda_dlink", 97 sources=[ 98 "cuda_dlink_extension.cpp", 99 "cuda_dlink_extension_kernel.cu", 100 "cuda_dlink_extension_add.cu", 101 ], 102 dlink=True, 103 extra_compile_args={"cxx": CXX_FLAGS, "nvcc": ["-O2", "-dc"]}, 104 ) 105 ext_modules.append(extension) 106 107setup( 108 name="torch_test_cpp_extension", 109 packages=["torch_test_cpp_extension"], 110 ext_modules=ext_modules, 111 include_dirs="self_compiler_include_dirs_test", 112 cmdclass={"build_ext": BuildExtension.with_options(use_ninja=USE_NINJA)}, 113 entry_points={ 114 "torch.backends": [ 115 "device_backend = torch_test_cpp_extension:_autoload", 116 ], 117 }, 118) 119