xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_native_group_norm.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/normalization_ops_util.h>
10 #include <executorch/kernels/portable/cpu/vec_ops.h>
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 #include <cmath>
13 #include <tuple>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 
19 using Tensor = exec_aten::Tensor;
20 
21 namespace {
22 
23 template <typename CTYPE>
group_norm(const Tensor & input,const optional<Tensor> & weight,const optional<Tensor> & bias,int64_t sN,int64_t sC,int64_t sHxW,int64_t group,CTYPE eps,Tensor & out,Tensor & mean,Tensor & rstd)24 void group_norm(
25     const Tensor& input,
26     const optional<Tensor>& weight,
27     const optional<Tensor>& bias,
28     int64_t sN,
29     int64_t sC,
30     int64_t sHxW,
31     int64_t group,
32     CTYPE eps,
33     Tensor& out,
34     Tensor& mean,
35     Tensor& rstd) {
36   size_t N = static_cast<size_t>(sN); // NOLINT
37   size_t C = static_cast<size_t>(sC); // NOLINT
38   size_t HxW = static_cast<size_t>(sHxW); // NOLINT
39   size_t G = static_cast<size_t>(group); // NOLINT
40 
41   size_t leading = N * G;
42   size_t D = C / G;
43   size_t inner_size = D * HxW;
44 
45   if (leading == 0) {
46     return;
47   }
48 
49   CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
50   CTYPE* mean_data = mean.mutable_data_ptr<CTYPE>();
51   CTYPE* rstd_data = rstd.mutable_data_ptr<CTYPE>();
52 
53   if (inner_size == 0) {
54     for (int i = 0; i < leading; ++i) {
55       mean_data[i] = static_cast<CTYPE>(0);
56       rstd_data[i] = static_cast<CTYPE>(NAN);
57     }
58     return;
59   }
60 
61   const CTYPE* input_data = input.const_data_ptr<CTYPE>();
62   const CTYPE* weight_data;
63   if (weight.has_value()) {
64     weight_data = weight.value().const_data_ptr<CTYPE>();
65   } else {
66     weight_data = nullptr;
67   }
68   const CTYPE* bias_data;
69   if (bias.has_value()) {
70     bias_data = bias.value().const_data_ptr<CTYPE>();
71   } else {
72     bias_data = nullptr;
73   }
74 
75   for (int i = 0; i < leading; ++i) {
76     const CTYPE* x = input_data + i * inner_size;
77 
78     // compute E[X] and Var[x] = E[x^2] - E[x]^2
79     CTYPE sum = reduce_add(x, inner_size);
80     CTYPE sq_sum = vec_powerf(x, inner_size);
81     CTYPE mean_value = sum / inner_size;
82     CTYPE variance = sq_sum / inner_size - mean_value * mean_value;
83     CTYPE std = std::sqrt(variance + eps);
84     CTYPE rstd_value = 1.0 / std;
85 
86     // Calculate the elements of output
87     if (weight_data == nullptr && bias_data == nullptr) {
88       CTYPE* y = out_data + i * inner_size;
89       for (size_t j = 0; j < inner_size; j++) {
90         y[j] = (x[j] - mean_value) * rstd_value;
91       }
92     } else {
93       const size_t g = i % G;
94       for (size_t j = 0; j < D; j++) {
95         const size_t ch = g * D + j;
96         const CTYPE scale =
97             rstd_value * (weight_data == nullptr ? 1.0 : weight_data[ch]);
98         const CTYPE beta =
99             -scale * mean_value + (bias_data == nullptr ? 0.0 : bias_data[ch]);
100         x = input_data + (i * D + j) * HxW;
101         CTYPE* y = out_data + (i * D + j) * HxW;
102         for (size_t k = 0; k < HxW; k++) {
103           y[k] = scale * x[k] + beta;
104         }
105       }
106     }
107 
108     mean_data[i] = mean_value;
109     rstd_data[i] = rstd_value;
110   }
111 }
112 
113 } // namespace
114 
native_group_norm_out(KernelRuntimeContext & ctx,const Tensor & input,const exec_aten::optional<Tensor> & weight,const exec_aten::optional<Tensor> & bias,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps,Tensor & out,Tensor & mean_out,Tensor & rstd_out)115 std::tuple<Tensor&, Tensor&, Tensor&> native_group_norm_out(
116     KernelRuntimeContext& ctx,
117     const Tensor& input,
118     const exec_aten::optional<Tensor>& weight,
119     const exec_aten::optional<Tensor>& bias,
120     int64_t N,
121     int64_t C,
122     int64_t HxW,
123     int64_t group,
124     double eps,
125     Tensor& out,
126     Tensor& mean_out,
127     Tensor& rstd_out) {
128   (void)ctx;
129 
130   std::tuple<Tensor&, Tensor&, Tensor&> ret_val(out, mean_out, rstd_out);
131 
132   ET_KERNEL_CHECK(
133       ctx,
134       check_group_norm_args(
135           input, weight, bias, N, C, HxW, group, out, mean_out, rstd_out),
136       InvalidArgument,
137       ret_val);
138 
139   Tensor::SizesType mean_rstd_sizes[kTensorDimensionLimit];
140   mean_rstd_sizes[0] = N;
141   mean_rstd_sizes[1] = group;
142 
143   ET_KERNEL_CHECK(
144       ctx,
145       resize_tensor(out, input.sizes()) == Error::Ok,
146       InvalidArgument,
147       ret_val);
148 
149   ET_KERNEL_CHECK(
150       ctx,
151       resize_tensor(mean_out, {mean_rstd_sizes, 2}) == Error::Ok,
152       InvalidArgument,
153       ret_val);
154 
155   ET_KERNEL_CHECK(
156       ctx,
157       resize_tensor(rstd_out, {mean_rstd_sizes, 2}) == Error::Ok,
158       InvalidArgument,
159       ret_val);
160 
161   ET_KERNEL_CHECK(
162       ctx, tensor_is_default_dim_order(input), InvalidArgument, ret_val);
163 
164   ET_KERNEL_CHECK(
165       ctx,
166       tensors_have_same_dim_order(input, out, mean_out, rstd_out),
167       InvalidArgument,
168       ret_val);
169 
170   if (weight.has_value()) {
171     ET_KERNEL_CHECK(
172         ctx,
173         tensors_have_same_dim_order(input, weight.value()),
174         InvalidArgument,
175         ret_val);
176   }
177 
178   if (bias.has_value()) {
179     ET_KERNEL_CHECK(
180         ctx,
181         tensors_have_same_dim_order(input, bias.value()),
182         InvalidArgument,
183         ret_val);
184   }
185 
186   constexpr auto name = "native_group_norm.out";
187 
188   ET_SWITCH_FLOAT_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {
189     group_norm<CTYPE>(
190         input, weight, bias, N, C, HxW, group, eps, out, mean_out, rstd_out);
191   });
192 
193   return ret_val;
194 }
195 
196 } // namespace native
197 } // namespace executor
198 } // namespace torch
199