xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/setup.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import os
2import sys
3
4from setuptools import setup
5
6import torch.cuda
7from torch.testing._internal.common_utils import IS_WINDOWS
8from torch.utils.cpp_extension import (
9    BuildExtension,
10    CppExtension,
11    CUDA_HOME,
12    CUDAExtension,
13    ROCM_HOME,
14)
15
16
17if sys.platform == "win32":
18    vc_version = os.getenv("VCToolsVersion", "")
19    if vc_version.startswith("14.16."):
20        CXX_FLAGS = ["/sdl"]
21    else:
22        CXX_FLAGS = ["/sdl", "/permissive-"]
23else:
24    CXX_FLAGS = ["-g"]
25
26USE_NINJA = os.getenv("USE_NINJA") == "1"
27
28ext_modules = [
29    CppExtension(
30        "torch_test_cpp_extension.cpp", ["extension.cpp"], extra_compile_args=CXX_FLAGS
31    ),
32    CppExtension(
33        "torch_test_cpp_extension.maia",
34        ["maia_extension.cpp"],
35        extra_compile_args=CXX_FLAGS,
36    ),
37    CppExtension(
38        "torch_test_cpp_extension.rng",
39        ["rng_extension.cpp"],
40        extra_compile_args=CXX_FLAGS,
41    ),
42]
43
44if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None):
45    extension = CUDAExtension(
46        "torch_test_cpp_extension.cuda",
47        [
48            "cuda_extension.cpp",
49            "cuda_extension_kernel.cu",
50            "cuda_extension_kernel2.cu",
51        ],
52        extra_compile_args={"cxx": CXX_FLAGS, "nvcc": ["-O2"]},
53    )
54    ext_modules.append(extension)
55
56if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None):
57    extension = CUDAExtension(
58        "torch_test_cpp_extension.torch_library",
59        ["torch_library.cu"],
60        extra_compile_args={"cxx": CXX_FLAGS, "nvcc": ["-O2"]},
61    )
62    ext_modules.append(extension)
63
64if torch.backends.mps.is_available():
65    extension = CppExtension(
66        "torch_test_cpp_extension.mps",
67        ["mps_extension.mm"],
68        extra_compile_args=CXX_FLAGS,
69    )
70    ext_modules.append(extension)
71
72# todo(mkozuki): Figure out the root cause
73if (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None:
74    # malfet: One should not assume that PyTorch re-exports CUDA dependencies
75    cublas_extension = CUDAExtension(
76        name="torch_test_cpp_extension.cublas_extension",
77        sources=["cublas_extension.cpp"],
78        libraries=["cublas"] if torch.version.hip is None else [],
79    )
80    ext_modules.append(cublas_extension)
81
82    cusolver_extension = CUDAExtension(
83        name="torch_test_cpp_extension.cusolver_extension",
84        sources=["cusolver_extension.cpp"],
85        libraries=["cusolver"] if torch.version.hip is None else [],
86    )
87    ext_modules.append(cusolver_extension)
88
89if (
90    USE_NINJA
91    and (not IS_WINDOWS)
92    and torch.cuda.is_available()
93    and CUDA_HOME is not None
94):
95    extension = CUDAExtension(
96        name="torch_test_cpp_extension.cuda_dlink",
97        sources=[
98            "cuda_dlink_extension.cpp",
99            "cuda_dlink_extension_kernel.cu",
100            "cuda_dlink_extension_add.cu",
101        ],
102        dlink=True,
103        extra_compile_args={"cxx": CXX_FLAGS, "nvcc": ["-O2", "-dc"]},
104    )
105    ext_modules.append(extension)
106
107setup(
108    name="torch_test_cpp_extension",
109    packages=["torch_test_cpp_extension"],
110    ext_modules=ext_modules,
111    include_dirs="self_compiler_include_dirs_test",
112    cmdclass={"build_ext": BuildExtension.with_options(use_ninja=USE_NINJA)},
113    entry_points={
114        "torch.backends": [
115            "device_backend = torch_test_cpp_extension:_autoload",
116        ],
117    },
118)
119