xref: /aosp_15_r20/external/pytorch/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Synopsis:
2#   CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures])
3#   -- Selects GPU arch flags for nvcc based on target_CUDA_architectures
4#      target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...)
5#       - "Auto" detects local machine GPU compute arch at runtime.
6#       - "Common" and "All" cover common and entire subsets of architectures
7#      ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
8#      NAME: Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal Volta Turing Ampere
9#      NUM: Any number. Only those pairs are currently accepted by NVCC though:
10#            3.5 3.7 5.0 5.2 5.3 6.0 6.2 7.0 7.2 7.5 8.0
11#      Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable}
12#      Additionally, sets ${out_variable}_readable to the resulting numeric list
13#      Example:
14#       CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell)
15#        LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS})
16#
17#      More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA
18#
19
20if(CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language
21  if(CMAKE_CUDA_COMPILER_ID STREQUAL "NVIDIA"
22      AND CMAKE_CUDA_COMPILER_VERSION MATCHES "^([0-9]+\\.[0-9]+)")
23    set(CUDA_VERSION "${CMAKE_MATCH_1}")
24  endif()
25endif()
26
27# See: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#gpu-feature-list
28
29# This list will be used for CUDA_ARCH_NAME = All option
30set(CUDA_KNOWN_GPU_ARCHITECTURES  "Kepler" "Maxwell")
31
32# This list will be used for CUDA_ARCH_NAME = Common option (enabled by default)
33set(CUDA_COMMON_GPU_ARCHITECTURES "3.5" "5.0")
34
35# This list is used to filter CUDA archs when autodetecting
36set(CUDA_ALL_GPU_ARCHITECTURES "3.5" "5.0")
37
38if(CUDA_VERSION VERSION_GREATER "10.5")
39  list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ampere")
40  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0")
41  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.0")
42
43  if(CUDA_VERSION VERSION_LESS "11.1")
44    set(CUDA_LIMIT_GPU_ARCHITECTURE "8.0")
45    list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0+PTX")
46  endif()
47endif()
48
49if(NOT CUDA_VERSION VERSION_LESS "11.1")
50  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6")
51  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.6")
52  set(CUDA_LIMIT_GPU_ARCHITECUTRE "8.6")
53
54  if(CUDA_VERSION VERSION_LESS "11.8")
55    set(CUDA_LIMIT_GPU_ARCHITECTURE "8.9")
56    list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6+PTX")
57  endif()
58endif()
59
60if(NOT CUDA_VERSION VERSION_LESS "11.8")
61  list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ada")
62  list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Hopper")
63  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.9")
64  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0")
65  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.9")
66  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0")
67
68  if(CUDA_VERSION VERSION_LESS "12.0")
69    set(CUDA_LIMIT_GPU_ARCHITECTURE "9.0")
70    list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.9+PTX")
71    list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0+PTX")
72  endif()
73endif()
74
75if(NOT CUDA_VERSION VERSION_LESS "12.0")
76  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0a")
77  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0a")
78  list(REMOVE_ITEM CUDA_COMMON_GPU_ARCHITECTURES "3.5")
79  list(REMOVE_ITEM CUDA_ALL_GPU_ARCHITECTURES "3.5")
80endif()
81
82################################################################################################
83# A function for automatic detection of GPUs installed  (if autodetection is enabled)
84# Usage:
85#   CUDA_DETECT_INSTALLED_GPUS(OUT_VARIABLE)
86#
87function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE)
88  if(NOT CUDA_GPU_DETECT_OUTPUT)
89    if(CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language
90      set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cu")
91    else()
92      set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cpp")
93    endif()
94
95    file(WRITE ${file} ""
96      "#include <cuda_runtime.h>\n"
97      "#include <cstdio>\n"
98      "int main()\n"
99      "{\n"
100      "  int count = 0;\n"
101      "  if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n"
102      "  if (count == 0) return -1;\n"
103      "  for (int device = 0; device < count; ++device)\n"
104      "  {\n"
105      "    cudaDeviceProp prop;\n"
106      "    if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n"
107      "      std::printf(\"%d.%d \", prop.major, prop.minor);\n"
108      "  }\n"
109      "  return 0;\n"
110      "}\n")
111
112    if(CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language
113      try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file}
114              RUN_OUTPUT_VARIABLE compute_capabilities)
115    else()
116      try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file}
117              CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}"
118              LINK_LIBRARIES ${CUDA_LIBRARIES}
119              RUN_OUTPUT_VARIABLE compute_capabilities)
120    endif()
121
122    # Filter unrelated content out of the output.
123    string(REGEX MATCHALL "[0-9]+\\.[0-9]+" compute_capabilities "${compute_capabilities}")
124
125    if(run_result EQUAL 0)
126      string(REPLACE "2.1" "2.1(2.0)" compute_capabilities "${compute_capabilities}")
127      set(CUDA_GPU_DETECT_OUTPUT ${compute_capabilities}
128        CACHE INTERNAL "Returned GPU architectures from detect_gpus tool" FORCE)
129    endif()
130  endif()
131
132  if(NOT CUDA_GPU_DETECT_OUTPUT)
133    message(STATUS "Automatic GPU detection failed. Building for common architectures.")
134    set(${OUT_VARIABLE} ${CUDA_COMMON_GPU_ARCHITECTURES} PARENT_SCOPE)
135  else()
136    # Filter based on CUDA version supported archs
137    set(CUDA_GPU_DETECT_OUTPUT_FILTERED "")
138    separate_arguments(CUDA_GPU_DETECT_OUTPUT)
139    foreach(ITEM IN ITEMS ${CUDA_GPU_DETECT_OUTPUT})
140        if(CUDA_LIMIT_GPU_ARCHITECTURE AND (ITEM VERSION_GREATER CUDA_LIMIT_GPU_ARCHITECTURE OR
141                                            ITEM VERSION_EQUAL CUDA_LIMIT_GPU_ARCHITECTURE))
142        list(GET CUDA_COMMON_GPU_ARCHITECTURES -1 NEWITEM)
143        string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${NEWITEM}")
144      else()
145        string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${ITEM}")
146      endif()
147    endforeach()
148
149    set(${OUT_VARIABLE} ${CUDA_GPU_DETECT_OUTPUT_FILTERED} PARENT_SCOPE)
150  endif()
151endfunction()
152
153
154################################################################################################
155# Function for selecting GPU arch flags for nvcc based on CUDA architectures from parameter list
156# Usage:
157#   SELECT_NVCC_ARCH_FLAGS(out_variable [list of CUDA compute archs])
158function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable)
159  set(CUDA_ARCH_LIST "${ARGN}")
160
161  if("X${CUDA_ARCH_LIST}" STREQUAL "X" )
162    set(CUDA_ARCH_LIST "Auto")
163  endif()
164
165  set(cuda_arch_bin)
166  set(cuda_arch_ptx)
167
168  if("${CUDA_ARCH_LIST}" STREQUAL "All")
169    set(CUDA_ARCH_LIST ${CUDA_KNOWN_GPU_ARCHITECTURES})
170  elseif("${CUDA_ARCH_LIST}" STREQUAL "Common")
171    set(CUDA_ARCH_LIST ${CUDA_COMMON_GPU_ARCHITECTURES})
172  elseif("${CUDA_ARCH_LIST}" STREQUAL "Auto")
173    CUDA_DETECT_INSTALLED_GPUS(CUDA_ARCH_LIST)
174    message(STATUS "Autodetected CUDA architecture(s): ${CUDA_ARCH_LIST}")
175  endif()
176
177  # Now process the list and look for names
178  string(REGEX REPLACE "[ \t]+" ";" CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
179  list(REMOVE_DUPLICATES CUDA_ARCH_LIST)
180  foreach(arch_name ${CUDA_ARCH_LIST})
181    set(arch_bin)
182    set(arch_ptx)
183    set(add_ptx FALSE)
184    # Check to see if we are compiling PTX
185    if(arch_name MATCHES "(.*)\\+PTX$")
186      set(add_ptx TRUE)
187      set(arch_name ${CMAKE_MATCH_1})
188    endif()
189    if(arch_name MATCHES "^([0-9]\\.[0-9]a?(\\([0-9]\\.[0-9]\\))?)$")
190      set(arch_bin ${CMAKE_MATCH_1})
191      set(arch_ptx ${arch_bin})
192    else()
193      # Look for it in our list of known architectures
194      if(${arch_name} STREQUAL "Kepler+Tesla")
195        set(arch_bin 3.7)
196      elseif(${arch_name} STREQUAL "Kepler")
197        set(arch_bin 3.5)
198        set(arch_ptx 3.5)
199      elseif(${arch_name} STREQUAL "Maxwell+Tegra")
200        set(arch_bin 5.3)
201      elseif(${arch_name} STREQUAL "Maxwell")
202        set(arch_bin 5.0 5.2)
203        set(arch_ptx 5.2)
204      elseif(${arch_name} STREQUAL "Pascal")
205        set(arch_bin 6.0 6.1)
206        set(arch_ptx 6.1)
207     elseif(${arch_name} STREQUAL "Volta+Tegra")
208        set(arch_bin 7.2)
209      elseif(${arch_name} STREQUAL "Volta")
210        set(arch_bin 7.0 7.0)
211        set(arch_ptx 7.0)
212      elseif(${arch_name} STREQUAL "Turing")
213        set(arch_bin 7.5)
214        set(arch_ptx 7.5)
215      elseif(${arch_name} STREQUAL "Ampere+Tegra")
216        set(arch_bin 8.7)
217      elseif(${arch_name} STREQUAL "Ampere")
218        set(arch_bin 8.0 8.6)
219        set(arch_ptx 8.0 8.6)
220      elseif(${arch_name} STREQUAL "Ada")
221        set(arch_bin 8.9)
222        set(arch_ptx 8.9)
223      elseif(${arch_name} STREQUAL "Hopper")
224        set(arch_bin 9.0)
225        set(arch_ptx 9.0)
226      else()
227        message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS")
228      endif()
229    endif()
230    if(NOT arch_bin)
231      message(SEND_ERROR "arch_bin wasn't set for some reason")
232    endif()
233    list(APPEND cuda_arch_bin ${arch_bin})
234    if(add_ptx)
235      if (NOT arch_ptx)
236        set(arch_ptx ${arch_bin})
237      endif()
238      list(APPEND cuda_arch_ptx ${arch_ptx})
239    endif()
240  endforeach()
241
242  # remove dots and convert to lists
243  string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}")
244  string(REGEX REPLACE "\\." "" cuda_arch_ptx "${cuda_arch_ptx}")
245  string(REGEX MATCHALL "[0-9()]+a?" cuda_arch_bin "${cuda_arch_bin}")
246  string(REGEX MATCHALL "[0-9]+a?"   cuda_arch_ptx "${cuda_arch_ptx}")
247
248  if(cuda_arch_bin)
249    list(REMOVE_DUPLICATES cuda_arch_bin)
250  endif()
251  if(cuda_arch_ptx)
252    list(REMOVE_DUPLICATES cuda_arch_ptx)
253  endif()
254
255  set(nvcc_flags "")
256  set(nvcc_archs_readable "")
257
258  # Tell NVCC to add binaries for the specified GPUs
259  foreach(arch ${cuda_arch_bin})
260    if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)")
261      # User explicitly specified ARCH for the concrete CODE
262      list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1})
263      list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1})
264    else()
265      # User didn't explicitly specify ARCH for the concrete CODE, we assume ARCH=CODE
266      list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch})
267      list(APPEND nvcc_archs_readable sm_${arch})
268    endif()
269  endforeach()
270
271  # Tell NVCC to add PTX intermediate code for the specified architectures
272  foreach(arch ${cuda_arch_ptx})
273    list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch})
274    list(APPEND nvcc_archs_readable compute_${arch})
275  endforeach()
276
277  string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}")
278  set(${out_variable}          ${nvcc_flags}          PARENT_SCOPE)
279  set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE)
280endfunction()
281