xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/runtime/execution_engine.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_RUNTIME_EXECUTION_ENGINE_H_
17 #define XLA_RUNTIME_EXECUTION_ENGINE_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/ExecutionEngine/JITEventListener.h"
26 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
27 #include "llvm/ExecutionEngine/SectionMemoryManager.h"
28 #include "llvm/Support/MemoryBuffer.h"
29 
30 namespace xla {
31 namespace runtime {
32 
33 // A pre-fabricated wrapper around ORC JIT stack for running XLA executables.
34 //
35 // It allows to run jit-compiled XLA executables, AOT compile them or load
36 // previously AOT compiled executables.
37 //
38 // XLA executable itself is responsible for the function signature verification,
39 // arguments packing according to the ABI, results decoding and linking the
40 // executable with runtime intrinsics. Execution engine only helps with setting
41 // up ORC JIT stack to support the execution, but itself doesn't know what it is
42 // executing.
43 class ExecutionEngine {
44  public:
45   // Pointer to a compiled XLA entrypoint function.
46   //
47   // XLA entrypoint function expects all arguments to be passed as an array of
48   // opaque pointers to the actual values. In C++ it would look like this:
49   //
50   //   void entrypoint(int32_t arg0, float arg1, ...);
51   //
52   //   void __xla_entrypoint(void** args) {
53   //      int32_t arg0 = *reinterpret_cast<int32_t*>(args[0]);
54   //      float arg1 = *reinterpret_cast<float*>(args[1]);
55   //      ...
56   //      entrypoint(arg0, arg1, ...);
57   //   }
58   //
59   // This is required to avoid dealing with ABI of the compiled function. See
60   // `SetUpEntrypointFunction` for implementation details.
61   using EntrypointFunctionPtr = void (*)(void **);
62 
63   // Callback to register symbols with the execution engine (e.g. to register
64   // custom runtime intrinsics for Gpu integration).
65   using SymbolsBinding =
66       std::function<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>;
67 
68   // Callback to run optimization passes on the compiled LLVM module.
69   using OptimizingTransformer = std::function<llvm::Error(llvm::Module *)>;
70 
71   // Callback to construct an optimizing transformer for the given options.
72   using MakeOptimizingTransformer = std::function<OptimizingTransformer(
73       unsigned opt_level, unsigned size_level,
74       llvm::TargetMachine *targetMachine)>;
75 
76   // Compose multiple symbol bindings into a single symbol binding function.
77   static SymbolsBinding BindAll(std::vector<SymbolsBinding> bindings);
78 
79   //------------------------------------------------------------------------- //
80   // Options for creating execution engine from an LLVM module.
81   //------------------------------------------------------------------------- //
82 
83   struct JitOptions {
84     // User-provided codegen optimization level.
85     llvm::CodeGenOpt::Level opt_level = llvm::CodeGenOpt::Level::Default;
86 
87     // User-provided target machine specification.
88     llvm::TargetMachine *target_machine = nullptr;
89 
90     // User-provided builder for the optimizing transformer.
91     MakeOptimizingTransformer make_optimizing_transformer;
92 
93     // User-provided memory mapper for allocating memory for executables.
94     llvm::SectionMemoryManager::MemoryMapper *section_memory_mapper = nullptr;
95 
96     // User-provided bindings for symbols.
97     SymbolsBinding symbols_binding = nullptr;
98 
99     // Notify the llvm's global GDB notifications listener.
100     bool enable_gdb_listener = true;
101 
102     // Notify the llvm's global Perf notifications listener.
103     bool enable_perf_listener = true;
104 
105     // Save compiled object file.
106     bool save_compiled_obj_file = true;
107   };
108 
109   // Creates a new execution engine by compiling the provided LLVM module to
110   // a native function using LLVM ORC stack.
111   static llvm::Expected<std::unique_ptr<ExecutionEngine>> CreateFromModule(
112       std::unique_ptr<llvm::LLVMContext> ctx,
113       std::unique_ptr<llvm::Module> module, llvm::StringRef entrypoint,
114       JitOptions options);
115 
116   //------------------------------------------------------------------------- //
117   // Options for creating execution engine from an AOT compiled object file.
118   //------------------------------------------------------------------------- //
119 
120   struct AotOptions {
121     // User-provided memory mapper for allocating memory for executables.
122     llvm::SectionMemoryManager::MemoryMapper *section_memory_mapper = nullptr;
123 
124     // User-provided bindings for symbols.
125     SymbolsBinding symbols_binding = nullptr;
126 
127     // Notify the llvm's global GDB notifications listener.
128     bool enable_gdb_listener = true;
129 
130     // Notify the llvm's global Perf notifications listener.
131     bool enable_perf_listener = true;
132   };
133 
134   // Creates a new execution engine by loading AOT compiled XLA executable
135   // object file.
136   static llvm::Expected<std::unique_ptr<ExecutionEngine>> CreateFromObjFile(
137       std::unique_ptr<llvm::MemoryBuffer>, llvm::StringRef entrypoint,
138       AotOptions options);
139 
140   //------------------------------------------------------------------------- //
141 
142   // Returns a pointer to the XLA entrypoint function.
entrypoint()143   EntrypointFunctionPtr entrypoint() const { return entrypoint_ptr_; }
144 
145   // Return a memory buffer with a object file behind this execution engine. Can
146   // be null if execution engine didn't save the compiled object file.
147   std::unique_ptr<llvm::MemoryBuffer> obj_file() const;
148 
149  private:
150   ExecutionEngine(bool enable_gdb_listener, bool enable_perf_listener);
151 
152   // We build execution engine on top of the ORC LLJIT API, which owns all
153   // compiled/loaded object files and does the linking at run time.
154   std::unique_ptr<llvm::orc::LLJIT> jit_;
155 
156   // Pointer to a resolved entrypoint function.
157   EntrypointFunctionPtr entrypoint_ptr_ = nullptr;
158 
159   // Object file that has the compiled entrypoint function. Can be null.
160   std::unique_ptr<llvm::MemoryBuffer> obj_file_;
161 
162   llvm::JITEventListener *gdb_listener_ = nullptr;
163   llvm::JITEventListener *perf_listener_ = nullptr;
164 };
165 
166 }  // namespace runtime
167 }  // namespace xla
168 
169 #endif  // XLA_RUNTIME_EXECUTION_ENGINE_H_
170