xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
17 #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
18 
19 #include "absl/types/optional.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
27 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
28 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
29 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 
34 namespace mlir {
35 
36 // This class will process an HloModule with the supplied BufferAssignment and
37 // populate the MLIR ModuleOp with the computation converted in the LHLO
38 // dialect.
39 class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault {
40  public:
41   // Initializes internal data structures. It must be called before calling any
42   // of the visitors.
43   tensorflow::Status Initialize();
44 
LhloDialectEmitter(const xla::BufferAssignment & assignment,const xla::HloComputation & computation,ModuleOp module)45   LhloDialectEmitter(const xla::BufferAssignment& assignment,
46                      const xla::HloComputation& computation, ModuleOp module)
47       : assignment_(std::move(assignment)),
48         computation_(computation),
49         module_(module),
50         builder_(module.getContext()),
51         i8_type_(builder_.getIntegerType(8)) {}
52 
53   xla::StatusOr<mlir::Operation*> EmitOp(const xla::HloInstruction* instr);
54 
55   static xla::StatusOr<mhlo::ScatterDimensionNumbersAttr>
56   GetScatterDimensionNumbers(const xla::HloInstruction* instr,
57                              mlir::MLIRContext* context);
58 
59  private:
60   xla::StatusOr<lmhlo::SortOp> EmitSortOp(const xla::HloInstruction* instr);
61   xla::StatusOr<lmhlo::FusionOp> EmitFusionOp(const xla::HloInstruction* instr);
62   xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(
63       const xla::HloInstruction* instr);
64   xla::StatusOr<lmhlo::SelectAndScatterOp> EmitSelectAndScatterOp(
65       const xla::HloInstruction* instr);
66 
67   xla::StatusOr<Operation*> EmitCustomCallOp(const xla::HloInstruction* instr);
68   xla::StatusOr<lmhlo_gpu::CholeskyOp> EmitCholesky(
69       const xla::HloCustomCallInstruction* custom_call);
70   xla::StatusOr<Operation*> EmitGemm(
71       const xla::HloCustomCallInstruction* custom_call);
72   xla::StatusOr<Operation*> EmitCublasLtMatmul(
73       const xla::HloCustomCallInstruction* custom_call);
74   xla::StatusOr<Operation*> EmitDnnConvolution(
75       const xla::HloCustomCallInstruction* custom_call);
76   xla::StatusOr<Operation*> EmitDnnBatchNorm(
77       const xla::HloCustomCallInstruction* custom_call);
78 
79   xla::StatusOr<memref::GetGlobalOp> EmitConstant(
80       const xla::HloInstruction* instr);
81 
82   xla::StatusOr<lmhlo::InfeedOp> EmitInfeedOp(const xla::HloInstruction* instr);
83   xla::StatusOr<lmhlo::OutfeedOp> EmitOutfeedOp(
84       const xla::HloInstruction* instr);
85 
86   xla::StatusOr<lmhlo::AllToAllOp> EmitAllToAllOp(
87       const xla::HloInstruction* instr);
88   xla::StatusOr<lmhlo::AllGatherOp> EmitAllGatherOp(
89       const xla::HloInstruction* instr);
90   xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
91       const xla::HloInstruction* instr);
92   xla::StatusOr<lmhlo_gpu::AllReduceStartOp> EmitAllReduceStartOp(
93       const xla::HloInstruction* instr);
94   xla::StatusOr<lmhlo_gpu::AllReduceDoneOp> EmitAllReduceDoneOp(
95       const xla::HloInstruction* instr);
96   xla::StatusOr<lmhlo::ReduceScatterOp> EmitReduceScatterOp(
97       const xla::HloInstruction* instr);
98   xla::StatusOr<lmhlo::CollectivePermuteOp> EmitCollectivePermuteOp(
99       const xla::HloInstruction* instr);
100 
101   xla::StatusOr<lmhlo::RngGetAndUpdateStateOp> EmitRngGetAndUpdateStateOp(
102       const xla::HloInstruction* instr);
103   xla::StatusOr<lmhlo::FftOp> EmitFftOp(const xla::HloInstruction* instr);
104   xla::StatusOr<lmhlo::TriangularSolveOp> EmitTriangularSolveOp(
105       const xla::HloInstruction* instr);
106   xla::StatusOr<Operation*> EmitBitcast(const xla::HloInstruction* instr);
107 
108   xla::StatusOr<lmhlo::CaseOp> EmitCaseOp(const xla::HloInstruction* instr);
109 
110   xla::StatusOr<lmhlo::WhileOp> EmitWhileOp(const xla::HloInstruction* instr);
111 
112   xla::Status ImportAsLmhloRegion(xla::HloComputation* computation,
113                                   mlir::Region* region);
114 
115   // Since LMHLO dialect does not define token types, this enum controls how
116   // token operand/results from XLA:HLO are lowered to MLIR.
117   enum class TokenLoweringMode {
118     kFailToLower,  // Fail lowering if token inputs are encountered.
119     kUseNull,      // Use a null Value in the operand list for each token.
120     // kSkip,        // Skip any token inputs or outputs (not yet needed)
121   };
122 
123   // Create LHLO operation operands given an XLA HLO instruction. By default,
124   // all XLA HLO operands and results are converted to MLIR and appended to
125   // `operands`. If `num_operands` is specified, only the first `num_operand`
126   // operands of the instruction are converted to MLIR. The function returns the
127   // actual number of operands and results generated for MLIR in `num_arguments`
128   // and `num_results`.
129   xla::Status CreateOperands(const xla::HloInstruction* instr,
130                              std::optional<int64_t> num_operands,
131                              TokenLoweringMode token_mode,
132                              SmallVectorImpl<Value>& operands,
133                              size_t& num_arguments, size_t& num_results);
134 
135   template <typename OpType>
136   xla::StatusOr<OpType> CreateOpWithoutAttrs(
137       const xla::HloInstruction* instr,
138       std::optional<int64_t> num_operands = std::nullopt) {
139     size_t unused;
140     return CreateOpWithoutAttrs<OpType>(instr, unused, unused, num_operands);
141   }
142 
143   template <typename OpType>
144   xla::StatusOr<OpType> CreateOpWithoutAttrs(
145       const xla::HloInstruction* instr, size_t& num_arguments,
146       size_t& num_results, std::optional<int64_t> num_operands = std::nullopt);
147 
148   template <typename OpType>
149   OpType CreateOpWithoutAttrs(const xla::HloInstruction* instr,
150                               ValueRange operands);
151 
152   xla::StatusOr<mlir::Operation*> CreateOpInFusion(
153       const xla::HloInstruction* instr, ValueRange buffer_operands,
154       size_t num_arguments, size_t num_results);
155 
156   xla::StatusOr<mlir::Operation*> CreateOpInFusion(
157       const xla::HloInstruction* instr);
158 
159   template <typename T>
GetI64DenseElementsAttr(const T & container)160   DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) {
161     return builder_.getI64TensorAttr(
162         {container.data(), static_cast<size_t>(container.size())});
163   }
164 
GetWindowElements(const xla::Window & window,std::function<int64_t (const xla::WindowDimension & dim)> getter)165   DenseIntElementsAttr GetWindowElements(
166       const xla::Window& window,
167       std::function<int64_t(const xla::WindowDimension& dim)> getter) {
168     llvm::SmallVector<int64_t, 4> elements;
169     elements.reserve(window.dimensions_size());
170     for (const xla::WindowDimension& dim : window.dimensions()) {
171       elements.push_back(getter(dim));
172     }
173     return GetI64DenseElementsAttr(elements);
174   }
175 
176   static mlir::DenseIntElementsAttr GetLayoutAttribute(
177       const xla::Layout& layout, Builder* builder);
178 
179   tensorflow::Status DefaultAction(const xla::HloInstruction* instr) final;
180 
181   // Computation parameters don't need any specific handling when they are
182   // visited, they are already processed when we enter a new computation.
HandleParameter(const xla::HloInstruction * instr)183   tensorflow::Status HandleParameter(const xla::HloInstruction* instr) final {
184     return ::tensorflow::OkStatus();
185   }
186 
187   // Helper function that recursively visits the tuple structure in
188   // `current_shape`, and reconstruct a matching lmhlo::TupleOp.
189   // Each leaf node is converted to an std.view op with corresponding offsets.
190   // If no tuple presents, it simply returns a view of the buffer.
191   tensorflow::Status GetOrCreateViewImpl(const xla::HloInstruction* instr,
192                                          const xla::Shape& current_shape,
193                                          xla::ShapeIndex* current_shape_index,
194                                          SmallVectorImpl<Value>* values,
195                                          TokenLoweringMode token_mode);
196 
197   // Helper function to create view/tuple of views to a buffer for a given
198   // instruction result. `result_subset` can be used to for instructions that
199   // have a tuple result and MLIR conversion needs to convert only one of the
200   // tuple elements. Note that if needed, this can be extended to take a list of
201   // ShapeIndex values in case we need finer control on what elements of the
202   // output tuple to be converted to MLIR.
203   tensorflow::Status GetOrCreateView(
204       const xla::HloInstruction* instr, SmallVectorImpl<Value>* values,
205       const xla::ShapeIndex& result_subset = {},
206       TokenLoweringMode token_mode = TokenLoweringMode::kFailToLower);
207 
208   xla::StatusOr<Value> GetOrCreateArrayView(
209       const xla::HloInstruction* instr, const xla::Shape& current_shape,
210       const xla::ShapeIndex& current_shape_index);
211 
212   xla::StatusOr<Value> RewriteFusionOperand(const xla::HloInstruction* root,
213                                             const xla::Shape& shape,
214                                             xla::ShapeIndex* shape_index,
215                                             OpBuilder* b, Location loc);
216 
217   // Return an MLIR location for an HLO instruction.
getLocation(const xla::HloInstruction * inst)218   Location getLocation(const xla::HloInstruction* inst) {
219     return NameLoc::get(builder_.getStringAttr(inst->name()));
220   }
221 
222   // This map provides access to MLIR buffers for each HLO buffer allocation.
223   // The MLIR buffers are all `memref<{size}xi8>` and correspond to function
224   // parameters. It is populated at the beginning of the processing with all
225   // the buffer allocations and is unchanged afterward. Every HLOInstruction
226   // is using a "slice" of the buffer allocation and providing shape, layout,
227   // and Dtype. An MLIR view is used separately to model slices into the
228   // allocations (see below).
229   llvm::DenseMap<const xla::BufferAllocation*, Value> allocations_;
230 
231   // This map provides access to MLIR buffers for each HLO instruction, keyed
232   // instruction identity. A slice is contained in a BufferAllocation, and has
233   // an offset and a size.
234   //
235   // As for why we don't use HloInstruction*, see GetOrCreateView(), but
236   // mostly we want to leverage better of the aliased buffers.
237   //
238   // If the HloInstruction is a tuple, all leaf nodes are stored flattened.
239   // Otherwise, there will be a single buffer.
240   //
241   // An MLIR buffer is either an input parameter, or a ViewOp in the case
242   // where the slice is only part of its allocation.
243   //
244   // `slices_` is populated lazily in the `GetOrCreateView()` helper as we
245   // process every instruction.
246   absl::flat_hash_map<std::pair<const xla::HloInstruction*, xla::ShapeIndex>,
247                       Value>
248       slices_;
249 
250   // The BufferAssignment computed by XLA ahead of time.
251   const xla::BufferAssignment& assignment_;
252 
253   // The HLO module that will be converted.
254   const xla::HloComputation& computation_;
255 
256   // This is the MLIR module in which a function will be created for every HLO
257   // computation.
258   ModuleOp module_;
259 
260   // The builder keeps track of the current insertion point in the MLIR
261   // module.
262   OpBuilder builder_;
263   // Convenient "cached" access to this widely used MLIR type (i8).
264   Type i8_type_;
265 
266   // Map all-reduce-start ops to their LHLO op, so we can connect the
267   // all-reduce-done op with the correct token.
268   absl::flat_hash_map<const xla::HloInstruction*, lmhlo_gpu::AllReduceStartOp>
269       all_reduce_start_ops_;
270 };
271 
272 // Populate the MLIR `module` with the computation from the `hlo_module` using
273 // the provided buffer `assignment`. The returned `Status` indicates success
274 // or failure in the conversion.
275 tensorflow::Status HloToLhloModule(const xla::BufferAssignment& assignment,
276                                    const xla::HloModule& hlo_module,
277                                    ModuleOp module);
278 
279 tensorflow::Status OptimizeAndConvertHloToLmhlo(
280     std::unique_ptr<xla::HloModule> hlo_module, ModuleOp module,
281     StringRef platform_name, bool optimize_xla_hlo);
282 OwningOpRef<mlir::ModuleOp> HloTextToLhloTranslateFunction(
283     llvm::StringRef input, MLIRContext* context, bool optimize_xla_hlo);
284 
285 // This register the MLIR pass with the command line.
286 void RegisterMhloToLhloWithXlaPass();
287 
288 }  // namespace mlir
289 
290 #endif  // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
291