1*da0073e9SAndroid Build Coastguard Workerset(PYTORCH_FOUND_HIP FALSE) 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerif(NOT DEFINED ENV{ROCM_PATH}) 4*da0073e9SAndroid Build Coastguard Worker set(ROCM_PATH /opt/rocm) 5*da0073e9SAndroid Build Coastguard Workerelse() 6*da0073e9SAndroid Build Coastguard Worker set(ROCM_PATH $ENV{ROCM_PATH}) 7*da0073e9SAndroid Build Coastguard Workerendif() 8*da0073e9SAndroid Build Coastguard Workerif(NOT DEFINED ENV{ROCM_INCLUDE_DIRS}) 9*da0073e9SAndroid Build Coastguard Worker set(ROCM_INCLUDE_DIRS ${ROCM_PATH}/include) 10*da0073e9SAndroid Build Coastguard Workerelse() 11*da0073e9SAndroid Build Coastguard Worker set(ROCM_INCLUDE_DIRS $ENV{ROCM_INCLUDE_DIRS}) 12*da0073e9SAndroid Build Coastguard Workerendif() 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerif(NOT EXISTS ${ROCM_PATH}) 15*da0073e9SAndroid Build Coastguard Worker return() 16*da0073e9SAndroid Build Coastguard Workerendif() 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker# MAGMA_HOME 19*da0073e9SAndroid Build Coastguard Workerif(NOT DEFINED ENV{MAGMA_HOME}) 20*da0073e9SAndroid Build Coastguard Worker set(MAGMA_HOME ${ROCM_PATH}/magma) 21*da0073e9SAndroid Build Coastguard Worker set(ENV{MAGMA_HOME} ${ROCM_PATH}/magma) 22*da0073e9SAndroid Build Coastguard Workerelse() 23*da0073e9SAndroid Build Coastguard Worker set(MAGMA_HOME $ENV{MAGMA_HOME}) 24*da0073e9SAndroid Build Coastguard Workerendif() 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Workertorch_hip_get_arch_list(PYTORCH_ROCM_ARCH) 27*da0073e9SAndroid Build Coastguard Workerif(PYTORCH_ROCM_ARCH STREQUAL "") 28*da0073e9SAndroid Build Coastguard Worker message(FATAL_ERROR "No GPU arch specified for ROCm build. Please use PYTORCH_ROCM_ARCH environment variable to specify GPU archs to build for.") 29*da0073e9SAndroid Build Coastguard Workerendif() 30*da0073e9SAndroid Build Coastguard Workermessage("Building PyTorch for GPU arch: ${PYTORCH_ROCM_ARCH}") 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker# Add HIP to the CMAKE Module Path 33*da0073e9SAndroid Build Coastguard Workerset(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH}) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Workermacro(find_package_and_print_version PACKAGE_NAME) 36*da0073e9SAndroid Build Coastguard Worker find_package("${PACKAGE_NAME}" ${ARGN}) 37*da0073e9SAndroid Build Coastguard Worker message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}") 38*da0073e9SAndroid Build Coastguard Workerendmacro() 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker# Find the HIP Package 41*da0073e9SAndroid Build Coastguard Workerfind_package_and_print_version(HIP 1.0) 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Workerif(HIP_FOUND) 44*da0073e9SAndroid Build Coastguard Worker set(PYTORCH_FOUND_HIP TRUE) 45*da0073e9SAndroid Build Coastguard Worker set(FOUND_ROCM_VERSION_H FALSE) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}") 48*da0073e9SAndroid Build Coastguard Worker set(file "${PROJECT_BINARY_DIR}/detect_rocm_version.cc") 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker # Find ROCM version for checks 51*da0073e9SAndroid Build Coastguard Worker # ROCM 5.0 and later will have header api for version management 52*da0073e9SAndroid Build Coastguard Worker if(EXISTS ${ROCM_INCLUDE_DIRS}/rocm_version.h) 53*da0073e9SAndroid Build Coastguard Worker set(FOUND_ROCM_VERSION_H TRUE) 54*da0073e9SAndroid Build Coastguard Worker file(WRITE ${file} "" 55*da0073e9SAndroid Build Coastguard Worker "#include <rocm_version.h>\n" 56*da0073e9SAndroid Build Coastguard Worker ) 57*da0073e9SAndroid Build Coastguard Worker elseif(EXISTS ${ROCM_INCLUDE_DIRS}/rocm-core/rocm_version.h) 58*da0073e9SAndroid Build Coastguard Worker set(FOUND_ROCM_VERSION_H TRUE) 59*da0073e9SAndroid Build Coastguard Worker file(WRITE ${file} "" 60*da0073e9SAndroid Build Coastguard Worker "#include <rocm-core/rocm_version.h>\n" 61*da0073e9SAndroid Build Coastguard Worker ) 62*da0073e9SAndroid Build Coastguard Worker else() 63*da0073e9SAndroid Build Coastguard Worker message("********************* rocm_version.h couldnt be found ******************\n") 64*da0073e9SAndroid Build Coastguard Worker endif() 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker if(FOUND_ROCM_VERSION_H) 67*da0073e9SAndroid Build Coastguard Worker file(APPEND ${file} "" 68*da0073e9SAndroid Build Coastguard Worker "#include <cstdio>\n" 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker "#ifndef ROCM_VERSION_PATCH\n" 71*da0073e9SAndroid Build Coastguard Worker "#define ROCM_VERSION_PATCH 0\n" 72*da0073e9SAndroid Build Coastguard Worker "#endif\n" 73*da0073e9SAndroid Build Coastguard Worker "#define STRINGIFYHELPER(x) #x\n" 74*da0073e9SAndroid Build Coastguard Worker "#define STRINGIFY(x) STRINGIFYHELPER(x)\n" 75*da0073e9SAndroid Build Coastguard Worker "int main() {\n" 76*da0073e9SAndroid Build Coastguard Worker " printf(\"%d.%d.%s\", ROCM_VERSION_MAJOR, ROCM_VERSION_MINOR, STRINGIFY(ROCM_VERSION_PATCH));\n" 77*da0073e9SAndroid Build Coastguard Worker " return 0;\n" 78*da0073e9SAndroid Build Coastguard Worker "}\n" 79*da0073e9SAndroid Build Coastguard Worker ) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker try_run(run_result compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file} 82*da0073e9SAndroid Build Coastguard Worker CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" 83*da0073e9SAndroid Build Coastguard Worker RUN_OUTPUT_VARIABLE rocm_version_from_header 84*da0073e9SAndroid Build Coastguard Worker COMPILE_OUTPUT_VARIABLE output_var 85*da0073e9SAndroid Build Coastguard Worker ) 86*da0073e9SAndroid Build Coastguard Worker # We expect the compile to be successful if the include directory exists. 87*da0073e9SAndroid Build Coastguard Worker if(NOT compile_result) 88*da0073e9SAndroid Build Coastguard Worker message(FATAL_ERROR "Caffe2: Couldn't determine version from header: " ${output_var}) 89*da0073e9SAndroid Build Coastguard Worker endif() 90*da0073e9SAndroid Build Coastguard Worker message(STATUS "Caffe2: Header version is: " ${rocm_version_from_header}) 91*da0073e9SAndroid Build Coastguard Worker set(ROCM_VERSION_DEV_RAW ${rocm_version_from_header}) 92*da0073e9SAndroid Build Coastguard Worker message("\n***** ROCm version from rocm_version.h ****\n") 93*da0073e9SAndroid Build Coastguard Worker endif() 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+).*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW}) 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker if(ROCM_VERSION_DEV_MATCH) 98*da0073e9SAndroid Build Coastguard Worker set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) 99*da0073e9SAndroid Build Coastguard Worker set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) 100*da0073e9SAndroid Build Coastguard Worker set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) 101*da0073e9SAndroid Build Coastguard Worker set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") 102*da0073e9SAndroid Build Coastguard Worker math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") 103*da0073e9SAndroid Build Coastguard Worker endif() 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") 106*da0073e9SAndroid Build Coastguard Worker message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}") 107*da0073e9SAndroid Build Coastguard Worker message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}") 108*da0073e9SAndroid Build Coastguard Worker message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}") 109*da0073e9SAndroid Build Coastguard Worker message("ROCM_VERSION_DEV_INT: ${ROCM_VERSION_DEV_INT}") 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker math(EXPR TORCH_HIP_VERSION "(${HIP_VERSION_MAJOR} * 100) + ${HIP_VERSION_MINOR}") 112*da0073e9SAndroid Build Coastguard Worker message("HIP_VERSION_MAJOR: ${HIP_VERSION_MAJOR}") 113*da0073e9SAndroid Build Coastguard Worker message("HIP_VERSION_MINOR: ${HIP_VERSION_MINOR}") 114*da0073e9SAndroid Build Coastguard Worker message("TORCH_HIP_VERSION: ${TORCH_HIP_VERSION}") 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker message("\n***** Library versions from dpkg *****\n") 117*da0073e9SAndroid Build Coastguard Worker execute_process(COMMAND dpkg -l COMMAND grep rocm-dev COMMAND awk "{print $2 \" VERSION: \" $3}") 118*da0073e9SAndroid Build Coastguard Worker execute_process(COMMAND dpkg -l COMMAND grep rocm-libs COMMAND awk "{print $2 \" VERSION: \" $3}") 119*da0073e9SAndroid Build Coastguard Worker execute_process(COMMAND dpkg -l COMMAND grep hsakmt-roct COMMAND awk "{print $2 \" VERSION: \" $3}") 120*da0073e9SAndroid Build Coastguard Worker execute_process(COMMAND dpkg -l COMMAND grep rocr-dev COMMAND awk "{print $2 \" VERSION: \" $3}") 121*da0073e9SAndroid Build Coastguard Worker execute_process(COMMAND dpkg -l COMMAND grep -w hcc COMMAND awk "{print $2 \" VERSION: \" $3}") 122*da0073e9SAndroid Build Coastguard Worker execute_process(COMMAND dpkg -l COMMAND grep hip-base COMMAND awk "{print $2 \" VERSION: \" $3}") 123*da0073e9SAndroid Build Coastguard Worker execute_process(COMMAND dpkg -l COMMAND grep hip_hcc COMMAND awk "{print $2 \" VERSION: \" $3}") 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker message("\n***** Library versions from cmake find_package *****\n") 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker set(CMAKE_HIP_CLANG_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) 128*da0073e9SAndroid Build Coastguard Worker set(CMAKE_HIP_CLANG_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) 129*da0073e9SAndroid Build Coastguard Worker ### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.### 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker set(hip_DIR ${ROCM_PATH}/lib/cmake/hip) 132*da0073e9SAndroid Build Coastguard Worker set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64) 133*da0073e9SAndroid Build Coastguard Worker set(AMDDeviceLibs_DIR ${ROCM_PATH}/lib/cmake/AMDDeviceLibs) 134*da0073e9SAndroid Build Coastguard Worker set(amd_comgr_DIR ${ROCM_PATH}/lib/cmake/amd_comgr) 135*da0073e9SAndroid Build Coastguard Worker set(rocrand_DIR ${ROCM_PATH}/lib/cmake/rocrand) 136*da0073e9SAndroid Build Coastguard Worker set(hiprand_DIR ${ROCM_PATH}/lib/cmake/hiprand) 137*da0073e9SAndroid Build Coastguard Worker set(rocblas_DIR ${ROCM_PATH}/lib/cmake/rocblas) 138*da0073e9SAndroid Build Coastguard Worker set(hipblas_DIR ${ROCM_PATH}/lib/cmake/hipblas) 139*da0073e9SAndroid Build Coastguard Worker set(hipblaslt_DIR ${ROCM_PATH}/lib/cmake/hipblaslt) 140*da0073e9SAndroid Build Coastguard Worker set(miopen_DIR ${ROCM_PATH}/lib/cmake/miopen) 141*da0073e9SAndroid Build Coastguard Worker set(rocfft_DIR ${ROCM_PATH}/lib/cmake/rocfft) 142*da0073e9SAndroid Build Coastguard Worker set(hipfft_DIR ${ROCM_PATH}/lib/cmake/hipfft) 143*da0073e9SAndroid Build Coastguard Worker set(hipsparse_DIR ${ROCM_PATH}/lib/cmake/hipsparse) 144*da0073e9SAndroid Build Coastguard Worker set(rccl_DIR ${ROCM_PATH}/lib/cmake/rccl) 145*da0073e9SAndroid Build Coastguard Worker set(rocprim_DIR ${ROCM_PATH}/lib/cmake/rocprim) 146*da0073e9SAndroid Build Coastguard Worker set(hipcub_DIR ${ROCM_PATH}/lib/cmake/hipcub) 147*da0073e9SAndroid Build Coastguard Worker set(rocthrust_DIR ${ROCM_PATH}/lib/cmake/rocthrust) 148*da0073e9SAndroid Build Coastguard Worker set(hipsolver_DIR ${ROCM_PATH}/lib/cmake/hipsolver) 149*da0073e9SAndroid Build Coastguard Worker set(hiprtc_DIR ${ROCM_PATH}/lib/cmake/hiprtc) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(hip REQUIRED) 153*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(hsa-runtime64 REQUIRED) 154*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(amd_comgr REQUIRED) 155*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(rocrand REQUIRED) 156*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(hiprand REQUIRED) 157*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(rocblas REQUIRED) 158*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(hipblas REQUIRED) 159*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(hipblaslt REQUIRED) 160*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(miopen REQUIRED) 161*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(hipfft REQUIRED) 162*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(hipsparse REQUIRED) 163*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(rccl) 164*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(rocprim REQUIRED) 165*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(hipcub REQUIRED) 166*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(rocthrust REQUIRED) 167*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(hipsolver REQUIRED) 168*da0073e9SAndroid Build Coastguard Worker find_package_and_print_version(hiprtc REQUIRED) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker find_library(PYTORCH_HIP_LIBRARIES amdhip64 HINTS ${ROCM_PATH}/lib) 172*da0073e9SAndroid Build Coastguard Worker # TODO: miopen_LIBRARIES should return fullpath to the library file, 173*da0073e9SAndroid Build Coastguard Worker # however currently it's just the lib name 174*da0073e9SAndroid Build Coastguard Worker if(TARGET ${miopen_LIBRARIES}) 175*da0073e9SAndroid Build Coastguard Worker set(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES}) 176*da0073e9SAndroid Build Coastguard Worker else() 177*da0073e9SAndroid Build Coastguard Worker find_library(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${ROCM_PATH}/lib) 178*da0073e9SAndroid Build Coastguard Worker endif() 179*da0073e9SAndroid Build Coastguard Worker # TODO: rccl_LIBRARIES should return fullpath to the library file, 180*da0073e9SAndroid Build Coastguard Worker # however currently it's just the lib name 181*da0073e9SAndroid Build Coastguard Worker if(TARGET ${rccl_LIBRARIES}) 182*da0073e9SAndroid Build Coastguard Worker set(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES}) 183*da0073e9SAndroid Build Coastguard Worker else() 184*da0073e9SAndroid Build Coastguard Worker find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${ROCM_PATH}/lib) 185*da0073e9SAndroid Build Coastguard Worker endif() 186*da0073e9SAndroid Build Coastguard Worker find_library(ROCM_HIPRTC_LIB hiprtc HINTS ${ROCM_PATH}/lib) 187*da0073e9SAndroid Build Coastguard Worker # roctx is part of roctracer 188*da0073e9SAndroid Build Coastguard Worker find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker # check whether HIP declares new types 191*da0073e9SAndroid Build Coastguard Worker set(file "${PROJECT_BINARY_DIR}/hip_new_types.cc") 192*da0073e9SAndroid Build Coastguard Worker file(WRITE ${file} "" 193*da0073e9SAndroid Build Coastguard Worker "#include <hip/library_types.h>\n" 194*da0073e9SAndroid Build Coastguard Worker "int main() {\n" 195*da0073e9SAndroid Build Coastguard Worker " hipDataType baz = HIP_R_8F_E4M3_FNUZ;\n" 196*da0073e9SAndroid Build Coastguard Worker " return 0;\n" 197*da0073e9SAndroid Build Coastguard Worker "}\n" 198*da0073e9SAndroid Build Coastguard Worker ) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker try_compile(hip_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file} 201*da0073e9SAndroid Build Coastguard Worker CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" 202*da0073e9SAndroid Build Coastguard Worker COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ 203*da0073e9SAndroid Build Coastguard Worker OUTPUT_VARIABLE hip_compile_output) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker if(hip_compile_result) 206*da0073e9SAndroid Build Coastguard Worker set(HIP_NEW_TYPE_ENUMS ON) 207*da0073e9SAndroid Build Coastguard Worker #message("HIP is using new type enums: ${hip_compile_output}") 208*da0073e9SAndroid Build Coastguard Worker message("HIP is using new type enums") 209*da0073e9SAndroid Build Coastguard Worker else() 210*da0073e9SAndroid Build Coastguard Worker set(HIP_NEW_TYPE_ENUMS OFF) 211*da0073e9SAndroid Build Coastguard Worker #message("HIP is NOT using new type enums: ${hip_compile_output}") 212*da0073e9SAndroid Build Coastguard Worker message("HIP is NOT using new type enums") 213*da0073e9SAndroid Build Coastguard Worker endif() 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Workerendif() 216