xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/hlo_function_importer.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
17 #define TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
18 
19 #include <unordered_map>
20 
21 #include "absl/types/optional.h"
22 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/Builders.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
28 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
30 #include "tensorflow/compiler/xla/comparison_util.h"
31 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
32 #include "tensorflow/compiler/xla/status.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace xla {
38 
39 class HloModule;
40 class HloComputation;
41 class HloInstruction;
42 class Shape;
43 
44 // HLO bounded dynamic shapes can be converted to either MLIR dynamic shapes
45 // (which lose the bound information) or casted to static shape using the
46 // bounds.
47 enum class DynamicShapeHandlingMode { kDynamic, kConvertToStatic };
48 
49 // Helper class for importing HloComputations.
50 class HloFunctionImporter {
51  public:
52   // Imports the given computation as a function in the given module. This also
53   // imports any computations referred by instructions in this computation.
54   static Status ImportAsFunc(
55       const xla::HloComputation& computation, mlir::ModuleOp module,
56       std::unordered_map<const xla::HloComputation*, mlir::func::FuncOp>*
57           function_map,
58       mlir::Builder* builder, bool is_main);
59 
60   // Imports the given hlo computation to the specified region. If
61   // 'flatten_region_arg_tuple' is true, then flatten the tuple-typed region
62   // argument(s) and return value(s).
63   static Status ImportAsRegion(const xla::HloComputation& computation,
64                                mlir::Region* region, mlir::Builder* builder,
65                                bool flatten_region_arg_tuple = false);
66 
67   // Imports the given computation to the given place specified by `builder`.
68   // `arguments` contains values for all parameters.
69   static StatusOr<mlir::Value> ImportInstructions(
70       const xla::HloComputation& computation,
71       const llvm::SmallVectorImpl<mlir::Value>& arguments,
72       mlir::OpBuilder* builder);
73 
74   static StatusOr<mlir::Operation*> ImportInstruction(
75       const xla::HloInstruction* instr,
76       const llvm::SmallVectorImpl<mlir::Value>& operands,
77       mlir::OpBuilder* builder,
78       DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic);
79 
80   static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape,
81                                llvm::StringRef attr_name);
82 
83   // TODO(b/179166199): move this to attribute_importer.h.
84   // Converts XLA instruction source target pairs to MLIR attribute.
85   static mlir::NamedAttribute ConvertSourceTargetPairs(
86       const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
87       mlir::Builder* builder);
88 
89   // TODO(b/179166199): move this to attribute_importer.h.
90   // Converts replica groups to attribute
91   static mlir::NamedAttribute ConvertReplicaGroups(
92       absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder);
93 
94   // For mlir::IfOp or mlir::CaseOp, replace the uses of their region's block
95   // arguments with 'implicit_operands'. Here | implicit_operands | == sum of
96   // the number of arguments in all the regions in IfOp or CaseOp.
97   void ReplaceBlockArgumentsWithImplicitOperands(
98       mlir::Operation* op, llvm::ArrayRef<mlir::Value> implicit_operands);
99 
100   // Create a TupleOp using the results of 'op' if 'type' is a mlir::TupleType.
101   // Otherwise, return 'op'.
102   mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder,
103                                             mlir::Location loc,
104                                             mlir::Operation* op,
105                                             mlir::Type type);
106 
107   // FlattenTupleType flattens the types in (nested) tuple-type 'type' and
108   // stores them in 'types'.
109   static void FlattenTupleType(
110       mlir::Type type, llvm::SmallVectorImpl<mlir::Type>& flattened_types);
111 
112   // FlattenTupleValue flattens the values in (nested) tuple-typed 'value' and
113   // stores them in 'flattened_values'.
114   static void FlattenTupleValue(
115       mlir::OpBuilder* func_builder, mlir::Location loc, mlir::Value value,
116       llvm::SmallVectorImpl<mlir::Value>& flattened_values);
117 
118   // CreateTupleValue creates a root TupleOp of (nested) tuple-type 'type' using
119   // the non-tuple-typed values in 'flatten_values'.
120   //
121   // e.g., Given 'flatten_values': [V1, V2, V3] &'type': tuple<T1,tuple<T1,T2>>,
122   //      The function returns %t2 such that:
123   //       %t1 = mhlo.tuple(V2,V3) : (T2,T3) -> tuple<T2,T3>
124   //       %t2 = mhlo.tuple(V1,%t1): (T1,tuple<T2,T3>) -> tuple<T1,tuple<T1,T2>>
125   //
126   // Note: 1. FlattenTupleValue and CreateTupleValue is a pair of functions to
127   //          resp. flatten and create tuples in the exact same order.
128   //       2. `flatten_values`, initially storing the flattened values, will be
129   //          mutated to a 0-length array by the end of function invocation.
130   static mlir::Value CreateTupleValue(
131       mlir::OpBuilder* func_builder, mlir::Location loc,
132       llvm::MutableArrayRef<mlir::Value>& flatten_values, mlir::Type type);
133 
134  private:
HloFunctionImporter(mlir::ModuleOp module,std::unordered_map<const xla::HloComputation *,mlir::func::FuncOp> * function_map,mlir::Builder * builder)135   HloFunctionImporter(mlir::ModuleOp module,
136                       std::unordered_map<const xla::HloComputation*,
137                                          mlir::func::FuncOp>* function_map,
138                       mlir::Builder* builder)
139       : context_(module.getContext()),
140         module_(module),
141         builder_(builder),
142         function_map_(function_map) {
143     context_->loadDialect<mlir::arith::ArithmeticDialect>();
144     context_->loadDialect<mlir::func::FuncDialect>();
145     context_->loadDialect<mlir::mhlo::MhloDialect>();
146   }
147 
148   // Imports the given computation as a new function, if it hasn't been already
149   // imported.
150   StatusOr<mlir::func::FuncOp> ImportAsFunc(
151       const xla::HloComputation& computation, bool is_main);
152 
153   // Imports the given computation in the specified region.
154   tensorflow::Status ImportAsRegion(const HloComputation& computation,
155                                     mlir::Region* region,
156                                     bool flatten_region_arg_tuple = false);
157 
158   // Imports instructions from the given computation in the specified block.
159   // Assumes that the block already has correct arguments populated.
160   tensorflow::Status ImportInstructions(const HloComputation& computation,
161                                         mlir::Block* block,
162                                         bool flatten_region_arg_tuple);
163   StatusOr<mlir::Value> ImportInstructionsImpl(
164       const xla::HloComputation& computation,
165       const llvm::SmallVectorImpl<mlir::Value>& arguments,
166       mlir::OpBuilder* builder);
167 
168   // Imports an instruction.
169   StatusOr<mlir::Operation*> ImportInstructionWithLayout(
170       const xla::HloInstruction* instruction,
171       const llvm::SmallVectorImpl<mlir::Value>& operands,
172       mlir::OpBuilder* func_builder,
173       DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic);
174 
175   StatusOr<mlir::Operation*> ImportInstructionImpl(
176       const HloInstruction* instruction,
177       const llvm::SmallVectorImpl<mlir::Value>& operands,
178       mlir::OpBuilder* func_builder,
179       DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic);
180 
181   // Gets the MLIR operand values from an HLO Instruction.
182   StatusOr<llvm::SmallVector<mlir::Value, 4>> GetOperands(
183       const xla::HloInstruction* instruction);
184 
185   // Converts xla Tensor type to the corresponding MLIR type.
186   StatusOr<mlir::RankedTensorType> ConvertTensorType(const xla::Shape& shape);
187 
188   // Converts an XLA shape/layout to the corresponding MLIR layout, in
189   // flattened_attr, while flattening the tuple layout.
190   Status ConvertShapeToMlirLayout(
191       const xla::Shape& shape,
192       llvm::SmallVectorImpl<mlir::Attribute>& flattened_attr);
193 
194   // Returns the output type of an HloInstruction.
195   StatusOr<mlir::Type> GetReturnType(const xla::HloInstruction* instruction);
196 
197   // Takes a list of HloInstructions and generates the list of types used for
198   // input, bypassing tuples to subsets.
199   Status GetMlirTypes(const std::vector<xla::HloInstruction*>& instructions,
200                       llvm::SmallVectorImpl<mlir::Type>* types);
201 
202   // Returns the Mlir Value for the corresponding HloInstruction.
203   StatusOr<mlir::Value> GetMlirValue(const xla::HloInstruction* instruction);
204 
205   // Converts an XLA ComparisonDirection to the corresponding MLIR attribute.
206   mlir::NamedAttribute ConvertComparisonDirection(
207       ComparisonDirection direction);
208 
209   // Converts an XLA Comparison::Type to the corresponding MLIR attribute.
210   mlir::NamedAttribute ConvertComparisonType(Comparison::Type type);
211 
212   // Converts the dimensions of an HLO instruction into an MLIR attribute.
213   mlir::DenseIntElementsAttr ConvertDimensions(
214       absl::Span<const int64_t> op_dimensions);
215 
216   // Converts Array ref to an DenseIntElementsAttr.
217   mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements);
218 
219   // Converts Array ref of bools to a DenseIntElementsAttr of I1 type.
220   mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<bool> elements);
221 
222   // Converts Array ref to padding attribute. Input is a flattened list of
223   // padding low and padding high for each of the spatial dimensions.
224   mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding);
225 
226   // Converts channel id to attribute
227   mlir::NamedAttribute ConvertChannelHandle(std::optional<int64_t> channel_id);
228 
229   // Convert use global device ids flag to attribute
230   mlir::NamedAttribute ConvertUseGlobalDeviceIds();
231 
232   // Converts channel handle to attribute
233   mlir::NamedAttribute ConvertChannelHandle(const xla::ChannelHandle& channel);
234 
235   mlir::MLIRContext* context_;
236   mlir::ModuleOp module_;
237   mlir::Builder* builder_;
238 
239   // Mapping from HloComputation to the created MLIR function.
240   std::unordered_map<const xla::HloComputation*, mlir::func::FuncOp>*
241       function_map_;
242 
243   // Mapping from HloInstructions to the associative MLIR values.
244   std::unordered_map<const xla::HloInstruction*, mlir::Value>
245       instruction_value_map_;
246 };
247 
248 }  // namespace xla
249 
250 #endif  // TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
251