1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h"
17
18 #include <cstddef>
19 #include <string>
20 #include <utility>
21
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/TargetSelect.h"
24 #include "mlir/ExecutionEngine/ExecutionEngine.h" // from @llvm-project
25 #include "mlir/ExecutionEngine/OptUtils.h" // from @llvm-project
26 #include "mlir/Parser/Parser.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/tools/kernel_gen/compile_cache_item.pb.h"
28 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
29 #include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h"
30 #include "tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/resource_mgr.h"
33 #include "tensorflow/core/lib/io/path.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/platform/statusor.h"
36 #include "tensorflow/stream_executor/stream.h"
37
38 #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
39 #include "tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.h"
40 #endif
41
42 static constexpr absl::string_view kTFJitCacheDirEnvVar = "TF_JIT_CACHE_DIR";
43
44 namespace mlir {
45 namespace kernel_gen {
46 namespace tf_framework {
47 namespace {
48
49 using tensorflow::Allocator;
50 using tensorflow::AllocatorAttributes;
51
GetAllocator(void * op_kernel_ctx)52 Allocator* GetAllocator(void* op_kernel_ctx) {
53 auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
54 // TODO(pifon): Figure out how to set AllocatorAttributes correctly.
55 AllocatorAttributes attrs;
56 return ctx->get_allocator(attrs);
57 }
58
59 } // namespace
60
_mlir_ciface_tf_alloc(void * op_kernel_ctx,size_t num_elements,size_t element_size,int32_t output_index,int32_t num_candidates,int32_t * candidate_input_indices)61 extern "C" void* _mlir_ciface_tf_alloc(void* op_kernel_ctx, size_t num_elements,
62 size_t element_size,
63 int32_t output_index,
64 int32_t num_candidates,
65 int32_t* candidate_input_indices) {
66 static constexpr int kAmbiguousOutputIndex = -1;
67 auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
68 if (output_index != kAmbiguousOutputIndex) {
69 // Create a 1D shape, because the shapes don't have to match exactly for
70 // input forwarding. Only the number of elements must be the same.
71 tensorflow::TensorShape output_shape;
72 output_shape.AddDim(num_elements);
73
74 // Iterate over indices of all inputs that can potentially be used for
75 // forwarding.
76 for (int i = 0; i < num_candidates; ++i) {
77 auto tensor = ctx->forward_input(candidate_input_indices[i], output_index,
78 ctx->expected_output_dtype(output_index),
79 output_shape,
80 ctx->output_memory_type(output_index),
81 ctx->output_alloc_attr(output_index));
82 if (tensor != nullptr) {
83 return tensor->data();
84 }
85 }
86
87 CHECK(!ctx->output_expects_forwarding(output_index));
88 }
89
90 // If no forwarding happened, allocate a chunk of memory.
91 return GetAllocator(op_kernel_ctx)
92 ->AllocateRaw(Allocator::kAllocatorAlignment,
93 num_elements * element_size);
94 }
95
_mlir_ciface_tf_dealloc(void * op_kernel_ctx,void * ptr)96 extern "C" void _mlir_ciface_tf_dealloc(void* op_kernel_ctx, void* ptr) {
97 GetAllocator(op_kernel_ctx)->DeallocateRaw(ptr);
98 }
99
_mlir_ciface_tf_report_error(void * op_kernel_ctx,int32_t error_code,char * msg)100 extern "C" void _mlir_ciface_tf_report_error(void* op_kernel_ctx,
101 int32_t error_code, char* msg) {
102 Optional<ErrorCode> symbol = symbolizeErrorCode(error_code);
103 if (!symbol.has_value()) {
104 LOG(ERROR) << "No valid conversion from integer value = " << error_code
105 << "to ErrorCode attribute";
106 return;
107 }
108 auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
109 ctx->CtxFailureWithWarning(
110 tensorflow::Status{ConvertAttrToEnumValue(symbol.getValue()), msg});
111 }
112
ReportError(void * op_kernel_ctx,ErrorCode error_code,const char * msg)113 static void ReportError(void* op_kernel_ctx, ErrorCode error_code,
114 const char* msg) {
115 _mlir_ciface_tf_report_error(op_kernel_ctx, static_cast<uint32_t>(error_code),
116 const_cast<char*>(msg));
117 }
118
119 namespace {
120
GetFileCachePath(const std::string cache_dir,const std::string & code)121 std::string GetFileCachePath(const std::string cache_dir,
122 const std::string& code) {
123 size_t hash = llvm::hash_value(code);
124 return tensorflow::io::JoinPath(cache_dir, std::to_string(hash));
125 }
126
127 // A callback to register all externally defined symbols needed by the kernel.
TFFrameworkSymbolMap(llvm::orc::MangleAndInterner mangle)128 llvm::orc::SymbolMap TFFrameworkSymbolMap(llvm::orc::MangleAndInterner mangle) {
129 llvm::orc::SymbolMap symbol_map;
130 auto bind = [&](llvm::StringRef name, auto symbol_ptr) {
131 symbol_map[mangle(name)] = llvm::JITEvaluatedSymbol(
132 llvm::pointerToJITTargetAddress(symbol_ptr), llvm::JITSymbolFlags());
133 };
134
135 // Register TF framework symbols.
136 bind("_mlir_ciface_tf_alloc", &_mlir_ciface_tf_alloc);
137 bind("_mlir_ciface_tf_dealloc", &_mlir_ciface_tf_dealloc);
138 bind("_mlir_ciface_tf_report_error", &_mlir_ciface_tf_report_error);
139 #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
140 bind("_mlir_ciface_tf_launch_kernel", &_mlir_ciface_tf_launch_kernel);
141 #endif
142
143 // Register malloc/free to avoid unexpected implementations from shared libs.
144 bind("malloc", &malloc);
145 bind("free", &free);
146
147 return symbol_map;
148 }
149
Compile(const std::string code,llvm::SmallVectorImpl<std::string> & architectures,llvm::SmallVectorImpl<int64_t> & tile_sizes,llvm::SmallVectorImpl<int64_t> & unroll_factors,int64_t max_supported_rank,bool enable_ftz,bool index_64bit)150 llvm::Expected<std::unique_ptr<ExecutionEngine>> Compile(
151 const std::string code, llvm::SmallVectorImpl<std::string>& architectures,
152 llvm::SmallVectorImpl<int64_t>& tile_sizes,
153 llvm::SmallVectorImpl<int64_t>& unroll_factors, int64_t max_supported_rank,
154 bool enable_ftz, bool index_64bit) {
155 std::string cache_dir;
156 if (const char* dir = getenv(kTFJitCacheDirEnvVar.data())) {
157 cache_dir = dir;
158 }
159
160 // Check if we already have a partially compiled module in the filesystem
161 // based cache.
162 CompilationCacheItem item;
163 auto tenv = tensorflow::Env::Default();
164 if (!cache_dir.empty() && tenv->RecursivelyCreateDir(cache_dir).ok()) {
165 std::string data;
166 if (tensorflow::ReadFileToString(tenv, GetFileCachePath(cache_dir, code),
167 &data)
168 .ok()) {
169 item.ParseFromString(data);
170 if (item.original_module() != code) {
171 item.Clear();
172 }
173 }
174 }
175
176 // Create the kernel.
177 mlir::OwningOpRef<mlir::ModuleOp> module;
178 mlir::MLIRContext context;
179
180 if (item.result_module().empty()) {
181 // Otherwise, compile the module now.
182 tensorflow::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> status_or_module =
183 tensorflow::kernel_gen::GenerateKernelForTfCode(
184 context, code, architectures, tile_sizes, unroll_factors,
185 max_supported_rank, /*embed_memref_prints=*/false,
186 /*print_ptx=*/false, /*print_llvmir=*/false, enable_ftz,
187 index_64bit,
188 /*jit_compile=*/false,
189 /*jit_i64_indexed_for_large_tensors=*/false,
190 /*apply_cl_options=*/false);
191 if (!status_or_module.ok()) return nullptr;
192 module = std::move(status_or_module.ValueOrDie());
193
194 if (!cache_dir.empty() && tenv->RecursivelyCreateDir(cache_dir).ok()) {
195 // Save the compilation result here for future processes to use.
196 item.set_original_module(code);
197 llvm::raw_string_ostream stream(*item.mutable_result_module());
198 module.get().print(stream);
199 stream.flush();
200
201 tensorflow::WriteStringToFile(tenv, GetFileCachePath(cache_dir, code),
202 item.SerializeAsString())
203 .IgnoreError();
204 }
205 } else {
206 module = tensorflow::kernel_gen::SetupContextAndParseModule(
207 context, item.result_module())
208 .ValueOrDie();
209 }
210
211 // Initialize LLVM targets.
212 llvm::InitializeNativeTarget();
213 llvm::InitializeNativeTargetAsmPrinter();
214
215 // Create execution engine with an inner optimization pipeline.
216 auto opt_pipeline = mlir::makeOptimizingTransformer(
217 /*optLevel=*/2, /*sizeLevel=*/0, /*targetMachine=*/nullptr);
218 mlir::ExecutionEngineOptions engine_options;
219 engine_options.transformer = opt_pipeline;
220 llvm::Expected<std::unique_ptr<ExecutionEngine>> engine =
221 mlir::ExecutionEngine::create(module.get(), engine_options);
222 if (!engine) return nullptr;
223
224 // Finally, register the missing symbols.
225 engine.get()->registerSymbols(TFFrameworkSymbolMap);
226 return engine;
227 }
228
229 template <typename T, typename U = T>
SmallVectorFromCArray(int64_t num_elements,U * elements_ptr)230 llvm::SmallVector<T, 8> SmallVectorFromCArray(int64_t num_elements,
231 U* elements_ptr) {
232 llvm::SmallVector<T, 8> result;
233 result.reserve(num_elements);
234 for (int i = 0; i < num_elements; ++i) result.push_back(elements_ptr[i]);
235 return result;
236 }
237
238 } // namespace
239
_mlir_ciface_tf_jit_compile(void * op_kernel_ctx,char * code,int64_t num_tile_sizes,int64_t * tile_sizes_ptr,int64_t num_unroll_factors,int64_t * unroll_factors_ptr,int64_t max_supported_rank,bool enable_ftz,bool index_64bit)240 extern "C" void* _mlir_ciface_tf_jit_compile(
241 void* op_kernel_ctx, char* code, int64_t num_tile_sizes,
242 int64_t* tile_sizes_ptr, int64_t num_unroll_factors,
243 int64_t* unroll_factors_ptr, int64_t max_supported_rank, bool enable_ftz,
244 bool index_64bit) {
245 // Get the resource manager.
246 auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
247 tensorflow::ResourceMgr* rm = ctx->resource_manager();
248 if (!rm) {
249 ReportError(op_kernel_ctx, ErrorCode::UNKNOWN, "No resource manager.");
250 return nullptr;
251 }
252
253 // Get the JIT cache.
254 JITCache* jit_cache = nullptr;
255 auto status = rm->LookupOrCreate<JITCache>(rm->default_container(),
256 JITCache::kDefaultResourceName,
257 &jit_cache, JITCache::Create);
258 tensorflow::core::ScopedUnref jit_cache_ref(jit_cache);
259 if (!status.ok()) {
260 ReportError(op_kernel_ctx, ErrorCode::UNKNOWN,
261 "Failed to find or create JIT cache.");
262 return nullptr;
263 }
264
265 // Determine the unique architecture for the current GPU, if any.
266 SmallVector<std::string, 1> architectures;
267 #if defined(GOOGLE_CUDA)
268 stream_executor::CudaComputeCapability cc =
269 ctx->op_device_context()->stream()->GetCudaComputeCapability();
270 architectures.push_back(absl::StrCat("sm_", cc.major, cc.minor));
271 #elif defined(TENSORFLOW_USE_ROCM)
272 stream_executor::RocmComputeCapability cc =
273 ctx->op_device_context()->stream()->GetRocmComputeCapability();
274 architectures.push_back(cc.gcn_arch_name());
275 #endif
276
277 // Construct `SmallVector`s from arguments.
278 llvm::SmallVector<int64_t, 8> tile_sizes =
279 SmallVectorFromCArray<int64_t>(num_tile_sizes, tile_sizes_ptr);
280 llvm::SmallVector<int64_t, 8> unroll_factors =
281 SmallVectorFromCArray<int64_t>(num_unroll_factors, unroll_factors_ptr);
282
283 // Lookup or compile the execution module.
284 ExecutionEngine* engine = jit_cache->LookupOrCompile(code, [&]() {
285 return Compile(code, architectures, tile_sizes, unroll_factors,
286 max_supported_rank, enable_ftz, index_64bit);
287 });
288 if (engine == nullptr) {
289 ReportError(op_kernel_ctx, ErrorCode::UNKNOWN, "JIT compilation failed.");
290 return nullptr;
291 }
292 return engine;
293 }
294
_mlir_ciface_tf_jit_execute(void * op_kernel_ctx,void * callable,void * result,int64_t num_args,void * args_ptr)295 extern "C" void _mlir_ciface_tf_jit_execute(void* op_kernel_ctx, void* callable,
296 void* result, int64_t num_args,
297 void* args_ptr) {
298 // JIT compilation must have failed earlier if there is no callable ptr.
299 // Return some empty memory descriptor to prevent a crash.
300 if (callable == nullptr) {
301 auto* desc = static_cast<::UnrankedMemRefType<void>*>(result);
302 desc->rank = 0;
303 auto* inner_desc = static_cast<StridedMemRefType<int8_t, 0>*>(
304 malloc(sizeof(StridedMemRefType<int8_t, 0>)));
305 inner_desc->basePtr = nullptr;
306 inner_desc->data = nullptr;
307 inner_desc->offset = 0;
308 desc->descriptor = inner_desc;
309 return;
310 }
311
312 // Build the argument array according to `ExecutionEngine`'s calling
313 // convention.
314 auto* typed_args_ptr = static_cast<::UnrankedMemRefType<void>*>(args_ptr);
315 llvm::SmallVector<void*, 8> args_array = {&op_kernel_ctx};
316 for (int i = 0; i < num_args; i++) {
317 auto& desc = typed_args_ptr[i];
318 args_array.push_back(&desc.rank);
319 args_array.push_back(&desc.descriptor);
320 }
321 args_array.push_back(result);
322
323 llvm::Error invocation_result =
324 static_cast<ExecutionEngine*>(callable)->invokePacked("main", args_array);
325 if (invocation_result)
326 ReportError(op_kernel_ctx, ErrorCode::UNKNOWN, "JIT invocation failed.");
327 }
328
329 } // namespace tf_framework
330 } // namespace kernel_gen
331 } // namespace mlir
332