xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h"
17 
18 #include <cstddef>
19 #include <string>
20 #include <utility>
21 
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/TargetSelect.h"
24 #include "mlir/ExecutionEngine/ExecutionEngine.h"  // from @llvm-project
25 #include "mlir/ExecutionEngine/OptUtils.h"  // from @llvm-project
26 #include "mlir/Parser/Parser.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/tools/kernel_gen/compile_cache_item.pb.h"
28 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
29 #include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h"
30 #include "tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/resource_mgr.h"
33 #include "tensorflow/core/lib/io/path.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/platform/statusor.h"
36 #include "tensorflow/stream_executor/stream.h"
37 
38 #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
39 #include "tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.h"
40 #endif
41 
42 static constexpr absl::string_view kTFJitCacheDirEnvVar = "TF_JIT_CACHE_DIR";
43 
44 namespace mlir {
45 namespace kernel_gen {
46 namespace tf_framework {
47 namespace {
48 
49 using tensorflow::Allocator;
50 using tensorflow::AllocatorAttributes;
51 
GetAllocator(void * op_kernel_ctx)52 Allocator* GetAllocator(void* op_kernel_ctx) {
53   auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
54   // TODO(pifon): Figure out how to set AllocatorAttributes correctly.
55   AllocatorAttributes attrs;
56   return ctx->get_allocator(attrs);
57 }
58 
59 }  // namespace
60 
_mlir_ciface_tf_alloc(void * op_kernel_ctx,size_t num_elements,size_t element_size,int32_t output_index,int32_t num_candidates,int32_t * candidate_input_indices)61 extern "C" void* _mlir_ciface_tf_alloc(void* op_kernel_ctx, size_t num_elements,
62                                        size_t element_size,
63                                        int32_t output_index,
64                                        int32_t num_candidates,
65                                        int32_t* candidate_input_indices) {
66   static constexpr int kAmbiguousOutputIndex = -1;
67   auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
68   if (output_index != kAmbiguousOutputIndex) {
69     // Create a 1D shape, because the shapes don't have to match exactly for
70     // input forwarding. Only the number of elements must be the same.
71     tensorflow::TensorShape output_shape;
72     output_shape.AddDim(num_elements);
73 
74     // Iterate over indices of all inputs that can potentially be used for
75     // forwarding.
76     for (int i = 0; i < num_candidates; ++i) {
77       auto tensor = ctx->forward_input(candidate_input_indices[i], output_index,
78                                        ctx->expected_output_dtype(output_index),
79                                        output_shape,
80                                        ctx->output_memory_type(output_index),
81                                        ctx->output_alloc_attr(output_index));
82       if (tensor != nullptr) {
83         return tensor->data();
84       }
85     }
86 
87     CHECK(!ctx->output_expects_forwarding(output_index));
88   }
89 
90   // If no forwarding happened, allocate a chunk of memory.
91   return GetAllocator(op_kernel_ctx)
92       ->AllocateRaw(Allocator::kAllocatorAlignment,
93                     num_elements * element_size);
94 }
95 
_mlir_ciface_tf_dealloc(void * op_kernel_ctx,void * ptr)96 extern "C" void _mlir_ciface_tf_dealloc(void* op_kernel_ctx, void* ptr) {
97   GetAllocator(op_kernel_ctx)->DeallocateRaw(ptr);
98 }
99 
_mlir_ciface_tf_report_error(void * op_kernel_ctx,int32_t error_code,char * msg)100 extern "C" void _mlir_ciface_tf_report_error(void* op_kernel_ctx,
101                                              int32_t error_code, char* msg) {
102   Optional<ErrorCode> symbol = symbolizeErrorCode(error_code);
103   if (!symbol.has_value()) {
104     LOG(ERROR) << "No valid conversion from integer value = " << error_code
105                << "to ErrorCode attribute";
106     return;
107   }
108   auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
109   ctx->CtxFailureWithWarning(
110       tensorflow::Status{ConvertAttrToEnumValue(symbol.getValue()), msg});
111 }
112 
ReportError(void * op_kernel_ctx,ErrorCode error_code,const char * msg)113 static void ReportError(void* op_kernel_ctx, ErrorCode error_code,
114                         const char* msg) {
115   _mlir_ciface_tf_report_error(op_kernel_ctx, static_cast<uint32_t>(error_code),
116                                const_cast<char*>(msg));
117 }
118 
119 namespace {
120 
GetFileCachePath(const std::string cache_dir,const std::string & code)121 std::string GetFileCachePath(const std::string cache_dir,
122                              const std::string& code) {
123   size_t hash = llvm::hash_value(code);
124   return tensorflow::io::JoinPath(cache_dir, std::to_string(hash));
125 }
126 
127 // A callback to register all externally defined symbols needed by the kernel.
TFFrameworkSymbolMap(llvm::orc::MangleAndInterner mangle)128 llvm::orc::SymbolMap TFFrameworkSymbolMap(llvm::orc::MangleAndInterner mangle) {
129   llvm::orc::SymbolMap symbol_map;
130   auto bind = [&](llvm::StringRef name, auto symbol_ptr) {
131     symbol_map[mangle(name)] = llvm::JITEvaluatedSymbol(
132         llvm::pointerToJITTargetAddress(symbol_ptr), llvm::JITSymbolFlags());
133   };
134 
135   // Register TF framework symbols.
136   bind("_mlir_ciface_tf_alloc", &_mlir_ciface_tf_alloc);
137   bind("_mlir_ciface_tf_dealloc", &_mlir_ciface_tf_dealloc);
138   bind("_mlir_ciface_tf_report_error", &_mlir_ciface_tf_report_error);
139 #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
140   bind("_mlir_ciface_tf_launch_kernel", &_mlir_ciface_tf_launch_kernel);
141 #endif
142 
143   // Register malloc/free to avoid unexpected implementations from shared libs.
144   bind("malloc", &malloc);
145   bind("free", &free);
146 
147   return symbol_map;
148 }
149 
Compile(const std::string code,llvm::SmallVectorImpl<std::string> & architectures,llvm::SmallVectorImpl<int64_t> & tile_sizes,llvm::SmallVectorImpl<int64_t> & unroll_factors,int64_t max_supported_rank,bool enable_ftz,bool index_64bit)150 llvm::Expected<std::unique_ptr<ExecutionEngine>> Compile(
151     const std::string code, llvm::SmallVectorImpl<std::string>& architectures,
152     llvm::SmallVectorImpl<int64_t>& tile_sizes,
153     llvm::SmallVectorImpl<int64_t>& unroll_factors, int64_t max_supported_rank,
154     bool enable_ftz, bool index_64bit) {
155   std::string cache_dir;
156   if (const char* dir = getenv(kTFJitCacheDirEnvVar.data())) {
157     cache_dir = dir;
158   }
159 
160   // Check if we already have a partially compiled module in the filesystem
161   // based cache.
162   CompilationCacheItem item;
163   auto tenv = tensorflow::Env::Default();
164   if (!cache_dir.empty() && tenv->RecursivelyCreateDir(cache_dir).ok()) {
165     std::string data;
166     if (tensorflow::ReadFileToString(tenv, GetFileCachePath(cache_dir, code),
167                                      &data)
168             .ok()) {
169       item.ParseFromString(data);
170       if (item.original_module() != code) {
171         item.Clear();
172       }
173     }
174   }
175 
176   // Create the kernel.
177   mlir::OwningOpRef<mlir::ModuleOp> module;
178   mlir::MLIRContext context;
179 
180   if (item.result_module().empty()) {
181     // Otherwise, compile the module now.
182     tensorflow::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> status_or_module =
183         tensorflow::kernel_gen::GenerateKernelForTfCode(
184             context, code, architectures, tile_sizes, unroll_factors,
185             max_supported_rank, /*embed_memref_prints=*/false,
186             /*print_ptx=*/false, /*print_llvmir=*/false, enable_ftz,
187             index_64bit,
188             /*jit_compile=*/false,
189             /*jit_i64_indexed_for_large_tensors=*/false,
190             /*apply_cl_options=*/false);
191     if (!status_or_module.ok()) return nullptr;
192     module = std::move(status_or_module.ValueOrDie());
193 
194     if (!cache_dir.empty() && tenv->RecursivelyCreateDir(cache_dir).ok()) {
195       // Save the compilation result here for future processes to use.
196       item.set_original_module(code);
197       llvm::raw_string_ostream stream(*item.mutable_result_module());
198       module.get().print(stream);
199       stream.flush();
200 
201       tensorflow::WriteStringToFile(tenv, GetFileCachePath(cache_dir, code),
202                                     item.SerializeAsString())
203           .IgnoreError();
204     }
205   } else {
206     module = tensorflow::kernel_gen::SetupContextAndParseModule(
207                  context, item.result_module())
208                  .ValueOrDie();
209   }
210 
211   // Initialize LLVM targets.
212   llvm::InitializeNativeTarget();
213   llvm::InitializeNativeTargetAsmPrinter();
214 
215   // Create execution engine with an inner optimization pipeline.
216   auto opt_pipeline = mlir::makeOptimizingTransformer(
217       /*optLevel=*/2, /*sizeLevel=*/0, /*targetMachine=*/nullptr);
218   mlir::ExecutionEngineOptions engine_options;
219   engine_options.transformer = opt_pipeline;
220   llvm::Expected<std::unique_ptr<ExecutionEngine>> engine =
221       mlir::ExecutionEngine::create(module.get(), engine_options);
222   if (!engine) return nullptr;
223 
224   // Finally, register the missing symbols.
225   engine.get()->registerSymbols(TFFrameworkSymbolMap);
226   return engine;
227 }
228 
229 template <typename T, typename U = T>
SmallVectorFromCArray(int64_t num_elements,U * elements_ptr)230 llvm::SmallVector<T, 8> SmallVectorFromCArray(int64_t num_elements,
231                                               U* elements_ptr) {
232   llvm::SmallVector<T, 8> result;
233   result.reserve(num_elements);
234   for (int i = 0; i < num_elements; ++i) result.push_back(elements_ptr[i]);
235   return result;
236 }
237 
238 }  // namespace
239 
_mlir_ciface_tf_jit_compile(void * op_kernel_ctx,char * code,int64_t num_tile_sizes,int64_t * tile_sizes_ptr,int64_t num_unroll_factors,int64_t * unroll_factors_ptr,int64_t max_supported_rank,bool enable_ftz,bool index_64bit)240 extern "C" void* _mlir_ciface_tf_jit_compile(
241     void* op_kernel_ctx, char* code, int64_t num_tile_sizes,
242     int64_t* tile_sizes_ptr, int64_t num_unroll_factors,
243     int64_t* unroll_factors_ptr, int64_t max_supported_rank, bool enable_ftz,
244     bool index_64bit) {
245   // Get the resource manager.
246   auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
247   tensorflow::ResourceMgr* rm = ctx->resource_manager();
248   if (!rm) {
249     ReportError(op_kernel_ctx, ErrorCode::UNKNOWN, "No resource manager.");
250     return nullptr;
251   }
252 
253   // Get the JIT cache.
254   JITCache* jit_cache = nullptr;
255   auto status = rm->LookupOrCreate<JITCache>(rm->default_container(),
256                                              JITCache::kDefaultResourceName,
257                                              &jit_cache, JITCache::Create);
258   tensorflow::core::ScopedUnref jit_cache_ref(jit_cache);
259   if (!status.ok()) {
260     ReportError(op_kernel_ctx, ErrorCode::UNKNOWN,
261                 "Failed to find or create JIT cache.");
262     return nullptr;
263   }
264 
265   // Determine the unique architecture for the current GPU, if any.
266   SmallVector<std::string, 1> architectures;
267 #if defined(GOOGLE_CUDA)
268   stream_executor::CudaComputeCapability cc =
269       ctx->op_device_context()->stream()->GetCudaComputeCapability();
270   architectures.push_back(absl::StrCat("sm_", cc.major, cc.minor));
271 #elif defined(TENSORFLOW_USE_ROCM)
272   stream_executor::RocmComputeCapability cc =
273       ctx->op_device_context()->stream()->GetRocmComputeCapability();
274   architectures.push_back(cc.gcn_arch_name());
275 #endif
276 
277   // Construct `SmallVector`s from arguments.
278   llvm::SmallVector<int64_t, 8> tile_sizes =
279       SmallVectorFromCArray<int64_t>(num_tile_sizes, tile_sizes_ptr);
280   llvm::SmallVector<int64_t, 8> unroll_factors =
281       SmallVectorFromCArray<int64_t>(num_unroll_factors, unroll_factors_ptr);
282 
283   // Lookup or compile the execution module.
284   ExecutionEngine* engine = jit_cache->LookupOrCompile(code, [&]() {
285     return Compile(code, architectures, tile_sizes, unroll_factors,
286                    max_supported_rank, enable_ftz, index_64bit);
287   });
288   if (engine == nullptr) {
289     ReportError(op_kernel_ctx, ErrorCode::UNKNOWN, "JIT compilation failed.");
290     return nullptr;
291   }
292   return engine;
293 }
294 
_mlir_ciface_tf_jit_execute(void * op_kernel_ctx,void * callable,void * result,int64_t num_args,void * args_ptr)295 extern "C" void _mlir_ciface_tf_jit_execute(void* op_kernel_ctx, void* callable,
296                                             void* result, int64_t num_args,
297                                             void* args_ptr) {
298   // JIT compilation must have failed earlier if there is no callable ptr.
299   // Return some empty memory descriptor to prevent a crash.
300   if (callable == nullptr) {
301     auto* desc = static_cast<::UnrankedMemRefType<void>*>(result);
302     desc->rank = 0;
303     auto* inner_desc = static_cast<StridedMemRefType<int8_t, 0>*>(
304         malloc(sizeof(StridedMemRefType<int8_t, 0>)));
305     inner_desc->basePtr = nullptr;
306     inner_desc->data = nullptr;
307     inner_desc->offset = 0;
308     desc->descriptor = inner_desc;
309     return;
310   }
311 
312   // Build the argument array according to `ExecutionEngine`'s calling
313   // convention.
314   auto* typed_args_ptr = static_cast<::UnrankedMemRefType<void>*>(args_ptr);
315   llvm::SmallVector<void*, 8> args_array = {&op_kernel_ctx};
316   for (int i = 0; i < num_args; i++) {
317     auto& desc = typed_args_ptr[i];
318     args_array.push_back(&desc.rank);
319     args_array.push_back(&desc.descriptor);
320   }
321   args_array.push_back(result);
322 
323   llvm::Error invocation_result =
324       static_cast<ExecutionEngine*>(callable)->invokePacked("main", args_array);
325   if (invocation_result)
326     ReportError(op_kernel_ctx, ErrorCode::UNKNOWN, "JIT invocation failed.");
327 }
328 
329 }  // namespace tf_framework
330 }  // namespace kernel_gen
331 }  // namespace mlir
332