xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/setup.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport os
2*da0073e9SAndroid Build Coastguard Workerimport sys
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerfrom setuptools import setup
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch.cuda
7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import IS_WINDOWS
8*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.cpp_extension import (
9*da0073e9SAndroid Build Coastguard Worker    BuildExtension,
10*da0073e9SAndroid Build Coastguard Worker    CppExtension,
11*da0073e9SAndroid Build Coastguard Worker    CUDA_HOME,
12*da0073e9SAndroid Build Coastguard Worker    CUDAExtension,
13*da0073e9SAndroid Build Coastguard Worker    ROCM_HOME,
14*da0073e9SAndroid Build Coastguard Worker)
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerif sys.platform == "win32":
18*da0073e9SAndroid Build Coastguard Worker    vc_version = os.getenv("VCToolsVersion", "")
19*da0073e9SAndroid Build Coastguard Worker    if vc_version.startswith("14.16."):
20*da0073e9SAndroid Build Coastguard Worker        CXX_FLAGS = ["/sdl"]
21*da0073e9SAndroid Build Coastguard Worker    else:
22*da0073e9SAndroid Build Coastguard Worker        CXX_FLAGS = ["/sdl", "/permissive-"]
23*da0073e9SAndroid Build Coastguard Workerelse:
24*da0073e9SAndroid Build Coastguard Worker    CXX_FLAGS = ["-g"]
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard WorkerUSE_NINJA = os.getenv("USE_NINJA") == "1"
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerext_modules = [
29*da0073e9SAndroid Build Coastguard Worker    CppExtension(
30*da0073e9SAndroid Build Coastguard Worker        "torch_test_cpp_extension.cpp", ["extension.cpp"], extra_compile_args=CXX_FLAGS
31*da0073e9SAndroid Build Coastguard Worker    ),
32*da0073e9SAndroid Build Coastguard Worker    CppExtension(
33*da0073e9SAndroid Build Coastguard Worker        "torch_test_cpp_extension.maia",
34*da0073e9SAndroid Build Coastguard Worker        ["maia_extension.cpp"],
35*da0073e9SAndroid Build Coastguard Worker        extra_compile_args=CXX_FLAGS,
36*da0073e9SAndroid Build Coastguard Worker    ),
37*da0073e9SAndroid Build Coastguard Worker    CppExtension(
38*da0073e9SAndroid Build Coastguard Worker        "torch_test_cpp_extension.rng",
39*da0073e9SAndroid Build Coastguard Worker        ["rng_extension.cpp"],
40*da0073e9SAndroid Build Coastguard Worker        extra_compile_args=CXX_FLAGS,
41*da0073e9SAndroid Build Coastguard Worker    ),
42*da0073e9SAndroid Build Coastguard Worker]
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Workerif torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None):
45*da0073e9SAndroid Build Coastguard Worker    extension = CUDAExtension(
46*da0073e9SAndroid Build Coastguard Worker        "torch_test_cpp_extension.cuda",
47*da0073e9SAndroid Build Coastguard Worker        [
48*da0073e9SAndroid Build Coastguard Worker            "cuda_extension.cpp",
49*da0073e9SAndroid Build Coastguard Worker            "cuda_extension_kernel.cu",
50*da0073e9SAndroid Build Coastguard Worker            "cuda_extension_kernel2.cu",
51*da0073e9SAndroid Build Coastguard Worker        ],
52*da0073e9SAndroid Build Coastguard Worker        extra_compile_args={"cxx": CXX_FLAGS, "nvcc": ["-O2"]},
53*da0073e9SAndroid Build Coastguard Worker    )
54*da0073e9SAndroid Build Coastguard Worker    ext_modules.append(extension)
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Workerif torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None):
57*da0073e9SAndroid Build Coastguard Worker    extension = CUDAExtension(
58*da0073e9SAndroid Build Coastguard Worker        "torch_test_cpp_extension.torch_library",
59*da0073e9SAndroid Build Coastguard Worker        ["torch_library.cu"],
60*da0073e9SAndroid Build Coastguard Worker        extra_compile_args={"cxx": CXX_FLAGS, "nvcc": ["-O2"]},
61*da0073e9SAndroid Build Coastguard Worker    )
62*da0073e9SAndroid Build Coastguard Worker    ext_modules.append(extension)
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Workerif torch.backends.mps.is_available():
65*da0073e9SAndroid Build Coastguard Worker    extension = CppExtension(
66*da0073e9SAndroid Build Coastguard Worker        "torch_test_cpp_extension.mps",
67*da0073e9SAndroid Build Coastguard Worker        ["mps_extension.mm"],
68*da0073e9SAndroid Build Coastguard Worker        extra_compile_args=CXX_FLAGS,
69*da0073e9SAndroid Build Coastguard Worker    )
70*da0073e9SAndroid Build Coastguard Worker    ext_modules.append(extension)
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker# todo(mkozuki): Figure out the root cause
73*da0073e9SAndroid Build Coastguard Workerif (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None:
74*da0073e9SAndroid Build Coastguard Worker    # malfet: One should not assume that PyTorch re-exports CUDA dependencies
75*da0073e9SAndroid Build Coastguard Worker    cublas_extension = CUDAExtension(
76*da0073e9SAndroid Build Coastguard Worker        name="torch_test_cpp_extension.cublas_extension",
77*da0073e9SAndroid Build Coastguard Worker        sources=["cublas_extension.cpp"],
78*da0073e9SAndroid Build Coastguard Worker        libraries=["cublas"] if torch.version.hip is None else [],
79*da0073e9SAndroid Build Coastguard Worker    )
80*da0073e9SAndroid Build Coastguard Worker    ext_modules.append(cublas_extension)
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    cusolver_extension = CUDAExtension(
83*da0073e9SAndroid Build Coastguard Worker        name="torch_test_cpp_extension.cusolver_extension",
84*da0073e9SAndroid Build Coastguard Worker        sources=["cusolver_extension.cpp"],
85*da0073e9SAndroid Build Coastguard Worker        libraries=["cusolver"] if torch.version.hip is None else [],
86*da0073e9SAndroid Build Coastguard Worker    )
87*da0073e9SAndroid Build Coastguard Worker    ext_modules.append(cusolver_extension)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Workerif (
90*da0073e9SAndroid Build Coastguard Worker    USE_NINJA
91*da0073e9SAndroid Build Coastguard Worker    and (not IS_WINDOWS)
92*da0073e9SAndroid Build Coastguard Worker    and torch.cuda.is_available()
93*da0073e9SAndroid Build Coastguard Worker    and CUDA_HOME is not None
94*da0073e9SAndroid Build Coastguard Worker):
95*da0073e9SAndroid Build Coastguard Worker    extension = CUDAExtension(
96*da0073e9SAndroid Build Coastguard Worker        name="torch_test_cpp_extension.cuda_dlink",
97*da0073e9SAndroid Build Coastguard Worker        sources=[
98*da0073e9SAndroid Build Coastguard Worker            "cuda_dlink_extension.cpp",
99*da0073e9SAndroid Build Coastguard Worker            "cuda_dlink_extension_kernel.cu",
100*da0073e9SAndroid Build Coastguard Worker            "cuda_dlink_extension_add.cu",
101*da0073e9SAndroid Build Coastguard Worker        ],
102*da0073e9SAndroid Build Coastguard Worker        dlink=True,
103*da0073e9SAndroid Build Coastguard Worker        extra_compile_args={"cxx": CXX_FLAGS, "nvcc": ["-O2", "-dc"]},
104*da0073e9SAndroid Build Coastguard Worker    )
105*da0073e9SAndroid Build Coastguard Worker    ext_modules.append(extension)
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Workersetup(
108*da0073e9SAndroid Build Coastguard Worker    name="torch_test_cpp_extension",
109*da0073e9SAndroid Build Coastguard Worker    packages=["torch_test_cpp_extension"],
110*da0073e9SAndroid Build Coastguard Worker    ext_modules=ext_modules,
111*da0073e9SAndroid Build Coastguard Worker    include_dirs="self_compiler_include_dirs_test",
112*da0073e9SAndroid Build Coastguard Worker    cmdclass={"build_ext": BuildExtension.with_options(use_ninja=USE_NINJA)},
113*da0073e9SAndroid Build Coastguard Worker    entry_points={
114*da0073e9SAndroid Build Coastguard Worker        "torch.backends": [
115*da0073e9SAndroid Build Coastguard Worker            "device_backend = torch_test_cpp_extension:_autoload",
116*da0073e9SAndroid Build Coastguard Worker        ],
117*da0073e9SAndroid Build Coastguard Worker    },
118*da0073e9SAndroid Build Coastguard Worker)
119