xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This is the optimization pattern definition file for TensorFlow Lite.
17
18include "mlir/IR/OpBase.td"
19include "mlir/IR/PatternBase.td"
20include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td"
21include "mlir/Dialect/Func/IR/FuncOps.td"
22include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
23include "tensorflow/compiler/mlir/lite/utils/utils.td"
24include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
25
26// Checks if the param passed is a F32 ElementsAttr.
27def F32ElementsAttr : ElementsAttrBase<
28  CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isF32()">,
29        "32 bit float constant tensor">;
30
31// Checks if the param passed is a float ElementsAttr.
32def FloatElementsAttr : ElementsAttrBase<
33  CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isa<FloatType>()">,
34        "float constant tensor">;
35
36// Checks if the param passed is of NoneType.
37def IsNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
38
39def ExtractSingleElementAsFloat : NativeCodeCall<
40    "ExtractSingleElementAsFloat($_self.cast<ElementsAttr>())">;
41
42// Checks if the value has rank at most 'n'.
43class HasRankAtMost<int n> : Constraint<
44    CPred<"$0.getType().cast<ShapedType>().hasRank() && "
45          "$0.getType().cast<ShapedType>().getRank() <= " # n>>;
46
47// Checks if the value has rank 'n'.
48class HasRank<int n> : Constraint<
49    CPred<"$0.getType().cast<ShapedType>().hasRank() && "
50          "$0.getType().cast<ShapedType>().getRank() == " # n>>;
51
52//===----------------------------------------------------------------------===//
53// Ternary ops patterns.
54//===----------------------------------------------------------------------===//
55// Multi-pattern consisting of matching stand-alone convolution op followed by
56// activation op.
57multiclass FuseActFnIntoConvOpPat<Op ActFnOp, ConstantStrAttr ActFnAttr> {
58  def FuseActivationFuncWithConv#ActFnOp#ActFnAttr : Pat<
59    (ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias, $h_factor,
60                 $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w)),
61    (TFL_Conv2DOp $input, $filter, $bias, $h_factor, $w_factor, ActFnAttr,
62        $padding, $stride_h, $stride_w),
63    [(HasOneUse $conv_out)]>;
64  def FuseActivationFuncWithDepthwiseConv#ActFnOp#ActFnAttr : Pat<
65    (ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias, $h_factor,
66                $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w,
67                $multiplier)),
68    (TFL_DepthwiseConv2DOp $input, $filter, $bias, $h_factor, $w_factor,
69        ActFnAttr, $padding, $stride_h, $stride_w, $multiplier),
70    [(HasOneUse $conv_out)]>;
71}
72
73multiclass FuseActFnIntoPoolOpPat<Op ActFnOp, ConstantStrAttr ActFnAttr> {
74  def FuseActivationFuncWithAvgPool#ActFnOp#ActFnAttr : Pat<
75    (ActFnOp (TFL_AveragePool2DOp:$pool_out $input, $filter_height,
76                  $filter_width, $padding, $stride_h, $stride_w, TFL_AF_None)),
77    (TFL_AveragePool2DOp $input, $filter_height, $filter_width, $padding,
78        $stride_h, $stride_w, ActFnAttr),
79    [(HasOneUse $pool_out)]>;
80  def FuseActivationFuncWithMaxPool#ActFnOp#ActFnAttr : Pat<
81    (ActFnOp (TFL_MaxPool2DOp:$pool_out $input, $padding, $stride_w, $stride_h,
82                  $filter_width, $filter_height, TFL_AF_None)),
83    (TFL_MaxPool2DOp $input, $padding, $stride_w, $stride_h,
84        $filter_width, $filter_height, ActFnAttr),
85    [(HasOneUse $pool_out)]>;
86}
87
88// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
89// activation functions.
90// Currently we're not fusing tanh, sigmoid, hard_swish and other activations
91// those cannot be simply translated into clamping.
92foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
93                     [TFL_Relu6Op, TFL_AF_Relu6],
94                     [TFL_Relu1Op, TFL_AF_Relu1]] in {
95  defm : FuseActFnIntoConvOpPat<!cast<Op>(actFnPair[0]), !cast<ConstantStrAttr>(actFnPair[1])>;
96  defm : FuseActFnIntoPoolOpPat<!cast<Op>(actFnPair[0]), !cast<ConstantStrAttr>(actFnPair[1])>;
97}
98
99class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
100  CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>;
101
102// If we see a binary op (add, sub) op adding a constant value to a convolution
103// op with constant bias, we can fuse the binary op into the convolution op by
104// constant folding the bias and the binary op's constant operand. The following
105// pattern restricts to float constant values for now.
106multiclass FuseBinaryOpToPrecedingAffine<Op binaryOp> {
107  def FuseBinaryOpWithConv#binaryOp : Pat<
108    (binaryOp (TFL_Conv2DOp:$output $input, $filter,
109                (Arith_ConstantOp FloatElementsAttr:$bias), $h_factor, $w_factor,
110                TFL_AF_None, $padding, $stride_h, $stride_w),
111              (Arith_ConstantOp FloatElementsAttr:$value), $act_fn),
112    (TFL_Conv2DOp $input, $filter,
113      (binaryOp (Arith_ConstantOp $bias),
114         (Arith_ConstantOp $value), TFL_AF_None),
115      $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w),
116    [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
117     (HasOneUse $output)]>;
118  def FuseBinaryOpWithDepthwiseConv#binaryOp : Pat<
119    (binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
120                (Arith_ConstantOp FloatElementsAttr:$bias),
121                $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
122                $stride_w, $multiplier),
123              (Arith_ConstantOp FloatElementsAttr:$value), $act_fn),
124    (TFL_DepthwiseConv2DOp $input, $filter,
125      (binaryOp (Arith_ConstantOp $bias), (Arith_ConstantOp $value), TFL_AF_None),
126      $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w,
127      $multiplier),
128    [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
129     (HasRank<1> $value),
130     (HasOneUse $output)]>;
131   def FuseBinaryOpWithTransposeConv#binaryOp : Pat<
132    (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs,
133                (Arith_ConstantOp FloatElementsAttr:$bias), $padding,
134                $stride_h, $stride_w),
135              (Arith_ConstantOp FloatElementsAttr:$value), TFL_AF_None),
136    (TFL_TransposeConvOp $output_shape, $weights, $inputs,
137      (binaryOp (Arith_ConstantOp $bias),
138         (Arith_ConstantOp $value), TFL_AF_None),
139      $padding, $stride_h, $stride_w),
140    [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
141     (HasOneUse $output)]>;
142  // Fuse for TransposeConv with no bias
143  def FuseBinaryOpWithTransposeConvNoneBias#binaryOp : Pat<
144    (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs,
145                $bias, $padding,
146                $stride_h, $stride_w),
147              (Arith_ConstantOp FloatElementsAttr:$value), TFL_AF_None),
148    (TFL_TransposeConvOp $output_shape, $weights, $inputs,
149      (Arith_ConstantOp $value),
150      $padding, $stride_h, $stride_w),
151    [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
152     (IsNoneType $bias),
153     (HasOneUse $output)]>;
154}
155foreach binaryOp = [TFL_AddOp, TFL_SubOp]<Op> in
156  defm : FuseBinaryOpToPrecedingAffine<binaryOp>;
157
158def ExpandTo4DForConv: NativeCodeCall<"ExpandTo4DForConv($0)">;
159
160def ExpandTo4DForDepthwiseConv: NativeCodeCall<
161  "ExpandTo4DForDepthwiseConv($0)">;
162
163// If we see a (div or Mul) op (dividing/multiplying) a constant value
164// to a convolution op with constant filter and bias, we can fuse the div/mul
165// into the convolution op by constant folding
166// the filter/bias and the div/mul op's constant operand.
167// The following pattern restricts to float constant values for now.
168
169multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<Op BinaryOp> {
170  def FuseMulOrDivWithDepthwiseConv#BinaryOp : Pat<
171    (BinaryOp (TFL_DepthwiseConv2DOp:$output $input,
172                (Arith_ConstantOp FloatElementsAttr:$filter),
173                (Arith_ConstantOp FloatElementsAttr:$bias),
174                $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
175                $stride_w, $multiplier),
176              (Arith_ConstantOp FloatElementsAttr:$value), $act_fn),
177    (TFL_DepthwiseConv2DOp $input,
178      (BinaryOp
179        (Arith_ConstantOp $filter),
180        (Arith_ConstantOp (ExpandTo4DForDepthwiseConv $value)),
181        TFL_AF_None),
182      (BinaryOp
183        (Arith_ConstantOp $bias),
184        (Arith_ConstantOp $value),
185        TFL_AF_None),
186      $h_factor, $w_factor, $act_fn, $padding, $stride_h,
187      $stride_w, $multiplier),
188    [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
189     (HasRank<1> $value),
190     (HasOneUse $output)]>;
191  def FuseMulOrDivWithConv#BinaryOp : Pat<
192    (BinaryOp (TFL_Conv2DOp:$conv_output $input,
193                (Arith_ConstantOp FloatElementsAttr:$filter),
194                (Arith_ConstantOp FloatElementsAttr:$bias),
195                $h_factor, $w_factor, TFL_AF_None,
196                $padding, $stride_h, $stride_w),
197              (Arith_ConstantOp FloatElementsAttr:$value), $act_fn),
198    (TFL_Conv2DOp $input,
199      (BinaryOp (Arith_ConstantOp $filter),
200        (Arith_ConstantOp (ExpandTo4DForConv $value)),
201        TFL_AF_None),
202      (BinaryOp (Arith_ConstantOp $bias),
203        (Arith_ConstantOp $value),
204        TFL_AF_None),
205      $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w),
206    [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
207     (HasOneUse $conv_output)]>;
208  def FuseMulOrDivWithTransposeConv#BinaryOp : Pat<
209    (BinaryOp (TFL_TransposeConvOp:$output $output_shape,
210                (Arith_ConstantOp FloatElementsAttr:$weights), $input,
211                (Arith_ConstantOp FloatElementsAttr:$bias),
212                $padding, $stride_h, $stride_w),
213              (Arith_ConstantOp $value), TFL_AF_None),
214    (TFL_TransposeConvOp $output_shape,
215      (BinaryOp (Arith_ConstantOp $weights),
216        (Arith_ConstantOp (ExpandTo4DForConv $value)),
217        TFL_AF_None),
218      $input,
219      (BinaryOp (Arith_ConstantOp $bias),
220        (Arith_ConstantOp $value),
221        TFL_AF_None),
222      $padding, $stride_h, $stride_w),
223    [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
224     (HasOneUse $output)]>;
225  def FuseMulOrDivWithTransposeConvWithNoneBias#BinaryOp : Pat<
226    (BinaryOp (TFL_TransposeConvOp:$output $output_shape,
227                (Arith_ConstantOp FloatElementsAttr:$weights), $input,
228                $bias,
229                $padding, $stride_h, $stride_w),
230              (Arith_ConstantOp $value), TFL_AF_None),
231    (TFL_TransposeConvOp $output_shape,
232      (BinaryOp (Arith_ConstantOp $weights),
233        (Arith_ConstantOp (ExpandTo4DForConv $value)),
234        TFL_AF_None),
235      $input,
236      $bias,
237      $padding, $stride_h, $stride_w),
238    [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
239     (IsNoneType $bias),
240     (HasOneUse $output)]>;
241}
242
243foreach BinaryOp = [TFL_DivOp, TFL_MulOp]<Op> in
244  defm : FuseMulOrDivWithConv2dOrDepthwiseConv2d<BinaryOp>;
245
246
247// This pattern applies when the same quantize/dequantize have been used twice
248// with the same scale. We want to remove the redundancy.
249// TODO(fengliuai): move this to the sanity check of pre-quantize pass.
250def eliminate_dq_q_pairs : Pat<
251  (TFL_QuantizeOp (TFL_DequantizeOp $in), $qt),
252  (replaceWithValue $in),
253  [(NotFromQuantOpOrSameQuantType $in, $qt)]>;
254
255// Matching HardSwish
256def MatchHardSwishPattern1 : Pat<
257  (TFL_MulOp
258    (TFL_MulOp
259     $x, (TFL_AddOp
260          $x,
261          (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
262          TFL_AF_Relu6),
263     TFL_AF_None),
264    (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
265     TFL_AF_None),
266  (TFL_HardSwishOp $x)>;
267
268def MatchHardSwishPattern2 : Pat<
269  (TFL_MulOp
270    $x,
271    (TFL_MulOp
272     (TFL_AddOp
273      $x,
274      (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
275      TFL_AF_Relu6),
276     (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
277     TFL_AF_None),
278     TFL_AF_None),
279  (TFL_HardSwishOp $x)>;
280
281def MatchHardSwishPattern3 : Pat<
282  (TFL_MulOp
283    (TFL_MulOp
284     $x,
285     (TFL_AddOp
286      $x,
287      (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
288      TFL_AF_Relu6),
289     TFL_AF_None),
290    (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
291    TFL_AF_None),
292  (TFL_HardSwishOp $x)>;
293
294def MatchHardSwishPattern4 : Pat<
295  (TFL_MulOp
296    (TFL_MulOp
297     (TFL_AddOp
298      $x,
299      (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
300      TFL_AF_Relu6),
301     (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
302     TFL_AF_None),
303    $x,
304    TFL_AF_None),
305  (TFL_HardSwishOp $x)>;
306
307// Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to
308// incorrect placement in the quantization aware training.
309def MatchHardSwishQuantized : Pat<
310  (TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp
311    (TFL_MulOp
312     $x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp
313          $x,
314          (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
315          TFL_AF_Relu6), $qattr2)),
316     TFL_AF_None), $qattr1)),
317    (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
318     TFL_AF_None),
319  (TFL_HardSwishOp $x)>;
320
321// Constraint that the attribute value is less than 'n'
322class ConstDoubleValueLessThan<string n> : Constraint<
323  CPred<"$0.isa<DenseElementsAttr>() && "
324  "$0.cast<DenseElementsAttr>().getNumElements() == 1 && "
325  "std::abs(*$0.cast<DenseElementsAttr>().getValues<float>().begin()) < "
326  # n>>;
327
328def L2NormValidReduceIndex : Constraint<CPred<
329  "L2NormalizeReduceAxis($0, $1.cast<DenseElementsAttr>())">>;
330
331// Currently L2Normalization doesn't support activation function
332// in TFLite.
333// TODO(karimnosseir): Add constraints that the kernel code assumes.
334// constraint on axis and depth.
335multiclass L2NormalizePatterns<Op FirstOp, Op SecondOp> {
336  // This pattern constructs L2NormalizationOp from
337  // Mul->Rsqrt->Sum->Square Or
338  // Div->sqrt->Sum->Square
339  def L2NormalizePattern1#FirstOp#SecondOp : Pat<
340                  (FirstOp $x,
341                     (SecondOp
342                        (TFL_SumOp
343                           (TFL_SquareOp:$sq_op $x),
344                           (Arith_ConstantOp I32ElementsAttr:$axis),
345                           $keep_dims)),
346                     TFL_AF_None),
347           (TFL_L2NormalizationOp $x, TFL_AF_None),
348           [(L2NormValidReduceIndex $sq_op, $axis)]>;
349
350  // Below patterns for L2Normalize when there is an Add or Maximum
351  // adding or clamping to a small constant scalar.
352  def L2NormalizePattern2#FirstOp#SecondOp : Pat<
353                    (FirstOp $x,
354                     (SecondOp
355                      (TFL_AddOp
356                       (TFL_SumOp
357                        (TFL_SquareOp:$sq_op $x),
358                        (Arith_ConstantOp I32ElementsAttr:$axis),
359                        $keep_dims),
360                       (Arith_ConstantOp $epsilon), TFL_AF_None)),
361           TFL_AF_None),
362           (TFL_L2NormalizationOp $x, TFL_AF_None),
363           [(L2NormValidReduceIndex $sq_op, $axis),
364            (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
365
366  def L2NormalizePattern3#FirstOp#SecondOp : Pat<
367                    (FirstOp $x,
368                     (SecondOp
369                      (TFL_MaximumOp
370                       (TFL_SumOp
371                        (TFL_SquareOp:$sq_op $x),
372                        (Arith_ConstantOp I32ElementsAttr:$axis),
373                        $keep_dims),
374                       (Arith_ConstantOp $epsilon))),
375           TFL_AF_None),
376           (TFL_L2NormalizationOp $x, TFL_AF_None),
377           [(L2NormValidReduceIndex $sq_op, $axis),
378            (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
379
380}
381
382foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]]
383  in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>;
384
385//===----------------------------------------------------------------------===//
386// Binary ops patterns.
387//===----------------------------------------------------------------------===//
388def AreBroadcastableTypes : Constraint<CPred<
389  "TFL::IsBroadcastableElementsAttrAndType($0.getType(), $1.getType())">>;
390
391def OperandsBroadcastToOutputType : Constraint<CPred<
392  "TFL::OperandsBroadcastToOutputType($0.getType(), $1.getType(), "
393                                     "$2.getType())">>;
394
395def IsTailOfShape : Constraint<CPred<
396  "TFL::IsTailOfShape($0.getType(), $1.getType())">>;
397
398def Flatten : NativeCodeCall<
399  "$0.cast<DenseElementsAttr>()"
400    ".reshape(RankedTensorType::get({$0.getType().cast<ShapedType>().getNumElements()}, "
401                                   "$0.getType().cast<ShapedType>().getElementType()))">;
402
403def IsLastDimEqualToNumElements : Constraint<CPred<
404  "$0.getType().cast<ShapedType>().getRank() >= 1 && "
405  "$0.getType().cast<ShapedType>().getDimSize($0.getType().cast<ShapedType>().getRank() - 1) == "
406  "$1.getType().cast<ShapedType>().getNumElements()">>;
407
408def IsDefinedByFullyConnectedOp : Constraint<CPred<
409  "$0.getDefiningOp<TFL::FullyConnectedOp>() != nullptr">>;
410
411// Pattern for skipping Tile if it is mainly for broadcasting and the
412// Op is already supporting broadcasting.
413multiclass FuseTileBroadcastIntoFollowingBinary<Op BinaryOp> {
414  def FuseTileBroadcastToBinaryOp1#BinaryOp : Pat<
415    (BinaryOp:$result (TFL_TileOp $input, (Arith_ConstantOp $tile)),
416     $operand, $act_func),
417    (BinaryOp $input, $operand, $act_func),
418  [(OperandsBroadcastToOutputType $input, $operand, $result),
419   (HasRankAtMost<4> $input),
420   (HasRankAtMost<4> $operand)]>;
421
422  def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat<
423    (BinaryOp:$result $operand,
424      (TFL_TileOp $input, (Arith_ConstantOp $tile)), $act_func),
425    (BinaryOp $operand, $input, $act_func),
426  [(OperandsBroadcastToOutputType $operand, $input, $result),
427   (HasRankAtMost<4> $operand),
428   (HasRankAtMost<4> $input)]>;
429}
430
431// Multi-pattern consisting of matching stand-alone op or op followed by relu.
432multiclass FusedBinaryActivationFuncOpPat<Op BinaryOp> {
433  foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
434                       [TFL_Relu6Op, TFL_AF_Relu6],
435                       [TFL_Relu1Op, TFL_AF_Relu1]] in {
436    def FuseBinaryWithActivation#BinaryOp#actFnPair[0] : Pat<
437      (actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)),
438      (BinaryOp $lhs, $rhs, actFnPair[1]),
439    [(HasOneUse $binary_out)]>;
440  }
441}
442
443foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
444  defm : FuseTileBroadcastIntoFollowingBinary<BinaryOp>;
445
446  // Instantiated FusedBinary patterns for the from-to pairs of ops.
447  defm : FusedBinaryActivationFuncOpPat<BinaryOp>;
448
449  // Move binary op before reshape: reshape -> binary => binary -> reshape.
450  // This is valid only when the binary operand is constant and the shape is the
451  // tail of the other operand and the intermediate result isn't used by other
452  // ops.
453  // $rhs is required to be the tail shape of $lhs, so after transformation the
454  // shape of the binary op result is valid. For example, assume the shapes of
455  // $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the
456  // transformation, the shape of the binary op result is [40x1600], which
457  // couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
458  // make sure $rhs is the tail shape of $lhs.
459  def MoveBinaryOpConstBeforeReshape#BinaryOp : Pat<
460    (BinaryOp (TFL_ReshapeOp:$lhs $input, (Arith_ConstantOp:$shape $s)),
461      (Arith_ConstantOp:$rhs $a), $act_fn),
462    (TFL_ReshapeOp (BinaryOp $input, $rhs, $act_fn), $shape),
463    // The broadcasting of "BinaryOp" only happens in the lower
464    // dimensions, and the higher dimensions are same, so we know the
465    // result and input of the "BinaryOp" in the source pattern have
466    // the same shape, which is defined by `shape`.
467    [(IsTailOfShape $rhs, $lhs),
468     (HasOneUse $lhs),
469     // The result of the new "BinaryOp" will have the same shape as
470     // `input`. In other words, the shape of the `Reshape` op are not
471     // changed after the transformation.
472     (IsTailOfShape $rhs, $input),
473     (HasRankAtMost<4> $input),
474     (HasRankAtMost<4> $lhs),
475     (HasRankAtMost<4> $rhs),
476     (SameElementType $input, $rhs)]>;
477
478    // Move binary op before reshape:
479    // binary(reshape(lhs), reshape(rhs)) => reshape(binary(lhs, rhs))
480    // This is valid only when both side of the binary operand is reshaped, and
481    // the sizes are the same both before and after the reshape.
482    def MoveBinaryOpBeforeReshape#BinaryOp : Pat<
483      (BinaryOp (TFL_ReshapeOp:$lhs $input1, (Arith_ConstantOp:$shape1 $s1)),
484                (TFL_ReshapeOp:$rhs $input2, (Arith_ConstantOp:$shape2 $s2)),
485                $act_fn),
486      (TFL_ReshapeOp (BinaryOp $input1, $input2, $act_fn), $shape1),
487      [(IsTailOfShape $rhs, $lhs),
488       (IsTailOfShape $lhs, $rhs),
489       (IsTailOfShape $input1, $input2),
490       (IsTailOfShape $input2, $input1),
491       (SameElementType $input1, $input2)]>;
492
493    // Move binary op before reshape:
494    // binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs)))
495    // This is valid only when the last dimension of lhs is equal to the
496    // number of elements in constant rhs.
497    // Therefore, after transformation broadcast of binary op is always
498    // applied to the last dimension of $input.
499    def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat<
500      (BinaryOp (TFL_ReshapeOp:$lhs $input, (Arith_ConstantOp:$shape $s)),
501                (Arith_ConstantOp:$rhs ElementsAttr:$rhs_attr), $act_fn),
502      (TFL_ReshapeOp (BinaryOp $input, (Arith_ConstantOp (Flatten $rhs_attr)),
503                               $act_fn),
504                     $shape),
505      [(AnyStaticShapeTensor $input),
506       (IsTailOfShape $rhs, $lhs),
507       (IsLastDimEqualToNumElements $input, $rhs),
508       (HasOneUse $lhs),
509       // Restrict operands to have at most rank 4 because TFLite binary
510       // kernel supports up to 4D broadcast.
511       (HasRankAtMost<4> $input),
512       (HasRankAtMost<4> $lhs),
513       (HasRankAtMost<4> $rhs),
514       (IsDefinedByFullyConnectedOp $input)]>;
515}
516
517foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
518                    TFL_MaximumOp, TFL_LessOp, TFL_LessEqualOp, TFL_GreaterOp,
519                    TFL_GreaterEqualOp] in {
520  // Move binary op before reshape: reshape -> binary => binary -> reshape.
521  // This is valid only when the binary operand is constant and the shape is the
522  // tail of the other operand and the intermediate result isn't used by other
523  // ops.
524  // $rhs is required to be the tail shape of $lhs, so after transformation the
525  // shape of the binary op result is valid. For example, assume the shapes of
526  // $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the
527  // transformation, the shape of the binary op result is [40x1600], which
528  // couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
529  // make sure $rhs is the tail shape of $lhs.
530  def MoveBinaryOpConstBeforeReshape#BinaryOp : Pat<
531    (BinaryOp (TFL_ReshapeOp:$lhs $input, (Arith_ConstantOp:$shape $s)),
532      (Arith_ConstantOp:$rhs $a)),
533    (TFL_ReshapeOp (BinaryOp $input, $rhs), $shape),
534    // The broadcasting of "BinaryOp" only happens in the lower
535    // dimensions, and the higher dimensions are same, so we know the
536    // result and input of the "BinaryOp" in the source pattern have
537    // the same shape, which is defined by `shape`.
538    [(IsTailOfShape $rhs, $lhs),
539     (HasOneUse $lhs),
540     // The result of the new "BinaryOp" will have the same shape as
541     // `input`. In other words, the shape of the `Reshape` op are not
542     // changed after the transformation.
543     (IsTailOfShape $rhs, $input),
544     (HasRankAtMost<4> $input),
545     (HasRankAtMost<4> $lhs),
546     (HasRankAtMost<4> $rhs),
547     (SameElementType $input, $rhs)]>;
548
549    // Move binary op before reshape:
550    // binary(reshape(lhs), reshape(rhs)) => reshape(binary(lhs, rhs))
551    // This is valid only when both side of the binary operand is reshaped, and
552    // the sizes are the same both before and after the reshape.
553    def MoveBinaryOpBeforeReshape#BinaryOp : Pat<
554      (BinaryOp (TFL_ReshapeOp:$lhs $input1, (Arith_ConstantOp:$shape1 $s1)),
555                (TFL_ReshapeOp:$rhs $input2, (Arith_ConstantOp:$shape2 $s2))),
556      (TFL_ReshapeOp (BinaryOp $input1, $input2), $shape1),
557      [(IsTailOfShape $rhs, $lhs),
558       (IsTailOfShape $lhs, $rhs),
559       (IsTailOfShape $input1, $input2),
560       (IsTailOfShape $input2, $input1),
561       (SameElementType $input1, $input2)]>;
562
563    // Move binary op before reshape:
564    // binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs)))
565    // This is valid only when the last dimension of lhs is equal to the
566    // number of elements in constant rhs.
567    // Therefore, after transformation broadcast of binary op is always
568    // applied to the last dimension of $input.
569    def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat<
570      (BinaryOp (TFL_ReshapeOp:$lhs $input, (Arith_ConstantOp:$shape $s)),
571                (Arith_ConstantOp:$rhs ElementsAttr:$rhs_attr)),
572      (TFL_ReshapeOp (BinaryOp $input, (Arith_ConstantOp (Flatten $rhs_attr))),
573                     $shape),
574      [(AnyStaticShapeTensor $input),
575       (IsTailOfShape $rhs, $lhs),
576       (IsLastDimEqualToNumElements $input, $rhs),
577       (HasOneUse $lhs),
578       // Restrict operands to have at most rank 4 because TFLite binary
579       // kernel supports up to 4D broadcast.
580       (HasRankAtMost<4> $input),
581       (HasRankAtMost<4> $lhs),
582       (HasRankAtMost<4> $rhs),
583       (IsDefinedByFullyConnectedOp $input)]>;
584}
585
586// Reorder the element-wise value operations and the element move operations,
587// such that the value operation happens before move operation.
588foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp,
589                   TFL_ReluOp, TFL_Relu1Op, TFL_Relu6Op, TFL_RoundOp,
590                   TFL_TanhOp, TFL_SqrtOp, TFL_SquareOp, TFL_LogisticOp] in {
591  foreach MoveOp = [TFL_DepthToSpaceOp, TFL_ExpandDimsOp, TFL_SqueezeOp,
592                   TFL_ReshapeOp, TFL_TransposeOp] in {
593    def ReorderElementwiseAndMoveOperations#ValueOp#MoveOp : Pat<
594      (ValueOp:$value (MoveOp:$move $input, $move_def)),
595      (MoveOp (ValueOp $input), $move_def),
596      [(SameElementType $input, $value), (HasOneUse $move)]>;
597  }
598}
599
600// Returns shape of a ranked tensor.
601// if called without a ranked tensor it will fail.
602def GetShape: NativeCodeCall<"GetShape($0)">;
603
604// Returns True if the operand type is RankedTensorType and valid.
605def HasValidRankedTensor : Constraint<CPred<
606  "$0.getType().isa<RankedTensorType>() && "
607  "$0.getType().cast<RankedTensorType>().getNumDynamicDims() <= 1">>;
608
609def ConvertSqueezeToReshape : Pat<
610  (TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
611  (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShape $squeeze_op))),
612  [(HasValidRankedTensor $squeeze_op)]>;
613
614// Convert expand_dims to reshape if possible.
615def ConvertExpandDimsToReshape : Pat<
616  (TFL_ExpandDimsOp:$expand_dims_op $input, $dim),
617  (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShape $expand_dims_op))),
618  [(AnyStaticShapeTensor $expand_dims_op)]>;
619
620class FloatValueEquals<string val> : Constraint<CPred<
621  "FloatValueEquals($0, " # val # ")">>;
622
623// ReLU patterns
624def MatchReluPattern : Pat<
625  (TFL_MaximumOp $input, (Arith_ConstantOp $Zero)),
626  (TFL_ReluOp $input),
627  [(FloatValueEquals<"0"> $Zero)]>;
628
629def MatchRelu1Pattern1 : Pat<
630  (TFL_MinimumOp (TFL_MaximumOp $input, (Arith_ConstantOp $NegOne)),
631    (Arith_ConstantOp $One)),
632  (TFL_Relu1Op $input),
633  [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
634
635def MatchRelu1Pattern2 : Pat<
636  (TFL_MaximumOp (TFL_MinimumOp $input, (Arith_ConstantOp $One)),
637    (Arith_ConstantOp $NegOne)),
638  (TFL_Relu1Op $input),
639  [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
640
641def MatchLeakyRelu : Pat<
642  (TFL_MaximumOp
643    (TFL_MulOp:$mul_out $x,
644     (Arith_ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),
645    $x),
646  (TFL_LeakyReluOp $x, ExtractSingleElementAsFloat:$alpha),
647  [(ConstDoubleValueLessThan<"1"> $alpha),
648   (HasOneUse $mul_out)]>;
649
650// Returns True if all users of this operation are in TF/TFL and don't need
651// shape exact matching. This prevents from removing cast on return values which
652// can break the verifier on function type mismatch.
653def AllUsersInTF : Constraint<CPred<[{
654  llvm::all_of($0.getUsers(), [&](Operation *user) {
655    auto name = user->getName().getDialectNamespace();
656    return name == "tf" || name == "tfl";
657  })
658  }]>, "all users are TF/TFL operations.">;
659
660def RemoveShapeOnlyCast : Pat<(TFL_CastOp:$output $input),
661                            (replaceWithValue $input),
662                            [(SameElementType $input, $output),
663                             (AllUsersInTF $output)]>;
664
665
666// Checks if the operand0's rank is one less than operand1's rank.
667def PReluAlphaRankCheck : Constraint<
668  CPred<"$0.getType().cast<ShapedType>().getRank() == "
669  "$1.getType().cast<ShapedType>().getRank() - 1">>;
670
671// PReLU pattern from Keras:
672// f(x) = Relu(x) + (-alpha * Relu(-x))
673def MatchPRelu : Pat<
674  (TFL_AddOp
675   (TFL_ReluOp:$relu_out $x),
676   (TFL_MulOp:$mul_out
677    (TFL_ReluOp (TFL_NegOp:$input_neg_out $x)),
678    $neg_alpha,
679    TFL_AF_None),
680   TFL_AF_None),
681  (TFL_PReluOp $x, (TFL_NegOp $neg_alpha)),
682  [(PReluAlphaRankCheck $neg_alpha, $x),
683   (HasOneUse $relu_out),
684   (HasOneUse $mul_out),
685   (HasOneUse $input_neg_out)]>;
686
687// The constant folding in this pass might produce constant in the tf dialect.
688// This rule is to legalize these constant to the tfl dialect.
689def LegalizeConstOp : Pat<
690  (TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
691
692// Reorders adds to allow constant folding.
693// Add --> Add $input, $constantA
694//    \--> $constantB
695// To
696// Add --> $input
697//    \--> Add ($constantA, $constantB)
698foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in {
699  def ReorderAddToAllowConstFold_ActFunc_#ActFun : Pat<
700    (TFL_AddOp
701     (TFL_AddOp:$first_output $input, (Arith_ConstantOp $a), TFL_AF_None),
702     (Arith_ConstantOp $b), ActFun),
703    (TFL_AddOp $input,
704     (TFL_AddOp (Arith_ConstantOp $a), (Arith_ConstantOp $b), TFL_AF_None),
705     ActFun),
706    [(HasOneUse $first_output),
707     (HasRankAtMost<4> $input),
708     (HasRankAtMost<4> $a),
709     (HasRankAtMost<4> $b)]>;
710}
711
712// We can eliminate Relu from Relu(SquaredDifference(x, y)),
713// since the result of SquaredDifference is always non-negative.
714// TFLite interpreter doesn't support Relu+int32 for now. So the test cases
715// are failing without the following pattern to optimize Relu away fixes
716// the problem.
717def OptimizeReluSquaredDifference : Pat<
718  (TFL_ReluOp (TFL_SquaredDifferenceOp $l, $r)),
719  (TFL_SquaredDifferenceOp $l, $r)>;
720
721// Optimize X^1 o X
722def OptimizePow1ToIdentity : Pat<
723  (TFL_PowOp $input,
724    (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">)),
725  (replaceWithValue $input)>;
726
727// Optimize X^2 to X*X
728def OptimizePow2ToSquare : Pat<
729  (TFL_PowOp $input,
730    (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "2.0f">)),
731  (TFL_MulOp $input, $input, TFL_AF_None)>;
732
733// Optimize X^(1/2) to √X
734def OptimizePow2ToSqrt : Pat<
735  (TFL_PowOp $input,
736    (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.5f">)),
737  (TFL_SqrtOp $input)>;
738
739// Optimize X^(-1/2) to 1/√X == rsqrt(x)
740def OptimizePow2ToRsqrt : Pat<
741  (TFL_PowOp $input,
742    (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "-0.5f">)),
743  (TFL_RsqrtOp $input)>;
744
745def CanOptimizeIdentityGatherNdOrScatterNdOp : Constraint<CPred<
746  "TFL::CanOptimizeIdentityGatherNdOrScatterNdOp("
747  "$0, $1.cast<DenseIntElementsAttr>(), $2.getType())">>;
748
749def OptimizeIdentityGatherNdOp : Pat<
750  (TFL_GatherNdOp:$output $params, (Arith_ConstantOp I32ElementsAttr: $indices)),
751  (replaceWithValue $params),
752  [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices, $output)]>;
753
754def OptimizeIdentityScatterNdOp : Pat<
755  (TFL_ScatterNdOp:$output (Arith_ConstantOp I32ElementsAttr: $indices), $params, $ignored),
756  (replaceWithValue $params),
757  [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices, $output)]>;
758
759def ShapeMatchesReduceWithKeepAxes : Constraint<CPred<
760  "ShapeMatchesReduceWithKeepAxes($0, $1, $2)">>;
761
762// Fold reshapes re-inserting reduced dimensions into the results of a reduction
763// with `keep_dims=false` by changing it to one using `keep_dims=true`.
764foreach ReduceOp = [TFL_MeanOp, TFL_ReduceMaxOp, TFL_ReduceMinOp,
765                    TFL_ReduceProdOp, TFL_SumOp] in {
766  def FoldReshapeTo#ReduceOp : Pat<
767    (TFL_ReshapeOp
768      (ReduceOp:$reduce $input, (Arith_ConstantOp I32ElementsAttr: $axes),
769                        ConstBoolAttrFalse),
770      (Arith_ConstantOp I32ElementsAttr: $shape)),
771    (ReduceOp $input, (Arith_ConstantOp $axes), ConstBoolAttrTrue),
772    [(ShapeMatchesReduceWithKeepAxes $input, $axes, $shape),
773     (HasOneUse $reduce)]>;
774}
775
776
777def IsSame : Constraint<CPred<"$0 == $1">>;
778def HasTwoUse : Constraint<CPred<
779  "std::distance($0.use_begin(), $0.use_end()) == 2">>;
780def AxesIsLastDimension : Constraint<CPred<
781  "$0.cast<DenseIntElementsAttr>().getNumElements() == 1 && "
782  "($0.cast<DenseIntElementsAttr>().getValues<APInt>()[0] == "
783  "$1.getType().cast<ShapedType>().getRank() - 1 || $0.cast<DenseIntElementsAttr>().getValues<int32_t>()[0] == -1)">>;
784
785// Convert exp(x)/sum(exp(x)) into softmax.
786def OptimizeToSoftmax : Pat<
787  (TFL_DivOp (TFL_ExpOp:$exp $input),
788             (TFL_SumOp:$sum $sum_input, (Arith_ConstantOp I32ElementsAttr: $axes),
789                             ConstBoolAttrTrue), TFL_AF_None),
790  (TFL_SoftmaxOp $input, ConstF32Attr<"1.0">),
791  [(IsSame $exp, $sum_input),
792   (AxesIsLastDimension $axes, $sum_input),
793   (HasTwoUse $exp),
794   (HasOneUse $sum)]>;
795
796// Convert softmax(x-max(x)) into softmax(x) as the softmax op already deals
797// with the max normalization.
798def FoldNormalizationIntoSoftmax : Pat<
799  (TFL_SoftmaxOp
800    (TFL_SubOp:$sub $input,
801      (TFL_ReduceMaxOp:$max $max_input, (Arith_ConstantOp I32ElementsAttr: $axes),
802                            ConstBoolAttrTrue),
803    TFL_AF_None),
804    $beta),
805  (TFL_SoftmaxOp $input, $beta),
806  [(IsSame $input, $max_input),
807   (AxesIsLastDimension $axes, $max_input),
808   (HasOneUse $sub),
809   (HasOneUse $max)]>;
810
811def HaveSameType : Constraint<CPred<"($0.getType() == $1.getType())">>;
812
813class AllElementsAreF32<string val> : Constraint<CPred<
814  "($0.isa<DenseElementsAttr>() && "
815   "$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isF32() && "
816   "std::all_of($0.cast<DenseElementsAttr>().getValues<float>().begin(), "
817               "$0.cast<DenseElementsAttr>().getValues<float>().end(), "
818               "[](float v){ return v == " #val# ";}))">>;
819
820// Optimize X*1 to X
821def OptimizeMul1ToIdentity : Pat<
822  (TFL_MulOp:$result $input,
823             (Arith_ConstantOp $constant),
824             TFL_AF_None),
825  (replaceWithValue $input),
826  [(HaveSameType $input, $result),
827   (AllElementsAreF32<"1.0f"> $constant)]>;
828
829class AllElementsAreBool<string val> : Constraint<CPred<
830  "($0.isa<DenseElementsAttr>() && "
831   "$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isInteger(1) && "
832   "std::all_of($0.cast<DenseElementsAttr>().getValues<bool>().begin(), "
833               "$0.cast<DenseElementsAttr>().getValues<bool>().end(), "
834               "[](bool v){ return v == " #val# ";}))">>;
835
836// Remove select operators when the result is known in advance.
837foreach SelectOp = [TFL_SelectOp, TFL_SelectV2Op] in {
838  // select(true_tensor, A, B) -> A
839  def Optimize#SelectOp#True : Pat<
840    (SelectOp:$result (Arith_ConstantOp $constant),
841                      $input1,
842                      $input2),
843    (replaceWithValue $input1),
844    [(HaveSameType $input1, $result),
845     (AllElementsAreBool<"true"> $constant)]>;
846  // select(false_tensor, A, B) -> B
847  def Optimize#SelectOp#False : Pat<
848    (SelectOp:$result (Arith_ConstantOp $constant),
849                      $input1,
850                      $input2),
851    (replaceWithValue $input2),
852    [(HaveSameType $input2, $result),
853     (AllElementsAreBool<"false"> $constant)]>;
854  // select(logical_not(C), A, B) -> select(C, B, A)
855  def Optimize#SelectOp#Not : Pat<
856    (SelectOp (TFL_LogicalNotOp $condition), $input1, $input2),
857    (SelectOp $condition, $input2, $input1)>;
858}
859
860def EliminateLogicalAndTrue : Pat<
861  (TFL_LogicalAndOp:$result $lhs, (Arith_ConstantOp:$rhs $constant)),
862  (replaceWithValue $lhs),
863  [(AllElementsAreBool<"true"> $constant), (HaveSameType $lhs, $result)]>;
864
865def EliminateLogicalAndFalse : Pat<
866  (TFL_LogicalAndOp:$result $lhs, (Arith_ConstantOp:$rhs $constant)),
867  (replaceWithValue $rhs),
868  [(AllElementsAreBool<"false"> $constant), (HaveSameType $rhs, $result)]>;
869
870def EliminateLogicalOrTrue : Pat<
871  (TFL_LogicalOrOp:$result $lhs, (Arith_ConstantOp:$rhs $constant)),
872  (replaceWithValue $rhs),
873  [(AllElementsAreBool<"true"> $constant), (HaveSameType $rhs, $result)]>;
874
875def EliminateLogicalOrFalse : Pat<
876  (TFL_LogicalOrOp:$result $lhs, (Arith_ConstantOp:$rhs $constant)),
877  (replaceWithValue $lhs),
878  [(AllElementsAreBool<"false"> $constant), (HaveSameType $lhs, $result)]>;
879
880// Remove reductions that do nothing: input and output have the same size.
881foreach ReduceOp = [TFL_ReduceAnyOp, TFL_ReduceAllOp,
882                    TFL_ReduceMinOp, TFL_ReduceMaxOp,
883                    TFL_MeanOp, TFL_SumOp, TFL_ReduceProdOp] in {
884  def EliminateNoOpReductionOp#ReduceOp : Pat<
885    (ReduceOp:$output $input, $index, $keep_dims),
886    (replaceWithValue $input),
887    [(IsTailOfShape $input, $output),
888     (IsTailOfShape $output, $input)]>;
889}
890
891// Remove (log-)softmax before arg-minmax as (log-)softmax is monotonic.
892foreach ArgMinMaxOp = [TFL_ArgMinOp, TFL_ArgMaxOp] in {
893  def RemoveSoftmaxOpBefore#ArgMinMaxOp : Pat<
894    (ArgMinMaxOp (TFL_SoftmaxOp:$softmax $logits, TFL_FloatNonNegative:$beta),
895                 (Arith_ConstantOp:$const_axes I32ElementsAttr:$axes)),
896    (ArgMinMaxOp $logits, $const_axes),
897    [(HasOneUse $softmax),
898     (AxesIsLastDimension $axes, $logits)]>;
899
900  def RemoveLogSoftmaxOpBefore#ArgMinMaxOp : Pat<
901    (ArgMinMaxOp (TFL_LogSoftmaxOp:$log_softmax $logits),
902                 (Arith_ConstantOp:$const_axes I32ElementsAttr:$axes)),
903    (ArgMinMaxOp $logits, $const_axes),
904    [(HasOneUse $log_softmax),
905     (AxesIsLastDimension $axes, $logits)]>;
906}
907
908def CanOptimizeIdentitySliceOp : Constraint<CPred<
909  "TFL::CanOptimizeIdentitySliceOp($0, $1, $2)">>;
910
911// Remove Slice ops slicing the whole input tensor, effectively no-op.
912def OptimizeSliceOp : Pat<
913  (TFL_SliceOp:$output $input, (Arith_ConstantOp $begin), (Arith_ConstantOp $size)),
914  (replaceWithValue $input),
915  [(CanOptimizeIdentitySliceOp $input, $begin, $size)]>;
916
917def GetNumElementsOrOne: NativeCodeCall<"GetNumElementsOrOne($0)">;
918
919def ReshapeValueDroppingLastDim : NativeCodeCall<
920  "ReshapeValueDroppingLastDim($_builder, $0, $1)">;
921
922def HasExactlyTwoElements : Constraint<CPred<
923  "TFL::HasExactlyTwoElements($0)">>;
924
925def IsLastElementEqualsOne : Constraint<CPred<
926  "TFL::IsLastElementEqualsOne($0)">>;
927
928def IsOneHotIndexAttribute : Constraint<CPred<
929  "TFL::IsOneHotIndexAttribute($0)">>;
930
931// Replace
932//   Equal(Reshape(X, shape), indices)
933// With
934//   OneHot(Reshape(X, shape[:-1]), N, true, false, -1)
935// where
936//  - shape has length 2 (unnecessary, just to be conservative)
937//  - last value in shape is 1
938//  - indices is a incrementing series from 0 to N-1. (N elements total.)
939def ReshapeEqualOpToOneHotOp : Pat<
940  (TFL_EqualOp (TFL_ReshapeOp $x, (Arith_ConstantOp $shape)),
941               (Arith_ConstantOp $series)),
942  (TFL_OneHotOp (ReshapeValueDroppingLastDim $x, $shape),
943                (Arith_ConstantOp (GetNumElementsOrOne $series)),
944                (Arith_ConstantOp ConstantAttr<RankedSignlessIntElementsAttr<1, []>, "true">),
945                (Arith_ConstantOp ConstantAttr<RankedSignlessIntElementsAttr<1, []>, "false">),
946                ConstantAttr<I32Attr, "-1">),
947  [(HasExactlyTwoElements $shape),
948   (IsLastElementEqualsOne $shape),
949   (IsOneHotIndexAttribute $series)]>;
950
951def F32ElementsVal : Constraint<CPred<
952  "$0.getType().cast<TensorType>().getElementType().isF32()">,
953  "32 bit float tensor">;
954def I32ElementsVal : Constraint<CPred<
955  "$0.getType().cast<TensorType>().getElementType().isInteger(32)">,
956  "32 bit integer tensor">;
957
958def ConvertSingleElementAttrToFloatAttr :
959  NativeCodeCall<"ConvertSingleElementAttrToFloatAttr($0)">;
960
961// Replace
962//   (float)OneHot(index, depth, on_val, off_val, axis)
963// With
964//   OneHot(index, depth, (float)on_val, (float)off_val, axis)
965def FuseOneHotAndCastToFloat : Pat<
966  (TFL_CastOp:$output (TFL_OneHotOp $indices,
967                                    $depth,
968                                    (Arith_ConstantOp $on_val),
969                                    (Arith_ConstantOp $off_val),
970                                    $axis)),
971  (TFL_OneHotOp $indices,
972                $depth,
973                (Arith_ConstantOp (ConvertSingleElementAttrToFloatAttr $on_val)),
974                (Arith_ConstantOp (ConvertSingleElementAttrToFloatAttr $off_val)),
975                $axis),
976  [(F32ElementsVal $output)]>;
977
978// Replace
979//   OneHot(index, depth, on=1.0f, off=0.0f, axis=-1) * filter
980// With
981//   EmbeddingLookup(index, Transpose(filter))
982//
983// OneHot with on=1 off=0 axis=-1, where `index` is a single element tensor,
984// creates a tensor of size depth, and all values are 0, except for the element
985// at `index`, which is 1. Multiplying such a tensor with a 2D filter esentially
986// returns the single column in filter as a 1D tensor. If the input has multiple
987// elements, repeat this for every entry, forming the higher dimensions in the
988// result tensor. For instance, if:
989//   input = [1, 2]
990//   depth = 4
991//   filter = [[5, 7, 11, 13], [17, 19, 23, 29]]
992// then:
993//   onehot = [[0, 1, 0, 0], [0, 0, 1, 0]]
994//   result = [[ 7, 19],   # == 1st column in filter
995//             [11, 23]]   # == 2nd column in filter
996// This is exactly what the EmbeddedLookup operator is doing, on the transposed
997// matrix, without doing any arithmetic but only memcpy.
998def ReplaceOneHotFullyConnectedWithLookup : Pat<
999  (TFL_FullyConnectedOp
1000    (TFL_OneHotOp
1001      $indices,
1002      (Arith_ConstantOp $depth),
1003      (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">),
1004      (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.0f">),
1005      ConstantAttr<I32Attr, "-1">),
1006    $filter,
1007    $bias,
1008    TFL_AF_None,
1009    TFL_FCWO_Default,
1010    ConstBoolAttrFalse,
1011    $asymmetric_quantize_inputs),
1012  (TFL_EmbeddingLookupOp
1013    $indices,
1014    (TFL_TransposeOp
1015      $filter,
1016      (Arith_ConstantOp ConstantAttr<RankedI32ElementsAttr<[2]>, "{1,0}"> ))),
1017  [(I32ElementsVal $indices),     // lookup is not implemented for i64
1018   (HasRank<1> $indices),  // lookup isn't implemented for any other rank
1019   (IsNoneType $bias)]>;          // Maybe folded into the lookup matrix later
1020
1021def AreInputDimensionsOneInAxes : Constraint<CPred<
1022  "AreInputDimensionsOneInAxes($0, $1)">>;
1023
1024// Eliminate cumulative summations if the input's dimension in axis is 1.
1025def EliminateCumSumInclusive : Pat<
1026  (TFL_CumsumOp
1027     $input,
1028     (Arith_ConstantOp I32ElementsAttr:$axis),
1029     ConstBoolAttrFalse,
1030     $reverse),
1031  (replaceWithValue $input),
1032  [(AreInputDimensionsOneInAxes $input, $axis)]>;
1033
1034// Fusing raw computation of GELU op into one native tfl_gelu op.
1035//
1036// Requires constants to be exact match and only one use of all of the
1037// intermediate results.
1038//
1039// For GeluApproximate, replaces
1040//   0.5 * x * ( 1 + tanh( sqrt_2dPi  * ( x + 0.044715 * pow( x, 3 ) ) ) )
1041def MatchGeluApproximate : Pat<
1042  (TFL_MulOp
1043   (TFL_MulOp:$mul_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None),
1044   (TFL_AddOp:$add_out
1045    (TFL_TanhOp:$tanh_out
1046     (TFL_MulOp:$mul_out1
1047      (TFL_AddOp:$add_out1 $arg0,
1048       (TFL_MulOp:$mul_out2
1049        (TFL_PowOp:$pow_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_3)),
1050        (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None),
1051      (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)),
1052    (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), TFL_AF_None),
1053  (TFL_GeluOp $arg0, ConstBoolAttrTrue),
1054  [(FloatValueEquals<"0.5"> $Cst_1_2),
1055   (FloatValueEquals<"1"> $Cst_1),
1056   (FloatValueEquals<"3"> $Cst_3),
1057   (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi),
1058   (FloatValueEquals<"0.044715"> $Coeff),
1059   (HasOneUse $mul_out),
1060   (HasOneUse $add_out),
1061   (HasOneUse $tanh_out),
1062   (HasOneUse $mul_out1),
1063   (HasOneUse $add_out1),
1064   (HasOneUse $mul_out2),
1065   (HasOneUse $pow_out),
1066  ]>;
1067
1068// Alternate pattern for GeluApproximate (see different order for mul), replaces
1069//   x * ( 0.5 * ( 1 + tanh( sqrt_2dPi  * ( x + 0.044715 * pow( x, 3 ) ) ) ) )
1070def MatchGeluApproximate1 : Pat<
1071  (TFL_MulOp $arg0,
1072   (TFL_MulOp:$mul_out
1073    (TFL_AddOp:$add_out
1074     (TFL_TanhOp:$tanh_out
1075      (TFL_MulOp:$mul_out1
1076       (TFL_AddOp:$add_out1 $arg0,
1077        (TFL_MulOp:$mul_out2
1078         (TFL_PowOp:$pow_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_3)),
1079         (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None),
1080       (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)),
1081     (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), TFL_AF_None),
1082  (TFL_GeluOp $arg0, ConstBoolAttrTrue),
1083  [(FloatValueEquals<"0.5"> $Cst_1_2),
1084   (FloatValueEquals<"1"> $Cst_1),
1085   (FloatValueEquals<"3"> $Cst_3),
1086   (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi),
1087   (FloatValueEquals<"0.044715"> $Coeff),
1088   (HasOneUse $mul_out),
1089   (HasOneUse $add_out),
1090   (HasOneUse $tanh_out),
1091   (HasOneUse $mul_out1),
1092   (HasOneUse $add_out1),
1093   (HasOneUse $mul_out2),
1094   (HasOneUse $pow_out),
1095  ]>;
1096
1097// For Gelu, replaces
1098//   0.5 * x * ( 1 + erf( x * sqrt_1_2 ) )
1099def MatchGelu : Pat<
1100  (TFL_MulOp
1101   (TFL_MulOp:$mul_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None),
1102   (TFL_AddOp:$add_out
1103    (TF_ErfOp:$erf_out
1104     (TFL_MulOp:$mul_out1 $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_1_2), TFL_AF_None)),
1105    (Arith_ConstantOp  F32ElementsAttr:$Cst_1), TFL_AF_None), TFL_AF_None),
1106  (TFL_GeluOp $arg0, ConstBoolAttrFalse),
1107  [(FloatValueEquals<"0.5"> $Cst_1_2),
1108   (FloatValueEquals<"1"> $Cst_1),
1109   (FloatValueEquals<"0.707106769"> $Cst_sqrt_1_2),
1110   (HasOneUse $mul_out),
1111   (HasOneUse $add_out),
1112   (HasOneUse $erf_out),
1113   (HasOneUse $mul_out1),
1114  ]>;
1115
1116
1117// Checks if the shape has shape with last dimension equals 1.
1118def IsLastDimensionEqualOne : Constraint<CPred<"IsLastDimensionEqualOne($0)">>;
1119
1120// Fetches the output of FC op, from the provided arguments.
1121def GetFcOutput : NativeCodeCall<
1122  "GetFcOutput(&$_builder, $0, $1, $2, $3, $4, $5, $6, $7)">;
1123
1124// Verifies all values in the provided argument are zero.
1125def AllValuesAreZero :  Constraint<CPred<"AllValuesAreZero($0)">>;
1126
1127def SimplifyDoubleSelectFCZerosLHS : Pat<
1128  (TFL_SelectV2Op $condition, $zeros_2,
1129   (TFL_FullyConnectedOp:$results
1130    (TFL_SelectV2Op $condition, $zeros_1, $input),
1131    $filter, $bias, $fused_activation_function, $weights_format,
1132    ConstBoolAttrTrue, $asymmetric_quantize_inputs)),
1133  (TFL_SelectV2Op $condition, $zeros_2,
1134   (GetFcOutput $results, $input, $filter, $bias, $fused_activation_function,
1135    $weights_format, ConstBoolAttrTrue, $asymmetric_quantize_inputs)),
1136  [(IsLastDimensionEqualOne $condition),
1137   (AllValuesAreZero $zeros_1),
1138   (AllValuesAreZero $zeros_2)
1139  ]>;
1140
1141def SimplifyDoubleSelectFCZerosRHS : Pat<
1142  (TFL_SelectV2Op $condition,
1143   (TFL_FullyConnectedOp:$results
1144    (TFL_SelectV2Op $condition, $input, $zeros_1),
1145    $filter, $bias, $fused_activation_function, $weights_format,
1146    ConstBoolAttrTrue, $asymmetric_quantize_inputs),
1147   $zeros_2),
1148  (TFL_SelectV2Op $condition,
1149   (GetFcOutput $results, $input, $filter, $bias, $fused_activation_function,
1150    $weights_format, ConstBoolAttrTrue, $asymmetric_quantize_inputs),
1151   $zeros_2),
1152  [(IsLastDimensionEqualOne $condition),
1153   (AllValuesAreZero $zeros_1),
1154   (AllValuesAreZero $zeros_2)
1155  ]>;
1156