xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/types.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <initializer_list>
22 
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 #include "tensorflow/lite/kernels/internal/runtime_shape.h"
25 
26 namespace tflite {
27 
28 enum class FusedActivationFunctionType : uint8_t {
29   kNone,
30   kRelu6,
31   kRelu1,
32   kRelu
33 };
34 enum class PaddingType : uint8_t { kNone, kSame, kValid };
35 
36 struct PaddingValues {
37   int16_t width;
38   int16_t height;
39   // offset is used for calculating "remaining" padding, for example, `width`
40   // is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
41   // 1 + 1 = 2.
42   int16_t width_offset;
43   // Same as width_offset except it's over the height dimension.
44   int16_t height_offset;
45 };
46 
47 struct Padding3DValues {
48   int16_t width;
49   int16_t height;
50   int16_t depth;
51   // offset is used for calculating "remaining" padding, for example, `width`
52   // is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
53   // 1 + 1 = 2.
54   int16_t width_offset;
55   // Same as width_offset except it's over the height dimension.
56   int16_t height_offset;
57   // Same as width_offset except it's over the depth dimension.
58   int16_t depth_offset;
59 };
60 
61 // This enumeration allows for non-default formats for the weights array
62 // of a fully-connected operator, allowing the use of special optimized
63 // runtime paths.
64 enum class FullyConnectedWeightsFormat : uint8_t {
65   // Default format (flat 2D layout, the inner contiguous dimension
66   // is input_depth, the outer non-contiguous dimension is output_depth)
67   kDefault,
68   // Summary: optimized layout for fast CPU runtime implementation,
69   // aimed specifically at ARM CPUs at the moment, and specialized for
70   // 8-bit quantized layers.
71   //
72   // The use case we're concerned with here is: 8-bit quantization,
73   // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in
74   // a key application that drove this), very small batch size (e.g. 1 -- 4).
75   //
76   // Even with 8-bit quantization of weights, the performance of memory
77   // accesses to the weights can become the dominant issue when
78   // the batch size is small, so each weight value is used in only a few
79   // arithmetic ops, i.e. the fully-connected node has a low arithmetic
80   // intensity. The specific issues that arise are of three kinds:
81   // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory
82   //     bound. That's the "good" issue to run into.
83   // (2) One may run into sub-optimal pre-fetching: the data hasn't been
84   //     prefetched into the cache by the time we need it.
85   // (3) One may run into cache aliasing: multiple values that are
86   //     pre-fetched, alias each other in the L1 cache (which typically
87   //     has only 4-way set associativity in ARM CPUs) and thus evict
88   //     each other before we get to using them.
89   //
90   // The point of this shuffling is to avoid issues (2) and (3) so that
91   // we get as fast as possible given only the hard constraint (1).
92   // This is achieved by turning the difficulty into a solution: the
93   // difficulty, that each value loaded from memory is used only in
94   // one kernel iteration, making this operation memory-intensive, hints at
95   // the solution, of shuffling the weights so that they are stored in the
96   // exact order as the kernel needs to load them, so that the memory
97   // accesses made by the kernel are trivial. This solves (2) because the
98   // trivial memory access pattern allows the CPU's automatic prefetching
99   // to perform very well (no need even for preload instructions), and this
100   // solves (3) because the values being loaded concurrently are now
101   // contiguous in the address space, thus don't alias each other in the cache.
102   //
103   // On ARM, we typically want our kernel to process a 4x16 block of weights
104   // at a time, because:
105   //   - 16 is the number of bytes in a NEON register.
106   //   - 4 is how many rows we need to handle concurrently in the kernel in
107   //     order to have sufficient mutual independence of instructions to
108   //     maximize arithmetic throughput.
109   //
110   // Finally, the 'Int8' part in the name refers to the fact that this
111   // weights format has each weights value encoded as a signed int8_t value,
112   // even if the data type of the weights buffer is uint8_t.  This is intended
113   // to save runtime kernels the effort to have to XOR the top bit of these
114   // bytes before using them in signed arithmetic, see this file for more
115   // explanations on the 'signed int8_t trick' in matrix multiplication kernels:
116   //
117   //   tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
118   //
119   kShuffled4x16Int8,
120 };
121 
122 // Quantization parameters, determining the mapping of quantized values
123 // to real values (i.e. determining how quantized values are mathematically
124 // interpreted).
125 //
126 // The correspondence is as follows:
127 //
128 //   real_value = scale * (quantized_value - zero_point);
129 //
130 // In other words, zero_point designates which quantized value corresponds to
131 // the real 0 value, and scale designates the difference between the real values
132 // corresponding to consecutive quantized values differing by 1.
133 struct QuantizationParams {
134   int32_t zero_point = 0;
135   double scale = 0.0;
136 };
137 
138 inline bool operator==(const QuantizationParams& qp1,
139                        const QuantizationParams& qp2) {
140   return qp1.zero_point == qp2.zero_point && qp1.scale == qp2.scale;
141 }
142 
143 // Quantization parameters for each channel, determining the mapping of
144 // quantized values to real values. See QuantizationParams for a single set of
145 // parameters per tensor. This has one parameters set per each channel.
146 //
147 // The correspondence is as follows:
148 //
149 //   real_value = scale[channel] * (quantized_value - zero_point[channel]);
150 //
151 struct PerChannelQuantizationParams {
152   // The following members typically point to the corresponding members of a
153   // TfLiteAffineQuantization struct.
154   const float* scale;
155   const int32_t* zero_point;
156   int32_t quantized_dimension;
157 };
158 
159 // Gets next index to iterate through a multidimensional array.
NextIndex(const int num_dims,const int * dims,int * current)160 inline bool NextIndex(const int num_dims, const int* dims, int* current) {
161   if (num_dims == 0) {
162     return false;
163   }
164   TFLITE_DCHECK(dims != nullptr);
165   TFLITE_DCHECK(current != nullptr);
166   int carry = 1;
167   for (int idx = num_dims - 1; idx >= 0; --idx) {
168     int current_val = current[idx] + carry;
169     TFLITE_DCHECK_GE(dims[idx], current_val);
170     if (dims[idx] == current_val) {
171       current[idx] = 0;
172     } else {
173       current[idx] = current_val;
174       carry = 0;
175       break;
176     }
177   }
178   return (carry == 0);
179 }
180 
181 // Gets offset of index if reducing on axis. When reducing, the flattened offset
182 // will not change, if the input index changes on the given axis. For example,
183 // if you have a 3D tensor and you are reducing to 2D by eliminating axis 0,
184 // then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened
185 // offset.
186 // TODO(kanlig): uses Dims to represent dimensions.
ReducedOutputOffset(const int num_dims,const int * dims,const int * index,const int num_axis,const int * axis)187 inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
188                                   const int* index, const int num_axis,
189                                   const int* axis) {
190   if (num_dims == 0) {
191     return 0;
192   }
193   TFLITE_DCHECK(dims != nullptr);
194   TFLITE_DCHECK(index != nullptr);
195   size_t offset = 0;
196   for (int idx = 0; idx < num_dims; ++idx) {
197     // if we need to skip this axis
198     bool is_axis = false;
199     if (axis != nullptr) {
200       for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
201         if (idx == axis[axis_idx]) {
202           is_axis = true;
203           break;
204         }
205       }
206     }
207     if (!is_axis) {
208       offset = offset * static_cast<size_t>(dims[idx]) +
209                static_cast<size_t>(index[idx]);
210     }
211   }
212   return offset;
213 }
214 
215 // Since tensors with '0' in their shape are valid in TF, these offset functions
216 // allow that as long as the corresponding index is also 0. It is upto the
217 // calling ops to ensure that they perform verification checks on tensor shapes
218 // if they don't support a particular behavior.
219 
Offset(const Dims<4> & dims,int i0,int i1,int i2,int i3)220 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
221   TFLITE_DCHECK((i0 == 0 && dims.sizes[0] == 0) ||
222                 (i0 >= 0 && i0 < dims.sizes[0]));
223   TFLITE_DCHECK((i1 == 0 && dims.sizes[1] == 0) ||
224                 (i1 >= 0 && i1 < dims.sizes[1]));
225   TFLITE_DCHECK((i2 == 0 && dims.sizes[2] == 0) ||
226                 (i2 >= 0 && i2 < dims.sizes[2]));
227   TFLITE_DCHECK((i3 == 0 && dims.sizes[3] == 0) ||
228                 (i3 >= 0 && i3 < dims.sizes[3]));
229   return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
230          i3 * dims.strides[3];
231 }
232 
Offset(const Dims<4> & dims,int * index)233 inline int Offset(const Dims<4>& dims, int* index) {
234   return Offset(dims, index[0], index[1], index[2], index[3]);
235 }
236 
237 // Get array size, DCHECKing that the dim index is in range.
238 //
239 // Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
240 // already performs this check.
241 template <int N>
ArraySize(const Dims<N> & array,int index)242 int ArraySize(const Dims<N>& array, int index) {
243   TFLITE_DCHECK(index >= 0 && index < N);
244   return array.sizes[index];
245 }
246 
247 // Get common array size, DCHECKing that they all agree.
248 template <typename ArrayType1, typename ArrayType2>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2)249 int MatchingArraySize(const ArrayType1& array1, int index1,
250                       const ArrayType2& array2, int index2) {
251   TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
252   return ArraySize(array1, index1);
253 }
254 
255 template <typename ArrayType1, typename ArrayType2, typename... Args>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2,Args...args)256 int MatchingArraySize(const ArrayType1& array1, int index1,
257                       const ArrayType2& array2, int index2, Args... args) {
258   TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
259   return MatchingArraySize(array1, index1, args...);
260 }
261 
262 // Get common shape dim, DCHECKing that they all agree.
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2)263 inline int MatchingDim(const RuntimeShape& shape1, int index1,
264                        const RuntimeShape& shape2, int index2) {
265   TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
266   return std::min(shape1.Dims(index1), shape2.Dims(index2));
267 }
268 
269 template <typename... Args>
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2,Args...args)270 int MatchingDim(const RuntimeShape& shape1, int index1,
271                 const RuntimeShape& shape2, int index2, Args... args) {
272   TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
273   return MatchingDim(shape1, index1, args...);
274 }
275 
276 // Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
277 template <int N>
FlatSize(const Dims<N> & dims)278 inline int FlatSize(const Dims<N>& dims) {
279   int flat_size = 1;
280   for (int i = 0; i < N; ++i) {
281     flat_size *= dims.sizes[i];
282   }
283   return flat_size;
284 }
285 
286 TFLITE_DEPRECATED("Prefer FlatSize.")
RequiredBufferSizeForDims(const Dims<4> & dims)287 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
288   return FlatSize(dims);
289 }
290 
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)291 inline int MatchingElementsSize(const RuntimeShape& shape,
292                                 const RuntimeShape& check_shape_0) {
293   const int size_1 = shape.FlatSize();
294   const int size_2 = check_shape_0.FlatSize();
295   TFLITE_CHECK_EQ(size_1, size_2);
296   return size_1;
297 }
298 
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)299 inline int MatchingElementsSize(const RuntimeShape& shape,
300                                 const RuntimeShape& check_shape_0,
301                                 const RuntimeShape& check_shape_1) {
302   const int size_1 = shape.FlatSize();
303   const int size_2 = check_shape_0.FlatSize();
304   const int size_3 = check_shape_1.FlatSize();
305   TFLITE_CHECK_EQ(size_1, size_2);
306   TFLITE_CHECK_EQ(size_2, size_3);
307   return size_1;
308 }
309 
310 // Flat size calculation, checking that dimensions match with one or more other
311 // arrays.
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)312 inline int MatchingFlatSize(const RuntimeShape& shape,
313                             const RuntimeShape& check_shape_0) {
314   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
315   const int dims_count = shape.DimensionsCount();
316   for (int i = 0; i < dims_count; ++i) {
317     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
318   }
319   return shape.FlatSize();
320 }
321 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)322 inline int MatchingFlatSize(const RuntimeShape& shape,
323                             const RuntimeShape& check_shape_0,
324                             const RuntimeShape& check_shape_1) {
325   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
326   const int dims_count = shape.DimensionsCount();
327   for (int i = 0; i < dims_count; ++i) {
328     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
329   }
330   return MatchingFlatSize(shape, check_shape_1);
331 }
332 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)333 inline int MatchingFlatSize(const RuntimeShape& shape,
334                             const RuntimeShape& check_shape_0,
335                             const RuntimeShape& check_shape_1,
336                             const RuntimeShape& check_shape_2) {
337   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
338   const int dims_count = shape.DimensionsCount();
339   for (int i = 0; i < dims_count; ++i) {
340     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
341   }
342   return MatchingFlatSize(shape, check_shape_1, check_shape_2);
343 }
344 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)345 inline int MatchingFlatSize(const RuntimeShape& shape,
346                             const RuntimeShape& check_shape_0,
347                             const RuntimeShape& check_shape_1,
348                             const RuntimeShape& check_shape_2,
349                             const RuntimeShape& check_shape_3) {
350   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
351   const int dims_count = shape.DimensionsCount();
352   for (int i = 0; i < dims_count; ++i) {
353     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
354   }
355   return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3);
356 }
357 
358 // Flat size calculation, checking that dimensions match with one or more other
359 // arrays.
360 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0)361 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
362   for (int i = 0; i < N; ++i) {
363     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
364   }
365   return FlatSize(dims);
366 }
367 
368 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)369 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
370                             const Dims<N>& check_dims_1) {
371   for (int i = 0; i < N; ++i) {
372     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
373   }
374   return MatchingFlatSize(dims, check_dims_1);
375 }
376 
377 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)378 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
379                             const Dims<N>& check_dims_1,
380                             const Dims<N>& check_dims_2) {
381   for (int i = 0; i < N; ++i) {
382     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
383   }
384   return MatchingFlatSize(dims, check_dims_1, check_dims_2);
385 }
386 
387 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)388 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
389                             const Dims<N>& check_dims_1,
390                             const Dims<N>& check_dims_2,
391                             const Dims<N>& check_dims_3) {
392   for (int i = 0; i < N; ++i) {
393     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
394   }
395   return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
396 }
397 
398 // Flat size calculation, checking if their extended shapes match.
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)399 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
400                                          const RuntimeShape& check_shape_0) {
401   const int shape_dims = shape.DimensionsCount();
402   const int check_shape_0_dims = check_shape_0.DimensionsCount();
403   const int min_dims = std::min(shape_dims, check_shape_0_dims);
404 
405   for (int i = 0; i < min_dims; ++i) {
406     TFLITE_DCHECK_EQ(shape.Dims(shape_dims - 1 - i),
407                      check_shape_0.Dims(check_shape_0_dims - 1 - i));
408   }
409   for (int i = min_dims; i < shape_dims; ++i) {
410     TFLITE_DCHECK_EQ(shape.Dims(shape_dims - 1 - i), 1);
411   }
412   for (int i = min_dims; i < check_shape_0_dims; ++i) {
413     TFLITE_DCHECK_EQ(check_shape_0.Dims(check_shape_0_dims - 1 - i), 1);
414   }
415   return shape.FlatSize();
416 }
417 
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)418 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
419                                          const RuntimeShape& check_shape_0,
420                                          const RuntimeShape& check_shape_1) {
421   const int flat_size = MatchingExtendedShapeFlatSize(shape, check_shape_0);
422   TFLITE_DCHECK_EQ(MatchingExtendedShapeFlatSize(shape, check_shape_1),
423                    flat_size);
424   return flat_size;
425 }
426 
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)427 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
428                                          const RuntimeShape& check_shape_0,
429                                          const RuntimeShape& check_shape_1,
430                                          const RuntimeShape& check_shape_2) {
431   const int flat_size = MatchingExtendedShapeFlatSize(shape, check_shape_0);
432   TFLITE_DCHECK_EQ(
433       MatchingExtendedShapeFlatSize(shape, check_shape_1, check_shape_2),
434       flat_size);
435   return flat_size;
436 }
437 
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)438 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
439                                          const RuntimeShape& check_shape_0,
440                                          const RuntimeShape& check_shape_1,
441                                          const RuntimeShape& check_shape_2,
442                                          const RuntimeShape& check_shape_3) {
443   const int flat_size = MatchingExtendedShapeFlatSize(shape, check_shape_0);
444   TFLITE_DCHECK_EQ(MatchingExtendedShapeFlatSize(shape, check_shape_1,
445                                                  check_shape_2, check_shape_3),
446                    flat_size);
447   return flat_size;
448 }
449 
450 // Data is required to be contiguous, and so many operators can use either the
451 // full array flat size or the flat size with one dimension skipped (commonly
452 // the depth).
453 template <int N>
FlatSizeSkipDim(const Dims<N> & dims,int skip_dim)454 inline int FlatSizeSkipDim(const Dims<N>& dims, int skip_dim) {
455   TFLITE_DCHECK(skip_dim >= 0 && skip_dim < N);
456   int flat_size = 1;
457   for (int i = 0; i < N; ++i) {
458     flat_size *= (i == skip_dim) ? 1 : dims.sizes[i];
459   }
460   return flat_size;
461 }
462 
463 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
464 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0)465 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
466                                    const Dims<N>& check_dims_0) {
467   for (int i = 0; i < N; ++i) {
468     if (i != skip_dim) {
469       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
470     }
471   }
472   return FlatSizeSkipDim(dims, skip_dim);
473 }
474 
475 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)476 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
477                                    const Dims<N>& check_dims_0,
478                                    const Dims<N>& check_dims_1) {
479   for (int i = 0; i < N; ++i) {
480     if (i != skip_dim) {
481       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
482     }
483   }
484   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1);
485 }
486 
487 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)488 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
489                                    const Dims<N>& check_dims_0,
490                                    const Dims<N>& check_dims_1,
491                                    const Dims<N>& check_dims_2) {
492   for (int i = 0; i < N; ++i) {
493     if (i != skip_dim) {
494       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
495     }
496   }
497   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2);
498 }
499 
500 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)501 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
502                                    const Dims<N>& check_dims_0,
503                                    const Dims<N>& check_dims_1,
504                                    const Dims<N>& check_dims_2,
505                                    const Dims<N>& check_dims_3) {
506   for (int i = 0; i < N; ++i) {
507     if (i != skip_dim) {
508       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
509     }
510   }
511   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2,
512                                  check_dims_3);
513 }
514 
515 // Data is required to be contiguous, and so many operators can use either the
516 // full array flat size or the flat size with one dimension skipped (commonly
517 // the depth).
FlatSizeSkipDim(const RuntimeShape & shape,int skip_dim)518 inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
519   const int dims_count = shape.DimensionsCount();
520   TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
521   const auto* dims_data = shape.DimsData();
522   int flat_size = 1;
523   for (int i = 0; i < dims_count; ++i) {
524     flat_size *= (i == skip_dim) ? 1 : dims_data[i];
525   }
526   return flat_size;
527 }
528 
529 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0)530 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
531                                    const RuntimeShape& check_shape_0) {
532   const int dims_count = shape.DimensionsCount();
533   for (int i = 0; i < dims_count; ++i) {
534     if (i != skip_dim) {
535       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
536     }
537   }
538   return FlatSizeSkipDim(shape, skip_dim);
539 }
540 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)541 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
542                                    const RuntimeShape& check_shape_0,
543                                    const RuntimeShape& check_shape_1) {
544   const int dims_count = shape.DimensionsCount();
545   for (int i = 0; i < dims_count; ++i) {
546     if (i != skip_dim) {
547       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
548     }
549   }
550   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
551 }
552 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)553 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
554                                    const RuntimeShape& check_shape_0,
555                                    const RuntimeShape& check_shape_1,
556                                    const RuntimeShape& check_shape_2) {
557   const int dims_count = shape.DimensionsCount();
558   for (int i = 0; i < dims_count; ++i) {
559     if (i != skip_dim) {
560       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
561     }
562   }
563   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
564 }
565 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)566 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
567                                    const RuntimeShape& check_shape_0,
568                                    const RuntimeShape& check_shape_1,
569                                    const RuntimeShape& check_shape_2,
570                                    const RuntimeShape& check_shape_3) {
571   const int dims_count = shape.DimensionsCount();
572   for (int i = 0; i < dims_count; ++i) {
573     if (i != skip_dim) {
574       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
575     }
576   }
577   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
578                                  check_shape_3);
579 }
580 
581 template <int N>
IsPackedWithoutStrides(const Dims<N> & dims)582 bool IsPackedWithoutStrides(const Dims<N>& dims) {
583   int expected_stride = 1;
584   for (int d = 0; d < N; d++) {
585     if (dims.strides[d] != expected_stride) return false;
586     expected_stride *= dims.sizes[d];
587   }
588   return true;
589 }
590 
591 template <int N>
ComputeStrides(Dims<N> * dims)592 void ComputeStrides(Dims<N>* dims) {
593   dims->strides[0] = 1;
594   for (int d = 1; d < N; d++) {
595     dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1];
596   }
597 }
598 
599 enum class BroadcastableOpCategory : uint8_t {
600   kNone,
601   kNonBroadcast,               // Matching input shapes.
602   kFirstInputBroadcastsFast,   // Fivefold nested loops.
603   kSecondInputBroadcastsFast,  // Fivefold nested loops.
604   kGenericBroadcast,           // Fall-back.
605 };
606 
607 struct MinMax {
608   float min;
609   float max;
610 };
611 static_assert(sizeof(MinMax) == 8, "");
612 
613 struct ActivationParams {
614   FusedActivationFunctionType activation_type;
615   // uint8_t, etc, activation params.
616   int32_t quantized_activation_min;
617   int32_t quantized_activation_max;
618 };
619 
620 struct ReluParams : public ActivationParams {
621   int32_t input_offset;
622   int32_t output_offset;
623   int32_t output_multiplier;
624   int output_shift;
625 };
626 
627 // Styles of resizing op usages. For example, kImageStyle can be used with a Pad
628 // op for pattern-specific optimization.
629 enum class ResizingCategory : uint8_t {
630   kNone,
631   kImageStyle,  // 4D, operating on inner dimensions, say {0, a, b, 0}.
632   kGenericResize,
633 };
634 
635 // For Add, Sub, Mul ops.
636 struct ArithmeticParams {
637   // Shape dependent / common to data / op types.
638   BroadcastableOpCategory broadcast_category;
639   // uint8_t inference params.
640   int32_t input1_offset;
641   int32_t input2_offset;
642   int32_t output_offset;
643   int32_t output_multiplier;
644   int output_shift;
645   // Add / Sub, not Mul, uint8_t inference params.
646   int left_shift;
647   int32_t input1_multiplier;
648   int input1_shift;
649   int32_t input2_multiplier;
650   int input2_shift;
651 
652   // TODO(b/158622529): Union the following activation params.
653   // uint8_t, etc, activation params.
654   int32_t quantized_activation_min;
655   int32_t quantized_activation_max;
656   // float activation params.
657   float float_activation_min;
658   float float_activation_max;
659   // int64_t activation params.
660   int64_t int64_activation_min;
661   int64_t int64_activation_max;
662 
663   // Processed output dimensions.
664   // Let input "a" be the one that broadcasts in the faster-changing dimension.
665   // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
666   // {b0, b1, b2, b3, b4},
667   // broadcast_shape[4] = b0 = a0.
668   // broadcast_shape[3] = b1; a1 = 1.
669   // broadcast_shape[2] = b2 = a2.
670   // broadcast_shape[1] = a3; b3 = 1.
671   // broadcast_shape[0] = b4 = a4.
672   int broadcast_shape[5];
673 };
674 
675 struct ConcatenationParams {
676   int8_t axis;
677   const int32_t* input_zeropoint;
678   const float* input_scale;
679   uint16_t inputs_count;
680   int32_t output_zeropoint;
681   float output_scale;
682 };
683 
684 struct ComparisonParams {
685   // uint8_t inference params.
686   int left_shift;
687   int32_t input1_offset;
688   int32_t input1_multiplier;
689   int input1_shift;
690   int32_t input2_offset;
691   int32_t input2_multiplier;
692   int input2_shift;
693   // Shape dependent / common to inference types.
694   bool is_broadcast;
695 };
696 
697 struct ConvParams {
698   PaddingType padding_type;
699   PaddingValues padding_values;
700   // TODO(starka): This was just "stride", so check that width+height is OK.
701   int16_t stride_width;
702   int16_t stride_height;
703   int16_t dilation_width_factor;
704   int16_t dilation_height_factor;
705   // uint8_t inference params.
706   // TODO(b/65838351): Use smaller types if appropriate.
707   int32_t input_offset;
708   int32_t weights_offset;
709   int32_t output_offset;
710   int32_t output_multiplier;
711   int output_shift;
712   // uint8_t, etc, activation params.
713   int32_t quantized_activation_min;
714   int32_t quantized_activation_max;
715   // float activation params.
716   float float_activation_min;
717   float float_activation_max;
718 };
719 
720 struct Conv3DParams {
721   Padding3DValues padding_values;
722   int stride_width;
723   int stride_height;
724   int stride_depth;
725   int dilation_width;
726   int dilation_height;
727   int dilation_depth;
728   // float activation params.
729   float float_activation_min;
730   float float_activation_max;
731 };
732 
733 typedef Conv3DParams Conv3DTransposeParams;
734 
735 struct DepthToSpaceParams {
736   int32_t block_size;
737 };
738 
739 struct DepthwiseParams {
740   PaddingType padding_type;
741   PaddingValues padding_values;
742   int16_t stride_width;
743   int16_t stride_height;
744   int16_t dilation_width_factor;
745   int16_t dilation_height_factor;
746   int16_t depth_multiplier;
747   // uint8_t inference params.
748   // TODO(b/65838351): Use smaller types if appropriate.
749   int32_t input_offset;
750   int32_t weights_offset;
751   int32_t output_offset;
752   int32_t output_multiplier;
753   int output_shift;
754   // uint8_t, etc, activation params.
755   int32_t quantized_activation_min;
756   int32_t quantized_activation_max;
757   // float activation params.
758   float float_activation_min;
759   float float_activation_max;
760   const int32_t* output_multiplier_per_channel;
761   const int32_t* output_shift_per_channel;
762 };
763 
764 struct DequantizationParams {
765   double scale;
766   int32_t zero_point;
767 };
768 
769 struct PerChannelDequantizationParams {
770   const float* scale;
771   const int32_t* zero_point;
772   int32_t quantized_dimension;
773 };
774 
775 struct FakeQuantParams {
776   MinMax minmax;
777   int32_t num_bits;
778 };
779 
780 struct FullyConnectedParams {
781   // uint8_t inference params.
782   // TODO(b/65838351): Use smaller types if appropriate.
783   int32_t input_offset;
784   int32_t weights_offset;
785   int32_t output_offset;
786   int32_t output_multiplier;
787   int output_shift;
788   // uint8_t, etc, activation params.
789   int32_t quantized_activation_min;
790   int32_t quantized_activation_max;
791   // float activation params.
792   float float_activation_min;
793   float float_activation_max;
794   // Mark the operands as cacheable if they are unchanging, e.g. weights.
795   bool lhs_cacheable;
796   bool rhs_cacheable;
797   FullyConnectedWeightsFormat weights_format;
798 };
799 
800 struct GatherParams {
801   int16_t axis;
802   int16_t batch_dims;
803 };
804 
805 struct L2NormalizationParams {
806   // uint8_t inference params.
807   int32_t input_zero_point;
808 };
809 
810 struct LocalResponseNormalizationParams {
811   int32_t range;
812   double bias;
813   double alpha;
814   double beta;
815 };
816 
817 struct HardSwishParams {
818   // zero_point of the input activations.
819   int16_t input_zero_point;
820   // zero_point of the output activations.
821   int16_t output_zero_point;
822   // 16bit fixed-point component of the multiplier to apply to go from the
823   // "high-res input scale", which is the input scale multiplied by 2^7, to the
824   // "relu-ish scale", which 3.0/32768.
825   // See the implementation of HardSwishPrepare.
826   int16_t reluish_multiplier_fixedpoint_int16;
827   // exponent/bit-shift component of the aforementioned multiplier.
828   int reluish_multiplier_exponent;
829   // 16bit fixed-point component of the multiplier to apply to go from the
830   // "high-res input scale", which is the input scale multiplied by 2^7, to the
831   // output scale.
832   // See the implementation of HardSwishPrepare.
833   int16_t output_multiplier_fixedpoint_int16;
834   // exponent/bit-shift component of the aforementioned multiplier.
835   int output_multiplier_exponent;
836 };
837 
838 struct LogisticParams {
839   // uint8_t inference params.
840   int32_t input_zero_point;
841   int32_t input_range_radius;
842   int32_t input_multiplier;
843   int input_left_shift;
844 };
845 
846 struct LstmCellParams {
847   int32_t weights_zero_point;
848   int32_t accum_multiplier;
849   int accum_shift;
850   int state_integer_bits;
851 };
852 
853 struct MeanParams {
854   int8_t axis_count;
855   int16_t axis[4];
856 };
857 
858 struct PackParams {
859   int8_t axis;
860   const int32_t* input_zeropoint;
861   const float* input_scale;
862   uint16_t inputs_count;
863   int32_t output_zeropoint;
864   float output_scale;
865 };
866 
867 struct PadParams {
868   int8_t left_padding_count;
869   int32_t left_padding[5];
870   int8_t right_padding_count;
871   int32_t right_padding[5];
872   ResizingCategory resizing_category;
873 };
874 
875 struct PreluParams {
876   int32_t input_offset;
877   int32_t alpha_offset;
878   int32_t output_offset;
879   int32_t output_multiplier_1;
880   int output_shift_1;
881   int32_t output_multiplier_2;
882   int output_shift_2;
883 };
884 
885 struct PoolParams {
886   FusedActivationFunctionType activation;
887   PaddingType padding_type;
888   PaddingValues padding_values;
889   int stride_height;
890   int stride_width;
891   int filter_height;
892   int filter_width;
893   // uint8_t, etc, activation params.
894   int32_t quantized_activation_min;
895   int32_t quantized_activation_max;
896   // float activation params.
897   float float_activation_min;
898   float float_activation_max;
899 };
900 
901 struct ReshapeParams {
902   int8_t shape_count;
903   int32_t shape[4];
904 };
905 
906 struct ResizeBilinearParams {
907   bool align_corners;
908   // half_pixel_centers assumes pixels are of half the actual dimensions, and
909   // yields more accurate resizes. Corresponds to the same argument for the
910   // original TensorFlow op in TF2.0.
911   bool half_pixel_centers;
912 };
913 
914 struct ResizeNearestNeighborParams {
915   bool align_corners;
916   bool half_pixel_centers;
917 };
918 
919 struct SliceParams {
920   int8_t begin_count;
921   int32_t begin[5];
922   int8_t size_count;
923   int32_t size[5];
924 };
925 
926 struct SoftmaxParams {
927   // beta is not really used (not a Tensorflow parameter) and not implemented
928   // for LogSoftmax.
929   double beta;
930   // uint8_t inference params.  Used even when beta defaults to 1.0.
931   int32_t input_multiplier;
932   int32_t input_left_shift;
933   // Reverse scaling is only used by LogSoftmax.
934   int32_t reverse_scaling_divisor;
935   int32_t reverse_scaling_right_shift;
936   int diff_min;
937   int32_t zero_point;
938   float scale;
939   float* table;
940   // int16 LUT for exp(x), where x uniform distributed between [-10.0 , 0.0]
941   int16_t* exp_lut;
942   // int16 LUT for 1 / (1 + x), where x uniform distributed between [0.0 , 1.0]
943   int16_t* one_over_one_plus_x_lut;
944   uint8_t* uint8_table1;
945   uint8_t* uint8_table2;
946 };
947 
948 struct SpaceToBatchParams {
949   // "Zero" padding for uint8_t means padding with the output offset.
950   int32_t output_offset;
951 };
952 
953 struct SpaceToDepthParams {
954   int32_t block_size;
955 };
956 
957 struct SplitParams {
958   // Graphs that split into, say, 2000 nodes are encountered.  The indices in
959   // OperatorEdges are of type uint16_t.
960   uint16_t num_split;
961   int16_t axis;
962 };
963 
964 struct SqueezeParams {
965   int8_t squeeze_dims_count;
966   int32_t squeeze_dims[4];
967 };
968 
969 struct StridedSliceParams {
970   int8_t start_indices_count;
971   int32_t start_indices[5];
972   int8_t stop_indices_count;
973   int32_t stop_indices[5];
974   int8_t strides_count;
975   int32_t strides[5];
976 
977   uint16_t begin_mask;
978   uint16_t ellipsis_mask;
979   uint16_t end_mask;
980   uint16_t new_axis_mask;
981   uint16_t shrink_axis_mask;
982 };
983 
984 struct TanhParams {
985   int32_t input_zero_point;
986   int32_t input_range_radius;
987   int32_t input_multiplier;
988   int input_left_shift;
989 };
990 
991 struct TransposeParams {
992   int8_t perm_count;
993   int32_t perm[5];
994 };
995 
996 struct UnpackParams {
997   uint16_t num_split;
998   int16_t axis;
999 };
1000 
1001 struct LeakyReluParams {
1002   float alpha;
1003   int32_t input_offset;
1004   int32_t output_offset;
1005   int32_t output_multiplier_alpha;
1006   int32_t output_shift_alpha;
1007   int32_t output_multiplier_identity;
1008   int32_t output_shift_identity;
1009 };
1010 
1011 template <typename P>
SetActivationParams(float min,float max,P * params)1012 inline void SetActivationParams(float min, float max, P* params) {
1013   params->float_activation_min = min;
1014   params->float_activation_max = max;
1015 }
1016 
1017 template <typename P>
SetActivationParams(int32_t min,int32_t max,P * params)1018 inline void SetActivationParams(int32_t min, int32_t max, P* params) {
1019   params->quantized_activation_min = min;
1020   params->quantized_activation_max = max;
1021 }
1022 
1023 template <typename P>
SetActivationParams(int64_t min,int64_t max,P * params)1024 inline void SetActivationParams(int64_t min, int64_t max, P* params) {
1025   params->int64_activation_min = min;
1026   params->int64_activation_max = max;
1027 }
1028 
1029 template <typename P>
GetActivationParams(const P & params,int32_t * min,int32_t * max)1030 inline void GetActivationParams(const P& params, int32_t* min, int32_t* max) {
1031   *min = params.quantized_activation_min;
1032   *max = params.quantized_activation_max;
1033 }
1034 
1035 template <typename P>
GetActivationParams(const P & params,float * min,float * max)1036 inline void GetActivationParams(const P& params, float* min, float* max) {
1037   *min = params.float_activation_min;
1038   *max = params.float_activation_max;
1039 }
1040 
1041 template <typename P>
GetActivationParams(const P & params,int64_t * min,int64_t * max)1042 inline void GetActivationParams(const P& params, int64_t* min, int64_t* max) {
1043   *min = params.int64_activation_min;
1044   *max = params.int64_activation_max;
1045 }
1046 
1047 // Type trait to check of given type has size smaller than 4 bytes.
1048 template <typename T>
1049 struct is_small_integer
1050     : public std::integral_constant<bool,
1051                                     std::is_same<T, int8_t>::value ||
1052                                         std::is_same<T, uint8_t>::value ||
1053                                         std::is_same<T, int16_t>::value ||
1054                                         std::is_same<T, uint16_t>::value> {};
1055 
1056 // Type trait to check of given type is int32 or int64.
1057 template <typename T>
1058 struct is_int32_or_int64
1059     : public std::integral_constant<bool, std::is_same<T, int32_t>::value ||
1060                                               std::is_same<T, int64_t>::value> {
1061 };
1062 
1063 }  // namespace tflite
1064 
1065 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
1066