xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/meanstddevnorm/generic/neon/impl.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2019-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #include "src/cpu/kernels/meanstddevnorm/generic/neon/impl.h"
26 #include "src/core/NEON/wrapper/wrapper.h"
27 
28 namespace arm_compute
29 {
30 namespace cpu
31 {
32 template <typename ScalarType, int size>
mean_stddev_normalization(ITensor * input,ITensor * output,float epsilon,const Window & window)33 void mean_stddev_normalization(ITensor *input, ITensor *output, float epsilon, const Window &window)
34 {
35     using ExactTagType = typename wrapper::traits::neon_vector<ScalarType, size>::tag_type;
36 
37     // Set build options
38     Window win = window;
39     win.set(Window::DimX, Window::Dimension(0, 1, 1));
40 
41     const int  window_step_x  = size;
42     const auto window_start_x = static_cast<int>(window.x().start());
43     const auto window_end_x   = static_cast<int>(window.x().end());
44 
45     Iterator input_itr(input, win);
46     Iterator output_itr(output, win);
47 
48     execute_window_loop(win, [&](const Coordinates &)
49     {
50         int  x       = window_start_x;
51         auto in_ptr  = reinterpret_cast<const ScalarType *>(input_itr.ptr());
52         auto out_ptr = reinterpret_cast<ScalarType *>(output_itr.ptr());
53 
54         auto sum_vec    = wrapper::vdup_n(static_cast<ScalarType>(0.f), ExactTagType{});
55         auto sum_sq_vec = wrapper::vdup_n(static_cast<ScalarType>(0.f), ExactTagType{});
56 
57         for(; x <= (window_end_x - window_step_x); x += window_step_x)
58         {
59             auto data  = wrapper::vloadq(in_ptr + x);
60             sum_vec    = wrapper::vadd(sum_vec, data);
61             sum_sq_vec = wrapper::vadd(sum_sq_vec, wrapper::vmul(data, data));
62         }
63 
64         auto sum_carry_res    = wrapper::vpadd(wrapper::vgethigh(sum_vec), wrapper::vgetlow(sum_vec));
65         auto sum_sq_carry_res = wrapper::vpadd(wrapper::vgethigh(sum_sq_vec), wrapper::vgetlow(sum_sq_vec));
66         for(int i = 0; i < size / 4; ++i)
67         {
68             sum_carry_res    = wrapper::vpadd(sum_carry_res, sum_carry_res);
69             sum_sq_carry_res = wrapper::vpadd(sum_sq_carry_res, sum_sq_carry_res);
70         }
71 
72         auto sum    = wrapper::vgetlane(sum_carry_res, 0);
73         auto sum_sq = wrapper::vgetlane(sum_sq_carry_res, 0);
74 
75         // Compute left-over elements
76         for(; x < window_end_x; ++x)
77         {
78             ScalarType data = *(in_ptr + x);
79             sum += data;
80             sum_sq += data * data;
81         }
82 
83         ScalarType mean       = sum / input->info()->dimension(0);
84         ScalarType var        = (sum_sq / input->info()->dimension(0)) - (mean * mean);
85         ScalarType stddev_inv = 1.f / sqrt(var + epsilon);
86 
87         auto mean_vec       = wrapper::vdup_n(mean, ExactTagType{});
88         auto stddev_inv_vec = wrapper::vdup_n(stddev_inv, ExactTagType{});
89         for(x = window_start_x; x <= (window_end_x - window_step_x); x += window_step_x)
90         {
91             auto data = wrapper::vloadq(in_ptr + x);
92             auto res  = wrapper::vmul(wrapper::vsub(data, mean_vec), stddev_inv_vec);
93             // Store results
94             wrapper::vstore(out_ptr + x, res);
95         }
96         for(; x < window_end_x; ++x)
97         {
98             *(out_ptr + x) = (*(in_ptr + x) - mean) * stddev_inv;
99         }
100     },
101     input_itr, output_itr);
102 }
103 template void mean_stddev_normalization<float, 4>(ITensor *input, ITensor *output, float epsilon, const Window &window);
104 
105 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
106 template <>
mean_stddev_normalization(ITensor * input,ITensor * output,float epsilon,const Window & window)107 void mean_stddev_normalization<float16_t, 8>(ITensor *input, ITensor *output, float epsilon, const Window &window)
108 {
109     // Set build options
110     Window win = window;
111     win.set(Window::DimX, Window::Dimension(0, 1, 1));
112 
113     const int  window_step_x  = 8;
114     const auto window_start_x = static_cast<int>(window.x().start());
115     const auto window_end_x   = static_cast<int>(window.x().end());
116 
117     Iterator input_itr(input, win);
118     Iterator output_itr(output, win);
119 
120     execute_window_loop(win, [&](const Coordinates &)
121     {
122         int  x       = window_start_x;
123         auto in_ptr  = reinterpret_cast<const float16_t *>(input_itr.ptr());
124         auto out_ptr = reinterpret_cast<float16_t *>(output_itr.ptr());
125 
126         float16x8_t sum_vec    = vdupq_n_f16(static_cast<float16_t>(0.0f));
127         float32x4_t sum_sq_vec = vdupq_n_f32(0.0f);
128 
129         for(; x <= (window_end_x - window_step_x); x += window_step_x)
130         {
131             float16x8_t data = vld1q_f16(in_ptr + x);
132             sum_vec          = vaddq_f16(sum_vec, data);
133             float32x4_t dl   = vcvt_f32_f16(vget_low_f16(data));
134             float32x4_t dh   = vcvt_f32_f16(vget_high_f16(data));
135             sum_sq_vec       = vaddq_f32(sum_sq_vec, vmulq_f32(dl, dl));
136             sum_sq_vec       = vaddq_f32(sum_sq_vec, vmulq_f32(dh, dh));
137         }
138 
139         float16x4_t sum_carry_res = vpadd_f16(vget_high_f16(sum_vec), vget_low_f16(sum_vec));
140         sum_carry_res             = vpadd_f16(sum_carry_res, sum_carry_res);
141         sum_carry_res             = vpadd_f16(sum_carry_res, sum_carry_res);
142 
143         float32x4_t sum_sq_carry_res = vpaddq_f32(sum_sq_vec, sum_sq_vec);
144         sum_sq_carry_res             = vpaddq_f32(sum_sq_carry_res, sum_sq_carry_res);
145 
146         float16_t sum    = vget_lane_f16(sum_carry_res, 0);
147         float     sum_sq = vgetq_lane_f32(sum_sq_carry_res, 0);
148 
149         // Compute left-over elements
150         for(; x < window_end_x; ++x)
151         {
152             float16_t data = *(in_ptr + x);
153             sum += data;
154             float fdata = static_cast<float>(data);
155             sum_sq += fdata * fdata;
156         }
157 
158         float16_t mean       = sum / input->info()->dimension(0);
159         float     var        = (sum_sq / input->info()->dimension(0)) - (mean * mean);
160         float16_t stddev_inv = static_cast<float16_t>(1.f / sqrt(var + epsilon));
161 
162         float16x8_t mean_vec       = vdupq_n_f16(mean);
163         float16x8_t stddev_inv_vec = vdupq_n_f16(stddev_inv);
164 
165         for(x = window_start_x; x <= (window_end_x - window_step_x); x += window_step_x)
166         {
167             float16x8_t data = vld1q_f16(in_ptr + x);
168             float16x8_t res  = vmulq_f16(vsubq_f16(data, mean_vec), stddev_inv_vec);
169             // Store results
170             vst1q_f16(out_ptr + x, res);
171         }
172         for(; x < window_end_x; ++x)
173         {
174             *(out_ptr + x) = (*(in_ptr + x) - mean) * stddev_inv;
175         }
176     },
177     input_itr, output_itr);
178 }
179 #endif //defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
180 
181 } // namespace cpu
182 } // namespace arm_compute
183