1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_MLIR_RUNTIME_JIT_COMPILER_H_ 17 #define XLA_MLIR_RUNTIME_JIT_COMPILER_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <string> 22 #include <string_view> 23 24 #include "llvm/Support/Error.h" 25 #include "llvm/Support/SourceMgr.h" 26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 27 #include "mlir/IR/OwningOpRef.h" // from @llvm-project 28 #include "mlir/Pass/PassManager.h" // from @llvm-project 29 #include "tensorflow/compiler/xla/mlir/transforms/runtime/calling_convention.h" 30 #include "tensorflow/compiler/xla/mlir/transforms/runtime/specialization.h" 31 #include "tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.h" 32 #include "tensorflow/compiler/xla/runtime/arguments.h" 33 #include "tensorflow/compiler/xla/runtime/constraints.h" 34 #include "tensorflow/compiler/xla/runtime/executable.h" 35 #include "tensorflow/compiler/xla/runtime/symbolic_shape.h" 36 #include "tfrt/support/error_util.h" // from @tf_runtime 37 38 namespace xla { 39 namespace runtime { 40 41 // JitCompiler manages parsing, specialization and compilation of a single XLA 42 // module to the XLA runtime executable. It owns the MLIR context where the 43 // module is created, and handlers to capture all compilation diagnostics 44 // messages. 45 // 46 // TODO(ezhulenev): Allow constructing JitCompiler (and JitExecutable) from the 47 // MLIR module directly without serializing it to string first. 48 class JitCompiler { 49 public: 50 using SymbolicShape = SymbolicShapesResolver::SymbolicShape; 51 52 struct Options { 53 // Register dialects that are allowed in the serialized module. 54 std::function<void(mlir::DialectRegistry&)> register_dialects; 55 56 // Create a pass pipeline that is called whenever the compiled module 57 // gets specialized. This pipeline can use refined shape information and 58 // symbolic shape attributes to do the shape inference and canonicalization. 59 // 60 // Original input module might have an undefined calling convention (e.g. 61 // XLA runtime does not support unranked tensors), and specialization can be 62 // required as a precondition for compilation. 63 std::function<void(mlir::PassManager&)> create_specialization_pipeline; 64 65 // Create a pass pipeline that lowers compiled module from high level 66 // dialects to the LLVM dialect. XLA runtime will use the LLVM ORC compiler 67 // API to compile the LLVM module at run time 68 // (https://llvm.org/docs/ORCv2.html). 69 // 70 // This compilation pipeline must create the entrypoint function with an ABI 71 // compatible with the calling convention advertised to the XLA through 72 // the `calling_convention` type conversion, and for that it usually must 73 // include `xla-rt-convert-to-entrypoint ` pass to convert regular functions 74 // to "XLA entrypoints". 75 std::function<void(mlir::PassManager&)> create_compilation_pipeline; 76 77 // LLVM optimization level when JIT compiling a module. 78 llvm::CodeGenOpt::Level jit_code_opt_level = 79 llvm::CodeGenOpt::Level::Default; 80 81 // Runtime symbols binding allows to pass user-defined bindings for symbols 82 // at JIT compilation time, e.g. to bind type ids or custom calls. 83 ExecutionEngine::SymbolsBinding symbols_binding; 84 85 // Calling convention defines an ABI for XLA runtime to call an executable. 86 // See `CallingConvention` documentation for details. 87 CallingConvention calling_convention = DefaultCallingConvention(); 88 89 // Type converter converts MLIR types to the corresponding run time types. 90 // Executable uses its own type hierarchy, parallel to MLIR's, so that it 91 // doesn't depend on any parts of the MLIR after compilation produces an 92 // executable artifact, because keeping MLIR context alive can be expensive 93 // in terms of memory usage. 94 // 95 // As a side effect, it allows loading AOT compiled executables from the obj 96 // files without any dependencies on MLIR. 97 // 98 // Default type converter knows how to convert canonical MLIR types 99 // (memrefs, tensors, etc...). All user-defined types used at the compiled 100 // function boundary (arguments or results) should register a custom type 101 // conversion. 102 // 103 // When we compile the input IR, we first apply the `calling_convention` to 104 // get the MLIR function type for the entrypoint, and then we convert it to 105 // the corresponding run time function type. 106 TypeConverter type_converter; 107 }; 108 109 // Instantiates compiler from the serialized mlir source. 110 static llvm::Expected<std::unique_ptr<JitCompiler>> Instantiate( 111 Options opts, std::string_view mlir_module, std::string_view entrypoint); 112 113 // Makes an executable from an instance of the JitCompiler. This is the end of 114 // life for the `JitCompiler`, it effectively converts the MLIR module 115 // to the executable (function pointer) using LLVM JIT code generation. 116 // Optional specialization identifier specifies if the compiled executable is 117 // a default one, or a specialization. 118 static llvm::Expected<Executable> Compile( 119 std::unique_ptr<JitCompiler> compiler, 120 std::string_view memory_region_name, 121 llvm::Optional<size_t> specialization = llvm::None); 122 123 // Specialize compiled module to the arguments: 124 // 125 // - update all unknown dimensions according to the resolved symbolic shapes 126 // - attach symbolic shape attribute to the operands 127 // - sink small constants into the function body 128 // 129 // After entrypoint signature is updated, and all constant arguments 130 // materialized in the function body, runs the user-provided specialization 131 // pipeline to optimize the module based on the new information in the IR. 132 // 133 // Returns error if arguments are not compatible with compiled module 134 // entrypoint signature. 135 llvm::Error Specialize(ArgumentsRef arguments, 136 llvm::ArrayRef<SymbolicShape> symbolic_shapes, 137 llvm::ArrayRef<ArgumentConstraint> constraints, 138 const SpecializationListener* listener = nullptr); 139 options()140 const Options& options() const { return opts_; } 141 name()142 llvm::StringRef name() const { 143 return module().getName().value_or("<unknown>"); 144 } 145 module()146 mlir::ModuleOp module() const { 147 assert(module_ && "failed to parse the mlir module"); 148 return *module_; 149 } 150 entrypoint()151 mlir::func::FuncOp entrypoint() const { 152 assert(entrypoint_ && "failed to resolve entrypoint function"); 153 return entrypoint_; 154 } 155 156 private: 157 JitCompiler(Options opts, std::string_view mlir_module, 158 std::string_view entrypoint); 159 160 template <typename OriginalError> Error(OriginalError original_error)161 llvm::Error Error(OriginalError original_error) { 162 return tfrt::MakeStringError(original_error, ":\n", diagnostic_); 163 } 164 165 Options opts_; 166 std::unique_ptr<mlir::MLIRContext> context_; 167 168 std::string diagnostic_; 169 llvm::raw_string_ostream diagnostic_os_; 170 171 llvm::SourceMgr source_mgr_; 172 mlir::SourceMgrDiagnosticHandler handler_; 173 174 mlir::OwningOpRef<mlir::ModuleOp> module_; // can be null if failed to parse 175 mlir::func::FuncOp entrypoint_; // can be null if failed to parse 176 177 bool specialized_; 178 }; 179 180 } // namespace runtime 181 } // namespace xla 182 183 #endif // XLA_MLIR_RUNTIME_JIT_COMPILER_H_ 184