xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.h"
17 
18 #include <string>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/platform/mutex.h"
24 #include "tensorflow/stream_executor/stream.h"
25 #include "tensorflow/stream_executor/stream_executor_internal.h"
26 
ReportInternalError(tensorflow::OpKernelContext * ctx,const std::string msg)27 static void ReportInternalError(tensorflow::OpKernelContext *ctx,
28                                 const std::string msg) {
29   if (ctx == nullptr) {
30     LOG(WARNING) << msg << "\n";
31     return;
32   }
33   ctx->CtxFailureWithWarning(
34       tensorflow::Status{tensorflow::error::INTERNAL, msg});
35 }
36 
37 #if GOOGLE_CUDA
38 using GPUResult = CUresult;
39 #endif
40 #if TENSORFLOW_USE_ROCM
41 using GPUResult = hipError_t;
42 #endif
43 
GPUReportIfError(GPUResult result,tensorflow::OpKernelContext * ctx,const char * expr_str)44 void GPUReportIfError(GPUResult result, tensorflow::OpKernelContext *ctx,
45                       const char *expr_str) {
46   if (!result) return;
47   const char *name = nullptr;
48 
49 #if GOOGLE_CUDA
50   cuGetErrorName(result, &name);
51 #endif
52 #if TENSORFLOW_USE_ROCM
53   name = hipGetErrorName(result);
54 #endif
55 
56   if (!name) name = "<unknown>";
57   std::string msg = absl::StrCat("'", expr_str, "' failed with '", name, "'");
58   ReportInternalError(ctx, msg);
59 }
60 
61 #define GPU_REPORT_IF_ERROR_WITH_CTX(expr, ctx) \
62   GPUReportIfError(expr, ctx, #expr)
63 #define GPU_REPORT_IF_ERROR(expr) GPU_REPORT_IF_ERROR_WITH_CTX(expr, nullptr)
64 
65 // Implement the GPU module cache and share what can be shared.
66 
67 namespace mlir {
68 namespace kernel_gen {
69 namespace tf_framework {
70 
~GPURuntimeCache()71 GPURuntimeCache::~GPURuntimeCache() {
72   tensorflow::mutex_lock lock(mu_);
73   for (auto it : gpu_module_by_data_ptr_) {
74 #if GOOGLE_CUDA
75     GPU_REPORT_IF_ERROR(cuModuleUnload(it.second));
76 #endif
77 #if TENSORFLOW_USE_ROCM
78     GPU_REPORT_IF_ERROR(hipModuleUnload(it.second));
79 #endif
80   }
81 }
82 
Create(GPURuntimeCache ** dst)83 tensorflow::Status GPURuntimeCache::Create(GPURuntimeCache **dst) {
84   *dst = new GPURuntimeCache;
85   return ::tensorflow::OkStatus();
86 }
87 
DebugString() const88 std::string GPURuntimeCache::DebugString() const { return "GPU runtime cache"; }
89 
LookupOrLoadModule(void * data)90 GPURuntimeCache::GPUModule GPURuntimeCache::LookupOrLoadModule(void *data) {
91   tensorflow::mutex_lock lock(mu_);
92   GPUModule &module = gpu_module_by_data_ptr_[data];
93 
94 #if GOOGLE_CUDA
95   if (!module) GPU_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
96 #endif
97 #if TENSORFLOW_USE_ROCM
98   if (!module) GPU_REPORT_IF_ERROR(hipModuleLoadData(&module, data));
99 #endif
100 
101   return module;
102 }
103 
104 // Implements a C wrapper around the TensorFlow runtime and CUDA (or ROCm)
105 // library that allows launching a kernel on the current device and stream from
106 // a binary blob for the module and function name.
107 // The wrapper uses intptr_t instead of CUDA's unsigned int (or ROCm's unsigned
108 // int) to match the type of MLIR's index type. This avoids the need for casts
109 // in the generated MLIR code.
_mlir_ciface_tf_launch_kernel(void * ctx,void * module_blob,char * kernel_name,intptr_t gridX,intptr_t gridY,intptr_t gridZ,intptr_t blockX,intptr_t blockY,intptr_t blockZ,void ** params)110 extern "C" void _mlir_ciface_tf_launch_kernel(void *ctx, void *module_blob,
111                                               char *kernel_name, intptr_t gridX,
112                                               intptr_t gridY, intptr_t gridZ,
113                                               intptr_t blockX, intptr_t blockY,
114                                               intptr_t blockZ, void **params) {
115   // For empty grids, we don't need to do anything.
116   if (!gridX || !gridY || !gridZ) return;
117 
118   // Get the GPU module cache.
119   auto *op_kernel_ctx = static_cast<tensorflow::OpKernelContext *>(ctx);
120   auto *rm = op_kernel_ctx->resource_manager();
121   if (rm == nullptr) {
122     ReportInternalError(op_kernel_ctx, "expected resource_manager");
123     return;
124   }
125   GPURuntimeCache *cache = nullptr;
126   OP_REQUIRES_OK(op_kernel_ctx, rm->LookupOrCreate<GPURuntimeCache>(
127                                     rm->default_container(),
128                                     GPURuntimeCache::kDefaultResourceName,
129                                     &cache, GPURuntimeCache::Create));
130   assert(cache != nullptr && "cache creation must not fail");
131   tensorflow::core::ScopedUnref ref(cache);
132 
133   // Get the GPU module.
134   stream_executor::Stream *se_stream =
135       op_kernel_ctx->op_device_context()->stream();
136   void *stream = se_stream->implementation()->GpuStreamHack();
137   GPURuntimeCache::GPUModule module = cache->LookupOrLoadModule(module_blob);
138 
139 #if GOOGLE_CUDA
140   CUfunction function;
141   GPU_REPORT_IF_ERROR_WITH_CTX(
142       cuModuleGetFunction(&function, module, kernel_name), op_kernel_ctx);
143   GPU_REPORT_IF_ERROR_WITH_CTX(
144       cuLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ,
145                      /*sharedMemBytes=*/0, reinterpret_cast<CUstream>(stream),
146                      params, nullptr),
147       op_kernel_ctx);
148 #endif
149 #if TENSORFLOW_USE_ROCM
150   hipFunction_t function;
151   GPU_REPORT_IF_ERROR_WITH_CTX(
152       hipModuleGetFunction(&function, module, kernel_name), op_kernel_ctx);
153   GPU_REPORT_IF_ERROR_WITH_CTX(
154       hipModuleLaunchKernel(
155           function, gridX, gridY, gridZ, blockX, blockY, blockZ,
156           /*sharedMemBytes=*/0, reinterpret_cast<hipStream_t>(stream), params,
157           nullptr),
158       op_kernel_ctx);
159 #endif
160 }
161 
162 }  // namespace tf_framework
163 }  // namespace kernel_gen
164 }  // namespace mlir
165