xref: /aosp_15_r20/external/pytorch/cmake/Modules_CUDA_fix/FindCUDNN.cmake (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Find the CUDNN libraries
2#
3# The following variables are optionally searched for defaults
4#  CUDNN_ROOT: Base directory where CUDNN is found
5#  CUDNN_INCLUDE_DIR: Directory where CUDNN header is searched for
6#  CUDNN_LIBRARY: Directory where CUDNN library is searched for
7#  CUDNN_STATIC: Are we looking for a static library? (default: no)
8#
9# The following are set after configuration is done:
10#  CUDNN_FOUND
11#  CUDNN_INCLUDE_PATH
12#  CUDNN_LIBRARY_PATH
13#
14
15include(FindPackageHandleStandardArgs)
16
17set(CUDNN_ROOT $ENV{CUDNN_ROOT_DIR} CACHE PATH "Folder containing NVIDIA cuDNN")
18if (DEFINED $ENV{CUDNN_ROOT_DIR})
19  message(WARNING "CUDNN_ROOT_DIR is deprecated. Please set CUDNN_ROOT instead.")
20endif()
21list(APPEND CUDNN_ROOT $ENV{CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR})
22
23# Compatible layer for CMake <3.12. CUDNN_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
24list(APPEND CMAKE_PREFIX_PATH ${CUDNN_ROOT})
25
26set(CUDNN_INCLUDE_DIR $ENV{CUDNN_INCLUDE_DIR} CACHE PATH "Folder containing NVIDIA cuDNN header files")
27
28find_path(CUDNN_INCLUDE_PATH cudnn.h
29  HINTS ${CUDNN_INCLUDE_DIR}
30  PATH_SUFFIXES cuda/include cuda include)
31
32option(CUDNN_STATIC "Look for static CUDNN" OFF)
33if (CUDNN_STATIC)
34  set(CUDNN_LIBNAME "libcudnn_static.a")
35else()
36  set(CUDNN_LIBNAME "cudnn")
37endif()
38
39set(CUDNN_LIBRARY $ENV{CUDNN_LIBRARY} CACHE PATH "Path to the cudnn library file (e.g., libcudnn.so)")
40if (CUDNN_LIBRARY MATCHES ".*cudnn_static.a" AND NOT CUDNN_STATIC)
41  message(WARNING "CUDNN_LIBRARY points to a static library (${CUDNN_LIBRARY}) but CUDNN_STATIC is OFF.")
42endif()
43
44find_library(CUDNN_LIBRARY_PATH ${CUDNN_LIBNAME}
45  PATHS ${CUDNN_LIBRARY}
46  PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)
47
48find_package_handle_standard_args(CUDNN DEFAULT_MSG CUDNN_LIBRARY_PATH CUDNN_INCLUDE_PATH)
49
50if(CUDNN_FOUND)
51  # Get cuDNN version
52  if(EXISTS ${CUDNN_INCLUDE_PATH}/cudnn_version.h)
53    file(READ ${CUDNN_INCLUDE_PATH}/cudnn_version.h CUDNN_HEADER_CONTENTS)
54  else()
55    file(READ ${CUDNN_INCLUDE_PATH}/cudnn.h CUDNN_HEADER_CONTENTS)
56  endif()
57  string(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)"
58               CUDNN_VERSION_MAJOR "${CUDNN_HEADER_CONTENTS}")
59  string(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1"
60               CUDNN_VERSION_MAJOR "${CUDNN_VERSION_MAJOR}")
61  string(REGEX MATCH "define CUDNN_MINOR * +([0-9]+)"
62               CUDNN_VERSION_MINOR "${CUDNN_HEADER_CONTENTS}")
63  string(REGEX REPLACE "define CUDNN_MINOR * +([0-9]+)" "\\1"
64               CUDNN_VERSION_MINOR "${CUDNN_VERSION_MINOR}")
65  string(REGEX MATCH "define CUDNN_PATCHLEVEL * +([0-9]+)"
66               CUDNN_VERSION_PATCH "${CUDNN_HEADER_CONTENTS}")
67  string(REGEX REPLACE "define CUDNN_PATCHLEVEL * +([0-9]+)" "\\1"
68               CUDNN_VERSION_PATCH "${CUDNN_VERSION_PATCH}")
69  # Assemble cuDNN version
70  if(NOT CUDNN_VERSION_MAJOR)
71    set(CUDNN_VERSION "?")
72  else()
73    set(CUDNN_VERSION
74        "${CUDNN_VERSION_MAJOR}.${CUDNN_VERSION_MINOR}.${CUDNN_VERSION_PATCH}")
75  endif()
76endif()
77
78mark_as_advanced(CUDNN_ROOT CUDNN_INCLUDE_DIR CUDNN_LIBRARY CUDNN_VERSION)
79