1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef XLA_RUNTIME_EXECUTABLE_H_
17 #define XLA_RUNTIME_EXECUTABLE_H_
18
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/Support/MemoryBuffer.h"
25 #include "tensorflow/compiler/xla/runtime/arguments.h"
26 #include "tensorflow/compiler/xla/runtime/async_runtime.h"
27 #include "tensorflow/compiler/xla/runtime/custom_call.h"
28 #include "tensorflow/compiler/xla/runtime/diagnostics.h"
29 #include "tensorflow/compiler/xla/runtime/execution_engine.h"
30 #include "tensorflow/compiler/xla/runtime/logical_result.h"
31 #include "tensorflow/compiler/xla/runtime/memory_mapper.h"
32 #include "tensorflow/compiler/xla/runtime/results.h"
33 #include "tensorflow/compiler/xla/runtime/type_id.h"
34 #include "tensorflow/compiler/xla/runtime/types.h"
35
36 namespace xla {
37 namespace runtime {
38
39 class KernelContext;
40 class JitCompiler;
41
42 // Returns a symbols binding for running XLA executable with a custom symbols
43 // provided by the user.
44 ExecutionEngine::SymbolsBinding RuntimeSymbolsBinding(
45 ExecutionEngine::SymbolsBinding custom_binding);
46
47 // Converts a custom call library and custom type id name registration function
48 // (types required by the library) to the execution engine symbols binding. This
49 // function automatically registeres type id symbols for all canonical types
50 // supported by the XLA runtime custom calls.
51 ExecutionEngine::SymbolsBinding ToSymbolsBinding(
52 DirectCustomCallLibrary lib, TypeIDNameRegistry::RegistrationFn types = {});
53
54 class Executable {
55 public:
56 // Forward declare types defined below.
57 struct ArgumentsMemoryLayout;
58 struct ResultsMemoryLayout;
59 struct CallFrame;
60 struct ExecuteOpts;
61
62 // Initializes call frame by adding all arguments according to the executable
63 // ABI. Also allocates storage for the returned values according to the
64 // results memory layout.
65 //
66 // If `verify_arguments` is true (in debug mode it's always on, independent of
67 // the argument value) this function also verifies that operands passed at run
68 // time matches the executable entrypoint signature (e.g. all statically known
69 // dimensions of the memrefs matches the operands). Returns an error if finds
70 // a mismatch.
71 //
72 // This function leaves the kernel context argument (the first argument of an
73 // entry function) uninitialized. It will be initialized in the `Execute`
74 // function right before the actual execution.
75 llvm::Error InitializeCallFrame(ArgumentsRef arguments, CallFrame* call_frame,
76 bool verify_arguments = true) const;
77
78 // Converts returned values owned by the call frame using provided result
79 // converter. If compiled function execution finished with an error (error
80 // flag is `true` in the call frame) returns error for all results.
81 llvm::Error ReturnResults(const ResultConverter& results,
82 CallFrame* call_frame) const;
83
84 // Executes compiled function with given arguments.
85 //
86 // If `verify_arguments` is true (in debug mode it's always on, independent of
87 // the argument value) this function also verifies that arguments passed at
88 // run time matches the executable entrypoint signature. If some of the
89 // arguments do not match the expected type, this function allocates error
90 // async values for all results and returns an error.
91 //
92 // Returns compiled function results via the user-provided results converter.
93 // If execution completed in the error state, returns error for all results.
94 llvm::Error Execute(ArgumentsRef arguments, const ResultConverter& results,
95 const ExecuteOpts& opts,
96 bool verify_arguments = true) const;
97
98 // Executes compiled function using user provided call frame.
99 //
100 // It is the caller responsibility to handle the compiled function results
101 // stored in the call frame.
102 void Execute(CallFrame& call_frame, const ExecuteOpts& opts) const;
103
IsAsync()104 bool IsAsync() const { return results_memory_layout_.has_async_results; }
105
name()106 llvm::StringRef name() const { return name_; }
107
specialization()108 llvm::Optional<size_t> specialization() const { return specialization_; }
109
110 // Returns the number of results in the runtime signature.
111 unsigned num_results() const;
112
113 // Signature of the compiled module entrypoint function before lowering to
114 // the runtime dialects. See JitExecutable's `signature_` for more details.
115 const FunctionType& signature() const;
116
117 // Signature of the compiled module entrypoint function after lowering it from
118 // high level dialects to the dialects supported by the XLA runtime.
119 // See JitExecutable's `signature_` for more details.
120 const FunctionType& runtime_signature() const;
121
122 std::chrono::milliseconds time_to_compile() const;
123
124 // Get the object file behind this executable (on linux for example, it will
125 // be https://en.wikipedia.org/wiki/Executable_and_Linkable_Format
126 // executable). Can be null.
127 std::unique_ptr<llvm::MemoryBuffer> obj_file() const;
128
129 // CallFrame provides a pointer-stable storage for packed function arguments
130 // and storage for returned values.
131 struct CallFrame {
132 // Pointers to executable arguments.
133 llvm::SmallVector<void*, 32> args;
134
135 // We use single block of memory to store executable results. We need to be
136 // able to store pointers to async values and tokens, and strided memrefs
137 // which at runtime are represented as StridedMemrefType<T, rank>.
138 //
139 // Currently we only need to provide result storage for pointers and memref
140 // sizes and strides (int64_t type). If we'll need to support more complex
141 // return types we'll have to be more careful about alignment requirements.
142 static_assert(sizeof(uintptr_t) == sizeof(int64_t),
143 "uintptr_t size must be the same as int64_t");
144
145 // Memory where the executable will write its results.
146 llvm::SmallVector<uint8_t, 128> results;
147
148 // Tracks whether any of the outputs were set.
149 bool has_set_outputs = false;
150
151 // Indicates whether the execution finished with an error.
152 bool is_error = false;
153
154 // The error message which is available only if `is_error` is true. The
155 // assumption is that the error message string is owned by the compiled
156 // binary and the call frame can safely keep a non-owning pointer.
157 llvm::StringRef error;
158 };
159
160 // Requirements for passing arguments to the compiled function.
161 struct ArgumentsMemoryLayout {
162 // Currently we always pass arguments as an array of pointers.
163 size_t num_args_ptrs = 0;
164 };
165
166 // Requirements for the contiguous block of memory to store compiled function
167 // results. When we invoke a compiled fuction we allocate a block of memory,
168 // and pass pointers to pre-computed offsets as output arguments to the
169 // function.
170 struct ResultsMemoryLayout {
171 bool has_async_results = false; // true iff returns async results
172 size_t size = 0; // number of bytes required
173 llvm::SmallVector<size_t> offsets; // offsets in the block of memory
174 };
175
176 struct ExecuteOpts {
177 // Async task runner for executing async runtime tasks. Typically it
178 // schedules async tasks into the underlying thread pool. It's the caller's
179 // responsibility to guarantee that it will outlive the execution of all
180 // async tasks started by the executable.
181 AsyncTaskRunner* async_task_runner = nullptr;
182
183 // A container for passing arbitrary user-provided data to the custom call
184 // handlers. Must outlive all async tasks launched by this executable.
185 CustomCall::UserData* custom_call_data = nullptr;
186
187 // Diagnostic engine is responsible for passing runtime diagnostics back
188 // to the caller through the diagnostic handler.
189 DiagnosticEngine* diagnostic_engine = nullptr;
190 };
191
192 // Loads executable from an object file. It is the caller responsibility to
193 // guarantee that signatures do match the compiled function in the object
194 // file, otherwise it will surely lead to crash.
195 static llvm::Expected<Executable> LoadFromObjFile(
196 llvm::StringRef name, std::unique_ptr<llvm::MemoryBuffer> obj_file,
197 llvm::StringRef entrypoint, FunctionType signature,
198 FunctionType runtime_signature,
199 ExecutionEngine::SymbolsBinding symbols_binding = {},
200 llvm::StringRef memory_region_name = "");
201
202 // Verifies that all operands types in the entrypoint function signature are
203 // supported at run time . Returns a pre-computed layout for the function
204 // arguments. If some arguments are not supported returns an error.
205 static llvm::Expected<ArgumentsMemoryLayout> GetArgumentsMemoryLayout(
206 const FunctionType& signature);
207
208 // Verifies that all results types in the entrypoint function signature are
209 // supported at run time . Returns a pre-computed layout for the function
210 // results. If some results are not supported returns an error.
211 static llvm::Expected<ResultsMemoryLayout> GetResultsMemoryLayout(
212 const FunctionType& signature);
213
214 // TODO(ezhulenev): The following three functions should be decoupled from
215 // the executable header file (maybe move them to runtime.h?) so that custom
216 // call implementations do not have to depend on the `executable` target.
217
218 // Returns the user data passed via the ExecuteOpts to the executable.
219 static CustomCall::UserData* GetUserData(KernelContext* ctx);
220
221 // Returns the diagnostic engine passed via the ExecuteOpts to the executable.
222 static DiagnosticEngine* GetDiagnosticEngine(KernelContext* ctx);
223
224 // Calls the custom call handler with the given runtime context, arguments and
225 // attributes.
226 static LogicalResult Call(KernelContext* ctx, CustomCall& call, void** args,
227 void** attrs);
228
229 private:
230 friend class JitCompiler; // see `mlir/transforms/runtime/compiler.h`
231
Executable(llvm::StringRef name,std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper,std::unique_ptr<ExecutionEngine> engine,FunctionType signature,FunctionType runtime_signature,ArgumentsMemoryLayout arguments_memory_layout,ResultsMemoryLayout results_memory_layout,llvm::Optional<size_t> specialization,std::chrono::milliseconds time_to_compile)232 Executable(llvm::StringRef name,
233 std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper,
234 std::unique_ptr<ExecutionEngine> engine, FunctionType signature,
235 FunctionType runtime_signature,
236 ArgumentsMemoryLayout arguments_memory_layout,
237 ResultsMemoryLayout results_memory_layout,
238 llvm::Optional<size_t> specialization,
239 std::chrono::milliseconds time_to_compile)
240 : name_(name.str()),
241 memory_mapper_(std::move(memory_mapper)),
242 engine_(std::move(engine)),
243 fptr_(engine_->entrypoint()),
244 signature_(std::move(signature)),
245 runtime_signature_(std::move(runtime_signature)),
246 arguments_memory_layout_(std::move(arguments_memory_layout)),
247 results_memory_layout_(std::move(results_memory_layout)),
248 specialization_(specialization),
249 time_to_compile_(time_to_compile) {
250 assert(fptr_ != nullptr && "executable function pointer must be not null");
251 }
252
253 std::string name_; // name of the compiled executable
254
255 // Called by `engine_`'s destructor; must appear before it.
256 std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper_; // optional
257
258 // XLA runtime execution engine owns the LLVM ORC jit compilation stack.
259 std::unique_ptr<ExecutionEngine> engine_;
260
261 // Compiled function owned by the `engine_`.
262 ExecutionEngine::EntrypointFunctionPtr fptr_;
263
264 // Signature of the compiled module entrypoint function before lowering to
265 // the runtime dialects (see JitExecutable `signature_` for more details).
266 FunctionType signature_;
267
268 // Signature of the compiled module entrypoint function after lowering it from
269 // high level dialects to the dialects supported by the XLA runtime.
270 //
271 // - Operands and results types converted to the types with well-defined ABI
272 // (e.g. tensors converted to memrefs).
273 //
274 // - First argument is always a kernel context added to the function by the
275 // lowering pipeline.
276 //
277 // From this signature executable infers how to pack runtime operands
278 // according to the expected memory layout, and how to convert results
279 // returned from the JIT-compiled function into high level types (e.g. how to
280 // convert StridedMemrefType into Tensorflow Tensor).
281 //
282 // To infer the type of the returned value, executable looks at the type
283 // defined by the `runtime_signature_` to get the memory layout of the
284 // returned value, and at the type defined by the `signature_` to get the type
285 // expected by the runtime.
286 FunctionType runtime_signature_;
287
288 ArgumentsMemoryLayout arguments_memory_layout_;
289 ResultsMemoryLayout results_memory_layout_;
290
291 // Specialization id if this executable is a specialization, or an empty
292 // optional if this executable is a default one.
293 llvm::Optional<size_t> specialization_;
294
295 // The time it took to compile this binary.
296 std::chrono::milliseconds time_to_compile_;
297 };
298
299 // Escape slashes, substituting them with double underscores to get a memory
300 // region name for the XlaRuntimeMemoryMapper.
301 //
302 // The profiler's UI might interpret slashes as callchain separators,
303 // whereas we want the region name to be shown in full.
EscapeMemRegionName(llvm::StringRef memory_region_name)304 inline std::string EscapeMemRegionName(llvm::StringRef memory_region_name) {
305 return llvm::join(llvm::split(memory_region_name, '/'), "__");
306 }
307
308 } // namespace runtime
309 } // namespace xla
310
311 #endif // XLA_RUNTIME_EXECUTABLE_H_
312