1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/normalization_ops_util.h>
10 #include <executorch/kernels/portable/cpu/vec_ops.h>
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 #include <cmath>
13 #include <tuple>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
19 using Tensor = exec_aten::Tensor;
20
21 namespace {
22
23 template <typename CTYPE>
group_norm(const Tensor & input,const optional<Tensor> & weight,const optional<Tensor> & bias,int64_t sN,int64_t sC,int64_t sHxW,int64_t group,CTYPE eps,Tensor & out,Tensor & mean,Tensor & rstd)24 void group_norm(
25 const Tensor& input,
26 const optional<Tensor>& weight,
27 const optional<Tensor>& bias,
28 int64_t sN,
29 int64_t sC,
30 int64_t sHxW,
31 int64_t group,
32 CTYPE eps,
33 Tensor& out,
34 Tensor& mean,
35 Tensor& rstd) {
36 size_t N = static_cast<size_t>(sN); // NOLINT
37 size_t C = static_cast<size_t>(sC); // NOLINT
38 size_t HxW = static_cast<size_t>(sHxW); // NOLINT
39 size_t G = static_cast<size_t>(group); // NOLINT
40
41 size_t leading = N * G;
42 size_t D = C / G;
43 size_t inner_size = D * HxW;
44
45 if (leading == 0) {
46 return;
47 }
48
49 CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
50 CTYPE* mean_data = mean.mutable_data_ptr<CTYPE>();
51 CTYPE* rstd_data = rstd.mutable_data_ptr<CTYPE>();
52
53 if (inner_size == 0) {
54 for (int i = 0; i < leading; ++i) {
55 mean_data[i] = static_cast<CTYPE>(0);
56 rstd_data[i] = static_cast<CTYPE>(NAN);
57 }
58 return;
59 }
60
61 const CTYPE* input_data = input.const_data_ptr<CTYPE>();
62 const CTYPE* weight_data;
63 if (weight.has_value()) {
64 weight_data = weight.value().const_data_ptr<CTYPE>();
65 } else {
66 weight_data = nullptr;
67 }
68 const CTYPE* bias_data;
69 if (bias.has_value()) {
70 bias_data = bias.value().const_data_ptr<CTYPE>();
71 } else {
72 bias_data = nullptr;
73 }
74
75 for (int i = 0; i < leading; ++i) {
76 const CTYPE* x = input_data + i * inner_size;
77
78 // compute E[X] and Var[x] = E[x^2] - E[x]^2
79 CTYPE sum = reduce_add(x, inner_size);
80 CTYPE sq_sum = vec_powerf(x, inner_size);
81 CTYPE mean_value = sum / inner_size;
82 CTYPE variance = sq_sum / inner_size - mean_value * mean_value;
83 CTYPE std = std::sqrt(variance + eps);
84 CTYPE rstd_value = 1.0 / std;
85
86 // Calculate the elements of output
87 if (weight_data == nullptr && bias_data == nullptr) {
88 CTYPE* y = out_data + i * inner_size;
89 for (size_t j = 0; j < inner_size; j++) {
90 y[j] = (x[j] - mean_value) * rstd_value;
91 }
92 } else {
93 const size_t g = i % G;
94 for (size_t j = 0; j < D; j++) {
95 const size_t ch = g * D + j;
96 const CTYPE scale =
97 rstd_value * (weight_data == nullptr ? 1.0 : weight_data[ch]);
98 const CTYPE beta =
99 -scale * mean_value + (bias_data == nullptr ? 0.0 : bias_data[ch]);
100 x = input_data + (i * D + j) * HxW;
101 CTYPE* y = out_data + (i * D + j) * HxW;
102 for (size_t k = 0; k < HxW; k++) {
103 y[k] = scale * x[k] + beta;
104 }
105 }
106 }
107
108 mean_data[i] = mean_value;
109 rstd_data[i] = rstd_value;
110 }
111 }
112
113 } // namespace
114
native_group_norm_out(KernelRuntimeContext & ctx,const Tensor & input,const exec_aten::optional<Tensor> & weight,const exec_aten::optional<Tensor> & bias,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps,Tensor & out,Tensor & mean_out,Tensor & rstd_out)115 std::tuple<Tensor&, Tensor&, Tensor&> native_group_norm_out(
116 KernelRuntimeContext& ctx,
117 const Tensor& input,
118 const exec_aten::optional<Tensor>& weight,
119 const exec_aten::optional<Tensor>& bias,
120 int64_t N,
121 int64_t C,
122 int64_t HxW,
123 int64_t group,
124 double eps,
125 Tensor& out,
126 Tensor& mean_out,
127 Tensor& rstd_out) {
128 (void)ctx;
129
130 std::tuple<Tensor&, Tensor&, Tensor&> ret_val(out, mean_out, rstd_out);
131
132 ET_KERNEL_CHECK(
133 ctx,
134 check_group_norm_args(
135 input, weight, bias, N, C, HxW, group, out, mean_out, rstd_out),
136 InvalidArgument,
137 ret_val);
138
139 Tensor::SizesType mean_rstd_sizes[kTensorDimensionLimit];
140 mean_rstd_sizes[0] = N;
141 mean_rstd_sizes[1] = group;
142
143 ET_KERNEL_CHECK(
144 ctx,
145 resize_tensor(out, input.sizes()) == Error::Ok,
146 InvalidArgument,
147 ret_val);
148
149 ET_KERNEL_CHECK(
150 ctx,
151 resize_tensor(mean_out, {mean_rstd_sizes, 2}) == Error::Ok,
152 InvalidArgument,
153 ret_val);
154
155 ET_KERNEL_CHECK(
156 ctx,
157 resize_tensor(rstd_out, {mean_rstd_sizes, 2}) == Error::Ok,
158 InvalidArgument,
159 ret_val);
160
161 ET_KERNEL_CHECK(
162 ctx, tensor_is_default_dim_order(input), InvalidArgument, ret_val);
163
164 ET_KERNEL_CHECK(
165 ctx,
166 tensors_have_same_dim_order(input, out, mean_out, rstd_out),
167 InvalidArgument,
168 ret_val);
169
170 if (weight.has_value()) {
171 ET_KERNEL_CHECK(
172 ctx,
173 tensors_have_same_dim_order(input, weight.value()),
174 InvalidArgument,
175 ret_val);
176 }
177
178 if (bias.has_value()) {
179 ET_KERNEL_CHECK(
180 ctx,
181 tensors_have_same_dim_order(input, bias.value()),
182 InvalidArgument,
183 ret_val);
184 }
185
186 constexpr auto name = "native_group_norm.out";
187
188 ET_SWITCH_FLOAT_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {
189 group_norm<CTYPE>(
190 input, weight, bias, N, C, HxW, group, eps, out, mean_out, rstd_out);
191 });
192
193 return ret_val;
194 }
195
196 } // namespace native
197 } // namespace executor
198 } // namespace torch
199