xref: /aosp_15_r20/external/pytorch/torch/csrc/cuda/shared/cudnn.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // The clang-tidy job seems to complain that it can't find cudnn.h without this.
2 // This file should only be compiled if this condition holds, so it should be
3 // safe.
4 #if defined(USE_CUDNN) || defined(USE_ROCM)
5 #include <torch/csrc/utils/pybind.h>
6 
7 #include <array>
8 #include <tuple>
9 
10 namespace {
11 using version_tuple = std::tuple<size_t, size_t, size_t>;
12 }
13 
14 #ifdef USE_CUDNN
15 #include <cudnn.h>
16 
17 namespace {
18 
getCompileVersion()19 version_tuple getCompileVersion() {
20   return version_tuple(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
21 }
22 
getRuntimeVersion()23 version_tuple getRuntimeVersion() {
24 #ifndef USE_STATIC_CUDNN
25   int major, minor, patch;
26   cudnnGetProperty(MAJOR_VERSION, &major);
27   cudnnGetProperty(MINOR_VERSION, &minor);
28   cudnnGetProperty(PATCH_LEVEL, &patch);
29   return version_tuple((size_t)major, (size_t)minor, (size_t)patch);
30 #else
31   return getCompileVersion();
32 #endif
33 }
34 
getVersionInt()35 size_t getVersionInt() {
36 #ifndef USE_STATIC_CUDNN
37   return cudnnGetVersion();
38 #else
39   return CUDNN_VERSION;
40 #endif
41 }
42 
43 } // namespace
44 #elif defined(USE_ROCM)
45 #include <miopen/miopen.h>
46 #include <miopen/version.h>
47 
48 namespace {
49 
getCompileVersion()50 version_tuple getCompileVersion() {
51   return version_tuple(
52       MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR, MIOPEN_VERSION_PATCH);
53 }
54 
getRuntimeVersion()55 version_tuple getRuntimeVersion() {
56   // MIOpen doesn't include runtime version info before 2.3.0
57 #if (MIOPEN_VERSION_MAJOR > 2) || \
58     (MIOPEN_VERSION_MAJOR == 2 && MIOPEN_VERSION_MINOR > 2)
59   size_t major, minor, patch;
60   miopenGetVersion(&major, &minor, &patch);
61   return version_tuple(major, minor, patch);
62 #else
63   return getCompileVersion();
64 #endif
65 }
66 
getVersionInt()67 size_t getVersionInt() {
68   // miopen version is MAJOR*1000000 + MINOR*1000 + PATCH
69   auto [major, minor, patch] = getRuntimeVersion();
70   return major * 1000000 + minor * 1000 + patch;
71 }
72 
73 } // namespace
74 #endif
75 
76 namespace torch::cuda::shared {
77 
initCudnnBindings(PyObject * module)78 void initCudnnBindings(PyObject* module) {
79   auto m = py::handle(module).cast<py::module>();
80 
81   auto cudnn = m.def_submodule("_cudnn", "libcudnn.so bindings");
82 
83   py::enum_<cudnnRNNMode_t>(cudnn, "RNNMode")
84       .value("rnn_relu", CUDNN_RNN_RELU)
85       .value("rnn_tanh", CUDNN_RNN_TANH)
86       .value("lstm", CUDNN_LSTM)
87       .value("gru", CUDNN_GRU);
88 
89   // The runtime version check in python needs to distinguish cudnn from miopen
90 #ifdef USE_CUDNN
91   cudnn.attr("is_cuda") = true;
92 #else
93   cudnn.attr("is_cuda") = false;
94 #endif
95 
96   cudnn.def("getRuntimeVersion", getRuntimeVersion);
97   cudnn.def("getCompileVersion", getCompileVersion);
98   cudnn.def("getVersionInt", getVersionInt);
99 }
100 
101 } // namespace torch::cuda::shared
102 #endif
103