1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TF_GPU_RUNTIME_WRAPPERS_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TF_GPU_RUNTIME_WRAPPERS_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "mlir/ExecutionEngine/RunnerUtils.h" // from @llvm-project 21 #include "tensorflow/core/framework/resource_op_kernel.h" 22 #include "tensorflow/core/platform/mutex.h" 23 24 #if GOOGLE_CUDA 25 #include "third_party/gpus/cuda/include/cuda.h" 26 #endif 27 #if TENSORFLOW_USE_ROCM 28 #include "rocm/include/hip/hip_runtime.h" 29 #endif 30 31 namespace mlir { 32 namespace kernel_gen { 33 namespace tf_framework { 34 35 class GPURuntimeCache : public tensorflow::ResourceBase { 36 public: 37 #if GOOGLE_CUDA 38 using GPUModule = CUmodule; 39 #endif 40 #if TENSORFLOW_USE_ROCM 41 using GPUModule = hipModule_t; 42 #endif 43 44 ~GPURuntimeCache() override; 45 static constexpr const char* kDefaultResourceName = "mlir-gpu-runtime-cache"; 46 static tensorflow::Status Create(GPURuntimeCache** dst); 47 std::string DebugString() const override; 48 49 // Assumes that no two modules are loaded from the same memory location over 50 // the lifetime of this cache. This allows to use the pointer as a key. All 51 // modules are unloaded on destruction of this cache. 52 GPUModule LookupOrLoadModule(void* data); 53 54 private: 55 tensorflow::mutex mu_; 56 absl::flat_hash_map<void*, GPUModule> gpu_module_by_data_ptr_ 57 TF_GUARDED_BY(mu_); 58 }; 59 60 // Implements a C wrapper around the TensorFlow runtime and CUDA (or ROCm) 61 // library that allows launching a kernel on the current device and stream from 62 // a binary blob for the module and function name. 63 extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_launch_kernel( 64 void* ctx, void* module_blob, char* kernel_name, intptr_t gridX, 65 intptr_t gridY, intptr_t gridZ, intptr_t blockX, intptr_t blockY, 66 intptr_t blockZ, void** params); 67 68 } // namespace tf_framework 69 } // namespace kernel_gen 70 } // namespace mlir 71 72 #endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TF_GPU_RUNTIME_WRAPPERS_H_ 73