1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// This is the optimization pattern definition file for TensorFlow Lite. 17 18include "mlir/IR/OpBase.td" 19include "mlir/IR/PatternBase.td" 20include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" 21include "mlir/Dialect/Func/IR/FuncOps.td" 22include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" 23include "tensorflow/compiler/mlir/lite/utils/utils.td" 24include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" 25 26// Checks if the param passed is a F32 ElementsAttr. 27def F32ElementsAttr : ElementsAttrBase< 28 CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isF32()">, 29 "32 bit float constant tensor">; 30 31// Checks if the param passed is a float ElementsAttr. 32def FloatElementsAttr : ElementsAttrBase< 33 CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isa<FloatType>()">, 34 "float constant tensor">; 35 36// Checks if the param passed is of NoneType. 37def IsNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>; 38 39def ExtractSingleElementAsFloat : NativeCodeCall< 40 "ExtractSingleElementAsFloat($_self.cast<ElementsAttr>())">; 41 42// Checks if the value has rank at most 'n'. 43class HasRankAtMost<int n> : Constraint< 44 CPred<"$0.getType().cast<ShapedType>().hasRank() && " 45 "$0.getType().cast<ShapedType>().getRank() <= " # n>>; 46 47// Checks if the value has rank 'n'. 48class HasRank<int n> : Constraint< 49 CPred<"$0.getType().cast<ShapedType>().hasRank() && " 50 "$0.getType().cast<ShapedType>().getRank() == " # n>>; 51 52//===----------------------------------------------------------------------===// 53// Ternary ops patterns. 54//===----------------------------------------------------------------------===// 55// Multi-pattern consisting of matching stand-alone convolution op followed by 56// activation op. 57multiclass FuseActFnIntoConvOpPat<Op ActFnOp, ConstantStrAttr ActFnAttr> { 58 def FuseActivationFuncWithConv#ActFnOp#ActFnAttr : Pat< 59 (ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias, $h_factor, 60 $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w)), 61 (TFL_Conv2DOp $input, $filter, $bias, $h_factor, $w_factor, ActFnAttr, 62 $padding, $stride_h, $stride_w), 63 [(HasOneUse $conv_out)]>; 64 def FuseActivationFuncWithDepthwiseConv#ActFnOp#ActFnAttr : Pat< 65 (ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias, $h_factor, 66 $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w, 67 $multiplier)), 68 (TFL_DepthwiseConv2DOp $input, $filter, $bias, $h_factor, $w_factor, 69 ActFnAttr, $padding, $stride_h, $stride_w, $multiplier), 70 [(HasOneUse $conv_out)]>; 71} 72 73multiclass FuseActFnIntoPoolOpPat<Op ActFnOp, ConstantStrAttr ActFnAttr> { 74 def FuseActivationFuncWithAvgPool#ActFnOp#ActFnAttr : Pat< 75 (ActFnOp (TFL_AveragePool2DOp:$pool_out $input, $filter_height, 76 $filter_width, $padding, $stride_h, $stride_w, TFL_AF_None)), 77 (TFL_AveragePool2DOp $input, $filter_height, $filter_width, $padding, 78 $stride_h, $stride_w, ActFnAttr), 79 [(HasOneUse $pool_out)]>; 80 def FuseActivationFuncWithMaxPool#ActFnOp#ActFnAttr : Pat< 81 (ActFnOp (TFL_MaxPool2DOp:$pool_out $input, $padding, $stride_w, $stride_h, 82 $filter_width, $filter_height, TFL_AF_None)), 83 (TFL_MaxPool2DOp $input, $padding, $stride_w, $stride_h, 84 $filter_width, $filter_height, ActFnAttr), 85 [(HasOneUse $pool_out)]>; 86} 87 88// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused 89// activation functions. 90// Currently we're not fusing tanh, sigmoid, hard_swish and other activations 91// those cannot be simply translated into clamping. 92foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], 93 [TFL_Relu6Op, TFL_AF_Relu6], 94 [TFL_Relu1Op, TFL_AF_Relu1]] in { 95 defm : FuseActFnIntoConvOpPat<!cast<Op>(actFnPair[0]), !cast<ConstantStrAttr>(actFnPair[1])>; 96 defm : FuseActFnIntoPoolOpPat<!cast<Op>(actFnPair[0]), !cast<ConstantStrAttr>(actFnPair[1])>; 97} 98 99class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint< 100 CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>; 101 102// If we see a binary op (add, sub) op adding a constant value to a convolution 103// op with constant bias, we can fuse the binary op into the convolution op by 104// constant folding the bias and the binary op's constant operand. The following 105// pattern restricts to float constant values for now. 106multiclass FuseBinaryOpToPrecedingAffine<Op binaryOp> { 107 def FuseBinaryOpWithConv#binaryOp : Pat< 108 (binaryOp (TFL_Conv2DOp:$output $input, $filter, 109 (Arith_ConstantOp FloatElementsAttr:$bias), $h_factor, $w_factor, 110 TFL_AF_None, $padding, $stride_h, $stride_w), 111 (Arith_ConstantOp FloatElementsAttr:$value), $act_fn), 112 (TFL_Conv2DOp $input, $filter, 113 (binaryOp (Arith_ConstantOp $bias), 114 (Arith_ConstantOp $value), TFL_AF_None), 115 $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w), 116 [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), 117 (HasOneUse $output)]>; 118 def FuseBinaryOpWithDepthwiseConv#binaryOp : Pat< 119 (binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter, 120 (Arith_ConstantOp FloatElementsAttr:$bias), 121 $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h, 122 $stride_w, $multiplier), 123 (Arith_ConstantOp FloatElementsAttr:$value), $act_fn), 124 (TFL_DepthwiseConv2DOp $input, $filter, 125 (binaryOp (Arith_ConstantOp $bias), (Arith_ConstantOp $value), TFL_AF_None), 126 $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w, 127 $multiplier), 128 [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value), 129 (HasRank<1> $value), 130 (HasOneUse $output)]>; 131 def FuseBinaryOpWithTransposeConv#binaryOp : Pat< 132 (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs, 133 (Arith_ConstantOp FloatElementsAttr:$bias), $padding, 134 $stride_h, $stride_w), 135 (Arith_ConstantOp FloatElementsAttr:$value), TFL_AF_None), 136 (TFL_TransposeConvOp $output_shape, $weights, $inputs, 137 (binaryOp (Arith_ConstantOp $bias), 138 (Arith_ConstantOp $value), TFL_AF_None), 139 $padding, $stride_h, $stride_w), 140 [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), 141 (HasOneUse $output)]>; 142 // Fuse for TransposeConv with no bias 143 def FuseBinaryOpWithTransposeConvNoneBias#binaryOp : Pat< 144 (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs, 145 $bias, $padding, 146 $stride_h, $stride_w), 147 (Arith_ConstantOp FloatElementsAttr:$value), TFL_AF_None), 148 (TFL_TransposeConvOp $output_shape, $weights, $inputs, 149 (Arith_ConstantOp $value), 150 $padding, $stride_h, $stride_w), 151 [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), 152 (IsNoneType $bias), 153 (HasOneUse $output)]>; 154} 155foreach binaryOp = [TFL_AddOp, TFL_SubOp]<Op> in 156 defm : FuseBinaryOpToPrecedingAffine<binaryOp>; 157 158def ExpandTo4DForConv: NativeCodeCall<"ExpandTo4DForConv($0)">; 159 160def ExpandTo4DForDepthwiseConv: NativeCodeCall< 161 "ExpandTo4DForDepthwiseConv($0)">; 162 163// If we see a (div or Mul) op (dividing/multiplying) a constant value 164// to a convolution op with constant filter and bias, we can fuse the div/mul 165// into the convolution op by constant folding 166// the filter/bias and the div/mul op's constant operand. 167// The following pattern restricts to float constant values for now. 168 169multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<Op BinaryOp> { 170 def FuseMulOrDivWithDepthwiseConv#BinaryOp : Pat< 171 (BinaryOp (TFL_DepthwiseConv2DOp:$output $input, 172 (Arith_ConstantOp FloatElementsAttr:$filter), 173 (Arith_ConstantOp FloatElementsAttr:$bias), 174 $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h, 175 $stride_w, $multiplier), 176 (Arith_ConstantOp FloatElementsAttr:$value), $act_fn), 177 (TFL_DepthwiseConv2DOp $input, 178 (BinaryOp 179 (Arith_ConstantOp $filter), 180 (Arith_ConstantOp (ExpandTo4DForDepthwiseConv $value)), 181 TFL_AF_None), 182 (BinaryOp 183 (Arith_ConstantOp $bias), 184 (Arith_ConstantOp $value), 185 TFL_AF_None), 186 $h_factor, $w_factor, $act_fn, $padding, $stride_h, 187 $stride_w, $multiplier), 188 [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value), 189 (HasRank<1> $value), 190 (HasOneUse $output)]>; 191 def FuseMulOrDivWithConv#BinaryOp : Pat< 192 (BinaryOp (TFL_Conv2DOp:$conv_output $input, 193 (Arith_ConstantOp FloatElementsAttr:$filter), 194 (Arith_ConstantOp FloatElementsAttr:$bias), 195 $h_factor, $w_factor, TFL_AF_None, 196 $padding, $stride_h, $stride_w), 197 (Arith_ConstantOp FloatElementsAttr:$value), $act_fn), 198 (TFL_Conv2DOp $input, 199 (BinaryOp (Arith_ConstantOp $filter), 200 (Arith_ConstantOp (ExpandTo4DForConv $value)), 201 TFL_AF_None), 202 (BinaryOp (Arith_ConstantOp $bias), 203 (Arith_ConstantOp $value), 204 TFL_AF_None), 205 $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w), 206 [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), 207 (HasOneUse $conv_output)]>; 208 def FuseMulOrDivWithTransposeConv#BinaryOp : Pat< 209 (BinaryOp (TFL_TransposeConvOp:$output $output_shape, 210 (Arith_ConstantOp FloatElementsAttr:$weights), $input, 211 (Arith_ConstantOp FloatElementsAttr:$bias), 212 $padding, $stride_h, $stride_w), 213 (Arith_ConstantOp $value), TFL_AF_None), 214 (TFL_TransposeConvOp $output_shape, 215 (BinaryOp (Arith_ConstantOp $weights), 216 (Arith_ConstantOp (ExpandTo4DForConv $value)), 217 TFL_AF_None), 218 $input, 219 (BinaryOp (Arith_ConstantOp $bias), 220 (Arith_ConstantOp $value), 221 TFL_AF_None), 222 $padding, $stride_h, $stride_w), 223 [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), 224 (HasOneUse $output)]>; 225 def FuseMulOrDivWithTransposeConvWithNoneBias#BinaryOp : Pat< 226 (BinaryOp (TFL_TransposeConvOp:$output $output_shape, 227 (Arith_ConstantOp FloatElementsAttr:$weights), $input, 228 $bias, 229 $padding, $stride_h, $stride_w), 230 (Arith_ConstantOp $value), TFL_AF_None), 231 (TFL_TransposeConvOp $output_shape, 232 (BinaryOp (Arith_ConstantOp $weights), 233 (Arith_ConstantOp (ExpandTo4DForConv $value)), 234 TFL_AF_None), 235 $input, 236 $bias, 237 $padding, $stride_h, $stride_w), 238 [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), 239 (IsNoneType $bias), 240 (HasOneUse $output)]>; 241} 242 243foreach BinaryOp = [TFL_DivOp, TFL_MulOp]<Op> in 244 defm : FuseMulOrDivWithConv2dOrDepthwiseConv2d<BinaryOp>; 245 246 247// This pattern applies when the same quantize/dequantize have been used twice 248// with the same scale. We want to remove the redundancy. 249// TODO(fengliuai): move this to the sanity check of pre-quantize pass. 250def eliminate_dq_q_pairs : Pat< 251 (TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), 252 (replaceWithValue $in), 253 [(NotFromQuantOpOrSameQuantType $in, $qt)]>; 254 255// Matching HardSwish 256def MatchHardSwishPattern1 : Pat< 257 (TFL_MulOp 258 (TFL_MulOp 259 $x, (TFL_AddOp 260 $x, 261 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">), 262 TFL_AF_Relu6), 263 TFL_AF_None), 264 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">), 265 TFL_AF_None), 266 (TFL_HardSwishOp $x)>; 267 268def MatchHardSwishPattern2 : Pat< 269 (TFL_MulOp 270 $x, 271 (TFL_MulOp 272 (TFL_AddOp 273 $x, 274 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">), 275 TFL_AF_Relu6), 276 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">), 277 TFL_AF_None), 278 TFL_AF_None), 279 (TFL_HardSwishOp $x)>; 280 281def MatchHardSwishPattern3 : Pat< 282 (TFL_MulOp 283 (TFL_MulOp 284 $x, 285 (TFL_AddOp 286 $x, 287 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">), 288 TFL_AF_Relu6), 289 TFL_AF_None), 290 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">), 291 TFL_AF_None), 292 (TFL_HardSwishOp $x)>; 293 294def MatchHardSwishPattern4 : Pat< 295 (TFL_MulOp 296 (TFL_MulOp 297 (TFL_AddOp 298 $x, 299 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">), 300 TFL_AF_Relu6), 301 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">), 302 TFL_AF_None), 303 $x, 304 TFL_AF_None), 305 (TFL_HardSwishOp $x)>; 306 307// Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to 308// incorrect placement in the quantization aware training. 309def MatchHardSwishQuantized : Pat< 310 (TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp 311 (TFL_MulOp 312 $x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp 313 $x, 314 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">), 315 TFL_AF_Relu6), $qattr2)), 316 TFL_AF_None), $qattr1)), 317 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">), 318 TFL_AF_None), 319 (TFL_HardSwishOp $x)>; 320 321// Constraint that the attribute value is less than 'n' 322class ConstDoubleValueLessThan<string n> : Constraint< 323 CPred<"$0.isa<DenseElementsAttr>() && " 324 "$0.cast<DenseElementsAttr>().getNumElements() == 1 && " 325 "std::abs(*$0.cast<DenseElementsAttr>().getValues<float>().begin()) < " 326 # n>>; 327 328def L2NormValidReduceIndex : Constraint<CPred< 329 "L2NormalizeReduceAxis($0, $1.cast<DenseElementsAttr>())">>; 330 331// Currently L2Normalization doesn't support activation function 332// in TFLite. 333// TODO(karimnosseir): Add constraints that the kernel code assumes. 334// constraint on axis and depth. 335multiclass L2NormalizePatterns<Op FirstOp, Op SecondOp> { 336 // This pattern constructs L2NormalizationOp from 337 // Mul->Rsqrt->Sum->Square Or 338 // Div->sqrt->Sum->Square 339 def L2NormalizePattern1#FirstOp#SecondOp : Pat< 340 (FirstOp $x, 341 (SecondOp 342 (TFL_SumOp 343 (TFL_SquareOp:$sq_op $x), 344 (Arith_ConstantOp I32ElementsAttr:$axis), 345 $keep_dims)), 346 TFL_AF_None), 347 (TFL_L2NormalizationOp $x, TFL_AF_None), 348 [(L2NormValidReduceIndex $sq_op, $axis)]>; 349 350 // Below patterns for L2Normalize when there is an Add or Maximum 351 // adding or clamping to a small constant scalar. 352 def L2NormalizePattern2#FirstOp#SecondOp : Pat< 353 (FirstOp $x, 354 (SecondOp 355 (TFL_AddOp 356 (TFL_SumOp 357 (TFL_SquareOp:$sq_op $x), 358 (Arith_ConstantOp I32ElementsAttr:$axis), 359 $keep_dims), 360 (Arith_ConstantOp $epsilon), TFL_AF_None)), 361 TFL_AF_None), 362 (TFL_L2NormalizationOp $x, TFL_AF_None), 363 [(L2NormValidReduceIndex $sq_op, $axis), 364 (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>; 365 366 def L2NormalizePattern3#FirstOp#SecondOp : Pat< 367 (FirstOp $x, 368 (SecondOp 369 (TFL_MaximumOp 370 (TFL_SumOp 371 (TFL_SquareOp:$sq_op $x), 372 (Arith_ConstantOp I32ElementsAttr:$axis), 373 $keep_dims), 374 (Arith_ConstantOp $epsilon))), 375 TFL_AF_None), 376 (TFL_L2NormalizationOp $x, TFL_AF_None), 377 [(L2NormValidReduceIndex $sq_op, $axis), 378 (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>; 379 380} 381 382foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]] 383 in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>; 384 385//===----------------------------------------------------------------------===// 386// Binary ops patterns. 387//===----------------------------------------------------------------------===// 388def AreBroadcastableTypes : Constraint<CPred< 389 "TFL::IsBroadcastableElementsAttrAndType($0.getType(), $1.getType())">>; 390 391def OperandsBroadcastToOutputType : Constraint<CPred< 392 "TFL::OperandsBroadcastToOutputType($0.getType(), $1.getType(), " 393 "$2.getType())">>; 394 395def IsTailOfShape : Constraint<CPred< 396 "TFL::IsTailOfShape($0.getType(), $1.getType())">>; 397 398def Flatten : NativeCodeCall< 399 "$0.cast<DenseElementsAttr>()" 400 ".reshape(RankedTensorType::get({$0.getType().cast<ShapedType>().getNumElements()}, " 401 "$0.getType().cast<ShapedType>().getElementType()))">; 402 403def IsLastDimEqualToNumElements : Constraint<CPred< 404 "$0.getType().cast<ShapedType>().getRank() >= 1 && " 405 "$0.getType().cast<ShapedType>().getDimSize($0.getType().cast<ShapedType>().getRank() - 1) == " 406 "$1.getType().cast<ShapedType>().getNumElements()">>; 407 408def IsDefinedByFullyConnectedOp : Constraint<CPred< 409 "$0.getDefiningOp<TFL::FullyConnectedOp>() != nullptr">>; 410 411// Pattern for skipping Tile if it is mainly for broadcasting and the 412// Op is already supporting broadcasting. 413multiclass FuseTileBroadcastIntoFollowingBinary<Op BinaryOp> { 414 def FuseTileBroadcastToBinaryOp1#BinaryOp : Pat< 415 (BinaryOp:$result (TFL_TileOp $input, (Arith_ConstantOp $tile)), 416 $operand, $act_func), 417 (BinaryOp $input, $operand, $act_func), 418 [(OperandsBroadcastToOutputType $input, $operand, $result), 419 (HasRankAtMost<4> $input), 420 (HasRankAtMost<4> $operand)]>; 421 422 def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat< 423 (BinaryOp:$result $operand, 424 (TFL_TileOp $input, (Arith_ConstantOp $tile)), $act_func), 425 (BinaryOp $operand, $input, $act_func), 426 [(OperandsBroadcastToOutputType $operand, $input, $result), 427 (HasRankAtMost<4> $operand), 428 (HasRankAtMost<4> $input)]>; 429} 430 431// Multi-pattern consisting of matching stand-alone op or op followed by relu. 432multiclass FusedBinaryActivationFuncOpPat<Op BinaryOp> { 433 foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], 434 [TFL_Relu6Op, TFL_AF_Relu6], 435 [TFL_Relu1Op, TFL_AF_Relu1]] in { 436 def FuseBinaryWithActivation#BinaryOp#actFnPair[0] : Pat< 437 (actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)), 438 (BinaryOp $lhs, $rhs, actFnPair[1]), 439 [(HasOneUse $binary_out)]>; 440 } 441} 442 443foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in { 444 defm : FuseTileBroadcastIntoFollowingBinary<BinaryOp>; 445 446 // Instantiated FusedBinary patterns for the from-to pairs of ops. 447 defm : FusedBinaryActivationFuncOpPat<BinaryOp>; 448 449 // Move binary op before reshape: reshape -> binary => binary -> reshape. 450 // This is valid only when the binary operand is constant and the shape is the 451 // tail of the other operand and the intermediate result isn't used by other 452 // ops. 453 // $rhs is required to be the tail shape of $lhs, so after transformation the 454 // shape of the binary op result is valid. For example, assume the shapes of 455 // $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the 456 // transformation, the shape of the binary op result is [40x1600], which 457 // couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to 458 // make sure $rhs is the tail shape of $lhs. 459 def MoveBinaryOpConstBeforeReshape#BinaryOp : Pat< 460 (BinaryOp (TFL_ReshapeOp:$lhs $input, (Arith_ConstantOp:$shape $s)), 461 (Arith_ConstantOp:$rhs $a), $act_fn), 462 (TFL_ReshapeOp (BinaryOp $input, $rhs, $act_fn), $shape), 463 // The broadcasting of "BinaryOp" only happens in the lower 464 // dimensions, and the higher dimensions are same, so we know the 465 // result and input of the "BinaryOp" in the source pattern have 466 // the same shape, which is defined by `shape`. 467 [(IsTailOfShape $rhs, $lhs), 468 (HasOneUse $lhs), 469 // The result of the new "BinaryOp" will have the same shape as 470 // `input`. In other words, the shape of the `Reshape` op are not 471 // changed after the transformation. 472 (IsTailOfShape $rhs, $input), 473 (HasRankAtMost<4> $input), 474 (HasRankAtMost<4> $lhs), 475 (HasRankAtMost<4> $rhs), 476 (SameElementType $input, $rhs)]>; 477 478 // Move binary op before reshape: 479 // binary(reshape(lhs), reshape(rhs)) => reshape(binary(lhs, rhs)) 480 // This is valid only when both side of the binary operand is reshaped, and 481 // the sizes are the same both before and after the reshape. 482 def MoveBinaryOpBeforeReshape#BinaryOp : Pat< 483 (BinaryOp (TFL_ReshapeOp:$lhs $input1, (Arith_ConstantOp:$shape1 $s1)), 484 (TFL_ReshapeOp:$rhs $input2, (Arith_ConstantOp:$shape2 $s2)), 485 $act_fn), 486 (TFL_ReshapeOp (BinaryOp $input1, $input2, $act_fn), $shape1), 487 [(IsTailOfShape $rhs, $lhs), 488 (IsTailOfShape $lhs, $rhs), 489 (IsTailOfShape $input1, $input2), 490 (IsTailOfShape $input2, $input1), 491 (SameElementType $input1, $input2)]>; 492 493 // Move binary op before reshape: 494 // binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs))) 495 // This is valid only when the last dimension of lhs is equal to the 496 // number of elements in constant rhs. 497 // Therefore, after transformation broadcast of binary op is always 498 // applied to the last dimension of $input. 499 def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat< 500 (BinaryOp (TFL_ReshapeOp:$lhs $input, (Arith_ConstantOp:$shape $s)), 501 (Arith_ConstantOp:$rhs ElementsAttr:$rhs_attr), $act_fn), 502 (TFL_ReshapeOp (BinaryOp $input, (Arith_ConstantOp (Flatten $rhs_attr)), 503 $act_fn), 504 $shape), 505 [(AnyStaticShapeTensor $input), 506 (IsTailOfShape $rhs, $lhs), 507 (IsLastDimEqualToNumElements $input, $rhs), 508 (HasOneUse $lhs), 509 // Restrict operands to have at most rank 4 because TFLite binary 510 // kernel supports up to 4D broadcast. 511 (HasRankAtMost<4> $input), 512 (HasRankAtMost<4> $lhs), 513 (HasRankAtMost<4> $rhs), 514 (IsDefinedByFullyConnectedOp $input)]>; 515} 516 517foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, 518 TFL_MaximumOp, TFL_LessOp, TFL_LessEqualOp, TFL_GreaterOp, 519 TFL_GreaterEqualOp] in { 520 // Move binary op before reshape: reshape -> binary => binary -> reshape. 521 // This is valid only when the binary operand is constant and the shape is the 522 // tail of the other operand and the intermediate result isn't used by other 523 // ops. 524 // $rhs is required to be the tail shape of $lhs, so after transformation the 525 // shape of the binary op result is valid. For example, assume the shapes of 526 // $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the 527 // transformation, the shape of the binary op result is [40x1600], which 528 // couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to 529 // make sure $rhs is the tail shape of $lhs. 530 def MoveBinaryOpConstBeforeReshape#BinaryOp : Pat< 531 (BinaryOp (TFL_ReshapeOp:$lhs $input, (Arith_ConstantOp:$shape $s)), 532 (Arith_ConstantOp:$rhs $a)), 533 (TFL_ReshapeOp (BinaryOp $input, $rhs), $shape), 534 // The broadcasting of "BinaryOp" only happens in the lower 535 // dimensions, and the higher dimensions are same, so we know the 536 // result and input of the "BinaryOp" in the source pattern have 537 // the same shape, which is defined by `shape`. 538 [(IsTailOfShape $rhs, $lhs), 539 (HasOneUse $lhs), 540 // The result of the new "BinaryOp" will have the same shape as 541 // `input`. In other words, the shape of the `Reshape` op are not 542 // changed after the transformation. 543 (IsTailOfShape $rhs, $input), 544 (HasRankAtMost<4> $input), 545 (HasRankAtMost<4> $lhs), 546 (HasRankAtMost<4> $rhs), 547 (SameElementType $input, $rhs)]>; 548 549 // Move binary op before reshape: 550 // binary(reshape(lhs), reshape(rhs)) => reshape(binary(lhs, rhs)) 551 // This is valid only when both side of the binary operand is reshaped, and 552 // the sizes are the same both before and after the reshape. 553 def MoveBinaryOpBeforeReshape#BinaryOp : Pat< 554 (BinaryOp (TFL_ReshapeOp:$lhs $input1, (Arith_ConstantOp:$shape1 $s1)), 555 (TFL_ReshapeOp:$rhs $input2, (Arith_ConstantOp:$shape2 $s2))), 556 (TFL_ReshapeOp (BinaryOp $input1, $input2), $shape1), 557 [(IsTailOfShape $rhs, $lhs), 558 (IsTailOfShape $lhs, $rhs), 559 (IsTailOfShape $input1, $input2), 560 (IsTailOfShape $input2, $input1), 561 (SameElementType $input1, $input2)]>; 562 563 // Move binary op before reshape: 564 // binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs))) 565 // This is valid only when the last dimension of lhs is equal to the 566 // number of elements in constant rhs. 567 // Therefore, after transformation broadcast of binary op is always 568 // applied to the last dimension of $input. 569 def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat< 570 (BinaryOp (TFL_ReshapeOp:$lhs $input, (Arith_ConstantOp:$shape $s)), 571 (Arith_ConstantOp:$rhs ElementsAttr:$rhs_attr)), 572 (TFL_ReshapeOp (BinaryOp $input, (Arith_ConstantOp (Flatten $rhs_attr))), 573 $shape), 574 [(AnyStaticShapeTensor $input), 575 (IsTailOfShape $rhs, $lhs), 576 (IsLastDimEqualToNumElements $input, $rhs), 577 (HasOneUse $lhs), 578 // Restrict operands to have at most rank 4 because TFLite binary 579 // kernel supports up to 4D broadcast. 580 (HasRankAtMost<4> $input), 581 (HasRankAtMost<4> $lhs), 582 (HasRankAtMost<4> $rhs), 583 (IsDefinedByFullyConnectedOp $input)]>; 584} 585 586// Reorder the element-wise value operations and the element move operations, 587// such that the value operation happens before move operation. 588foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp, 589 TFL_ReluOp, TFL_Relu1Op, TFL_Relu6Op, TFL_RoundOp, 590 TFL_TanhOp, TFL_SqrtOp, TFL_SquareOp, TFL_LogisticOp] in { 591 foreach MoveOp = [TFL_DepthToSpaceOp, TFL_ExpandDimsOp, TFL_SqueezeOp, 592 TFL_ReshapeOp, TFL_TransposeOp] in { 593 def ReorderElementwiseAndMoveOperations#ValueOp#MoveOp : Pat< 594 (ValueOp:$value (MoveOp:$move $input, $move_def)), 595 (MoveOp (ValueOp $input), $move_def), 596 [(SameElementType $input, $value), (HasOneUse $move)]>; 597 } 598} 599 600// Returns shape of a ranked tensor. 601// if called without a ranked tensor it will fail. 602def GetShape: NativeCodeCall<"GetShape($0)">; 603 604// Returns True if the operand type is RankedTensorType and valid. 605def HasValidRankedTensor : Constraint<CPred< 606 "$0.getType().isa<RankedTensorType>() && " 607 "$0.getType().cast<RankedTensorType>().getNumDynamicDims() <= 1">>; 608 609def ConvertSqueezeToReshape : Pat< 610 (TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims), 611 (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShape $squeeze_op))), 612 [(HasValidRankedTensor $squeeze_op)]>; 613 614// Convert expand_dims to reshape if possible. 615def ConvertExpandDimsToReshape : Pat< 616 (TFL_ExpandDimsOp:$expand_dims_op $input, $dim), 617 (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShape $expand_dims_op))), 618 [(AnyStaticShapeTensor $expand_dims_op)]>; 619 620class FloatValueEquals<string val> : Constraint<CPred< 621 "FloatValueEquals($0, " # val # ")">>; 622 623// ReLU patterns 624def MatchReluPattern : Pat< 625 (TFL_MaximumOp $input, (Arith_ConstantOp $Zero)), 626 (TFL_ReluOp $input), 627 [(FloatValueEquals<"0"> $Zero)]>; 628 629def MatchRelu1Pattern1 : Pat< 630 (TFL_MinimumOp (TFL_MaximumOp $input, (Arith_ConstantOp $NegOne)), 631 (Arith_ConstantOp $One)), 632 (TFL_Relu1Op $input), 633 [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>; 634 635def MatchRelu1Pattern2 : Pat< 636 (TFL_MaximumOp (TFL_MinimumOp $input, (Arith_ConstantOp $One)), 637 (Arith_ConstantOp $NegOne)), 638 (TFL_Relu1Op $input), 639 [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>; 640 641def MatchLeakyRelu : Pat< 642 (TFL_MaximumOp 643 (TFL_MulOp:$mul_out $x, 644 (Arith_ConstantOp F32ElementsAttr:$alpha), TFL_AF_None), 645 $x), 646 (TFL_LeakyReluOp $x, ExtractSingleElementAsFloat:$alpha), 647 [(ConstDoubleValueLessThan<"1"> $alpha), 648 (HasOneUse $mul_out)]>; 649 650// Returns True if all users of this operation are in TF/TFL and don't need 651// shape exact matching. This prevents from removing cast on return values which 652// can break the verifier on function type mismatch. 653def AllUsersInTF : Constraint<CPred<[{ 654 llvm::all_of($0.getUsers(), [&](Operation *user) { 655 auto name = user->getName().getDialectNamespace(); 656 return name == "tf" || name == "tfl"; 657 }) 658 }]>, "all users are TF/TFL operations.">; 659 660def RemoveShapeOnlyCast : Pat<(TFL_CastOp:$output $input), 661 (replaceWithValue $input), 662 [(SameElementType $input, $output), 663 (AllUsersInTF $output)]>; 664 665 666// Checks if the operand0's rank is one less than operand1's rank. 667def PReluAlphaRankCheck : Constraint< 668 CPred<"$0.getType().cast<ShapedType>().getRank() == " 669 "$1.getType().cast<ShapedType>().getRank() - 1">>; 670 671// PReLU pattern from Keras: 672// f(x) = Relu(x) + (-alpha * Relu(-x)) 673def MatchPRelu : Pat< 674 (TFL_AddOp 675 (TFL_ReluOp:$relu_out $x), 676 (TFL_MulOp:$mul_out 677 (TFL_ReluOp (TFL_NegOp:$input_neg_out $x)), 678 $neg_alpha, 679 TFL_AF_None), 680 TFL_AF_None), 681 (TFL_PReluOp $x, (TFL_NegOp $neg_alpha)), 682 [(PReluAlphaRankCheck $neg_alpha, $x), 683 (HasOneUse $relu_out), 684 (HasOneUse $mul_out), 685 (HasOneUse $input_neg_out)]>; 686 687// The constant folding in this pass might produce constant in the tf dialect. 688// This rule is to legalize these constant to the tfl dialect. 689def LegalizeConstOp : Pat< 690 (TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>; 691 692// Reorders adds to allow constant folding. 693// Add --> Add $input, $constantA 694// \--> $constantB 695// To 696// Add --> $input 697// \--> Add ($constantA, $constantB) 698foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in { 699 def ReorderAddToAllowConstFold_ActFunc_#ActFun : Pat< 700 (TFL_AddOp 701 (TFL_AddOp:$first_output $input, (Arith_ConstantOp $a), TFL_AF_None), 702 (Arith_ConstantOp $b), ActFun), 703 (TFL_AddOp $input, 704 (TFL_AddOp (Arith_ConstantOp $a), (Arith_ConstantOp $b), TFL_AF_None), 705 ActFun), 706 [(HasOneUse $first_output), 707 (HasRankAtMost<4> $input), 708 (HasRankAtMost<4> $a), 709 (HasRankAtMost<4> $b)]>; 710} 711 712// We can eliminate Relu from Relu(SquaredDifference(x, y)), 713// since the result of SquaredDifference is always non-negative. 714// TFLite interpreter doesn't support Relu+int32 for now. So the test cases 715// are failing without the following pattern to optimize Relu away fixes 716// the problem. 717def OptimizeReluSquaredDifference : Pat< 718 (TFL_ReluOp (TFL_SquaredDifferenceOp $l, $r)), 719 (TFL_SquaredDifferenceOp $l, $r)>; 720 721// Optimize X^1 o X 722def OptimizePow1ToIdentity : Pat< 723 (TFL_PowOp $input, 724 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">)), 725 (replaceWithValue $input)>; 726 727// Optimize X^2 to X*X 728def OptimizePow2ToSquare : Pat< 729 (TFL_PowOp $input, 730 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "2.0f">)), 731 (TFL_MulOp $input, $input, TFL_AF_None)>; 732 733// Optimize X^(1/2) to √X 734def OptimizePow2ToSqrt : Pat< 735 (TFL_PowOp $input, 736 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.5f">)), 737 (TFL_SqrtOp $input)>; 738 739// Optimize X^(-1/2) to 1/√X == rsqrt(x) 740def OptimizePow2ToRsqrt : Pat< 741 (TFL_PowOp $input, 742 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "-0.5f">)), 743 (TFL_RsqrtOp $input)>; 744 745def CanOptimizeIdentityGatherNdOrScatterNdOp : Constraint<CPred< 746 "TFL::CanOptimizeIdentityGatherNdOrScatterNdOp(" 747 "$0, $1.cast<DenseIntElementsAttr>(), $2.getType())">>; 748 749def OptimizeIdentityGatherNdOp : Pat< 750 (TFL_GatherNdOp:$output $params, (Arith_ConstantOp I32ElementsAttr: $indices)), 751 (replaceWithValue $params), 752 [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices, $output)]>; 753 754def OptimizeIdentityScatterNdOp : Pat< 755 (TFL_ScatterNdOp:$output (Arith_ConstantOp I32ElementsAttr: $indices), $params, $ignored), 756 (replaceWithValue $params), 757 [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices, $output)]>; 758 759def ShapeMatchesReduceWithKeepAxes : Constraint<CPred< 760 "ShapeMatchesReduceWithKeepAxes($0, $1, $2)">>; 761 762// Fold reshapes re-inserting reduced dimensions into the results of a reduction 763// with `keep_dims=false` by changing it to one using `keep_dims=true`. 764foreach ReduceOp = [TFL_MeanOp, TFL_ReduceMaxOp, TFL_ReduceMinOp, 765 TFL_ReduceProdOp, TFL_SumOp] in { 766 def FoldReshapeTo#ReduceOp : Pat< 767 (TFL_ReshapeOp 768 (ReduceOp:$reduce $input, (Arith_ConstantOp I32ElementsAttr: $axes), 769 ConstBoolAttrFalse), 770 (Arith_ConstantOp I32ElementsAttr: $shape)), 771 (ReduceOp $input, (Arith_ConstantOp $axes), ConstBoolAttrTrue), 772 [(ShapeMatchesReduceWithKeepAxes $input, $axes, $shape), 773 (HasOneUse $reduce)]>; 774} 775 776 777def IsSame : Constraint<CPred<"$0 == $1">>; 778def HasTwoUse : Constraint<CPred< 779 "std::distance($0.use_begin(), $0.use_end()) == 2">>; 780def AxesIsLastDimension : Constraint<CPred< 781 "$0.cast<DenseIntElementsAttr>().getNumElements() == 1 && " 782 "($0.cast<DenseIntElementsAttr>().getValues<APInt>()[0] == " 783 "$1.getType().cast<ShapedType>().getRank() - 1 || $0.cast<DenseIntElementsAttr>().getValues<int32_t>()[0] == -1)">>; 784 785// Convert exp(x)/sum(exp(x)) into softmax. 786def OptimizeToSoftmax : Pat< 787 (TFL_DivOp (TFL_ExpOp:$exp $input), 788 (TFL_SumOp:$sum $sum_input, (Arith_ConstantOp I32ElementsAttr: $axes), 789 ConstBoolAttrTrue), TFL_AF_None), 790 (TFL_SoftmaxOp $input, ConstF32Attr<"1.0">), 791 [(IsSame $exp, $sum_input), 792 (AxesIsLastDimension $axes, $sum_input), 793 (HasTwoUse $exp), 794 (HasOneUse $sum)]>; 795 796// Convert softmax(x-max(x)) into softmax(x) as the softmax op already deals 797// with the max normalization. 798def FoldNormalizationIntoSoftmax : Pat< 799 (TFL_SoftmaxOp 800 (TFL_SubOp:$sub $input, 801 (TFL_ReduceMaxOp:$max $max_input, (Arith_ConstantOp I32ElementsAttr: $axes), 802 ConstBoolAttrTrue), 803 TFL_AF_None), 804 $beta), 805 (TFL_SoftmaxOp $input, $beta), 806 [(IsSame $input, $max_input), 807 (AxesIsLastDimension $axes, $max_input), 808 (HasOneUse $sub), 809 (HasOneUse $max)]>; 810 811def HaveSameType : Constraint<CPred<"($0.getType() == $1.getType())">>; 812 813class AllElementsAreF32<string val> : Constraint<CPred< 814 "($0.isa<DenseElementsAttr>() && " 815 "$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isF32() && " 816 "std::all_of($0.cast<DenseElementsAttr>().getValues<float>().begin(), " 817 "$0.cast<DenseElementsAttr>().getValues<float>().end(), " 818 "[](float v){ return v == " #val# ";}))">>; 819 820// Optimize X*1 to X 821def OptimizeMul1ToIdentity : Pat< 822 (TFL_MulOp:$result $input, 823 (Arith_ConstantOp $constant), 824 TFL_AF_None), 825 (replaceWithValue $input), 826 [(HaveSameType $input, $result), 827 (AllElementsAreF32<"1.0f"> $constant)]>; 828 829class AllElementsAreBool<string val> : Constraint<CPred< 830 "($0.isa<DenseElementsAttr>() && " 831 "$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isInteger(1) && " 832 "std::all_of($0.cast<DenseElementsAttr>().getValues<bool>().begin(), " 833 "$0.cast<DenseElementsAttr>().getValues<bool>().end(), " 834 "[](bool v){ return v == " #val# ";}))">>; 835 836// Remove select operators when the result is known in advance. 837foreach SelectOp = [TFL_SelectOp, TFL_SelectV2Op] in { 838 // select(true_tensor, A, B) -> A 839 def Optimize#SelectOp#True : Pat< 840 (SelectOp:$result (Arith_ConstantOp $constant), 841 $input1, 842 $input2), 843 (replaceWithValue $input1), 844 [(HaveSameType $input1, $result), 845 (AllElementsAreBool<"true"> $constant)]>; 846 // select(false_tensor, A, B) -> B 847 def Optimize#SelectOp#False : Pat< 848 (SelectOp:$result (Arith_ConstantOp $constant), 849 $input1, 850 $input2), 851 (replaceWithValue $input2), 852 [(HaveSameType $input2, $result), 853 (AllElementsAreBool<"false"> $constant)]>; 854 // select(logical_not(C), A, B) -> select(C, B, A) 855 def Optimize#SelectOp#Not : Pat< 856 (SelectOp (TFL_LogicalNotOp $condition), $input1, $input2), 857 (SelectOp $condition, $input2, $input1)>; 858} 859 860def EliminateLogicalAndTrue : Pat< 861 (TFL_LogicalAndOp:$result $lhs, (Arith_ConstantOp:$rhs $constant)), 862 (replaceWithValue $lhs), 863 [(AllElementsAreBool<"true"> $constant), (HaveSameType $lhs, $result)]>; 864 865def EliminateLogicalAndFalse : Pat< 866 (TFL_LogicalAndOp:$result $lhs, (Arith_ConstantOp:$rhs $constant)), 867 (replaceWithValue $rhs), 868 [(AllElementsAreBool<"false"> $constant), (HaveSameType $rhs, $result)]>; 869 870def EliminateLogicalOrTrue : Pat< 871 (TFL_LogicalOrOp:$result $lhs, (Arith_ConstantOp:$rhs $constant)), 872 (replaceWithValue $rhs), 873 [(AllElementsAreBool<"true"> $constant), (HaveSameType $rhs, $result)]>; 874 875def EliminateLogicalOrFalse : Pat< 876 (TFL_LogicalOrOp:$result $lhs, (Arith_ConstantOp:$rhs $constant)), 877 (replaceWithValue $lhs), 878 [(AllElementsAreBool<"false"> $constant), (HaveSameType $lhs, $result)]>; 879 880// Remove reductions that do nothing: input and output have the same size. 881foreach ReduceOp = [TFL_ReduceAnyOp, TFL_ReduceAllOp, 882 TFL_ReduceMinOp, TFL_ReduceMaxOp, 883 TFL_MeanOp, TFL_SumOp, TFL_ReduceProdOp] in { 884 def EliminateNoOpReductionOp#ReduceOp : Pat< 885 (ReduceOp:$output $input, $index, $keep_dims), 886 (replaceWithValue $input), 887 [(IsTailOfShape $input, $output), 888 (IsTailOfShape $output, $input)]>; 889} 890 891// Remove (log-)softmax before arg-minmax as (log-)softmax is monotonic. 892foreach ArgMinMaxOp = [TFL_ArgMinOp, TFL_ArgMaxOp] in { 893 def RemoveSoftmaxOpBefore#ArgMinMaxOp : Pat< 894 (ArgMinMaxOp (TFL_SoftmaxOp:$softmax $logits, TFL_FloatNonNegative:$beta), 895 (Arith_ConstantOp:$const_axes I32ElementsAttr:$axes)), 896 (ArgMinMaxOp $logits, $const_axes), 897 [(HasOneUse $softmax), 898 (AxesIsLastDimension $axes, $logits)]>; 899 900 def RemoveLogSoftmaxOpBefore#ArgMinMaxOp : Pat< 901 (ArgMinMaxOp (TFL_LogSoftmaxOp:$log_softmax $logits), 902 (Arith_ConstantOp:$const_axes I32ElementsAttr:$axes)), 903 (ArgMinMaxOp $logits, $const_axes), 904 [(HasOneUse $log_softmax), 905 (AxesIsLastDimension $axes, $logits)]>; 906} 907 908def CanOptimizeIdentitySliceOp : Constraint<CPred< 909 "TFL::CanOptimizeIdentitySliceOp($0, $1, $2)">>; 910 911// Remove Slice ops slicing the whole input tensor, effectively no-op. 912def OptimizeSliceOp : Pat< 913 (TFL_SliceOp:$output $input, (Arith_ConstantOp $begin), (Arith_ConstantOp $size)), 914 (replaceWithValue $input), 915 [(CanOptimizeIdentitySliceOp $input, $begin, $size)]>; 916 917def GetNumElementsOrOne: NativeCodeCall<"GetNumElementsOrOne($0)">; 918 919def ReshapeValueDroppingLastDim : NativeCodeCall< 920 "ReshapeValueDroppingLastDim($_builder, $0, $1)">; 921 922def HasExactlyTwoElements : Constraint<CPred< 923 "TFL::HasExactlyTwoElements($0)">>; 924 925def IsLastElementEqualsOne : Constraint<CPred< 926 "TFL::IsLastElementEqualsOne($0)">>; 927 928def IsOneHotIndexAttribute : Constraint<CPred< 929 "TFL::IsOneHotIndexAttribute($0)">>; 930 931// Replace 932// Equal(Reshape(X, shape), indices) 933// With 934// OneHot(Reshape(X, shape[:-1]), N, true, false, -1) 935// where 936// - shape has length 2 (unnecessary, just to be conservative) 937// - last value in shape is 1 938// - indices is a incrementing series from 0 to N-1. (N elements total.) 939def ReshapeEqualOpToOneHotOp : Pat< 940 (TFL_EqualOp (TFL_ReshapeOp $x, (Arith_ConstantOp $shape)), 941 (Arith_ConstantOp $series)), 942 (TFL_OneHotOp (ReshapeValueDroppingLastDim $x, $shape), 943 (Arith_ConstantOp (GetNumElementsOrOne $series)), 944 (Arith_ConstantOp ConstantAttr<RankedSignlessIntElementsAttr<1, []>, "true">), 945 (Arith_ConstantOp ConstantAttr<RankedSignlessIntElementsAttr<1, []>, "false">), 946 ConstantAttr<I32Attr, "-1">), 947 [(HasExactlyTwoElements $shape), 948 (IsLastElementEqualsOne $shape), 949 (IsOneHotIndexAttribute $series)]>; 950 951def F32ElementsVal : Constraint<CPred< 952 "$0.getType().cast<TensorType>().getElementType().isF32()">, 953 "32 bit float tensor">; 954def I32ElementsVal : Constraint<CPred< 955 "$0.getType().cast<TensorType>().getElementType().isInteger(32)">, 956 "32 bit integer tensor">; 957 958def ConvertSingleElementAttrToFloatAttr : 959 NativeCodeCall<"ConvertSingleElementAttrToFloatAttr($0)">; 960 961// Replace 962// (float)OneHot(index, depth, on_val, off_val, axis) 963// With 964// OneHot(index, depth, (float)on_val, (float)off_val, axis) 965def FuseOneHotAndCastToFloat : Pat< 966 (TFL_CastOp:$output (TFL_OneHotOp $indices, 967 $depth, 968 (Arith_ConstantOp $on_val), 969 (Arith_ConstantOp $off_val), 970 $axis)), 971 (TFL_OneHotOp $indices, 972 $depth, 973 (Arith_ConstantOp (ConvertSingleElementAttrToFloatAttr $on_val)), 974 (Arith_ConstantOp (ConvertSingleElementAttrToFloatAttr $off_val)), 975 $axis), 976 [(F32ElementsVal $output)]>; 977 978// Replace 979// OneHot(index, depth, on=1.0f, off=0.0f, axis=-1) * filter 980// With 981// EmbeddingLookup(index, Transpose(filter)) 982// 983// OneHot with on=1 off=0 axis=-1, where `index` is a single element tensor, 984// creates a tensor of size depth, and all values are 0, except for the element 985// at `index`, which is 1. Multiplying such a tensor with a 2D filter esentially 986// returns the single column in filter as a 1D tensor. If the input has multiple 987// elements, repeat this for every entry, forming the higher dimensions in the 988// result tensor. For instance, if: 989// input = [1, 2] 990// depth = 4 991// filter = [[5, 7, 11, 13], [17, 19, 23, 29]] 992// then: 993// onehot = [[0, 1, 0, 0], [0, 0, 1, 0]] 994// result = [[ 7, 19], # == 1st column in filter 995// [11, 23]] # == 2nd column in filter 996// This is exactly what the EmbeddedLookup operator is doing, on the transposed 997// matrix, without doing any arithmetic but only memcpy. 998def ReplaceOneHotFullyConnectedWithLookup : Pat< 999 (TFL_FullyConnectedOp 1000 (TFL_OneHotOp 1001 $indices, 1002 (Arith_ConstantOp $depth), 1003 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">), 1004 (Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.0f">), 1005 ConstantAttr<I32Attr, "-1">), 1006 $filter, 1007 $bias, 1008 TFL_AF_None, 1009 TFL_FCWO_Default, 1010 ConstBoolAttrFalse, 1011 $asymmetric_quantize_inputs), 1012 (TFL_EmbeddingLookupOp 1013 $indices, 1014 (TFL_TransposeOp 1015 $filter, 1016 (Arith_ConstantOp ConstantAttr<RankedI32ElementsAttr<[2]>, "{1,0}"> ))), 1017 [(I32ElementsVal $indices), // lookup is not implemented for i64 1018 (HasRank<1> $indices), // lookup isn't implemented for any other rank 1019 (IsNoneType $bias)]>; // Maybe folded into the lookup matrix later 1020 1021def AreInputDimensionsOneInAxes : Constraint<CPred< 1022 "AreInputDimensionsOneInAxes($0, $1)">>; 1023 1024// Eliminate cumulative summations if the input's dimension in axis is 1. 1025def EliminateCumSumInclusive : Pat< 1026 (TFL_CumsumOp 1027 $input, 1028 (Arith_ConstantOp I32ElementsAttr:$axis), 1029 ConstBoolAttrFalse, 1030 $reverse), 1031 (replaceWithValue $input), 1032 [(AreInputDimensionsOneInAxes $input, $axis)]>; 1033 1034// Fusing raw computation of GELU op into one native tfl_gelu op. 1035// 1036// Requires constants to be exact match and only one use of all of the 1037// intermediate results. 1038// 1039// For GeluApproximate, replaces 1040// 0.5 * x * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * pow( x, 3 ) ) ) ) 1041def MatchGeluApproximate : Pat< 1042 (TFL_MulOp 1043 (TFL_MulOp:$mul_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), 1044 (TFL_AddOp:$add_out 1045 (TFL_TanhOp:$tanh_out 1046 (TFL_MulOp:$mul_out1 1047 (TFL_AddOp:$add_out1 $arg0, 1048 (TFL_MulOp:$mul_out2 1049 (TFL_PowOp:$pow_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_3)), 1050 (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None), 1051 (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)), 1052 (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), TFL_AF_None), 1053 (TFL_GeluOp $arg0, ConstBoolAttrTrue), 1054 [(FloatValueEquals<"0.5"> $Cst_1_2), 1055 (FloatValueEquals<"1"> $Cst_1), 1056 (FloatValueEquals<"3"> $Cst_3), 1057 (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi), 1058 (FloatValueEquals<"0.044715"> $Coeff), 1059 (HasOneUse $mul_out), 1060 (HasOneUse $add_out), 1061 (HasOneUse $tanh_out), 1062 (HasOneUse $mul_out1), 1063 (HasOneUse $add_out1), 1064 (HasOneUse $mul_out2), 1065 (HasOneUse $pow_out), 1066 ]>; 1067 1068// Alternate pattern for GeluApproximate (see different order for mul), replaces 1069// x * ( 0.5 * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * pow( x, 3 ) ) ) ) ) 1070def MatchGeluApproximate1 : Pat< 1071 (TFL_MulOp $arg0, 1072 (TFL_MulOp:$mul_out 1073 (TFL_AddOp:$add_out 1074 (TFL_TanhOp:$tanh_out 1075 (TFL_MulOp:$mul_out1 1076 (TFL_AddOp:$add_out1 $arg0, 1077 (TFL_MulOp:$mul_out2 1078 (TFL_PowOp:$pow_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_3)), 1079 (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None), 1080 (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)), 1081 (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), TFL_AF_None), 1082 (TFL_GeluOp $arg0, ConstBoolAttrTrue), 1083 [(FloatValueEquals<"0.5"> $Cst_1_2), 1084 (FloatValueEquals<"1"> $Cst_1), 1085 (FloatValueEquals<"3"> $Cst_3), 1086 (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi), 1087 (FloatValueEquals<"0.044715"> $Coeff), 1088 (HasOneUse $mul_out), 1089 (HasOneUse $add_out), 1090 (HasOneUse $tanh_out), 1091 (HasOneUse $mul_out1), 1092 (HasOneUse $add_out1), 1093 (HasOneUse $mul_out2), 1094 (HasOneUse $pow_out), 1095 ]>; 1096 1097// For Gelu, replaces 1098// 0.5 * x * ( 1 + erf( x * sqrt_1_2 ) ) 1099def MatchGelu : Pat< 1100 (TFL_MulOp 1101 (TFL_MulOp:$mul_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), 1102 (TFL_AddOp:$add_out 1103 (TF_ErfOp:$erf_out 1104 (TFL_MulOp:$mul_out1 $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_1_2), TFL_AF_None)), 1105 (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), TFL_AF_None), 1106 (TFL_GeluOp $arg0, ConstBoolAttrFalse), 1107 [(FloatValueEquals<"0.5"> $Cst_1_2), 1108 (FloatValueEquals<"1"> $Cst_1), 1109 (FloatValueEquals<"0.707106769"> $Cst_sqrt_1_2), 1110 (HasOneUse $mul_out), 1111 (HasOneUse $add_out), 1112 (HasOneUse $erf_out), 1113 (HasOneUse $mul_out1), 1114 ]>; 1115 1116 1117// Checks if the shape has shape with last dimension equals 1. 1118def IsLastDimensionEqualOne : Constraint<CPred<"IsLastDimensionEqualOne($0)">>; 1119 1120// Fetches the output of FC op, from the provided arguments. 1121def GetFcOutput : NativeCodeCall< 1122 "GetFcOutput(&$_builder, $0, $1, $2, $3, $4, $5, $6, $7)">; 1123 1124// Verifies all values in the provided argument are zero. 1125def AllValuesAreZero : Constraint<CPred<"AllValuesAreZero($0)">>; 1126 1127def SimplifyDoubleSelectFCZerosLHS : Pat< 1128 (TFL_SelectV2Op $condition, $zeros_2, 1129 (TFL_FullyConnectedOp:$results 1130 (TFL_SelectV2Op $condition, $zeros_1, $input), 1131 $filter, $bias, $fused_activation_function, $weights_format, 1132 ConstBoolAttrTrue, $asymmetric_quantize_inputs)), 1133 (TFL_SelectV2Op $condition, $zeros_2, 1134 (GetFcOutput $results, $input, $filter, $bias, $fused_activation_function, 1135 $weights_format, ConstBoolAttrTrue, $asymmetric_quantize_inputs)), 1136 [(IsLastDimensionEqualOne $condition), 1137 (AllValuesAreZero $zeros_1), 1138 (AllValuesAreZero $zeros_2) 1139 ]>; 1140 1141def SimplifyDoubleSelectFCZerosRHS : Pat< 1142 (TFL_SelectV2Op $condition, 1143 (TFL_FullyConnectedOp:$results 1144 (TFL_SelectV2Op $condition, $input, $zeros_1), 1145 $filter, $bias, $fused_activation_function, $weights_format, 1146 ConstBoolAttrTrue, $asymmetric_quantize_inputs), 1147 $zeros_2), 1148 (TFL_SelectV2Op $condition, 1149 (GetFcOutput $results, $input, $filter, $bias, $fused_activation_function, 1150 $weights_format, ConstBoolAttrTrue, $asymmetric_quantize_inputs), 1151 $zeros_2), 1152 [(IsLastDimensionEqualOne $condition), 1153 (AllValuesAreZero $zeros_1), 1154 (AllValuesAreZero $zeros_2) 1155 ]>; 1156