1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.h"
17
18 #include <string>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/platform/mutex.h"
24 #include "tensorflow/stream_executor/stream.h"
25 #include "tensorflow/stream_executor/stream_executor_internal.h"
26
ReportInternalError(tensorflow::OpKernelContext * ctx,const std::string msg)27 static void ReportInternalError(tensorflow::OpKernelContext *ctx,
28 const std::string msg) {
29 if (ctx == nullptr) {
30 LOG(WARNING) << msg << "\n";
31 return;
32 }
33 ctx->CtxFailureWithWarning(
34 tensorflow::Status{tensorflow::error::INTERNAL, msg});
35 }
36
37 #if GOOGLE_CUDA
38 using GPUResult = CUresult;
39 #endif
40 #if TENSORFLOW_USE_ROCM
41 using GPUResult = hipError_t;
42 #endif
43
GPUReportIfError(GPUResult result,tensorflow::OpKernelContext * ctx,const char * expr_str)44 void GPUReportIfError(GPUResult result, tensorflow::OpKernelContext *ctx,
45 const char *expr_str) {
46 if (!result) return;
47 const char *name = nullptr;
48
49 #if GOOGLE_CUDA
50 cuGetErrorName(result, &name);
51 #endif
52 #if TENSORFLOW_USE_ROCM
53 name = hipGetErrorName(result);
54 #endif
55
56 if (!name) name = "<unknown>";
57 std::string msg = absl::StrCat("'", expr_str, "' failed with '", name, "'");
58 ReportInternalError(ctx, msg);
59 }
60
61 #define GPU_REPORT_IF_ERROR_WITH_CTX(expr, ctx) \
62 GPUReportIfError(expr, ctx, #expr)
63 #define GPU_REPORT_IF_ERROR(expr) GPU_REPORT_IF_ERROR_WITH_CTX(expr, nullptr)
64
65 // Implement the GPU module cache and share what can be shared.
66
67 namespace mlir {
68 namespace kernel_gen {
69 namespace tf_framework {
70
~GPURuntimeCache()71 GPURuntimeCache::~GPURuntimeCache() {
72 tensorflow::mutex_lock lock(mu_);
73 for (auto it : gpu_module_by_data_ptr_) {
74 #if GOOGLE_CUDA
75 GPU_REPORT_IF_ERROR(cuModuleUnload(it.second));
76 #endif
77 #if TENSORFLOW_USE_ROCM
78 GPU_REPORT_IF_ERROR(hipModuleUnload(it.second));
79 #endif
80 }
81 }
82
Create(GPURuntimeCache ** dst)83 tensorflow::Status GPURuntimeCache::Create(GPURuntimeCache **dst) {
84 *dst = new GPURuntimeCache;
85 return ::tensorflow::OkStatus();
86 }
87
DebugString() const88 std::string GPURuntimeCache::DebugString() const { return "GPU runtime cache"; }
89
LookupOrLoadModule(void * data)90 GPURuntimeCache::GPUModule GPURuntimeCache::LookupOrLoadModule(void *data) {
91 tensorflow::mutex_lock lock(mu_);
92 GPUModule &module = gpu_module_by_data_ptr_[data];
93
94 #if GOOGLE_CUDA
95 if (!module) GPU_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
96 #endif
97 #if TENSORFLOW_USE_ROCM
98 if (!module) GPU_REPORT_IF_ERROR(hipModuleLoadData(&module, data));
99 #endif
100
101 return module;
102 }
103
104 // Implements a C wrapper around the TensorFlow runtime and CUDA (or ROCm)
105 // library that allows launching a kernel on the current device and stream from
106 // a binary blob for the module and function name.
107 // The wrapper uses intptr_t instead of CUDA's unsigned int (or ROCm's unsigned
108 // int) to match the type of MLIR's index type. This avoids the need for casts
109 // in the generated MLIR code.
_mlir_ciface_tf_launch_kernel(void * ctx,void * module_blob,char * kernel_name,intptr_t gridX,intptr_t gridY,intptr_t gridZ,intptr_t blockX,intptr_t blockY,intptr_t blockZ,void ** params)110 extern "C" void _mlir_ciface_tf_launch_kernel(void *ctx, void *module_blob,
111 char *kernel_name, intptr_t gridX,
112 intptr_t gridY, intptr_t gridZ,
113 intptr_t blockX, intptr_t blockY,
114 intptr_t blockZ, void **params) {
115 // For empty grids, we don't need to do anything.
116 if (!gridX || !gridY || !gridZ) return;
117
118 // Get the GPU module cache.
119 auto *op_kernel_ctx = static_cast<tensorflow::OpKernelContext *>(ctx);
120 auto *rm = op_kernel_ctx->resource_manager();
121 if (rm == nullptr) {
122 ReportInternalError(op_kernel_ctx, "expected resource_manager");
123 return;
124 }
125 GPURuntimeCache *cache = nullptr;
126 OP_REQUIRES_OK(op_kernel_ctx, rm->LookupOrCreate<GPURuntimeCache>(
127 rm->default_container(),
128 GPURuntimeCache::kDefaultResourceName,
129 &cache, GPURuntimeCache::Create));
130 assert(cache != nullptr && "cache creation must not fail");
131 tensorflow::core::ScopedUnref ref(cache);
132
133 // Get the GPU module.
134 stream_executor::Stream *se_stream =
135 op_kernel_ctx->op_device_context()->stream();
136 void *stream = se_stream->implementation()->GpuStreamHack();
137 GPURuntimeCache::GPUModule module = cache->LookupOrLoadModule(module_blob);
138
139 #if GOOGLE_CUDA
140 CUfunction function;
141 GPU_REPORT_IF_ERROR_WITH_CTX(
142 cuModuleGetFunction(&function, module, kernel_name), op_kernel_ctx);
143 GPU_REPORT_IF_ERROR_WITH_CTX(
144 cuLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ,
145 /*sharedMemBytes=*/0, reinterpret_cast<CUstream>(stream),
146 params, nullptr),
147 op_kernel_ctx);
148 #endif
149 #if TENSORFLOW_USE_ROCM
150 hipFunction_t function;
151 GPU_REPORT_IF_ERROR_WITH_CTX(
152 hipModuleGetFunction(&function, module, kernel_name), op_kernel_ctx);
153 GPU_REPORT_IF_ERROR_WITH_CTX(
154 hipModuleLaunchKernel(
155 function, gridX, gridY, gridZ, blockX, blockY, blockZ,
156 /*sharedMemBytes=*/0, reinterpret_cast<hipStream_t>(stream), params,
157 nullptr),
158 op_kernel_ctx);
159 #endif
160 }
161
162 } // namespace tf_framework
163 } // namespace kernel_gen
164 } // namespace mlir
165