xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <iostream>
16 
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
18 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
19 #include "mlir/IR/Attributes.h"  // from @llvm-project
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/Operation.h"  // from @llvm-project
22 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "mlir/Pass/PassManager.h"  // from @llvm-project
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
26 #include "mlir/Transforms/Passes.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
31 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
32 
33 namespace mlir {
34 namespace TF {
35 namespace {
36 
37 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_optimize.inc"
38 
39 // Returns a TF Constant tensor with the passed in values.
GetI64ConstantTensor(PatternRewriter & rewriter,ArrayRef<int64_t> values,Location location)40 TF::ConstOp GetI64ConstantTensor(PatternRewriter &rewriter,
41                                  ArrayRef<int64_t> values, Location location) {
42   auto cst_attr = rewriter.getI64TensorAttr(values);
43   return rewriter.create<TF::ConstOp>(location, cst_attr.getType(), cst_attr);
44 }
45 
46 // Rewrites broadcast->reshape to a reshape->broadcast that reduces
47 // the rank of the input and output of the broadcast.
48 class SimplifyBroadcastReshape : public OpRewritePattern<BroadcastToOp> {
49   using OpRewritePattern<BroadcastToOp>::OpRewritePattern;
50 
matchAndRewrite(BroadcastToOp op,PatternRewriter & rewriter) const51   LogicalResult matchAndRewrite(BroadcastToOp op,
52                                 PatternRewriter &rewriter) const override {
53     // Only rewrite if the Broadcast has only one consumer.
54     if (!op.output().hasOneUse()) return failure();
55 
56     Operation *user = *op.output().getUsers().begin();
57 
58     auto reshape_op = llvm::dyn_cast_or_null<ReshapeOp>(user);
59     if (!reshape_op) return failure();
60 
61     auto reshape_type = reshape_op.output().getType().cast<ShapedType>();
62 
63     if (!reshape_type.hasStaticShape()) return failure();
64     ArrayRef<int64_t> reshape_shape = reshape_type.getShape();
65 
66     auto input_type = op.input().getType().cast<ShapedType>();
67     auto output_type = op.output().getType().cast<ShapedType>();
68 
69     if (!input_type.hasRank() || !output_type.hasRank()) return failure();
70 
71     // The pattern attempts to reduce the rank of the input to BroadcastTo.
72     // Thus, we fail to match if the consuming reshape rank is larger.
73     ArrayRef<int64_t> input_shape = input_type.getShape();
74     if (reshape_shape.size() > input_shape.size()) return failure();
75 
76     // Extend the input shape with leading 1s to match the broadcast shape.
77     ArrayRef<int64_t> broadcast_shape = output_type.getShape();
78     SmallVector<int64_t, 4> input_shape_extended;
79     input_shape_extended.append(broadcast_shape.size() - input_shape.size(), 1);
80     input_shape_extended.append(input_shape.begin(), input_shape.end());
81 
82     // Collect non-unit dims and corresponding dim in the input shape.
83     SmallVector<int64_t, 4> input_carryover_dims;
84     SmallVector<int64_t, 4> non_unit_dims;
85 
86     for (int i = 0; i < input_shape_extended.size(); i++) {
87       int64_t dim = broadcast_shape[i];
88       if (dim != 1) {
89         non_unit_dims.push_back(dim);
90         input_carryover_dims.push_back(input_shape_extended[i]);
91       }
92     }
93 
94     // If the reshape rank is less than the number of non-unit dimensions
95     // of the broadcast, then the reshape collapses non-unit dimensions.
96     // TODO(rahulsp) : Handle this case with more careful checks.
97     if (reshape_shape.size() < non_unit_dims.size()) return failure();
98 
99     SmallVector<int64_t, 4> old_reshape_non_unit_dims;
100     SmallVector<int64_t, 4> new_reshape_dims;
101     int new_reshape_dim_idx = 0;
102     for (int64_t dim : reshape_shape) {
103       int new_reshape_dim = 1;
104       if (dim != 1) {
105         old_reshape_non_unit_dims.push_back(dim);
106         if (new_reshape_dim_idx < input_carryover_dims.size()) {
107           new_reshape_dim = input_carryover_dims[new_reshape_dim_idx];
108           new_reshape_dim_idx++;
109         }
110       }
111       new_reshape_dims.push_back(new_reshape_dim);
112     }
113 
114     if (non_unit_dims != old_reshape_non_unit_dims) return failure();
115 
116     if (failed(VerifyShapeOfReshapeOp(new_reshape_dims))) return failure();
117 
118     Type el_ty = getElementTypeOrSelf(op.getType());
119     TF::ConstOp new_reshape_shape = GetI64ConstantTensor(
120         rewriter, ArrayRef<int64_t>(new_reshape_dims), op.getLoc());
121     auto new_reshape_type = RankedTensorType::get(new_reshape_dims, el_ty);
122     ReshapeOp new_reshape =
123         rewriter.create<ReshapeOp>(new_reshape_shape.getLoc(), new_reshape_type,
124                                    op.input(), new_reshape_shape);
125     TF::ConstOp new_broadcast_shape =
126         GetI64ConstantTensor(rewriter, reshape_shape, op.getLoc());
127     rewriter.replaceOpWithNewOp<BroadcastToOp>(
128         reshape_op, reshape_op.output().getType(), new_reshape,
129         new_broadcast_shape);
130     return success();
131   }
132 };
133 
134 // Canonicalize operations in functions.
135 struct TensorFlowOptimizePass
136     : public TensorFlowOptimizePassBase<TensorFlowOptimizePass> {
initializemlir::TF::__anon197ea67d0111::TensorFlowOptimizePass137   LogicalResult initialize(MLIRContext *context) override {
138     RewritePatternSet pattern_list(context);
139     populateWithGenerated(pattern_list);
140     pattern_list.add<SimplifyBroadcastReshape>(context);
141     patterns = std::move(pattern_list);
142     return success();
143   }
144 
runOnOperationmlir::TF::__anon197ea67d0111::TensorFlowOptimizePass145   void runOnOperation() override {
146     auto func = getOperation();
147     if (failed(applyPatternsAndFoldGreedily(func, patterns)))
148       signalPassFailure();
149   }
150 
151   FrozenRewritePatternSet patterns;
152 };
153 
154 }  // namespace
155 
CreateTFStandardPipeline(OpPassManager & pm,const StandardPipelineOptions & options)156 void CreateTFStandardPipeline(OpPassManager &pm,
157                               const StandardPipelineOptions &options) {
158   OpPassManager &func_pm = pm.nest<func::FuncOp>();
159 
160   // First operates on the executor dialect:
161   // - remove dead islands.
162   // - fuse islands as much as possible.
163   // - materialize the eventual "pass-through" ops by inlining their content.
164   func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass());
165   func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass());
166   func_pm.addPass(CreateMaterializePassthroughOpPass());
167   if (options.form_clusters)
168     func_pm.addPass(TFDevice::CreateClusterFormationPass());
169 
170   // Hopefully there is a single island left, or there wasn't any to begin with.
171   // We now run the optimizer which operates mostly inside islands.
172   func_pm.addPass(createCanonicalizerPass());
173   pm.addPass(CreateTFShapeInferencePass());
174   if (options.enable_inliner) {
175     pm.addPass(createInlinerPass());
176   }
177   pm.addPass(createSymbolDCEPass());
178   pm.addNestedPass<func::FuncOp>(CreateTFOptimizePass());
179   pm.addNestedPass<func::FuncOp>(createCSEPass());
180 }
181 
CreateTFOptimizePass()182 std::unique_ptr<OperationPass<func::FuncOp>> CreateTFOptimizePass() {
183   return std::make_unique<TensorFlowOptimizePass>();
184 }
185 
RegisterTFOptimizePassPipeline()186 void RegisterTFOptimizePassPipeline() {
187   // Registers a pipeline builder function for the default
188   // canonicalize/optimizer.
189   static mlir::PassPipelineRegistration<StandardPipelineOptions> pipeline(
190       "tf-standard-pipeline",
191       "Run all the passes involved in transforming/optimizing the graph after "
192       "importing into MLIR, without any target specialization.",
193       CreateTFStandardPipeline);
194 }
195 
196 }  // namespace TF
197 }  // namespace mlir
198