xref: /aosp_15_r20/external/ComputeLibrary/arm_compute/core/Utils.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2016-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_UTILS_H
25 #define ARM_COMPUTE_UTILS_H
26 
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/core/PixelValue.h"
29 #include "arm_compute/core/Rounding.h"
30 #include "arm_compute/core/Types.h"
31 #include "arm_compute/core/Version.h"
32 
33 #include <algorithm>
34 #include <cstdint>
35 #include <cstdlib>
36 #include <iomanip>
37 #include <numeric>
38 #include <sstream>
39 #include <string>
40 #include <type_traits>
41 #include <unordered_map>
42 #include <utility>
43 #include <vector>
44 
45 namespace arm_compute
46 {
47 class ITensor;
48 class ITensorInfo;
49 
50 /** Calculate the rounded up quotient of val / m.
51  *
52  * @param[in] val Value to divide and round up.
53  * @param[in] m   Value to divide by.
54  *
55  * @return the result.
56  */
57 template <typename S, typename T>
58 constexpr auto DIV_CEIL(S val, T m) -> decltype((val + m - 1) / m)
59 {
60     return (val + m - 1) / m;
61 }
62 
63 /** Computes the smallest number larger or equal to value that is a multiple of divisor.
64  *
65  * @param[in] value   Lower bound value
66  * @param[in] divisor Value to compute multiple of.
67  *
68  * @return the result.
69  */
70 template <typename S, typename T>
71 inline auto ceil_to_multiple(S value, T divisor) -> decltype(((value + divisor - 1) / divisor) * divisor)
72 {
73     ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0);
74     return DIV_CEIL(value, divisor) * divisor;
75 }
76 
77 /** Computes the largest number smaller or equal to value that is a multiple of divisor.
78  *
79  * @param[in] value   Upper bound value
80  * @param[in] divisor Value to compute multiple of.
81  *
82  * @return the result.
83  */
84 template <typename S, typename T>
85 inline auto floor_to_multiple(S value, T divisor) -> decltype((value / divisor) * divisor)
86 {
87     ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0);
88     return (value / divisor) * divisor;
89 }
90 
91 /** Load an entire file in memory
92  *
93  * @param[in] filename Name of the file to read.
94  * @param[in] binary   Is it a binary file ?
95  *
96  * @return The content of the file.
97  */
98 std::string read_file(const std::string &filename, bool binary);
99 
100 /** The size in bytes of the data type
101  *
102  * @param[in] data_type Input data type
103  *
104  * @return The size in bytes of the data type
105  */
data_size_from_type(DataType data_type)106 inline size_t data_size_from_type(DataType data_type)
107 {
108     switch(data_type)
109     {
110         case DataType::U8:
111         case DataType::S8:
112         case DataType::QSYMM8:
113         case DataType::QASYMM8:
114         case DataType::QASYMM8_SIGNED:
115         case DataType::QSYMM8_PER_CHANNEL:
116             return 1;
117         case DataType::U16:
118         case DataType::S16:
119         case DataType::QSYMM16:
120         case DataType::QASYMM16:
121         case DataType::BFLOAT16:
122         case DataType::F16:
123             return 2;
124         case DataType::F32:
125         case DataType::U32:
126         case DataType::S32:
127             return 4;
128         case DataType::F64:
129         case DataType::U64:
130         case DataType::S64:
131             return 8;
132         case DataType::SIZET:
133             return sizeof(size_t);
134         default:
135             ARM_COMPUTE_ERROR("Invalid data type");
136             return 0;
137     }
138 }
139 
140 /** The size in bytes of the pixel format
141  *
142  * @param[in] format Input format
143  *
144  * @return The size in bytes of the pixel format
145  */
pixel_size_from_format(Format format)146 inline size_t pixel_size_from_format(Format format)
147 {
148     switch(format)
149     {
150         case Format::U8:
151             return 1;
152         case Format::U16:
153         case Format::S16:
154         case Format::BFLOAT16:
155         case Format::F16:
156         case Format::UV88:
157         case Format::YUYV422:
158         case Format::UYVY422:
159             return 2;
160         case Format::RGB888:
161             return 3;
162         case Format::RGBA8888:
163             return 4;
164         case Format::U32:
165         case Format::S32:
166         case Format::F32:
167             return 4;
168         //Doesn't make sense for planar formats:
169         case Format::NV12:
170         case Format::NV21:
171         case Format::IYUV:
172         case Format::YUV444:
173         default:
174             ARM_COMPUTE_ERROR("Undefined pixel size for given format");
175             return 0;
176     }
177 }
178 
179 /** The size in bytes of the data type
180  *
181  * @param[in] dt Input data type
182  *
183  * @return The size in bytes of the data type
184  */
element_size_from_data_type(DataType dt)185 inline size_t element_size_from_data_type(DataType dt)
186 {
187     switch(dt)
188     {
189         case DataType::S8:
190         case DataType::U8:
191         case DataType::QSYMM8:
192         case DataType::QASYMM8:
193         case DataType::QASYMM8_SIGNED:
194         case DataType::QSYMM8_PER_CHANNEL:
195             return 1;
196         case DataType::U16:
197         case DataType::S16:
198         case DataType::QSYMM16:
199         case DataType::QASYMM16:
200         case DataType::BFLOAT16:
201         case DataType::F16:
202             return 2;
203         case DataType::U32:
204         case DataType::S32:
205         case DataType::F32:
206             return 4;
207         default:
208             ARM_COMPUTE_ERROR("Undefined element size for given data type");
209             return 0;
210     }
211 }
212 
213 /** Return the data type used by a given single-planar pixel format
214  *
215  * @param[in] format Input format
216  *
217  * @return The size in bytes of the pixel format
218  */
data_type_from_format(Format format)219 inline DataType data_type_from_format(Format format)
220 {
221     switch(format)
222     {
223         case Format::U8:
224         case Format::UV88:
225         case Format::RGB888:
226         case Format::RGBA8888:
227         case Format::YUYV422:
228         case Format::UYVY422:
229             return DataType::U8;
230         case Format::U16:
231             return DataType::U16;
232         case Format::S16:
233             return DataType::S16;
234         case Format::U32:
235             return DataType::U32;
236         case Format::S32:
237             return DataType::S32;
238         case Format::BFLOAT16:
239             return DataType::BFLOAT16;
240         case Format::F16:
241             return DataType::F16;
242         case Format::F32:
243             return DataType::F32;
244         //Doesn't make sense for planar formats:
245         case Format::NV12:
246         case Format::NV21:
247         case Format::IYUV:
248         case Format::YUV444:
249         default:
250             ARM_COMPUTE_ERROR("Not supported data_type for given format");
251             return DataType::UNKNOWN;
252     }
253 }
254 
255 /** Return the plane index of a given channel given an input format.
256  *
257  * @param[in] format  Input format
258  * @param[in] channel Input channel
259  *
260  * @return The plane index of the specific channel of the specific format
261  */
plane_idx_from_channel(Format format,Channel channel)262 inline int plane_idx_from_channel(Format format, Channel channel)
263 {
264     switch(format)
265     {
266         // Single planar formats have a single plane
267         case Format::U8:
268         case Format::U16:
269         case Format::S16:
270         case Format::U32:
271         case Format::S32:
272         case Format::BFLOAT16:
273         case Format::F16:
274         case Format::F32:
275         case Format::UV88:
276         case Format::RGB888:
277         case Format::RGBA8888:
278         case Format::YUYV422:
279         case Format::UYVY422:
280             return 0;
281         // Multi planar formats
282         case Format::NV12:
283         case Format::NV21:
284         {
285             // Channel U and V share the same plane of format UV88
286             switch(channel)
287             {
288                 case Channel::Y:
289                     return 0;
290                 case Channel::U:
291                 case Channel::V:
292                     return 1;
293                 default:
294                     ARM_COMPUTE_ERROR("Not supported channel");
295                     return 0;
296             }
297         }
298         case Format::IYUV:
299         case Format::YUV444:
300         {
301             switch(channel)
302             {
303                 case Channel::Y:
304                     return 0;
305                 case Channel::U:
306                     return 1;
307                 case Channel::V:
308                     return 2;
309                 default:
310                     ARM_COMPUTE_ERROR("Not supported channel");
311                     return 0;
312             }
313         }
314         default:
315             ARM_COMPUTE_ERROR("Not supported format");
316             return 0;
317     }
318 }
319 
320 /** Return the channel index of a given channel given an input format.
321  *
322  * @param[in] format  Input format
323  * @param[in] channel Input channel
324  *
325  * @return The channel index of the specific channel of the specific format
326  */
channel_idx_from_format(Format format,Channel channel)327 inline int channel_idx_from_format(Format format, Channel channel)
328 {
329     switch(format)
330     {
331         case Format::RGB888:
332         {
333             switch(channel)
334             {
335                 case Channel::R:
336                     return 0;
337                 case Channel::G:
338                     return 1;
339                 case Channel::B:
340                     return 2;
341                 default:
342                     ARM_COMPUTE_ERROR("Not supported channel");
343                     return 0;
344             }
345         }
346         case Format::RGBA8888:
347         {
348             switch(channel)
349             {
350                 case Channel::R:
351                     return 0;
352                 case Channel::G:
353                     return 1;
354                 case Channel::B:
355                     return 2;
356                 case Channel::A:
357                     return 3;
358                 default:
359                     ARM_COMPUTE_ERROR("Not supported channel");
360                     return 0;
361             }
362         }
363         case Format::YUYV422:
364         {
365             switch(channel)
366             {
367                 case Channel::Y:
368                     return 0;
369                 case Channel::U:
370                     return 1;
371                 case Channel::V:
372                     return 3;
373                 default:
374                     ARM_COMPUTE_ERROR("Not supported channel");
375                     return 0;
376             }
377         }
378         case Format::UYVY422:
379         {
380             switch(channel)
381             {
382                 case Channel::Y:
383                     return 1;
384                 case Channel::U:
385                     return 0;
386                 case Channel::V:
387                     return 2;
388                 default:
389                     ARM_COMPUTE_ERROR("Not supported channel");
390                     return 0;
391             }
392         }
393         case Format::NV12:
394         {
395             switch(channel)
396             {
397                 case Channel::Y:
398                     return 0;
399                 case Channel::U:
400                     return 0;
401                 case Channel::V:
402                     return 1;
403                 default:
404                     ARM_COMPUTE_ERROR("Not supported channel");
405                     return 0;
406             }
407         }
408         case Format::NV21:
409         {
410             switch(channel)
411             {
412                 case Channel::Y:
413                     return 0;
414                 case Channel::U:
415                     return 1;
416                 case Channel::V:
417                     return 0;
418                 default:
419                     ARM_COMPUTE_ERROR("Not supported channel");
420                     return 0;
421             }
422         }
423         case Format::YUV444:
424         case Format::IYUV:
425         {
426             switch(channel)
427             {
428                 case Channel::Y:
429                     return 0;
430                 case Channel::U:
431                     return 0;
432                 case Channel::V:
433                     return 0;
434                 default:
435                     ARM_COMPUTE_ERROR("Not supported channel");
436                     return 0;
437             }
438         }
439         default:
440             ARM_COMPUTE_ERROR("Not supported format");
441             return 0;
442     }
443 }
444 
445 /** Return the number of planes for a given format
446  *
447  * @param[in] format Input format
448  *
449  * @return The number of planes for a given image format.
450  */
num_planes_from_format(Format format)451 inline size_t num_planes_from_format(Format format)
452 {
453     switch(format)
454     {
455         case Format::U8:
456         case Format::S16:
457         case Format::U16:
458         case Format::S32:
459         case Format::U32:
460         case Format::BFLOAT16:
461         case Format::F16:
462         case Format::F32:
463         case Format::RGB888:
464         case Format::RGBA8888:
465         case Format::YUYV422:
466         case Format::UYVY422:
467             return 1;
468         case Format::NV12:
469         case Format::NV21:
470             return 2;
471         case Format::IYUV:
472         case Format::YUV444:
473             return 3;
474         default:
475             ARM_COMPUTE_ERROR("Not supported format");
476             return 0;
477     }
478 }
479 
480 /** Return the number of channels for a given single-planar pixel format
481  *
482  * @param[in] format Input format
483  *
484  * @return The number of channels for a given image format.
485  */
num_channels_from_format(Format format)486 inline size_t num_channels_from_format(Format format)
487 {
488     switch(format)
489     {
490         case Format::U8:
491         case Format::U16:
492         case Format::S16:
493         case Format::U32:
494         case Format::S32:
495         case Format::BFLOAT16:
496         case Format::F16:
497         case Format::F32:
498             return 1;
499         // Because the U and V channels are subsampled
500         // these formats appear like having only 2 channels:
501         case Format::YUYV422:
502         case Format::UYVY422:
503             return 2;
504         case Format::UV88:
505             return 2;
506         case Format::RGB888:
507             return 3;
508         case Format::RGBA8888:
509             return 4;
510         //Doesn't make sense for planar formats:
511         case Format::NV12:
512         case Format::NV21:
513         case Format::IYUV:
514         case Format::YUV444:
515         default:
516             return 0;
517     }
518 }
519 
520 /** Return the promoted data type of a given data type.
521  *
522  * @note If promoted data type is not supported an error will be thrown
523  *
524  * @param[in] dt Data type to get the promoted type of.
525  *
526  * @return Promoted data type
527  */
get_promoted_data_type(DataType dt)528 inline DataType get_promoted_data_type(DataType dt)
529 {
530     switch(dt)
531     {
532         case DataType::U8:
533             return DataType::U16;
534         case DataType::S8:
535             return DataType::S16;
536         case DataType::U16:
537             return DataType::U32;
538         case DataType::S16:
539             return DataType::S32;
540         case DataType::QSYMM8:
541         case DataType::QASYMM8:
542         case DataType::QASYMM8_SIGNED:
543         case DataType::QSYMM8_PER_CHANNEL:
544         case DataType::QSYMM16:
545         case DataType::QASYMM16:
546         case DataType::BFLOAT16:
547         case DataType::F16:
548         case DataType::U32:
549         case DataType::S32:
550         case DataType::F32:
551             ARM_COMPUTE_ERROR("Unsupported data type promotions!");
552         default:
553             ARM_COMPUTE_ERROR("Undefined data type!");
554     }
555     return DataType::UNKNOWN;
556 }
557 
558 /** Compute the mininum and maximum values a data type can take
559  *
560  * @param[in] dt Data type to get the min/max bounds of
561  *
562  * @return A tuple (min,max) with the minimum and maximum values respectively wrapped in PixelValue.
563  */
get_min_max(DataType dt)564 inline std::tuple<PixelValue, PixelValue> get_min_max(DataType dt)
565 {
566     PixelValue min{};
567     PixelValue max{};
568     switch(dt)
569     {
570         case DataType::U8:
571         case DataType::QASYMM8:
572         {
573             min = PixelValue(static_cast<int32_t>(std::numeric_limits<uint8_t>::lowest()));
574             max = PixelValue(static_cast<int32_t>(std::numeric_limits<uint8_t>::max()));
575             break;
576         }
577         case DataType::S8:
578         case DataType::QSYMM8:
579         case DataType::QASYMM8_SIGNED:
580         case DataType::QSYMM8_PER_CHANNEL:
581         {
582             min = PixelValue(static_cast<int32_t>(std::numeric_limits<int8_t>::lowest()));
583             max = PixelValue(static_cast<int32_t>(std::numeric_limits<int8_t>::max()));
584             break;
585         }
586         case DataType::U16:
587         case DataType::QASYMM16:
588         {
589             min = PixelValue(static_cast<int32_t>(std::numeric_limits<uint16_t>::lowest()));
590             max = PixelValue(static_cast<int32_t>(std::numeric_limits<uint16_t>::max()));
591             break;
592         }
593         case DataType::S16:
594         case DataType::QSYMM16:
595         {
596             min = PixelValue(static_cast<int32_t>(std::numeric_limits<int16_t>::lowest()));
597             max = PixelValue(static_cast<int32_t>(std::numeric_limits<int16_t>::max()));
598             break;
599         }
600         case DataType::U32:
601         {
602             min = PixelValue(std::numeric_limits<uint32_t>::lowest());
603             max = PixelValue(std::numeric_limits<uint32_t>::max());
604             break;
605         }
606         case DataType::S32:
607         {
608             min = PixelValue(std::numeric_limits<int32_t>::lowest());
609             max = PixelValue(std::numeric_limits<int32_t>::max());
610             break;
611         }
612         case DataType::BFLOAT16:
613         {
614             min = PixelValue(bfloat16::lowest());
615             max = PixelValue(bfloat16::max());
616             break;
617         }
618         case DataType::F16:
619         {
620             min = PixelValue(std::numeric_limits<half>::lowest());
621             max = PixelValue(std::numeric_limits<half>::max());
622             break;
623         }
624         case DataType::F32:
625         {
626             min = PixelValue(std::numeric_limits<float>::lowest());
627             max = PixelValue(std::numeric_limits<float>::max());
628             break;
629         }
630         default:
631             ARM_COMPUTE_ERROR("Undefined data type!");
632     }
633     return std::make_tuple(min, max);
634 }
635 
636 /** Return true if the given format has horizontal subsampling.
637  *
638  * @param[in] format Format to determine subsampling.
639  *
640  * @return True if the format can be subsampled horizontaly.
641  */
has_format_horizontal_subsampling(Format format)642 inline bool has_format_horizontal_subsampling(Format format)
643 {
644     return (format == Format::YUYV422 || format == Format::UYVY422 || format == Format::NV12 || format == Format::NV21 || format == Format::IYUV || format == Format::UV88) ? true : false;
645 }
646 
647 /** Return true if the given format has vertical subsampling.
648  *
649  * @param[in] format Format to determine subsampling.
650  *
651  * @return True if the format can be subsampled verticaly.
652  */
has_format_vertical_subsampling(Format format)653 inline bool has_format_vertical_subsampling(Format format)
654 {
655     return (format == Format::NV12 || format == Format::NV21 || format == Format::IYUV || format == Format::UV88) ? true : false;
656 }
657 
658 /** Adjust tensor shape size if width or height are odd for a given multi-planar format. No modification is done for other formats.
659  *
660  * @note Adding here a few links discussing the issue of odd size and sharing the same solution:
661  *       <a href="https://android.googlesource.com/platform/frameworks/base/+/refs/heads/master/graphics/java/android/graphics/YuvImage.java">Android Source</a>
662  *       <a href="https://groups.google.com/a/webmproject.org/forum/#!topic/webm-discuss/LaCKpqiDTXM">WebM</a>
663  *       <a href="https://bugs.chromium.org/p/libyuv/issues/detail?id=198&amp;can=1&amp;q=odd%20width">libYUV</a>
664  *       <a href="https://sourceforge.net/p/raw-yuvplayer/bugs/1/">YUVPlayer</a> *
665  *
666  * @param[in, out] shape  Tensor shape of 2D size
667  * @param[in]      format Format of the tensor
668  *
669  * @return The adjusted tensor shape.
670  */
adjust_odd_shape(const TensorShape & shape,Format format)671 inline TensorShape adjust_odd_shape(const TensorShape &shape, Format format)
672 {
673     TensorShape output{ shape };
674 
675     // Force width to be even for formats which require subsampling of the U and V channels
676     if(has_format_horizontal_subsampling(format))
677     {
678         output.set(0, (output.x() + 1) & ~1U);
679     }
680 
681     // Force height to be even for formats which require subsampling of the U and V channels
682     if(has_format_vertical_subsampling(format))
683     {
684         output.set(1, (output.y() + 1) & ~1U);
685     }
686 
687     return output;
688 }
689 
690 /** Calculate subsampled shape for a given format and channel
691  *
692  * @param[in] shape   Shape of the tensor to calculate the extracted channel.
693  * @param[in] format  Format of the tensor.
694  * @param[in] channel Channel to create tensor shape to be extracted.
695  *
696  * @return The subsampled tensor shape.
697  */
698 inline TensorShape calculate_subsampled_shape(const TensorShape &shape, Format format, Channel channel = Channel::UNKNOWN)
699 {
700     TensorShape output{ shape };
701 
702     // Subsample shape only for U or V channel
703     if(Channel::U == channel || Channel::V == channel || Channel::UNKNOWN == channel)
704     {
705         // Subsample width for the tensor shape when channel is U or V
706         if(has_format_horizontal_subsampling(format))
707         {
708             output.set(0, output.x() / 2U);
709         }
710 
711         // Subsample height for the tensor shape when channel is U or V
712         if(has_format_vertical_subsampling(format))
713         {
714             output.set(1, output.y() / 2U);
715         }
716     }
717 
718     return output;
719 }
720 
721 /** Permutes the given dimensions according the permutation vector
722  *
723  * @param[in,out] dimensions Dimensions to be permuted.
724  * @param[in]     perm       Vector describing the permutation.
725  *
726  */
727 template <typename T>
permute_strides(Dimensions<T> & dimensions,const PermutationVector & perm)728 inline void permute_strides(Dimensions<T> &dimensions, const PermutationVector &perm)
729 {
730     const auto old_dim = utility::make_array<Dimensions<T>::num_max_dimensions>(dimensions.begin(), dimensions.end());
731     for(unsigned int i = 0; i < perm.num_dimensions(); ++i)
732     {
733         T dimension_val = old_dim[i];
734         dimensions.set(perm[i], dimension_val);
735     }
736 }
737 
738 /** Calculate padding requirements in case of SAME padding
739  *
740  * @param[in] input_shape   Input shape
741  * @param[in] weights_shape Weights shape
742  * @param[in] conv_info     Convolution information (containing strides)
743  * @param[in] data_layout   (Optional) Data layout of the input and weights tensor
744  * @param[in] dilation      (Optional) Dilation factor used in the convolution.
745  * @param[in] rounding_type (Optional) Dimension rounding type when down-scaling.
746  *
747  * @return PadStrideInfo for SAME padding
748  */
749 PadStrideInfo calculate_same_pad(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info, DataLayout data_layout = DataLayout::NCHW, const Size2D &dilation = Size2D(1u, 1u),
750                                  const DimensionRoundingType &rounding_type = DimensionRoundingType::FLOOR);
751 
752 /** Returns expected width and height of the deconvolution's output tensor.
753  *
754  * @param[in] in_width        Width of input tensor (Number of columns)
755  * @param[in] in_height       Height of input tensor (Number of rows)
756  * @param[in] kernel_width    Kernel width.
757  * @param[in] kernel_height   Kernel height.
758  * @param[in] pad_stride_info Pad and stride information.
759  *
760  * @return A pair with the new width in the first position and the new height in the second.
761  */
762 std::pair<unsigned int, unsigned int> deconvolution_output_dimensions(unsigned int in_width, unsigned int in_height,
763                                                                       unsigned int kernel_width, unsigned int kernel_height,
764                                                                       const PadStrideInfo &pad_stride_info);
765 
766 /** Returns expected width and height of output scaled tensor depending on dimensions rounding mode.
767  *
768  * @param[in] width           Width of input tensor (Number of columns)
769  * @param[in] height          Height of input tensor (Number of rows)
770  * @param[in] kernel_width    Kernel width.
771  * @param[in] kernel_height   Kernel height.
772  * @param[in] pad_stride_info Pad and stride information.
773  * @param[in] dilation        (Optional) Dilation, in elements, across x and y. Defaults to (1, 1).
774  *
775  * @return A pair with the new width in the first position and the new height in the second.
776  */
777 std::pair<unsigned int, unsigned int> scaled_dimensions(int width, int height,
778                                                         int kernel_width, int kernel_height,
779                                                         const PadStrideInfo &pad_stride_info,
780                                                         const Size2D        &dilation = Size2D(1U, 1U));
781 
782 /** Returns calculated width and height of output scaled tensor depending on dimensions rounding mode.
783  *
784  * @param[in] width           Width of input tensor (Number of columns)
785  * @param[in] height          Height of input tensor (Number of rows)
786  * @param[in] kernel_width    Kernel width.
787  * @param[in] kernel_height   Kernel height.
788  * @param[in] pad_stride_info Pad and stride information.
789  *
790  * @return A pair with the new width in the first position and the new height in the second, returned values can be < 1
791  */
792 std::pair<int, int> scaled_dimensions_signed(int width, int height,
793                                              int kernel_width, int kernel_height,
794                                              const PadStrideInfo &pad_stride_info);
795 
796 /** Returns calculated width, height and depth of output scaled tensor depending on dimensions rounding mode.
797  *
798  * @param[in] width         Width of input tensor
799  * @param[in] height        Height of input tensor
800  * @param[in] depth         Depth of input tensor
801  * @param[in] kernel_width  Kernel width.
802  * @param[in] kernel_height Kernel height.
803  * @param[in] kernel_depth  Kernel depth.
804  * @param[in] pool3d_info   Pad and stride and round information for 3d pooling
805  *
806  * @return A tuple with the new width in the first position, the new height in the second, and the new depth in the third.
807  *         Returned values can be < 1
808  */
809 std::tuple<int, int, int> scaled_3d_dimensions_signed(int width, int height, int depth,
810                                                       int kernel_width, int kernel_height, int kernel_depth,
811                                                       const Pooling3dLayerInfo &pool3d_info);
812 
813 /** Check if the given reduction operation should be handled in a serial way.
814  *
815  * @param[in] op   Reduction operation to perform
816  * @param[in] dt   Data type
817  * @param[in] axis Axis along which to reduce
818  *
819  * @return True if the given reduction operation should be handled in a serial way.
820  */
821 bool needs_serialized_reduction(ReductionOperation op, DataType dt, unsigned int axis);
822 
823 /** Returns output quantization information for softmax layer
824  *
825  * @param[in] input_type The data type of the input tensor
826  * @param[in] is_log     True for log softmax
827  *
828  * @return Quantization information for the output tensor
829  */
830 QuantizationInfo get_softmax_output_quantization_info(DataType input_type, bool is_log);
831 
832 /** Returns a pair of minimum and maximum values for a quantized activation
833  *
834  * @param[in] act_info  The information for activation
835  * @param[in] data_type The used data type
836  * @param[in] oq_info   The output quantization information
837  *
838  * @return The pair with minimum and maximum values
839  */
840 std::pair<int32_t, int32_t> get_quantized_activation_min_max(ActivationLayerInfo act_info, DataType data_type, UniformQuantizationInfo oq_info);
841 
842 /** Convert a tensor format into a string.
843  *
844  * @param[in] format @ref Format to be translated to string.
845  *
846  * @return The string describing the format.
847  */
848 const std::string &string_from_format(Format format);
849 
850 /** Convert a channel identity into a string.
851  *
852  * @param[in] channel @ref Channel to be translated to string.
853  *
854  * @return The string describing the channel.
855  */
856 const std::string &string_from_channel(Channel channel);
857 /** Convert a data layout identity into a string.
858  *
859  * @param[in] dl @ref DataLayout to be translated to string.
860  *
861  * @return The string describing the data layout.
862  */
863 const std::string &string_from_data_layout(DataLayout dl);
864 /** Convert a data type identity into a string.
865  *
866  * @param[in] dt @ref DataType to be translated to string.
867  *
868  * @return The string describing the data type.
869  */
870 const std::string &string_from_data_type(DataType dt);
871 /** Translates a given activation function to a string.
872  *
873  * @param[in] act @ref ActivationLayerInfo::ActivationFunction to be translated to string.
874  *
875  * @return The string describing the activation function.
876  */
877 const std::string &string_from_activation_func(ActivationLayerInfo::ActivationFunction act);
878 /** Translates a given interpolation policy to a string.
879  *
880  * @param[in] policy @ref InterpolationPolicy to be translated to string.
881  *
882  * @return The string describing the interpolation policy.
883  */
884 const std::string &string_from_interpolation_policy(InterpolationPolicy policy);
885 /** Translates a given border mode policy to a string.
886  *
887  * @param[in] border_mode @ref BorderMode to be translated to string.
888  *
889  * @return The string describing the border mode.
890  */
891 const std::string &string_from_border_mode(BorderMode border_mode);
892 /** Translates a given normalization type to a string.
893  *
894  * @param[in] type @ref NormType to be translated to string.
895  *
896  * @return The string describing the normalization type.
897  */
898 const std::string &string_from_norm_type(NormType type);
899 /** Translates a given pooling type to a string.
900  *
901  * @param[in] type @ref PoolingType to be translated to string.
902  *
903  * @return The string describing the pooling type.
904  */
905 const std::string &string_from_pooling_type(PoolingType type);
906 /** Check if the pool region is entirely outside the input tensor
907  *
908  * @param[in] info @ref PoolingLayerInfo to be checked.
909  *
910  * @return True if the pool region is entirely outside the input tensor, False otherwise.
911  */
912 bool is_pool_region_entirely_outside_input(const PoolingLayerInfo &info);
913 /** Check if the 3d pool region is entirely outside the input tensor
914  *
915  * @param[in] info @ref Pooling3dLayerInfo to be checked.
916  *
917  * @return True if the pool region is entirely outside the input tensor, False otherwise.
918  */
919 bool is_pool_3d_region_entirely_outside_input(const Pooling3dLayerInfo &info);
920 /** Check if the 3D padding is symmetric i.e. padding in each opposite sides are euqal (left=right, top=bottom and front=back)
921  *
922  * @param[in] info @ref Padding3D input 3D padding object to check if it is symmetric
923  *
924  * @return True if padding is symmetric
925  */
is_symmetric(const Padding3D & info)926 inline bool is_symmetric(const Padding3D& info)
927 {
928     return ((info.left == info.right) && (info.top == info.bottom) && (info.front == info.back));
929 }
930 /** Translates a given GEMMLowp output stage to a string.
931  *
932  * @param[in] output_stage @ref GEMMLowpOutputStageInfo to be translated to string.
933  *
934  * @return The string describing the GEMMLowp output stage
935  */
936 const std::string &string_from_gemmlowp_output_stage(GEMMLowpOutputStageType output_stage);
937 /** Convert a PixelValue to a string, represented through the specific data type
938  *
939  * @param[in] value     The PixelValue to convert
940  * @param[in] data_type The type to be used to convert the @p value
941  *
942  * @return String representation of the PixelValue through the given data type.
943  */
944 std::string string_from_pixel_value(const PixelValue &value, const DataType data_type);
945 /** Convert a string to DataType
946  *
947  * @param[in] name The name of the data type
948  *
949  * @return DataType
950  */
951 DataType data_type_from_name(const std::string &name);
952 /** Stores padding information before configuring a kernel
953  *
954  * @param[in] infos list of tensor infos to store the padding info for
955  *
956  * @return An unordered map where each tensor info pointer is paired with its original padding info
957  */
958 std::unordered_map<const ITensorInfo *, PaddingSize> get_padding_info(std::initializer_list<const ITensorInfo *> infos);
959 /** Stores padding information before configuring a kernel
960  *
961  * @param[in] tensors list of tensors to store the padding info for
962  *
963  * @return An unordered map where each tensor info pointer is paired with its original padding info
964  */
965 std::unordered_map<const ITensorInfo *, PaddingSize> get_padding_info(std::initializer_list<const ITensor *> tensors);
966 /** Check if the previously stored padding info has changed after configuring a kernel
967  *
968  * @param[in] padding_map an unordered map where each tensor info pointer is paired with its original padding info
969  *
970  * @return true if any of the tensor infos has changed its paddings
971  */
972 bool has_padding_changed(const std::unordered_map<const ITensorInfo *, PaddingSize> &padding_map);
973 
974 /** Input Stream operator for @ref DataType
975  *
976  * @param[in]  stream    Stream to parse
977  * @param[out] data_type Output data type
978  *
979  * @return Updated stream
980  */
981 inline ::std::istream &operator>>(::std::istream &stream, DataType &data_type)
982 {
983     std::string value;
984     stream >> value;
985     data_type = data_type_from_name(value);
986     return stream;
987 }
988 /** Lower a given string.
989  *
990  * @param[in] val Given string to lower.
991  *
992  * @return The lowered string
993  */
994 std::string lower_string(const std::string &val);
995 
996 /** Raise a given string to upper case
997  *
998  * @param[in] val Given string to lower.
999  *
1000  * @return The upper case string
1001  */
1002 std::string upper_string(const std::string &val);
1003 
1004 /** Check if a given data type is of floating point type
1005  *
1006  * @param[in] dt Input data type.
1007  *
1008  * @return True if data type is of floating point type, else false.
1009  */
is_data_type_float(DataType dt)1010 inline bool is_data_type_float(DataType dt)
1011 {
1012     switch(dt)
1013     {
1014         case DataType::F16:
1015         case DataType::F32:
1016             return true;
1017         default:
1018             return false;
1019     }
1020 }
1021 
1022 /** Check if a given data type is of quantized type
1023  *
1024  * @note Quantized is considered a super-set of fixed-point and asymmetric data types.
1025  *
1026  * @param[in] dt Input data type.
1027  *
1028  * @return True if data type is of quantized type, else false.
1029  */
is_data_type_quantized(DataType dt)1030 inline bool is_data_type_quantized(DataType dt)
1031 {
1032     switch(dt)
1033     {
1034         case DataType::QSYMM8:
1035         case DataType::QASYMM8:
1036         case DataType::QASYMM8_SIGNED:
1037         case DataType::QSYMM8_PER_CHANNEL:
1038         case DataType::QSYMM16:
1039         case DataType::QASYMM16:
1040             return true;
1041         default:
1042             return false;
1043     }
1044 }
1045 
1046 /** Check if a given data type is of asymmetric quantized type
1047  *
1048  * @param[in] dt Input data type.
1049  *
1050  * @return True if data type is of asymmetric quantized type, else false.
1051  */
is_data_type_quantized_asymmetric(DataType dt)1052 inline bool is_data_type_quantized_asymmetric(DataType dt)
1053 {
1054     switch(dt)
1055     {
1056         case DataType::QASYMM8:
1057         case DataType::QASYMM8_SIGNED:
1058         case DataType::QASYMM16:
1059             return true;
1060         default:
1061             return false;
1062     }
1063 }
1064 
1065 /** Check if a given data type is of asymmetric quantized signed type
1066  *
1067  * @param[in] dt Input data type.
1068  *
1069  * @return True if data type is of asymmetric quantized signed type, else false.
1070  */
is_data_type_quantized_asymmetric_signed(DataType dt)1071 inline bool is_data_type_quantized_asymmetric_signed(DataType dt)
1072 {
1073     switch(dt)
1074     {
1075         case DataType::QASYMM8_SIGNED:
1076             return true;
1077         default:
1078             return false;
1079     }
1080 }
1081 
1082 /** Check if a given data type is of symmetric quantized type
1083  *
1084  * @param[in] dt Input data type.
1085  *
1086  * @return True if data type is of symmetric quantized type, else false.
1087  */
is_data_type_quantized_symmetric(DataType dt)1088 inline bool is_data_type_quantized_symmetric(DataType dt)
1089 {
1090     switch(dt)
1091     {
1092         case DataType::QSYMM8:
1093         case DataType::QSYMM8_PER_CHANNEL:
1094         case DataType::QSYMM16:
1095             return true;
1096         default:
1097             return false;
1098     }
1099 }
1100 
1101 /** Check if a given data type is of per channel type
1102  *
1103  * @param[in] dt Input data type.
1104  *
1105  * @return True if data type is of per channel type, else false.
1106  */
is_data_type_quantized_per_channel(DataType dt)1107 inline bool is_data_type_quantized_per_channel(DataType dt)
1108 {
1109     switch(dt)
1110     {
1111         case DataType::QSYMM8_PER_CHANNEL:
1112             return true;
1113         default:
1114             return false;
1115     }
1116 }
1117 
1118 /** Create a string with the float in full precision.
1119  *
1120  * @param val Floating point value
1121  *
1122  * @return String with the floating point value.
1123  */
float_to_string_with_full_precision(float val)1124 inline std::string float_to_string_with_full_precision(float val)
1125 {
1126     std::stringstream ss;
1127     ss.precision(std::numeric_limits<float>::max_digits10);
1128     ss << val;
1129 
1130     if(val != static_cast<int>(val))
1131     {
1132         ss << "f";
1133     }
1134 
1135     return ss.str();
1136 }
1137 
1138 /** Returns the number of elements required to go from start to end with the wanted step
1139  *
1140  * @param[in] start start value
1141  * @param[in] end   end value
1142  * @param[in] step  step value between each number in the wanted sequence
1143  *
1144  * @return number of elements to go from start value to end value using the wanted step
1145  */
num_of_elements_in_range(const float start,const float end,const float step)1146 inline size_t num_of_elements_in_range(const float start, const float end, const float step)
1147 {
1148     ARM_COMPUTE_ERROR_ON_MSG(step == 0, "Range Step cannot be 0");
1149     return size_t(std::ceil((end - start) / step));
1150 }
1151 
1152 /** Returns true if the value can be represented by the given data type
1153  *
1154  * @param[in] val   value to be checked
1155  * @param[in] dt    data type that is checked
1156  * @param[in] qinfo (Optional) quantization info if the data type is QASYMM8
1157  *
1158  * @return true if the data type can hold the value.
1159  */
1160 template <typename T>
1161 bool check_value_range(T val, DataType dt, QuantizationInfo qinfo = QuantizationInfo())
1162 {
1163     switch(dt)
1164     {
1165         case DataType::U8:
1166         {
1167             const auto val_u8 = static_cast<uint8_t>(val);
1168             return ((val_u8 == val) && val >= std::numeric_limits<uint8_t>::lowest() && val <= std::numeric_limits<uint8_t>::max());
1169         }
1170         case DataType::QASYMM8:
1171         {
1172             double min = static_cast<double>(dequantize_qasymm8(0, qinfo));
1173             double max = static_cast<double>(dequantize_qasymm8(std::numeric_limits<uint8_t>::max(), qinfo));
1174             return ((double)val >= min && (double)val <= max);
1175         }
1176         case DataType::S8:
1177         {
1178             const auto val_s8 = static_cast<int8_t>(val);
1179             return ((val_s8 == val) && val >= std::numeric_limits<int8_t>::lowest() && val <= std::numeric_limits<int8_t>::max());
1180         }
1181         case DataType::U16:
1182         {
1183             const auto val_u16 = static_cast<uint16_t>(val);
1184             return ((val_u16 == val) && val >= std::numeric_limits<uint16_t>::lowest() && val <= std::numeric_limits<uint16_t>::max());
1185         }
1186         case DataType::S16:
1187         {
1188             const auto val_s16 = static_cast<int16_t>(val);
1189             return ((val_s16 == val) && val >= std::numeric_limits<int16_t>::lowest() && val <= std::numeric_limits<int16_t>::max());
1190         }
1191         case DataType::U32:
1192         {
1193             const auto val_d64 = static_cast<double>(val);
1194             const auto val_u32 = static_cast<uint32_t>(val);
1195             return ((val_u32 == val_d64) && val_d64 >= std::numeric_limits<uint32_t>::lowest() && val_d64 <= std::numeric_limits<uint32_t>::max());
1196         }
1197         case DataType::S32:
1198         {
1199             const auto val_d64 = static_cast<double>(val);
1200             const auto val_s32 = static_cast<int32_t>(val);
1201             return ((val_s32 == val_d64) && val_d64 >= std::numeric_limits<int32_t>::lowest() && val_d64 <= std::numeric_limits<int32_t>::max());
1202         }
1203         case DataType::BFLOAT16:
1204             return (val >= bfloat16::lowest() && val <= bfloat16::max());
1205         case DataType::F16:
1206             return (val >= std::numeric_limits<half>::lowest() && val <= std::numeric_limits<half>::max());
1207         case DataType::F32:
1208             return (val >= std::numeric_limits<float>::lowest() && val <= std::numeric_limits<float>::max());
1209         default:
1210             ARM_COMPUTE_ERROR("Data type not supported");
1211             return false;
1212     }
1213 }
1214 
1215 /** Returns the adjusted vector size in case it is less than the input's first dimension, getting rounded down to its closest valid vector size
1216  *
1217  * @param[in] vec_size vector size to be adjusted
1218  * @param[in] dim0     size of the first dimension
1219  *
1220  * @return the number of element processed along the X axis per thread
1221  */
adjust_vec_size(unsigned int vec_size,size_t dim0)1222 inline unsigned int adjust_vec_size(unsigned int vec_size, size_t dim0)
1223 {
1224     ARM_COMPUTE_ERROR_ON(vec_size > 16);
1225 
1226     if((vec_size >= dim0) && (dim0 == 3))
1227     {
1228         return dim0;
1229     }
1230 
1231     while(vec_size > dim0)
1232     {
1233         vec_size >>= 1;
1234     }
1235 
1236     return vec_size;
1237 }
1238 
1239 /** Returns the suffix string of CPU kernel implementation names based on the given data type
1240  *
1241  * @param[in] data_type The data type the CPU kernel implemetation uses
1242  *
1243  * @return the suffix string of CPU kernel implementations
1244  */
cpu_impl_dt(const DataType & data_type)1245 inline std::string cpu_impl_dt(const DataType &data_type)
1246 {
1247     std::string ret = "";
1248 
1249     switch(data_type)
1250     {
1251         case DataType::F32:
1252             ret = "fp32";
1253             break;
1254         case DataType::F16:
1255             ret = "fp16";
1256             break;
1257         case DataType::U8:
1258             ret = "u8";
1259             break;
1260         case DataType::S16:
1261             ret = "s16";
1262             break;
1263         case DataType::S32:
1264             ret = "s32";
1265             break;
1266         case DataType::QASYMM8:
1267             ret = "qu8";
1268             break;
1269         case DataType::QASYMM8_SIGNED:
1270             ret = "qs8";
1271             break;
1272         case DataType::QSYMM16:
1273             ret = "qs16";
1274             break;
1275         case DataType::QSYMM8_PER_CHANNEL:
1276             ret = "qp8";
1277             break;
1278         case DataType::BFLOAT16:
1279             ret = "bf16";
1280             break;
1281         default:
1282             ARM_COMPUTE_ERROR("Unsupported.");
1283     }
1284 
1285     return ret;
1286 }
1287 
1288 #ifdef ARM_COMPUTE_ASSERTS_ENABLED
1289 /** Print consecutive elements to an output stream.
1290  *
1291  * @param[out] s             Output stream to print the elements to.
1292  * @param[in]  ptr           Pointer to print the elements from.
1293  * @param[in]  n             Number of elements to print.
1294  * @param[in]  stream_width  (Optional) Width of the stream. If set to 0 the element's width is used. Defaults to 0.
1295  * @param[in]  element_delim (Optional) Delimeter among the consecutive elements. Defaults to space delimeter
1296  */
1297 template <typename T>
1298 void print_consecutive_elements_impl(std::ostream &s, const T *ptr, unsigned int n, int stream_width = 0, const std::string &element_delim = " ")
1299 {
1300     using print_type = typename std::conditional<std::is_floating_point<T>::value, T, int>::type;
1301     std::ios stream_status(nullptr);
1302     stream_status.copyfmt(s);
1303 
1304     for(unsigned int i = 0; i < n; ++i)
1305     {
1306         // Set stream width as it is not a "sticky" stream manipulator
1307         if(stream_width != 0)
1308         {
1309             s.width(stream_width);
1310         }
1311 
1312         if(std::is_same<typename std::decay<T>::type, half>::value)
1313         {
1314             // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
1315             s << std::right << static_cast<T>(ptr[i]) << element_delim;
1316         }
1317         else if(std::is_same<typename std::decay<T>::type, bfloat16>::value)
1318         {
1319             // We use T instead of print_type here is because the std::is_floating_point<bfloat16> returns false and then the print_type becomes int.
1320             s << std::right << float(ptr[i]) << element_delim;
1321         }
1322         else
1323         {
1324             s << std::right << static_cast<print_type>(ptr[i]) << element_delim;
1325         }
1326     }
1327 
1328     // Restore output stream flags
1329     s.copyfmt(stream_status);
1330 }
1331 
1332 /** Identify the maximum width of n consecutive elements.
1333  *
1334  * @param[in] s   The output stream which will be used to print the elements. Used to extract the stream format.
1335  * @param[in] ptr Pointer to the elements.
1336  * @param[in] n   Number of elements.
1337  *
1338  * @return The maximum width of the elements.
1339  */
1340 template <typename T>
max_consecutive_elements_display_width_impl(std::ostream & s,const T * ptr,unsigned int n)1341 int max_consecutive_elements_display_width_impl(std::ostream &s, const T *ptr, unsigned int n)
1342 {
1343     using print_type = typename std::conditional<std::is_floating_point<T>::value, T, int>::type;
1344 
1345     int max_width = -1;
1346     for(unsigned int i = 0; i < n; ++i)
1347     {
1348         std::stringstream ss;
1349         ss.copyfmt(s);
1350 
1351         if(std::is_same<typename std::decay<T>::type, half>::value)
1352         {
1353             // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
1354             ss << static_cast<T>(ptr[i]);
1355         }
1356         else if(std::is_same<typename std::decay<T>::type, bfloat16>::value)
1357         {
1358             // We use T instead of print_type here is because the std::is_floating_point<bfloat> returns false and then the print_type becomes int.
1359             ss << float(ptr[i]);
1360         }
1361         else
1362         {
1363             ss << static_cast<print_type>(ptr[i]);
1364         }
1365 
1366         max_width = std::max<int>(max_width, ss.str().size());
1367     }
1368     return max_width;
1369 }
1370 
1371 /** Print consecutive elements to an output stream.
1372  *
1373  * @param[out] s             Output stream to print the elements to.
1374  * @param[in]  dt            Data type of the elements
1375  * @param[in]  ptr           Pointer to print the elements from.
1376  * @param[in]  n             Number of elements to print.
1377  * @param[in]  stream_width  (Optional) Width of the stream. If set to 0 the element's width is used. Defaults to 0.
1378  * @param[in]  element_delim (Optional) Delimeter among the consecutive elements. Defaults to space delimeter
1379  */
1380 void print_consecutive_elements(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n, int stream_width, const std::string &element_delim = " ");
1381 
1382 /** Identify the maximum width of n consecutive elements.
1383  *
1384  * @param[in] s   Output stream to print the elements to.
1385  * @param[in] dt  Data type of the elements
1386  * @param[in] ptr Pointer to print the elements from.
1387  * @param[in] n   Number of elements to print.
1388  *
1389  * @return The maximum width of the elements.
1390  */
1391 int max_consecutive_elements_display_width(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n);
1392 #endif /* ARM_COMPUTE_ASSERTS_ENABLED */
1393 }
1394 #endif /*ARM_COMPUTE_UTILS_H */
1395