1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <initializer_list>
22
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 #include "tensorflow/lite/kernels/internal/runtime_shape.h"
25
26 namespace tflite {
27
28 enum class FusedActivationFunctionType : uint8_t {
29 kNone,
30 kRelu6,
31 kRelu1,
32 kRelu
33 };
34 enum class PaddingType : uint8_t { kNone, kSame, kValid };
35
36 struct PaddingValues {
37 int16_t width;
38 int16_t height;
39 // offset is used for calculating "remaining" padding, for example, `width`
40 // is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
41 // 1 + 1 = 2.
42 int16_t width_offset;
43 // Same as width_offset except it's over the height dimension.
44 int16_t height_offset;
45 };
46
47 struct Padding3DValues {
48 int16_t width;
49 int16_t height;
50 int16_t depth;
51 // offset is used for calculating "remaining" padding, for example, `width`
52 // is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
53 // 1 + 1 = 2.
54 int16_t width_offset;
55 // Same as width_offset except it's over the height dimension.
56 int16_t height_offset;
57 // Same as width_offset except it's over the depth dimension.
58 int16_t depth_offset;
59 };
60
61 // This enumeration allows for non-default formats for the weights array
62 // of a fully-connected operator, allowing the use of special optimized
63 // runtime paths.
64 enum class FullyConnectedWeightsFormat : uint8_t {
65 // Default format (flat 2D layout, the inner contiguous dimension
66 // is input_depth, the outer non-contiguous dimension is output_depth)
67 kDefault,
68 // Summary: optimized layout for fast CPU runtime implementation,
69 // aimed specifically at ARM CPUs at the moment, and specialized for
70 // 8-bit quantized layers.
71 //
72 // The use case we're concerned with here is: 8-bit quantization,
73 // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in
74 // a key application that drove this), very small batch size (e.g. 1 -- 4).
75 //
76 // Even with 8-bit quantization of weights, the performance of memory
77 // accesses to the weights can become the dominant issue when
78 // the batch size is small, so each weight value is used in only a few
79 // arithmetic ops, i.e. the fully-connected node has a low arithmetic
80 // intensity. The specific issues that arise are of three kinds:
81 // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory
82 // bound. That's the "good" issue to run into.
83 // (2) One may run into sub-optimal pre-fetching: the data hasn't been
84 // prefetched into the cache by the time we need it.
85 // (3) One may run into cache aliasing: multiple values that are
86 // pre-fetched, alias each other in the L1 cache (which typically
87 // has only 4-way set associativity in ARM CPUs) and thus evict
88 // each other before we get to using them.
89 //
90 // The point of this shuffling is to avoid issues (2) and (3) so that
91 // we get as fast as possible given only the hard constraint (1).
92 // This is achieved by turning the difficulty into a solution: the
93 // difficulty, that each value loaded from memory is used only in
94 // one kernel iteration, making this operation memory-intensive, hints at
95 // the solution, of shuffling the weights so that they are stored in the
96 // exact order as the kernel needs to load them, so that the memory
97 // accesses made by the kernel are trivial. This solves (2) because the
98 // trivial memory access pattern allows the CPU's automatic prefetching
99 // to perform very well (no need even for preload instructions), and this
100 // solves (3) because the values being loaded concurrently are now
101 // contiguous in the address space, thus don't alias each other in the cache.
102 //
103 // On ARM, we typically want our kernel to process a 4x16 block of weights
104 // at a time, because:
105 // - 16 is the number of bytes in a NEON register.
106 // - 4 is how many rows we need to handle concurrently in the kernel in
107 // order to have sufficient mutual independence of instructions to
108 // maximize arithmetic throughput.
109 //
110 // Finally, the 'Int8' part in the name refers to the fact that this
111 // weights format has each weights value encoded as a signed int8_t value,
112 // even if the data type of the weights buffer is uint8_t. This is intended
113 // to save runtime kernels the effort to have to XOR the top bit of these
114 // bytes before using them in signed arithmetic, see this file for more
115 // explanations on the 'signed int8_t trick' in matrix multiplication kernels:
116 //
117 // tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
118 //
119 kShuffled4x16Int8,
120 };
121
122 // Quantization parameters, determining the mapping of quantized values
123 // to real values (i.e. determining how quantized values are mathematically
124 // interpreted).
125 //
126 // The correspondence is as follows:
127 //
128 // real_value = scale * (quantized_value - zero_point);
129 //
130 // In other words, zero_point designates which quantized value corresponds to
131 // the real 0 value, and scale designates the difference between the real values
132 // corresponding to consecutive quantized values differing by 1.
133 struct QuantizationParams {
134 int32_t zero_point = 0;
135 double scale = 0.0;
136 };
137
138 inline bool operator==(const QuantizationParams& qp1,
139 const QuantizationParams& qp2) {
140 return qp1.zero_point == qp2.zero_point && qp1.scale == qp2.scale;
141 }
142
143 // Quantization parameters for each channel, determining the mapping of
144 // quantized values to real values. See QuantizationParams for a single set of
145 // parameters per tensor. This has one parameters set per each channel.
146 //
147 // The correspondence is as follows:
148 //
149 // real_value = scale[channel] * (quantized_value - zero_point[channel]);
150 //
151 struct PerChannelQuantizationParams {
152 // The following members typically point to the corresponding members of a
153 // TfLiteAffineQuantization struct.
154 const float* scale;
155 const int32_t* zero_point;
156 int32_t quantized_dimension;
157 };
158
159 // Gets next index to iterate through a multidimensional array.
NextIndex(const int num_dims,const int * dims,int * current)160 inline bool NextIndex(const int num_dims, const int* dims, int* current) {
161 if (num_dims == 0) {
162 return false;
163 }
164 TFLITE_DCHECK(dims != nullptr);
165 TFLITE_DCHECK(current != nullptr);
166 int carry = 1;
167 for (int idx = num_dims - 1; idx >= 0; --idx) {
168 int current_val = current[idx] + carry;
169 TFLITE_DCHECK_GE(dims[idx], current_val);
170 if (dims[idx] == current_val) {
171 current[idx] = 0;
172 } else {
173 current[idx] = current_val;
174 carry = 0;
175 break;
176 }
177 }
178 return (carry == 0);
179 }
180
181 // Gets offset of index if reducing on axis. When reducing, the flattened offset
182 // will not change, if the input index changes on the given axis. For example,
183 // if you have a 3D tensor and you are reducing to 2D by eliminating axis 0,
184 // then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened
185 // offset.
186 // TODO(kanlig): uses Dims to represent dimensions.
ReducedOutputOffset(const int num_dims,const int * dims,const int * index,const int num_axis,const int * axis)187 inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
188 const int* index, const int num_axis,
189 const int* axis) {
190 if (num_dims == 0) {
191 return 0;
192 }
193 TFLITE_DCHECK(dims != nullptr);
194 TFLITE_DCHECK(index != nullptr);
195 size_t offset = 0;
196 for (int idx = 0; idx < num_dims; ++idx) {
197 // if we need to skip this axis
198 bool is_axis = false;
199 if (axis != nullptr) {
200 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
201 if (idx == axis[axis_idx]) {
202 is_axis = true;
203 break;
204 }
205 }
206 }
207 if (!is_axis) {
208 offset = offset * static_cast<size_t>(dims[idx]) +
209 static_cast<size_t>(index[idx]);
210 }
211 }
212 return offset;
213 }
214
215 // Since tensors with '0' in their shape are valid in TF, these offset functions
216 // allow that as long as the corresponding index is also 0. It is upto the
217 // calling ops to ensure that they perform verification checks on tensor shapes
218 // if they don't support a particular behavior.
219
Offset(const Dims<4> & dims,int i0,int i1,int i2,int i3)220 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
221 TFLITE_DCHECK((i0 == 0 && dims.sizes[0] == 0) ||
222 (i0 >= 0 && i0 < dims.sizes[0]));
223 TFLITE_DCHECK((i1 == 0 && dims.sizes[1] == 0) ||
224 (i1 >= 0 && i1 < dims.sizes[1]));
225 TFLITE_DCHECK((i2 == 0 && dims.sizes[2] == 0) ||
226 (i2 >= 0 && i2 < dims.sizes[2]));
227 TFLITE_DCHECK((i3 == 0 && dims.sizes[3] == 0) ||
228 (i3 >= 0 && i3 < dims.sizes[3]));
229 return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
230 i3 * dims.strides[3];
231 }
232
Offset(const Dims<4> & dims,int * index)233 inline int Offset(const Dims<4>& dims, int* index) {
234 return Offset(dims, index[0], index[1], index[2], index[3]);
235 }
236
237 // Get array size, DCHECKing that the dim index is in range.
238 //
239 // Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
240 // already performs this check.
241 template <int N>
ArraySize(const Dims<N> & array,int index)242 int ArraySize(const Dims<N>& array, int index) {
243 TFLITE_DCHECK(index >= 0 && index < N);
244 return array.sizes[index];
245 }
246
247 // Get common array size, DCHECKing that they all agree.
248 template <typename ArrayType1, typename ArrayType2>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2)249 int MatchingArraySize(const ArrayType1& array1, int index1,
250 const ArrayType2& array2, int index2) {
251 TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
252 return ArraySize(array1, index1);
253 }
254
255 template <typename ArrayType1, typename ArrayType2, typename... Args>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2,Args...args)256 int MatchingArraySize(const ArrayType1& array1, int index1,
257 const ArrayType2& array2, int index2, Args... args) {
258 TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
259 return MatchingArraySize(array1, index1, args...);
260 }
261
262 // Get common shape dim, DCHECKing that they all agree.
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2)263 inline int MatchingDim(const RuntimeShape& shape1, int index1,
264 const RuntimeShape& shape2, int index2) {
265 TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
266 return std::min(shape1.Dims(index1), shape2.Dims(index2));
267 }
268
269 template <typename... Args>
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2,Args...args)270 int MatchingDim(const RuntimeShape& shape1, int index1,
271 const RuntimeShape& shape2, int index2, Args... args) {
272 TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
273 return MatchingDim(shape1, index1, args...);
274 }
275
276 // Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
277 template <int N>
FlatSize(const Dims<N> & dims)278 inline int FlatSize(const Dims<N>& dims) {
279 int flat_size = 1;
280 for (int i = 0; i < N; ++i) {
281 flat_size *= dims.sizes[i];
282 }
283 return flat_size;
284 }
285
286 TFLITE_DEPRECATED("Prefer FlatSize.")
RequiredBufferSizeForDims(const Dims<4> & dims)287 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
288 return FlatSize(dims);
289 }
290
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)291 inline int MatchingElementsSize(const RuntimeShape& shape,
292 const RuntimeShape& check_shape_0) {
293 const int size_1 = shape.FlatSize();
294 const int size_2 = check_shape_0.FlatSize();
295 TFLITE_CHECK_EQ(size_1, size_2);
296 return size_1;
297 }
298
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)299 inline int MatchingElementsSize(const RuntimeShape& shape,
300 const RuntimeShape& check_shape_0,
301 const RuntimeShape& check_shape_1) {
302 const int size_1 = shape.FlatSize();
303 const int size_2 = check_shape_0.FlatSize();
304 const int size_3 = check_shape_1.FlatSize();
305 TFLITE_CHECK_EQ(size_1, size_2);
306 TFLITE_CHECK_EQ(size_2, size_3);
307 return size_1;
308 }
309
310 // Flat size calculation, checking that dimensions match with one or more other
311 // arrays.
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)312 inline int MatchingFlatSize(const RuntimeShape& shape,
313 const RuntimeShape& check_shape_0) {
314 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
315 const int dims_count = shape.DimensionsCount();
316 for (int i = 0; i < dims_count; ++i) {
317 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
318 }
319 return shape.FlatSize();
320 }
321
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)322 inline int MatchingFlatSize(const RuntimeShape& shape,
323 const RuntimeShape& check_shape_0,
324 const RuntimeShape& check_shape_1) {
325 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
326 const int dims_count = shape.DimensionsCount();
327 for (int i = 0; i < dims_count; ++i) {
328 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
329 }
330 return MatchingFlatSize(shape, check_shape_1);
331 }
332
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)333 inline int MatchingFlatSize(const RuntimeShape& shape,
334 const RuntimeShape& check_shape_0,
335 const RuntimeShape& check_shape_1,
336 const RuntimeShape& check_shape_2) {
337 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
338 const int dims_count = shape.DimensionsCount();
339 for (int i = 0; i < dims_count; ++i) {
340 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
341 }
342 return MatchingFlatSize(shape, check_shape_1, check_shape_2);
343 }
344
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)345 inline int MatchingFlatSize(const RuntimeShape& shape,
346 const RuntimeShape& check_shape_0,
347 const RuntimeShape& check_shape_1,
348 const RuntimeShape& check_shape_2,
349 const RuntimeShape& check_shape_3) {
350 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
351 const int dims_count = shape.DimensionsCount();
352 for (int i = 0; i < dims_count; ++i) {
353 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
354 }
355 return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3);
356 }
357
358 // Flat size calculation, checking that dimensions match with one or more other
359 // arrays.
360 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0)361 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
362 for (int i = 0; i < N; ++i) {
363 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
364 }
365 return FlatSize(dims);
366 }
367
368 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)369 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
370 const Dims<N>& check_dims_1) {
371 for (int i = 0; i < N; ++i) {
372 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
373 }
374 return MatchingFlatSize(dims, check_dims_1);
375 }
376
377 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)378 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
379 const Dims<N>& check_dims_1,
380 const Dims<N>& check_dims_2) {
381 for (int i = 0; i < N; ++i) {
382 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
383 }
384 return MatchingFlatSize(dims, check_dims_1, check_dims_2);
385 }
386
387 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)388 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
389 const Dims<N>& check_dims_1,
390 const Dims<N>& check_dims_2,
391 const Dims<N>& check_dims_3) {
392 for (int i = 0; i < N; ++i) {
393 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
394 }
395 return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
396 }
397
398 // Flat size calculation, checking if their extended shapes match.
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)399 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
400 const RuntimeShape& check_shape_0) {
401 const int shape_dims = shape.DimensionsCount();
402 const int check_shape_0_dims = check_shape_0.DimensionsCount();
403 const int min_dims = std::min(shape_dims, check_shape_0_dims);
404
405 for (int i = 0; i < min_dims; ++i) {
406 TFLITE_DCHECK_EQ(shape.Dims(shape_dims - 1 - i),
407 check_shape_0.Dims(check_shape_0_dims - 1 - i));
408 }
409 for (int i = min_dims; i < shape_dims; ++i) {
410 TFLITE_DCHECK_EQ(shape.Dims(shape_dims - 1 - i), 1);
411 }
412 for (int i = min_dims; i < check_shape_0_dims; ++i) {
413 TFLITE_DCHECK_EQ(check_shape_0.Dims(check_shape_0_dims - 1 - i), 1);
414 }
415 return shape.FlatSize();
416 }
417
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)418 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
419 const RuntimeShape& check_shape_0,
420 const RuntimeShape& check_shape_1) {
421 const int flat_size = MatchingExtendedShapeFlatSize(shape, check_shape_0);
422 TFLITE_DCHECK_EQ(MatchingExtendedShapeFlatSize(shape, check_shape_1),
423 flat_size);
424 return flat_size;
425 }
426
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)427 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
428 const RuntimeShape& check_shape_0,
429 const RuntimeShape& check_shape_1,
430 const RuntimeShape& check_shape_2) {
431 const int flat_size = MatchingExtendedShapeFlatSize(shape, check_shape_0);
432 TFLITE_DCHECK_EQ(
433 MatchingExtendedShapeFlatSize(shape, check_shape_1, check_shape_2),
434 flat_size);
435 return flat_size;
436 }
437
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)438 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
439 const RuntimeShape& check_shape_0,
440 const RuntimeShape& check_shape_1,
441 const RuntimeShape& check_shape_2,
442 const RuntimeShape& check_shape_3) {
443 const int flat_size = MatchingExtendedShapeFlatSize(shape, check_shape_0);
444 TFLITE_DCHECK_EQ(MatchingExtendedShapeFlatSize(shape, check_shape_1,
445 check_shape_2, check_shape_3),
446 flat_size);
447 return flat_size;
448 }
449
450 // Data is required to be contiguous, and so many operators can use either the
451 // full array flat size or the flat size with one dimension skipped (commonly
452 // the depth).
453 template <int N>
FlatSizeSkipDim(const Dims<N> & dims,int skip_dim)454 inline int FlatSizeSkipDim(const Dims<N>& dims, int skip_dim) {
455 TFLITE_DCHECK(skip_dim >= 0 && skip_dim < N);
456 int flat_size = 1;
457 for (int i = 0; i < N; ++i) {
458 flat_size *= (i == skip_dim) ? 1 : dims.sizes[i];
459 }
460 return flat_size;
461 }
462
463 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
464 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0)465 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
466 const Dims<N>& check_dims_0) {
467 for (int i = 0; i < N; ++i) {
468 if (i != skip_dim) {
469 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
470 }
471 }
472 return FlatSizeSkipDim(dims, skip_dim);
473 }
474
475 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)476 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
477 const Dims<N>& check_dims_0,
478 const Dims<N>& check_dims_1) {
479 for (int i = 0; i < N; ++i) {
480 if (i != skip_dim) {
481 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
482 }
483 }
484 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1);
485 }
486
487 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)488 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
489 const Dims<N>& check_dims_0,
490 const Dims<N>& check_dims_1,
491 const Dims<N>& check_dims_2) {
492 for (int i = 0; i < N; ++i) {
493 if (i != skip_dim) {
494 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
495 }
496 }
497 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2);
498 }
499
500 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)501 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
502 const Dims<N>& check_dims_0,
503 const Dims<N>& check_dims_1,
504 const Dims<N>& check_dims_2,
505 const Dims<N>& check_dims_3) {
506 for (int i = 0; i < N; ++i) {
507 if (i != skip_dim) {
508 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
509 }
510 }
511 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2,
512 check_dims_3);
513 }
514
515 // Data is required to be contiguous, and so many operators can use either the
516 // full array flat size or the flat size with one dimension skipped (commonly
517 // the depth).
FlatSizeSkipDim(const RuntimeShape & shape,int skip_dim)518 inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
519 const int dims_count = shape.DimensionsCount();
520 TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
521 const auto* dims_data = shape.DimsData();
522 int flat_size = 1;
523 for (int i = 0; i < dims_count; ++i) {
524 flat_size *= (i == skip_dim) ? 1 : dims_data[i];
525 }
526 return flat_size;
527 }
528
529 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0)530 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
531 const RuntimeShape& check_shape_0) {
532 const int dims_count = shape.DimensionsCount();
533 for (int i = 0; i < dims_count; ++i) {
534 if (i != skip_dim) {
535 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
536 }
537 }
538 return FlatSizeSkipDim(shape, skip_dim);
539 }
540
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)541 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
542 const RuntimeShape& check_shape_0,
543 const RuntimeShape& check_shape_1) {
544 const int dims_count = shape.DimensionsCount();
545 for (int i = 0; i < dims_count; ++i) {
546 if (i != skip_dim) {
547 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
548 }
549 }
550 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
551 }
552
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)553 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
554 const RuntimeShape& check_shape_0,
555 const RuntimeShape& check_shape_1,
556 const RuntimeShape& check_shape_2) {
557 const int dims_count = shape.DimensionsCount();
558 for (int i = 0; i < dims_count; ++i) {
559 if (i != skip_dim) {
560 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
561 }
562 }
563 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
564 }
565
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)566 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
567 const RuntimeShape& check_shape_0,
568 const RuntimeShape& check_shape_1,
569 const RuntimeShape& check_shape_2,
570 const RuntimeShape& check_shape_3) {
571 const int dims_count = shape.DimensionsCount();
572 for (int i = 0; i < dims_count; ++i) {
573 if (i != skip_dim) {
574 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
575 }
576 }
577 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
578 check_shape_3);
579 }
580
581 template <int N>
IsPackedWithoutStrides(const Dims<N> & dims)582 bool IsPackedWithoutStrides(const Dims<N>& dims) {
583 int expected_stride = 1;
584 for (int d = 0; d < N; d++) {
585 if (dims.strides[d] != expected_stride) return false;
586 expected_stride *= dims.sizes[d];
587 }
588 return true;
589 }
590
591 template <int N>
ComputeStrides(Dims<N> * dims)592 void ComputeStrides(Dims<N>* dims) {
593 dims->strides[0] = 1;
594 for (int d = 1; d < N; d++) {
595 dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1];
596 }
597 }
598
599 enum class BroadcastableOpCategory : uint8_t {
600 kNone,
601 kNonBroadcast, // Matching input shapes.
602 kFirstInputBroadcastsFast, // Fivefold nested loops.
603 kSecondInputBroadcastsFast, // Fivefold nested loops.
604 kGenericBroadcast, // Fall-back.
605 };
606
607 struct MinMax {
608 float min;
609 float max;
610 };
611 static_assert(sizeof(MinMax) == 8, "");
612
613 struct ActivationParams {
614 FusedActivationFunctionType activation_type;
615 // uint8_t, etc, activation params.
616 int32_t quantized_activation_min;
617 int32_t quantized_activation_max;
618 };
619
620 struct ReluParams : public ActivationParams {
621 int32_t input_offset;
622 int32_t output_offset;
623 int32_t output_multiplier;
624 int output_shift;
625 };
626
627 // Styles of resizing op usages. For example, kImageStyle can be used with a Pad
628 // op for pattern-specific optimization.
629 enum class ResizingCategory : uint8_t {
630 kNone,
631 kImageStyle, // 4D, operating on inner dimensions, say {0, a, b, 0}.
632 kGenericResize,
633 };
634
635 // For Add, Sub, Mul ops.
636 struct ArithmeticParams {
637 // Shape dependent / common to data / op types.
638 BroadcastableOpCategory broadcast_category;
639 // uint8_t inference params.
640 int32_t input1_offset;
641 int32_t input2_offset;
642 int32_t output_offset;
643 int32_t output_multiplier;
644 int output_shift;
645 // Add / Sub, not Mul, uint8_t inference params.
646 int left_shift;
647 int32_t input1_multiplier;
648 int input1_shift;
649 int32_t input2_multiplier;
650 int input2_shift;
651
652 // TODO(b/158622529): Union the following activation params.
653 // uint8_t, etc, activation params.
654 int32_t quantized_activation_min;
655 int32_t quantized_activation_max;
656 // float activation params.
657 float float_activation_min;
658 float float_activation_max;
659 // int64_t activation params.
660 int64_t int64_activation_min;
661 int64_t int64_activation_max;
662
663 // Processed output dimensions.
664 // Let input "a" be the one that broadcasts in the faster-changing dimension.
665 // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
666 // {b0, b1, b2, b3, b4},
667 // broadcast_shape[4] = b0 = a0.
668 // broadcast_shape[3] = b1; a1 = 1.
669 // broadcast_shape[2] = b2 = a2.
670 // broadcast_shape[1] = a3; b3 = 1.
671 // broadcast_shape[0] = b4 = a4.
672 int broadcast_shape[5];
673 };
674
675 struct ConcatenationParams {
676 int8_t axis;
677 const int32_t* input_zeropoint;
678 const float* input_scale;
679 uint16_t inputs_count;
680 int32_t output_zeropoint;
681 float output_scale;
682 };
683
684 struct ComparisonParams {
685 // uint8_t inference params.
686 int left_shift;
687 int32_t input1_offset;
688 int32_t input1_multiplier;
689 int input1_shift;
690 int32_t input2_offset;
691 int32_t input2_multiplier;
692 int input2_shift;
693 // Shape dependent / common to inference types.
694 bool is_broadcast;
695 };
696
697 struct ConvParams {
698 PaddingType padding_type;
699 PaddingValues padding_values;
700 // TODO(starka): This was just "stride", so check that width+height is OK.
701 int16_t stride_width;
702 int16_t stride_height;
703 int16_t dilation_width_factor;
704 int16_t dilation_height_factor;
705 // uint8_t inference params.
706 // TODO(b/65838351): Use smaller types if appropriate.
707 int32_t input_offset;
708 int32_t weights_offset;
709 int32_t output_offset;
710 int32_t output_multiplier;
711 int output_shift;
712 // uint8_t, etc, activation params.
713 int32_t quantized_activation_min;
714 int32_t quantized_activation_max;
715 // float activation params.
716 float float_activation_min;
717 float float_activation_max;
718 };
719
720 struct Conv3DParams {
721 Padding3DValues padding_values;
722 int stride_width;
723 int stride_height;
724 int stride_depth;
725 int dilation_width;
726 int dilation_height;
727 int dilation_depth;
728 // float activation params.
729 float float_activation_min;
730 float float_activation_max;
731 };
732
733 typedef Conv3DParams Conv3DTransposeParams;
734
735 struct DepthToSpaceParams {
736 int32_t block_size;
737 };
738
739 struct DepthwiseParams {
740 PaddingType padding_type;
741 PaddingValues padding_values;
742 int16_t stride_width;
743 int16_t stride_height;
744 int16_t dilation_width_factor;
745 int16_t dilation_height_factor;
746 int16_t depth_multiplier;
747 // uint8_t inference params.
748 // TODO(b/65838351): Use smaller types if appropriate.
749 int32_t input_offset;
750 int32_t weights_offset;
751 int32_t output_offset;
752 int32_t output_multiplier;
753 int output_shift;
754 // uint8_t, etc, activation params.
755 int32_t quantized_activation_min;
756 int32_t quantized_activation_max;
757 // float activation params.
758 float float_activation_min;
759 float float_activation_max;
760 const int32_t* output_multiplier_per_channel;
761 const int32_t* output_shift_per_channel;
762 };
763
764 struct DequantizationParams {
765 double scale;
766 int32_t zero_point;
767 };
768
769 struct PerChannelDequantizationParams {
770 const float* scale;
771 const int32_t* zero_point;
772 int32_t quantized_dimension;
773 };
774
775 struct FakeQuantParams {
776 MinMax minmax;
777 int32_t num_bits;
778 };
779
780 struct FullyConnectedParams {
781 // uint8_t inference params.
782 // TODO(b/65838351): Use smaller types if appropriate.
783 int32_t input_offset;
784 int32_t weights_offset;
785 int32_t output_offset;
786 int32_t output_multiplier;
787 int output_shift;
788 // uint8_t, etc, activation params.
789 int32_t quantized_activation_min;
790 int32_t quantized_activation_max;
791 // float activation params.
792 float float_activation_min;
793 float float_activation_max;
794 // Mark the operands as cacheable if they are unchanging, e.g. weights.
795 bool lhs_cacheable;
796 bool rhs_cacheable;
797 FullyConnectedWeightsFormat weights_format;
798 };
799
800 struct GatherParams {
801 int16_t axis;
802 int16_t batch_dims;
803 };
804
805 struct L2NormalizationParams {
806 // uint8_t inference params.
807 int32_t input_zero_point;
808 };
809
810 struct LocalResponseNormalizationParams {
811 int32_t range;
812 double bias;
813 double alpha;
814 double beta;
815 };
816
817 struct HardSwishParams {
818 // zero_point of the input activations.
819 int16_t input_zero_point;
820 // zero_point of the output activations.
821 int16_t output_zero_point;
822 // 16bit fixed-point component of the multiplier to apply to go from the
823 // "high-res input scale", which is the input scale multiplied by 2^7, to the
824 // "relu-ish scale", which 3.0/32768.
825 // See the implementation of HardSwishPrepare.
826 int16_t reluish_multiplier_fixedpoint_int16;
827 // exponent/bit-shift component of the aforementioned multiplier.
828 int reluish_multiplier_exponent;
829 // 16bit fixed-point component of the multiplier to apply to go from the
830 // "high-res input scale", which is the input scale multiplied by 2^7, to the
831 // output scale.
832 // See the implementation of HardSwishPrepare.
833 int16_t output_multiplier_fixedpoint_int16;
834 // exponent/bit-shift component of the aforementioned multiplier.
835 int output_multiplier_exponent;
836 };
837
838 struct LogisticParams {
839 // uint8_t inference params.
840 int32_t input_zero_point;
841 int32_t input_range_radius;
842 int32_t input_multiplier;
843 int input_left_shift;
844 };
845
846 struct LstmCellParams {
847 int32_t weights_zero_point;
848 int32_t accum_multiplier;
849 int accum_shift;
850 int state_integer_bits;
851 };
852
853 struct MeanParams {
854 int8_t axis_count;
855 int16_t axis[4];
856 };
857
858 struct PackParams {
859 int8_t axis;
860 const int32_t* input_zeropoint;
861 const float* input_scale;
862 uint16_t inputs_count;
863 int32_t output_zeropoint;
864 float output_scale;
865 };
866
867 struct PadParams {
868 int8_t left_padding_count;
869 int32_t left_padding[5];
870 int8_t right_padding_count;
871 int32_t right_padding[5];
872 ResizingCategory resizing_category;
873 };
874
875 struct PreluParams {
876 int32_t input_offset;
877 int32_t alpha_offset;
878 int32_t output_offset;
879 int32_t output_multiplier_1;
880 int output_shift_1;
881 int32_t output_multiplier_2;
882 int output_shift_2;
883 };
884
885 struct PoolParams {
886 FusedActivationFunctionType activation;
887 PaddingType padding_type;
888 PaddingValues padding_values;
889 int stride_height;
890 int stride_width;
891 int filter_height;
892 int filter_width;
893 // uint8_t, etc, activation params.
894 int32_t quantized_activation_min;
895 int32_t quantized_activation_max;
896 // float activation params.
897 float float_activation_min;
898 float float_activation_max;
899 };
900
901 struct ReshapeParams {
902 int8_t shape_count;
903 int32_t shape[4];
904 };
905
906 struct ResizeBilinearParams {
907 bool align_corners;
908 // half_pixel_centers assumes pixels are of half the actual dimensions, and
909 // yields more accurate resizes. Corresponds to the same argument for the
910 // original TensorFlow op in TF2.0.
911 bool half_pixel_centers;
912 };
913
914 struct ResizeNearestNeighborParams {
915 bool align_corners;
916 bool half_pixel_centers;
917 };
918
919 struct SliceParams {
920 int8_t begin_count;
921 int32_t begin[5];
922 int8_t size_count;
923 int32_t size[5];
924 };
925
926 struct SoftmaxParams {
927 // beta is not really used (not a Tensorflow parameter) and not implemented
928 // for LogSoftmax.
929 double beta;
930 // uint8_t inference params. Used even when beta defaults to 1.0.
931 int32_t input_multiplier;
932 int32_t input_left_shift;
933 // Reverse scaling is only used by LogSoftmax.
934 int32_t reverse_scaling_divisor;
935 int32_t reverse_scaling_right_shift;
936 int diff_min;
937 int32_t zero_point;
938 float scale;
939 float* table;
940 // int16 LUT for exp(x), where x uniform distributed between [-10.0 , 0.0]
941 int16_t* exp_lut;
942 // int16 LUT for 1 / (1 + x), where x uniform distributed between [0.0 , 1.0]
943 int16_t* one_over_one_plus_x_lut;
944 uint8_t* uint8_table1;
945 uint8_t* uint8_table2;
946 };
947
948 struct SpaceToBatchParams {
949 // "Zero" padding for uint8_t means padding with the output offset.
950 int32_t output_offset;
951 };
952
953 struct SpaceToDepthParams {
954 int32_t block_size;
955 };
956
957 struct SplitParams {
958 // Graphs that split into, say, 2000 nodes are encountered. The indices in
959 // OperatorEdges are of type uint16_t.
960 uint16_t num_split;
961 int16_t axis;
962 };
963
964 struct SqueezeParams {
965 int8_t squeeze_dims_count;
966 int32_t squeeze_dims[4];
967 };
968
969 struct StridedSliceParams {
970 int8_t start_indices_count;
971 int32_t start_indices[5];
972 int8_t stop_indices_count;
973 int32_t stop_indices[5];
974 int8_t strides_count;
975 int32_t strides[5];
976
977 uint16_t begin_mask;
978 uint16_t ellipsis_mask;
979 uint16_t end_mask;
980 uint16_t new_axis_mask;
981 uint16_t shrink_axis_mask;
982 };
983
984 struct TanhParams {
985 int32_t input_zero_point;
986 int32_t input_range_radius;
987 int32_t input_multiplier;
988 int input_left_shift;
989 };
990
991 struct TransposeParams {
992 int8_t perm_count;
993 int32_t perm[5];
994 };
995
996 struct UnpackParams {
997 uint16_t num_split;
998 int16_t axis;
999 };
1000
1001 struct LeakyReluParams {
1002 float alpha;
1003 int32_t input_offset;
1004 int32_t output_offset;
1005 int32_t output_multiplier_alpha;
1006 int32_t output_shift_alpha;
1007 int32_t output_multiplier_identity;
1008 int32_t output_shift_identity;
1009 };
1010
1011 template <typename P>
SetActivationParams(float min,float max,P * params)1012 inline void SetActivationParams(float min, float max, P* params) {
1013 params->float_activation_min = min;
1014 params->float_activation_max = max;
1015 }
1016
1017 template <typename P>
SetActivationParams(int32_t min,int32_t max,P * params)1018 inline void SetActivationParams(int32_t min, int32_t max, P* params) {
1019 params->quantized_activation_min = min;
1020 params->quantized_activation_max = max;
1021 }
1022
1023 template <typename P>
SetActivationParams(int64_t min,int64_t max,P * params)1024 inline void SetActivationParams(int64_t min, int64_t max, P* params) {
1025 params->int64_activation_min = min;
1026 params->int64_activation_max = max;
1027 }
1028
1029 template <typename P>
GetActivationParams(const P & params,int32_t * min,int32_t * max)1030 inline void GetActivationParams(const P& params, int32_t* min, int32_t* max) {
1031 *min = params.quantized_activation_min;
1032 *max = params.quantized_activation_max;
1033 }
1034
1035 template <typename P>
GetActivationParams(const P & params,float * min,float * max)1036 inline void GetActivationParams(const P& params, float* min, float* max) {
1037 *min = params.float_activation_min;
1038 *max = params.float_activation_max;
1039 }
1040
1041 template <typename P>
GetActivationParams(const P & params,int64_t * min,int64_t * max)1042 inline void GetActivationParams(const P& params, int64_t* min, int64_t* max) {
1043 *min = params.int64_activation_min;
1044 *max = params.int64_activation_max;
1045 }
1046
1047 // Type trait to check of given type has size smaller than 4 bytes.
1048 template <typename T>
1049 struct is_small_integer
1050 : public std::integral_constant<bool,
1051 std::is_same<T, int8_t>::value ||
1052 std::is_same<T, uint8_t>::value ||
1053 std::is_same<T, int16_t>::value ||
1054 std::is_same<T, uint16_t>::value> {};
1055
1056 // Type trait to check of given type is int32 or int64.
1057 template <typename T>
1058 struct is_int32_or_int64
1059 : public std::integral_constant<bool, std::is_same<T, int32_t>::value ||
1060 std::is_same<T, int64_t>::value> {
1061 };
1062
1063 } // namespace tflite
1064
1065 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
1066