1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass prepares for legalization to the TFLite dialect by
17 // converting operations in TensorFlow dialect into operations that can be
18 // legalized to TensorFlow Lite dialect with simple replacements. The newly
19 // created operations are in the TensorFlow dialect if the operation can be
20 // represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op is
21 // used. For example, Conv2D in TFLite which uses OHWI data format for filters
22 // is not supported in TensorFlow because TensorFlow requires filters in the
23 // HWIO data format.
24 //
25 // Motivation to prepare for the TFLite legalization before the actual
26 // legalization is to exploit constant folding opportunities in any newly
27 // created ops by leveraging constant folding support for the TensorFlow ops.
28 // This way TFLite can be used as a serialization format only and does not
29 // require access to the TFLite runtime for optimizations as required by the
30 // TFLite team.
31
32 #include <climits>
33 #include <cstdint>
34 #include <utility>
35
36 #include "absl/memory/memory.h"
37 #include "absl/numeric/bits.h"
38 #include "llvm/ADT/ArrayRef.h"
39 #include "llvm/ADT/STLExtras.h"
40 #include "llvm/ADT/StringSwitch.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Support/Debug.h"
43 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project
44 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
45 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
46 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
47 #include "mlir/IR/Attributes.h" // from @llvm-project
48 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
49 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
50 #include "mlir/IR/MLIRContext.h" // from @llvm-project
51 #include "mlir/IR/Operation.h" // from @llvm-project
52 #include "mlir/Pass/Pass.h" // from @llvm-project
53 #include "mlir/Support/LLVM.h" // from @llvm-project
54 #include "mlir/Support/LogicalResult.h" // from @llvm-project
55 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
56 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
57 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
58 #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h"
59 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
60 #include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h"
61 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
62 #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
63 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
64 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
65 #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
66 #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h"
67 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
69 #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
70 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
71 #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
72 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
73 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
74 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
75
76 #define DEBUG_TYPE "tf-tfl-legalization"
77
78 namespace mlir {
79 namespace TFL {
80 namespace {
81 // Returns a TF_CastOp to I32. This function is used for CastOps that are
82 // intermediate nodes in a TableGen pattern result. In such a case, the
83 // destination type is not inferred and must be given explicitly.
84 //
85 // Preconditions: The given value must have a ShapedType.
CreateTFCastOpI32(OpBuilder * builder,Location loc,Value x,BoolAttr truncate)86 static Value CreateTFCastOpI32(OpBuilder *builder, Location loc, Value x,
87 BoolAttr truncate) {
88 auto x_type = x.getType().dyn_cast_or_null<ShapedType>();
89 if (!x_type) llvm_unreachable("unsupported type");
90 Type type = x_type.clone(builder->getI32Type());
91 return builder->create<TF::CastOp>(loc, type, x, truncate);
92 }
93 } // namespace
94
95 //===----------------------------------------------------------------------===//
96 // The actual PrepareTF Pass.
97 //
98 // TODO(hinsu): Add and use TensorFlow dialect ops for the ops created in this
99 // pass.
100 namespace {
101 #define GEN_PASS_CLASSES
102 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
103
104 // Prepare TF operations in functions for subsequent legalization.
105 class PrepareTFPass : public PrepareTFPassBase<PrepareTFPass> {
106 public:
107 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareTFPass)
108
109 PrepareTFPass() = default;
PrepareTFPass(const PrepareTFPass &)110 PrepareTFPass(const PrepareTFPass &) {}
PrepareTFPass(bool unfold_batch_matmul,bool allow_bf16_and_f16_type_legalization,bool use_fake_quant_num_bits=false)111 explicit PrepareTFPass(bool unfold_batch_matmul,
112 bool allow_bf16_and_f16_type_legalization,
113 bool use_fake_quant_num_bits = false) {
114 this->unfold_batch_matmul_ = unfold_batch_matmul;
115 this->allow_bf16_and_f16_type_legalization_ =
116 allow_bf16_and_f16_type_legalization;
117 this->use_fake_quant_num_bits_ = use_fake_quant_num_bits;
118 }
119
120 void runOnOperation() override;
121 };
122
123 // Transient state for preserving data from match to rewrite
124 struct ConvertTFConvOpMatchState {
125 IntegerAttr dilation_height_factor;
126 IntegerAttr dilation_width_factor;
127 StringAttr padding;
128 IntegerAttr stride_height;
129 IntegerAttr stride_width;
130 };
131
132 // Templated class for declaring a converter from some TensorFlow convolution
133 // op into its counterpart in TensorFlow Lite.
134 //
135 // The `ConcreteType` deriving from this template must provide the following
136 // method for constructing TensorFlow Lite op:
137 //
138 // TFL::[op] createTFLOp(ConvertTFConvOpMatchState *state,
139 // PatternRewriter &rewriter, Location loc,
140 // Type result_type, Value input,
141 // Value filter, Value bias) const;
142 //
143 // And also the following method for getting the dimension for bias tensor:
144 //
145 // int64_t getBiasDim(ArrayRef<int64_t> filterShape) const;
146 template <typename ConcreteType, typename TFConvOpType>
147 class ConvertTFConvOp : public RewritePattern {
148 public:
ConvertTFConvOp(MLIRContext * context,bool allow_bf16_and_f16_type_legalization)149 ConvertTFConvOp(MLIRContext *context,
150 bool allow_bf16_and_f16_type_legalization)
151 : RewritePattern(TFConvOpType::getOperationName(), 1, context),
152 intAttrOne(Builder(context).getI32IntegerAttr(1)),
153 allow_bf16_and_f16_type_legalization_(
154 allow_bf16_and_f16_type_legalization) {}
155
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const156 LogicalResult matchAndRewrite(Operation *op,
157 PatternRewriter &rewriter) const override {
158 // Assumes TensorFlow convolution op is already verified to be
159 // in valid form.
160
161 // Match a TFConvOpType under the following conditions:
162 // * The 'T' attribute must exist and be of value DT_FLOAT.
163 // * The 'data_format' attribute must exist and be of value "NHWC".
164 // * The 'strides' attribute must exist and is of the form [1, X, Y, 1].
165 // * The 'dilations' attribute is optional, but it must be of the form
166 // [1, X, Y, 1] if exists.
167
168 TFConvOpType tf_op = cast<TFConvOpType>(op);
169 if (!TFTypeIsFloat32Tensor(tf_op.input()) &&
170 !(allow_bf16_and_f16_type_legalization_ &&
171 TFTypeIsBFloat16OrHalfTensor(tf_op.input())))
172 return failure();
173
174 if (!TFDataFormatIsNHWC(op)) return failure();
175
176 IntegerAttr height, width;
177 if (!TFIntListIs1XY1(op, "strides", &height, &width)) return failure();
178
179 ConvertTFConvOpMatchState state;
180 state.stride_height = height;
181 state.stride_width = width;
182
183 if (TFIntListIs1XY1(op, "dilations", &height, &width)) {
184 state.dilation_height_factor = height;
185 state.dilation_width_factor = width;
186 } else {
187 // If the 'dilations' attribute is missing, we use the default value (1)
188 // for both dilation height and width factor.
189 state.dilation_height_factor = intAttrOne;
190 state.dilation_width_factor = intAttrOne;
191 }
192
193 TFPaddingIsSameOrValid(op, &state.padding);
194
195 // Additionally, we require the filter operand to be of 4-D tensor type so
196 // that we can extract info from the shape (e.g., for constructing bias
197 // tensor, for setting depth_multiplier attribute, etc.).
198 auto filter = tf_op.filter();
199 auto filter_type = filter.getType().template dyn_cast<RankedTensorType>();
200 if (!filter_type || filter_type.getRank() != 4 ||
201 !filter_type.hasStaticShape())
202 return failure();
203
204 Value input = tf_op.input();
205 RankedTensorType input_type =
206 input.getType().template dyn_cast<RankedTensorType>();
207 // Only rank size four input will be only available by the tf.Conv2D
208 // operator verification.
209 if (!input_type || input_type.isDynamicDim(3)) {
210 return failure();
211 }
212 // Check if the given op is based on grouped convolution.
213 // Dim size zero will be verified by the tf.Conv2D operator verification.
214 if (input_type.getDimSize(3) % filter_type.getDimSize(2) != 0) {
215 return failure();
216 }
217
218 // TensorFlow convolution op only has two inputs, while the TFLite one has
219 // three, with the bias vector marked as optional. However, TOCO has a
220 // dedicated pass, EnsureBiasVectors, to create default bias vectors for all
221 // those missing. So we model TFLite convolution op as requiring three
222 // inputs to achieve the legalization task of EnsureBiasVector. this
223 // requires the filter tensor to have static shape.
224
225 // TODO(antiagainst): also handle the case of tf.Add(tf.[op], <bias>)
226
227 // Get a splat zero tensor with the expected dimension for the bias tensor
228 auto elem_type = filter_type.getElementType();
229 auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim(
230 filter_type.getShape());
231 auto bias_type = RankedTensorType::get({bias_dim}, elem_type);
232 auto bias_attr = rewriter.getZeroAttr(bias_type);
233 auto bias =
234 rewriter.create<TF::ConstOp>(op->getLoc(), bias_type, bias_attr);
235
236 if (op->getAttrOfType<StringAttr>("padding").getValue() == "EXPLICIT") {
237 // Add Const op for padding value.
238 ArrayRef<Attribute> padding_attr_array =
239 op->getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
240
241 auto get_int = [](Attribute attr) {
242 return attr.template cast<IntegerAttr>().getInt();
243 };
244
245 SmallVector<int32_t> padding_values(padding_attr_array.size());
246 for (int i = 0; i < padding_attr_array.size(); i++) {
247 padding_values[i] =
248 static_cast<int32_t>(get_int(padding_attr_array[i]));
249 }
250
251 RankedTensorType padding_attr_type = RankedTensorType::get(
252 {filter_type.getRank(), 2}, rewriter.getIntegerType(32));
253 auto padding_attr =
254 mlir::DenseIntElementsAttr::get(padding_attr_type, padding_values);
255
256 auto padding_const =
257 rewriter.create<TF::ConstOp>(op->getLoc(), padding_attr);
258
259 // Add Pad op.
260 auto pad_output_type = UnrankedTensorType::get(elem_type);
261 input = rewriter.create<TF::PadOp>(op->getLoc(), pad_output_type, input,
262 padding_const);
263
264 // Set Conv padding to `VALID` since padding has been handled by Pad op.
265 state.padding = rewriter.getStringAttr("VALID");
266 }
267 auto conv_op = static_cast<const ConcreteType *>(this)->createTFLOp(
268 &state, rewriter, op->getLoc(), tf_op.getType(), input, filter, bias);
269
270 rewriter.replaceOp(op, conv_op.getResult());
271 return success();
272 }
273
274 const IntegerAttr intAttrOne;
275
276 private:
277 bool allow_bf16_and_f16_type_legalization_;
278 };
279
280 class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
281 public:
282 using BaseType = ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp>;
283
ConvertTFConv2D(MLIRContext * context,bool allow_bf16_type_legalization)284 ConvertTFConv2D(MLIRContext *context, bool allow_bf16_type_legalization)
285 : BaseType(context, allow_bf16_type_legalization) {}
286
getBiasDim(ArrayRef<int64_t> filterShape) const287 int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
288 return filterShape.back();
289 }
290
createTFLOp(ConvertTFConvOpMatchState * state,PatternRewriter & rewriter,Location loc,Type result_type,Value input,Value filter,Value bias) const291 TFL::Conv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
292 PatternRewriter &rewriter, Location loc,
293 Type result_type, Value input, Value filter,
294 Value bias) const {
295 filter = legalizeFilter(rewriter, loc, filter);
296 return rewriter.create<TFL::Conv2DOp>(
297 loc, result_type, input, filter, bias,
298 /*dilation_h_factor=*/state->dilation_height_factor,
299 /*dilation_w_factor=*/state->dilation_width_factor,
300 /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
301 /*padding=*/state->padding,
302 /*stride_h=*/state->stride_height,
303 /*stride_w=*/state->stride_width);
304 }
305
306 private:
307 // Legalize the given filter by converting it from TensorFlow filter data
308 // format HWIO to TFLite Conv2D op filter data format OHWI and return Value
309 // for the converted filter. Requires that filter is verified by the match
310 // method that it is a 4-D RankedTensorType.
legalizeFilter(PatternRewriter & rewriter,Location loc,Value filter) const311 Value legalizeFilter(PatternRewriter &rewriter, Location loc,
312 Value filter) const {
313 // Create a constant op for HWIO to OHWI transpose permutation.
314 SmallVector<int, 4> perm = {3, 0, 1, 2};
315 auto perm_type = RankedTensorType::get({static_cast<int>(perm.size())},
316 rewriter.getIntegerType(32));
317 auto perm_attr =
318 DenseElementsAttr::get(perm_type, llvm::makeArrayRef<int>(perm));
319 auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
320
321 // Create tensor type for the transpose result.
322 auto filter_type = filter.getType().cast<RankedTensorType>();
323 auto result_shape =
324 llvm::to_vector<4>(llvm::map_range(perm, [filter_type](int64_t dim) {
325 return filter_type.getDimSize(dim);
326 }));
327 auto elem_type = filter_type.getElementType();
328 auto result_type = RankedTensorType::get(result_shape, elem_type);
329
330 return rewriter.create<TF::TransposeOp>(loc, result_type, filter, perm_op);
331 }
332 };
333
334 class ConvertTFDepthwiseConv2dNative
335 : public ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
336 TF::DepthwiseConv2dNativeOp> {
337 public:
338 using BaseType = ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
339 TF::DepthwiseConv2dNativeOp>;
340
ConvertTFDepthwiseConv2dNative(MLIRContext * context,bool allow_bf16_type_legalization)341 ConvertTFDepthwiseConv2dNative(MLIRContext *context,
342 bool allow_bf16_type_legalization)
343 : BaseType(context, allow_bf16_type_legalization) {}
344
getBiasDim(ArrayRef<int64_t> filterShape) const345 int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
346 return filterShape[2] * filterShape[3];
347 }
348
createTFLOp(ConvertTFConvOpMatchState * state,PatternRewriter & rewriter,Location loc,Type result_type,Value input,Value filter,Value bias) const349 TFL::DepthwiseConv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
350 PatternRewriter &rewriter, Location loc,
351 Type result_type, Value input,
352 Value filter, Value bias) const {
353 // Compared to tfl.conv_2d, tfl.depthwise_conv_2d has an additional
354 // 'depth_multiplier' attribute. However, tf.DepthwiseConv2dNative does not
355 // have a corresponding 'depth_multiplier' attribute; the multiplier is the
356 // fourth dimension in the 4-D filter tensor. We query the multiplier from
357 // tf.DepthwiseConv2dNative and set it as the attribute value accordingly.
358 auto multiplier = filter.getType().cast<RankedTensorType>().getDimSize(3);
359
360 filter = legalizeFilter(rewriter, loc, filter);
361 return rewriter.create<TFL::DepthwiseConv2DOp>(
362 loc, result_type, input, filter, bias,
363 /*dilation_h_factor=*/state->dilation_height_factor,
364 /*dilation_w_factor=*/state->dilation_width_factor,
365 /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
366 /*padding=*/state->padding,
367 /*stride_h=*/state->stride_height,
368 /*stride_w=*/state->stride_width,
369 /*depth_multiplier=*/rewriter.getI32IntegerAttr(multiplier));
370 }
371
372 private:
373 /// Legalize the given filter by converting it from TensorFlow filter data
374 /// format to TFLite DepthwiseConv2D op filter data format and return Value
375 /// for the converted filter. TensorFlow filter data format is
376 /// [filter_height, filter_width, in_channels, channel_multiplier] and TFLite
377 /// filter data format is [1, filter_height, filter_width, out_channels].
378 /// Requires that filter is verified by the match method that it is a 4-D
379 /// RankedTensorType.
legalizeFilter(PatternRewriter & rewriter,Location loc,Value filter) const380 Value legalizeFilter(PatternRewriter &rewriter, Location loc,
381 Value filter) const {
382 auto filter_type = filter.getType().cast<RankedTensorType>();
383 auto filterShape = filter_type.getShape();
384 SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
385 filterShape[2] * filterShape[3]};
386 auto elem_type = filter_type.getElementType();
387 auto result_type = RankedTensorType::get(result_shape, elem_type);
388 // TensorFlow Lite `Reshape` op only support int32 shape tensor currently.
389 auto shape_type = RankedTensorType::get({4}, rewriter.getIntegerType(32));
390 SmallVector<Attribute, 4> result_shape_data(4);
391 for (int i = 0; i < 4; ++i) {
392 result_shape_data[i] =
393 rewriter.getI32IntegerAttr(static_cast<int32_t>(result_shape[i]));
394 }
395 auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
396 auto shape = rewriter.create<TF::ConstOp>(loc, shape_type, shape_attr);
397
398 return rewriter.create<TF::ReshapeOp>(loc, result_type, filter, shape);
399 }
400 };
401
402 // StridedSlice can have complicated attributes like begin_axis_mask,
403 // end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
404 // masks will complicate the strided_slice computation logic, we can simplify
405 // the logic by inserting a reshape op to pad the inputs so strided_slice can
406 // be easier to handle.
407 //
408 // So the graph may looks like below:
409 // original_input -> strided_slice -> output
410 // (transforms)
411 // original_input -> reshape -> strided_slice -> output
412 //
413 // And the new shape is computed based on the masks.
414 //
415 // An example for new_axis_mask. say the new_axis_mask is 9 which represents
416 // [1 0 0 1], and that means we're inserting two new axes at 0 & 3 dim, so
417 // if original shape is [2, 3], now we reshape that into [1, 2, 3, 1].
418 struct ConvertTFStridedSlice : public RewritePattern {
ConvertTFStridedSlicemlir::TFL::__anon32f458ef0211::ConvertTFStridedSlice419 explicit ConvertTFStridedSlice(MLIRContext *context)
420 : RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
421
RewriteNewAxisMaskmlir::TFL::__anon32f458ef0211::ConvertTFStridedSlice422 LogicalResult RewriteNewAxisMask(Operation *op,
423 PatternRewriter &rewriter) const {
424 TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
425 uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
426
427 if (strided_slice_op.ellipsis_mask() != 0) {
428 // Ellipsis mask should have been lowered-away prior to invoking this
429 // function.
430 op->emitError() << "encountered a logical error";
431 return failure();
432 }
433
434 // Insert a new reshape op.
435 Value original_input = strided_slice_op.input();
436 RankedTensorType original_input_type =
437 original_input.getType().dyn_cast<RankedTensorType>();
438 if (!original_input_type) {
439 return failure();
440 }
441
442 const ArrayRef<int64_t> &original_input_shape =
443 original_input_type.getShape();
444 SmallVector<int64_t, 4> revised_shape;
445 int index = 0;
446 const int original_input_rank = original_input_shape.size();
447 while (index < original_input_rank || new_axis_mask) {
448 if (new_axis_mask & 1) {
449 revised_shape.emplace_back(1);
450 } else {
451 revised_shape.emplace_back(original_input_shape[index++]);
452 }
453 new_axis_mask >>= 1;
454 }
455
456 if (failed(TF::VerifyShapeOfReshapeOp(revised_shape))) return failure();
457
458 const int dim_size = revised_shape.size();
459 Location loc = strided_slice_op.getLoc();
460 auto shape_type =
461 RankedTensorType::get({dim_size}, rewriter.getIntegerType(32));
462 SmallVector<Attribute, 4> result_shape_data(dim_size);
463 for (int i = 0; i < dim_size; ++i) {
464 result_shape_data[i] =
465 rewriter.getI32IntegerAttr(static_cast<int32_t>(revised_shape[i]));
466 }
467
468 auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
469 auto shape =
470 rewriter.create<arith::ConstantOp>(loc, shape_type, shape_attr);
471 auto revised_output_type = RankedTensorType::get(
472 revised_shape, original_input_type.getElementType());
473 TF::ReshapeOp reshape = rewriter.create<TF::ReshapeOp>(
474 loc, revised_output_type, original_input, shape);
475
476 // Replace the original strided_slice.
477 uint64_t revised_begin_mask = strided_slice_op.begin_mask();
478 uint64_t revised_end_mask = strided_slice_op.end_mask();
479 // Since we expand the dims, we need to apply them to the begin_mask &
480 // end_mask.
481 revised_begin_mask |= strided_slice_op.new_axis_mask();
482 revised_end_mask |= strided_slice_op.new_axis_mask();
483
484 // Enforce operator precedence.
485 uint64_t revised_shrink_axis_mask =
486 strided_slice_op.shrink_axis_mask() & ~strided_slice_op.new_axis_mask();
487
488 auto attribute_type = rewriter.getIntegerType(64);
489 rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
490 op, strided_slice_op.getType(), reshape, strided_slice_op.begin(),
491 strided_slice_op.end(), strided_slice_op.strides(),
492 rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
493 rewriter.getIntegerAttr(attribute_type, revised_end_mask),
494 rewriter.getIntegerAttr(attribute_type,
495 strided_slice_op.ellipsis_mask()),
496 rewriter.getI64IntegerAttr(0),
497 rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask));
498 return success();
499 }
500
RewriteEllipsisMaskmlir::TFL::__anon32f458ef0211::ConvertTFStridedSlice501 LogicalResult RewriteEllipsisMask(Operation *op,
502 PatternRewriter &rewriter) const {
503 TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
504
505 uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask();
506 uint64_t shrink_axis_mask = strided_slice_op.shrink_axis_mask();
507 uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
508
509 // Enforce operator precedence.
510 shrink_axis_mask &= ~ellipsis_mask;
511 new_axis_mask &= ~ellipsis_mask;
512
513 DenseIntElementsAttr begin_dense_elem_attr;
514 Value begin = strided_slice_op.begin();
515 auto begin_ranked_attr_type = begin.getType().dyn_cast<RankedTensorType>();
516 if (!begin_ranked_attr_type ||
517 !matchPattern(begin, m_Constant(&begin_dense_elem_attr))) {
518 return failure();
519 }
520
521 DenseIntElementsAttr end_dense_elem_attr;
522 Value end = strided_slice_op.end();
523 auto end_ranked_attr_type = end.getType().dyn_cast<RankedTensorType>();
524 if (!end_ranked_attr_type ||
525 !matchPattern(end, m_Constant(&end_dense_elem_attr))) {
526 return failure();
527 }
528
529 DenseIntElementsAttr stride_dense_elem_attr;
530 Value stride = strided_slice_op.strides();
531 auto stride_ranked_attr_type =
532 stride.getType().dyn_cast<RankedTensorType>();
533 if (!stride_ranked_attr_type ||
534 !matchPattern(stride, m_Constant(&stride_dense_elem_attr))) {
535 return failure();
536 }
537
538 Value input = strided_slice_op.input();
539 RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
540 if (!input_type) {
541 return failure();
542 }
543 const ArrayRef<int64_t> input_shape = input_type.getShape();
544
545 const int input_size = input_shape.size();
546
547 RankedTensorType begin_type = begin.getType().cast<RankedTensorType>();
548 const ArrayRef<int64_t> begin_shape = begin_type.getShape();
549 const int begin_dim = begin_shape.size();
550
551 if (begin_dim != 1) return failure();
552
553 // The ellipsis fill might exceed the current output shape because we are
554 // also taking account of any to-be-inserted new axes.
555 const int ellipsis_filled_dim_size =
556 input_size - begin_shape[0] + 1 + absl::popcount(new_axis_mask);
557
558 int64_t begin_mask = strided_slice_op.begin_mask();
559 int64_t end_mask = strided_slice_op.end_mask();
560 int64_t revised_begin_mask = 0;
561 int64_t revised_end_mask = 0;
562 int64_t revised_shrink_axis_mask = 0;
563 int64_t revised_new_axis_mask = 0;
564
565 SmallVector<int32_t, 4> padded_begin;
566 SmallVector<int32_t, 4> padded_end;
567 SmallVector<int32_t, 4> padded_stride;
568
569 // Before the ellipsis.
570 int index = 0;
571 int new_index = 0;
572 while (((ellipsis_mask >> index) & 1) == 0) {
573 padded_begin.push_back(begin_dense_elem_attr.getValues<int32_t>()[index]);
574 padded_end.push_back(end_dense_elem_attr.getValues<int32_t>()[index]);
575 padded_stride.push_back(
576 stride_dense_elem_attr.getValues<int32_t>()[index]);
577 if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
578 if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
579 if ((shrink_axis_mask >> index) & 1)
580 revised_shrink_axis_mask |= (1 << new_index);
581
582 if ((new_axis_mask >> index) & 1)
583 revised_new_axis_mask |= (1 << new_index);
584
585 ++index;
586 ++new_index;
587 }
588
589 // Ellipsis.
590 for (; new_index < index + ellipsis_filled_dim_size; ++new_index) {
591 revised_begin_mask |= (1 << new_index);
592 revised_end_mask |= (1 << new_index);
593
594 // Mimic the begin/end/strides mask behavior.
595 padded_begin.push_back(0);
596 padded_end.push_back(0);
597 padded_stride.push_back(1);
598 }
599
600 // Account for ellipsis mask.
601 ++index;
602
603 // After the ellipsis.
604 for (; index < begin_shape[0];) {
605 padded_begin.push_back(begin_dense_elem_attr.getValues<int32_t>()[index]);
606 padded_end.push_back(end_dense_elem_attr.getValues<int32_t>()[index]);
607 padded_stride.push_back(
608 stride_dense_elem_attr.getValues<int32_t>()[index]);
609
610 if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
611 if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
612 if ((shrink_axis_mask >> index) & 1)
613 revised_shrink_axis_mask |= (1 << new_index);
614 if ((new_axis_mask >> index) & 1)
615 revised_new_axis_mask |= (1 << new_index);
616
617 ++index;
618 ++new_index;
619 }
620
621 auto attribute_type = rewriter.getIntegerType(64);
622
623 int full_dim_count = padded_begin.size();
624 auto type =
625 RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32));
626
627 auto begin_attr = DenseElementsAttr::get<int32_t>(type, padded_begin);
628 auto begin_op =
629 rewriter.create<arith::ConstantOp>(op->getLoc(), type, begin_attr);
630 auto end_attr = DenseElementsAttr::get<int32_t>(type, padded_end);
631 auto end_op =
632 rewriter.create<arith::ConstantOp>(op->getLoc(), type, end_attr);
633 auto stride_attr = DenseElementsAttr::get<int32_t>(type, padded_stride);
634 auto stride_op =
635 rewriter.create<arith::ConstantOp>(op->getLoc(), type, stride_attr);
636
637 rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
638 op, strided_slice_op.getType(), input, begin_op.getResult(),
639 end_op.getResult(), stride_op.getResult(),
640 rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
641 rewriter.getIntegerAttr(attribute_type, revised_end_mask),
642 /*ellipsis_mask=*/rewriter.getI64IntegerAttr(0),
643 rewriter.getIntegerAttr(attribute_type, revised_new_axis_mask),
644 rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask));
645
646 return success();
647 }
648
PadStridedSliceAttributeArraymlir::TFL::__anon32f458ef0211::ConvertTFStridedSlice649 void PadStridedSliceAttributeArray(DenseIntElementsAttr dense_elem_attr,
650 SmallVectorImpl<int32_t> &val,
651 SmallVectorImpl<int32_t> &padded_val,
652 ArrayRef<int32_t> padding_val,
653 int *mask) const {
654 for (const auto &idx : dense_elem_attr.getValues<APInt>()) {
655 val.push_back(idx.getSExtValue());
656 padded_val.push_back(idx.getSExtValue());
657 }
658 int attr_dim_count = val.size();
659 int full_dim_count = padding_val.size();
660 for (int i = attr_dim_count; i < full_dim_count; ++i) {
661 padded_val.push_back(padding_val[i]);
662 if (mask) *mask |= 1 << i;
663 }
664 }
665
matchAndRewritemlir::TFL::__anon32f458ef0211::ConvertTFStridedSlice666 LogicalResult matchAndRewrite(Operation *op,
667 PatternRewriter &rewriter) const override {
668 TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
669
670 // Handle ellipsis mask.
671 if (strided_slice_op.ellipsis_mask() != 0) {
672 return RewriteEllipsisMask(strided_slice_op, rewriter);
673 }
674
675 // Handle new axis mask.
676 if (strided_slice_op.new_axis_mask() != 0) {
677 return RewriteNewAxisMask(strided_slice_op, rewriter);
678 }
679
680 auto ranked_input_type =
681 strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
682 if (!ranked_input_type) {
683 return failure();
684 }
685
686 auto begin_attr = strided_slice_op.begin();
687 auto end_attr = strided_slice_op.end();
688 auto strides_attr = strided_slice_op.strides();
689
690 auto begin_attr_type = begin_attr.getType().dyn_cast<RankedTensorType>();
691 auto end_attr_type = end_attr.getType().dyn_cast<RankedTensorType>();
692 auto strides_attr_type =
693 strides_attr.getType().dyn_cast<RankedTensorType>();
694
695 DenseIntElementsAttr begin_elem_attr;
696 DenseIntElementsAttr end_elem_attr;
697 DenseIntElementsAttr strides_elem_attr;
698
699 if (!begin_attr_type ||
700 !matchPattern(begin_attr, m_Constant(&begin_elem_attr))) {
701 return failure();
702 }
703 if (!end_attr_type || !matchPattern(end_attr, m_Constant(&end_elem_attr))) {
704 return failure();
705 }
706 if (!strides_attr_type ||
707 !matchPattern(strides_attr, m_Constant(&strides_elem_attr))) {
708 return failure();
709 }
710
711 SmallVector<int32_t, 4> begin, end, strides;
712 SmallVector<int32_t, 4> padded_begin, padded_end, padded_strides;
713
714 int num_input_dims = ranked_input_type.getRank();
715 SmallVector<int32_t, 4> padding_begin(num_input_dims, 0);
716 auto input_shape = ranked_input_type.getShape();
717 SmallVector<int32_t, 4> padding_end(input_shape.begin(), input_shape.end());
718 SmallVector<int32_t, 4> padding_strides(num_input_dims, 1);
719
720 int begin_mask = strided_slice_op.begin_mask();
721 int end_mask = strided_slice_op.end_mask();
722
723 PadStridedSliceAttributeArray(begin_elem_attr, begin, padded_begin,
724 padding_begin, &begin_mask);
725 PadStridedSliceAttributeArray(end_elem_attr, end, padded_end, padding_end,
726 &end_mask);
727 PadStridedSliceAttributeArray(strides_elem_attr, strides, padded_strides,
728 padding_strides, nullptr);
729
730 if (begin == padded_begin && end == padded_end &&
731 strides == padded_strides &&
732 begin_mask == strided_slice_op.begin_mask() &&
733 end_mask == strided_slice_op.end_mask()) {
734 return failure();
735 }
736
737 auto begin_end_type =
738 RankedTensorType::get({num_input_dims}, rewriter.getIntegerType(32));
739 auto new_begin_attr = rewriter.create<arith::ConstantOp>(
740 op->getLoc(), begin_end_type,
741 DenseElementsAttr::get<int32_t>(begin_end_type, padded_begin));
742 auto new_end_attr = rewriter.create<arith::ConstantOp>(
743 op->getLoc(), begin_end_type,
744 DenseElementsAttr::get<int32_t>(begin_end_type, padded_end));
745 auto strides_type =
746 RankedTensorType::get({static_cast<long>(padded_strides.size())},
747 rewriter.getIntegerType(32));
748 auto new_strides_attr = rewriter.create<arith::ConstantOp>(
749 op->getLoc(), strides_type,
750 DenseElementsAttr::get<int32_t>(strides_type, padded_strides));
751
752 auto attribute_type = rewriter.getIntegerType(64);
753 rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
754 op, strided_slice_op.output().getType(), strided_slice_op.input(),
755 new_begin_attr, new_end_attr, new_strides_attr,
756 rewriter.getIntegerAttr(attribute_type, begin_mask),
757 rewriter.getIntegerAttr(attribute_type, end_mask),
758 rewriter.getIntegerAttr(attribute_type,
759 strided_slice_op.ellipsis_mask()),
760 rewriter.getIntegerAttr(attribute_type,
761 strided_slice_op.new_axis_mask()),
762 rewriter.getIntegerAttr(attribute_type,
763 strided_slice_op.shrink_axis_mask()));
764
765 return success();
766 }
767 };
768
769 struct ConvertTFBroadcastTo : public RewritePattern {
ConvertTFBroadcastTomlir::TFL::__anon32f458ef0211::ConvertTFBroadcastTo770 explicit ConvertTFBroadcastTo(MLIRContext *context)
771 : RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {}
772
matchAndRewritemlir::TFL::__anon32f458ef0211::ConvertTFBroadcastTo773 LogicalResult matchAndRewrite(Operation *op,
774 PatternRewriter &rewriter) const override {
775 auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
776 auto input_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
777 auto output_type = tf_broadcast_to_op.output().getType().cast<ShapedType>();
778 auto shape_type = tf_broadcast_to_op.shape().getType().cast<ShapedType>();
779 Type element_type = input_type.getElementType();
780
781 // Allow lowering when low dimension inputs are given and its type is F32 or
782 // I32.
783 if (!((output_type.hasRank() && output_type.getRank() <= 4) ||
784 (shape_type.hasStaticShape() && shape_type.getRank() == 1 &&
785 shape_type.getDimSize(0) <= 4)))
786 return failure();
787
788 if (!(element_type.isa<BFloat16Type, Float32Type>() ||
789 element_type.isInteger(32) || element_type.isInteger(16)))
790 return failure();
791
792 auto status_or_const_op =
793 CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1);
794 if (!status_or_const_op.ok()) {
795 return failure();
796 }
797
798 auto tf_fill_op = rewriter.create<TF::FillOp>(
799 op->getLoc(), output_type, tf_broadcast_to_op.shape(),
800 status_or_const_op.ValueOrDie());
801
802 auto mul_op = rewriter.create<TF::MulOp>(
803 op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op);
804 rewriter.replaceOp(op, mul_op.getResult());
805 return success();
806 }
807 };
808
809 // The below pattern is equivalent to the DRR rule below
810 // The checks are dependent on generated values, so we can't add
811 // the checks on intermediate values, ideally we should find equivalent
812 // checks that guarantees the resultant ops are valid.
813 // The extra conditions are the broadcasting conditions.
814 //
815 // The pattern lower FusedBatchNormV3 to arithmetic ops.
816 // Specifically, performs the following calculation:
817 //
818 // (x - mean) * scale / sqrt(variance + epsilon) + offset
819 //
820 // Let multiplier = scale / sqrt(variance + epsilon),
821 // to compute
822 // (x - mean) * scale / sqrt(variance + epsilon) + offset,
823 // is then to compute
824 // (x * multiplier) + (offset - mean * multiplier).
825 //
826 // def : Pattern<
827 // (TF_FusedBatchNormV3Op:$root
828 // $x, $scale, $offset, $mean, $variance,
829 // F32Attr:$epsilon, $exponential_avg_factor,
830 // $data_format, FalseBoolAttr:$is_training),
831 // [(TF_AddOp
832 // (TF_MulOp
833 // $x,
834 // (TF_MulOp:$multiplier
835 // $scale,
836 // (TF_RsqrtOp
837 // (TF_AddOp $variance,
838 // (TF_ConstOp $epsilon))))),
839 // (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
840 // // We already guaranteed that the last five results have no use so it does
841 // // not matter what value we provide here for replacement.
842 // /*batch_mean=*/(replaceWithValue $x),
843 // /*batch_variance=*/(replaceWithValue $x),
844 // /*reserve_space_1=*/(replaceWithValue $x),
845 // /*reserve_space_2=*/(replaceWithValue $x),
846 // /*reserve_space_3=*/(replaceWithValue $x)],
847 // [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
848 // (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
849 // (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
850 //
851 // When is_training is set to true, the given variance and mean are not used.
852 // In above calculation, they are replaced by new values. These new mean and
853 // variance are calculated as following:
854 // new_mean = mean(x, axis=[0, 1, 2])
855 // new_variance = mean(squared_difference(x, new_mean), axis=[0, 1, 2])
856 //
857 // The DDR rule for the is_training equals true case is as following:
858 // def : Pattern<
859 // (TF_FusedBatchNormV3Op:$root
860 // $x, $scale, $offset, $mean, $variance,
861 // F32Attr:$epsilon, $exponential_avg_factor,
862 // $data_format, FalseBoolAttr:$is_training),
863 // [(TF_AddOp
864 // (TF_MulOp
865 // $x,
866 // (TF_MulOp:$multiplier
867 // $scale,
868 // (TF_RsqrtOp
869 // (TF_AddOp
870 // (TF_MeanOp
871 // (TF_SquaredDifferenceOp $x, $new_mean),
872 // (TF_ConstOp [0,1,2])),
873 // (TF_ConstOp $epsilon))))),
874 // (TF_SubOp
875 // $offset,
876 // (TF_MulOp
877 // (TF_MeanOp $x, (TF_ConstOp [0,1,2])),
878 // $multiplier))),
879 // // We already guaranteed that the last five results have no use so it does
880 // // not matter what value we provide here for replacement.
881 // /*batch_mean=*/(replaceWithValue $x),
882 // /*batch_variance=*/(replaceWithValue $x),
883 // /*reserve_space_1=*/(replaceWithValue $x),
884 // /*reserve_space_2=*/(replaceWithValue $x),
885 // /*reserve_space_3=*/(replaceWithValue $x)],
886 // [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
887 // (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
888 // (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
889
890 struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
FusedBatchNormV3Patmlir::TFL::__anon32f458ef0211::FusedBatchNormV3Pat891 explicit FusedBatchNormV3Pat(::mlir::MLIRContext *context)
892 : ::mlir::RewritePattern(
893 "tf.FusedBatchNormV3", 1, context,
894 {"tf.Add", "tf.Const", "tf.Mul", "tf.Rsqrt", "tf.Sub"}) {}
895
matchAndRewritemlir::TFL::__anon32f458ef0211::FusedBatchNormV3Pat896 ::mlir::LogicalResult matchAndRewrite(
897 ::mlir::Operation *fused_batch_norm,
898 ::mlir::PatternRewriter &rewriter) const override {
899 // Variables for capturing values and attributes used for creating ops
900 Operation::operand_range mean(fused_batch_norm->getOperands());
901 ::mlir::FloatAttr exponential_avg_factor;
902 ::mlir::TF::FusedBatchNormV3Op root;
903 Operation::operand_range offset(fused_batch_norm->getOperands());
904 Operation::operand_range x(fused_batch_norm->getOperands());
905 Operation::operand_range scale(fused_batch_norm->getOperands());
906 Operation::operand_range variance(fused_batch_norm->getOperands());
907 ::mlir::FloatAttr epsilon;
908 ::mlir::BoolAttr is_training;
909
910 // Match
911 auto fused_batch_norm_op =
912 dyn_cast_or_null<::mlir::TF::FusedBatchNormV3Op>(fused_batch_norm);
913 root = fused_batch_norm_op;
914 x = fused_batch_norm_op.getODSOperands(0);
915 scale = fused_batch_norm_op.getODSOperands(1);
916 offset = fused_batch_norm_op.getODSOperands(2);
917 mean = fused_batch_norm_op.getODSOperands(3);
918 variance = fused_batch_norm_op.getODSOperands(4);
919
920 ::mlir::Value mean_value = (*mean.begin());
921 ::mlir::Value variance_value = (*variance.begin());
922
923 if (!TFTypeIsFloat32Tensor(fused_batch_norm_op.x())) return failure();
924
925 {
926 epsilon =
927 fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>("epsilon");
928 if (!epsilon)
929 epsilon = rewriter.getFloatAttr(rewriter.getF32Type(), 0.0001f);
930
931 if (!(((epsilon.isa<::mlir::FloatAttr>())) &&
932 ((epsilon.cast<::mlir::FloatAttr>().getType().isF32())))) {
933 return rewriter.notifyMatchFailure(
934 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
935 diag << "op 'tf.FusedBatchNormV3' attribute 'epsilon' failed to "
936 "satisfy constraint: 32-bit float attribute";
937 });
938 }
939 }
940 {
941 exponential_avg_factor =
942 fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>(
943 "exponential_avg_factor");
944 if (!exponential_avg_factor)
945 exponential_avg_factor =
946 rewriter.getFloatAttr(rewriter.getF32Type(), 1.0f);
947 }
948 if (!TFDataFormatIsNHWC(fused_batch_norm_op) &&
949 !TFDataFormatIsNDHWC(fused_batch_norm_op))
950 return failure();
951
952 if (!(((*root.getODSResults(1).begin()).use_empty()))) {
953 return rewriter.notifyMatchFailure(
954 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
955 diag << "entities '' failed to satisfy constraint: has no use";
956 });
957 }
958
959 if (!(((*root.getODSResults(2).begin()).use_empty()))) {
960 return rewriter.notifyMatchFailure(
961 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
962 diag << "entities '' failed to satisfy constraint: has no use";
963 });
964 }
965
966 if (!(((*root.getODSResults(3).begin()).use_empty()))) {
967 return rewriter.notifyMatchFailure(
968 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
969 diag << "entities '' failed to satisfy constraint: has no use";
970 });
971 }
972
973 if (!(((*root.getODSResults(4).begin()).use_empty()))) {
974 return rewriter.notifyMatchFailure(
975 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
976 diag << "entities '' failed to satisfy constraint: has no use";
977 });
978 }
979
980 if (!(((*root.getODSResults(5).begin()).use_empty()))) {
981 return rewriter.notifyMatchFailure(
982 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
983 diag << "entities '' failed to satisfy constraint: has no use";
984 });
985 }
986
987 is_training =
988 fused_batch_norm_op->getAttrOfType<::mlir::BoolAttr>("is_training");
989 auto odsLoc = rewriter.getFusedLoc({fused_batch_norm->getLoc()});
990
991 // We need to make sure input and output shapes are compatible.
992 int64_t last_dim = -1;
993 {
994 auto is_last_dim_compatible = [](const Value &v, int64_t &last_dim) {
995 auto v_type = v.getType().dyn_cast_or_null<RankedTensorType>();
996 if (!v_type) return true;
997 int64_t v_last_dim = v_type.getDimSize(v_type.getRank() - 1);
998 if (v_last_dim == -1) return true;
999 if (last_dim != -1 && v_last_dim != last_dim) return false;
1000 last_dim = v_last_dim;
1001 return true;
1002 };
1003
1004 if (!is_last_dim_compatible(*x.begin(), last_dim) ||
1005 !is_last_dim_compatible(*scale.begin(), last_dim) ||
1006 !is_last_dim_compatible(*offset.begin(), last_dim)) {
1007 return rewriter.notifyMatchFailure(
1008 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1009 diag << "Shapes of scale and offset should be 1D and "
1010 "compatible with x";
1011 });
1012 }
1013
1014 if (!is_training.getValue()) {
1015 if (!is_last_dim_compatible(mean_value, last_dim) ||
1016 !is_last_dim_compatible(variance_value, last_dim)) {
1017 return rewriter.notifyMatchFailure(
1018 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1019 diag << "Shapes of mean and variance should be 1D and "
1020 "compatible with x";
1021 });
1022 }
1023 }
1024
1025 // Check if output shape and input shape are compatible.
1026 auto x_type = (*x.begin()).getType();
1027 auto y_type = (*root.getODSResults(0).begin()).getType();
1028 if (!OpTrait::util::getBroadcastedType(x_type, y_type)) {
1029 return rewriter.notifyMatchFailure(
1030 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1031 diag << "Shapes of x and the first output should be compatible";
1032 });
1033 }
1034 }
1035
1036 // For training, mean and variance is calculated from input values.
1037 if (is_training.getValue()) {
1038 auto input_type = fused_batch_norm_op.x()
1039 .getType()
1040 .dyn_cast_or_null<RankedTensorType>();
1041 if (!input_type || input_type.getRank() != 4) {
1042 return rewriter.notifyMatchFailure(
1043 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1044 diag << "op 'tf.FusedBatchNormV3' that has 'is_training' equals "
1045 "True is only supported with input of rank 4";
1046 });
1047 }
1048
1049 ::mlir::TF::ConstOp reduce_dim_op;
1050 {
1051 auto reduce_dim_type =
1052 ::mlir::RankedTensorType::get({3}, rewriter.getIntegerType(32));
1053 ::mlir::SmallVector<int32_t, 3> reduce_dim_values = {0, 1, 2};
1054 reduce_dim_op = rewriter.create<TF::ConstOp>(
1055 odsLoc, ::mlir::DenseIntElementsAttr::get(reduce_dim_type,
1056 reduce_dim_values));
1057 }
1058
1059 auto new_mean_type =
1060 ::mlir::RankedTensorType::get({last_dim}, rewriter.getF32Type());
1061 ::mlir::TF::MeanOp mean_op_1;
1062 {
1063 ::mlir::Value x_value = (*x.begin());
1064 mean_op_1 = rewriter.create<TF::MeanOp>(
1065 odsLoc, new_mean_type, x_value, reduce_dim_op,
1066 /*keep_dims=*/rewriter.getBoolAttr(false));
1067 }
1068
1069 ::mlir::TF::SquaredDifferenceOp square_diff_op;
1070 {
1071 ::mlir::Value tblgen_value_0 = (*x.begin());
1072 ::mlir::Value tblgen_value_1 = (*mean_op_1.getODSResults(0).begin());
1073 // If x has shape of [b, h, w, c], the result of mean_op_1 will have
1074 // shape of [c]. Therefore, their shapes are always compatible.
1075 square_diff_op = rewriter.create<::mlir::TF::SquaredDifferenceOp>(
1076 odsLoc, tblgen_value_0, tblgen_value_1);
1077 }
1078
1079 ::mlir::TF::MeanOp mean_op_2;
1080 {
1081 ::mlir::Value input_value = (*square_diff_op.getODSResults(0).begin());
1082 mean_op_2 = rewriter.create<TF::MeanOp>(
1083 odsLoc, new_mean_type, input_value, reduce_dim_op,
1084 /*keep_dims=*/rewriter.getBoolAttr(false));
1085 }
1086
1087 mean_value = (*mean_op_1.getODSResults(0).begin());
1088 variance_value = (*mean_op_2.getODSResults(0).begin());
1089 } // End is_training equals true if.
1090
1091 ::llvm::SmallVector<::mlir::Value, 4> replace_values;
1092 ::mlir::TF::ConstOp epsilon_const_op;
1093 {
1094 epsilon_const_op =
1095 rewriter.create<::mlir::TF::ConstOp>(odsLoc,
1096 /*value=*/epsilon);
1097 }
1098 ::mlir::TF::AddOp add_op_1;
1099 {
1100 ::mlir::Value epsilon_value =
1101 (*epsilon_const_op.getODSResults(0).begin());
1102 // Multiplying with a constant, no need to check broadcastibility.
1103 add_op_1 = rewriter.create<::mlir::TF::AddOp>(odsLoc,
1104 /*x=*/variance_value,
1105 /*y=*/epsilon_value);
1106 }
1107 ::mlir::TF::RsqrtOp rsqrt_op;
1108 {
1109 ::mlir::SmallVector<::mlir::Value, 4> tblgen_values;
1110 ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs;
1111 tblgen_values.push_back((*add_op_1.getODSResults(0).begin()));
1112 rsqrt_op = rewriter.create<::mlir::TF::RsqrtOp>(odsLoc, tblgen_values,
1113 tblgen_attrs);
1114 }
1115 ::mlir::TF::MulOp multiplier;
1116 {
1117 ::mlir::Value tblgen_value_0 = (*scale.begin());
1118 ::mlir::Value tblgen_value_1 = (*rsqrt_op.getODSResults(0).begin());
1119 multiplier = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1120 /*x=*/tblgen_value_0,
1121 /*y=*/tblgen_value_1);
1122 }
1123 ::mlir::TF::MulOp mul_op_1;
1124 {
1125 ::mlir::Value tblgen_value_0 = (*x.begin());
1126 ::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin());
1127 mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1128 /*x=*/tblgen_value_0,
1129 /*y=*/tblgen_value_1);
1130 }
1131 ::mlir::TF::MulOp mul_op_2;
1132 {
1133 ::mlir::Value multiplier_value = (*multiplier.getODSResults(0).begin());
1134 mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1135 /*x=*/mean_value,
1136 /*y=*/multiplier_value);
1137 }
1138 ::mlir::TF::SubOp sub_op;
1139 {
1140 ::mlir::Value tblgen_value_0 = (*offset.begin());
1141 ::mlir::Value tblgen_value_1 = (*mul_op_2.getODSResults(0).begin());
1142 sub_op = rewriter.create<::mlir::TF::SubOp>(odsLoc,
1143 /*x=*/tblgen_value_0,
1144 /*y=*/tblgen_value_1);
1145 }
1146 ::mlir::TF::AddOp add_op_2;
1147 {
1148 ::mlir::SmallVector<::mlir::Value, 4> tblgen_values;
1149 ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs;
1150 tblgen_values.push_back((*mul_op_1.getODSResults(0).begin()));
1151 tblgen_values.push_back((*sub_op.getODSResults(0).begin()));
1152 ::mlir::SmallVector<::mlir::Type, 4> tblgen_types;
1153 for (auto v : fused_batch_norm_op.getODSResults(0)) {
1154 tblgen_types.push_back(v.getType());
1155 }
1156 add_op_2 = rewriter.create<::mlir::TF::AddOp>(
1157 odsLoc, tblgen_types, tblgen_values, tblgen_attrs);
1158 }
1159 for (auto v :
1160 ::llvm::SmallVector<::mlir::Value, 4>{add_op_2.getODSResults(0)}) {
1161 replace_values.push_back(v);
1162 }
1163 for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1164 replace_values.push_back(v);
1165 }
1166 for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1167 replace_values.push_back(v);
1168 }
1169 for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1170 replace_values.push_back(v);
1171 }
1172 for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1173 replace_values.push_back(v);
1174 }
1175 for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1176 replace_values.push_back(v);
1177 }
1178 rewriter.replaceOp(fused_batch_norm, replace_values);
1179 return success();
1180 };
1181 };
1182
1183 #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
1184
1185 // Returns success if all the operations in the `op`'s regions including `op`
1186 // itself are legal in a TFLite pipeline.
ValidateOp(Operation * op)1187 LogicalResult ValidateOp(Operation *op) {
1188 bool has_illegal_ops = false;
1189 op->walk([&](Operation *op) {
1190 if (isa<TF::VariableV2Op>(op)) {
1191 has_illegal_ops = true;
1192 op->emitOpError() << "is illegal in a TFLite pipeline";
1193 }
1194 });
1195
1196 return failure(has_illegal_ops);
1197 }
1198
1199 // Converts a set of TF2XLA ops into pure TF ops for future legalizations as
1200 // TF2XLA ops aren't supported by later stages.
ConvertTf2XlaOps(func::FuncOp func,MLIRContext * context)1201 LogicalResult ConvertTf2XlaOps(func::FuncOp func, MLIRContext *context) {
1202 ConversionTarget target(*context);
1203 target.addLegalDialect<arith::ArithmeticDialect>();
1204 target.addLegalDialect<func::FuncDialect>();
1205 target.addLegalDialect<TF::TensorFlowDialect>();
1206 target.addLegalOp<ModuleOp>();
1207 target.addLegalOp<func::FuncOp>();
1208 target.addIllegalOp<TF::XlaConvV2Op>();
1209 target.addIllegalOp<TF::XlaGatherOp>();
1210
1211 RewritePatternSet patterns(context);
1212 mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns, context);
1213 mhlo::PopulateLegalizeTfPatterns(context, &patterns);
1214 TF::PopulateLegalizeHloToTfPatterns(&patterns, context);
1215 mhlo::GatherOp::getCanonicalizationPatterns(patterns, context);
1216
1217 return applyPartialConversion(func, target, std::move(patterns));
1218 }
1219
1220 // Convert rfft to rfft2d.
1221 // The transformation pattern looks like below:
1222 //
1223 // input fft_len
1224 // \ /
1225 // rfft
1226 //
1227 // ||
1228 // \/
1229 //
1230 // input fft_len
1231 // \ /
1232 // expand_dim concat with [1] at the front
1233 // \ /
1234 // rfft_2d
1235 // |
1236 // squeeze
1237 struct ConvertRfftToRfft2d : public RewritePattern {
ConvertRfftToRfft2dmlir::TFL::__anon32f458ef0211::ConvertRfftToRfft2d1238 explicit ConvertRfftToRfft2d(MLIRContext *context)
1239 : RewritePattern(TF::RFFTOp::getOperationName(), 1, context) {}
1240
matchAndRewritemlir::TFL::__anon32f458ef0211::ConvertRfftToRfft2d1241 LogicalResult matchAndRewrite(Operation *op,
1242 PatternRewriter &rewriter) const override {
1243 auto rfft_op = dyn_cast<TF::RFFTOp>(op);
1244
1245 auto input = rfft_op.input();
1246 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
1247 if (!input_type) return failure();
1248 auto fft_len = rfft_op.fft_length();
1249 auto fft_len_type = fft_len.getType().dyn_cast_or_null<ShapedType>();
1250 if (!fft_len_type) return failure();
1251
1252 auto output_type =
1253 rfft_op.getResult().getType().dyn_cast_or_null<RankedTensorType>();
1254 if (!output_type) return failure();
1255
1256 // Expanded inputs.
1257 // Insert at -2 location.
1258 auto one_ele_type =
1259 mlir::RankedTensorType::get({1}, rewriter.getIntegerType(32));
1260 auto minus_two = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
1261 one_ele_type, -2);
1262
1263 SmallVector<int64_t, 4> expanded_input_shape;
1264 SmallVector<int64_t, 4> expanded_output_shape;
1265 int expanded_rank = input_type.getRank() + 1;
1266 int r = 0;
1267 for (int i = 0; i < expanded_rank; ++i) {
1268 if (i == expanded_rank - 2) {
1269 expanded_input_shape.push_back(1);
1270 expanded_output_shape.push_back(1);
1271 } else {
1272 expanded_input_shape.push_back(input_type.getDimSize(r));
1273 expanded_output_shape.push_back(output_type.getDimSize(r));
1274 r++;
1275 }
1276 }
1277
1278 auto expaned_input_type = mlir::RankedTensorType::get(
1279 expanded_input_shape, input_type.getElementType());
1280 TF::ExpandDimsOp expanded_input = rewriter.create<TF::ExpandDimsOp>(
1281 rfft_op.getLoc(), expaned_input_type, input, minus_two->getResult());
1282
1283 // Expanded fft_len.
1284 auto one_attr = mlir::DenseIntElementsAttr::get(one_ele_type, {1});
1285
1286 auto one = rewriter.create<TF::ConstOp>(rfft_op.getLoc(), one_attr);
1287
1288 auto zero = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
1289 one_ele_type, 0);
1290
1291 auto expanded_fft_len_type =
1292 mlir::RankedTensorType::get({2}, fft_len_type.getElementType());
1293
1294 TF::ConcatV2Op expanded_fft_len = rewriter.create<TF::ConcatV2Op>(
1295 rfft_op.getLoc(), expanded_fft_len_type,
1296 SmallVector<Value, 2>({one.getResult(), fft_len}), zero->getResult());
1297
1298 // Insert the rfft_2d.
1299 auto rfft2d_out_type = mlir::RankedTensorType::get(
1300 expanded_output_shape, output_type.getElementType());
1301 TF::RFFT2DOp rfft2d = rewriter.create<TF::RFFT2DOp>(
1302 rfft_op.getLoc(), rfft2d_out_type, expanded_input.getResult(),
1303 expanded_fft_len.getResult());
1304
1305 // Insert the squeeze op.
1306 auto squeeze_dim = rewriter.getI64ArrayAttr({-2});
1307 TF::SqueezeOp squeeze = rewriter.create<TF::SqueezeOp>(
1308 rfft_op.getLoc(), output_type, rfft2d.getResult(), squeeze_dim);
1309
1310 rewriter.replaceOp(op, squeeze.getResult());
1311
1312 return success();
1313 }
1314 };
1315
1316 // Replaces the Identity op with its input in either of the following scenarios
1317 // : 1) The Identity op's input and output have same types/shapes. 2) The result
1318 // of Identity op is only used by TF ops.
1319 struct RemoveIdentity : public OpRewritePattern<TF::IdentityOp> {
1320 using OpRewritePattern<TF::IdentityOp>::OpRewritePattern;
1321
matchAndRewritemlir::TFL::__anon32f458ef0211::RemoveIdentity1322 LogicalResult matchAndRewrite(TF::IdentityOp identity,
1323 PatternRewriter &rewriter) const override {
1324 // Replace the op with the input if input and result have the same type.
1325 if (identity.input().getType() == identity.getType()) {
1326 rewriter.replaceOp(identity, identity.input());
1327 return success();
1328 }
1329 // Replace the op with the input if output is only used by TF ops.
1330 // Currently this is more on the conservative side since we need to ensure
1331 // every consumer op to be a TF op before applying this pattern. We can
1332 // consider to revisit this in the future if this turns out to be too
1333 // restrictive.
1334 for (Operation *user : identity->getUsers()) {
1335 if (user->getDialect()->getNamespace() != "tf") {
1336 return failure();
1337 }
1338 }
1339
1340 rewriter.replaceOp(identity, identity.input());
1341 return success();
1342 }
1343 };
1344
runOnOperation()1345 void PrepareTFPass::runOnOperation() {
1346 MLIRContext *ctx = &getContext();
1347 RewritePatternSet patterns(ctx);
1348 RewritePatternSet phase_2_patterns(ctx);
1349 auto func = getOperation();
1350
1351 // Check illegal ops in a TFLite pipeline (e.g. trainning only ops) , since
1352 // PrepareTFPass is the very first TFLite pass in the pipeline.
1353 // TODO(jingpu): It might be better to split this check into its own pass
1354 // to make things more modular.
1355 if (failed(ValidateOp(func))) {
1356 func.emitError() << "tfl-prepare-tf pass failed.";
1357 signalPassFailure();
1358 return;
1359 }
1360
1361 if (failed(ConvertTf2XlaOps(func, ctx))) {
1362 signalPassFailure();
1363 return;
1364 }
1365
1366 // This pattern will try to identify and optimize for dilated convolution.
1367 // e.g. Patterns like "SpaceToBatchND -> Conv2D -> BatchToSpaceND" will be
1368 // replaced with a single Conv op with dilation parameter.
1369 patterns.add<ConvertTFDilatedConvOp<TF::Conv2DOp>, FusedBatchNormV3Pat,
1370 ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx);
1371
1372 patterns.add<RemoveIdentity>(ctx);
1373 TFL::populateWithGenerated(patterns);
1374 // Remove redundant reshape ops.
1375 TF::ReshapeOp::getCanonicalizationPatterns(patterns, ctx);
1376 // TODO(karimnosseir): Split to separate pass probably after
1377 // deciding on long term plan for this optimization.
1378 // This will allow optimizing any TF_Mul->TF_Conv in the graph
1379 // and any expanded from FusedBatchNorm. We need to do this
1380 // before converting TF_Conv to TFL_Conv
1381 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
1382
1383 // Remove the wrapper of the tf.FakeQuant* ops and also insert the
1384 // tfl.quantize and tfl.dequantize to preserve the quantization parameters.
1385 // This is done after the first round of optimization to make sure all the
1386 // min/max operands of the tf.FakeQuant* are constants to be matched. The
1387 // following round of optimization will folding the unwrapped
1388 // tf.FakeQuant* ops with the weight constants.
1389 if (failed(ConvertFakeQuantOps(func, ctx, use_fake_quant_num_bits_))) {
1390 signalPassFailure();
1391 return;
1392 }
1393
1394 // Load the generated pattern again, so new quantization pass-through
1395 // will be applied.
1396 TFL::populateWithGenerated(phase_2_patterns);
1397 if (unfold_batch_matmul_) {
1398 TF::PopulateUnrollTfBatchMatMul(ctx, phase_2_patterns);
1399 }
1400 phase_2_patterns
1401 .add<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo, ConvertTFStridedSlice,
1402 ConvertRfftToRfft2d, RemoveIdentity>(ctx);
1403 phase_2_patterns.add<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
1404 ctx, allow_bf16_and_f16_type_legalization_);
1405 // Remove redundant reshape ops.
1406 TF::ReshapeOp::getCanonicalizationPatterns(phase_2_patterns, ctx);
1407
1408 (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
1409 }
1410
1411 } // namespace
1412
1413 // Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
CreatePrepareTFPass(bool unfold_batch_matmul,bool allow_bf16_and_f16_type_legalization,bool use_fake_quant_num_bits)1414 std::unique_ptr<OperationPass<func::FuncOp>> CreatePrepareTFPass(
1415 bool unfold_batch_matmul, bool allow_bf16_and_f16_type_legalization,
1416 bool use_fake_quant_num_bits) {
1417 return std::make_unique<PrepareTFPass>(unfold_batch_matmul,
1418 allow_bf16_and_f16_type_legalization,
1419 use_fake_quant_num_bits);
1420 }
1421
1422 // Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
CreatePrepareTFPass()1423 std::unique_ptr<OperationPass<func::FuncOp>> CreatePrepareTFPass() {
1424 return std::make_unique<PrepareTFPass>();
1425 }
1426
1427 } // namespace TFL
1428 } // namespace mlir
1429