xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass prepares for legalization to the TFLite dialect by
17 // converting operations in TensorFlow dialect into operations that can be
18 // legalized to TensorFlow Lite dialect with simple replacements.  The newly
19 // created operations are in the TensorFlow dialect if the operation can be
20 // represented using a TensorFlow op.  Otherwise, TensorFlow Lite dialect op is
21 // used.  For example, Conv2D in TFLite which uses OHWI data format for filters
22 // is not supported in TensorFlow because TensorFlow requires filters in the
23 // HWIO data format.
24 //
25 // Motivation to prepare for the TFLite legalization before the actual
26 // legalization is to exploit constant folding opportunities in any newly
27 // created ops by leveraging constant folding support for the TensorFlow ops.
28 // This way TFLite can be used as a serialization format only and does not
29 // require access to the TFLite runtime for optimizations as required by the
30 // TFLite team.
31 
32 #include <climits>
33 #include <cstdint>
34 #include <utility>
35 
36 #include "absl/memory/memory.h"
37 #include "absl/numeric/bits.h"
38 #include "llvm/ADT/ArrayRef.h"
39 #include "llvm/ADT/STLExtras.h"
40 #include "llvm/ADT/StringSwitch.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Support/Debug.h"
43 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"  // from @llvm-project
44 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
45 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
46 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
47 #include "mlir/IR/Attributes.h"  // from @llvm-project
48 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
49 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
50 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
51 #include "mlir/IR/Operation.h"  // from @llvm-project
52 #include "mlir/Pass/Pass.h"  // from @llvm-project
53 #include "mlir/Support/LLVM.h"  // from @llvm-project
54 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
55 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
56 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
57 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
58 #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h"
59 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
60 #include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h"
61 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
62 #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
63 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
64 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
65 #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
66 #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h"
67 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
69 #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
70 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
71 #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
72 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
73 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
74 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
75 
76 #define DEBUG_TYPE "tf-tfl-legalization"
77 
78 namespace mlir {
79 namespace TFL {
80 namespace {
81 // Returns a TF_CastOp to I32. This function is used for CastOps that are
82 // intermediate nodes in a TableGen pattern result. In such a case, the
83 // destination type is not inferred and must be given explicitly.
84 //
85 // Preconditions: The given value must have a ShapedType.
CreateTFCastOpI32(OpBuilder * builder,Location loc,Value x,BoolAttr truncate)86 static Value CreateTFCastOpI32(OpBuilder *builder, Location loc, Value x,
87                                BoolAttr truncate) {
88   auto x_type = x.getType().dyn_cast_or_null<ShapedType>();
89   if (!x_type) llvm_unreachable("unsupported type");
90   Type type = x_type.clone(builder->getI32Type());
91   return builder->create<TF::CastOp>(loc, type, x, truncate);
92 }
93 }  // namespace
94 
95 //===----------------------------------------------------------------------===//
96 // The actual PrepareTF Pass.
97 //
98 // TODO(hinsu): Add and use TensorFlow dialect ops for the ops created in this
99 // pass.
100 namespace {
101 #define GEN_PASS_CLASSES
102 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
103 
104 // Prepare TF operations in functions for subsequent legalization.
105 class PrepareTFPass : public PrepareTFPassBase<PrepareTFPass> {
106  public:
107   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareTFPass)
108 
109   PrepareTFPass() = default;
PrepareTFPass(const PrepareTFPass &)110   PrepareTFPass(const PrepareTFPass &) {}
PrepareTFPass(bool unfold_batch_matmul,bool allow_bf16_and_f16_type_legalization,bool use_fake_quant_num_bits=false)111   explicit PrepareTFPass(bool unfold_batch_matmul,
112                          bool allow_bf16_and_f16_type_legalization,
113                          bool use_fake_quant_num_bits = false) {
114     this->unfold_batch_matmul_ = unfold_batch_matmul;
115     this->allow_bf16_and_f16_type_legalization_ =
116         allow_bf16_and_f16_type_legalization;
117     this->use_fake_quant_num_bits_ = use_fake_quant_num_bits;
118   }
119 
120   void runOnOperation() override;
121 };
122 
123 // Transient state for preserving data from match to rewrite
124 struct ConvertTFConvOpMatchState {
125   IntegerAttr dilation_height_factor;
126   IntegerAttr dilation_width_factor;
127   StringAttr padding;
128   IntegerAttr stride_height;
129   IntegerAttr stride_width;
130 };
131 
132 // Templated class for declaring a converter from some TensorFlow convolution
133 // op into its counterpart in TensorFlow Lite.
134 //
135 // The `ConcreteType` deriving from this template must provide the following
136 // method for constructing TensorFlow Lite op:
137 //
138 //   TFL::[op] createTFLOp(ConvertTFConvOpMatchState *state,
139 //                         PatternRewriter &rewriter, Location loc,
140 //                         Type result_type, Value input,
141 //                         Value filter, Value bias) const;
142 //
143 // And also the following method for getting the dimension for bias tensor:
144 //
145 //  int64_t getBiasDim(ArrayRef<int64_t> filterShape) const;
146 template <typename ConcreteType, typename TFConvOpType>
147 class ConvertTFConvOp : public RewritePattern {
148  public:
ConvertTFConvOp(MLIRContext * context,bool allow_bf16_and_f16_type_legalization)149   ConvertTFConvOp(MLIRContext *context,
150                   bool allow_bf16_and_f16_type_legalization)
151       : RewritePattern(TFConvOpType::getOperationName(), 1, context),
152         intAttrOne(Builder(context).getI32IntegerAttr(1)),
153         allow_bf16_and_f16_type_legalization_(
154             allow_bf16_and_f16_type_legalization) {}
155 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const156   LogicalResult matchAndRewrite(Operation *op,
157                                 PatternRewriter &rewriter) const override {
158     // Assumes TensorFlow convolution op is already verified to be
159     // in valid form.
160 
161     // Match a TFConvOpType under the following conditions:
162     // * The 'T' attribute must exist and be of value DT_FLOAT.
163     // * The 'data_format' attribute must exist and be of value "NHWC".
164     // * The 'strides' attribute must exist and is of the form [1, X, Y, 1].
165     // * The 'dilations' attribute is optional, but it must be of the form
166     //   [1, X, Y, 1] if exists.
167 
168     TFConvOpType tf_op = cast<TFConvOpType>(op);
169     if (!TFTypeIsFloat32Tensor(tf_op.input()) &&
170         !(allow_bf16_and_f16_type_legalization_ &&
171           TFTypeIsBFloat16OrHalfTensor(tf_op.input())))
172       return failure();
173 
174     if (!TFDataFormatIsNHWC(op)) return failure();
175 
176     IntegerAttr height, width;
177     if (!TFIntListIs1XY1(op, "strides", &height, &width)) return failure();
178 
179     ConvertTFConvOpMatchState state;
180     state.stride_height = height;
181     state.stride_width = width;
182 
183     if (TFIntListIs1XY1(op, "dilations", &height, &width)) {
184       state.dilation_height_factor = height;
185       state.dilation_width_factor = width;
186     } else {
187       // If the 'dilations' attribute is missing, we use the default value (1)
188       // for both dilation height and width factor.
189       state.dilation_height_factor = intAttrOne;
190       state.dilation_width_factor = intAttrOne;
191     }
192 
193     TFPaddingIsSameOrValid(op, &state.padding);
194 
195     // Additionally, we require the filter operand to be of 4-D tensor type so
196     // that we can extract info from the shape (e.g., for constructing bias
197     // tensor, for setting depth_multiplier attribute, etc.).
198     auto filter = tf_op.filter();
199     auto filter_type = filter.getType().template dyn_cast<RankedTensorType>();
200     if (!filter_type || filter_type.getRank() != 4 ||
201         !filter_type.hasStaticShape())
202       return failure();
203 
204     Value input = tf_op.input();
205     RankedTensorType input_type =
206         input.getType().template dyn_cast<RankedTensorType>();
207     // Only rank size four input will be only available by the tf.Conv2D
208     // operator verification.
209     if (!input_type || input_type.isDynamicDim(3)) {
210       return failure();
211     }
212     // Check if the given op is based on grouped convolution.
213     // Dim size zero will be verified by the tf.Conv2D operator verification.
214     if (input_type.getDimSize(3) % filter_type.getDimSize(2) != 0) {
215       return failure();
216     }
217 
218     // TensorFlow convolution op only has two inputs, while the TFLite one has
219     // three, with the bias vector marked as optional. However, TOCO has a
220     // dedicated pass, EnsureBiasVectors, to create default bias vectors for all
221     // those missing. So we model TFLite convolution op as requiring three
222     // inputs to achieve the legalization task of EnsureBiasVector. this
223     // requires the filter tensor to have static shape.
224 
225     // TODO(antiagainst): also handle the case of tf.Add(tf.[op], <bias>)
226 
227     // Get a splat zero tensor with the expected dimension for the bias tensor
228     auto elem_type = filter_type.getElementType();
229     auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim(
230         filter_type.getShape());
231     auto bias_type = RankedTensorType::get({bias_dim}, elem_type);
232     auto bias_attr = rewriter.getZeroAttr(bias_type);
233     auto bias =
234         rewriter.create<TF::ConstOp>(op->getLoc(), bias_type, bias_attr);
235 
236     if (op->getAttrOfType<StringAttr>("padding").getValue() == "EXPLICIT") {
237       // Add Const op for padding value.
238       ArrayRef<Attribute> padding_attr_array =
239           op->getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
240 
241       auto get_int = [](Attribute attr) {
242         return attr.template cast<IntegerAttr>().getInt();
243       };
244 
245       SmallVector<int32_t> padding_values(padding_attr_array.size());
246       for (int i = 0; i < padding_attr_array.size(); i++) {
247         padding_values[i] =
248             static_cast<int32_t>(get_int(padding_attr_array[i]));
249       }
250 
251       RankedTensorType padding_attr_type = RankedTensorType::get(
252           {filter_type.getRank(), 2}, rewriter.getIntegerType(32));
253       auto padding_attr =
254           mlir::DenseIntElementsAttr::get(padding_attr_type, padding_values);
255 
256       auto padding_const =
257           rewriter.create<TF::ConstOp>(op->getLoc(), padding_attr);
258 
259       // Add Pad op.
260       auto pad_output_type = UnrankedTensorType::get(elem_type);
261       input = rewriter.create<TF::PadOp>(op->getLoc(), pad_output_type, input,
262                                          padding_const);
263 
264       // Set Conv padding to `VALID` since padding has been handled by Pad op.
265       state.padding = rewriter.getStringAttr("VALID");
266     }
267     auto conv_op = static_cast<const ConcreteType *>(this)->createTFLOp(
268         &state, rewriter, op->getLoc(), tf_op.getType(), input, filter, bias);
269 
270     rewriter.replaceOp(op, conv_op.getResult());
271     return success();
272   }
273 
274   const IntegerAttr intAttrOne;
275 
276  private:
277   bool allow_bf16_and_f16_type_legalization_;
278 };
279 
280 class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
281  public:
282   using BaseType = ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp>;
283 
ConvertTFConv2D(MLIRContext * context,bool allow_bf16_type_legalization)284   ConvertTFConv2D(MLIRContext *context, bool allow_bf16_type_legalization)
285       : BaseType(context, allow_bf16_type_legalization) {}
286 
getBiasDim(ArrayRef<int64_t> filterShape) const287   int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
288     return filterShape.back();
289   }
290 
createTFLOp(ConvertTFConvOpMatchState * state,PatternRewriter & rewriter,Location loc,Type result_type,Value input,Value filter,Value bias) const291   TFL::Conv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
292                             PatternRewriter &rewriter, Location loc,
293                             Type result_type, Value input, Value filter,
294                             Value bias) const {
295     filter = legalizeFilter(rewriter, loc, filter);
296     return rewriter.create<TFL::Conv2DOp>(
297         loc, result_type, input, filter, bias,
298         /*dilation_h_factor=*/state->dilation_height_factor,
299         /*dilation_w_factor=*/state->dilation_width_factor,
300         /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
301         /*padding=*/state->padding,
302         /*stride_h=*/state->stride_height,
303         /*stride_w=*/state->stride_width);
304   }
305 
306  private:
307   // Legalize the given filter by converting it from TensorFlow filter data
308   // format HWIO to TFLite Conv2D op filter data format OHWI and return Value
309   // for the converted filter.  Requires that filter is verified by the match
310   // method that it is a 4-D RankedTensorType.
legalizeFilter(PatternRewriter & rewriter,Location loc,Value filter) const311   Value legalizeFilter(PatternRewriter &rewriter, Location loc,
312                        Value filter) const {
313     // Create a constant op for HWIO to OHWI transpose permutation.
314     SmallVector<int, 4> perm = {3, 0, 1, 2};
315     auto perm_type = RankedTensorType::get({static_cast<int>(perm.size())},
316                                            rewriter.getIntegerType(32));
317     auto perm_attr =
318         DenseElementsAttr::get(perm_type, llvm::makeArrayRef<int>(perm));
319     auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
320 
321     // Create tensor type for the transpose result.
322     auto filter_type = filter.getType().cast<RankedTensorType>();
323     auto result_shape =
324         llvm::to_vector<4>(llvm::map_range(perm, [filter_type](int64_t dim) {
325           return filter_type.getDimSize(dim);
326         }));
327     auto elem_type = filter_type.getElementType();
328     auto result_type = RankedTensorType::get(result_shape, elem_type);
329 
330     return rewriter.create<TF::TransposeOp>(loc, result_type, filter, perm_op);
331   }
332 };
333 
334 class ConvertTFDepthwiseConv2dNative
335     : public ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
336                              TF::DepthwiseConv2dNativeOp> {
337  public:
338   using BaseType = ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
339                                    TF::DepthwiseConv2dNativeOp>;
340 
ConvertTFDepthwiseConv2dNative(MLIRContext * context,bool allow_bf16_type_legalization)341   ConvertTFDepthwiseConv2dNative(MLIRContext *context,
342                                  bool allow_bf16_type_legalization)
343       : BaseType(context, allow_bf16_type_legalization) {}
344 
getBiasDim(ArrayRef<int64_t> filterShape) const345   int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
346     return filterShape[2] * filterShape[3];
347   }
348 
createTFLOp(ConvertTFConvOpMatchState * state,PatternRewriter & rewriter,Location loc,Type result_type,Value input,Value filter,Value bias) const349   TFL::DepthwiseConv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
350                                      PatternRewriter &rewriter, Location loc,
351                                      Type result_type, Value input,
352                                      Value filter, Value bias) const {
353     // Compared to tfl.conv_2d, tfl.depthwise_conv_2d has an additional
354     // 'depth_multiplier' attribute. However, tf.DepthwiseConv2dNative does not
355     // have a corresponding 'depth_multiplier' attribute; the multiplier is the
356     // fourth dimension in the 4-D filter tensor. We query the multiplier from
357     // tf.DepthwiseConv2dNative and set it as the attribute value accordingly.
358     auto multiplier = filter.getType().cast<RankedTensorType>().getDimSize(3);
359 
360     filter = legalizeFilter(rewriter, loc, filter);
361     return rewriter.create<TFL::DepthwiseConv2DOp>(
362         loc, result_type, input, filter, bias,
363         /*dilation_h_factor=*/state->dilation_height_factor,
364         /*dilation_w_factor=*/state->dilation_width_factor,
365         /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
366         /*padding=*/state->padding,
367         /*stride_h=*/state->stride_height,
368         /*stride_w=*/state->stride_width,
369         /*depth_multiplier=*/rewriter.getI32IntegerAttr(multiplier));
370   }
371 
372  private:
373   /// Legalize the given filter by converting it from TensorFlow filter data
374   /// format to TFLite DepthwiseConv2D op filter data format and return Value
375   /// for the converted filter.  TensorFlow filter data format is
376   /// [filter_height, filter_width, in_channels, channel_multiplier] and TFLite
377   /// filter data format is [1, filter_height, filter_width, out_channels].
378   /// Requires that filter is verified by the match method that it is a 4-D
379   /// RankedTensorType.
legalizeFilter(PatternRewriter & rewriter,Location loc,Value filter) const380   Value legalizeFilter(PatternRewriter &rewriter, Location loc,
381                        Value filter) const {
382     auto filter_type = filter.getType().cast<RankedTensorType>();
383     auto filterShape = filter_type.getShape();
384     SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
385                                             filterShape[2] * filterShape[3]};
386     auto elem_type = filter_type.getElementType();
387     auto result_type = RankedTensorType::get(result_shape, elem_type);
388     // TensorFlow Lite `Reshape` op only support int32 shape tensor currently.
389     auto shape_type = RankedTensorType::get({4}, rewriter.getIntegerType(32));
390     SmallVector<Attribute, 4> result_shape_data(4);
391     for (int i = 0; i < 4; ++i) {
392       result_shape_data[i] =
393           rewriter.getI32IntegerAttr(static_cast<int32_t>(result_shape[i]));
394     }
395     auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
396     auto shape = rewriter.create<TF::ConstOp>(loc, shape_type, shape_attr);
397 
398     return rewriter.create<TF::ReshapeOp>(loc, result_type, filter, shape);
399   }
400 };
401 
402 // StridedSlice can have complicated attributes like begin_axis_mask,
403 // end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
404 // masks will complicate the strided_slice computation logic, we can simplify
405 // the logic by inserting a reshape op to pad the inputs so strided_slice can
406 // be easier to handle.
407 //
408 // So the graph may looks like below:
409 //   original_input -> strided_slice -> output
410 //      (transforms)
411 //   original_input -> reshape -> strided_slice -> output
412 //
413 // And the new shape is computed based on the masks.
414 //
415 // An example for new_axis_mask. say the new_axis_mask is 9 which represents
416 // [1 0 0 1], and that means we're inserting two new axes at 0 & 3 dim, so
417 // if original shape is [2, 3], now we reshape that into [1, 2, 3, 1].
418 struct ConvertTFStridedSlice : public RewritePattern {
ConvertTFStridedSlicemlir::TFL::__anon32f458ef0211::ConvertTFStridedSlice419   explicit ConvertTFStridedSlice(MLIRContext *context)
420       : RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
421 
RewriteNewAxisMaskmlir::TFL::__anon32f458ef0211::ConvertTFStridedSlice422   LogicalResult RewriteNewAxisMask(Operation *op,
423                                    PatternRewriter &rewriter) const {
424     TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
425     uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
426 
427     if (strided_slice_op.ellipsis_mask() != 0) {
428       // Ellipsis mask should have been lowered-away prior to invoking this
429       // function.
430       op->emitError() << "encountered a logical error";
431       return failure();
432     }
433 
434     // Insert a new reshape op.
435     Value original_input = strided_slice_op.input();
436     RankedTensorType original_input_type =
437         original_input.getType().dyn_cast<RankedTensorType>();
438     if (!original_input_type) {
439       return failure();
440     }
441 
442     const ArrayRef<int64_t> &original_input_shape =
443         original_input_type.getShape();
444     SmallVector<int64_t, 4> revised_shape;
445     int index = 0;
446     const int original_input_rank = original_input_shape.size();
447     while (index < original_input_rank || new_axis_mask) {
448       if (new_axis_mask & 1) {
449         revised_shape.emplace_back(1);
450       } else {
451         revised_shape.emplace_back(original_input_shape[index++]);
452       }
453       new_axis_mask >>= 1;
454     }
455 
456     if (failed(TF::VerifyShapeOfReshapeOp(revised_shape))) return failure();
457 
458     const int dim_size = revised_shape.size();
459     Location loc = strided_slice_op.getLoc();
460     auto shape_type =
461         RankedTensorType::get({dim_size}, rewriter.getIntegerType(32));
462     SmallVector<Attribute, 4> result_shape_data(dim_size);
463     for (int i = 0; i < dim_size; ++i) {
464       result_shape_data[i] =
465           rewriter.getI32IntegerAttr(static_cast<int32_t>(revised_shape[i]));
466     }
467 
468     auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
469     auto shape =
470         rewriter.create<arith::ConstantOp>(loc, shape_type, shape_attr);
471     auto revised_output_type = RankedTensorType::get(
472         revised_shape, original_input_type.getElementType());
473     TF::ReshapeOp reshape = rewriter.create<TF::ReshapeOp>(
474         loc, revised_output_type, original_input, shape);
475 
476     // Replace the original strided_slice.
477     uint64_t revised_begin_mask = strided_slice_op.begin_mask();
478     uint64_t revised_end_mask = strided_slice_op.end_mask();
479     // Since we expand the dims, we need to apply them to the begin_mask &
480     // end_mask.
481     revised_begin_mask |= strided_slice_op.new_axis_mask();
482     revised_end_mask |= strided_slice_op.new_axis_mask();
483 
484     // Enforce operator precedence.
485     uint64_t revised_shrink_axis_mask =
486         strided_slice_op.shrink_axis_mask() & ~strided_slice_op.new_axis_mask();
487 
488     auto attribute_type = rewriter.getIntegerType(64);
489     rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
490         op, strided_slice_op.getType(), reshape, strided_slice_op.begin(),
491         strided_slice_op.end(), strided_slice_op.strides(),
492         rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
493         rewriter.getIntegerAttr(attribute_type, revised_end_mask),
494         rewriter.getIntegerAttr(attribute_type,
495                                 strided_slice_op.ellipsis_mask()),
496         rewriter.getI64IntegerAttr(0),
497         rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask));
498     return success();
499   }
500 
RewriteEllipsisMaskmlir::TFL::__anon32f458ef0211::ConvertTFStridedSlice501   LogicalResult RewriteEllipsisMask(Operation *op,
502                                     PatternRewriter &rewriter) const {
503     TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
504 
505     uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask();
506     uint64_t shrink_axis_mask = strided_slice_op.shrink_axis_mask();
507     uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
508 
509     // Enforce operator precedence.
510     shrink_axis_mask &= ~ellipsis_mask;
511     new_axis_mask &= ~ellipsis_mask;
512 
513     DenseIntElementsAttr begin_dense_elem_attr;
514     Value begin = strided_slice_op.begin();
515     auto begin_ranked_attr_type = begin.getType().dyn_cast<RankedTensorType>();
516     if (!begin_ranked_attr_type ||
517         !matchPattern(begin, m_Constant(&begin_dense_elem_attr))) {
518       return failure();
519     }
520 
521     DenseIntElementsAttr end_dense_elem_attr;
522     Value end = strided_slice_op.end();
523     auto end_ranked_attr_type = end.getType().dyn_cast<RankedTensorType>();
524     if (!end_ranked_attr_type ||
525         !matchPattern(end, m_Constant(&end_dense_elem_attr))) {
526       return failure();
527     }
528 
529     DenseIntElementsAttr stride_dense_elem_attr;
530     Value stride = strided_slice_op.strides();
531     auto stride_ranked_attr_type =
532         stride.getType().dyn_cast<RankedTensorType>();
533     if (!stride_ranked_attr_type ||
534         !matchPattern(stride, m_Constant(&stride_dense_elem_attr))) {
535       return failure();
536     }
537 
538     Value input = strided_slice_op.input();
539     RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
540     if (!input_type) {
541       return failure();
542     }
543     const ArrayRef<int64_t> input_shape = input_type.getShape();
544 
545     const int input_size = input_shape.size();
546 
547     RankedTensorType begin_type = begin.getType().cast<RankedTensorType>();
548     const ArrayRef<int64_t> begin_shape = begin_type.getShape();
549     const int begin_dim = begin_shape.size();
550 
551     if (begin_dim != 1) return failure();
552 
553     // The ellipsis fill might exceed the current output shape because we are
554     // also taking account of any to-be-inserted new axes.
555     const int ellipsis_filled_dim_size =
556         input_size - begin_shape[0] + 1 + absl::popcount(new_axis_mask);
557 
558     int64_t begin_mask = strided_slice_op.begin_mask();
559     int64_t end_mask = strided_slice_op.end_mask();
560     int64_t revised_begin_mask = 0;
561     int64_t revised_end_mask = 0;
562     int64_t revised_shrink_axis_mask = 0;
563     int64_t revised_new_axis_mask = 0;
564 
565     SmallVector<int32_t, 4> padded_begin;
566     SmallVector<int32_t, 4> padded_end;
567     SmallVector<int32_t, 4> padded_stride;
568 
569     // Before the ellipsis.
570     int index = 0;
571     int new_index = 0;
572     while (((ellipsis_mask >> index) & 1) == 0) {
573       padded_begin.push_back(begin_dense_elem_attr.getValues<int32_t>()[index]);
574       padded_end.push_back(end_dense_elem_attr.getValues<int32_t>()[index]);
575       padded_stride.push_back(
576           stride_dense_elem_attr.getValues<int32_t>()[index]);
577       if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
578       if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
579       if ((shrink_axis_mask >> index) & 1)
580         revised_shrink_axis_mask |= (1 << new_index);
581 
582       if ((new_axis_mask >> index) & 1)
583         revised_new_axis_mask |= (1 << new_index);
584 
585       ++index;
586       ++new_index;
587     }
588 
589     // Ellipsis.
590     for (; new_index < index + ellipsis_filled_dim_size; ++new_index) {
591       revised_begin_mask |= (1 << new_index);
592       revised_end_mask |= (1 << new_index);
593 
594       // Mimic the begin/end/strides mask behavior.
595       padded_begin.push_back(0);
596       padded_end.push_back(0);
597       padded_stride.push_back(1);
598     }
599 
600     // Account for ellipsis mask.
601     ++index;
602 
603     // After the ellipsis.
604     for (; index < begin_shape[0];) {
605       padded_begin.push_back(begin_dense_elem_attr.getValues<int32_t>()[index]);
606       padded_end.push_back(end_dense_elem_attr.getValues<int32_t>()[index]);
607       padded_stride.push_back(
608           stride_dense_elem_attr.getValues<int32_t>()[index]);
609 
610       if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
611       if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
612       if ((shrink_axis_mask >> index) & 1)
613         revised_shrink_axis_mask |= (1 << new_index);
614       if ((new_axis_mask >> index) & 1)
615         revised_new_axis_mask |= (1 << new_index);
616 
617       ++index;
618       ++new_index;
619     }
620 
621     auto attribute_type = rewriter.getIntegerType(64);
622 
623     int full_dim_count = padded_begin.size();
624     auto type =
625         RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32));
626 
627     auto begin_attr = DenseElementsAttr::get<int32_t>(type, padded_begin);
628     auto begin_op =
629         rewriter.create<arith::ConstantOp>(op->getLoc(), type, begin_attr);
630     auto end_attr = DenseElementsAttr::get<int32_t>(type, padded_end);
631     auto end_op =
632         rewriter.create<arith::ConstantOp>(op->getLoc(), type, end_attr);
633     auto stride_attr = DenseElementsAttr::get<int32_t>(type, padded_stride);
634     auto stride_op =
635         rewriter.create<arith::ConstantOp>(op->getLoc(), type, stride_attr);
636 
637     rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
638         op, strided_slice_op.getType(), input, begin_op.getResult(),
639         end_op.getResult(), stride_op.getResult(),
640         rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
641         rewriter.getIntegerAttr(attribute_type, revised_end_mask),
642         /*ellipsis_mask=*/rewriter.getI64IntegerAttr(0),
643         rewriter.getIntegerAttr(attribute_type, revised_new_axis_mask),
644         rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask));
645 
646     return success();
647   }
648 
PadStridedSliceAttributeArraymlir::TFL::__anon32f458ef0211::ConvertTFStridedSlice649   void PadStridedSliceAttributeArray(DenseIntElementsAttr dense_elem_attr,
650                                      SmallVectorImpl<int32_t> &val,
651                                      SmallVectorImpl<int32_t> &padded_val,
652                                      ArrayRef<int32_t> padding_val,
653                                      int *mask) const {
654     for (const auto &idx : dense_elem_attr.getValues<APInt>()) {
655       val.push_back(idx.getSExtValue());
656       padded_val.push_back(idx.getSExtValue());
657     }
658     int attr_dim_count = val.size();
659     int full_dim_count = padding_val.size();
660     for (int i = attr_dim_count; i < full_dim_count; ++i) {
661       padded_val.push_back(padding_val[i]);
662       if (mask) *mask |= 1 << i;
663     }
664   }
665 
matchAndRewritemlir::TFL::__anon32f458ef0211::ConvertTFStridedSlice666   LogicalResult matchAndRewrite(Operation *op,
667                                 PatternRewriter &rewriter) const override {
668     TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
669 
670     // Handle ellipsis mask.
671     if (strided_slice_op.ellipsis_mask() != 0) {
672       return RewriteEllipsisMask(strided_slice_op, rewriter);
673     }
674 
675     // Handle new axis mask.
676     if (strided_slice_op.new_axis_mask() != 0) {
677       return RewriteNewAxisMask(strided_slice_op, rewriter);
678     }
679 
680     auto ranked_input_type =
681         strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
682     if (!ranked_input_type) {
683       return failure();
684     }
685 
686     auto begin_attr = strided_slice_op.begin();
687     auto end_attr = strided_slice_op.end();
688     auto strides_attr = strided_slice_op.strides();
689 
690     auto begin_attr_type = begin_attr.getType().dyn_cast<RankedTensorType>();
691     auto end_attr_type = end_attr.getType().dyn_cast<RankedTensorType>();
692     auto strides_attr_type =
693         strides_attr.getType().dyn_cast<RankedTensorType>();
694 
695     DenseIntElementsAttr begin_elem_attr;
696     DenseIntElementsAttr end_elem_attr;
697     DenseIntElementsAttr strides_elem_attr;
698 
699     if (!begin_attr_type ||
700         !matchPattern(begin_attr, m_Constant(&begin_elem_attr))) {
701       return failure();
702     }
703     if (!end_attr_type || !matchPattern(end_attr, m_Constant(&end_elem_attr))) {
704       return failure();
705     }
706     if (!strides_attr_type ||
707         !matchPattern(strides_attr, m_Constant(&strides_elem_attr))) {
708       return failure();
709     }
710 
711     SmallVector<int32_t, 4> begin, end, strides;
712     SmallVector<int32_t, 4> padded_begin, padded_end, padded_strides;
713 
714     int num_input_dims = ranked_input_type.getRank();
715     SmallVector<int32_t, 4> padding_begin(num_input_dims, 0);
716     auto input_shape = ranked_input_type.getShape();
717     SmallVector<int32_t, 4> padding_end(input_shape.begin(), input_shape.end());
718     SmallVector<int32_t, 4> padding_strides(num_input_dims, 1);
719 
720     int begin_mask = strided_slice_op.begin_mask();
721     int end_mask = strided_slice_op.end_mask();
722 
723     PadStridedSliceAttributeArray(begin_elem_attr, begin, padded_begin,
724                                   padding_begin, &begin_mask);
725     PadStridedSliceAttributeArray(end_elem_attr, end, padded_end, padding_end,
726                                   &end_mask);
727     PadStridedSliceAttributeArray(strides_elem_attr, strides, padded_strides,
728                                   padding_strides, nullptr);
729 
730     if (begin == padded_begin && end == padded_end &&
731         strides == padded_strides &&
732         begin_mask == strided_slice_op.begin_mask() &&
733         end_mask == strided_slice_op.end_mask()) {
734       return failure();
735     }
736 
737     auto begin_end_type =
738         RankedTensorType::get({num_input_dims}, rewriter.getIntegerType(32));
739     auto new_begin_attr = rewriter.create<arith::ConstantOp>(
740         op->getLoc(), begin_end_type,
741         DenseElementsAttr::get<int32_t>(begin_end_type, padded_begin));
742     auto new_end_attr = rewriter.create<arith::ConstantOp>(
743         op->getLoc(), begin_end_type,
744         DenseElementsAttr::get<int32_t>(begin_end_type, padded_end));
745     auto strides_type =
746         RankedTensorType::get({static_cast<long>(padded_strides.size())},
747                               rewriter.getIntegerType(32));
748     auto new_strides_attr = rewriter.create<arith::ConstantOp>(
749         op->getLoc(), strides_type,
750         DenseElementsAttr::get<int32_t>(strides_type, padded_strides));
751 
752     auto attribute_type = rewriter.getIntegerType(64);
753     rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
754         op, strided_slice_op.output().getType(), strided_slice_op.input(),
755         new_begin_attr, new_end_attr, new_strides_attr,
756         rewriter.getIntegerAttr(attribute_type, begin_mask),
757         rewriter.getIntegerAttr(attribute_type, end_mask),
758         rewriter.getIntegerAttr(attribute_type,
759                                 strided_slice_op.ellipsis_mask()),
760         rewriter.getIntegerAttr(attribute_type,
761                                 strided_slice_op.new_axis_mask()),
762         rewriter.getIntegerAttr(attribute_type,
763                                 strided_slice_op.shrink_axis_mask()));
764 
765     return success();
766   }
767 };
768 
769 struct ConvertTFBroadcastTo : public RewritePattern {
ConvertTFBroadcastTomlir::TFL::__anon32f458ef0211::ConvertTFBroadcastTo770   explicit ConvertTFBroadcastTo(MLIRContext *context)
771       : RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {}
772 
matchAndRewritemlir::TFL::__anon32f458ef0211::ConvertTFBroadcastTo773   LogicalResult matchAndRewrite(Operation *op,
774                                 PatternRewriter &rewriter) const override {
775     auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
776     auto input_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
777     auto output_type = tf_broadcast_to_op.output().getType().cast<ShapedType>();
778     auto shape_type = tf_broadcast_to_op.shape().getType().cast<ShapedType>();
779     Type element_type = input_type.getElementType();
780 
781     // Allow lowering when low dimension inputs are given and its type is F32 or
782     // I32.
783     if (!((output_type.hasRank() && output_type.getRank() <= 4) ||
784           (shape_type.hasStaticShape() && shape_type.getRank() == 1 &&
785            shape_type.getDimSize(0) <= 4)))
786       return failure();
787 
788     if (!(element_type.isa<BFloat16Type, Float32Type>() ||
789           element_type.isInteger(32) || element_type.isInteger(16)))
790       return failure();
791 
792     auto status_or_const_op =
793         CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1);
794     if (!status_or_const_op.ok()) {
795       return failure();
796     }
797 
798     auto tf_fill_op = rewriter.create<TF::FillOp>(
799         op->getLoc(), output_type, tf_broadcast_to_op.shape(),
800         status_or_const_op.ValueOrDie());
801 
802     auto mul_op = rewriter.create<TF::MulOp>(
803         op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op);
804     rewriter.replaceOp(op, mul_op.getResult());
805     return success();
806   }
807 };
808 
809 // The below pattern is equivalent to the DRR rule below
810 // The checks are dependent on generated values, so we can't add
811 // the checks on intermediate values, ideally we should find equivalent
812 // checks that guarantees the resultant ops are valid.
813 // The extra conditions are the broadcasting conditions.
814 //
815 // The pattern lower FusedBatchNormV3 to arithmetic ops.
816 // Specifically, performs the following calculation:
817 //
818 //   (x - mean) * scale / sqrt(variance + epsilon) + offset
819 //
820 // Let multiplier = scale / sqrt(variance + epsilon),
821 // to compute
822 //   (x - mean) * scale / sqrt(variance + epsilon) + offset,
823 // is then to compute
824 //   (x * multiplier) + (offset - mean * multiplier).
825 //
826 // def : Pattern<
827 //     (TF_FusedBatchNormV3Op:$root
828 //         $x, $scale, $offset, $mean, $variance,
829 //         F32Attr:$epsilon, $exponential_avg_factor,
830 //         $data_format, FalseBoolAttr:$is_training),
831 //     [(TF_AddOp
832 //         (TF_MulOp
833 //             $x,
834 //             (TF_MulOp:$multiplier
835 //                 $scale,
836 //                 (TF_RsqrtOp
837 //                     (TF_AddOp $variance,
838 //                               (TF_ConstOp $epsilon))))),
839 //         (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
840 //    // We already guaranteed that the last five results have no use so it does
841 //    // not matter what value we provide here for replacement.
842 //      /*batch_mean=*/(replaceWithValue $x),
843 //      /*batch_variance=*/(replaceWithValue $x),
844 //      /*reserve_space_1=*/(replaceWithValue $x),
845 //      /*reserve_space_2=*/(replaceWithValue $x),
846 //      /*reserve_space_3=*/(replaceWithValue $x)],
847 //     [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
848 //      (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
849 //      (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
850 //
851 // When is_training is set to true, the given variance and mean are not used.
852 // In above calculation, they are replaced by new values. These new mean and
853 // variance are calculated as following:
854 // new_mean = mean(x, axis=[0, 1, 2])
855 // new_variance = mean(squared_difference(x, new_mean), axis=[0, 1, 2])
856 //
857 // The DDR rule for the is_training equals true case is as following:
858 // def : Pattern<
859 //     (TF_FusedBatchNormV3Op:$root
860 //         $x, $scale, $offset, $mean, $variance,
861 //         F32Attr:$epsilon, $exponential_avg_factor,
862 //         $data_format, FalseBoolAttr:$is_training),
863 //     [(TF_AddOp
864 //         (TF_MulOp
865 //             $x,
866 //             (TF_MulOp:$multiplier
867 //                 $scale,
868 //                 (TF_RsqrtOp
869 //                     (TF_AddOp
870 //                         (TF_MeanOp
871 //                             (TF_SquaredDifferenceOp $x, $new_mean),
872 //                             (TF_ConstOp [0,1,2])),
873 //                         (TF_ConstOp $epsilon))))),
874 //         (TF_SubOp
875 //             $offset,
876 //             (TF_MulOp
877 //                 (TF_MeanOp $x, (TF_ConstOp [0,1,2])),
878 //                 $multiplier))),
879 //    // We already guaranteed that the last five results have no use so it does
880 //    // not matter what value we provide here for replacement.
881 //      /*batch_mean=*/(replaceWithValue $x),
882 //      /*batch_variance=*/(replaceWithValue $x),
883 //      /*reserve_space_1=*/(replaceWithValue $x),
884 //      /*reserve_space_2=*/(replaceWithValue $x),
885 //      /*reserve_space_3=*/(replaceWithValue $x)],
886 //     [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
887 //      (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
888 //      (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
889 
890 struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
FusedBatchNormV3Patmlir::TFL::__anon32f458ef0211::FusedBatchNormV3Pat891   explicit FusedBatchNormV3Pat(::mlir::MLIRContext *context)
892       : ::mlir::RewritePattern(
893             "tf.FusedBatchNormV3", 1, context,
894             {"tf.Add", "tf.Const", "tf.Mul", "tf.Rsqrt", "tf.Sub"}) {}
895 
matchAndRewritemlir::TFL::__anon32f458ef0211::FusedBatchNormV3Pat896   ::mlir::LogicalResult matchAndRewrite(
897       ::mlir::Operation *fused_batch_norm,
898       ::mlir::PatternRewriter &rewriter) const override {
899     // Variables for capturing values and attributes used for creating ops
900     Operation::operand_range mean(fused_batch_norm->getOperands());
901     ::mlir::FloatAttr exponential_avg_factor;
902     ::mlir::TF::FusedBatchNormV3Op root;
903     Operation::operand_range offset(fused_batch_norm->getOperands());
904     Operation::operand_range x(fused_batch_norm->getOperands());
905     Operation::operand_range scale(fused_batch_norm->getOperands());
906     Operation::operand_range variance(fused_batch_norm->getOperands());
907     ::mlir::FloatAttr epsilon;
908     ::mlir::BoolAttr is_training;
909 
910     // Match
911     auto fused_batch_norm_op =
912         dyn_cast_or_null<::mlir::TF::FusedBatchNormV3Op>(fused_batch_norm);
913     root = fused_batch_norm_op;
914     x = fused_batch_norm_op.getODSOperands(0);
915     scale = fused_batch_norm_op.getODSOperands(1);
916     offset = fused_batch_norm_op.getODSOperands(2);
917     mean = fused_batch_norm_op.getODSOperands(3);
918     variance = fused_batch_norm_op.getODSOperands(4);
919 
920     ::mlir::Value mean_value = (*mean.begin());
921     ::mlir::Value variance_value = (*variance.begin());
922 
923     if (!TFTypeIsFloat32Tensor(fused_batch_norm_op.x())) return failure();
924 
925     {
926       epsilon =
927           fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>("epsilon");
928       if (!epsilon)
929         epsilon = rewriter.getFloatAttr(rewriter.getF32Type(), 0.0001f);
930 
931       if (!(((epsilon.isa<::mlir::FloatAttr>())) &&
932             ((epsilon.cast<::mlir::FloatAttr>().getType().isF32())))) {
933         return rewriter.notifyMatchFailure(
934             fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
935               diag << "op 'tf.FusedBatchNormV3' attribute 'epsilon' failed to "
936                       "satisfy constraint: 32-bit float attribute";
937             });
938       }
939     }
940     {
941       exponential_avg_factor =
942           fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>(
943               "exponential_avg_factor");
944       if (!exponential_avg_factor)
945         exponential_avg_factor =
946             rewriter.getFloatAttr(rewriter.getF32Type(), 1.0f);
947     }
948     if (!TFDataFormatIsNHWC(fused_batch_norm_op) &&
949         !TFDataFormatIsNDHWC(fused_batch_norm_op))
950       return failure();
951 
952     if (!(((*root.getODSResults(1).begin()).use_empty()))) {
953       return rewriter.notifyMatchFailure(
954           fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
955             diag << "entities '' failed to satisfy constraint: has no use";
956           });
957     }
958 
959     if (!(((*root.getODSResults(2).begin()).use_empty()))) {
960       return rewriter.notifyMatchFailure(
961           fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
962             diag << "entities '' failed to satisfy constraint: has no use";
963           });
964     }
965 
966     if (!(((*root.getODSResults(3).begin()).use_empty()))) {
967       return rewriter.notifyMatchFailure(
968           fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
969             diag << "entities '' failed to satisfy constraint: has no use";
970           });
971     }
972 
973     if (!(((*root.getODSResults(4).begin()).use_empty()))) {
974       return rewriter.notifyMatchFailure(
975           fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
976             diag << "entities '' failed to satisfy constraint: has no use";
977           });
978     }
979 
980     if (!(((*root.getODSResults(5).begin()).use_empty()))) {
981       return rewriter.notifyMatchFailure(
982           fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
983             diag << "entities '' failed to satisfy constraint: has no use";
984           });
985     }
986 
987     is_training =
988         fused_batch_norm_op->getAttrOfType<::mlir::BoolAttr>("is_training");
989     auto odsLoc = rewriter.getFusedLoc({fused_batch_norm->getLoc()});
990 
991     // We need to make sure input and output shapes are compatible.
992     int64_t last_dim = -1;
993     {
994       auto is_last_dim_compatible = [](const Value &v, int64_t &last_dim) {
995         auto v_type = v.getType().dyn_cast_or_null<RankedTensorType>();
996         if (!v_type) return true;
997         int64_t v_last_dim = v_type.getDimSize(v_type.getRank() - 1);
998         if (v_last_dim == -1) return true;
999         if (last_dim != -1 && v_last_dim != last_dim) return false;
1000         last_dim = v_last_dim;
1001         return true;
1002       };
1003 
1004       if (!is_last_dim_compatible(*x.begin(), last_dim) ||
1005           !is_last_dim_compatible(*scale.begin(), last_dim) ||
1006           !is_last_dim_compatible(*offset.begin(), last_dim)) {
1007         return rewriter.notifyMatchFailure(
1008             fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1009               diag << "Shapes of scale and offset should be 1D and "
1010                       "compatible with x";
1011             });
1012       }
1013 
1014       if (!is_training.getValue()) {
1015         if (!is_last_dim_compatible(mean_value, last_dim) ||
1016             !is_last_dim_compatible(variance_value, last_dim)) {
1017           return rewriter.notifyMatchFailure(
1018               fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1019                 diag << "Shapes of mean and variance should be 1D and "
1020                         "compatible with x";
1021               });
1022         }
1023       }
1024 
1025       // Check if output shape and input shape are compatible.
1026       auto x_type = (*x.begin()).getType();
1027       auto y_type = (*root.getODSResults(0).begin()).getType();
1028       if (!OpTrait::util::getBroadcastedType(x_type, y_type)) {
1029         return rewriter.notifyMatchFailure(
1030             fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1031               diag << "Shapes of x and the first output should be compatible";
1032             });
1033       }
1034     }
1035 
1036     // For training, mean and variance is calculated from input values.
1037     if (is_training.getValue()) {
1038       auto input_type = fused_batch_norm_op.x()
1039                             .getType()
1040                             .dyn_cast_or_null<RankedTensorType>();
1041       if (!input_type || input_type.getRank() != 4) {
1042         return rewriter.notifyMatchFailure(
1043             fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1044               diag << "op 'tf.FusedBatchNormV3' that has 'is_training' equals "
1045                       "True is only supported with input of rank 4";
1046             });
1047       }
1048 
1049       ::mlir::TF::ConstOp reduce_dim_op;
1050       {
1051         auto reduce_dim_type =
1052             ::mlir::RankedTensorType::get({3}, rewriter.getIntegerType(32));
1053         ::mlir::SmallVector<int32_t, 3> reduce_dim_values = {0, 1, 2};
1054         reduce_dim_op = rewriter.create<TF::ConstOp>(
1055             odsLoc, ::mlir::DenseIntElementsAttr::get(reduce_dim_type,
1056                                                       reduce_dim_values));
1057       }
1058 
1059       auto new_mean_type =
1060           ::mlir::RankedTensorType::get({last_dim}, rewriter.getF32Type());
1061       ::mlir::TF::MeanOp mean_op_1;
1062       {
1063         ::mlir::Value x_value = (*x.begin());
1064         mean_op_1 = rewriter.create<TF::MeanOp>(
1065             odsLoc, new_mean_type, x_value, reduce_dim_op,
1066             /*keep_dims=*/rewriter.getBoolAttr(false));
1067       }
1068 
1069       ::mlir::TF::SquaredDifferenceOp square_diff_op;
1070       {
1071         ::mlir::Value tblgen_value_0 = (*x.begin());
1072         ::mlir::Value tblgen_value_1 = (*mean_op_1.getODSResults(0).begin());
1073         // If x has shape of [b, h, w, c], the result of mean_op_1 will have
1074         // shape of [c]. Therefore, their shapes are always compatible.
1075         square_diff_op = rewriter.create<::mlir::TF::SquaredDifferenceOp>(
1076             odsLoc, tblgen_value_0, tblgen_value_1);
1077       }
1078 
1079       ::mlir::TF::MeanOp mean_op_2;
1080       {
1081         ::mlir::Value input_value = (*square_diff_op.getODSResults(0).begin());
1082         mean_op_2 = rewriter.create<TF::MeanOp>(
1083             odsLoc, new_mean_type, input_value, reduce_dim_op,
1084             /*keep_dims=*/rewriter.getBoolAttr(false));
1085       }
1086 
1087       mean_value = (*mean_op_1.getODSResults(0).begin());
1088       variance_value = (*mean_op_2.getODSResults(0).begin());
1089     }  // End is_training equals true if.
1090 
1091     ::llvm::SmallVector<::mlir::Value, 4> replace_values;
1092     ::mlir::TF::ConstOp epsilon_const_op;
1093     {
1094       epsilon_const_op =
1095           rewriter.create<::mlir::TF::ConstOp>(odsLoc,
1096                                                /*value=*/epsilon);
1097     }
1098     ::mlir::TF::AddOp add_op_1;
1099     {
1100       ::mlir::Value epsilon_value =
1101           (*epsilon_const_op.getODSResults(0).begin());
1102       // Multiplying with a constant, no need to check broadcastibility.
1103       add_op_1 = rewriter.create<::mlir::TF::AddOp>(odsLoc,
1104                                                     /*x=*/variance_value,
1105                                                     /*y=*/epsilon_value);
1106     }
1107     ::mlir::TF::RsqrtOp rsqrt_op;
1108     {
1109       ::mlir::SmallVector<::mlir::Value, 4> tblgen_values;
1110       ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs;
1111       tblgen_values.push_back((*add_op_1.getODSResults(0).begin()));
1112       rsqrt_op = rewriter.create<::mlir::TF::RsqrtOp>(odsLoc, tblgen_values,
1113                                                       tblgen_attrs);
1114     }
1115     ::mlir::TF::MulOp multiplier;
1116     {
1117       ::mlir::Value tblgen_value_0 = (*scale.begin());
1118       ::mlir::Value tblgen_value_1 = (*rsqrt_op.getODSResults(0).begin());
1119       multiplier = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1120                                                       /*x=*/tblgen_value_0,
1121                                                       /*y=*/tblgen_value_1);
1122     }
1123     ::mlir::TF::MulOp mul_op_1;
1124     {
1125       ::mlir::Value tblgen_value_0 = (*x.begin());
1126       ::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin());
1127       mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1128                                                     /*x=*/tblgen_value_0,
1129                                                     /*y=*/tblgen_value_1);
1130     }
1131     ::mlir::TF::MulOp mul_op_2;
1132     {
1133       ::mlir::Value multiplier_value = (*multiplier.getODSResults(0).begin());
1134       mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1135                                                     /*x=*/mean_value,
1136                                                     /*y=*/multiplier_value);
1137     }
1138     ::mlir::TF::SubOp sub_op;
1139     {
1140       ::mlir::Value tblgen_value_0 = (*offset.begin());
1141       ::mlir::Value tblgen_value_1 = (*mul_op_2.getODSResults(0).begin());
1142       sub_op = rewriter.create<::mlir::TF::SubOp>(odsLoc,
1143                                                   /*x=*/tblgen_value_0,
1144                                                   /*y=*/tblgen_value_1);
1145     }
1146     ::mlir::TF::AddOp add_op_2;
1147     {
1148       ::mlir::SmallVector<::mlir::Value, 4> tblgen_values;
1149       ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs;
1150       tblgen_values.push_back((*mul_op_1.getODSResults(0).begin()));
1151       tblgen_values.push_back((*sub_op.getODSResults(0).begin()));
1152       ::mlir::SmallVector<::mlir::Type, 4> tblgen_types;
1153       for (auto v : fused_batch_norm_op.getODSResults(0)) {
1154         tblgen_types.push_back(v.getType());
1155       }
1156       add_op_2 = rewriter.create<::mlir::TF::AddOp>(
1157           odsLoc, tblgen_types, tblgen_values, tblgen_attrs);
1158     }
1159     for (auto v :
1160          ::llvm::SmallVector<::mlir::Value, 4>{add_op_2.getODSResults(0)}) {
1161       replace_values.push_back(v);
1162     }
1163     for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1164       replace_values.push_back(v);
1165     }
1166     for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1167       replace_values.push_back(v);
1168     }
1169     for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1170       replace_values.push_back(v);
1171     }
1172     for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1173       replace_values.push_back(v);
1174     }
1175     for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1176       replace_values.push_back(v);
1177     }
1178     rewriter.replaceOp(fused_batch_norm, replace_values);
1179     return success();
1180   };
1181 };
1182 
1183 #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
1184 
1185 // Returns success if all the operations in the `op`'s regions including `op`
1186 // itself are legal in a TFLite pipeline.
ValidateOp(Operation * op)1187 LogicalResult ValidateOp(Operation *op) {
1188   bool has_illegal_ops = false;
1189   op->walk([&](Operation *op) {
1190     if (isa<TF::VariableV2Op>(op)) {
1191       has_illegal_ops = true;
1192       op->emitOpError() << "is illegal in a TFLite pipeline";
1193     }
1194   });
1195 
1196   return failure(has_illegal_ops);
1197 }
1198 
1199 // Converts a set of TF2XLA ops into pure TF ops for future legalizations as
1200 // TF2XLA ops aren't supported by later stages.
ConvertTf2XlaOps(func::FuncOp func,MLIRContext * context)1201 LogicalResult ConvertTf2XlaOps(func::FuncOp func, MLIRContext *context) {
1202   ConversionTarget target(*context);
1203   target.addLegalDialect<arith::ArithmeticDialect>();
1204   target.addLegalDialect<func::FuncDialect>();
1205   target.addLegalDialect<TF::TensorFlowDialect>();
1206   target.addLegalOp<ModuleOp>();
1207   target.addLegalOp<func::FuncOp>();
1208   target.addIllegalOp<TF::XlaConvV2Op>();
1209   target.addIllegalOp<TF::XlaGatherOp>();
1210 
1211   RewritePatternSet patterns(context);
1212   mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns, context);
1213   mhlo::PopulateLegalizeTfPatterns(context, &patterns);
1214   TF::PopulateLegalizeHloToTfPatterns(&patterns, context);
1215   mhlo::GatherOp::getCanonicalizationPatterns(patterns, context);
1216 
1217   return applyPartialConversion(func, target, std::move(patterns));
1218 }
1219 
1220 // Convert rfft to rfft2d.
1221 // The transformation pattern looks like below:
1222 //
1223 //    input     fft_len
1224 //     \      /
1225 //     rfft
1226 //
1227 //     ||
1228 //     \/
1229 //
1230 //   input       fft_len
1231 //    \            /
1232 //   expand_dim    concat with [1] at the front
1233 //      \         /
1234 //     rfft_2d
1235 //       |
1236 //     squeeze
1237 struct ConvertRfftToRfft2d : public RewritePattern {
ConvertRfftToRfft2dmlir::TFL::__anon32f458ef0211::ConvertRfftToRfft2d1238   explicit ConvertRfftToRfft2d(MLIRContext *context)
1239       : RewritePattern(TF::RFFTOp::getOperationName(), 1, context) {}
1240 
matchAndRewritemlir::TFL::__anon32f458ef0211::ConvertRfftToRfft2d1241   LogicalResult matchAndRewrite(Operation *op,
1242                                 PatternRewriter &rewriter) const override {
1243     auto rfft_op = dyn_cast<TF::RFFTOp>(op);
1244 
1245     auto input = rfft_op.input();
1246     auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
1247     if (!input_type) return failure();
1248     auto fft_len = rfft_op.fft_length();
1249     auto fft_len_type = fft_len.getType().dyn_cast_or_null<ShapedType>();
1250     if (!fft_len_type) return failure();
1251 
1252     auto output_type =
1253         rfft_op.getResult().getType().dyn_cast_or_null<RankedTensorType>();
1254     if (!output_type) return failure();
1255 
1256     // Expanded inputs.
1257     // Insert at -2 location.
1258     auto one_ele_type =
1259         mlir::RankedTensorType::get({1}, rewriter.getIntegerType(32));
1260     auto minus_two = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
1261                                                   one_ele_type, -2);
1262 
1263     SmallVector<int64_t, 4> expanded_input_shape;
1264     SmallVector<int64_t, 4> expanded_output_shape;
1265     int expanded_rank = input_type.getRank() + 1;
1266     int r = 0;
1267     for (int i = 0; i < expanded_rank; ++i) {
1268       if (i == expanded_rank - 2) {
1269         expanded_input_shape.push_back(1);
1270         expanded_output_shape.push_back(1);
1271       } else {
1272         expanded_input_shape.push_back(input_type.getDimSize(r));
1273         expanded_output_shape.push_back(output_type.getDimSize(r));
1274         r++;
1275       }
1276     }
1277 
1278     auto expaned_input_type = mlir::RankedTensorType::get(
1279         expanded_input_shape, input_type.getElementType());
1280     TF::ExpandDimsOp expanded_input = rewriter.create<TF::ExpandDimsOp>(
1281         rfft_op.getLoc(), expaned_input_type, input, minus_two->getResult());
1282 
1283     // Expanded fft_len.
1284     auto one_attr = mlir::DenseIntElementsAttr::get(one_ele_type, {1});
1285 
1286     auto one = rewriter.create<TF::ConstOp>(rfft_op.getLoc(), one_attr);
1287 
1288     auto zero = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
1289                                              one_ele_type, 0);
1290 
1291     auto expanded_fft_len_type =
1292         mlir::RankedTensorType::get({2}, fft_len_type.getElementType());
1293 
1294     TF::ConcatV2Op expanded_fft_len = rewriter.create<TF::ConcatV2Op>(
1295         rfft_op.getLoc(), expanded_fft_len_type,
1296         SmallVector<Value, 2>({one.getResult(), fft_len}), zero->getResult());
1297 
1298     // Insert the rfft_2d.
1299     auto rfft2d_out_type = mlir::RankedTensorType::get(
1300         expanded_output_shape, output_type.getElementType());
1301     TF::RFFT2DOp rfft2d = rewriter.create<TF::RFFT2DOp>(
1302         rfft_op.getLoc(), rfft2d_out_type, expanded_input.getResult(),
1303         expanded_fft_len.getResult());
1304 
1305     // Insert the squeeze op.
1306     auto squeeze_dim = rewriter.getI64ArrayAttr({-2});
1307     TF::SqueezeOp squeeze = rewriter.create<TF::SqueezeOp>(
1308         rfft_op.getLoc(), output_type, rfft2d.getResult(), squeeze_dim);
1309 
1310     rewriter.replaceOp(op, squeeze.getResult());
1311 
1312     return success();
1313   }
1314 };
1315 
1316 // Replaces the Identity op with its input in either of the following scenarios
1317 // : 1) The Identity op's input and output have same types/shapes. 2) The result
1318 // of Identity op is only used by TF ops.
1319 struct RemoveIdentity : public OpRewritePattern<TF::IdentityOp> {
1320   using OpRewritePattern<TF::IdentityOp>::OpRewritePattern;
1321 
matchAndRewritemlir::TFL::__anon32f458ef0211::RemoveIdentity1322   LogicalResult matchAndRewrite(TF::IdentityOp identity,
1323                                 PatternRewriter &rewriter) const override {
1324     // Replace the op with the input if input and result have the same type.
1325     if (identity.input().getType() == identity.getType()) {
1326       rewriter.replaceOp(identity, identity.input());
1327       return success();
1328     }
1329     // Replace the op with the input if output is only used by TF ops.
1330     // Currently this is more on the conservative side since we need to ensure
1331     // every consumer op to be a TF op before applying this pattern. We can
1332     // consider to revisit this in the future if this turns out to be too
1333     // restrictive.
1334     for (Operation *user : identity->getUsers()) {
1335       if (user->getDialect()->getNamespace() != "tf") {
1336         return failure();
1337       }
1338     }
1339 
1340     rewriter.replaceOp(identity, identity.input());
1341     return success();
1342   }
1343 };
1344 
runOnOperation()1345 void PrepareTFPass::runOnOperation() {
1346   MLIRContext *ctx = &getContext();
1347   RewritePatternSet patterns(ctx);
1348   RewritePatternSet phase_2_patterns(ctx);
1349   auto func = getOperation();
1350 
1351   // Check illegal ops in a TFLite pipeline (e.g. trainning only ops) , since
1352   // PrepareTFPass is the very first TFLite pass in the pipeline.
1353   // TODO(jingpu): It might be better to split this check into its own pass
1354   // to make things more modular.
1355   if (failed(ValidateOp(func))) {
1356     func.emitError() << "tfl-prepare-tf pass failed.";
1357     signalPassFailure();
1358     return;
1359   }
1360 
1361   if (failed(ConvertTf2XlaOps(func, ctx))) {
1362     signalPassFailure();
1363     return;
1364   }
1365 
1366   // This pattern will try to identify and optimize for dilated convolution.
1367   // e.g. Patterns like "SpaceToBatchND -> Conv2D -> BatchToSpaceND" will be
1368   // replaced with a single Conv op with dilation parameter.
1369   patterns.add<ConvertTFDilatedConvOp<TF::Conv2DOp>, FusedBatchNormV3Pat,
1370                ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx);
1371 
1372   patterns.add<RemoveIdentity>(ctx);
1373   TFL::populateWithGenerated(patterns);
1374   // Remove redundant reshape ops.
1375   TF::ReshapeOp::getCanonicalizationPatterns(patterns, ctx);
1376   // TODO(karimnosseir): Split to separate pass probably after
1377   // deciding on long term plan for this optimization.
1378   // This will allow optimizing any TF_Mul->TF_Conv in the graph
1379   // and any expanded from FusedBatchNorm. We need to do this
1380   // before converting TF_Conv to TFL_Conv
1381   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
1382 
1383   // Remove the wrapper of the tf.FakeQuant* ops and also insert the
1384   // tfl.quantize and tfl.dequantize to preserve the quantization parameters.
1385   // This is done after the first round of optimization to make sure all the
1386   // min/max operands of the tf.FakeQuant* are constants to be matched. The
1387   // following round of optimization will folding the unwrapped
1388   // tf.FakeQuant* ops with the weight constants.
1389   if (failed(ConvertFakeQuantOps(func, ctx, use_fake_quant_num_bits_))) {
1390     signalPassFailure();
1391     return;
1392   }
1393 
1394   // Load the generated pattern again, so new quantization pass-through
1395   // will be applied.
1396   TFL::populateWithGenerated(phase_2_patterns);
1397   if (unfold_batch_matmul_) {
1398     TF::PopulateUnrollTfBatchMatMul(ctx, phase_2_patterns);
1399   }
1400   phase_2_patterns
1401       .add<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo, ConvertTFStridedSlice,
1402            ConvertRfftToRfft2d, RemoveIdentity>(ctx);
1403   phase_2_patterns.add<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
1404       ctx, allow_bf16_and_f16_type_legalization_);
1405   // Remove redundant reshape ops.
1406   TF::ReshapeOp::getCanonicalizationPatterns(phase_2_patterns, ctx);
1407 
1408   (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
1409 }
1410 
1411 }  // namespace
1412 
1413 // Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
CreatePrepareTFPass(bool unfold_batch_matmul,bool allow_bf16_and_f16_type_legalization,bool use_fake_quant_num_bits)1414 std::unique_ptr<OperationPass<func::FuncOp>> CreatePrepareTFPass(
1415     bool unfold_batch_matmul, bool allow_bf16_and_f16_type_legalization,
1416     bool use_fake_quant_num_bits) {
1417   return std::make_unique<PrepareTFPass>(unfold_batch_matmul,
1418                                          allow_bf16_and_f16_type_legalization,
1419                                          use_fake_quant_num_bits);
1420 }
1421 
1422 // Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
CreatePrepareTFPass()1423 std::unique_ptr<OperationPass<func::FuncOp>> CreatePrepareTFPass() {
1424   return std::make_unique<PrepareTFPass>();
1425 }
1426 
1427 }  // namespace TFL
1428 }  // namespace mlir
1429