1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_RUNTIME_EXECUTION_ENGINE_H_ 17 #define XLA_RUNTIME_EXECUTION_ENGINE_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <utility> 22 #include <vector> 23 24 #include "llvm/ADT/StringRef.h" 25 #include "llvm/ExecutionEngine/JITEventListener.h" 26 #include "llvm/ExecutionEngine/Orc/LLJIT.h" 27 #include "llvm/ExecutionEngine/SectionMemoryManager.h" 28 #include "llvm/Support/MemoryBuffer.h" 29 30 namespace xla { 31 namespace runtime { 32 33 // A pre-fabricated wrapper around ORC JIT stack for running XLA executables. 34 // 35 // It allows to run jit-compiled XLA executables, AOT compile them or load 36 // previously AOT compiled executables. 37 // 38 // XLA executable itself is responsible for the function signature verification, 39 // arguments packing according to the ABI, results decoding and linking the 40 // executable with runtime intrinsics. Execution engine only helps with setting 41 // up ORC JIT stack to support the execution, but itself doesn't know what it is 42 // executing. 43 class ExecutionEngine { 44 public: 45 // Pointer to a compiled XLA entrypoint function. 46 // 47 // XLA entrypoint function expects all arguments to be passed as an array of 48 // opaque pointers to the actual values. In C++ it would look like this: 49 // 50 // void entrypoint(int32_t arg0, float arg1, ...); 51 // 52 // void __xla_entrypoint(void** args) { 53 // int32_t arg0 = *reinterpret_cast<int32_t*>(args[0]); 54 // float arg1 = *reinterpret_cast<float*>(args[1]); 55 // ... 56 // entrypoint(arg0, arg1, ...); 57 // } 58 // 59 // This is required to avoid dealing with ABI of the compiled function. See 60 // `SetUpEntrypointFunction` for implementation details. 61 using EntrypointFunctionPtr = void (*)(void **); 62 63 // Callback to register symbols with the execution engine (e.g. to register 64 // custom runtime intrinsics for Gpu integration). 65 using SymbolsBinding = 66 std::function<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>; 67 68 // Callback to run optimization passes on the compiled LLVM module. 69 using OptimizingTransformer = std::function<llvm::Error(llvm::Module *)>; 70 71 // Callback to construct an optimizing transformer for the given options. 72 using MakeOptimizingTransformer = std::function<OptimizingTransformer( 73 unsigned opt_level, unsigned size_level, 74 llvm::TargetMachine *targetMachine)>; 75 76 // Compose multiple symbol bindings into a single symbol binding function. 77 static SymbolsBinding BindAll(std::vector<SymbolsBinding> bindings); 78 79 //------------------------------------------------------------------------- // 80 // Options for creating execution engine from an LLVM module. 81 //------------------------------------------------------------------------- // 82 83 struct JitOptions { 84 // User-provided codegen optimization level. 85 llvm::CodeGenOpt::Level opt_level = llvm::CodeGenOpt::Level::Default; 86 87 // User-provided target machine specification. 88 llvm::TargetMachine *target_machine = nullptr; 89 90 // User-provided builder for the optimizing transformer. 91 MakeOptimizingTransformer make_optimizing_transformer; 92 93 // User-provided memory mapper for allocating memory for executables. 94 llvm::SectionMemoryManager::MemoryMapper *section_memory_mapper = nullptr; 95 96 // User-provided bindings for symbols. 97 SymbolsBinding symbols_binding = nullptr; 98 99 // Notify the llvm's global GDB notifications listener. 100 bool enable_gdb_listener = true; 101 102 // Notify the llvm's global Perf notifications listener. 103 bool enable_perf_listener = true; 104 105 // Save compiled object file. 106 bool save_compiled_obj_file = true; 107 }; 108 109 // Creates a new execution engine by compiling the provided LLVM module to 110 // a native function using LLVM ORC stack. 111 static llvm::Expected<std::unique_ptr<ExecutionEngine>> CreateFromModule( 112 std::unique_ptr<llvm::LLVMContext> ctx, 113 std::unique_ptr<llvm::Module> module, llvm::StringRef entrypoint, 114 JitOptions options); 115 116 //------------------------------------------------------------------------- // 117 // Options for creating execution engine from an AOT compiled object file. 118 //------------------------------------------------------------------------- // 119 120 struct AotOptions { 121 // User-provided memory mapper for allocating memory for executables. 122 llvm::SectionMemoryManager::MemoryMapper *section_memory_mapper = nullptr; 123 124 // User-provided bindings for symbols. 125 SymbolsBinding symbols_binding = nullptr; 126 127 // Notify the llvm's global GDB notifications listener. 128 bool enable_gdb_listener = true; 129 130 // Notify the llvm's global Perf notifications listener. 131 bool enable_perf_listener = true; 132 }; 133 134 // Creates a new execution engine by loading AOT compiled XLA executable 135 // object file. 136 static llvm::Expected<std::unique_ptr<ExecutionEngine>> CreateFromObjFile( 137 std::unique_ptr<llvm::MemoryBuffer>, llvm::StringRef entrypoint, 138 AotOptions options); 139 140 //------------------------------------------------------------------------- // 141 142 // Returns a pointer to the XLA entrypoint function. entrypoint()143 EntrypointFunctionPtr entrypoint() const { return entrypoint_ptr_; } 144 145 // Return a memory buffer with a object file behind this execution engine. Can 146 // be null if execution engine didn't save the compiled object file. 147 std::unique_ptr<llvm::MemoryBuffer> obj_file() const; 148 149 private: 150 ExecutionEngine(bool enable_gdb_listener, bool enable_perf_listener); 151 152 // We build execution engine on top of the ORC LLJIT API, which owns all 153 // compiled/loaded object files and does the linking at run time. 154 std::unique_ptr<llvm::orc::LLJIT> jit_; 155 156 // Pointer to a resolved entrypoint function. 157 EntrypointFunctionPtr entrypoint_ptr_ = nullptr; 158 159 // Object file that has the compiled entrypoint function. Can be null. 160 std::unique_ptr<llvm::MemoryBuffer> obj_file_; 161 162 llvm::JITEventListener *gdb_listener_ = nullptr; 163 llvm::JITEventListener *perf_listener_ = nullptr; 164 }; 165 166 } // namespace runtime 167 } // namespace xla 168 169 #endif // XLA_RUNTIME_EXECUTION_ENGINE_H_ 170