xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/NEROIAlignLayerKernel.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2019-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/core/NEON/kernels/NEROIAlignLayerKernel.h"
25 
26 #include "arm_compute/core/Helpers.h"
27 #include "arm_compute/core/TensorInfo.h"
28 #include "arm_compute/core/Utils.h"
29 #include "arm_compute/core/Window.h"
30 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
31 #include "arm_compute/core/utils/misc/Utility.h"
32 #include "src/core/CPP/Validate.h"
33 #include "src/core/common/Registrars.h"
34 #include "src/core/helpers/AutoConfiguration.h"
35 #include "src/core/helpers/WindowHelpers.h"
36 #include "src/cpu/kernels/roialign/list.h"
37 
38 #include <arm_neon.h>
39 
40 using namespace arm_compute::misc::shape_calculator;
41 
42 namespace arm_compute
43 {
44 namespace
45 {
46 struct ROIAlignSelectorData
47 {
48     DataType dt;
49 };
50 
51 using ROIAlignSelctorPtr = std::add_pointer<bool(const ROIAlignSelectorData &data)>::type;
52 using ROIAlignUKernelPtr = std::add_pointer<void(const ITensor *input, ITensor *output, const ITensor *rois, ROIPoolingLayerInfo pool_info, const Window &window, const ThreadInfo &info)>::type;
53 
54 struct ROIAlignKernel
55 {
56     const char              *name;
57     const ROIAlignSelctorPtr is_selected;
58     ROIAlignUKernelPtr       ukernel;
59 };
60 
61 static const ROIAlignKernel available_kernels[] =
62 {
63     {
64         "fp32_neon_roialign",
__anona16cef500202() 65         [](const ROIAlignSelectorData & data) { return data.dt == DataType::F32; },
66         REGISTER_FP32_NEON(arm_compute::cpu::neon_fp32_roialign)
67     },
68 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
69     {
70         "fp16_neon_roialign",
__anona16cef500302() 71         [](const ROIAlignSelectorData & data) { return data.dt == DataType::F16; },
72         REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_roialign)
73     },
74 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
75 #if defined(ARM_COMPUTE_ENABLE_NEON)
76     {
77         "qu8_neon_roialign",
__anona16cef500402() 78         [](const ROIAlignSelectorData & data) { return data.dt == DataType::QASYMM8; },
79         REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qu8_roialign)
80     },
81     {
82         "qs8_neon_roialign",
__anona16cef500502() 83         [](const ROIAlignSelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED; },
84         REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_qs8_roialign)
85     },
86 #endif //defined(ARM_COMPUTE_ENABLE_NEON)
87 };
88 
89 /** Micro-kernel selector
90  *
91  * @param[in] data Selection data passed to help pick the appropriate micro-kernel
92  *
93  * @return A matching micro-kernel else nullptr
94  */
get_implementation(const ROIAlignSelectorData & data)95 const ROIAlignKernel *get_implementation(const ROIAlignSelectorData &data)
96 {
97     for(const auto &uk : available_kernels)
98     {
99         if(uk.is_selected(data))
100         {
101             return &uk;
102         }
103     }
104     return nullptr;
105 }
106 
validate_arguments(const ITensorInfo * input,const ITensorInfo * rois,ITensorInfo * output,const ROIPoolingLayerInfo & pool_info)107 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *rois, ITensorInfo *output, const ROIPoolingLayerInfo &pool_info)
108 {
109     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, rois, output);
110     ARM_COMPUTE_RETURN_ERROR_ON(rois->dimension(0) != 5);
111     ARM_COMPUTE_RETURN_ERROR_ON(rois->num_dimensions() > 2);
112     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F32, DataType::F16);
113     ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(input, DataLayout::NHWC, DataLayout::NCHW);
114     ARM_COMPUTE_RETURN_ERROR_ON((pool_info.pooled_width() == 0) || (pool_info.pooled_height() == 0));
115     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
116 
117     if(output->total_size() != 0)
118     {
119         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
120         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
121         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(compute_roi_align_shape(*input, *rois, pool_info), output->tensor_shape());
122     }
123 
124     if(input->data_type() == DataType::QASYMM8 || input->data_type() == DataType::QASYMM8_SIGNED)
125     {
126         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(rois, 1, DataType::QASYMM16);
127 
128         const UniformQuantizationInfo rois_qinfo = rois->quantization_info().uniform();
129         ARM_COMPUTE_RETURN_ERROR_ON(rois_qinfo.scale != 0.125f);
130         ARM_COMPUTE_RETURN_ERROR_ON(rois_qinfo.offset != 0);
131     }
132     else
133     {
134         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, rois);
135     }
136 
137     return Status{};
138 }
139 } // namespace
140 
NEROIAlignLayerKernel()141 NEROIAlignLayerKernel::NEROIAlignLayerKernel()
142     : _input(nullptr), _output(nullptr), _rois(nullptr), _pool_info(0, 0, 0.f)
143 {
144 }
145 
configure(const ITensor * input,const ITensor * rois,ITensor * output,const ROIPoolingLayerInfo & pool_info)146 void NEROIAlignLayerKernel::configure(const ITensor *input, const ITensor *rois, ITensor *output, const ROIPoolingLayerInfo &pool_info)
147 {
148     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output, rois);
149     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), rois->info(), output->info(), pool_info));
150     // Output auto inizialitation if not yet initialized
151     const TensorShape output_shape = compute_roi_align_shape(*input->info(), *rois->info(), pool_info);
152     auto_init_if_empty((*output->info()), output_shape, 1, input->info()->data_type(), input->info()->quantization_info());
153     output->info()->set_data_layout(input->info()->data_layout());
154 
155     // Configure kernel window
156     const unsigned int num_rois = rois->info()->dimension(1);
157     Window             window;
158     window.set(Window::DimX, Window::Dimension(0, num_rois));
159     window.set(Window::DimY, Window::Dimension(0, 1));
160 
161     // Set instance variables
162     _input     = input;
163     _rois      = rois;
164     _output    = output;
165     _pool_info = pool_info;
166 
167     INEKernel::configure(window);
168 }
169 
validate(const ITensorInfo * input,const ITensorInfo * rois,ITensorInfo * output,const ROIPoolingLayerInfo & pool_info)170 Status NEROIAlignLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *rois, ITensorInfo *output, const ROIPoolingLayerInfo &pool_info)
171 {
172     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, rois, output, pool_info));
173     return Status{};
174 }
175 
run(const Window & window,const ThreadInfo & info)176 void NEROIAlignLayerKernel::run(const Window &window, const ThreadInfo &info)
177 {
178     const DataLayout data_layout = _input->info()->data_layout();
179     if(data_layout == DataLayout::NCHW || data_layout == DataLayout::NHWC)
180     {
181         const auto *uk = get_implementation(ROIAlignSelectorData{ _input->info()->data_type() });
182         ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
183 
184         uk->ukernel(_input, _output, _rois, _pool_info, window, info);
185     }
186     else
187     {
188         ARM_COMPUTE_ERROR("Invalid layout");
189     }
190 }
191 } // namespace arm_compute
192