1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H 17 #define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H 18 19 #include "mlir/IR/PatternMatch.h" // from @llvm-project 20 #include "mlir/Support/LLVM.h" // from @llvm-project 21 22 // This file contains legalizations common to mapping both TensorFlow and 23 // TensorFlow Lite to TOSA. 24 // 25 // Conversion functions return None on a failure or result value on success. 26 // Callers must check and return a LogicalResult failure on nullptr. 27 // 28 // For these functions, the framework-specific operands/attributes/defaults 29 // are already extracted and placed in a common form for lowering. 30 31 namespace mlir { 32 namespace tosa { 33 34 // Lowers the Pack operator to TOSA. 35 llvm::Optional<Value> convertPackOp(PatternRewriter& rewriter, Operation* op, 36 Value result_value, 37 SmallVectorImpl<Value>& inputs, 38 int32_t axis); 39 40 // Lowers the Unpack operator to TOSA. 41 llvm::Optional<SmallVector<Value>> convertUnpackOp(PatternRewriter& rewriter, 42 Operation* op, 43 Value input_value, 44 int32_t axis); 45 46 // Lowers the Select operator to TOSA. 47 llvm::Optional<Value> convertSelectOp(PatternRewriter& rewriter, Operation* op, 48 Value result_value, Value condition_value, 49 Value x_value, Value y_value); 50 51 // Lowers the ZerosLike operator to TOSA by creating a constant 52 // of the desired type and shape. 53 llvm::Optional<Value> convertZerosLikeOp(PatternRewriter& rewriter, 54 Operation* op, Value result, 55 Value input); 56 57 // Lowers the Mul operator to TOSA. For quantized types, this requires 58 // inserting rescale operators before and after the operation. 59 llvm::Optional<Value> convertMultiplyOp(PatternRewriter& rewriter, 60 Operation* op, Value output_val, 61 Value input_lhs_val, 62 Value input_rhs_val); 63 64 // Lowers the SquaredDifference operator to TOSA. 65 llvm::Optional<Value> convertSquaredDifferenceOp(PatternRewriter& rewriter, 66 Operation* op, Value result, 67 Value x, Value y); 68 69 // Lowers the Round operator to TOSA. 70 llvm::Optional<Value> convertRoundOp(PatternRewriter& rewriter, Operation* op, 71 Value result, Value input); 72 73 // Lowers ConcatV2 to TOSA. 74 llvm::Optional<Value> convertConcatV2Op(PatternRewriter& rewriter, 75 Operation* op, Value result_value, 76 SmallVectorImpl<Value>& values, 77 int32_t axis); 78 79 // Lowers SpaceToBatchND to TOSA. 80 llvm::Optional<Value> convertSpaceToBatchNDOp(PatternRewriter& rewriter, 81 Operation* op, Value result_value, 82 Value input_value, 83 Value block_shape_value, 84 Value paddings_value); 85 86 // Lowers BatchToSpaceND to TOSA. 87 llvm::Optional<Value> convertBatchToSpaceNDOp(PatternRewriter& rewriter, 88 Operation* op, Value result_value, 89 Value input_value, 90 Value block_shape_value, 91 Value crops_value); 92 93 // Lowers ExpandDims to TOSA. 94 llvm::Optional<Value> convertExpandDimsOp(PatternRewriter& rewriter, 95 Operation* op, Value result_value, 96 Value input_value, Value dim_value); 97 98 // Lowers Squeeze to TOSA. 99 llvm::Optional<Value> convertSqueezeOp(PatternRewriter& rewriter, Operation* op, 100 Value result_value, Value input_value, 101 SmallVectorImpl<int32_t>& squeeze_dims); 102 103 // Lowers ELU to a sequence of TOSA ops. 104 llvm::Optional<Value> convertEluOp(PatternRewriter& rewriter, Operation* op, 105 Value result_value, Value features_value); 106 107 // Lowers Softmax to a sequence of TOSA ops. 108 llvm::Optional<Value> convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, 109 Value result_value, Value logits_value, 110 double beta); 111 112 // Lowers LogSoftmax to a sequence of TOSA ops. 113 llvm::Optional<Value> convertLogSoftmaxOp(PatternRewriter& rewriter, 114 Operation* op, Value result_value, 115 Value logits_value); 116 117 // Lowers SpaceToDepth to a sequence of TOSA ops. Supports NHWC. 118 llvm::Optional<Value> convertSpaceToDepthOp(PatternRewriter& rewriter, 119 Operation* op, Value result_value, 120 Value input_value, 121 IntegerAttr block_size_attr, 122 StringAttr data_format); 123 124 // Lowers DepthToSpace to a sequence of TOSA ops. Supports NHWC. 125 llvm::Optional<Value> convertDepthToSpaceOp(PatternRewriter& rewriter, 126 Operation* op, Value result_value, 127 Value input_value, 128 IntegerAttr block_size_attr, 129 StringAttr data_format); 130 131 // Lowers Split to a sequence of TOSA ops. 132 llvm::Optional<SmallVector<Value>> convertSplitOp( 133 PatternRewriter& rewriter, Operation* op, Value result_value, 134 Value input_value, int32_t num_split, int32_t axis); 135 136 // Lowers SplitV to a sequence of TOSA ops. 137 llvm::Optional<SmallVector<Value>> convertSplitVOp( 138 PatternRewriter& rewriter, Operation* op, Value result_value, 139 Value input_value, SmallVectorImpl<int32_t>& size_split, int32_t axis); 140 141 // Lowers StridedSlice to a sequence of TOSA ops. 142 llvm::Optional<Value> convertStridedSliceOp( 143 PatternRewriter& rewriter, Operation* op, Value result_value, 144 Value input_value, Value begin_value, Value end_value, Value strides_value, 145 int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask, 146 int32_t new_axis_mask, int32_t shrink_axis_mask); 147 148 // Lowers FloorDiv to a sequence of TOSA operators. 149 llvm::Optional<Value> convertFloorDivOp(PatternRewriter& rewriter, 150 Operation* op, Value result_value, 151 Value lhs_value, Value rhs_value); 152 153 // Lowers FloorMod to a sequence of TOSA operators. 154 llvm::Optional<Value> convertFloorModOp(PatternRewriter& rewriter, 155 Operation* op, Value result_value, 156 Value lhs_value, Value rhs_value); 157 158 // Lowers FusedActivation to a sequence of TOSA ops. 159 llvm::Optional<Value> convertFusedActivation(PatternRewriter& rewriter, 160 Operation* op, Value input_value, 161 StringAttr fused_activation_fn); 162 163 // Helper function for implementing quantized divide by power-of-two in TOSA 164 // ops. 165 llvm::Optional<Value> convertRoundingDivideByPOT(PatternRewriter& rewriter, 166 Operation* op, 167 Value input_value, 168 Value rshift_value); 169 170 // Lowers ReduceAll to a sequence of TOSA ops. 171 llvm::Optional<Value> convertReduceAllOp(PatternRewriter& rewriter, 172 Operation* op, 173 RankedTensorType output_type, 174 Value input_value, 175 ElementsAttr axes_elems); 176 177 // Lowers ReduceAny to a sequence of TOSA ops. 178 llvm::Optional<Value> convertReduceAnyOp(PatternRewriter& rewriter, 179 Operation* op, 180 RankedTensorType output_type, 181 Value input_value, 182 ElementsAttr axes_elems); 183 184 // Lowers ReduceMin to a sequence of TOSA ops. 185 llvm::Optional<Value> convertReduceMinOp(PatternRewriter& rewriter, 186 Operation* op, 187 RankedTensorType output_type, 188 Value input_value, 189 ElementsAttr axes_elems); 190 191 // Lowers ReduceMax to a sequence of TOSA ops. 192 llvm::Optional<Value> convertReduceMaxOp(PatternRewriter& rewriter, 193 Operation* op, 194 RankedTensorType output_type, 195 Value input_value, 196 ElementsAttr axes_elems); 197 198 // Lowers ReduceProd to a sequence of TOSA ops. 199 llvm::Optional<Value> convertReduceProdOp(PatternRewriter& rewriter, 200 Operation* op, 201 RankedTensorType output_type, 202 Value input_value, 203 ElementsAttr axes_elems); 204 205 // Lowers ReduceSum to a sequence of TOSA ops. 206 llvm::Optional<Value> convertReduceSumOp(PatternRewriter& rewriter, 207 Operation* op, 208 RankedTensorType output_type, 209 Value input_value, 210 ElementsAttr axes_elems); 211 212 // Lowers ReduceMean to a sequence of TOSA ops. 213 llvm::Optional<Value> convertReduceMeanOp(PatternRewriter& rewriter, 214 Operation* op, 215 RankedTensorType output_type, 216 Value input_value, 217 ElementsAttr axes_elem); 218 219 // Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize. 220 llvm::Optional<Value> convertResizeOp(PatternRewriter& rewriter, Operation* op, 221 RankedTensorType output_type, 222 Value input_value, StringRef mode, 223 bool align_corners, 224 bool half_pixel_centers); 225 226 // Lowers Quantize to a sequence of TOSA quantization ops. 227 llvm::Optional<Value> convertQuantizeOp(PatternRewriter& rewriter, 228 Operation* op, ShapedType output_type, 229 Value input_value, double scale, 230 int64_t zeropoint); 231 232 // Lowers Dequantize to a sequence of TOSA dequantization ops. 233 llvm::Optional<Value> convertDequantizeOp(PatternRewriter& rewriter, 234 Operation* op, ShapedType output_type, 235 Value input_value, 236 ArrayRef<float> scale, 237 ArrayRef<float> zeropoint, 238 int64_t dim); 239 240 // Lowers FakeQuant to a sequence of TOSA quantization ops. 241 llvm::Optional<Value> convertFakeQuantOp(PatternRewriter& rewriter, 242 Operation* op, ShapedType output_type, 243 Value input_value, double min, 244 double max, int64_t num_bits, 245 bool narrow_range); 246 247 // Lowers TensorFlow Conv2D to a sequence of TOSA quantization ops. 248 llvm::Optional<Value> convertTFConv2DCommon( 249 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, 250 Value input, Value filter, Value bias, ArrayAttr strides_attr, 251 ArrayAttr dilations_attr, ArrayAttr explicit_padding_attr, 252 StringRef padding_ref, StringRef data_format_ref); 253 254 // Lowers Gather operator to a sequence of TOSA ops. 255 llvm::Optional<Value> convertGatherOp(PatternRewriter& rewriter, Operation* op, 256 Value result_value, Value params_value, 257 Value indices_value, int32_t batch_dims, 258 int32_t axis); 259 260 // Lowers GatherNd operator to a sequence of TOSA ops. 261 llvm::Optional<Value> convertGatherNdOp(PatternRewriter& rewriter, 262 Operation* op, Value result_value, 263 Value params_value, 264 Value indices_value); 265 266 // Lowers OneHot operator to a sequence of TOSA ops. 267 llvm::Optional<Value> convertOneHotOp(PatternRewriter& rewriter, Operation* op, 268 Value result_value, Value indices_value, 269 Value on_value, Value off_value, 270 int32_t depth, int32_t axis); 271 272 }; // namespace tosa 273 }; // namespace mlir 274 275 #endif // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H 276