xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/runtime/executable.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_RUNTIME_EXECUTABLE_H_
17 #define XLA_RUNTIME_EXECUTABLE_H_
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/Support/MemoryBuffer.h"
25 #include "tensorflow/compiler/xla/runtime/arguments.h"
26 #include "tensorflow/compiler/xla/runtime/async_runtime.h"
27 #include "tensorflow/compiler/xla/runtime/custom_call.h"
28 #include "tensorflow/compiler/xla/runtime/diagnostics.h"
29 #include "tensorflow/compiler/xla/runtime/execution_engine.h"
30 #include "tensorflow/compiler/xla/runtime/logical_result.h"
31 #include "tensorflow/compiler/xla/runtime/memory_mapper.h"
32 #include "tensorflow/compiler/xla/runtime/results.h"
33 #include "tensorflow/compiler/xla/runtime/type_id.h"
34 #include "tensorflow/compiler/xla/runtime/types.h"
35 
36 namespace xla {
37 namespace runtime {
38 
39 class KernelContext;
40 class JitCompiler;
41 
42 // Returns a symbols binding for running XLA executable with a custom symbols
43 // provided by the user.
44 ExecutionEngine::SymbolsBinding RuntimeSymbolsBinding(
45     ExecutionEngine::SymbolsBinding custom_binding);
46 
47 // Converts a custom call library and custom type id name registration function
48 // (types required by the library) to the execution engine symbols binding. This
49 // function automatically registeres type id symbols for all canonical types
50 // supported by the XLA runtime custom calls.
51 ExecutionEngine::SymbolsBinding ToSymbolsBinding(
52     DirectCustomCallLibrary lib, TypeIDNameRegistry::RegistrationFn types = {});
53 
54 class Executable {
55  public:
56   // Forward declare types defined below.
57   struct ArgumentsMemoryLayout;
58   struct ResultsMemoryLayout;
59   struct CallFrame;
60   struct ExecuteOpts;
61 
62   // Initializes call frame by adding all arguments according to the executable
63   // ABI. Also allocates storage for the returned values according to the
64   // results memory layout.
65   //
66   // If `verify_arguments` is true (in debug mode it's always on, independent of
67   // the argument value) this function also verifies that operands passed at run
68   // time matches the executable entrypoint signature (e.g. all statically known
69   // dimensions of the memrefs matches the operands). Returns an error if finds
70   // a mismatch.
71   //
72   // This function leaves the kernel context argument (the first argument of an
73   // entry function) uninitialized. It will be initialized in the `Execute`
74   // function right before the actual execution.
75   llvm::Error InitializeCallFrame(ArgumentsRef arguments, CallFrame* call_frame,
76                                   bool verify_arguments = true) const;
77 
78   // Converts returned values owned by the call frame using provided result
79   // converter. If compiled function execution finished with an error (error
80   // flag is `true` in the call frame) returns error for all results.
81   llvm::Error ReturnResults(const ResultConverter& results,
82                             CallFrame* call_frame) const;
83 
84   // Executes compiled function with given arguments.
85   //
86   // If `verify_arguments` is true (in debug mode it's always on, independent of
87   // the argument value) this function also verifies that arguments passed at
88   // run time matches the executable entrypoint signature. If some of the
89   // arguments do not match the expected type, this function allocates error
90   // async values for all results and returns an error.
91   //
92   // Returns compiled function results via the user-provided results converter.
93   // If execution completed in the error state, returns error for all results.
94   llvm::Error Execute(ArgumentsRef arguments, const ResultConverter& results,
95                       const ExecuteOpts& opts,
96                       bool verify_arguments = true) const;
97 
98   // Executes compiled function using user provided call frame.
99   //
100   // It is the caller responsibility to handle the compiled function results
101   // stored in the call frame.
102   void Execute(CallFrame& call_frame, const ExecuteOpts& opts) const;
103 
IsAsync()104   bool IsAsync() const { return results_memory_layout_.has_async_results; }
105 
name()106   llvm::StringRef name() const { return name_; }
107 
specialization()108   llvm::Optional<size_t> specialization() const { return specialization_; }
109 
110   // Returns the number of results in the runtime signature.
111   unsigned num_results() const;
112 
113   // Signature of the compiled module entrypoint function before lowering to
114   // the runtime dialects. See JitExecutable's `signature_` for more details.
115   const FunctionType& signature() const;
116 
117   // Signature of the compiled module entrypoint function after lowering it from
118   // high level dialects to the dialects supported by the XLA runtime.
119   // See JitExecutable's `signature_` for more details.
120   const FunctionType& runtime_signature() const;
121 
122   std::chrono::milliseconds time_to_compile() const;
123 
124   // Get the object file behind this executable (on linux for example, it will
125   // be https://en.wikipedia.org/wiki/Executable_and_Linkable_Format
126   // executable). Can be null.
127   std::unique_ptr<llvm::MemoryBuffer> obj_file() const;
128 
129   // CallFrame provides a pointer-stable storage for packed function arguments
130   // and storage for returned values.
131   struct CallFrame {
132     // Pointers to executable arguments.
133     llvm::SmallVector<void*, 32> args;
134 
135     // We use single block of memory to store executable results. We need to be
136     // able to store pointers to async values and tokens, and strided memrefs
137     // which at runtime are represented as StridedMemrefType<T, rank>.
138     //
139     // Currently we only need to provide result storage for pointers and memref
140     // sizes and strides (int64_t type). If we'll need to support more complex
141     // return types we'll have to be more careful about alignment requirements.
142     static_assert(sizeof(uintptr_t) == sizeof(int64_t),
143                   "uintptr_t size must be the same as int64_t");
144 
145     // Memory where the executable will write its results.
146     llvm::SmallVector<uint8_t, 128> results;
147 
148     // Tracks whether any of the outputs were set.
149     bool has_set_outputs = false;
150 
151     // Indicates whether the execution finished with an error.
152     bool is_error = false;
153 
154     // The error message which is available only if `is_error` is true. The
155     // assumption is that the error message string is owned by the compiled
156     // binary and the call frame can safely keep a non-owning pointer.
157     llvm::StringRef error;
158   };
159 
160   // Requirements for passing arguments to the compiled function.
161   struct ArgumentsMemoryLayout {
162     // Currently we always pass arguments as an array of pointers.
163     size_t num_args_ptrs = 0;
164   };
165 
166   // Requirements for the contiguous block of memory to store compiled function
167   // results. When we invoke a compiled fuction we allocate a block of memory,
168   // and pass pointers to pre-computed offsets as output arguments to the
169   // function.
170   struct ResultsMemoryLayout {
171     bool has_async_results = false;     // true iff returns async results
172     size_t size = 0;                    // number of bytes required
173     llvm::SmallVector<size_t> offsets;  // offsets in the block of memory
174   };
175 
176   struct ExecuteOpts {
177     // Async task runner for executing async runtime tasks. Typically it
178     // schedules async tasks into the underlying thread pool. It's the caller's
179     // responsibility to guarantee that it will outlive the execution of all
180     // async tasks started by the executable.
181     AsyncTaskRunner* async_task_runner = nullptr;
182 
183     // A container for passing arbitrary user-provided data to the custom call
184     // handlers. Must outlive all async tasks launched by this executable.
185     CustomCall::UserData* custom_call_data = nullptr;
186 
187     // Diagnostic engine is responsible for passing runtime diagnostics back
188     // to the caller through the diagnostic handler.
189     DiagnosticEngine* diagnostic_engine = nullptr;
190   };
191 
192   // Loads executable from an object file. It is the caller responsibility to
193   // guarantee that signatures do match the compiled function in the object
194   // file, otherwise it will surely lead to crash.
195   static llvm::Expected<Executable> LoadFromObjFile(
196       llvm::StringRef name, std::unique_ptr<llvm::MemoryBuffer> obj_file,
197       llvm::StringRef entrypoint, FunctionType signature,
198       FunctionType runtime_signature,
199       ExecutionEngine::SymbolsBinding symbols_binding = {},
200       llvm::StringRef memory_region_name = "");
201 
202   // Verifies that all operands types in the entrypoint function signature are
203   // supported at run time . Returns a pre-computed layout for the function
204   // arguments. If some arguments are not supported returns an error.
205   static llvm::Expected<ArgumentsMemoryLayout> GetArgumentsMemoryLayout(
206       const FunctionType& signature);
207 
208   // Verifies that all results types in the entrypoint function signature are
209   // supported at run time . Returns a pre-computed layout for the function
210   // results. If some results are not supported returns an error.
211   static llvm::Expected<ResultsMemoryLayout> GetResultsMemoryLayout(
212       const FunctionType& signature);
213 
214   // TODO(ezhulenev): The following three functions should be decoupled from
215   // the executable header file (maybe move them to runtime.h?) so that custom
216   // call implementations do not have to depend on the `executable` target.
217 
218   // Returns the user data passed via the ExecuteOpts to the executable.
219   static CustomCall::UserData* GetUserData(KernelContext* ctx);
220 
221   // Returns the diagnostic engine passed via the ExecuteOpts to the executable.
222   static DiagnosticEngine* GetDiagnosticEngine(KernelContext* ctx);
223 
224   // Calls the custom call handler with the given runtime context, arguments and
225   // attributes.
226   static LogicalResult Call(KernelContext* ctx, CustomCall& call, void** args,
227                             void** attrs);
228 
229  private:
230   friend class JitCompiler;  // see `mlir/transforms/runtime/compiler.h`
231 
Executable(llvm::StringRef name,std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper,std::unique_ptr<ExecutionEngine> engine,FunctionType signature,FunctionType runtime_signature,ArgumentsMemoryLayout arguments_memory_layout,ResultsMemoryLayout results_memory_layout,llvm::Optional<size_t> specialization,std::chrono::milliseconds time_to_compile)232   Executable(llvm::StringRef name,
233              std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper,
234              std::unique_ptr<ExecutionEngine> engine, FunctionType signature,
235              FunctionType runtime_signature,
236              ArgumentsMemoryLayout arguments_memory_layout,
237              ResultsMemoryLayout results_memory_layout,
238              llvm::Optional<size_t> specialization,
239              std::chrono::milliseconds time_to_compile)
240       : name_(name.str()),
241         memory_mapper_(std::move(memory_mapper)),
242         engine_(std::move(engine)),
243         fptr_(engine_->entrypoint()),
244         signature_(std::move(signature)),
245         runtime_signature_(std::move(runtime_signature)),
246         arguments_memory_layout_(std::move(arguments_memory_layout)),
247         results_memory_layout_(std::move(results_memory_layout)),
248         specialization_(specialization),
249         time_to_compile_(time_to_compile) {
250     assert(fptr_ != nullptr && "executable function pointer must be not null");
251   }
252 
253   std::string name_;  // name of the compiled executable
254 
255   // Called by `engine_`'s destructor; must appear before it.
256   std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper_;  // optional
257 
258   // XLA runtime execution engine owns the LLVM ORC jit compilation stack.
259   std::unique_ptr<ExecutionEngine> engine_;
260 
261   // Compiled function owned by the `engine_`.
262   ExecutionEngine::EntrypointFunctionPtr fptr_;
263 
264   // Signature of the compiled module entrypoint function before lowering to
265   // the runtime dialects (see JitExecutable `signature_` for more details).
266   FunctionType signature_;
267 
268   // Signature of the compiled module entrypoint function after lowering it from
269   // high level dialects to the dialects supported by the XLA runtime.
270   //
271   // - Operands and results types converted to the types with well-defined ABI
272   //   (e.g. tensors converted to memrefs).
273   //
274   // - First argument is always a kernel context added to the function by the
275   //   lowering pipeline.
276   //
277   // From this signature executable infers how to pack runtime operands
278   // according to the expected memory layout, and how to convert results
279   // returned from the JIT-compiled function into high level types (e.g. how to
280   // convert StridedMemrefType into Tensorflow Tensor).
281   //
282   // To infer the type of the returned value, executable looks at the type
283   // defined by the `runtime_signature_` to get the memory layout of the
284   // returned value, and at the type defined by the `signature_` to get the type
285   // expected by the runtime.
286   FunctionType runtime_signature_;
287 
288   ArgumentsMemoryLayout arguments_memory_layout_;
289   ResultsMemoryLayout results_memory_layout_;
290 
291   // Specialization id if this executable is a specialization, or an empty
292   // optional if this executable is a default one.
293   llvm::Optional<size_t> specialization_;
294 
295   // The time it took to compile this binary.
296   std::chrono::milliseconds time_to_compile_;
297 };
298 
299 // Escape slashes, substituting them with double underscores to get a memory
300 // region name for the XlaRuntimeMemoryMapper.
301 //
302 // The profiler's UI might interpret slashes as callchain separators,
303 // whereas we want the region name to be shown in full.
EscapeMemRegionName(llvm::StringRef memory_region_name)304 inline std::string EscapeMemRegionName(llvm::StringRef memory_region_name) {
305   return llvm::join(llvm::split(memory_region_name, '/'), "__");
306 }
307 
308 }  // namespace runtime
309 }  // namespace xla
310 
311 #endif  // XLA_RUNTIME_EXECUTABLE_H_
312