xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/Decoders.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "BaseIterator.hpp"
9 
10 #include <armnnUtils/FloatingPointConverter.hpp>
11 #include <armnnUtils/TensorUtils.hpp>
12 
13 #include <armnn/utility/Assert.hpp>
14 
15 namespace armnn
16 {
17 
18 namespace
19 {
20 
MakeSigned32PerAxisDecoder(const TensorInfo & info,const void * data)21 inline std::unique_ptr<Decoder<float>> MakeSigned32PerAxisDecoder(const TensorInfo& info, const void* data)
22 {
23     return std::make_unique<ScaledInt32PerAxisDecoder>(static_cast<const int32_t*>(data), info);
24 }
25 
MakeSigned32Decoder(const TensorInfo & info,const void * data)26 inline std::unique_ptr<Decoder<float>> MakeSigned32Decoder(const TensorInfo& info, const void* data)
27 {
28     if(info.HasMultipleQuantizationScales())
29     {
30         // NOTE: If we have multiple quantization scales, we create a ScaledInt32PerAxisDecoder.
31         // This will be used to decode per-axis quantized convolution biases.
32         return MakeSigned32PerAxisDecoder(info, data);
33     }
34     else
35     {
36         if (info.GetQuantizationDim().has_value())
37         {
38             // NOTE: Even though we only have a single quantization scale, if the quantization
39             // dimension is set, the tensor has per-axis quantization and we need to create a
40             // ScaledInt32PerAxisDecoder
41             return MakeSigned32PerAxisDecoder(info, data);
42         }
43 
44         const float scale = info.GetQuantizationScale();
45         if (scale == 0.f)
46         {
47             // NOTE:: If no quantization scale is set, we create an Int32Decoder, which simply
48             // casts the int value to float. This will be used for any INT32 data other than
49             // convolution biases.
50             return std::make_unique<Int32Decoder>(static_cast<const int32_t*>(data));
51         }
52 
53         // NOTE: If we only have a single (non-zero) quantization scale and no quantization
54         // dimension is specified, we need to create a ScaledInt32Decoder. This will be used
55         // to decode per-tensor quantized convolution biases.
56         return std::make_unique<ScaledInt32Decoder>(static_cast<const int32_t*>(data), scale);
57     }
58 }
59 
60 } // anonymous namespace
61 
62 template<typename T>
63 inline std::unique_ptr<Decoder<T>> MakeDecoder(const TensorInfo& info, const void* data = nullptr);
64 
65 template<>
MakeDecoder(const TensorInfo & info,const void * data)66 inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const void* data)
67 {
68     switch(info.GetDataType())
69     {
70         case DataType::QAsymmS8:
71         {
72             return std::make_unique<QASymmS8Decoder>(
73                 static_cast<const int8_t*>(data),
74                 info.GetQuantizationScale(),
75                 info.GetQuantizationOffset());
76         }
77         case DataType::QAsymmU8:
78         {
79             return std::make_unique<QASymm8Decoder>(
80                 static_cast<const uint8_t*>(data),
81                 info.GetQuantizationScale(),
82                 info.GetQuantizationOffset());
83         }
84         case DataType::QSymmS16:
85         {
86             return std::make_unique<QSymm16Decoder>(
87                 static_cast<const int16_t*>(data),
88                 info.GetQuantizationScale(),
89                 info.GetQuantizationOffset());
90         }
91         case DataType::Float16:
92         {
93             return std::make_unique<Float16Decoder>(static_cast<const Half*>(data));
94         }
95         case DataType::Float32:
96         {
97             return std::make_unique<Float32Decoder>(static_cast<const float*>(data));
98         }
99         case DataType::Signed32:
100         {
101             return MakeSigned32Decoder(info, data);
102         }
103         case DataType::QSymmS8:
104         {
105             if (info.HasPerAxisQuantization())
106             {
107                 std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
108                 return std::make_unique<QSymm8PerAxisDecoder>(static_cast<const int8_t*>(data), info);
109             }
110             else
111             {
112                 return std::make_unique<QSymmS8Decoder>(
113                     static_cast<const int8_t*>(data),
114                     info.GetQuantizationScale(),
115                     info.GetQuantizationOffset());
116             }
117         }
118         case armnn::DataType::Boolean:
119         {
120             return std::make_unique<BooleanDecoder>(static_cast<const uint8_t*>(data));
121         }
122         default:
123         {
124             ARMNN_ASSERT_MSG(false, "Unsupported Data Type!");
125             break;
126         }
127     }
128     return nullptr;
129 }
130 
131 template<>
MakeDecoder(const TensorInfo & info,const void * data)132 inline std::unique_ptr<Decoder<bool>> MakeDecoder(const TensorInfo& info, const void* data)
133 {
134     switch(info.GetDataType())
135     {
136         case DataType::Boolean:
137         {
138             return std::make_unique<BooleanDecoderBool>(static_cast<const uint8_t*>(data));
139         }
140         default:
141         {
142             ARMNN_ASSERT_MSG(false, "Unsupported Data Type!");
143             break;
144         }
145     }
146     return nullptr;
147 }
148 
149 template<>
MakeDecoder(const TensorInfo & info,const void * data)150 inline std::unique_ptr<Decoder<int32_t>> MakeDecoder(const TensorInfo& info, const void* data)
151 {
152     switch(info.GetDataType())
153     {
154         case DataType::Signed32:
155         {
156             return std::make_unique<Int32ToInt32tDecoder>(static_cast<const int32_t*>(data));
157         }
158         default:
159         {
160             ARMNN_ASSERT_MSG(false, "Unsupported Data Type!");
161             break;
162         }
163     }
164     return nullptr;
165 }
166 
167 } //namespace armnn
168