1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file wraps hipsparse API calls with dso loader so that we don't need to 17 // have explicit linking to libhipsparse. All TF hipsarse API usage should route 18 // through this wrapper. 19 20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_ 21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_ 22 23 #include "rocm/include/hipsparse/hipsparse.h" 24 #include "tensorflow/stream_executor/lib/env.h" 25 #include "tensorflow/stream_executor/platform/dso_loader.h" 26 #include "tensorflow/stream_executor/platform/port.h" 27 28 namespace tensorflow { 29 namespace wrap { 30 31 #ifdef PLATFORM_GOOGLE 32 33 #define HIPSPARSE_API_WRAPPER(__name) \ 34 struct WrapperShim__##__name { \ 35 template <typename... Args> \ 36 hipsparseStatus_t operator()(Args... args) { \ 37 hipSparseStatus_t retval = ::__name(args...); \ 38 return retval; \ 39 } \ 40 } __name; 41 42 #else 43 44 #define HIPSPARSE_API_WRAPPER(__name) \ 45 struct DynLoadShim__##__name { \ 46 static const char* kName; \ 47 using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \ 48 static void* GetDsoHandle() { \ 49 auto s = \ 50 stream_executor::internal::CachedDsoLoader::GetHipsparseDsoHandle(); \ 51 return s.ValueOrDie(); \ 52 } \ 53 static FuncPtrT LoadOrDie() { \ 54 void* f; \ 55 auto s = \ 56 Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), kName, &f); \ 57 CHECK(s.ok()) << "could not find " << kName \ 58 << " in miopen DSO; dlerror: " << s.error_message(); \ 59 return reinterpret_cast<FuncPtrT>(f); \ 60 } \ 61 static FuncPtrT DynLoad() { \ 62 static FuncPtrT f = LoadOrDie(); \ 63 return f; \ 64 } \ 65 template <typename... Args> \ 66 hipsparseStatus_t operator()(Args... args) { \ 67 return DynLoad()(args...); \ 68 } \ 69 } __name; \ 70 const char* DynLoadShim__##__name::kName = #__name; 71 72 #endif 73 74 // clang-format off 75 #define FOREACH_HIPSPARSE_API(__macro) \ 76 __macro(hipsparseCreate) \ 77 __macro(hipsparseCreateMatDescr) \ 78 __macro(hipsparseCcsr2csc) \ 79 __macro(hipsparseCcsrgeam2) \ 80 __macro(hipsparseCcsrgeam2_bufferSizeExt) \ 81 __macro(hipsparseCcsrgemm) \ 82 __macro(hipsparseCcsrmm) \ 83 __macro(hipsparseCcsrmm2) \ 84 __macro(hipsparseCcsrmv) \ 85 __macro(hipsparseDcsr2csc) \ 86 __macro(hipsparseDcsrgeam2) \ 87 __macro(hipsparseDcsrgeam2_bufferSizeExt) \ 88 __macro(hipsparseDcsrgemm) \ 89 __macro(hipsparseDcsrmm) \ 90 __macro(hipsparseDcsrmm2) \ 91 __macro(hipsparseDcsrmv) \ 92 __macro(hipsparseDestroy) \ 93 __macro(hipsparseDestroyMatDescr) \ 94 __macro(hipsparseScsr2csc) \ 95 __macro(hipsparseScsrgeam2) \ 96 __macro(hipsparseScsrgeam2_bufferSizeExt) \ 97 __macro(hipsparseScsrgemm) \ 98 __macro(hipsparseScsrmm) \ 99 __macro(hipsparseScsrmm2) \ 100 __macro(hipsparseScsrmv) \ 101 __macro(hipsparseSetStream) \ 102 __macro(hipsparseSetMatIndexBase) \ 103 __macro(hipsparseSetMatType) \ 104 __macro(hipsparseXcoo2csr) \ 105 __macro(hipsparseXcsr2coo) \ 106 __macro(hipsparseXcsrgeam2Nnz) \ 107 __macro(hipsparseXcsrgemmNnz) \ 108 __macro(hipsparseZcsr2csc) \ 109 __macro(hipsparseZcsrgeam2) \ 110 __macro(hipsparseZcsrgeam2_bufferSizeExt) \ 111 __macro(hipsparseZcsrgemm) \ 112 __macro(hipsparseZcsrmm) \ 113 __macro(hipsparseZcsrmm2) \ 114 __macro(hipsparseZcsrmv) 115 116 #if TF_ROCM_VERSION >= 40200 117 #define FOREACH_HIPSPARSE_ROCM42_API(__macro) \ 118 __macro(hipsparseCcsru2csr_bufferSizeExt) \ 119 __macro(hipsparseCcsru2csr) \ 120 __macro(hipsparseCreateCsr) \ 121 __macro(hipsparseCreateDnMat) \ 122 __macro(hipsparseDestroyDnMat) \ 123 __macro(hipsparseDestroySpMat) \ 124 __macro(hipsparseDcsru2csr_bufferSizeExt) \ 125 __macro(hipsparseDcsru2csr) \ 126 __macro(hipsparseScsru2csr_bufferSizeExt) \ 127 __macro(hipsparseScsru2csr) \ 128 __macro(hipsparseSpMM_bufferSize) \ 129 __macro(hipsparseSpMM) \ 130 __macro(hipsparseZcsru2csr_bufferSizeExt) \ 131 __macro(hipsparseZcsru2csr) 132 133 134 FOREACH_HIPSPARSE_ROCM42_API(HIPSPARSE_API_WRAPPER) 135 136 #undef FOREACH_HIPSPARSE_ROCM42_API 137 #endif 138 139 // clang-format on 140 141 FOREACH_HIPSPARSE_API(HIPSPARSE_API_WRAPPER) 142 143 #undef FOREACH_HIPSPARSE_API 144 #undef HIPSPARSE_API_WRAPPER 145 146 } // namespace wrap 147 } // namespace tensorflow 148 149 #endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_ 150