xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/runtime/executable.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/runtime/executable.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "llvm/ExecutionEngine/Orc/Core.h"
23 #include "llvm/Support/ErrorOr.h"
24 #include "tensorflow/compiler/xla/mlir/utils/runtime/async_runtime_api.h"
25 #include "tensorflow/compiler/xla/mlir/utils/runtime/c_runner_utils.h"
26 #include "tensorflow/compiler/xla/runtime/custom_call.h"
27 #include "tensorflow/compiler/xla/runtime/custom_call_registry.h"
28 #include "tensorflow/compiler/xla/runtime/errors.h"
29 #include "tensorflow/compiler/xla/runtime/runtime.h"
30 #include "tensorflow/compiler/xla/runtime/type_id.h"
31 
32 namespace xla {
33 namespace runtime {
34 
35 using llvm::dyn_cast;
36 
37 using llvm::Error;
38 using llvm::ErrorOr;
39 using llvm::Expected;
40 
41 using llvm::orc::MangleAndInterner;
42 using llvm::orc::SymbolMap;
43 
44 
45 // KernelContext encapsulates all the data that is required to implement XLA
46 // Runtime <-> XLA Executable integration API.
47 struct KernelContext {
48   // Results memory layout is owned by the executable, and stays alive after
49   // the entrypoint function execution completes.
50   const Executable::ResultsMemoryLayout* results_memory_layout = nullptr;
51 
52   // CallFrame life time bound to the entrypoint function execution and
53   // destroyed immediately when the function returns. Only the entrypoint
54   // function itself reads the arguments and writes to the function results
55   // storage.
56   Executable::CallFrame* call_frame = nullptr;
57 
58   // User-defined data for custom call handlers.
59   CustomCall::UserData* custom_call_data = nullptr;
60 
61   // User-defined diagnostic engine for reporting diagnostics.
62   DiagnosticEngine* diagnostic_engine = nullptr;
63 };
64 
65 //===----------------------------------------------------------------------===//
66 // Conversion from custom calls library and type id registry to symbols binding.
67 //===----------------------------------------------------------------------===//
68 
ToSymbolsBinding(DirectCustomCallLibrary lib,TypeIDNameRegistry::RegistrationFn types)69 ExecutionEngine::SymbolsBinding ToSymbolsBinding(
70     DirectCustomCallLibrary lib, TypeIDNameRegistry::RegistrationFn types) {
71   return [=](MangleAndInterner mangle) {
72     SymbolMap symbol_map;
73 
74     // Always register canonical custom call types with the registry.
75     TypeIDNameRegistry registry;
76     PopulateCustomCallTypeIdNames(registry);
77     if (types) types(registry);
78 
79     // Register direct custom calls.
80     using DirectCustomCall = DirectCustomCallLibrary::DirectCustomCall;
81     lib.ForEach([&](llvm::StringRef name, DirectCustomCall custom_call) {
82       symbol_map[mangle(name)] = llvm::JITEvaluatedSymbol(
83           llvm::pointerToJITTargetAddress(custom_call), llvm::JITSymbolFlags());
84     });
85 
86     // Register type id symbols.
87     registry.ForEach([&](llvm::StringRef name, TypeID type_id) {
88       auto type_id_ptr =
89           reinterpret_cast<std::uintptr_t>(type_id.getAsOpaquePointer());
90       symbol_map[mangle(name)] = llvm::JITEvaluatedSymbol(
91           static_cast<llvm::JITTargetAddress>(type_id_ptr),
92           llvm::JITSymbolFlags());
93     });
94 
95     return symbol_map;
96   };
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // Register XLA runtime symbols with XLA execution engine.
101 //===----------------------------------------------------------------------===//
102 
103 static SymbolMap RuntimeApiSymbolMap(MangleAndInterner);
104 
105 //===----------------------------------------------------------------------===//
106 // Construct a symbols binding for XLA executable.
107 //===----------------------------------------------------------------------===//
108 
RuntimeSymbolsBinding(ExecutionEngine::SymbolsBinding custom_binding)109 ExecutionEngine::SymbolsBinding RuntimeSymbolsBinding(
110     ExecutionEngine::SymbolsBinding custom_binding) {
111   return ExecutionEngine::BindAll(
112       {// Register MLIR C Runner API intrinsics (defined in CRunnerUtils).
113        CRunnerUtilsSymbolMap,
114        // Register Async Runtime API intrinsics.
115        AsyncRuntimeApiSymbolMap,
116        // Register memory allocation functions (malloc, free, ...).
117        AsyncRuntimeMemoryAllocationSymbolMap,
118        // Register Runtime API intrinsics (returning results and errors).
119        RuntimeApiSymbolMap,
120        // Register any additional user-defined symbol bindings
121        std::move(custom_binding)});
122 }
123 
124 //===----------------------------------------------------------------------===//
125 // Get executable arguments and results memory layouts.
126 //===----------------------------------------------------------------------===//
127 
128 /*static*/ Expected<Executable::ArgumentsMemoryLayout>
GetArgumentsMemoryLayout(const FunctionType & signature)129 Executable::GetArgumentsMemoryLayout(const FunctionType& signature) {
130   // Requirements for passing function arguments.
131   ArgumentsMemoryLayout layout;
132 
133   for (unsigned i = 0; i < signature.num_operands(); ++i) {
134     const Type* type = signature.operand(i);
135 
136     // Check if the type defines the ABI for passing it as an argument.
137     if (ErrorOr<Type::ArgumentAbi> abi = type->AsArgument()) {
138       layout.num_args_ptrs += abi->num_ptrs;
139       continue;
140     }
141 
142     return MakeStringError("unknown operand #", i, " argument ABI: ", *type);
143   }
144 
145   return layout;
146 }
147 
148 /*static*/ Expected<Executable::ResultsMemoryLayout>
GetResultsMemoryLayout(const FunctionType & signature)149 Executable::GetResultsMemoryLayout(const FunctionType& signature) {
150   // Requirements for returning function results.
151   ResultsMemoryLayout layout;
152   layout.offsets.reserve(signature.num_results());
153 
154   // TODO(ezhulenev): We should support allocating storage for results with non
155   // standard alignment requirements.
156 
157   for (unsigned i = 0; i < signature.num_results(); ++i) {
158     const Type* type = signature.result(i);
159 
160     // Keep track if the function has asynchronous results.
161     layout.has_async_results |= llvm::isa<AsyncTokenType, AsyncValueType>(type);
162 
163     // Check if the type defines the ABI for returning it as a result.
164     if (ErrorOr<Type::ResultAbi> abi = type->AsResult()) {
165       layout.offsets.emplace_back(layout.size);
166       layout.size += abi->size;
167       continue;
168     }
169 
170     return MakeStringError("unknown result #", i, " type result ABI: ", *type);
171   }
172 
173   return layout;
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // Executable CallFrame initialization.
178 //===----------------------------------------------------------------------===//
179 
180 // Always verify executable arguments in debug mode.
VerifyArguments(bool verify_arguments)181 static bool VerifyArguments(bool verify_arguments) {
182 #if defined(NDEBUG)
183   return verify_arguments;
184 #endif
185   return true;
186 }
187 
InitializeCallFrame(ArgumentsRef arguments,CallFrame * call_frame,bool verify_arguments) const188 Error Executable::InitializeCallFrame(ArgumentsRef arguments,
189                                       CallFrame* call_frame,
190                                       bool verify_arguments) const {
191   // TODO(ezhulenev): If executable is specialized for concrete shapes then
192   // there is no need to verify them once more here. However currently we rely
193   // on a hash code to look up specializations, and this can lead to collisions.
194   if (VerifyArguments(verify_arguments)) {
195     // We verify run time arguments against the run time signature.
196     const FunctionType& signature = runtime_signature_;
197 
198     // Make sure that we call the executable with the correct number of
199     // arguments. We subtract one argument from the signature because it
200     // corresponds to the context that we prepend to the given arguments.
201     if (LLVM_UNLIKELY(arguments.size() != signature.num_operands() - 1))
202       return MakeStringError(
203           "number of arguments doesn't match the function signature: ",
204           arguments.size(), " vs ", signature.num_operands() - 1);
205 
206     // Verify that all arguments passed at runtime are compatible with compiled
207     // function signature.
208     auto kctx = dyn_cast<KernelContextOperandType>(signature.operand(0));
209     if (LLVM_UNLIKELY(!kctx)) {
210       return MakeStringError(
211           "expected KernelContext in first argument of signature, got: ",
212           signature.operand(0));
213     }
214 
215     // We use 0-based index for arguments, because the kernel context argument
216     // is an internal implementation detail, and in case of an error users
217     // should get back argument index corresponding to the user provided
218     // signature.
219     for (unsigned i = 0; i < arguments.size(); ++i) {
220       unsigned idx = i + 1;  // use 1-based index to fetch signature operand
221       if (auto err = arguments[i].Verify(*signature.operand(idx)))
222         return MakeStringError("argument #", i,
223                                " doesn't match the signature: ", err);
224     }
225   }
226 
227   size_t num_args_ptrs = arguments_memory_layout_.num_args_ptrs;
228   call_frame->args.resize_for_overwrite(num_args_ptrs);
229 
230   // Add a placeholder for the kernel context as the first argument.
231   call_frame->args[0] = nullptr;
232 
233   // Keep offset of the next argument in the `args` array, and update it every
234   // time we pack a new argument.
235   size_t offset = 1;
236 
237   // Pack all arguments according to the ABI to the call frame arguments.
238   for (unsigned i = 0; i < arguments.size(); ++i)
239     offset = arguments[i].Pack(call_frame->args, offset);
240 
241   assert(offset == num_args_ptrs &&
242          "reserved number of args must match the argument offset");
243 
244   // Allocate storage for results.
245   call_frame->results.resize_for_overwrite(results_memory_layout_.size);
246 
247   // Mark results memory initialized to supress potential msan errors.
248   ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(call_frame->results.data(),
249                                       call_frame->results.size());
250 
251   return Error::success();
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // Execute the compiled XLA runtime executable.
256 //===----------------------------------------------------------------------===//
257 
Execute(ArgumentsRef arguments,const ResultConverter & results,const ExecuteOpts & opts,bool verify_arguments) const258 Error Executable::Execute(ArgumentsRef arguments,
259                           const ResultConverter& results,
260                           const ExecuteOpts& opts,
261                           bool verify_arguments) const {
262   // CallFrame can be allocated on the stack because compiled function will
263   // unpack all the arguments it needs, and async regions will not access
264   // the data after the initial function will return the result.
265   CallFrame call_frame;
266 
267   // Touch every byte of the memref arguments, to trigger memory sanitizer error
268   // if some of the memrefs are already deallocated. Unfortunatelly sanitizers
269   // do not work inside the JIT compiled code, and compiled executables still
270   // can do out of bounds memory access, however this sanity check allows to
271   // catch obvious errors earlier.
272 #if defined(MEMORY_SANITIZER)
273   auto do_not_optimize = [&](const auto& value) -> void {
274     asm volatile("" : : "r,m"(value) : "memory");
275   };
276 
277   for (unsigned i = 0; i < arguments.size(); ++i) {
278     auto* memref = dyn_cast<MemrefDesc>(&arguments[i]);
279     if (!memref) continue;
280 
281     int64_t size_in_bytes = GetHostSize(memref->dtype());
282     for (int64_t size : memref->sizes()) size_in_bytes *= size;
283 
284     uint8_t* data = static_cast<uint8_t*>(memref->data());
285     for (int64_t i = 0; i < size_in_bytes; ++i) {
286       uint8_t value = data[i];
287       do_not_optimize(value);
288     }
289   }
290 #endif
291 
292   // Compiled function takes arguments and results as `void**` type erased
293   // pointer. See mlir::ExecutionEngine `packFunctionArguments` for the details.
294   if (auto err = InitializeCallFrame(arguments, &call_frame, verify_arguments))
295     return (results.ReturnError(err), std::move(err));
296 
297   Execute(call_frame, opts);
298 
299   // Convert compiled function return values into results.
300   if (auto err = ReturnResults(results, &call_frame)) return err;
301 
302   return Error::success();
303 }
304 
Execute(CallFrame & call_frame,const ExecuteOpts & opts) const305 void Executable::Execute(CallFrame& call_frame, const ExecuteOpts& opts) const {
306   // Set the AsyncRuntime to be used by all async tasks spawned by the
307   // executable.
308   AsyncRuntime::Set(AsyncRuntime(opts.async_task_runner));
309 
310   // Runtime kernel context can be used only by the entrypoint function and can
311   // be safely allocated on the stack.
312   KernelContext kernel_context = {&results_memory_layout_, &call_frame,
313                                   opts.custom_call_data,
314                                   opts.diagnostic_engine};
315 
316   // Override the kernel context argument.
317   KernelContext* kernel_context_ptr = &kernel_context;
318   assert(call_frame.args.size() == arguments_memory_layout_.num_args_ptrs);
319   assert(call_frame.args[0] == nullptr && "expected to see a placeholder");
320   call_frame.args[0] = &kernel_context_ptr;
321 
322   // Call the compiled function.
323   (*fptr_)(call_frame.args.data());
324 }
325 
ReturnResults(const ResultConverter & results,CallFrame * call_frame) const326 Error Executable::ReturnResults(const ResultConverter& results,
327                                 CallFrame* call_frame) const {
328   // If execution failed, forward error to all results.
329   if (call_frame->is_error) {
330     auto err = MakeStringError("run time error: ", call_frame->error);
331     return (results.ReturnError(err), std::move(err));
332   }
333 
334   // Try to convert results using registered conversion functions.
335   bool converted = true;
336 
337   for (unsigned i = 0; i < runtime_signature_.num_results(); ++i) {
338     const Type* type = signature_.result(i);
339     const Type* runtime_type = runtime_signature_.result(i);
340     void* ret = &call_frame->results[results_memory_layout_.offsets[i]];
341     bool res = mlir::succeeded(results.ReturnValue(i, type, runtime_type, ret));
342     converted = converted && res;
343   }
344 
345   if (LLVM_UNLIKELY(!converted))
346     return MakeStringError("failed to convert all returned values");
347   else
348     return Error::success();
349 }
350 
351 //===----------------------------------------------------------------------===//
352 // Load AOT compiled executable from an object file.
353 //===----------------------------------------------------------------------===//
354 
LoadFromObjFile(llvm::StringRef name,std::unique_ptr<llvm::MemoryBuffer> obj_file,llvm::StringRef entrypoint,FunctionType signature,FunctionType runtime_signature,ExecutionEngine::SymbolsBinding symbols_binding,llvm::StringRef memory_region_name)355 /*static*/ Expected<Executable> Executable::LoadFromObjFile(
356     llvm::StringRef name, std::unique_ptr<llvm::MemoryBuffer> obj_file,
357     llvm::StringRef entrypoint, FunctionType signature,
358     FunctionType runtime_signature,
359     ExecutionEngine::SymbolsBinding symbols_binding,
360     llvm::StringRef memory_region_name) {
361   // Memory region name to mmap executable code.
362   std::string mapper_name = llvm::formatv(
363       "/xla_aot{0}{1}:@{2}::@{3}", memory_region_name.empty() ? "" : ":",
364       EscapeMemRegionName(memory_region_name), name, entrypoint);
365 
366   // Custom memory mapper to tag memory allocated for XLA executables.
367   std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper =
368       XlaRuntimeMemoryMapper::Create(std::move(mapper_name));
369 
370   // Construct options for the XLA execution engine.
371   ExecutionEngine::AotOptions options;
372   options.section_memory_mapper = memory_mapper.get();
373   options.symbols_binding = RuntimeSymbolsBinding(std::move(symbols_binding));
374 
375   auto engine = ExecutionEngine::CreateFromObjFile(std::move(obj_file),
376                                                    entrypoint, options);
377 
378   // Get the memory layout for passing function arguments.
379   auto arguments_memory_layout = GetArgumentsMemoryLayout(runtime_signature);
380   if (auto err = arguments_memory_layout.takeError()) return std::move(err);
381 
382   // Get the memory layout for returning function results.
383   auto results_memory_layout = GetResultsMemoryLayout(runtime_signature);
384   if (auto err = results_memory_layout.takeError()) return std::move(err);
385 
386   return Executable(name.str(), std::move(memory_mapper), std::move(*engine),
387                     std::move(signature), std::move(runtime_signature),
388                     std::move(*arguments_memory_layout),
389                     std::move(*results_memory_layout),
390                     /*specialization=*/llvm::None,
391                     /*time_to_compile*/ std::chrono::milliseconds(0));
392 }
393 
394 //===----------------------------------------------------------------------===//
395 
num_results() const396 unsigned Executable::num_results() const {
397   return runtime_signature_.num_results();
398 }
399 
signature() const400 const FunctionType& Executable::signature() const { return signature_; }
401 
runtime_signature() const402 const FunctionType& Executable::runtime_signature() const {
403   return runtime_signature_;
404 }
405 
time_to_compile() const406 std::chrono::milliseconds Executable::time_to_compile() const {
407   return time_to_compile_;
408 }
409 
obj_file() const410 std::unique_ptr<llvm::MemoryBuffer> Executable::obj_file() const {
411   return engine_->obj_file();
412 }
413 
GetUserData(KernelContext * ctx)414 CustomCall::UserData* Executable::GetUserData(KernelContext* ctx) {
415   return ctx->custom_call_data;
416 }
417 
GetDiagnosticEngine(KernelContext * ctx)418 DiagnosticEngine* Executable::GetDiagnosticEngine(KernelContext* ctx) {
419   return ctx->diagnostic_engine;
420 }
421 
Call(KernelContext * ctx,class CustomCall & call,void ** args,void ** attrs)422 mlir::LogicalResult Executable::Call(KernelContext* ctx, class CustomCall& call,
423                                      void** args, void** attrs) {
424   return call.call(args, attrs, ctx->custom_call_data, ctx->diagnostic_engine);
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // Register XLA runtime symbols with XLA execution engine.
429 //===----------------------------------------------------------------------===//
430 
RuntimeApiSymbolMap(MangleAndInterner mangle)431 SymbolMap RuntimeApiSymbolMap(MangleAndInterner mangle) {
432   SymbolMap symbol_map;
433 
434   auto bind = [&](llvm::StringRef name, auto symbol_ptr) {
435     symbol_map[mangle(name)] = llvm::JITEvaluatedSymbol(
436         llvm::pointerToJITTargetAddress(symbol_ptr), llvm::JITSymbolFlags());
437   };
438 
439   bind("runtimeGetResultStorage", &GetResultStorage);
440   bind("runtimeSetError", &SetError);
441   bind("runtimeCustomCall", &CustomCall);
442 
443   return symbol_map;
444 }
445 
446 //----------------------------------------------------------------------------//
447 // Implement XLA Runtime <-> XLA Executable integration API.
448 //----------------------------------------------------------------------------//
449 
GetResultStorage(KernelContext * ctx,int64_t index)450 void* GetResultStorage(KernelContext* ctx, int64_t index) {
451   assert(ctx && "kernel context must be not null");
452   assert(!ctx->call_frame->is_error && "error must not be set");
453   size_t offset = ctx->results_memory_layout->offsets[index];
454   assert(offset < ctx->call_frame->results.size() && "offset is out of bounds");
455   ctx->call_frame->has_set_outputs = true;
456   return &ctx->call_frame->results[offset];
457 }
458 
SetError(KernelContext * ctx,const char * error)459 void SetError(KernelContext* ctx, const char* error) {
460   assert(ctx && "kernel context must be not null");
461   assert(error && "runtime error must be not null");
462   assert(!ctx->call_frame->is_error && "error must be set only once");
463   assert(!ctx->call_frame->has_set_outputs && "outputs must be undefined");
464   ctx->call_frame->is_error = true;
465   ctx->call_frame->error = {error};
466 }
467 
CustomCall(KernelContext * ctx,const char * target,void ** args,void ** attrs)468 bool CustomCall(KernelContext* ctx, const char* target, void** args,
469                 void** attrs) {
470   assert(ctx && target && args && attrs && "all arguments must be not null");
471 
472   // Default custom calls registry for the XLA executables.
473   static CustomCallRegistry* registry = []() {
474     auto* registry = new CustomCallRegistry();
475     RegisterStaticCustomCalls(registry);
476     return registry;
477   }();
478 
479   auto* custom_call = registry->Find(target);
480   assert(custom_call && "custom call not found");
481   if (custom_call == nullptr) return false;
482 
483   return succeeded(custom_call->call(args, attrs, ctx->custom_call_data,
484                                      ctx->diagnostic_engine));
485 }
486 
487 }  // namespace runtime
488 }  // namespace xla
489