xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir/transforms/runtime/jit_compiler.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_MLIR_RUNTIME_JIT_COMPILER_H_
17 #define XLA_MLIR_RUNTIME_JIT_COMPILER_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <string_view>
23 
24 #include "llvm/Support/Error.h"
25 #include "llvm/Support/SourceMgr.h"
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/OwningOpRef.h"  // from @llvm-project
28 #include "mlir/Pass/PassManager.h"  // from @llvm-project
29 #include "tensorflow/compiler/xla/mlir/transforms/runtime/calling_convention.h"
30 #include "tensorflow/compiler/xla/mlir/transforms/runtime/specialization.h"
31 #include "tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.h"
32 #include "tensorflow/compiler/xla/runtime/arguments.h"
33 #include "tensorflow/compiler/xla/runtime/constraints.h"
34 #include "tensorflow/compiler/xla/runtime/executable.h"
35 #include "tensorflow/compiler/xla/runtime/symbolic_shape.h"
36 #include "tfrt/support/error_util.h"  // from @tf_runtime
37 
38 namespace xla {
39 namespace runtime {
40 
41 // JitCompiler manages parsing, specialization and compilation of a single XLA
42 // module to the XLA runtime executable. It owns the MLIR context where the
43 // module is created, and handlers to capture all compilation diagnostics
44 // messages.
45 //
46 // TODO(ezhulenev): Allow constructing JitCompiler (and JitExecutable) from the
47 // MLIR module directly without serializing it to string first.
48 class JitCompiler {
49  public:
50   using SymbolicShape = SymbolicShapesResolver::SymbolicShape;
51 
52   struct Options {
53     // Register dialects that are allowed in the serialized module.
54     std::function<void(mlir::DialectRegistry&)> register_dialects;
55 
56     // Create a pass pipeline that is called whenever the compiled module
57     // gets specialized. This pipeline can use refined shape information and
58     // symbolic shape attributes to do the shape inference and canonicalization.
59     //
60     // Original input module might have an undefined calling convention (e.g.
61     // XLA runtime does not support unranked tensors), and specialization can be
62     // required as a precondition for compilation.
63     std::function<void(mlir::PassManager&)> create_specialization_pipeline;
64 
65     // Create a pass pipeline that lowers compiled module from high level
66     // dialects to the LLVM dialect. XLA runtime will use the LLVM ORC compiler
67     // API to compile the LLVM module at run time
68     // (https://llvm.org/docs/ORCv2.html).
69     //
70     // This compilation pipeline must create the entrypoint function with an ABI
71     // compatible with the calling convention advertised to the XLA through
72     // the `calling_convention` type conversion, and for that it usually must
73     // include `xla-rt-convert-to-entrypoint ` pass to convert regular functions
74     // to "XLA entrypoints".
75     std::function<void(mlir::PassManager&)> create_compilation_pipeline;
76 
77     // LLVM optimization level when JIT compiling a module.
78     llvm::CodeGenOpt::Level jit_code_opt_level =
79         llvm::CodeGenOpt::Level::Default;
80 
81     // Runtime symbols binding allows to pass user-defined bindings for symbols
82     // at JIT compilation time, e.g. to bind type ids or custom calls.
83     ExecutionEngine::SymbolsBinding symbols_binding;
84 
85     // Calling convention defines an ABI for XLA runtime to call an executable.
86     // See `CallingConvention` documentation for details.
87     CallingConvention calling_convention = DefaultCallingConvention();
88 
89     // Type converter converts MLIR types to the corresponding run time types.
90     // Executable uses its own type hierarchy, parallel to MLIR's, so that it
91     // doesn't depend on any parts of the MLIR after compilation produces an
92     // executable artifact, because keeping MLIR context alive can be expensive
93     // in terms of memory usage.
94     //
95     // As a side effect, it allows loading AOT compiled executables from the obj
96     // files without any dependencies on MLIR.
97     //
98     // Default type converter knows how to convert canonical MLIR types
99     // (memrefs, tensors, etc...). All user-defined types used at the compiled
100     // function boundary (arguments or results) should register a custom type
101     // conversion.
102     //
103     // When we compile the input IR, we first apply the `calling_convention` to
104     // get the MLIR function type for the entrypoint, and then we convert it to
105     // the corresponding run time function type.
106     TypeConverter type_converter;
107   };
108 
109   // Instantiates compiler from the serialized mlir source.
110   static llvm::Expected<std::unique_ptr<JitCompiler>> Instantiate(
111       Options opts, std::string_view mlir_module, std::string_view entrypoint);
112 
113   // Makes an executable from an instance of the JitCompiler. This is the end of
114   // life for the `JitCompiler`, it effectively converts the MLIR module
115   // to the executable (function pointer) using LLVM JIT code generation.
116   // Optional specialization identifier specifies if the compiled executable is
117   // a default one, or a specialization.
118   static llvm::Expected<Executable> Compile(
119       std::unique_ptr<JitCompiler> compiler,
120       std::string_view memory_region_name,
121       llvm::Optional<size_t> specialization = llvm::None);
122 
123   // Specialize compiled module to the arguments:
124   //
125   // - update all unknown dimensions according to the resolved symbolic shapes
126   // - attach symbolic shape attribute to the operands
127   // - sink small constants into the function body
128   //
129   // After entrypoint signature is updated, and all constant arguments
130   // materialized in the function body, runs the user-provided specialization
131   // pipeline to optimize the module based on the new information in the IR.
132   //
133   // Returns error if arguments are not compatible with compiled module
134   // entrypoint signature.
135   llvm::Error Specialize(ArgumentsRef arguments,
136                          llvm::ArrayRef<SymbolicShape> symbolic_shapes,
137                          llvm::ArrayRef<ArgumentConstraint> constraints,
138                          const SpecializationListener* listener = nullptr);
139 
options()140   const Options& options() const { return opts_; }
141 
name()142   llvm::StringRef name() const {
143     return module().getName().value_or("<unknown>");
144   }
145 
module()146   mlir::ModuleOp module() const {
147     assert(module_ && "failed to parse the mlir module");
148     return *module_;
149   }
150 
entrypoint()151   mlir::func::FuncOp entrypoint() const {
152     assert(entrypoint_ && "failed to resolve entrypoint function");
153     return entrypoint_;
154   }
155 
156  private:
157   JitCompiler(Options opts, std::string_view mlir_module,
158               std::string_view entrypoint);
159 
160   template <typename OriginalError>
Error(OriginalError original_error)161   llvm::Error Error(OriginalError original_error) {
162     return tfrt::MakeStringError(original_error, ":\n", diagnostic_);
163   }
164 
165   Options opts_;
166   std::unique_ptr<mlir::MLIRContext> context_;
167 
168   std::string diagnostic_;
169   llvm::raw_string_ostream diagnostic_os_;
170 
171   llvm::SourceMgr source_mgr_;
172   mlir::SourceMgrDiagnosticHandler handler_;
173 
174   mlir::OwningOpRef<mlir::ModuleOp> module_;  // can be null if failed to parse
175   mlir::func::FuncOp entrypoint_;             // can be null if failed to parse
176 
177   bool specialized_;
178 };
179 
180 }  // namespace runtime
181 }  // namespace xla
182 
183 #endif  // XLA_MLIR_RUNTIME_JIT_COMPILER_H_
184