1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cstring>
10
11 #include <executorch/kernels/portable/cpu/util/normalization_ops_util.h>
12
13 namespace torch {
14 namespace executor {
15
16 using Tensor = exec_aten::Tensor;
17
check_batch_norm_args(const Tensor & in,const exec_aten::optional<Tensor> & weight,const exec_aten::optional<Tensor> & bias,const exec_aten::optional<Tensor> & running_mean,const exec_aten::optional<Tensor> & running_var,double momentum,double eps,Tensor & out,Tensor & mean_out,Tensor & var_out)18 bool check_batch_norm_args(
19 const Tensor& in,
20 const exec_aten::optional<Tensor>& weight,
21 const exec_aten::optional<Tensor>& bias,
22 const exec_aten::optional<Tensor>& running_mean,
23 const exec_aten::optional<Tensor>& running_var,
24 double momentum,
25 double eps,
26 Tensor& out,
27 Tensor& mean_out,
28 Tensor& var_out) {
29 // All tensors must be the same dtype
30 if (weight.has_value()) {
31 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight.value()));
32 }
33 if (bias.has_value()) {
34 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, bias.value()));
35 }
36 if (running_mean.has_value()) {
37 ET_LOG_AND_RETURN_IF_FALSE(
38 tensors_have_same_dtype(in, running_mean.value()));
39 }
40 if (running_mean.has_value()) {
41 ET_LOG_AND_RETURN_IF_FALSE(
42 tensors_have_same_dtype(in, running_var.value()));
43 }
44 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
45 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mean_out));
46 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, var_out));
47
48 size_t C_dim = in.dim() >= 1 ? 1 : 0;
49 // All parameter tensors must be of dim 1 and have length equal to the
50 // channels dim of in
51 if (weight.has_value()) {
52 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight.value(), 1));
53 ET_LOG_AND_RETURN_IF_FALSE(
54 tensors_have_same_size_at_dims(weight.value(), 0, in, C_dim));
55 }
56 if (bias.has_value()) {
57 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(bias.value(), 1));
58 ET_LOG_AND_RETURN_IF_FALSE(
59 tensors_have_same_size_at_dims(bias.value(), 0, in, C_dim));
60 }
61 if (running_mean.has_value()) {
62 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_mean.value(), 1));
63 ET_LOG_AND_RETURN_IF_FALSE(
64 tensors_have_same_size_at_dims(running_mean.value(), 0, in, C_dim));
65 }
66 if (running_var.has_value()) {
67 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_var.value(), 1));
68 ET_LOG_AND_RETURN_IF_FALSE(
69 tensors_have_same_size_at_dims(running_var.value(), 0, in, C_dim));
70 }
71
72 return true;
73 }
74
check_layer_norm_args(const Tensor & in,IntArrayRef normalized_shape,const exec_aten::optional<Tensor> & weight,const exec_aten::optional<Tensor> & bias,Tensor & out,Tensor & mean_out,Tensor & rstd_out)75 bool check_layer_norm_args(
76 const Tensor& in,
77 IntArrayRef normalized_shape,
78 const exec_aten::optional<Tensor>& weight,
79 const exec_aten::optional<Tensor>& bias,
80 Tensor& out,
81 Tensor& mean_out,
82 Tensor& rstd_out) {
83 size_t ndim = normalized_shape.size();
84 ET_LOG_MSG_AND_RETURN_IF_FALSE(
85 ndim >= 1,
86 "Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element.");
87 ET_LOG_MSG_AND_RETURN_IF_FALSE(
88 in.dim() >= ndim,
89 "Expected input tensor to have rank >= the length of normalized_shape.");
90 size_t shift = in.dim() - ndim;
91 for (size_t d = 0; d < ndim; ++d) {
92 ET_LOG_MSG_AND_RETURN_IF_FALSE(
93 in.size(d + shift) == normalized_shape[d],
94 "Expected normalized_shape to match the sizes of input's rightmost dimensions.");
95 }
96 exec_aten::SizesType shape[ndim];
97 for (size_t i = 0; i < ndim; ++i) {
98 shape[i] = static_cast<exec_aten::SizesType>(normalized_shape[i]);
99 }
100
101 if (weight.has_value()) {
102 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight.value()));
103 ET_LOG_AND_RETURN_IF_FALSE(
104 tensor_has_expected_size(weight.value(), {shape, ndim}));
105 }
106 if (bias.has_value()) {
107 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, bias.value()));
108 ET_LOG_AND_RETURN_IF_FALSE(
109 tensor_has_expected_size(bias.value(), {shape, ndim}));
110 }
111 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
112 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mean_out));
113 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, rstd_out));
114 return true;
115 }
116
get_layer_norm_out_target_size(const Tensor & in,IntArrayRef normalized_shape,Tensor::SizesType * mean_rstd_sizes,size_t * mean_rstd_ndim)117 void get_layer_norm_out_target_size(
118 const Tensor& in,
119 IntArrayRef normalized_shape,
120 Tensor::SizesType* mean_rstd_sizes,
121 size_t* mean_rstd_ndim) {
122 *mean_rstd_ndim = in.dim();
123
124 for (size_t d = 0; d < in.dim(); ++d) {
125 if (d < in.dim() - normalized_shape.size()) {
126 mean_rstd_sizes[d] = in.size(d);
127 } else {
128 mean_rstd_sizes[d] = 1;
129 }
130 }
131 }
132
check_group_norm_args(const Tensor & in,const exec_aten::optional<Tensor> & weight,const exec_aten::optional<Tensor> & bias,int64_t N,int64_t C,int64_t HxW,int64_t group,Tensor & out,Tensor & mean_out,Tensor & rstd_out)133 bool check_group_norm_args(
134 const Tensor& in,
135 const exec_aten::optional<Tensor>& weight,
136 const exec_aten::optional<Tensor>& bias,
137 int64_t N,
138 int64_t C,
139 int64_t HxW,
140 int64_t group,
141 Tensor& out,
142 Tensor& mean_out,
143 Tensor& rstd_out) {
144 ET_LOG_AND_RETURN_IF_FALSE(in.size(0) == N);
145 ET_LOG_AND_RETURN_IF_FALSE(in.size(1) == C);
146 ET_LOG_AND_RETURN_IF_FALSE(in.numel() == N * C * HxW);
147 ET_LOG_MSG_AND_RETURN_IF_FALSE(
148 group > 0, "Expected number of groups to be greater than 0");
149 ET_LOG_MSG_AND_RETURN_IF_FALSE(
150 C % group == 0,
151 "Expected number of channels in input to be divisible by number of groups");
152 ET_LOG_MSG_AND_RETURN_IF_FALSE(
153 !weight.has_value() ||
154 (weight.value().dim() == 1 && weight.value().size(0) == C),
155 "Expected weight to be a vector of size equal to the number of channels in input");
156 ET_LOG_MSG_AND_RETURN_IF_FALSE(
157 !bias.has_value() ||
158 (bias.value().dim() == 1 && bias.value().size(0) == C),
159 "Expected bias to be a vector of size equal to the number of channels in input");
160
161 if (weight.has_value()) {
162 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight.value()));
163 }
164 if (bias.has_value()) {
165 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, bias.value()));
166 }
167 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
168 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mean_out));
169 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, rstd_out));
170 return true;
171 }
172
173 } // namespace executor
174 } // namespace torch
175