xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H
17 #define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H
18 
19 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
20 #include "mlir/Support/LLVM.h"  // from @llvm-project
21 
22 // This file contains legalizations common to mapping both TensorFlow and
23 // TensorFlow Lite to TOSA.
24 //
25 // Conversion functions return None on a failure or result value on success.
26 // Callers must check and return a LogicalResult failure on nullptr.
27 //
28 // For these functions, the framework-specific operands/attributes/defaults
29 // are already extracted and placed in a common form for lowering.
30 
31 namespace mlir {
32 namespace tosa {
33 
34 // Lowers the Pack operator to TOSA.
35 llvm::Optional<Value> convertPackOp(PatternRewriter& rewriter, Operation* op,
36                                     Value result_value,
37                                     SmallVectorImpl<Value>& inputs,
38                                     int32_t axis);
39 
40 // Lowers the Unpack operator to TOSA.
41 llvm::Optional<SmallVector<Value>> convertUnpackOp(PatternRewriter& rewriter,
42                                                    Operation* op,
43                                                    Value input_value,
44                                                    int32_t axis);
45 
46 // Lowers the Select operator to TOSA.
47 llvm::Optional<Value> convertSelectOp(PatternRewriter& rewriter, Operation* op,
48                                       Value result_value, Value condition_value,
49                                       Value x_value, Value y_value);
50 
51 // Lowers the ZerosLike operator to TOSA by creating a constant
52 // of the desired type and shape.
53 llvm::Optional<Value> convertZerosLikeOp(PatternRewriter& rewriter,
54                                          Operation* op, Value result,
55                                          Value input);
56 
57 // Lowers the Mul operator to TOSA.  For quantized types, this requires
58 // inserting rescale operators before and after the operation.
59 llvm::Optional<Value> convertMultiplyOp(PatternRewriter& rewriter,
60                                         Operation* op, Value output_val,
61                                         Value input_lhs_val,
62                                         Value input_rhs_val);
63 
64 // Lowers the SquaredDifference operator to TOSA.
65 llvm::Optional<Value> convertSquaredDifferenceOp(PatternRewriter& rewriter,
66                                                  Operation* op, Value result,
67                                                  Value x, Value y);
68 
69 // Lowers the Round operator to TOSA.
70 llvm::Optional<Value> convertRoundOp(PatternRewriter& rewriter, Operation* op,
71                                      Value result, Value input);
72 
73 // Lowers ConcatV2 to TOSA.
74 llvm::Optional<Value> convertConcatV2Op(PatternRewriter& rewriter,
75                                         Operation* op, Value result_value,
76                                         SmallVectorImpl<Value>& values,
77                                         int32_t axis);
78 
79 // Lowers SpaceToBatchND to TOSA.
80 llvm::Optional<Value> convertSpaceToBatchNDOp(PatternRewriter& rewriter,
81                                               Operation* op, Value result_value,
82                                               Value input_value,
83                                               Value block_shape_value,
84                                               Value paddings_value);
85 
86 // Lowers BatchToSpaceND to TOSA.
87 llvm::Optional<Value> convertBatchToSpaceNDOp(PatternRewriter& rewriter,
88                                               Operation* op, Value result_value,
89                                               Value input_value,
90                                               Value block_shape_value,
91                                               Value crops_value);
92 
93 // Lowers ExpandDims to TOSA.
94 llvm::Optional<Value> convertExpandDimsOp(PatternRewriter& rewriter,
95                                           Operation* op, Value result_value,
96                                           Value input_value, Value dim_value);
97 
98 // Lowers Squeeze to TOSA.
99 llvm::Optional<Value> convertSqueezeOp(PatternRewriter& rewriter, Operation* op,
100                                        Value result_value, Value input_value,
101                                        SmallVectorImpl<int32_t>& squeeze_dims);
102 
103 // Lowers ELU to a sequence of TOSA ops.
104 llvm::Optional<Value> convertEluOp(PatternRewriter& rewriter, Operation* op,
105                                    Value result_value, Value features_value);
106 
107 // Lowers Softmax to a sequence of TOSA ops.
108 llvm::Optional<Value> convertSoftmaxOp(PatternRewriter& rewriter, Operation* op,
109                                        Value result_value, Value logits_value,
110                                        double beta);
111 
112 // Lowers LogSoftmax to a sequence of TOSA ops.
113 llvm::Optional<Value> convertLogSoftmaxOp(PatternRewriter& rewriter,
114                                           Operation* op, Value result_value,
115                                           Value logits_value);
116 
117 // Lowers SpaceToDepth to a sequence of TOSA ops.  Supports NHWC.
118 llvm::Optional<Value> convertSpaceToDepthOp(PatternRewriter& rewriter,
119                                             Operation* op, Value result_value,
120                                             Value input_value,
121                                             IntegerAttr block_size_attr,
122                                             StringAttr data_format);
123 
124 // Lowers DepthToSpace to a sequence of TOSA ops.  Supports NHWC.
125 llvm::Optional<Value> convertDepthToSpaceOp(PatternRewriter& rewriter,
126                                             Operation* op, Value result_value,
127                                             Value input_value,
128                                             IntegerAttr block_size_attr,
129                                             StringAttr data_format);
130 
131 // Lowers Split to a sequence of TOSA ops.
132 llvm::Optional<SmallVector<Value>> convertSplitOp(
133     PatternRewriter& rewriter, Operation* op, Value result_value,
134     Value input_value, int32_t num_split, int32_t axis);
135 
136 // Lowers SplitV to a sequence of TOSA ops.
137 llvm::Optional<SmallVector<Value>> convertSplitVOp(
138     PatternRewriter& rewriter, Operation* op, Value result_value,
139     Value input_value, SmallVectorImpl<int32_t>& size_split, int32_t axis);
140 
141 // Lowers StridedSlice to a sequence of TOSA ops.
142 llvm::Optional<Value> convertStridedSliceOp(
143     PatternRewriter& rewriter, Operation* op, Value result_value,
144     Value input_value, Value begin_value, Value end_value, Value strides_value,
145     int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
146     int32_t new_axis_mask, int32_t shrink_axis_mask);
147 
148 // Lowers FloorDiv to a sequence of TOSA operators.
149 llvm::Optional<Value> convertFloorDivOp(PatternRewriter& rewriter,
150                                         Operation* op, Value result_value,
151                                         Value lhs_value, Value rhs_value);
152 
153 // Lowers FloorMod to a sequence of TOSA operators.
154 llvm::Optional<Value> convertFloorModOp(PatternRewriter& rewriter,
155                                         Operation* op, Value result_value,
156                                         Value lhs_value, Value rhs_value);
157 
158 // Lowers FusedActivation to a sequence of TOSA ops.
159 llvm::Optional<Value> convertFusedActivation(PatternRewriter& rewriter,
160                                              Operation* op, Value input_value,
161                                              StringAttr fused_activation_fn);
162 
163 // Helper function for implementing quantized divide by power-of-two in TOSA
164 // ops.
165 llvm::Optional<Value> convertRoundingDivideByPOT(PatternRewriter& rewriter,
166                                                  Operation* op,
167                                                  Value input_value,
168                                                  Value rshift_value);
169 
170 // Lowers ReduceAll to a sequence of TOSA ops.
171 llvm::Optional<Value> convertReduceAllOp(PatternRewriter& rewriter,
172                                          Operation* op,
173                                          RankedTensorType output_type,
174                                          Value input_value,
175                                          ElementsAttr axes_elems);
176 
177 // Lowers ReduceAny to a sequence of TOSA ops.
178 llvm::Optional<Value> convertReduceAnyOp(PatternRewriter& rewriter,
179                                          Operation* op,
180                                          RankedTensorType output_type,
181                                          Value input_value,
182                                          ElementsAttr axes_elems);
183 
184 // Lowers ReduceMin to a sequence of TOSA ops.
185 llvm::Optional<Value> convertReduceMinOp(PatternRewriter& rewriter,
186                                          Operation* op,
187                                          RankedTensorType output_type,
188                                          Value input_value,
189                                          ElementsAttr axes_elems);
190 
191 // Lowers ReduceMax to a sequence of TOSA ops.
192 llvm::Optional<Value> convertReduceMaxOp(PatternRewriter& rewriter,
193                                          Operation* op,
194                                          RankedTensorType output_type,
195                                          Value input_value,
196                                          ElementsAttr axes_elems);
197 
198 // Lowers ReduceProd to a sequence of TOSA ops.
199 llvm::Optional<Value> convertReduceProdOp(PatternRewriter& rewriter,
200                                           Operation* op,
201                                           RankedTensorType output_type,
202                                           Value input_value,
203                                           ElementsAttr axes_elems);
204 
205 // Lowers ReduceSum to a sequence of TOSA ops.
206 llvm::Optional<Value> convertReduceSumOp(PatternRewriter& rewriter,
207                                          Operation* op,
208                                          RankedTensorType output_type,
209                                          Value input_value,
210                                          ElementsAttr axes_elems);
211 
212 // Lowers ReduceMean to a sequence of TOSA ops.
213 llvm::Optional<Value> convertReduceMeanOp(PatternRewriter& rewriter,
214                                           Operation* op,
215                                           RankedTensorType output_type,
216                                           Value input_value,
217                                           ElementsAttr axes_elem);
218 
219 // Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize.
220 llvm::Optional<Value> convertResizeOp(PatternRewriter& rewriter, Operation* op,
221                                       RankedTensorType output_type,
222                                       Value input_value, StringRef mode,
223                                       bool align_corners,
224                                       bool half_pixel_centers);
225 
226 // Lowers Quantize to a sequence of TOSA quantization ops.
227 llvm::Optional<Value> convertQuantizeOp(PatternRewriter& rewriter,
228                                         Operation* op, ShapedType output_type,
229                                         Value input_value, double scale,
230                                         int64_t zeropoint);
231 
232 // Lowers Dequantize to a sequence of TOSA dequantization ops.
233 llvm::Optional<Value> convertDequantizeOp(PatternRewriter& rewriter,
234                                           Operation* op, ShapedType output_type,
235                                           Value input_value,
236                                           ArrayRef<float> scale,
237                                           ArrayRef<float> zeropoint,
238                                           int64_t dim);
239 
240 // Lowers FakeQuant to a sequence of TOSA quantization ops.
241 llvm::Optional<Value> convertFakeQuantOp(PatternRewriter& rewriter,
242                                          Operation* op, ShapedType output_type,
243                                          Value input_value, double min,
244                                          double max, int64_t num_bits,
245                                          bool narrow_range);
246 
247 // Lowers TensorFlow Conv2D to a sequence of TOSA quantization ops.
248 llvm::Optional<Value> convertTFConv2DCommon(
249     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
250     Value input, Value filter, Value bias, ArrayAttr strides_attr,
251     ArrayAttr dilations_attr, ArrayAttr explicit_padding_attr,
252     StringRef padding_ref, StringRef data_format_ref);
253 
254 // Lowers Gather operator to a sequence of TOSA ops.
255 llvm::Optional<Value> convertGatherOp(PatternRewriter& rewriter, Operation* op,
256                                       Value result_value, Value params_value,
257                                       Value indices_value, int32_t batch_dims,
258                                       int32_t axis);
259 
260 // Lowers GatherNd operator to a sequence of TOSA ops.
261 llvm::Optional<Value> convertGatherNdOp(PatternRewriter& rewriter,
262                                         Operation* op, Value result_value,
263                                         Value params_value,
264                                         Value indices_value);
265 
266 // Lowers OneHot operator to a sequence of TOSA ops.
267 llvm::Optional<Value> convertOneHotOp(PatternRewriter& rewriter, Operation* op,
268                                       Value result_value, Value indices_value,
269                                       Value on_value, Value off_value,
270                                       int32_t depth, int32_t axis);
271 
272 };  // namespace tosa
273 };  // namespace mlir
274 
275 #endif  // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H
276