xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/instancenorm/generic/neon/impl.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2019-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/cpu/kernels/instancenorm/generic/neon/impl.h"
25 #include "src/core/NEON/wrapper/wrapper.h"
26 
27 namespace arm_compute
28 {
29 class ITensor;
30 class Window;
31 namespace cpu
32 {
33 template <typename InputType, typename AccType>
vector_float_sum(AccType & result,AccType & result_square,const InputType & inputs)34 void vector_float_sum(AccType &result, AccType &result_square, const InputType &inputs)
35 {
36     result        = wrapper::vadd(result, inputs);
37     result_square = wrapper::vadd(result_square, wrapper::vmul(inputs, inputs));
38 }
39 
40 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
41 template <>
vector_float_sum(float32x4_t & result,float32x4_t & result_square,const float16x8_t & inputs)42 inline void vector_float_sum(float32x4_t &result, float32x4_t &result_square, const float16x8_t &inputs)
43 {
44     vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgetlow(inputs)));
45     vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgethigh(inputs)));
46 }
47 template <>
vector_float_norm(const float16x8_t & inputs,const float32x4_t & vec_mean,const float32x4_t & vec_multip,const float32x4_t & vec_beta)48 inline float16x8_t vector_float_norm(const float16x8_t &inputs, const float32x4_t &vec_mean, const float32x4_t &vec_multip, const float32x4_t &vec_beta)
49 {
50     const auto  input_low   = wrapper::vcvt<float>(wrapper::vgetlow(inputs));
51     const auto  input_high  = wrapper::vcvt<float>(wrapper::vgethigh(inputs));
52     const auto  result_low  = wrapper::vcvt<float16_t>(vector_float_norm(input_low, vec_mean, vec_multip, vec_beta));
53     const auto  result_high = wrapper::vcvt<float16_t>(vector_float_norm(input_high, vec_mean, vec_multip, vec_beta));
54     float16x8_t result      = wrapper::vcombine(result_low, result_high);
55 
56     return result;
57 }
58 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
59 
60 template <typename InputType, typename AccType>
vector_float_norm(const InputType & inputs,const AccType & vec_mean,const AccType & vec_multip,const AccType & vec_beta)61 InputType vector_float_norm(const InputType &inputs, const AccType &vec_mean, const AccType &vec_multip, const AccType &vec_beta)
62 {
63     return wrapper::vadd(wrapper::vmul(wrapper::vsub(inputs, vec_mean), vec_multip), vec_beta);
64 }
65 
66 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
67 
68 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
69 template <typename T, typename AccType>
instance_normalization_nchw(ITensor * input,ITensor * output,float gamma,float beta,float epsilon,const Window & window)70 void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window)
71 {
72     /** SIMD vector tag type. */
73     using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
74 
75     // Clear X/Y dimensions on execution window as we handle the planes manually
76     Window win = window;
77     win.set(Window::DimX, Window::Dimension(0, 1, 1));
78     win.set(Window::DimY, Window::Dimension(0, 1, 1));
79 
80     constexpr int      window_step_x  = 16 / sizeof(T);
81     const unsigned int elements_plane = input->info()->dimension(0) * output->info()->dimension(1);
82 
83     Iterator input_it(input, win);
84     execute_window_loop(win, [&](const Coordinates & id)
85     {
86         Window win_plane = window;
87         win_plane.set(Window::DimX, Window::Dimension(0, 1, 1));
88         win_plane.set(Window::DimZ, Window::Dimension(id[2], id[2] + 1, 1));
89         win_plane.set(3, Window::Dimension(id[3], id[3] + 1, 1));
90 
91         Iterator input_plane_it(input, win_plane);
92         Iterator output_plane_it(output, win_plane);
93 
94         auto sum_h_w         = static_cast<AccType>(0.f);
95         auto sum_squares_h_w = static_cast<AccType>(0.f);
96 
97         execute_window_loop(win_plane, [&](const Coordinates &)
98         {
99             const auto input_ptr = reinterpret_cast<const T *>(input_plane_it.ptr());
100 
101             auto vec_sum_h_w         = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
102             auto vec_sum_squares_h_w = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
103 
104             // Compute S elements per iteration
105             int x = window.x().start();
106             for(; x <= (window.x().end() - window_step_x); x += window_step_x)
107             {
108                 auto vec_input_val = wrapper::vloadq(input_ptr + x);
109                 vector_float_sum(vec_sum_h_w, vec_sum_squares_h_w, vec_input_val);
110             }
111 
112             auto vec2_sum_h_w         = wrapper::vpadd(wrapper::vgethigh(vec_sum_h_w), wrapper::vgetlow(vec_sum_h_w));
113             auto vec2_sum_squares_h_w = wrapper::vpadd(wrapper::vgethigh(vec_sum_squares_h_w), wrapper::vgetlow(vec_sum_squares_h_w));
114 
115             vec2_sum_h_w         = wrapper::vpadd(vec2_sum_h_w, vec2_sum_h_w);
116             vec2_sum_squares_h_w = wrapper::vpadd(vec2_sum_squares_h_w, vec2_sum_squares_h_w);
117 
118             sum_h_w += wrapper::vgetlane(vec2_sum_h_w, 0);
119             sum_squares_h_w += wrapper::vgetlane(vec2_sum_squares_h_w, 0);
120 
121             // Compute left-over elements
122             for(; x < window.x().end(); ++x)
123             {
124                 const auto value = static_cast<AccType>(*(input_ptr + x));
125                 sum_h_w += value;
126                 sum_squares_h_w += value * value;
127             }
128         },
129         input_plane_it, output_plane_it);
130 
131         const auto mean_h_w = sum_h_w / elements_plane;
132         const auto var_h_w  = sum_squares_h_w / elements_plane - mean_h_w * mean_h_w;
133 
134         const auto multip_h_w     = gamma / std::sqrt(var_h_w + epsilon);
135         const auto vec_mean_h_w   = wrapper::vdup_n(static_cast<AccType>(mean_h_w), ExactTagType{});
136         const auto vec_multip_h_w = wrapper::vdup_n(static_cast<AccType>(multip_h_w), ExactTagType{});
137         const auto vec_beta       = wrapper::vdup_n(static_cast<AccType>(beta), ExactTagType{});
138 
139         execute_window_loop(win_plane, [&](const Coordinates &)
140         {
141             auto input_ptr  = reinterpret_cast<T *>(input_plane_it.ptr());
142             auto output_ptr = reinterpret_cast<T *>(output_plane_it.ptr());
143 
144             // Compute S elements per iteration
145             int x = window.x().start();
146             //auto vec_val = wrapper::vdup_n(static_cast<T>(0.0f), ExactTagType{});
147             for(; x <= (window.x().end() - window_step_x); x += window_step_x)
148             {
149                 const auto vec_val        = wrapper::vloadq(input_ptr + x);
150                 const auto normalized_vec = vector_float_norm(vec_val, vec_mean_h_w, vec_multip_h_w, vec_beta);
151                 wrapper::vstore(output_ptr + x, normalized_vec);
152             }
153 
154             // Compute left-over elements
155             for(; x < window.x().end(); ++x)
156             {
157                 const auto val    = static_cast<AccType>(*(input_ptr + x));
158                 *(output_ptr + x) = static_cast<T>((val - mean_h_w) * multip_h_w + beta);
159             }
160         },
161         input_plane_it, output_plane_it);
162     },
163     input_it);
164 }
165 
166 template void instance_normalization_nchw<float>(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window);
167 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
168 template void instance_normalization_nchw<float16_t, float>(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window);
169 template void instance_normalization_nchw<float16_t>(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window);
170 #endif //defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
171 } // namespace cpu
172 } // namespace arm_compute
173