1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <utility>
18
19 #include "llvm/ADT/DenseSet.h"
20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
23 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "mlir/Support/LogicalResult.h" // from @llvm-project
26 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
29 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
31 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
32 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
33
34 namespace mlir {
35 namespace mhlo {
36 namespace {
37
38 class LegalizeTFNoFallback
39 : public LegalizeTFNoFallbackBase<LegalizeTFNoFallback> {
40 public:
LegalizeTFNoFallback(bool allow_partial_conversion)41 explicit LegalizeTFNoFallback(bool allow_partial_conversion) {
42 allow_partial_conversion_ = allow_partial_conversion;
43 }
44 /// Performs the lowering to HLO dialect.
45 void runOnOperation() override;
46 };
47
runOnOperation()48 void LegalizeTFNoFallback::runOnOperation() {
49 Operation *op = getOperation();
50 MLIRContext *context = op->getContext();
51 RewritePatternSet patterns(context);
52
53 // Add TF->HLO legalization patterns.
54 PopulateLegalizeTfPatterns(context, &patterns);
55
56 // ConstantLike op is convenient to create splat constants, but is
57 // canonicalized to plain HLO constant if statically shaped. Add the
58 // canonicalization pattern to pattern list to enable multi-hop lowering.
59 chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context);
60
61 ConversionTarget target(*context);
62 target.addLegalDialect<arith::ArithmeticDialect>();
63 target.addLegalDialect<chlo::ChloDialect>();
64 target.addLegalDialect<MhloDialect>();
65 target.addLegalDialect<func::FuncDialect>();
66 target.addLegalDialect<tensor::TensorDialect>();
67 target.addLegalDialect<shape::ShapeDialect>();
68 target.addLegalOp<func::CallOp>();
69
70 // Add TF->TF lowering patterns.
71 TF::PopulateTFLoweringBeforeHLOPatterns(context, &patterns);
72 if (!allow_partial_conversion_) {
73 // Fully qualify ReturnOp here as mhlo dialect also defines a ReturnOp.
74 target.addLegalOp<ModuleOp, func::FuncOp, ::mlir::func::ReturnOp>();
75 llvm::DenseSet<Operation *> nonlegalized_ops;
76 LogicalResult result = applyPartialConversion(
77 op, target, std::move(patterns), &nonlegalized_ops);
78 // In order to enforce that the conversion result is fully converted,
79 // fail if there are any nonlegalized ops in the set.
80 if (failed(result) || !nonlegalized_ops.empty()) {
81 return signalPassFailure();
82 }
83 } else if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
84 signalPassFailure();
85 }
86 }
87
88 } // end namespace
89
createLegalizeTFNoFallbackPass(bool allow_partial_conversion)90 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeTFNoFallbackPass(
91 bool allow_partial_conversion) {
92 return std::make_unique<LegalizeTFNoFallback>(allow_partial_conversion);
93 }
94
95 } // end namespace mhlo
96 } // end namespace mlir
97