xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.h"
16 
17 #include <stdlib.h>
18 
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/string_view.h"
21 #include "third_party/gpus/cuda/cuda_config.h"
22 #include "tensorflow/compiler/xla/stream_executor/lib/env.h"
23 #include "tensorflow/compiler/xla/stream_executor/lib/error.h"
24 #include "tensorflow/compiler/xla/stream_executor/lib/path.h"
25 #include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
26 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
27 #include "third_party/tensorrt/tensorrt_config.h"
28 
29 #if TENSORFLOW_USE_ROCM
30 #include "rocm/rocm_config.h"
31 #endif
32 
33 namespace stream_executor {
34 namespace internal {
35 
36 namespace {
GetCudaVersion()37 string GetCudaVersion() { return TF_CUDA_VERSION; }
GetCudaRtVersion()38 string GetCudaRtVersion() { return TF_CUDART_VERSION; }
GetCudnnVersion()39 string GetCudnnVersion() { return TF_CUDNN_VERSION; }
GetCublasVersion()40 string GetCublasVersion() { return TF_CUBLAS_VERSION; }
GetCusolverVersion()41 string GetCusolverVersion() { return TF_CUSOLVER_VERSION; }
GetCurandVersion()42 string GetCurandVersion() { return TF_CURAND_VERSION; }
GetCufftVersion()43 string GetCufftVersion() { return TF_CUFFT_VERSION; }
GetCusparseVersion()44 string GetCusparseVersion() { return TF_CUSPARSE_VERSION; }
GetTensorRTVersion()45 string GetTensorRTVersion() { return TF_TENSORRT_VERSION; }
46 
GetDsoHandle(const string & name,const string & version)47 port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
48   auto filename = port::Env::Default()->FormatLibraryFileName(name, version);
49   void* dso_handle;
50   port::Status status =
51       port::Env::Default()->LoadDynamicLibrary(filename.c_str(), &dso_handle);
52   if (status.ok()) {
53     VLOG(1) << "Successfully opened dynamic library " << filename;
54     return dso_handle;
55   }
56 
57   auto message = absl::StrCat("Could not load dynamic library '", filename,
58                               "'; dlerror: ", status.error_message());
59 #if !defined(PLATFORM_WINDOWS)
60   if (const char* ld_library_path = getenv("LD_LIBRARY_PATH")) {
61     message += absl::StrCat("; LD_LIBRARY_PATH: ", ld_library_path);
62   }
63 #endif
64   LOG(WARNING) << message;
65   return port::Status(port::error::FAILED_PRECONDITION, message);
66 }
67 }  // namespace
68 
69 namespace DsoLoader {
GetCudaDriverDsoHandle()70 port::StatusOr<void*> GetCudaDriverDsoHandle() {
71 #if defined(PLATFORM_WINDOWS)
72   return GetDsoHandle("nvcuda", "");
73 #elif defined(__APPLE__)
74   // On Mac OS X, CUDA sometimes installs libcuda.dylib instead of
75   // libcuda.1.dylib.
76   auto handle_or = GetDsoHandle("cuda", "");
77   if (handle_or.ok()) {
78     return handle_or;
79   }
80 #endif
81   return GetDsoHandle("cuda", "1");
82 }
83 
GetCudaRuntimeDsoHandle()84 port::StatusOr<void*> GetCudaRuntimeDsoHandle() {
85   return GetDsoHandle("cudart", GetCudaRtVersion());
86 }
87 
GetCublasDsoHandle()88 port::StatusOr<void*> GetCublasDsoHandle() {
89   return GetDsoHandle("cublas", GetCublasVersion());
90 }
91 
GetCublasLtDsoHandle()92 port::StatusOr<void*> GetCublasLtDsoHandle() {
93   return GetDsoHandle("cublasLt", GetCublasVersion());
94 }
95 
GetCufftDsoHandle()96 port::StatusOr<void*> GetCufftDsoHandle() {
97   return GetDsoHandle("cufft", GetCufftVersion());
98 }
99 
GetCusolverDsoHandle()100 port::StatusOr<void*> GetCusolverDsoHandle() {
101   return GetDsoHandle("cusolver", GetCusolverVersion());
102 }
103 
GetCusparseDsoHandle()104 port::StatusOr<void*> GetCusparseDsoHandle() {
105   return GetDsoHandle("cusparse", GetCusparseVersion());
106 }
107 
GetCurandDsoHandle()108 port::StatusOr<void*> GetCurandDsoHandle() {
109   return GetDsoHandle("curand", GetCurandVersion());
110 }
111 
GetCuptiDsoHandle()112 port::StatusOr<void*> GetCuptiDsoHandle() {
113   // Load specific version of CUPTI this is built.
114   auto status_or_handle = GetDsoHandle("cupti", GetCudaVersion());
115   if (status_or_handle.ok()) return status_or_handle;
116   // Load whatever libcupti.so user specified.
117   return GetDsoHandle("cupti", "");
118 }
119 
GetCudnnDsoHandle()120 port::StatusOr<void*> GetCudnnDsoHandle() {
121   return GetDsoHandle("cudnn", GetCudnnVersion());
122 }
123 
GetNvInferDsoHandle()124 port::StatusOr<void*> GetNvInferDsoHandle() {
125 #if defined(PLATFORM_WINDOWS)
126   return GetDsoHandle("nvinfer", "");
127 #else
128   return GetDsoHandle("nvinfer", GetTensorRTVersion());
129 #endif
130 }
131 
GetNvInferPluginDsoHandle()132 port::StatusOr<void*> GetNvInferPluginDsoHandle() {
133 #if defined(PLATFORM_WINDOWS)
134   return GetDsoHandle("nvinfer_plugin", "");
135 #else
136   return GetDsoHandle("nvinfer_plugin", GetTensorRTVersion());
137 #endif
138 }
139 
GetRocblasDsoHandle()140 port::StatusOr<void*> GetRocblasDsoHandle() {
141   return GetDsoHandle("rocblas", "");
142 }
143 
GetMiopenDsoHandle()144 port::StatusOr<void*> GetMiopenDsoHandle() {
145   return GetDsoHandle("MIOpen", "");
146 }
147 
GetHipfftDsoHandle()148 port::StatusOr<void*> GetHipfftDsoHandle() {
149   return GetDsoHandle("hipfft", "");
150 }
151 
GetRocrandDsoHandle()152 port::StatusOr<void*> GetRocrandDsoHandle() {
153   return GetDsoHandle("rocrand", "");
154 }
155 
GetRocsolverDsoHandle()156 port::StatusOr<void*> GetRocsolverDsoHandle() {
157   return GetDsoHandle("rocsolver", "");
158 }
159 
160 #if TF_ROCM_VERSION >= 40500
GetHipsolverDsoHandle()161 port::StatusOr<void*> GetHipsolverDsoHandle() {
162   return GetDsoHandle("hipsolver", "");
163 }
164 #endif
165 
GetRoctracerDsoHandle()166 port::StatusOr<void*> GetRoctracerDsoHandle() {
167   return GetDsoHandle("roctracer64", "");
168 }
169 
GetHipsparseDsoHandle()170 port::StatusOr<void*> GetHipsparseDsoHandle() {
171   return GetDsoHandle("hipsparse", "");
172 }
173 
GetHipDsoHandle()174 port::StatusOr<void*> GetHipDsoHandle() { return GetDsoHandle("amdhip64", ""); }
175 
176 }  // namespace DsoLoader
177 
178 namespace CachedDsoLoader {
GetCudaDriverDsoHandle()179 port::StatusOr<void*> GetCudaDriverDsoHandle() {
180   static auto result = new auto(DsoLoader::GetCudaDriverDsoHandle());
181   return *result;
182 }
183 
GetCudaRuntimeDsoHandle()184 port::StatusOr<void*> GetCudaRuntimeDsoHandle() {
185   static auto result = new auto(DsoLoader::GetCudaRuntimeDsoHandle());
186   return *result;
187 }
188 
GetCublasDsoHandle()189 port::StatusOr<void*> GetCublasDsoHandle() {
190   static auto result = new auto(DsoLoader::GetCublasDsoHandle());
191   return *result;
192 }
193 
GetCublasLtDsoHandle()194 port::StatusOr<void*> GetCublasLtDsoHandle() {
195   static auto result = new auto(DsoLoader::GetCublasLtDsoHandle());
196   return *result;
197 }
198 
GetCurandDsoHandle()199 port::StatusOr<void*> GetCurandDsoHandle() {
200   static auto result = new auto(DsoLoader::GetCurandDsoHandle());
201   return *result;
202 }
203 
GetCufftDsoHandle()204 port::StatusOr<void*> GetCufftDsoHandle() {
205   static auto result = new auto(DsoLoader::GetCufftDsoHandle());
206   return *result;
207 }
208 
GetCusolverDsoHandle()209 port::StatusOr<void*> GetCusolverDsoHandle() {
210   static auto result = new auto(DsoLoader::GetCusolverDsoHandle());
211   return *result;
212 }
213 
GetCusparseDsoHandle()214 port::StatusOr<void*> GetCusparseDsoHandle() {
215   static auto result = new auto(DsoLoader::GetCusparseDsoHandle());
216   return *result;
217 }
218 
GetCuptiDsoHandle()219 port::StatusOr<void*> GetCuptiDsoHandle() {
220   static auto result = new auto(DsoLoader::GetCuptiDsoHandle());
221   return *result;
222 }
223 
GetCudnnDsoHandle()224 port::StatusOr<void*> GetCudnnDsoHandle() {
225   static auto result = new auto(DsoLoader::GetCudnnDsoHandle());
226   return *result;
227 }
228 
GetRocblasDsoHandle()229 port::StatusOr<void*> GetRocblasDsoHandle() {
230   static auto result = new auto(DsoLoader::GetRocblasDsoHandle());
231   return *result;
232 }
233 
GetMiopenDsoHandle()234 port::StatusOr<void*> GetMiopenDsoHandle() {
235   static auto result = new auto(DsoLoader::GetMiopenDsoHandle());
236   return *result;
237 }
238 
GetHipfftDsoHandle()239 port::StatusOr<void*> GetHipfftDsoHandle() {
240   static auto result = new auto(DsoLoader::GetHipfftDsoHandle());
241   return *result;
242 }
243 
GetRocrandDsoHandle()244 port::StatusOr<void*> GetRocrandDsoHandle() {
245   static auto result = new auto(DsoLoader::GetRocrandDsoHandle());
246   return *result;
247 }
248 
GetRoctracerDsoHandle()249 port::StatusOr<void*> GetRoctracerDsoHandle() {
250   static auto result = new auto(DsoLoader::GetRoctracerDsoHandle());
251   return *result;
252 }
253 
GetRocsolverDsoHandle()254 port::StatusOr<void*> GetRocsolverDsoHandle() {
255   static auto result = new auto(DsoLoader::GetRocsolverDsoHandle());
256   return *result;
257 }
258 
259 #if TF_ROCM_VERSION >= 40500
GetHipsolverDsoHandle()260 port::StatusOr<void*> GetHipsolverDsoHandle() {
261   static auto result = new auto(DsoLoader::GetHipsolverDsoHandle());
262   return *result;
263 }
264 #endif
265 
GetHipsparseDsoHandle()266 port::StatusOr<void*> GetHipsparseDsoHandle() {
267   static auto result = new auto(DsoLoader::GetHipsparseDsoHandle());
268   return *result;
269 }
270 
GetHipDsoHandle()271 port::StatusOr<void*> GetHipDsoHandle() {
272   static auto result = new auto(DsoLoader::GetHipDsoHandle());
273   return *result;
274 }
275 
276 }  // namespace CachedDsoLoader
277 }  // namespace internal
278 }  // namespace stream_executor
279