xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/normalization_ops_util.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstring>
10 
11 #include <executorch/kernels/portable/cpu/util/normalization_ops_util.h>
12 
13 namespace torch {
14 namespace executor {
15 
16 using Tensor = exec_aten::Tensor;
17 
check_batch_norm_args(const Tensor & in,const exec_aten::optional<Tensor> & weight,const exec_aten::optional<Tensor> & bias,const exec_aten::optional<Tensor> & running_mean,const exec_aten::optional<Tensor> & running_var,double momentum,double eps,Tensor & out,Tensor & mean_out,Tensor & var_out)18 bool check_batch_norm_args(
19     const Tensor& in,
20     const exec_aten::optional<Tensor>& weight,
21     const exec_aten::optional<Tensor>& bias,
22     const exec_aten::optional<Tensor>& running_mean,
23     const exec_aten::optional<Tensor>& running_var,
24     double momentum,
25     double eps,
26     Tensor& out,
27     Tensor& mean_out,
28     Tensor& var_out) {
29   // All tensors must be the same dtype
30   if (weight.has_value()) {
31     ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight.value()));
32   }
33   if (bias.has_value()) {
34     ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, bias.value()));
35   }
36   if (running_mean.has_value()) {
37     ET_LOG_AND_RETURN_IF_FALSE(
38         tensors_have_same_dtype(in, running_mean.value()));
39   }
40   if (running_mean.has_value()) {
41     ET_LOG_AND_RETURN_IF_FALSE(
42         tensors_have_same_dtype(in, running_var.value()));
43   }
44   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
45   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mean_out));
46   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, var_out));
47 
48   size_t C_dim = in.dim() >= 1 ? 1 : 0;
49   // All parameter tensors must be of dim 1 and have length equal to the
50   // channels dim of in
51   if (weight.has_value()) {
52     ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight.value(), 1));
53     ET_LOG_AND_RETURN_IF_FALSE(
54         tensors_have_same_size_at_dims(weight.value(), 0, in, C_dim));
55   }
56   if (bias.has_value()) {
57     ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(bias.value(), 1));
58     ET_LOG_AND_RETURN_IF_FALSE(
59         tensors_have_same_size_at_dims(bias.value(), 0, in, C_dim));
60   }
61   if (running_mean.has_value()) {
62     ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_mean.value(), 1));
63     ET_LOG_AND_RETURN_IF_FALSE(
64         tensors_have_same_size_at_dims(running_mean.value(), 0, in, C_dim));
65   }
66   if (running_var.has_value()) {
67     ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_var.value(), 1));
68     ET_LOG_AND_RETURN_IF_FALSE(
69         tensors_have_same_size_at_dims(running_var.value(), 0, in, C_dim));
70   }
71 
72   return true;
73 }
74 
check_layer_norm_args(const Tensor & in,IntArrayRef normalized_shape,const exec_aten::optional<Tensor> & weight,const exec_aten::optional<Tensor> & bias,Tensor & out,Tensor & mean_out,Tensor & rstd_out)75 bool check_layer_norm_args(
76     const Tensor& in,
77     IntArrayRef normalized_shape,
78     const exec_aten::optional<Tensor>& weight,
79     const exec_aten::optional<Tensor>& bias,
80     Tensor& out,
81     Tensor& mean_out,
82     Tensor& rstd_out) {
83   size_t ndim = normalized_shape.size();
84   ET_LOG_MSG_AND_RETURN_IF_FALSE(
85       ndim >= 1,
86       "Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element.");
87   ET_LOG_MSG_AND_RETURN_IF_FALSE(
88       in.dim() >= ndim,
89       "Expected input tensor to have rank >= the length of normalized_shape.");
90   size_t shift = in.dim() - ndim;
91   for (size_t d = 0; d < ndim; ++d) {
92     ET_LOG_MSG_AND_RETURN_IF_FALSE(
93         in.size(d + shift) == normalized_shape[d],
94         "Expected normalized_shape to match the sizes of input's rightmost dimensions.");
95   }
96   exec_aten::SizesType shape[ndim];
97   for (size_t i = 0; i < ndim; ++i) {
98     shape[i] = static_cast<exec_aten::SizesType>(normalized_shape[i]);
99   }
100 
101   if (weight.has_value()) {
102     ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight.value()));
103     ET_LOG_AND_RETURN_IF_FALSE(
104         tensor_has_expected_size(weight.value(), {shape, ndim}));
105   }
106   if (bias.has_value()) {
107     ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, bias.value()));
108     ET_LOG_AND_RETURN_IF_FALSE(
109         tensor_has_expected_size(bias.value(), {shape, ndim}));
110   }
111   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
112   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mean_out));
113   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, rstd_out));
114   return true;
115 }
116 
get_layer_norm_out_target_size(const Tensor & in,IntArrayRef normalized_shape,Tensor::SizesType * mean_rstd_sizes,size_t * mean_rstd_ndim)117 void get_layer_norm_out_target_size(
118     const Tensor& in,
119     IntArrayRef normalized_shape,
120     Tensor::SizesType* mean_rstd_sizes,
121     size_t* mean_rstd_ndim) {
122   *mean_rstd_ndim = in.dim();
123 
124   for (size_t d = 0; d < in.dim(); ++d) {
125     if (d < in.dim() - normalized_shape.size()) {
126       mean_rstd_sizes[d] = in.size(d);
127     } else {
128       mean_rstd_sizes[d] = 1;
129     }
130   }
131 }
132 
check_group_norm_args(const Tensor & in,const exec_aten::optional<Tensor> & weight,const exec_aten::optional<Tensor> & bias,int64_t N,int64_t C,int64_t HxW,int64_t group,Tensor & out,Tensor & mean_out,Tensor & rstd_out)133 bool check_group_norm_args(
134     const Tensor& in,
135     const exec_aten::optional<Tensor>& weight,
136     const exec_aten::optional<Tensor>& bias,
137     int64_t N,
138     int64_t C,
139     int64_t HxW,
140     int64_t group,
141     Tensor& out,
142     Tensor& mean_out,
143     Tensor& rstd_out) {
144   ET_LOG_AND_RETURN_IF_FALSE(in.size(0) == N);
145   ET_LOG_AND_RETURN_IF_FALSE(in.size(1) == C);
146   ET_LOG_AND_RETURN_IF_FALSE(in.numel() == N * C * HxW);
147   ET_LOG_MSG_AND_RETURN_IF_FALSE(
148       group > 0, "Expected number of groups to be greater than 0");
149   ET_LOG_MSG_AND_RETURN_IF_FALSE(
150       C % group == 0,
151       "Expected number of channels in input to be divisible by number of groups");
152   ET_LOG_MSG_AND_RETURN_IF_FALSE(
153       !weight.has_value() ||
154           (weight.value().dim() == 1 && weight.value().size(0) == C),
155       "Expected weight to be a vector of size equal to the number of channels in input");
156   ET_LOG_MSG_AND_RETURN_IF_FALSE(
157       !bias.has_value() ||
158           (bias.value().dim() == 1 && bias.value().size(0) == C),
159       "Expected bias to be a vector of size equal to the number of channels in input");
160 
161   if (weight.has_value()) {
162     ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight.value()));
163   }
164   if (bias.has_value()) {
165     ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, bias.value()));
166   }
167   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
168   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mean_out));
169   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, rstd_out));
170   return true;
171 }
172 
173 } // namespace executor
174 } // namespace torch
175