1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <map>
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
29 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
30 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
31 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
32 #include "mlir/IR/MLIRContext.h" // from @llvm-project
33 #include "mlir/IR/PatternMatch.h" // from @llvm-project
34 #include "mlir/Pass/Pass.h" // from @llvm-project
35 #include "mlir/Support/LogicalResult.h" // from @llvm-project
36 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
38 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
39 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
40 #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
41 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
42 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
43 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
44
45 namespace mlir {
46 namespace mhlo {
47 namespace {
48
49 class LegalizeTF : public LegalizeTFBase<LegalizeTF> {
50 public:
LegalizeTF(bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type,bool prefer_tf2xla)51 explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo,
52 llvm::Optional<StringRef> tf2xla_fallback_device_type,
53 bool prefer_tf2xla) {
54 allow_partial_conversion_ = allow_partial_conversion;
55 legalize_chlo_ = legalize_chlo;
56 prefer_tf2xla_ = prefer_tf2xla;
57 use_tf2xla_fallback_ = tf2xla_fallback_device_type.has_value();
58 if (tf2xla_fallback_device_type.has_value()) {
59 device_type_ = tf2xla_fallback_device_type.getValue().str();
60 }
61 }
62 /// Performs the lowering to XLA dialect.
63 void runOnOperation() override;
64 };
65
66 class LegalizeTFModulePass
67 : public LegalizeTFModulePassBase<LegalizeTFModulePass> {
68 public:
LegalizeTFModulePass(StringRef tf2xla_fallback_device_type)69 explicit LegalizeTFModulePass(StringRef tf2xla_fallback_device_type) {
70 device_type_ = tf2xla_fallback_device_type.str();
71 }
72
73 /// Performs the lowering to XLA dialect.
74 void runOnOperation() override;
75 };
76
77 // Emits debug information which includes the number of ops of each type which
78 // failed to legalize.
EmitLegalizationErrors(Operation * op,const DenseSet<Operation * > & nonlegalized_ops)79 void EmitLegalizationErrors(Operation *op,
80 const DenseSet<Operation *> &nonlegalized_ops) {
81 // Track the legalization failures by mapping op name to information about
82 // that failure: the number of unlegalized occurrences of the op, and one
83 // example operation that failed.
84 std::map<StringRef, std::pair<int, Operation *>> op_name_to_error_info;
85 DenseSet<Operation *> error_ops;
86 for (Operation *nonlegalized_op : nonlegalized_ops) {
87 // Increment count of this legalization failure.
88 StringRef op_name = nonlegalized_op->getName().getStringRef();
89 // If this emplace is successful, it's the first time we've encountered
90 // this op type. Initialize count to 0 so that after increment, it is 1.
91 auto insertion_result = op_name_to_error_info.emplace(
92 op_name, std::make_pair(0, nonlegalized_op));
93 ++insertion_result.first->second.first;
94 }
95 std::vector<std::string> error_messages;
96 error_messages.reserve(op_name_to_error_info.size());
97 for (const auto &op_info : op_name_to_error_info) {
98 error_messages.push_back(
99 llvm::formatv("{0} (count: {1})", op_info.first, op_info.second.first));
100 }
101 Location loc = op->getLoc();
102 emitError(loc) << "The following operations cannot be legalized: "
103 << llvm::join(error_messages, "; ")
104 << ". These legalization failure(s) may be due to missing TF "
105 "to HLO lowerings and/or unsupported attributes, etc.";
106 // Emit more information about the missing ops. This error message
107 // contains useful details beyond the op name (input and output shapes,
108 // attributes, etc.).
109 if (!VLOG_IS_ON(1) && nonlegalized_ops.size() != 1) {
110 emitError(loc)
111 << "Emitting more detail about one op that failed to legalize...";
112 } else if (VLOG_IS_ON(1)) {
113 emitError(loc) << "Emitting more detail about one of each type of op "
114 "that failed to legalize...";
115 }
116 for (const auto &op_info : op_name_to_error_info) {
117 op_info.second.second->emitOpError() << "is not legalizable";
118 if (!VLOG_IS_ON(1)) break;
119 }
120 }
121
122 /// Returns ops that should use MLIR legalization only in the case of
123 /// prefer_tf2xla. All other ops not in this list should use XlaOpKernel
124 /// legalization only or not be legalized by the new bridge.
MlirPreferredOps()125 const llvm::DenseSet<mlir::TypeID> &MlirPreferredOps() {
126 // The static variable is a pointer in order to avoid destruction upon thread
127 // termination.
128
129 // clang-format off
130 static const llvm::DenseSet<mlir::TypeID>* ops =
131 new llvm::DenseSet<mlir::TypeID>{
132 // Ops that are legalized in the old bridge using MlirXlaOpKernel
133 TypeID::get<TF::AbsOp>(),
134 TypeID::get<TF::AtanOp>(),
135 TypeID::get<TF::AvgPool3DOp>(),
136 TypeID::get<TF::BiasAddGradOp>(),
137 TypeID::get<TF::CeilOp>(),
138 TypeID::get<TF::CheckNumericsOp>(),
139 TypeID::get<TF::ComplexOp>(),
140 TypeID::get<TF::CosOp>(),
141 TypeID::get<TF::DiagPartOp>(),
142 TypeID::get<TF::DivOp>(),
143 TypeID::get<TF::EinsumOp>(),
144 TypeID::get<TF::ExpOp>(),
145 TypeID::get<TF::Expm1Op>(),
146 TypeID::get<TF::FakeQuantWithMinMaxArgsOp>(),
147 TypeID::get<TF::FloorOp>(),
148 TypeID::get<TF::GreaterEqualOp>(),
149 TypeID::get<TF::IFFTOp>(),
150 TypeID::get<TF::ImagOp>(),
151 TypeID::get<TF::IsFiniteOp>(),
152 TypeID::get<TF::IsInfOp>(),
153 TypeID::get<TF::IsNanOp>(),
154 TypeID::get<TF::LessEqualOp>(),
155 TypeID::get<TF::LgammaOp>(),
156 TypeID::get<TF::Log1pOp>(),
157 TypeID::get<TF::LogicalOrOp>(),
158 TypeID::get<TF::LogSoftmaxOp>(),
159 TypeID::get<TF::MatrixBandPartOp>(),
160 TypeID::get<TF::MaxPool3DGradOp>(),
161 TypeID::get<TF::PreventGradientOp>(),
162 TypeID::get<TF::RandomShuffleOp>(),
163 TypeID::get<TF::RealOp>(),
164 TypeID::get<TF::ReciprocalOp>(),
165 TypeID::get<TF::ReluOp>(),
166 TypeID::get<TF::Relu6Op>(),
167 TypeID::get<TF::ReluGradOp>(),
168 TypeID::get<TF::RsqrtOp>(),
169 TypeID::get<TF::SelectOp>(),
170 TypeID::get<TF::SigmoidOp>(),
171 TypeID::get<TF::SignOp>(),
172 TypeID::get<TF::SoftmaxOp>(),
173 TypeID::get<TF::SqrtOp>(),
174 TypeID::get<TF::SqrtGradOp>(),
175 TypeID::get<TF::SquaredDifferenceOp>(),
176 TypeID::get<TF::TanhOp>(),
177 TypeID::get<TF::TanhGradOp>(),
178 TypeID::get<TF::XlaConvV2Op>(),
179 TypeID::get<TF::XlaDotOp>(),
180 TypeID::get<TF::XlaDotV2Op>(),
181 TypeID::get<TF::XlaDynamicSliceOp>(),
182 TypeID::get<TF::XlaEinsumOp>(),
183 TypeID::get<TF::XlaReduceWindowOp>(),
184 TypeID::get<TF::XlaReplicaIdOp>(),
185 TypeID::get<TF::XlaRngBitGeneratorOp>(),
186 TypeID::get<TF::XlaSelectAndScatterOp>(),
187 TypeID::get<TF::XlaSortOp>(),
188 TypeID::get<TF::XlaVariadicReduceV2Op>(),
189 TypeID::get<TF::XlaVariadicSortOp>(),
190 TypeID::get<TF::XlogyOp>(),
191 TypeID::get<TF::ZetaOp>(),
192
193 // Ops that have no XlaOpKernel.
194 TypeID::get<TF::RiscAddOp>(),
195 TypeID::get<TF::RiscDotOp>(),
196
197 // Const op has a simple legalization and it is much more efficient to lower
198 // within MLIR.
199 TypeID::get<TF::ConstOp>(),
200
201 // AssertOp with string types are not supported by the fallback.
202 TypeID::get<TF::AssertOp>(),
203
204 // TF2XLA fallback pattern doesn't support these op as MLIR hlo builder
205 // doesn't override the necessary builder methods. These ops have simple
206 // lowering pattern so this should be safe.
207 TypeID::get<TF::CrossReplicaSumOp>(),
208 TypeID::get<TF::InfeedDequeueTupleOp>(),
209 TypeID::get<TF::OutfeedEnqueueTupleOp>(),
210 TypeID::get<TF::XlaShardingOp>(),
211
212 // These ops have undetermined bugs, may not be legalizable with XlaOpKernel
213 // legalization in TF2XLA fallback. By legalization with MLIR, we can fix
214 // the bug. b/195583695 describes the motivation of this change.
215 // See b/216355804 how to reproduce the bug regarding tf.RandomUniform Op
216 // See b/216353817 how to reproduce the bug regarding tf.StridedSlice Op
217 TypeID::get<TF::RandomUniformOp>(),
218 TypeID::get<TF::StridedSliceOp>(),
219 };
220 // clang-format on
221 return *ops;
222 }
223
224 // Patterns whose root op is in the set `include_ops` are moved from the set
225 // `from` to the returned set. This is used to partition patterns by op so they
226 // can be cleanly migrated from the old bridge to the MLIR bridge.
PatternsIncludeOps(RewritePatternSet & from,const llvm::DenseSet<mlir::TypeID> & include_ops)227 RewritePatternSet PatternsIncludeOps(
228 RewritePatternSet &from, const llvm::DenseSet<mlir::TypeID> &include_ops) {
229 RewritePatternSet to(from.getContext());
230 // Filter NativePatterns.
231 for (auto &pattern : from.getNativePatterns()) {
232 Optional<OperationName> pat_op_name = pattern->getRootKind();
233 // If the pattern does not have a specific operation, always include it,
234 // If the pattern is in include_ops then include it.
235 bool include =
236 !pat_op_name ||
237 include_ops.count(pat_op_name->getRegisteredInfo()->getTypeID());
238 if (include) to.add(std::move(pattern));
239 }
240
241 // Don't filter PDLPatterns.
242 to.add(std::move(from.getPDLPatterns()));
243
244 return to;
245 }
246
ApplyPatterns(Operation * op,RewritePatternSet & patterns,bool legalize_chlo,bool allow_partial_conversion)247 mlir::LogicalResult ApplyPatterns(Operation *op, RewritePatternSet &patterns,
248 bool legalize_chlo,
249 bool allow_partial_conversion) {
250 ConversionTarget target(*op->getContext());
251 if (legalize_chlo) {
252 target.addIllegalDialect<chlo::ChloDialect>();
253 } else {
254 target.addLegalDialect<chlo::ChloDialect>();
255 }
256 target.addLegalDialect<MhloDialect>();
257 target.addLegalDialect<arith::ArithmeticDialect>();
258 target.addLegalDialect<func::FuncDialect>();
259 target.addLegalDialect<tensor::TensorDialect>();
260 target.addLegalDialect<shape::ShapeDialect>();
261 target.addLegalOp<func::CallOp>();
262
263 if (!allow_partial_conversion) {
264 // Fully qualify ReturnOp here as mhlo dialect also defines a ReturnOp.
265 target.addLegalOp<ModuleOp, ::mlir::func::FuncOp, ::mlir::func::ReturnOp>();
266 DenseSet<Operation *> nonlegalized_ops;
267 LogicalResult result = applyPartialConversion(
268 op, target, std::move(patterns), &nonlegalized_ops);
269 // In order to enforce that the conversion result is fully converted,
270 // fail if there are any nonlegalized ops in the set.
271 if (failed(result) || !nonlegalized_ops.empty()) {
272 EmitLegalizationErrors(op, nonlegalized_ops);
273 return failure();
274 }
275 return result;
276 }
277
278 return applyPartialConversion(op, target, std::move(patterns));
279 }
280
281 /// When `tf2xla_fallback_device_type` is not `None`, also uses legalization
282 /// patterns from TF2XLA fallback for provided device type (see
283 /// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not
284 /// used.
legalizeTF(Operation * op,bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type,bool prefer_tf2xla)285 LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
286 bool legalize_chlo,
287 llvm::Optional<StringRef> tf2xla_fallback_device_type,
288 bool prefer_tf2xla) {
289 MLIRContext *context = op->getContext();
290 RewritePatternSet legalize_lower_patterns(context);
291 // Note that the `OperationConverter` orders patterns lexicographically by:
292 // 1) Ascending legalization depth (i.e., minimum number of patterns necessary
293 // to arrive at conversion target). This requires relevant patterns to
294 // specify the list of ops generated by it which most of patterns
295 // implemented in C++ don't do so this comparison doesn't work in those
296 // cases.
297 // 2) Descending pattern benefit.
298 // 3) Op specific patterns over patterns with MatchAnyOpTypeTag.
299 // 4) Order of patterns in `RewritePatternSet`.
300
301 // Add TF->HLO legalization patterns.
302 PopulateLegalizeTfPatterns(context, &legalize_lower_patterns);
303
304 // Add TF->TF lowering patterns.
305 TF::PopulateTFLoweringBeforeHLOPatterns(context, &legalize_lower_patterns);
306
307 if (tf2xla_fallback_device_type && prefer_tf2xla) {
308 VLOG(1) << "TF to XLA legalization patterns are partitioned by op into "
309 "either native MLIR legalization, or TF2XLA fallback "
310 "legalzation, with a preference toward TF2XLA.";
311 } else if (tf2xla_fallback_device_type) {
312 VLOG(1) << "TF to XLA legalization patterns include all native patterns "
313 "and TF2XLA fallback patterns.";
314 } else {
315 VLOG(1) << "TF to XLA legalization patterns are native patterns only.";
316 }
317
318 // Set patterns to legalize_lower_patters, where in the prefer_tf2xla case
319 // only patterns whose ops are in the set MlirPreferredOps are kept.
320 RewritePatternSet patterns =
321 (tf2xla_fallback_device_type && prefer_tf2xla)
322 ? PatternsIncludeOps(legalize_lower_patterns, MlirPreferredOps())
323 : std::move(legalize_lower_patterns);
324
325 if (tf2xla_fallback_device_type) {
326 // Add TF->HLO legalization patterns via TF2XLA fallback.
327 PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.getValue(),
328 patterns, context, prefer_tf2xla);
329 }
330
331 // Populate with CHLO->HLO lowerings to account for TF ops legalized to
332 // CHLO first.
333 if (legalize_chlo) {
334 chlo::populateDecomposeChloPatterns(context, &patterns);
335 chlo::populateChloBroadcastingPatterns(context, &patterns);
336 }
337 // ConstantLike op is convenient to create splat constants, but is
338 // canonicalized to plain HLO constant if statically shaped. Add the
339 // canonicalization pattern to pattern list to enable multi-hop lowering.
340 chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context);
341
342 return ApplyPatterns(op, patterns, legalize_chlo, allow_partial_conversion);
343 }
344
345 // Performs the lowering to XLA dialect.
runOnOperation()346 void LegalizeTF::runOnOperation() {
347 llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None;
348 if (use_tf2xla_fallback_) {
349 tf2xla_fallback_device_type = device_type_;
350 }
351 if (failed(legalizeTF(getOperation(), allow_partial_conversion_,
352 legalize_chlo_, tf2xla_fallback_device_type,
353 prefer_tf2xla_))) {
354 signalPassFailure();
355 }
356 }
357
runOnOperation()358 void LegalizeTFModulePass::runOnOperation() {
359 // This pass should only be run when a fallback device is present.
360 if (!device_type_.hasValue()) {
361 return;
362 }
363 VLOG(1) << "TF to XLA legalization patterns include TF2XLA fallback "
364 "patterns for Ops that need to create functions.";
365 Operation *op = getOperation();
366 MLIRContext *context = op->getContext();
367 RewritePatternSet patterns(context);
368 PopulateLegalizeTfWithTf2XlaPatterns(device_type_, patterns, context,
369 /*prefer_tf2xla=*/false,
370 /*is_module_pass=*/true);
371
372 if (failed(ApplyPatterns(op, patterns,
373 /*legalize_chlo=*/false,
374 /*allow_partial_conversion=*/true))) {
375 signalPassFailure();
376 }
377 }
378
379 } // end namespace
380
createLegalizeTFPass(bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type,bool prefer_tf2xla)381 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeTFPass(
382 bool allow_partial_conversion, bool legalize_chlo,
383 llvm::Optional<StringRef> tf2xla_fallback_device_type, bool prefer_tf2xla) {
384 return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo,
385 tf2xla_fallback_device_type,
386 prefer_tf2xla);
387 }
388
createLegalizeTFModulePass(StringRef tf2xla_fallback_device_type)389 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeTFModulePass(
390 StringRef tf2xla_fallback_device_type) {
391 return std::make_unique<LegalizeTFModulePass>(tf2xla_fallback_device_type);
392 }
393
394 } // end namespace mhlo
395 } // end namespace mlir
396