xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/roialign/generic/neon/impl.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2019-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/cpu/kernels/roialign/generic/neon/impl.h"
25 #include "src/core/NEON/INEKernel.h"
26 namespace arm_compute
27 {
28 namespace cpu
29 {
30 /** Average pooling over an aligned window */
31 template <typename input_data_type>
roi_align_1x1(const ITensor * input,unsigned int roi_batch,float region_start_x,float bin_size_x,int grid_size_x,float region_end_x,float region_start_y,float bin_size_y,int grid_size_y,float region_end_y,int pz)32 inline input_data_type roi_align_1x1(const ITensor *input,
33                                      unsigned int   roi_batch,
34                                      float          region_start_x,
35                                      float          bin_size_x,
36                                      int            grid_size_x,
37                                      float          region_end_x,
38                                      float          region_start_y,
39                                      float          bin_size_y,
40                                      int            grid_size_y,
41                                      float          region_end_y,
42                                      int            pz)
43 {
44     if((region_end_x <= region_start_x) || (region_end_y <= region_start_y))
45     {
46         return input_data_type(0);
47     }
48     else
49     {
50         const DataLayout data_layout = input->info()->data_layout();
51         float            avg         = 0;
52         // Iterate through the aligned pooling region
53         for(int iy = 0; iy < grid_size_y; ++iy)
54         {
55             for(int ix = 0; ix < grid_size_x; ++ix)
56             {
57                 // Align the window in the middle of every bin
58                 float y = region_start_y + (iy + 0.5) * bin_size_y / float(grid_size_y);
59                 float x = region_start_x + (ix + 0.5) * bin_size_x / float(grid_size_x);
60 
61                 // Interpolation in the [0,0] [0,1] [1,0] [1,1] square
62                 const int y_low  = y;
63                 const int x_low  = x;
64                 const int y_high = y_low + 1;
65                 const int x_high = x_low + 1;
66 
67                 const float ly = y - y_low;
68                 const float lx = x - x_low;
69                 const float hy = 1. - ly;
70                 const float hx = 1. - lx;
71 
72                 const float w1 = hy * hx;
73                 const float w2 = hy * lx;
74                 const float w3 = ly * hx;
75                 const float w4 = ly * lx;
76                 if(data_layout == DataLayout::NCHW)
77                 {
78                     const auto data1 = *reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_low, pz, roi_batch)));
79                     const auto data2 = *reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_low, pz, roi_batch)));
80                     const auto data3 = *reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_high, pz, roi_batch)));
81                     const auto data4 = *reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_high, pz, roi_batch)));
82                     avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
83                 }
84                 else
85                 {
86                     const auto data1 = *reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_low, roi_batch)));
87                     const auto data2 = *reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_low, roi_batch)));
88                     const auto data3 = *reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_high, roi_batch)));
89                     const auto data4 = *reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_high, roi_batch)));
90                     avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
91                 }
92             }
93         }
94 
95         avg /= grid_size_x * grid_size_y;
96         return input_data_type(avg);
97     }
98 }
99 
100 /** Average pooling over an aligned window */
101 template <typename input_data_type>
roi_align_1x1_qasymm8(const ITensor * input,unsigned int roi_batch,float region_start_x,float bin_size_x,int grid_size_x,float region_end_x,float region_start_y,float bin_size_y,int grid_size_y,float region_end_y,int pz,const QuantizationInfo & out_qinfo)102 inline input_data_type roi_align_1x1_qasymm8(const ITensor          *input,
103                                              unsigned int            roi_batch,
104                                              float                   region_start_x,
105                                              float                   bin_size_x,
106                                              int                     grid_size_x,
107                                              float                   region_end_x,
108                                              float                   region_start_y,
109                                              float                   bin_size_y,
110                                              int                     grid_size_y,
111                                              float                   region_end_y,
112                                              int                     pz,
113                                              const QuantizationInfo &out_qinfo)
114 {
115     if((region_end_x <= region_start_x) || (region_end_y <= region_start_y))
116     {
117         return input_data_type(out_qinfo.uniform().offset);
118     }
119     else
120     {
121         float                         avg              = 0;
122         const UniformQuantizationInfo input_qinfo      = input->info()->quantization_info().uniform();
123         const bool                    is_qasymm_signed = is_data_type_quantized_asymmetric_signed(input->info()->data_type());
124         const DataLayout              data_layout      = input->info()->data_layout();
125 
126         // Iterate through the aligned pooling region
127         for(int iy = 0; iy < grid_size_y; ++iy)
128         {
129             for(int ix = 0; ix < grid_size_x; ++ix)
130             {
131                 // Align the window in the middle of every bin
132                 float y = region_start_y + (iy + 0.5) * bin_size_y / float(grid_size_y);
133                 float x = region_start_x + (ix + 0.5) * bin_size_x / float(grid_size_x);
134 
135                 // Interpolation in the [0,0] [0,1] [1,0] [1,1] square
136                 const int y_low  = y;
137                 const int x_low  = x;
138                 const int y_high = y_low + 1;
139                 const int x_high = x_low + 1;
140 
141                 const float ly = y - y_low;
142                 const float lx = x - x_low;
143                 const float hy = 1. - ly;
144                 const float hx = 1. - lx;
145 
146                 const float w1 = hy * hx;
147                 const float w2 = hy * lx;
148                 const float w3 = ly * hx;
149                 const float w4 = ly * lx;
150 
151                 if(data_layout == DataLayout::NCHW)
152                 {
153                     if(is_qasymm_signed)
154                     {
155                         float data1 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_low, pz, roi_batch))), input_qinfo);
156                         float data2 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_low, pz, roi_batch))), input_qinfo);
157                         float data3 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_high, pz, roi_batch))), input_qinfo);
158                         float data4 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_high, pz, roi_batch))), input_qinfo);
159                         avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
160                     }
161                     else
162                     {
163                         float data1 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_low, pz, roi_batch))), input_qinfo);
164                         float data2 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_low, pz, roi_batch))), input_qinfo);
165                         float data3 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_high, pz, roi_batch))), input_qinfo);
166                         float data4 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_high, pz, roi_batch))), input_qinfo);
167                         avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
168                     }
169                 }
170                 else
171                 {
172                     if(is_qasymm_signed)
173                     {
174                         const auto data1 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_low, roi_batch))), input_qinfo);
175                         const auto data2 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_low, roi_batch))), input_qinfo);
176                         const auto data3 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_high, roi_batch))), input_qinfo);
177                         const auto data4 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_high, roi_batch))), input_qinfo);
178                         avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
179                     }
180                     else
181                     {
182                         const auto data1 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_low, roi_batch))), input_qinfo);
183                         const auto data2 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_low, roi_batch))), input_qinfo);
184                         const auto data3 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_high, roi_batch))), input_qinfo);
185                         const auto data4 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_high, roi_batch))), input_qinfo);
186                         avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
187                     }
188                 }
189             }
190         }
191 
192         avg /= grid_size_x * grid_size_y;
193 
194         input_data_type res = 0;
195         if(is_qasymm_signed)
196         {
197             res = quantize_qasymm8_signed(avg, out_qinfo);
198         }
199         else
200         {
201             res = quantize_qasymm8(avg, out_qinfo);
202         }
203         return res;
204     }
205 }
compute_region_coordinate(int p,float bin_size,float roi_anchor,float max_value)206 inline float compute_region_coordinate(int p, float bin_size, float roi_anchor, float max_value)
207 {
208     const float region_start = p * bin_size + roi_anchor;
209     return utility::clamp(region_start, 0.0f, max_value);
210 }
211 
212 template <typename input_data_type, typename roi_data_type>
roi_align(const ITensor * input,ITensor * output,const ITensor * rois,ROIPoolingLayerInfo pool_info,const Window & window,const ThreadInfo & info)213 void roi_align(const ITensor *input, ITensor *output, const ITensor *rois, ROIPoolingLayerInfo pool_info, const Window &window, const ThreadInfo &info)
214 {
215     ARM_COMPUTE_UNUSED(info);
216 
217     const DataLayout data_layout    = input->info()->data_layout();
218     const size_t     values_per_roi = rois->info()->dimension(0);
219 
220     const int roi_list_start = window.x().start();
221     const int roi_list_end   = window.x().end();
222 
223     const unsigned int idx_width  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
224     const unsigned int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
225     const unsigned int idx_depth  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
226 
227     const int input_width   = input->info()->dimension(idx_width);
228     const int input_height  = input->info()->dimension(idx_height);
229     const int input_chanels = input->info()->dimension(idx_depth);
230     const int pooled_w      = pool_info.pooled_width();
231     const int pooled_h      = pool_info.pooled_height();
232 
233     const DataType data_type = input->info()->data_type();
234     const bool     is_qasymm = is_data_type_quantized_asymmetric(data_type);
235 
236     const auto             *rois_ptr   = reinterpret_cast<const roi_data_type *>(rois->buffer());
237     const QuantizationInfo &rois_qinfo = rois->info()->quantization_info();
238     for(int roi_indx = roi_list_start; roi_indx < roi_list_end; ++roi_indx)
239     {
240         const unsigned int roi_batch = rois_ptr[values_per_roi * roi_indx];
241 
242         roi_data_type qx1 = rois_ptr[values_per_roi * roi_indx + 1];
243         roi_data_type qy1 = rois_ptr[values_per_roi * roi_indx + 2];
244         roi_data_type qx2 = rois_ptr[values_per_roi * roi_indx + 3];
245         roi_data_type qy2 = rois_ptr[values_per_roi * roi_indx + 4];
246         float         x1(qx1);
247         float         x2(qx2);
248         float         y1(qy1);
249         float         y2(qy2);
250         if(is_qasymm)
251         {
252             x1 = dequantize_qasymm16(qx1, rois_qinfo);
253             x2 = dequantize_qasymm16(qx2, rois_qinfo);
254             y1 = dequantize_qasymm16(qy1, rois_qinfo);
255             y2 = dequantize_qasymm16(qy2, rois_qinfo);
256         }
257         const float roi_anchor_x = x1 * pool_info.spatial_scale();
258         const float roi_anchor_y = y1 * pool_info.spatial_scale();
259         const float roi_dims_x   = std::max((x2 - x1) * pool_info.spatial_scale(), 1.0f);
260         const float roi_dims_y   = std::max((y2 - y1) * pool_info.spatial_scale(), 1.0f);
261         float       bin_size_x   = roi_dims_x / pool_info.pooled_width();
262         float       bin_size_y   = roi_dims_y / pool_info.pooled_height();
263 
264         // Iterate through all feature maps
265         for(int ch = 0; ch < input_chanels; ++ch)
266         {
267             // Iterate through all output pixels
268             for(int py = 0; py < pooled_h; ++py)
269             {
270                 for(int px = 0; px < pooled_w; ++px)
271                 {
272                     const float     region_start_x = compute_region_coordinate(px, bin_size_x, roi_anchor_x, input_width);
273                     const float     region_start_y = compute_region_coordinate(py, bin_size_y, roi_anchor_y, input_height);
274                     const float     region_end_x   = compute_region_coordinate(px + 1, bin_size_x, roi_anchor_x, input_width);
275                     const float     region_end_y   = compute_region_coordinate(py + 1, bin_size_y, roi_anchor_y, input_height);
276                     const int       roi_bin_grid_x = (pool_info.sampling_ratio() > 0) ? pool_info.sampling_ratio() : int(ceil(bin_size_x));
277                     const int       roi_bin_grid_y = (pool_info.sampling_ratio() > 0) ? pool_info.sampling_ratio() : int(ceil(bin_size_y));
278                     input_data_type out_val(0);
279                     if(is_qasymm)
280                     {
281                         out_val = roi_align_1x1_qasymm8<input_data_type>(
282                                       input, roi_batch, region_start_x, bin_size_x,
283                                       roi_bin_grid_x, region_end_x, region_start_y, bin_size_y,
284                                       roi_bin_grid_y, region_end_y, ch, output->info()->quantization_info());
285                     }
286                     else
287                     {
288                         out_val = roi_align_1x1<input_data_type>(
289                                       input, roi_batch, region_start_x, bin_size_x,
290                                       roi_bin_grid_x, region_end_x, region_start_y, bin_size_y,
291                                       roi_bin_grid_y, region_end_y, ch);
292                     }
293 
294                     if(data_layout == DataLayout::NCHW)
295                     {
296                         auto out_ptr = reinterpret_cast<input_data_type *>(output->ptr_to_element(Coordinates(px, py, ch, roi_indx)));
297                         *out_ptr     = out_val;
298                     }
299                     else
300                     {
301                         auto out_ptr = reinterpret_cast<input_data_type *>(output->ptr_to_element(Coordinates(ch, px, py, roi_indx)));
302                         *out_ptr     = out_val;
303                     }
304                 }
305             }
306         }
307     }
308 }
309 template void roi_align<float, float>(const ITensor *input, ITensor *output, const ITensor *rois, ROIPoolingLayerInfo pool_info, const Window &window, const ThreadInfo &info);
310 template void roi_align<uint8_t, uint16_t>(const ITensor *input, ITensor *output, const ITensor *rois, ROIPoolingLayerInfo pool_info, const Window &window, const ThreadInfo &info);
311 template void roi_align<int8_t, uint16_t>(const ITensor *input, ITensor *output, const ITensor *rois, ROIPoolingLayerInfo pool_info, const Window &window, const ThreadInfo &info);
312 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
313 template void roi_align<float16_t, float16_t>(const ITensor *input, ITensor *output, const ITensor *rois, ROIPoolingLayerInfo pool_info, const Window &window, const ThreadInfo &info);
314 #endif //defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
315 } // namespace cpu
316 } // namespace arm_compute
317