xref: /aosp_15_r20/external/tensorflow/tensorflow/stream_executor/rocm/hipsparse_wrapper.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file wraps hipsparse API calls with dso loader so that we don't need to
17 // have explicit linking to libhipsparse. All TF hipsarse API usage should route
18 // through this wrapper.
19 
20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
22 
23 #include "rocm/include/hipsparse/hipsparse.h"
24 #include "tensorflow/stream_executor/lib/env.h"
25 #include "tensorflow/stream_executor/platform/dso_loader.h"
26 #include "tensorflow/stream_executor/platform/port.h"
27 
28 namespace tensorflow {
29 namespace wrap {
30 
31 #ifdef PLATFORM_GOOGLE
32 
33 #define HIPSPARSE_API_WRAPPER(__name)               \
34   struct WrapperShim__##__name {                    \
35     template <typename... Args>                     \
36     hipsparseStatus_t operator()(Args... args) {    \
37       hipSparseStatus_t retval = ::__name(args...); \
38       return retval;                                \
39     }                                               \
40   } __name;
41 
42 #else
43 
44 #define HIPSPARSE_API_WRAPPER(__name)                                          \
45   struct DynLoadShim__##__name {                                               \
46     static const char* kName;                                                  \
47     using FuncPtrT = std::add_pointer<decltype(::__name)>::type;               \
48     static void* GetDsoHandle() {                                              \
49       auto s =                                                                 \
50           stream_executor::internal::CachedDsoLoader::GetHipsparseDsoHandle(); \
51       return s.ValueOrDie();                                                   \
52     }                                                                          \
53     static FuncPtrT LoadOrDie() {                                              \
54       void* f;                                                                 \
55       auto s =                                                                 \
56           Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), kName, &f);     \
57       CHECK(s.ok()) << "could not find " << kName                              \
58                     << " in miopen DSO; dlerror: " << s.error_message();       \
59       return reinterpret_cast<FuncPtrT>(f);                                    \
60     }                                                                          \
61     static FuncPtrT DynLoad() {                                                \
62       static FuncPtrT f = LoadOrDie();                                         \
63       return f;                                                                \
64     }                                                                          \
65     template <typename... Args>                                                \
66     hipsparseStatus_t operator()(Args... args) {                               \
67       return DynLoad()(args...);                                               \
68     }                                                                          \
69   } __name;                                                                    \
70   const char* DynLoadShim__##__name::kName = #__name;
71 
72 #endif
73 
74 // clang-format off
75 #define FOREACH_HIPSPARSE_API(__macro)          \
76   __macro(hipsparseCreate)                      \
77   __macro(hipsparseCreateMatDescr)              \
78   __macro(hipsparseCcsr2csc)                    \
79   __macro(hipsparseCcsrgeam2)                   \
80   __macro(hipsparseCcsrgeam2_bufferSizeExt)     \
81   __macro(hipsparseCcsrgemm)                    \
82   __macro(hipsparseCcsrmm)                      \
83   __macro(hipsparseCcsrmm2)                     \
84   __macro(hipsparseCcsrmv)                      \
85   __macro(hipsparseDcsr2csc)                    \
86   __macro(hipsparseDcsrgeam2)                   \
87   __macro(hipsparseDcsrgeam2_bufferSizeExt)     \
88   __macro(hipsparseDcsrgemm)                    \
89   __macro(hipsparseDcsrmm)                      \
90   __macro(hipsparseDcsrmm2)                     \
91   __macro(hipsparseDcsrmv)                      \
92   __macro(hipsparseDestroy)                     \
93   __macro(hipsparseDestroyMatDescr)             \
94   __macro(hipsparseScsr2csc)                    \
95   __macro(hipsparseScsrgeam2)                   \
96   __macro(hipsparseScsrgeam2_bufferSizeExt)     \
97   __macro(hipsparseScsrgemm)                    \
98   __macro(hipsparseScsrmm)                      \
99   __macro(hipsparseScsrmm2)                     \
100   __macro(hipsparseScsrmv)                      \
101   __macro(hipsparseSetStream)                   \
102   __macro(hipsparseSetMatIndexBase)             \
103   __macro(hipsparseSetMatType)                  \
104   __macro(hipsparseXcoo2csr)                    \
105   __macro(hipsparseXcsr2coo)                    \
106   __macro(hipsparseXcsrgeam2Nnz)                \
107   __macro(hipsparseXcsrgemmNnz)                 \
108   __macro(hipsparseZcsr2csc)                    \
109   __macro(hipsparseZcsrgeam2)                   \
110   __macro(hipsparseZcsrgeam2_bufferSizeExt)     \
111   __macro(hipsparseZcsrgemm)                    \
112   __macro(hipsparseZcsrmm)                      \
113   __macro(hipsparseZcsrmm2)                     \
114   __macro(hipsparseZcsrmv)
115 
116 #if TF_ROCM_VERSION >= 40200
117 #define FOREACH_HIPSPARSE_ROCM42_API(__macro)   \
118   __macro(hipsparseCcsru2csr_bufferSizeExt)     \
119   __macro(hipsparseCcsru2csr)                   \
120   __macro(hipsparseCreateCsr)                   \
121   __macro(hipsparseCreateDnMat)                 \
122   __macro(hipsparseDestroyDnMat)                \
123   __macro(hipsparseDestroySpMat)                \
124   __macro(hipsparseDcsru2csr_bufferSizeExt)     \
125   __macro(hipsparseDcsru2csr)                   \
126   __macro(hipsparseScsru2csr_bufferSizeExt)     \
127   __macro(hipsparseScsru2csr)                   \
128   __macro(hipsparseSpMM_bufferSize)             \
129   __macro(hipsparseSpMM)                        \
130   __macro(hipsparseZcsru2csr_bufferSizeExt)     \
131   __macro(hipsparseZcsru2csr)
132 
133 
134 FOREACH_HIPSPARSE_ROCM42_API(HIPSPARSE_API_WRAPPER)
135 
136 #undef FOREACH_HIPSPARSE_ROCM42_API
137 #endif
138 
139 // clang-format on
140 
141 FOREACH_HIPSPARSE_API(HIPSPARSE_API_WRAPPER)
142 
143 #undef FOREACH_HIPSPARSE_API
144 #undef HIPSPARSE_API_WRAPPER
145 
146 }  // namespace wrap
147 }  // namespace tensorflow
148 
149 #endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
150