1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/runtime/executable.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "llvm/ExecutionEngine/Orc/Core.h"
23 #include "llvm/Support/ErrorOr.h"
24 #include "tensorflow/compiler/xla/mlir/utils/runtime/async_runtime_api.h"
25 #include "tensorflow/compiler/xla/mlir/utils/runtime/c_runner_utils.h"
26 #include "tensorflow/compiler/xla/runtime/custom_call.h"
27 #include "tensorflow/compiler/xla/runtime/custom_call_registry.h"
28 #include "tensorflow/compiler/xla/runtime/errors.h"
29 #include "tensorflow/compiler/xla/runtime/runtime.h"
30 #include "tensorflow/compiler/xla/runtime/type_id.h"
31
32 namespace xla {
33 namespace runtime {
34
35 using llvm::dyn_cast;
36
37 using llvm::Error;
38 using llvm::ErrorOr;
39 using llvm::Expected;
40
41 using llvm::orc::MangleAndInterner;
42 using llvm::orc::SymbolMap;
43
44
45 // KernelContext encapsulates all the data that is required to implement XLA
46 // Runtime <-> XLA Executable integration API.
47 struct KernelContext {
48 // Results memory layout is owned by the executable, and stays alive after
49 // the entrypoint function execution completes.
50 const Executable::ResultsMemoryLayout* results_memory_layout = nullptr;
51
52 // CallFrame life time bound to the entrypoint function execution and
53 // destroyed immediately when the function returns. Only the entrypoint
54 // function itself reads the arguments and writes to the function results
55 // storage.
56 Executable::CallFrame* call_frame = nullptr;
57
58 // User-defined data for custom call handlers.
59 CustomCall::UserData* custom_call_data = nullptr;
60
61 // User-defined diagnostic engine for reporting diagnostics.
62 DiagnosticEngine* diagnostic_engine = nullptr;
63 };
64
65 //===----------------------------------------------------------------------===//
66 // Conversion from custom calls library and type id registry to symbols binding.
67 //===----------------------------------------------------------------------===//
68
ToSymbolsBinding(DirectCustomCallLibrary lib,TypeIDNameRegistry::RegistrationFn types)69 ExecutionEngine::SymbolsBinding ToSymbolsBinding(
70 DirectCustomCallLibrary lib, TypeIDNameRegistry::RegistrationFn types) {
71 return [=](MangleAndInterner mangle) {
72 SymbolMap symbol_map;
73
74 // Always register canonical custom call types with the registry.
75 TypeIDNameRegistry registry;
76 PopulateCustomCallTypeIdNames(registry);
77 if (types) types(registry);
78
79 // Register direct custom calls.
80 using DirectCustomCall = DirectCustomCallLibrary::DirectCustomCall;
81 lib.ForEach([&](llvm::StringRef name, DirectCustomCall custom_call) {
82 symbol_map[mangle(name)] = llvm::JITEvaluatedSymbol(
83 llvm::pointerToJITTargetAddress(custom_call), llvm::JITSymbolFlags());
84 });
85
86 // Register type id symbols.
87 registry.ForEach([&](llvm::StringRef name, TypeID type_id) {
88 auto type_id_ptr =
89 reinterpret_cast<std::uintptr_t>(type_id.getAsOpaquePointer());
90 symbol_map[mangle(name)] = llvm::JITEvaluatedSymbol(
91 static_cast<llvm::JITTargetAddress>(type_id_ptr),
92 llvm::JITSymbolFlags());
93 });
94
95 return symbol_map;
96 };
97 }
98
99 //===----------------------------------------------------------------------===//
100 // Register XLA runtime symbols with XLA execution engine.
101 //===----------------------------------------------------------------------===//
102
103 static SymbolMap RuntimeApiSymbolMap(MangleAndInterner);
104
105 //===----------------------------------------------------------------------===//
106 // Construct a symbols binding for XLA executable.
107 //===----------------------------------------------------------------------===//
108
RuntimeSymbolsBinding(ExecutionEngine::SymbolsBinding custom_binding)109 ExecutionEngine::SymbolsBinding RuntimeSymbolsBinding(
110 ExecutionEngine::SymbolsBinding custom_binding) {
111 return ExecutionEngine::BindAll(
112 {// Register MLIR C Runner API intrinsics (defined in CRunnerUtils).
113 CRunnerUtilsSymbolMap,
114 // Register Async Runtime API intrinsics.
115 AsyncRuntimeApiSymbolMap,
116 // Register memory allocation functions (malloc, free, ...).
117 AsyncRuntimeMemoryAllocationSymbolMap,
118 // Register Runtime API intrinsics (returning results and errors).
119 RuntimeApiSymbolMap,
120 // Register any additional user-defined symbol bindings
121 std::move(custom_binding)});
122 }
123
124 //===----------------------------------------------------------------------===//
125 // Get executable arguments and results memory layouts.
126 //===----------------------------------------------------------------------===//
127
128 /*static*/ Expected<Executable::ArgumentsMemoryLayout>
GetArgumentsMemoryLayout(const FunctionType & signature)129 Executable::GetArgumentsMemoryLayout(const FunctionType& signature) {
130 // Requirements for passing function arguments.
131 ArgumentsMemoryLayout layout;
132
133 for (unsigned i = 0; i < signature.num_operands(); ++i) {
134 const Type* type = signature.operand(i);
135
136 // Check if the type defines the ABI for passing it as an argument.
137 if (ErrorOr<Type::ArgumentAbi> abi = type->AsArgument()) {
138 layout.num_args_ptrs += abi->num_ptrs;
139 continue;
140 }
141
142 return MakeStringError("unknown operand #", i, " argument ABI: ", *type);
143 }
144
145 return layout;
146 }
147
148 /*static*/ Expected<Executable::ResultsMemoryLayout>
GetResultsMemoryLayout(const FunctionType & signature)149 Executable::GetResultsMemoryLayout(const FunctionType& signature) {
150 // Requirements for returning function results.
151 ResultsMemoryLayout layout;
152 layout.offsets.reserve(signature.num_results());
153
154 // TODO(ezhulenev): We should support allocating storage for results with non
155 // standard alignment requirements.
156
157 for (unsigned i = 0; i < signature.num_results(); ++i) {
158 const Type* type = signature.result(i);
159
160 // Keep track if the function has asynchronous results.
161 layout.has_async_results |= llvm::isa<AsyncTokenType, AsyncValueType>(type);
162
163 // Check if the type defines the ABI for returning it as a result.
164 if (ErrorOr<Type::ResultAbi> abi = type->AsResult()) {
165 layout.offsets.emplace_back(layout.size);
166 layout.size += abi->size;
167 continue;
168 }
169
170 return MakeStringError("unknown result #", i, " type result ABI: ", *type);
171 }
172
173 return layout;
174 }
175
176 //===----------------------------------------------------------------------===//
177 // Executable CallFrame initialization.
178 //===----------------------------------------------------------------------===//
179
180 // Always verify executable arguments in debug mode.
VerifyArguments(bool verify_arguments)181 static bool VerifyArguments(bool verify_arguments) {
182 #if defined(NDEBUG)
183 return verify_arguments;
184 #endif
185 return true;
186 }
187
InitializeCallFrame(ArgumentsRef arguments,CallFrame * call_frame,bool verify_arguments) const188 Error Executable::InitializeCallFrame(ArgumentsRef arguments,
189 CallFrame* call_frame,
190 bool verify_arguments) const {
191 // TODO(ezhulenev): If executable is specialized for concrete shapes then
192 // there is no need to verify them once more here. However currently we rely
193 // on a hash code to look up specializations, and this can lead to collisions.
194 if (VerifyArguments(verify_arguments)) {
195 // We verify run time arguments against the run time signature.
196 const FunctionType& signature = runtime_signature_;
197
198 // Make sure that we call the executable with the correct number of
199 // arguments. We subtract one argument from the signature because it
200 // corresponds to the context that we prepend to the given arguments.
201 if (LLVM_UNLIKELY(arguments.size() != signature.num_operands() - 1))
202 return MakeStringError(
203 "number of arguments doesn't match the function signature: ",
204 arguments.size(), " vs ", signature.num_operands() - 1);
205
206 // Verify that all arguments passed at runtime are compatible with compiled
207 // function signature.
208 auto kctx = dyn_cast<KernelContextOperandType>(signature.operand(0));
209 if (LLVM_UNLIKELY(!kctx)) {
210 return MakeStringError(
211 "expected KernelContext in first argument of signature, got: ",
212 signature.operand(0));
213 }
214
215 // We use 0-based index for arguments, because the kernel context argument
216 // is an internal implementation detail, and in case of an error users
217 // should get back argument index corresponding to the user provided
218 // signature.
219 for (unsigned i = 0; i < arguments.size(); ++i) {
220 unsigned idx = i + 1; // use 1-based index to fetch signature operand
221 if (auto err = arguments[i].Verify(*signature.operand(idx)))
222 return MakeStringError("argument #", i,
223 " doesn't match the signature: ", err);
224 }
225 }
226
227 size_t num_args_ptrs = arguments_memory_layout_.num_args_ptrs;
228 call_frame->args.resize_for_overwrite(num_args_ptrs);
229
230 // Add a placeholder for the kernel context as the first argument.
231 call_frame->args[0] = nullptr;
232
233 // Keep offset of the next argument in the `args` array, and update it every
234 // time we pack a new argument.
235 size_t offset = 1;
236
237 // Pack all arguments according to the ABI to the call frame arguments.
238 for (unsigned i = 0; i < arguments.size(); ++i)
239 offset = arguments[i].Pack(call_frame->args, offset);
240
241 assert(offset == num_args_ptrs &&
242 "reserved number of args must match the argument offset");
243
244 // Allocate storage for results.
245 call_frame->results.resize_for_overwrite(results_memory_layout_.size);
246
247 // Mark results memory initialized to supress potential msan errors.
248 ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(call_frame->results.data(),
249 call_frame->results.size());
250
251 return Error::success();
252 }
253
254 //===----------------------------------------------------------------------===//
255 // Execute the compiled XLA runtime executable.
256 //===----------------------------------------------------------------------===//
257
Execute(ArgumentsRef arguments,const ResultConverter & results,const ExecuteOpts & opts,bool verify_arguments) const258 Error Executable::Execute(ArgumentsRef arguments,
259 const ResultConverter& results,
260 const ExecuteOpts& opts,
261 bool verify_arguments) const {
262 // CallFrame can be allocated on the stack because compiled function will
263 // unpack all the arguments it needs, and async regions will not access
264 // the data after the initial function will return the result.
265 CallFrame call_frame;
266
267 // Touch every byte of the memref arguments, to trigger memory sanitizer error
268 // if some of the memrefs are already deallocated. Unfortunatelly sanitizers
269 // do not work inside the JIT compiled code, and compiled executables still
270 // can do out of bounds memory access, however this sanity check allows to
271 // catch obvious errors earlier.
272 #if defined(MEMORY_SANITIZER)
273 auto do_not_optimize = [&](const auto& value) -> void {
274 asm volatile("" : : "r,m"(value) : "memory");
275 };
276
277 for (unsigned i = 0; i < arguments.size(); ++i) {
278 auto* memref = dyn_cast<MemrefDesc>(&arguments[i]);
279 if (!memref) continue;
280
281 int64_t size_in_bytes = GetHostSize(memref->dtype());
282 for (int64_t size : memref->sizes()) size_in_bytes *= size;
283
284 uint8_t* data = static_cast<uint8_t*>(memref->data());
285 for (int64_t i = 0; i < size_in_bytes; ++i) {
286 uint8_t value = data[i];
287 do_not_optimize(value);
288 }
289 }
290 #endif
291
292 // Compiled function takes arguments and results as `void**` type erased
293 // pointer. See mlir::ExecutionEngine `packFunctionArguments` for the details.
294 if (auto err = InitializeCallFrame(arguments, &call_frame, verify_arguments))
295 return (results.ReturnError(err), std::move(err));
296
297 Execute(call_frame, opts);
298
299 // Convert compiled function return values into results.
300 if (auto err = ReturnResults(results, &call_frame)) return err;
301
302 return Error::success();
303 }
304
Execute(CallFrame & call_frame,const ExecuteOpts & opts) const305 void Executable::Execute(CallFrame& call_frame, const ExecuteOpts& opts) const {
306 // Set the AsyncRuntime to be used by all async tasks spawned by the
307 // executable.
308 AsyncRuntime::Set(AsyncRuntime(opts.async_task_runner));
309
310 // Runtime kernel context can be used only by the entrypoint function and can
311 // be safely allocated on the stack.
312 KernelContext kernel_context = {&results_memory_layout_, &call_frame,
313 opts.custom_call_data,
314 opts.diagnostic_engine};
315
316 // Override the kernel context argument.
317 KernelContext* kernel_context_ptr = &kernel_context;
318 assert(call_frame.args.size() == arguments_memory_layout_.num_args_ptrs);
319 assert(call_frame.args[0] == nullptr && "expected to see a placeholder");
320 call_frame.args[0] = &kernel_context_ptr;
321
322 // Call the compiled function.
323 (*fptr_)(call_frame.args.data());
324 }
325
ReturnResults(const ResultConverter & results,CallFrame * call_frame) const326 Error Executable::ReturnResults(const ResultConverter& results,
327 CallFrame* call_frame) const {
328 // If execution failed, forward error to all results.
329 if (call_frame->is_error) {
330 auto err = MakeStringError("run time error: ", call_frame->error);
331 return (results.ReturnError(err), std::move(err));
332 }
333
334 // Try to convert results using registered conversion functions.
335 bool converted = true;
336
337 for (unsigned i = 0; i < runtime_signature_.num_results(); ++i) {
338 const Type* type = signature_.result(i);
339 const Type* runtime_type = runtime_signature_.result(i);
340 void* ret = &call_frame->results[results_memory_layout_.offsets[i]];
341 bool res = mlir::succeeded(results.ReturnValue(i, type, runtime_type, ret));
342 converted = converted && res;
343 }
344
345 if (LLVM_UNLIKELY(!converted))
346 return MakeStringError("failed to convert all returned values");
347 else
348 return Error::success();
349 }
350
351 //===----------------------------------------------------------------------===//
352 // Load AOT compiled executable from an object file.
353 //===----------------------------------------------------------------------===//
354
LoadFromObjFile(llvm::StringRef name,std::unique_ptr<llvm::MemoryBuffer> obj_file,llvm::StringRef entrypoint,FunctionType signature,FunctionType runtime_signature,ExecutionEngine::SymbolsBinding symbols_binding,llvm::StringRef memory_region_name)355 /*static*/ Expected<Executable> Executable::LoadFromObjFile(
356 llvm::StringRef name, std::unique_ptr<llvm::MemoryBuffer> obj_file,
357 llvm::StringRef entrypoint, FunctionType signature,
358 FunctionType runtime_signature,
359 ExecutionEngine::SymbolsBinding symbols_binding,
360 llvm::StringRef memory_region_name) {
361 // Memory region name to mmap executable code.
362 std::string mapper_name = llvm::formatv(
363 "/xla_aot{0}{1}:@{2}::@{3}", memory_region_name.empty() ? "" : ":",
364 EscapeMemRegionName(memory_region_name), name, entrypoint);
365
366 // Custom memory mapper to tag memory allocated for XLA executables.
367 std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper =
368 XlaRuntimeMemoryMapper::Create(std::move(mapper_name));
369
370 // Construct options for the XLA execution engine.
371 ExecutionEngine::AotOptions options;
372 options.section_memory_mapper = memory_mapper.get();
373 options.symbols_binding = RuntimeSymbolsBinding(std::move(symbols_binding));
374
375 auto engine = ExecutionEngine::CreateFromObjFile(std::move(obj_file),
376 entrypoint, options);
377
378 // Get the memory layout for passing function arguments.
379 auto arguments_memory_layout = GetArgumentsMemoryLayout(runtime_signature);
380 if (auto err = arguments_memory_layout.takeError()) return std::move(err);
381
382 // Get the memory layout for returning function results.
383 auto results_memory_layout = GetResultsMemoryLayout(runtime_signature);
384 if (auto err = results_memory_layout.takeError()) return std::move(err);
385
386 return Executable(name.str(), std::move(memory_mapper), std::move(*engine),
387 std::move(signature), std::move(runtime_signature),
388 std::move(*arguments_memory_layout),
389 std::move(*results_memory_layout),
390 /*specialization=*/llvm::None,
391 /*time_to_compile*/ std::chrono::milliseconds(0));
392 }
393
394 //===----------------------------------------------------------------------===//
395
num_results() const396 unsigned Executable::num_results() const {
397 return runtime_signature_.num_results();
398 }
399
signature() const400 const FunctionType& Executable::signature() const { return signature_; }
401
runtime_signature() const402 const FunctionType& Executable::runtime_signature() const {
403 return runtime_signature_;
404 }
405
time_to_compile() const406 std::chrono::milliseconds Executable::time_to_compile() const {
407 return time_to_compile_;
408 }
409
obj_file() const410 std::unique_ptr<llvm::MemoryBuffer> Executable::obj_file() const {
411 return engine_->obj_file();
412 }
413
GetUserData(KernelContext * ctx)414 CustomCall::UserData* Executable::GetUserData(KernelContext* ctx) {
415 return ctx->custom_call_data;
416 }
417
GetDiagnosticEngine(KernelContext * ctx)418 DiagnosticEngine* Executable::GetDiagnosticEngine(KernelContext* ctx) {
419 return ctx->diagnostic_engine;
420 }
421
Call(KernelContext * ctx,class CustomCall & call,void ** args,void ** attrs)422 mlir::LogicalResult Executable::Call(KernelContext* ctx, class CustomCall& call,
423 void** args, void** attrs) {
424 return call.call(args, attrs, ctx->custom_call_data, ctx->diagnostic_engine);
425 }
426
427 //===----------------------------------------------------------------------===//
428 // Register XLA runtime symbols with XLA execution engine.
429 //===----------------------------------------------------------------------===//
430
RuntimeApiSymbolMap(MangleAndInterner mangle)431 SymbolMap RuntimeApiSymbolMap(MangleAndInterner mangle) {
432 SymbolMap symbol_map;
433
434 auto bind = [&](llvm::StringRef name, auto symbol_ptr) {
435 symbol_map[mangle(name)] = llvm::JITEvaluatedSymbol(
436 llvm::pointerToJITTargetAddress(symbol_ptr), llvm::JITSymbolFlags());
437 };
438
439 bind("runtimeGetResultStorage", &GetResultStorage);
440 bind("runtimeSetError", &SetError);
441 bind("runtimeCustomCall", &CustomCall);
442
443 return symbol_map;
444 }
445
446 //----------------------------------------------------------------------------//
447 // Implement XLA Runtime <-> XLA Executable integration API.
448 //----------------------------------------------------------------------------//
449
GetResultStorage(KernelContext * ctx,int64_t index)450 void* GetResultStorage(KernelContext* ctx, int64_t index) {
451 assert(ctx && "kernel context must be not null");
452 assert(!ctx->call_frame->is_error && "error must not be set");
453 size_t offset = ctx->results_memory_layout->offsets[index];
454 assert(offset < ctx->call_frame->results.size() && "offset is out of bounds");
455 ctx->call_frame->has_set_outputs = true;
456 return &ctx->call_frame->results[offset];
457 }
458
SetError(KernelContext * ctx,const char * error)459 void SetError(KernelContext* ctx, const char* error) {
460 assert(ctx && "kernel context must be not null");
461 assert(error && "runtime error must be not null");
462 assert(!ctx->call_frame->is_error && "error must be set only once");
463 assert(!ctx->call_frame->has_set_outputs && "outputs must be undefined");
464 ctx->call_frame->is_error = true;
465 ctx->call_frame->error = {error};
466 }
467
CustomCall(KernelContext * ctx,const char * target,void ** args,void ** attrs)468 bool CustomCall(KernelContext* ctx, const char* target, void** args,
469 void** attrs) {
470 assert(ctx && target && args && attrs && "all arguments must be not null");
471
472 // Default custom calls registry for the XLA executables.
473 static CustomCallRegistry* registry = []() {
474 auto* registry = new CustomCallRegistry();
475 RegisterStaticCustomCalls(registry);
476 return registry;
477 }();
478
479 auto* custom_call = registry->Find(target);
480 assert(custom_call && "custom call not found");
481 if (custom_call == nullptr) return false;
482
483 return succeeded(custom_call->call(args, attrs, ctx->custom_call_data,
484 ctx->diagnostic_engine));
485 }
486
487 } // namespace runtime
488 } // namespace xla
489