1 // The clang-tidy job seems to complain that it can't find cudnn.h without this.
2 // This file should only be compiled if this condition holds, so it should be
3 // safe.
4 #if defined(USE_CUDNN) || defined(USE_ROCM)
5 #include <torch/csrc/utils/pybind.h>
6
7 #include <array>
8 #include <tuple>
9
10 namespace {
11 using version_tuple = std::tuple<size_t, size_t, size_t>;
12 }
13
14 #ifdef USE_CUDNN
15 #include <cudnn.h>
16
17 namespace {
18
getCompileVersion()19 version_tuple getCompileVersion() {
20 return version_tuple(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
21 }
22
getRuntimeVersion()23 version_tuple getRuntimeVersion() {
24 #ifndef USE_STATIC_CUDNN
25 int major, minor, patch;
26 cudnnGetProperty(MAJOR_VERSION, &major);
27 cudnnGetProperty(MINOR_VERSION, &minor);
28 cudnnGetProperty(PATCH_LEVEL, &patch);
29 return version_tuple((size_t)major, (size_t)minor, (size_t)patch);
30 #else
31 return getCompileVersion();
32 #endif
33 }
34
getVersionInt()35 size_t getVersionInt() {
36 #ifndef USE_STATIC_CUDNN
37 return cudnnGetVersion();
38 #else
39 return CUDNN_VERSION;
40 #endif
41 }
42
43 } // namespace
44 #elif defined(USE_ROCM)
45 #include <miopen/miopen.h>
46 #include <miopen/version.h>
47
48 namespace {
49
getCompileVersion()50 version_tuple getCompileVersion() {
51 return version_tuple(
52 MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR, MIOPEN_VERSION_PATCH);
53 }
54
getRuntimeVersion()55 version_tuple getRuntimeVersion() {
56 // MIOpen doesn't include runtime version info before 2.3.0
57 #if (MIOPEN_VERSION_MAJOR > 2) || \
58 (MIOPEN_VERSION_MAJOR == 2 && MIOPEN_VERSION_MINOR > 2)
59 size_t major, minor, patch;
60 miopenGetVersion(&major, &minor, &patch);
61 return version_tuple(major, minor, patch);
62 #else
63 return getCompileVersion();
64 #endif
65 }
66
getVersionInt()67 size_t getVersionInt() {
68 // miopen version is MAJOR*1000000 + MINOR*1000 + PATCH
69 auto [major, minor, patch] = getRuntimeVersion();
70 return major * 1000000 + minor * 1000 + patch;
71 }
72
73 } // namespace
74 #endif
75
76 namespace torch::cuda::shared {
77
initCudnnBindings(PyObject * module)78 void initCudnnBindings(PyObject* module) {
79 auto m = py::handle(module).cast<py::module>();
80
81 auto cudnn = m.def_submodule("_cudnn", "libcudnn.so bindings");
82
83 py::enum_<cudnnRNNMode_t>(cudnn, "RNNMode")
84 .value("rnn_relu", CUDNN_RNN_RELU)
85 .value("rnn_tanh", CUDNN_RNN_TANH)
86 .value("lstm", CUDNN_LSTM)
87 .value("gru", CUDNN_GRU);
88
89 // The runtime version check in python needs to distinguish cudnn from miopen
90 #ifdef USE_CUDNN
91 cudnn.attr("is_cuda") = true;
92 #else
93 cudnn.attr("is_cuda") = false;
94 #endif
95
96 cudnn.def("getRuntimeVersion", getRuntimeVersion);
97 cudnn.def("getCompileVersion", getCompileVersion);
98 cudnn.def("getVersionInt", getVersionInt);
99 }
100
101 } // namespace torch::cuda::shared
102 #endif
103