1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ 17 #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ 18 19 #include "absl/types/optional.h" 20 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project 22 #include "mlir/IR/Attributes.h" // from @llvm-project 23 #include "mlir/IR/Builders.h" // from @llvm-project 24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 26 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" 27 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h" 28 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 29 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 30 #include "tensorflow/compiler/xla/service/hlo_module.h" 31 #include "tensorflow/compiler/xla/shape_util.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 34 namespace mlir { 35 36 // This class will process an HloModule with the supplied BufferAssignment and 37 // populate the MLIR ModuleOp with the computation converted in the LHLO 38 // dialect. 39 class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault { 40 public: 41 // Initializes internal data structures. It must be called before calling any 42 // of the visitors. 43 tensorflow::Status Initialize(); 44 LhloDialectEmitter(const xla::BufferAssignment & assignment,const xla::HloComputation & computation,ModuleOp module)45 LhloDialectEmitter(const xla::BufferAssignment& assignment, 46 const xla::HloComputation& computation, ModuleOp module) 47 : assignment_(std::move(assignment)), 48 computation_(computation), 49 module_(module), 50 builder_(module.getContext()), 51 i8_type_(builder_.getIntegerType(8)) {} 52 53 xla::StatusOr<mlir::Operation*> EmitOp(const xla::HloInstruction* instr); 54 55 static xla::StatusOr<mhlo::ScatterDimensionNumbersAttr> 56 GetScatterDimensionNumbers(const xla::HloInstruction* instr, 57 mlir::MLIRContext* context); 58 59 private: 60 xla::StatusOr<lmhlo::SortOp> EmitSortOp(const xla::HloInstruction* instr); 61 xla::StatusOr<lmhlo::FusionOp> EmitFusionOp(const xla::HloInstruction* instr); 62 xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp( 63 const xla::HloInstruction* instr); 64 xla::StatusOr<lmhlo::SelectAndScatterOp> EmitSelectAndScatterOp( 65 const xla::HloInstruction* instr); 66 67 xla::StatusOr<Operation*> EmitCustomCallOp(const xla::HloInstruction* instr); 68 xla::StatusOr<lmhlo_gpu::CholeskyOp> EmitCholesky( 69 const xla::HloCustomCallInstruction* custom_call); 70 xla::StatusOr<Operation*> EmitGemm( 71 const xla::HloCustomCallInstruction* custom_call); 72 xla::StatusOr<Operation*> EmitCublasLtMatmul( 73 const xla::HloCustomCallInstruction* custom_call); 74 xla::StatusOr<Operation*> EmitDnnConvolution( 75 const xla::HloCustomCallInstruction* custom_call); 76 xla::StatusOr<Operation*> EmitDnnBatchNorm( 77 const xla::HloCustomCallInstruction* custom_call); 78 79 xla::StatusOr<memref::GetGlobalOp> EmitConstant( 80 const xla::HloInstruction* instr); 81 82 xla::StatusOr<lmhlo::InfeedOp> EmitInfeedOp(const xla::HloInstruction* instr); 83 xla::StatusOr<lmhlo::OutfeedOp> EmitOutfeedOp( 84 const xla::HloInstruction* instr); 85 86 xla::StatusOr<lmhlo::AllToAllOp> EmitAllToAllOp( 87 const xla::HloInstruction* instr); 88 xla::StatusOr<lmhlo::AllGatherOp> EmitAllGatherOp( 89 const xla::HloInstruction* instr); 90 xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp( 91 const xla::HloInstruction* instr); 92 xla::StatusOr<lmhlo_gpu::AllReduceStartOp> EmitAllReduceStartOp( 93 const xla::HloInstruction* instr); 94 xla::StatusOr<lmhlo_gpu::AllReduceDoneOp> EmitAllReduceDoneOp( 95 const xla::HloInstruction* instr); 96 xla::StatusOr<lmhlo::ReduceScatterOp> EmitReduceScatterOp( 97 const xla::HloInstruction* instr); 98 xla::StatusOr<lmhlo::CollectivePermuteOp> EmitCollectivePermuteOp( 99 const xla::HloInstruction* instr); 100 101 xla::StatusOr<lmhlo::RngGetAndUpdateStateOp> EmitRngGetAndUpdateStateOp( 102 const xla::HloInstruction* instr); 103 xla::StatusOr<lmhlo::FftOp> EmitFftOp(const xla::HloInstruction* instr); 104 xla::StatusOr<lmhlo::TriangularSolveOp> EmitTriangularSolveOp( 105 const xla::HloInstruction* instr); 106 xla::StatusOr<Operation*> EmitBitcast(const xla::HloInstruction* instr); 107 108 xla::StatusOr<lmhlo::CaseOp> EmitCaseOp(const xla::HloInstruction* instr); 109 110 xla::StatusOr<lmhlo::WhileOp> EmitWhileOp(const xla::HloInstruction* instr); 111 112 xla::Status ImportAsLmhloRegion(xla::HloComputation* computation, 113 mlir::Region* region); 114 115 // Since LMHLO dialect does not define token types, this enum controls how 116 // token operand/results from XLA:HLO are lowered to MLIR. 117 enum class TokenLoweringMode { 118 kFailToLower, // Fail lowering if token inputs are encountered. 119 kUseNull, // Use a null Value in the operand list for each token. 120 // kSkip, // Skip any token inputs or outputs (not yet needed) 121 }; 122 123 // Create LHLO operation operands given an XLA HLO instruction. By default, 124 // all XLA HLO operands and results are converted to MLIR and appended to 125 // `operands`. If `num_operands` is specified, only the first `num_operand` 126 // operands of the instruction are converted to MLIR. The function returns the 127 // actual number of operands and results generated for MLIR in `num_arguments` 128 // and `num_results`. 129 xla::Status CreateOperands(const xla::HloInstruction* instr, 130 std::optional<int64_t> num_operands, 131 TokenLoweringMode token_mode, 132 SmallVectorImpl<Value>& operands, 133 size_t& num_arguments, size_t& num_results); 134 135 template <typename OpType> 136 xla::StatusOr<OpType> CreateOpWithoutAttrs( 137 const xla::HloInstruction* instr, 138 std::optional<int64_t> num_operands = std::nullopt) { 139 size_t unused; 140 return CreateOpWithoutAttrs<OpType>(instr, unused, unused, num_operands); 141 } 142 143 template <typename OpType> 144 xla::StatusOr<OpType> CreateOpWithoutAttrs( 145 const xla::HloInstruction* instr, size_t& num_arguments, 146 size_t& num_results, std::optional<int64_t> num_operands = std::nullopt); 147 148 template <typename OpType> 149 OpType CreateOpWithoutAttrs(const xla::HloInstruction* instr, 150 ValueRange operands); 151 152 xla::StatusOr<mlir::Operation*> CreateOpInFusion( 153 const xla::HloInstruction* instr, ValueRange buffer_operands, 154 size_t num_arguments, size_t num_results); 155 156 xla::StatusOr<mlir::Operation*> CreateOpInFusion( 157 const xla::HloInstruction* instr); 158 159 template <typename T> GetI64DenseElementsAttr(const T & container)160 DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) { 161 return builder_.getI64TensorAttr( 162 {container.data(), static_cast<size_t>(container.size())}); 163 } 164 GetWindowElements(const xla::Window & window,std::function<int64_t (const xla::WindowDimension & dim)> getter)165 DenseIntElementsAttr GetWindowElements( 166 const xla::Window& window, 167 std::function<int64_t(const xla::WindowDimension& dim)> getter) { 168 llvm::SmallVector<int64_t, 4> elements; 169 elements.reserve(window.dimensions_size()); 170 for (const xla::WindowDimension& dim : window.dimensions()) { 171 elements.push_back(getter(dim)); 172 } 173 return GetI64DenseElementsAttr(elements); 174 } 175 176 static mlir::DenseIntElementsAttr GetLayoutAttribute( 177 const xla::Layout& layout, Builder* builder); 178 179 tensorflow::Status DefaultAction(const xla::HloInstruction* instr) final; 180 181 // Computation parameters don't need any specific handling when they are 182 // visited, they are already processed when we enter a new computation. HandleParameter(const xla::HloInstruction * instr)183 tensorflow::Status HandleParameter(const xla::HloInstruction* instr) final { 184 return ::tensorflow::OkStatus(); 185 } 186 187 // Helper function that recursively visits the tuple structure in 188 // `current_shape`, and reconstruct a matching lmhlo::TupleOp. 189 // Each leaf node is converted to an std.view op with corresponding offsets. 190 // If no tuple presents, it simply returns a view of the buffer. 191 tensorflow::Status GetOrCreateViewImpl(const xla::HloInstruction* instr, 192 const xla::Shape& current_shape, 193 xla::ShapeIndex* current_shape_index, 194 SmallVectorImpl<Value>* values, 195 TokenLoweringMode token_mode); 196 197 // Helper function to create view/tuple of views to a buffer for a given 198 // instruction result. `result_subset` can be used to for instructions that 199 // have a tuple result and MLIR conversion needs to convert only one of the 200 // tuple elements. Note that if needed, this can be extended to take a list of 201 // ShapeIndex values in case we need finer control on what elements of the 202 // output tuple to be converted to MLIR. 203 tensorflow::Status GetOrCreateView( 204 const xla::HloInstruction* instr, SmallVectorImpl<Value>* values, 205 const xla::ShapeIndex& result_subset = {}, 206 TokenLoweringMode token_mode = TokenLoweringMode::kFailToLower); 207 208 xla::StatusOr<Value> GetOrCreateArrayView( 209 const xla::HloInstruction* instr, const xla::Shape& current_shape, 210 const xla::ShapeIndex& current_shape_index); 211 212 xla::StatusOr<Value> RewriteFusionOperand(const xla::HloInstruction* root, 213 const xla::Shape& shape, 214 xla::ShapeIndex* shape_index, 215 OpBuilder* b, Location loc); 216 217 // Return an MLIR location for an HLO instruction. getLocation(const xla::HloInstruction * inst)218 Location getLocation(const xla::HloInstruction* inst) { 219 return NameLoc::get(builder_.getStringAttr(inst->name())); 220 } 221 222 // This map provides access to MLIR buffers for each HLO buffer allocation. 223 // The MLIR buffers are all `memref<{size}xi8>` and correspond to function 224 // parameters. It is populated at the beginning of the processing with all 225 // the buffer allocations and is unchanged afterward. Every HLOInstruction 226 // is using a "slice" of the buffer allocation and providing shape, layout, 227 // and Dtype. An MLIR view is used separately to model slices into the 228 // allocations (see below). 229 llvm::DenseMap<const xla::BufferAllocation*, Value> allocations_; 230 231 // This map provides access to MLIR buffers for each HLO instruction, keyed 232 // instruction identity. A slice is contained in a BufferAllocation, and has 233 // an offset and a size. 234 // 235 // As for why we don't use HloInstruction*, see GetOrCreateView(), but 236 // mostly we want to leverage better of the aliased buffers. 237 // 238 // If the HloInstruction is a tuple, all leaf nodes are stored flattened. 239 // Otherwise, there will be a single buffer. 240 // 241 // An MLIR buffer is either an input parameter, or a ViewOp in the case 242 // where the slice is only part of its allocation. 243 // 244 // `slices_` is populated lazily in the `GetOrCreateView()` helper as we 245 // process every instruction. 246 absl::flat_hash_map<std::pair<const xla::HloInstruction*, xla::ShapeIndex>, 247 Value> 248 slices_; 249 250 // The BufferAssignment computed by XLA ahead of time. 251 const xla::BufferAssignment& assignment_; 252 253 // The HLO module that will be converted. 254 const xla::HloComputation& computation_; 255 256 // This is the MLIR module in which a function will be created for every HLO 257 // computation. 258 ModuleOp module_; 259 260 // The builder keeps track of the current insertion point in the MLIR 261 // module. 262 OpBuilder builder_; 263 // Convenient "cached" access to this widely used MLIR type (i8). 264 Type i8_type_; 265 266 // Map all-reduce-start ops to their LHLO op, so we can connect the 267 // all-reduce-done op with the correct token. 268 absl::flat_hash_map<const xla::HloInstruction*, lmhlo_gpu::AllReduceStartOp> 269 all_reduce_start_ops_; 270 }; 271 272 // Populate the MLIR `module` with the computation from the `hlo_module` using 273 // the provided buffer `assignment`. The returned `Status` indicates success 274 // or failure in the conversion. 275 tensorflow::Status HloToLhloModule(const xla::BufferAssignment& assignment, 276 const xla::HloModule& hlo_module, 277 ModuleOp module); 278 279 tensorflow::Status OptimizeAndConvertHloToLmhlo( 280 std::unique_ptr<xla::HloModule> hlo_module, ModuleOp module, 281 StringRef platform_name, bool optimize_xla_hlo); 282 OwningOpRef<mlir::ModuleOp> HloTextToLhloTranslateFunction( 283 llvm::StringRef input, MLIRContext* context, bool optimize_xla_hlo); 284 285 // This register the MLIR pass with the command line. 286 void RegisterMhloToLhloWithXlaPass(); 287 288 } // namespace mlir 289 290 #endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ 291