1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <armnn/utility/IgnoreUnused.hpp>
9 #include <armnn/utility/NumericCast.hpp>
10 #include <armnn/TypesUtils.hpp>
11
12 #include <BFloat16.hpp>
13 #include <Half.hpp>
14
15 #include <initializer_list>
16 #include <iterator>
17 #include <vector>
18
19 namespace armnnUtils
20 {
21
22 template<typename T, bool DoQuantize=true>
23 struct SelectiveQuantizer
24 {
QuantizearmnnUtils::SelectiveQuantizer25 static T Quantize(float value, float scale, int32_t offset)
26 {
27 return armnn::Quantize<T>(value, scale, offset);
28 }
29
DequantizearmnnUtils::SelectiveQuantizer30 static float Dequantize(T value, float scale, int32_t offset)
31 {
32 return armnn::Dequantize(value, scale, offset);
33 }
34 };
35
36 template<typename T>
37 struct SelectiveQuantizer<T, false>
38 {
QuantizearmnnUtils::SelectiveQuantizer39 static T Quantize(float value, float scale, int32_t offset)
40 {
41 armnn::IgnoreUnused(scale, offset);
42 return value;
43 }
44
DequantizearmnnUtils::SelectiveQuantizer45 static float Dequantize(T value, float scale, int32_t offset)
46 {
47 armnn::IgnoreUnused(scale, offset);
48 return value;
49 }
50 };
51
52 template<>
53 struct SelectiveQuantizer<armnn::Half, false>
54 {
QuantizearmnnUtils::SelectiveQuantizer55 static armnn::Half Quantize(float value, float scale, int32_t offset)
56 {
57 armnn::IgnoreUnused(scale, offset);
58 return armnn::Half(value);
59 }
60
DequantizearmnnUtils::SelectiveQuantizer61 static float Dequantize(armnn::Half value, float scale, int32_t offset)
62 {
63 armnn::IgnoreUnused(scale, offset);
64 return value;
65 }
66 };
67
68 template<>
69 struct SelectiveQuantizer<armnn::BFloat16, false>
70 {
QuantizearmnnUtils::SelectiveQuantizer71 static armnn::BFloat16 Quantize(float value, float scale, int32_t offset)
72 {
73 armnn::IgnoreUnused(scale, offset);
74 return armnn::BFloat16(value);
75 }
76
DequantizearmnnUtils::SelectiveQuantizer77 static float Dequantize(armnn::BFloat16 value, float scale, int32_t offset)
78 {
79 armnn::IgnoreUnused(scale, offset);
80 return value;
81 }
82 };
83
84 template<typename T>
SelectiveQuantize(float value,float scale,int32_t offset)85 T SelectiveQuantize(float value, float scale, int32_t offset)
86 {
87 return SelectiveQuantizer<T, armnn::IsQuantizedType<T>()>::Quantize(value, scale, offset);
88 };
89
90 template<typename T>
SelectiveDequantize(T value,float scale,int32_t offset)91 float SelectiveDequantize(T value, float scale, int32_t offset)
92 {
93 return SelectiveQuantizer<T, armnn::IsQuantizedType<T>()>::Dequantize(value, scale, offset);
94 };
95
96 template<typename ItType>
97 struct IsFloatingPointIterator
98 {
99 static constexpr bool value=std::is_floating_point<typename std::iterator_traits<ItType>::value_type>::value;
100 };
101
102 template <typename T, typename FloatIt,
103 typename std::enable_if<IsFloatingPointIterator<FloatIt>::value, int>::type=0 // Makes sure fp iterator is valid.
104 >
QuantizedVector(FloatIt first,FloatIt last,float qScale,int32_t qOffset)105 std::vector<T> QuantizedVector(FloatIt first, FloatIt last, float qScale, int32_t qOffset)
106 {
107 std::vector<T> quantized;
108 quantized.reserve(armnn::numeric_cast<size_t>(std::distance(first, last)));
109
110 for (auto it = first; it != last; ++it)
111 {
112 auto f = *it;
113 T q = SelectiveQuantize<T>(f, qScale, qOffset);
114 quantized.push_back(q);
115 }
116
117 return quantized;
118 }
119
120 template<typename T>
QuantizedVector(const std::vector<float> & array,float qScale=1.f,int32_t qOffset=0)121 std::vector<T> QuantizedVector(const std::vector<float>& array, float qScale = 1.f, int32_t qOffset = 0)
122 {
123 return QuantizedVector<T>(array.begin(), array.end(), qScale, qOffset);
124 }
125
126 template<typename T>
QuantizedVector(std::initializer_list<float> array,float qScale=1.f,int32_t qOffset=0)127 std::vector<T> QuantizedVector(std::initializer_list<float> array, float qScale = 1.f, int32_t qOffset = 0)
128 {
129 return QuantizedVector<T>(array.begin(), array.end(), qScale, qOffset);
130 }
131
132 } // namespace armnnUtils
133