xref: /aosp_15_r20/external/pytorch/cmake/public/LoadHIP.cmake (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerset(PYTORCH_FOUND_HIP FALSE)
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerif(NOT DEFINED ENV{ROCM_PATH})
4*da0073e9SAndroid Build Coastguard Worker  set(ROCM_PATH /opt/rocm)
5*da0073e9SAndroid Build Coastguard Workerelse()
6*da0073e9SAndroid Build Coastguard Worker  set(ROCM_PATH $ENV{ROCM_PATH})
7*da0073e9SAndroid Build Coastguard Workerendif()
8*da0073e9SAndroid Build Coastguard Workerif(NOT DEFINED ENV{ROCM_INCLUDE_DIRS})
9*da0073e9SAndroid Build Coastguard Worker  set(ROCM_INCLUDE_DIRS ${ROCM_PATH}/include)
10*da0073e9SAndroid Build Coastguard Workerelse()
11*da0073e9SAndroid Build Coastguard Worker  set(ROCM_INCLUDE_DIRS $ENV{ROCM_INCLUDE_DIRS})
12*da0073e9SAndroid Build Coastguard Workerendif()
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerif(NOT EXISTS ${ROCM_PATH})
15*da0073e9SAndroid Build Coastguard Worker  return()
16*da0073e9SAndroid Build Coastguard Workerendif()
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker# MAGMA_HOME
19*da0073e9SAndroid Build Coastguard Workerif(NOT DEFINED ENV{MAGMA_HOME})
20*da0073e9SAndroid Build Coastguard Worker  set(MAGMA_HOME ${ROCM_PATH}/magma)
21*da0073e9SAndroid Build Coastguard Worker  set(ENV{MAGMA_HOME} ${ROCM_PATH}/magma)
22*da0073e9SAndroid Build Coastguard Workerelse()
23*da0073e9SAndroid Build Coastguard Worker  set(MAGMA_HOME $ENV{MAGMA_HOME})
24*da0073e9SAndroid Build Coastguard Workerendif()
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Workertorch_hip_get_arch_list(PYTORCH_ROCM_ARCH)
27*da0073e9SAndroid Build Coastguard Workerif(PYTORCH_ROCM_ARCH STREQUAL "")
28*da0073e9SAndroid Build Coastguard Worker  message(FATAL_ERROR "No GPU arch specified for ROCm build. Please use PYTORCH_ROCM_ARCH environment variable to specify GPU archs to build for.")
29*da0073e9SAndroid Build Coastguard Workerendif()
30*da0073e9SAndroid Build Coastguard Workermessage("Building PyTorch for GPU arch: ${PYTORCH_ROCM_ARCH}")
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker# Add HIP to the CMAKE Module Path
33*da0073e9SAndroid Build Coastguard Workerset(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH})
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Workermacro(find_package_and_print_version PACKAGE_NAME)
36*da0073e9SAndroid Build Coastguard Worker  find_package("${PACKAGE_NAME}" ${ARGN})
37*da0073e9SAndroid Build Coastguard Worker  message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
38*da0073e9SAndroid Build Coastguard Workerendmacro()
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker# Find the HIP Package
41*da0073e9SAndroid Build Coastguard Workerfind_package_and_print_version(HIP 1.0)
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Workerif(HIP_FOUND)
44*da0073e9SAndroid Build Coastguard Worker  set(PYTORCH_FOUND_HIP TRUE)
45*da0073e9SAndroid Build Coastguard Worker  set(FOUND_ROCM_VERSION_H FALSE)
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker  set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}")
48*da0073e9SAndroid Build Coastguard Worker  set(file "${PROJECT_BINARY_DIR}/detect_rocm_version.cc")
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker  # Find ROCM version for checks
51*da0073e9SAndroid Build Coastguard Worker  # ROCM 5.0 and later will have header api for version management
52*da0073e9SAndroid Build Coastguard Worker  if(EXISTS ${ROCM_INCLUDE_DIRS}/rocm_version.h)
53*da0073e9SAndroid Build Coastguard Worker    set(FOUND_ROCM_VERSION_H TRUE)
54*da0073e9SAndroid Build Coastguard Worker    file(WRITE ${file} ""
55*da0073e9SAndroid Build Coastguard Worker      "#include <rocm_version.h>\n"
56*da0073e9SAndroid Build Coastguard Worker      )
57*da0073e9SAndroid Build Coastguard Worker  elseif(EXISTS ${ROCM_INCLUDE_DIRS}/rocm-core/rocm_version.h)
58*da0073e9SAndroid Build Coastguard Worker    set(FOUND_ROCM_VERSION_H TRUE)
59*da0073e9SAndroid Build Coastguard Worker    file(WRITE ${file} ""
60*da0073e9SAndroid Build Coastguard Worker      "#include <rocm-core/rocm_version.h>\n"
61*da0073e9SAndroid Build Coastguard Worker      )
62*da0073e9SAndroid Build Coastguard Worker  else()
63*da0073e9SAndroid Build Coastguard Worker    message("********************* rocm_version.h couldnt be found ******************\n")
64*da0073e9SAndroid Build Coastguard Worker  endif()
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker  if(FOUND_ROCM_VERSION_H)
67*da0073e9SAndroid Build Coastguard Worker    file(APPEND ${file} ""
68*da0073e9SAndroid Build Coastguard Worker      "#include <cstdio>\n"
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker      "#ifndef ROCM_VERSION_PATCH\n"
71*da0073e9SAndroid Build Coastguard Worker      "#define ROCM_VERSION_PATCH 0\n"
72*da0073e9SAndroid Build Coastguard Worker      "#endif\n"
73*da0073e9SAndroid Build Coastguard Worker      "#define STRINGIFYHELPER(x) #x\n"
74*da0073e9SAndroid Build Coastguard Worker      "#define STRINGIFY(x) STRINGIFYHELPER(x)\n"
75*da0073e9SAndroid Build Coastguard Worker      "int main() {\n"
76*da0073e9SAndroid Build Coastguard Worker      "  printf(\"%d.%d.%s\", ROCM_VERSION_MAJOR, ROCM_VERSION_MINOR, STRINGIFY(ROCM_VERSION_PATCH));\n"
77*da0073e9SAndroid Build Coastguard Worker      "  return 0;\n"
78*da0073e9SAndroid Build Coastguard Worker      "}\n"
79*da0073e9SAndroid Build Coastguard Worker      )
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker    try_run(run_result compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
82*da0073e9SAndroid Build Coastguard Worker      CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
83*da0073e9SAndroid Build Coastguard Worker      RUN_OUTPUT_VARIABLE rocm_version_from_header
84*da0073e9SAndroid Build Coastguard Worker      COMPILE_OUTPUT_VARIABLE output_var
85*da0073e9SAndroid Build Coastguard Worker      )
86*da0073e9SAndroid Build Coastguard Worker    # We expect the compile to be successful if the include directory exists.
87*da0073e9SAndroid Build Coastguard Worker    if(NOT compile_result)
88*da0073e9SAndroid Build Coastguard Worker      message(FATAL_ERROR "Caffe2: Couldn't determine version from header: " ${output_var})
89*da0073e9SAndroid Build Coastguard Worker    endif()
90*da0073e9SAndroid Build Coastguard Worker    message(STATUS "Caffe2: Header version is: " ${rocm_version_from_header})
91*da0073e9SAndroid Build Coastguard Worker    set(ROCM_VERSION_DEV_RAW ${rocm_version_from_header})
92*da0073e9SAndroid Build Coastguard Worker    message("\n***** ROCm version from rocm_version.h ****\n")
93*da0073e9SAndroid Build Coastguard Worker  endif()
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker  string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+).*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW})
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker  if(ROCM_VERSION_DEV_MATCH)
98*da0073e9SAndroid Build Coastguard Worker    set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1})
99*da0073e9SAndroid Build Coastguard Worker    set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2})
100*da0073e9SAndroid Build Coastguard Worker    set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3})
101*da0073e9SAndroid Build Coastguard Worker    set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}")
102*da0073e9SAndroid Build Coastguard Worker    math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}")
103*da0073e9SAndroid Build Coastguard Worker  endif()
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker  message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}")
106*da0073e9SAndroid Build Coastguard Worker  message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}")
107*da0073e9SAndroid Build Coastguard Worker  message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}")
108*da0073e9SAndroid Build Coastguard Worker  message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}")
109*da0073e9SAndroid Build Coastguard Worker  message("ROCM_VERSION_DEV_INT:   ${ROCM_VERSION_DEV_INT}")
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker  math(EXPR TORCH_HIP_VERSION "(${HIP_VERSION_MAJOR} * 100) + ${HIP_VERSION_MINOR}")
112*da0073e9SAndroid Build Coastguard Worker  message("HIP_VERSION_MAJOR: ${HIP_VERSION_MAJOR}")
113*da0073e9SAndroid Build Coastguard Worker  message("HIP_VERSION_MINOR: ${HIP_VERSION_MINOR}")
114*da0073e9SAndroid Build Coastguard Worker  message("TORCH_HIP_VERSION: ${TORCH_HIP_VERSION}")
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker  message("\n***** Library versions from dpkg *****\n")
117*da0073e9SAndroid Build Coastguard Worker  execute_process(COMMAND dpkg -l COMMAND grep rocm-dev COMMAND awk "{print $2 \" VERSION: \" $3}")
118*da0073e9SAndroid Build Coastguard Worker  execute_process(COMMAND dpkg -l COMMAND grep rocm-libs COMMAND awk "{print $2 \" VERSION: \" $3}")
119*da0073e9SAndroid Build Coastguard Worker  execute_process(COMMAND dpkg -l COMMAND grep hsakmt-roct COMMAND awk "{print $2 \" VERSION: \" $3}")
120*da0073e9SAndroid Build Coastguard Worker  execute_process(COMMAND dpkg -l COMMAND grep rocr-dev COMMAND awk "{print $2 \" VERSION: \" $3}")
121*da0073e9SAndroid Build Coastguard Worker  execute_process(COMMAND dpkg -l COMMAND grep -w hcc COMMAND awk "{print $2 \" VERSION: \" $3}")
122*da0073e9SAndroid Build Coastguard Worker  execute_process(COMMAND dpkg -l COMMAND grep hip-base COMMAND awk "{print $2 \" VERSION: \" $3}")
123*da0073e9SAndroid Build Coastguard Worker  execute_process(COMMAND dpkg -l COMMAND grep hip_hcc COMMAND awk "{print $2 \" VERSION: \" $3}")
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker  message("\n***** Library versions from cmake find_package *****\n")
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker  set(CMAKE_HIP_CLANG_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
128*da0073e9SAndroid Build Coastguard Worker  set(CMAKE_HIP_CLANG_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
129*da0073e9SAndroid Build Coastguard Worker  ### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.###
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker  set(hip_DIR ${ROCM_PATH}/lib/cmake/hip)
132*da0073e9SAndroid Build Coastguard Worker  set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64)
133*da0073e9SAndroid Build Coastguard Worker  set(AMDDeviceLibs_DIR ${ROCM_PATH}/lib/cmake/AMDDeviceLibs)
134*da0073e9SAndroid Build Coastguard Worker  set(amd_comgr_DIR ${ROCM_PATH}/lib/cmake/amd_comgr)
135*da0073e9SAndroid Build Coastguard Worker  set(rocrand_DIR ${ROCM_PATH}/lib/cmake/rocrand)
136*da0073e9SAndroid Build Coastguard Worker  set(hiprand_DIR ${ROCM_PATH}/lib/cmake/hiprand)
137*da0073e9SAndroid Build Coastguard Worker  set(rocblas_DIR ${ROCM_PATH}/lib/cmake/rocblas)
138*da0073e9SAndroid Build Coastguard Worker  set(hipblas_DIR ${ROCM_PATH}/lib/cmake/hipblas)
139*da0073e9SAndroid Build Coastguard Worker  set(hipblaslt_DIR ${ROCM_PATH}/lib/cmake/hipblaslt)
140*da0073e9SAndroid Build Coastguard Worker  set(miopen_DIR ${ROCM_PATH}/lib/cmake/miopen)
141*da0073e9SAndroid Build Coastguard Worker  set(rocfft_DIR ${ROCM_PATH}/lib/cmake/rocfft)
142*da0073e9SAndroid Build Coastguard Worker  set(hipfft_DIR ${ROCM_PATH}/lib/cmake/hipfft)
143*da0073e9SAndroid Build Coastguard Worker  set(hipsparse_DIR ${ROCM_PATH}/lib/cmake/hipsparse)
144*da0073e9SAndroid Build Coastguard Worker  set(rccl_DIR ${ROCM_PATH}/lib/cmake/rccl)
145*da0073e9SAndroid Build Coastguard Worker  set(rocprim_DIR ${ROCM_PATH}/lib/cmake/rocprim)
146*da0073e9SAndroid Build Coastguard Worker  set(hipcub_DIR ${ROCM_PATH}/lib/cmake/hipcub)
147*da0073e9SAndroid Build Coastguard Worker  set(rocthrust_DIR ${ROCM_PATH}/lib/cmake/rocthrust)
148*da0073e9SAndroid Build Coastguard Worker  set(hipsolver_DIR ${ROCM_PATH}/lib/cmake/hipsolver)
149*da0073e9SAndroid Build Coastguard Worker  set(hiprtc_DIR ${ROCM_PATH}/lib/cmake/hiprtc)
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(hip REQUIRED)
153*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(hsa-runtime64 REQUIRED)
154*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(amd_comgr REQUIRED)
155*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(rocrand REQUIRED)
156*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(hiprand REQUIRED)
157*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(rocblas REQUIRED)
158*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(hipblas REQUIRED)
159*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(hipblaslt REQUIRED)
160*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(miopen REQUIRED)
161*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(hipfft REQUIRED)
162*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(hipsparse REQUIRED)
163*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(rccl)
164*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(rocprim REQUIRED)
165*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(hipcub REQUIRED)
166*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(rocthrust REQUIRED)
167*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(hipsolver REQUIRED)
168*da0073e9SAndroid Build Coastguard Worker  find_package_and_print_version(hiprtc REQUIRED)
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker  find_library(PYTORCH_HIP_LIBRARIES amdhip64 HINTS ${ROCM_PATH}/lib)
172*da0073e9SAndroid Build Coastguard Worker  # TODO: miopen_LIBRARIES should return fullpath to the library file,
173*da0073e9SAndroid Build Coastguard Worker  # however currently it's just the lib name
174*da0073e9SAndroid Build Coastguard Worker  if(TARGET ${miopen_LIBRARIES})
175*da0073e9SAndroid Build Coastguard Worker    set(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES})
176*da0073e9SAndroid Build Coastguard Worker  else()
177*da0073e9SAndroid Build Coastguard Worker    find_library(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${ROCM_PATH}/lib)
178*da0073e9SAndroid Build Coastguard Worker  endif()
179*da0073e9SAndroid Build Coastguard Worker  # TODO: rccl_LIBRARIES should return fullpath to the library file,
180*da0073e9SAndroid Build Coastguard Worker  # however currently it's just the lib name
181*da0073e9SAndroid Build Coastguard Worker  if(TARGET ${rccl_LIBRARIES})
182*da0073e9SAndroid Build Coastguard Worker    set(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES})
183*da0073e9SAndroid Build Coastguard Worker  else()
184*da0073e9SAndroid Build Coastguard Worker    find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${ROCM_PATH}/lib)
185*da0073e9SAndroid Build Coastguard Worker  endif()
186*da0073e9SAndroid Build Coastguard Worker  find_library(ROCM_HIPRTC_LIB hiprtc HINTS ${ROCM_PATH}/lib)
187*da0073e9SAndroid Build Coastguard Worker  # roctx is part of roctracer
188*da0073e9SAndroid Build Coastguard Worker  find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker  # check whether HIP declares new types
191*da0073e9SAndroid Build Coastguard Worker  set(file "${PROJECT_BINARY_DIR}/hip_new_types.cc")
192*da0073e9SAndroid Build Coastguard Worker  file(WRITE ${file} ""
193*da0073e9SAndroid Build Coastguard Worker    "#include <hip/library_types.h>\n"
194*da0073e9SAndroid Build Coastguard Worker    "int main() {\n"
195*da0073e9SAndroid Build Coastguard Worker    "    hipDataType baz = HIP_R_8F_E4M3_FNUZ;\n"
196*da0073e9SAndroid Build Coastguard Worker    "    return 0;\n"
197*da0073e9SAndroid Build Coastguard Worker    "}\n"
198*da0073e9SAndroid Build Coastguard Worker    )
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker  try_compile(hip_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
201*da0073e9SAndroid Build Coastguard Worker    CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
202*da0073e9SAndroid Build Coastguard Worker    COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
203*da0073e9SAndroid Build Coastguard Worker    OUTPUT_VARIABLE hip_compile_output)
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker  if(hip_compile_result)
206*da0073e9SAndroid Build Coastguard Worker    set(HIP_NEW_TYPE_ENUMS ON)
207*da0073e9SAndroid Build Coastguard Worker    #message("HIP is using new type enums: ${hip_compile_output}")
208*da0073e9SAndroid Build Coastguard Worker    message("HIP is using new type enums")
209*da0073e9SAndroid Build Coastguard Worker  else()
210*da0073e9SAndroid Build Coastguard Worker    set(HIP_NEW_TYPE_ENUMS OFF)
211*da0073e9SAndroid Build Coastguard Worker    #message("HIP is NOT using new type enums: ${hip_compile_output}")
212*da0073e9SAndroid Build Coastguard Worker    message("HIP is NOT using new type enums")
213*da0073e9SAndroid Build Coastguard Worker  endif()
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Workerendif()
216