1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_ 17 #define TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_ 18 19 #include <unordered_map> 20 21 #include "absl/types/optional.h" 22 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project 23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project 24 #include "mlir/IR/Attributes.h" // from @llvm-project 25 #include "mlir/IR/Builders.h" // from @llvm-project 26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 27 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 28 #include "mlir/IR/MLIRContext.h" // from @llvm-project 29 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" 30 #include "tensorflow/compiler/xla/comparison_util.h" 31 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" 32 #include "tensorflow/compiler/xla/status.h" 33 #include "tensorflow/compiler/xla/statusor.h" 34 #include "tensorflow/compiler/xla/xla_data.pb.h" 35 #include "tensorflow/core/platform/types.h" 36 37 namespace xla { 38 39 class HloModule; 40 class HloComputation; 41 class HloInstruction; 42 class Shape; 43 44 // HLO bounded dynamic shapes can be converted to either MLIR dynamic shapes 45 // (which lose the bound information) or casted to static shape using the 46 // bounds. 47 enum class DynamicShapeHandlingMode { kDynamic, kConvertToStatic }; 48 49 // Helper class for importing HloComputations. 50 class HloFunctionImporter { 51 public: 52 // Imports the given computation as a function in the given module. This also 53 // imports any computations referred by instructions in this computation. 54 static Status ImportAsFunc( 55 const xla::HloComputation& computation, mlir::ModuleOp module, 56 std::unordered_map<const xla::HloComputation*, mlir::func::FuncOp>* 57 function_map, 58 mlir::Builder* builder, bool is_main); 59 60 // Imports the given hlo computation to the specified region. If 61 // 'flatten_region_arg_tuple' is true, then flatten the tuple-typed region 62 // argument(s) and return value(s). 63 static Status ImportAsRegion(const xla::HloComputation& computation, 64 mlir::Region* region, mlir::Builder* builder, 65 bool flatten_region_arg_tuple = false); 66 67 // Imports the given computation to the given place specified by `builder`. 68 // `arguments` contains values for all parameters. 69 static StatusOr<mlir::Value> ImportInstructions( 70 const xla::HloComputation& computation, 71 const llvm::SmallVectorImpl<mlir::Value>& arguments, 72 mlir::OpBuilder* builder); 73 74 static StatusOr<mlir::Operation*> ImportInstruction( 75 const xla::HloInstruction* instr, 76 const llvm::SmallVectorImpl<mlir::Value>& operands, 77 mlir::OpBuilder* builder, 78 DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); 79 80 static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape, 81 llvm::StringRef attr_name); 82 83 // TODO(b/179166199): move this to attribute_importer.h. 84 // Converts XLA instruction source target pairs to MLIR attribute. 85 static mlir::NamedAttribute ConvertSourceTargetPairs( 86 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs, 87 mlir::Builder* builder); 88 89 // TODO(b/179166199): move this to attribute_importer.h. 90 // Converts replica groups to attribute 91 static mlir::NamedAttribute ConvertReplicaGroups( 92 absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder); 93 94 // For mlir::IfOp or mlir::CaseOp, replace the uses of their region's block 95 // arguments with 'implicit_operands'. Here | implicit_operands | == sum of 96 // the number of arguments in all the regions in IfOp or CaseOp. 97 void ReplaceBlockArgumentsWithImplicitOperands( 98 mlir::Operation* op, llvm::ArrayRef<mlir::Value> implicit_operands); 99 100 // Create a TupleOp using the results of 'op' if 'type' is a mlir::TupleType. 101 // Otherwise, return 'op'. 102 mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder, 103 mlir::Location loc, 104 mlir::Operation* op, 105 mlir::Type type); 106 107 // FlattenTupleType flattens the types in (nested) tuple-type 'type' and 108 // stores them in 'types'. 109 static void FlattenTupleType( 110 mlir::Type type, llvm::SmallVectorImpl<mlir::Type>& flattened_types); 111 112 // FlattenTupleValue flattens the values in (nested) tuple-typed 'value' and 113 // stores them in 'flattened_values'. 114 static void FlattenTupleValue( 115 mlir::OpBuilder* func_builder, mlir::Location loc, mlir::Value value, 116 llvm::SmallVectorImpl<mlir::Value>& flattened_values); 117 118 // CreateTupleValue creates a root TupleOp of (nested) tuple-type 'type' using 119 // the non-tuple-typed values in 'flatten_values'. 120 // 121 // e.g., Given 'flatten_values': [V1, V2, V3] &'type': tuple<T1,tuple<T1,T2>>, 122 // The function returns %t2 such that: 123 // %t1 = mhlo.tuple(V2,V3) : (T2,T3) -> tuple<T2,T3> 124 // %t2 = mhlo.tuple(V1,%t1): (T1,tuple<T2,T3>) -> tuple<T1,tuple<T1,T2>> 125 // 126 // Note: 1. FlattenTupleValue and CreateTupleValue is a pair of functions to 127 // resp. flatten and create tuples in the exact same order. 128 // 2. `flatten_values`, initially storing the flattened values, will be 129 // mutated to a 0-length array by the end of function invocation. 130 static mlir::Value CreateTupleValue( 131 mlir::OpBuilder* func_builder, mlir::Location loc, 132 llvm::MutableArrayRef<mlir::Value>& flatten_values, mlir::Type type); 133 134 private: HloFunctionImporter(mlir::ModuleOp module,std::unordered_map<const xla::HloComputation *,mlir::func::FuncOp> * function_map,mlir::Builder * builder)135 HloFunctionImporter(mlir::ModuleOp module, 136 std::unordered_map<const xla::HloComputation*, 137 mlir::func::FuncOp>* function_map, 138 mlir::Builder* builder) 139 : context_(module.getContext()), 140 module_(module), 141 builder_(builder), 142 function_map_(function_map) { 143 context_->loadDialect<mlir::arith::ArithmeticDialect>(); 144 context_->loadDialect<mlir::func::FuncDialect>(); 145 context_->loadDialect<mlir::mhlo::MhloDialect>(); 146 } 147 148 // Imports the given computation as a new function, if it hasn't been already 149 // imported. 150 StatusOr<mlir::func::FuncOp> ImportAsFunc( 151 const xla::HloComputation& computation, bool is_main); 152 153 // Imports the given computation in the specified region. 154 tensorflow::Status ImportAsRegion(const HloComputation& computation, 155 mlir::Region* region, 156 bool flatten_region_arg_tuple = false); 157 158 // Imports instructions from the given computation in the specified block. 159 // Assumes that the block already has correct arguments populated. 160 tensorflow::Status ImportInstructions(const HloComputation& computation, 161 mlir::Block* block, 162 bool flatten_region_arg_tuple); 163 StatusOr<mlir::Value> ImportInstructionsImpl( 164 const xla::HloComputation& computation, 165 const llvm::SmallVectorImpl<mlir::Value>& arguments, 166 mlir::OpBuilder* builder); 167 168 // Imports an instruction. 169 StatusOr<mlir::Operation*> ImportInstructionWithLayout( 170 const xla::HloInstruction* instruction, 171 const llvm::SmallVectorImpl<mlir::Value>& operands, 172 mlir::OpBuilder* func_builder, 173 DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); 174 175 StatusOr<mlir::Operation*> ImportInstructionImpl( 176 const HloInstruction* instruction, 177 const llvm::SmallVectorImpl<mlir::Value>& operands, 178 mlir::OpBuilder* func_builder, 179 DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); 180 181 // Gets the MLIR operand values from an HLO Instruction. 182 StatusOr<llvm::SmallVector<mlir::Value, 4>> GetOperands( 183 const xla::HloInstruction* instruction); 184 185 // Converts xla Tensor type to the corresponding MLIR type. 186 StatusOr<mlir::RankedTensorType> ConvertTensorType(const xla::Shape& shape); 187 188 // Converts an XLA shape/layout to the corresponding MLIR layout, in 189 // flattened_attr, while flattening the tuple layout. 190 Status ConvertShapeToMlirLayout( 191 const xla::Shape& shape, 192 llvm::SmallVectorImpl<mlir::Attribute>& flattened_attr); 193 194 // Returns the output type of an HloInstruction. 195 StatusOr<mlir::Type> GetReturnType(const xla::HloInstruction* instruction); 196 197 // Takes a list of HloInstructions and generates the list of types used for 198 // input, bypassing tuples to subsets. 199 Status GetMlirTypes(const std::vector<xla::HloInstruction*>& instructions, 200 llvm::SmallVectorImpl<mlir::Type>* types); 201 202 // Returns the Mlir Value for the corresponding HloInstruction. 203 StatusOr<mlir::Value> GetMlirValue(const xla::HloInstruction* instruction); 204 205 // Converts an XLA ComparisonDirection to the corresponding MLIR attribute. 206 mlir::NamedAttribute ConvertComparisonDirection( 207 ComparisonDirection direction); 208 209 // Converts an XLA Comparison::Type to the corresponding MLIR attribute. 210 mlir::NamedAttribute ConvertComparisonType(Comparison::Type type); 211 212 // Converts the dimensions of an HLO instruction into an MLIR attribute. 213 mlir::DenseIntElementsAttr ConvertDimensions( 214 absl::Span<const int64_t> op_dimensions); 215 216 // Converts Array ref to an DenseIntElementsAttr. 217 mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements); 218 219 // Converts Array ref of bools to a DenseIntElementsAttr of I1 type. 220 mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<bool> elements); 221 222 // Converts Array ref to padding attribute. Input is a flattened list of 223 // padding low and padding high for each of the spatial dimensions. 224 mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding); 225 226 // Converts channel id to attribute 227 mlir::NamedAttribute ConvertChannelHandle(std::optional<int64_t> channel_id); 228 229 // Convert use global device ids flag to attribute 230 mlir::NamedAttribute ConvertUseGlobalDeviceIds(); 231 232 // Converts channel handle to attribute 233 mlir::NamedAttribute ConvertChannelHandle(const xla::ChannelHandle& channel); 234 235 mlir::MLIRContext* context_; 236 mlir::ModuleOp module_; 237 mlir::Builder* builder_; 238 239 // Mapping from HloComputation to the created MLIR function. 240 std::unordered_map<const xla::HloComputation*, mlir::func::FuncOp>* 241 function_map_; 242 243 // Mapping from HloInstructions to the associative MLIR values. 244 std::unordered_map<const xla::HloInstruction*, mlir::Value> 245 instruction_value_map_; 246 }; 247 248 } // namespace xla 249 250 #endif // TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_ 251