xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <map>
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
29 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
30 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
31 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
32 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
33 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
34 #include "mlir/Pass/Pass.h"  // from @llvm-project
35 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
36 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
38 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
39 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
40 #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
41 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
42 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
43 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
44 
45 namespace mlir {
46 namespace mhlo {
47 namespace {
48 
49 class LegalizeTF : public LegalizeTFBase<LegalizeTF> {
50  public:
LegalizeTF(bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type,bool prefer_tf2xla)51   explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo,
52                       llvm::Optional<StringRef> tf2xla_fallback_device_type,
53                       bool prefer_tf2xla) {
54     allow_partial_conversion_ = allow_partial_conversion;
55     legalize_chlo_ = legalize_chlo;
56     prefer_tf2xla_ = prefer_tf2xla;
57     use_tf2xla_fallback_ = tf2xla_fallback_device_type.has_value();
58     if (tf2xla_fallback_device_type.has_value()) {
59       device_type_ = tf2xla_fallback_device_type.getValue().str();
60     }
61   }
62   /// Performs the lowering to XLA dialect.
63   void runOnOperation() override;
64 };
65 
66 class LegalizeTFModulePass
67     : public LegalizeTFModulePassBase<LegalizeTFModulePass> {
68  public:
LegalizeTFModulePass(StringRef tf2xla_fallback_device_type)69   explicit LegalizeTFModulePass(StringRef tf2xla_fallback_device_type) {
70     device_type_ = tf2xla_fallback_device_type.str();
71   }
72 
73   /// Performs the lowering to XLA dialect.
74   void runOnOperation() override;
75 };
76 
77 // Emits debug information which includes the number of ops of each type which
78 // failed to legalize.
EmitLegalizationErrors(Operation * op,const DenseSet<Operation * > & nonlegalized_ops)79 void EmitLegalizationErrors(Operation *op,
80                             const DenseSet<Operation *> &nonlegalized_ops) {
81   // Track the legalization failures by mapping op name to information about
82   // that failure: the number of unlegalized occurrences of the op, and one
83   // example operation that failed.
84   std::map<StringRef, std::pair<int, Operation *>> op_name_to_error_info;
85   DenseSet<Operation *> error_ops;
86   for (Operation *nonlegalized_op : nonlegalized_ops) {
87     // Increment count of this legalization failure.
88     StringRef op_name = nonlegalized_op->getName().getStringRef();
89     // If this emplace is successful, it's the first time we've encountered
90     // this op type. Initialize count to 0 so that after increment, it is 1.
91     auto insertion_result = op_name_to_error_info.emplace(
92         op_name, std::make_pair(0, nonlegalized_op));
93     ++insertion_result.first->second.first;
94   }
95   std::vector<std::string> error_messages;
96   error_messages.reserve(op_name_to_error_info.size());
97   for (const auto &op_info : op_name_to_error_info) {
98     error_messages.push_back(
99         llvm::formatv("{0} (count: {1})", op_info.first, op_info.second.first));
100   }
101   Location loc = op->getLoc();
102   emitError(loc) << "The following operations cannot be legalized: "
103                  << llvm::join(error_messages, "; ")
104                  << ". These legalization failure(s) may be due to missing TF "
105                     "to HLO lowerings and/or unsupported attributes, etc.";
106   // Emit more information about the missing ops. This error message
107   // contains useful details beyond the op name (input and output shapes,
108   // attributes, etc.).
109   if (!VLOG_IS_ON(1) && nonlegalized_ops.size() != 1) {
110     emitError(loc)
111         << "Emitting more detail about one op that failed to legalize...";
112   } else if (VLOG_IS_ON(1)) {
113     emitError(loc) << "Emitting more detail about one of each type of op "
114                       "that failed to legalize...";
115   }
116   for (const auto &op_info : op_name_to_error_info) {
117     op_info.second.second->emitOpError() << "is not legalizable";
118     if (!VLOG_IS_ON(1)) break;
119   }
120 }
121 
122 /// Returns ops that should use MLIR legalization only in the case of
123 /// prefer_tf2xla. All other ops not in this list should use XlaOpKernel
124 /// legalization only or not be legalized by the new bridge.
MlirPreferredOps()125 const llvm::DenseSet<mlir::TypeID> &MlirPreferredOps() {
126   // The static variable is a pointer in order to avoid destruction upon thread
127   // termination.
128 
129   // clang-format off
130   static const llvm::DenseSet<mlir::TypeID>* ops =
131       new llvm::DenseSet<mlir::TypeID>{
132     // Ops that are legalized in the old bridge using MlirXlaOpKernel
133     TypeID::get<TF::AbsOp>(),
134     TypeID::get<TF::AtanOp>(),
135     TypeID::get<TF::AvgPool3DOp>(),
136     TypeID::get<TF::BiasAddGradOp>(),
137     TypeID::get<TF::CeilOp>(),
138     TypeID::get<TF::CheckNumericsOp>(),
139     TypeID::get<TF::ComplexOp>(),
140     TypeID::get<TF::CosOp>(),
141     TypeID::get<TF::DiagPartOp>(),
142     TypeID::get<TF::DivOp>(),
143     TypeID::get<TF::EinsumOp>(),
144     TypeID::get<TF::ExpOp>(),
145     TypeID::get<TF::Expm1Op>(),
146     TypeID::get<TF::FakeQuantWithMinMaxArgsOp>(),
147     TypeID::get<TF::FloorOp>(),
148     TypeID::get<TF::GreaterEqualOp>(),
149     TypeID::get<TF::IFFTOp>(),
150     TypeID::get<TF::ImagOp>(),
151     TypeID::get<TF::IsFiniteOp>(),
152     TypeID::get<TF::IsInfOp>(),
153     TypeID::get<TF::IsNanOp>(),
154     TypeID::get<TF::LessEqualOp>(),
155     TypeID::get<TF::LgammaOp>(),
156     TypeID::get<TF::Log1pOp>(),
157     TypeID::get<TF::LogicalOrOp>(),
158     TypeID::get<TF::LogSoftmaxOp>(),
159     TypeID::get<TF::MatrixBandPartOp>(),
160     TypeID::get<TF::MaxPool3DGradOp>(),
161     TypeID::get<TF::PreventGradientOp>(),
162     TypeID::get<TF::RandomShuffleOp>(),
163     TypeID::get<TF::RealOp>(),
164     TypeID::get<TF::ReciprocalOp>(),
165     TypeID::get<TF::ReluOp>(),
166     TypeID::get<TF::Relu6Op>(),
167     TypeID::get<TF::ReluGradOp>(),
168     TypeID::get<TF::RsqrtOp>(),
169     TypeID::get<TF::SelectOp>(),
170     TypeID::get<TF::SigmoidOp>(),
171     TypeID::get<TF::SignOp>(),
172     TypeID::get<TF::SoftmaxOp>(),
173     TypeID::get<TF::SqrtOp>(),
174     TypeID::get<TF::SqrtGradOp>(),
175     TypeID::get<TF::SquaredDifferenceOp>(),
176     TypeID::get<TF::TanhOp>(),
177     TypeID::get<TF::TanhGradOp>(),
178     TypeID::get<TF::XlaConvV2Op>(),
179     TypeID::get<TF::XlaDotOp>(),
180     TypeID::get<TF::XlaDotV2Op>(),
181     TypeID::get<TF::XlaDynamicSliceOp>(),
182     TypeID::get<TF::XlaEinsumOp>(),
183     TypeID::get<TF::XlaReduceWindowOp>(),
184     TypeID::get<TF::XlaReplicaIdOp>(),
185     TypeID::get<TF::XlaRngBitGeneratorOp>(),
186     TypeID::get<TF::XlaSelectAndScatterOp>(),
187     TypeID::get<TF::XlaSortOp>(),
188     TypeID::get<TF::XlaVariadicReduceV2Op>(),
189     TypeID::get<TF::XlaVariadicSortOp>(),
190     TypeID::get<TF::XlogyOp>(),
191     TypeID::get<TF::ZetaOp>(),
192 
193     // Ops that have no XlaOpKernel.
194     TypeID::get<TF::RiscAddOp>(),
195     TypeID::get<TF::RiscDotOp>(),
196 
197     // Const op has a simple legalization and it is much more efficient to lower
198     // within MLIR.
199     TypeID::get<TF::ConstOp>(),
200 
201     // AssertOp with string types are not supported by the fallback.
202     TypeID::get<TF::AssertOp>(),
203 
204     // TF2XLA fallback pattern doesn't support these op as MLIR hlo builder
205     // doesn't override the necessary builder methods. These ops have simple
206     // lowering pattern so this should be safe.
207     TypeID::get<TF::CrossReplicaSumOp>(),
208     TypeID::get<TF::InfeedDequeueTupleOp>(),
209     TypeID::get<TF::OutfeedEnqueueTupleOp>(),
210     TypeID::get<TF::XlaShardingOp>(),
211 
212     // These ops have undetermined bugs, may not be legalizable with XlaOpKernel
213     // legalization in TF2XLA fallback. By legalization with MLIR, we can fix
214     // the bug. b/195583695 describes the motivation of this change.
215     // See b/216355804 how to reproduce the bug regarding tf.RandomUniform Op
216     // See b/216353817 how to reproduce the bug regarding tf.StridedSlice Op
217     TypeID::get<TF::RandomUniformOp>(),
218     TypeID::get<TF::StridedSliceOp>(),
219   };
220   // clang-format on
221   return *ops;
222 }
223 
224 // Patterns whose root op is in the set `include_ops` are moved from the set
225 // `from` to the returned set. This is used to partition patterns by op so they
226 // can be cleanly migrated from the old bridge to the MLIR bridge.
PatternsIncludeOps(RewritePatternSet & from,const llvm::DenseSet<mlir::TypeID> & include_ops)227 RewritePatternSet PatternsIncludeOps(
228     RewritePatternSet &from, const llvm::DenseSet<mlir::TypeID> &include_ops) {
229   RewritePatternSet to(from.getContext());
230   // Filter NativePatterns.
231   for (auto &pattern : from.getNativePatterns()) {
232     Optional<OperationName> pat_op_name = pattern->getRootKind();
233     // If the pattern does not have a specific operation, always include it,
234     // If the pattern is in include_ops then include it.
235     bool include =
236         !pat_op_name ||
237         include_ops.count(pat_op_name->getRegisteredInfo()->getTypeID());
238     if (include) to.add(std::move(pattern));
239   }
240 
241   // Don't filter PDLPatterns.
242   to.add(std::move(from.getPDLPatterns()));
243 
244   return to;
245 }
246 
ApplyPatterns(Operation * op,RewritePatternSet & patterns,bool legalize_chlo,bool allow_partial_conversion)247 mlir::LogicalResult ApplyPatterns(Operation *op, RewritePatternSet &patterns,
248                                   bool legalize_chlo,
249                                   bool allow_partial_conversion) {
250   ConversionTarget target(*op->getContext());
251   if (legalize_chlo) {
252     target.addIllegalDialect<chlo::ChloDialect>();
253   } else {
254     target.addLegalDialect<chlo::ChloDialect>();
255   }
256   target.addLegalDialect<MhloDialect>();
257   target.addLegalDialect<arith::ArithmeticDialect>();
258   target.addLegalDialect<func::FuncDialect>();
259   target.addLegalDialect<tensor::TensorDialect>();
260   target.addLegalDialect<shape::ShapeDialect>();
261   target.addLegalOp<func::CallOp>();
262 
263   if (!allow_partial_conversion) {
264     // Fully qualify ReturnOp here as mhlo dialect also defines a ReturnOp.
265     target.addLegalOp<ModuleOp, ::mlir::func::FuncOp, ::mlir::func::ReturnOp>();
266     DenseSet<Operation *> nonlegalized_ops;
267     LogicalResult result = applyPartialConversion(
268         op, target, std::move(patterns), &nonlegalized_ops);
269     // In order to enforce that the conversion result is fully converted,
270     // fail if there are any nonlegalized ops in the set.
271     if (failed(result) || !nonlegalized_ops.empty()) {
272       EmitLegalizationErrors(op, nonlegalized_ops);
273       return failure();
274     }
275     return result;
276   }
277 
278   return applyPartialConversion(op, target, std::move(patterns));
279 }
280 
281 /// When `tf2xla_fallback_device_type` is not `None`, also uses legalization
282 /// patterns from TF2XLA fallback for provided device type (see
283 /// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not
284 /// used.
legalizeTF(Operation * op,bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type,bool prefer_tf2xla)285 LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
286                          bool legalize_chlo,
287                          llvm::Optional<StringRef> tf2xla_fallback_device_type,
288                          bool prefer_tf2xla) {
289   MLIRContext *context = op->getContext();
290   RewritePatternSet legalize_lower_patterns(context);
291   // Note that the `OperationConverter` orders patterns lexicographically by:
292   // 1) Ascending legalization depth (i.e., minimum number of patterns necessary
293   //    to arrive at conversion target). This requires relevant patterns to
294   //    specify the list of ops generated by it which most of patterns
295   //    implemented in C++ don't do so this comparison doesn't work in those
296   //    cases.
297   // 2) Descending pattern benefit.
298   // 3) Op specific patterns over patterns with MatchAnyOpTypeTag.
299   // 4) Order of patterns in `RewritePatternSet`.
300 
301   // Add TF->HLO legalization patterns.
302   PopulateLegalizeTfPatterns(context, &legalize_lower_patterns);
303 
304   // Add TF->TF lowering patterns.
305   TF::PopulateTFLoweringBeforeHLOPatterns(context, &legalize_lower_patterns);
306 
307   if (tf2xla_fallback_device_type && prefer_tf2xla) {
308     VLOG(1) << "TF to XLA legalization patterns are partitioned by op into "
309                "either native MLIR legalization, or TF2XLA fallback "
310                "legalzation, with a preference toward TF2XLA.";
311   } else if (tf2xla_fallback_device_type) {
312     VLOG(1) << "TF to XLA legalization patterns include all native patterns "
313                "and TF2XLA fallback patterns.";
314   } else {
315     VLOG(1) << "TF to XLA legalization patterns are native patterns only.";
316   }
317 
318   // Set patterns to legalize_lower_patters, where in the prefer_tf2xla case
319   // only patterns whose ops are in the set MlirPreferredOps are kept.
320   RewritePatternSet patterns =
321       (tf2xla_fallback_device_type && prefer_tf2xla)
322           ? PatternsIncludeOps(legalize_lower_patterns, MlirPreferredOps())
323           : std::move(legalize_lower_patterns);
324 
325   if (tf2xla_fallback_device_type) {
326     // Add TF->HLO legalization patterns via TF2XLA fallback.
327     PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.getValue(),
328                                          patterns, context, prefer_tf2xla);
329   }
330 
331   // Populate with CHLO->HLO lowerings to account for TF ops legalized to
332   // CHLO first.
333   if (legalize_chlo) {
334     chlo::populateDecomposeChloPatterns(context, &patterns);
335     chlo::populateChloBroadcastingPatterns(context, &patterns);
336   }
337   // ConstantLike op is convenient to create splat constants, but is
338   // canonicalized to plain HLO constant if statically shaped. Add the
339   // canonicalization pattern to pattern list to enable multi-hop lowering.
340   chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context);
341 
342   return ApplyPatterns(op, patterns, legalize_chlo, allow_partial_conversion);
343 }
344 
345 // Performs the lowering to XLA dialect.
runOnOperation()346 void LegalizeTF::runOnOperation() {
347   llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None;
348   if (use_tf2xla_fallback_) {
349     tf2xla_fallback_device_type = device_type_;
350   }
351   if (failed(legalizeTF(getOperation(), allow_partial_conversion_,
352                         legalize_chlo_, tf2xla_fallback_device_type,
353                         prefer_tf2xla_))) {
354     signalPassFailure();
355   }
356 }
357 
runOnOperation()358 void LegalizeTFModulePass::runOnOperation() {
359   // This pass should only be run when a fallback device is present.
360   if (!device_type_.hasValue()) {
361     return;
362   }
363   VLOG(1) << "TF to XLA legalization patterns include TF2XLA fallback "
364              "patterns for Ops that need to create functions.";
365   Operation *op = getOperation();
366   MLIRContext *context = op->getContext();
367   RewritePatternSet patterns(context);
368   PopulateLegalizeTfWithTf2XlaPatterns(device_type_, patterns, context,
369                                        /*prefer_tf2xla=*/false,
370                                        /*is_module_pass=*/true);
371 
372   if (failed(ApplyPatterns(op, patterns,
373                            /*legalize_chlo=*/false,
374                            /*allow_partial_conversion=*/true))) {
375     signalPassFailure();
376   }
377 }
378 
379 }  // end namespace
380 
createLegalizeTFPass(bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type,bool prefer_tf2xla)381 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeTFPass(
382     bool allow_partial_conversion, bool legalize_chlo,
383     llvm::Optional<StringRef> tf2xla_fallback_device_type, bool prefer_tf2xla) {
384   return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo,
385                                       tf2xla_fallback_device_type,
386                                       prefer_tf2xla);
387 }
388 
createLegalizeTFModulePass(StringRef tf2xla_fallback_device_type)389 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeTFModulePass(
390     StringRef tf2xla_fallback_device_type) {
391   return std::make_unique<LegalizeTFModulePass>(tf2xla_fallback_device_type);
392 }
393 
394 }  // end namespace mhlo
395 }  // end namespace mlir
396