xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/layer_norm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/layer_norm.h>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/native/cpu/mixed_data_type.h>
8 #include <c10/util/irange.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/empty_like_native.h>
17 #include <ATen/ops/layer_norm_native.h>
18 #include <ATen/ops/native_batch_norm.h>
19 #include <ATen/ops/native_layer_norm.h>
20 #include <ATen/ops/native_layer_norm_backward_native.h>
21 #include <ATen/ops/native_layer_norm_native.h>
22 #include <ATen/ops/pow.h>
23 #include <ATen/ops/rsqrt.h>
24 #include <ATen/ops/rms_norm.h>
25 #include <ATen/ops/zeros_like_native.h>
26 #endif
27 
28 #include <array>
29 #include <tuple>
30 #include <vector>
31 
32 namespace at::native {
33 
layer_norm_with_mean_rstd_out(at::Tensor & out,at::Tensor & mean,at::Tensor & rstd,const at::Tensor & input,IntArrayRef normalized_shape,const Tensor & gamma,const Tensor & beta,double eps,int64_t M,int64_t N)34 static void layer_norm_with_mean_rstd_out(
35     at::Tensor& out,
36     at::Tensor& mean,
37     at::Tensor& rstd,
38     const at::Tensor& input,
39     IntArrayRef normalized_shape,
40     const Tensor& gamma,
41     const Tensor& beta,
42     double eps,
43     int64_t M,
44     int64_t N) {
45   LayerNormKernel(kCPU, input, gamma, beta, M, N, eps, &out, &mean, &rstd);
46   const auto input_shape = input.sizes();
47   const size_t axis = input.dim() - normalized_shape.size();
48 
49   DimVector stat_shape;
50   for (const auto idx : c10::irange(axis)) {
51     stat_shape.emplace_back(input_shape[idx]);
52   }
53   for (const auto idx C10_UNUSED : c10::irange(axis, input.dim())) {
54     stat_shape.emplace_back(1);
55   }
56 
57   mean = mean.view(stat_shape);
58   rstd = rstd.view(stat_shape);
59 }
60 
layer_norm_cpu_out(at::Tensor & out,const at::Tensor & input,const Tensor & gamma,const Tensor & beta,double eps,int64_t M,int64_t N)61 void layer_norm_cpu_out(
62     at::Tensor& out,
63     const at::Tensor& input,
64     const Tensor& gamma,
65     const Tensor& beta,
66     double eps,
67     int64_t M,
68     int64_t N) {
69   if (M <= 0) {
70     return;
71   }
72   LayerNormKernel(kCPU, input, gamma, beta, M, N, eps, &out, /*mean=*/nullptr, /*rstd=*/nullptr);
73 }
74 
layer_norm_cpu(const Tensor & input,IntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,double eps)75 std::tuple<Tensor, Tensor, Tensor> layer_norm_cpu(
76     const Tensor& input,
77     IntArrayRef normalized_shape, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& bias_opt /* optional */,
78     double eps) {
79   // See [Note: hacky wrapper removal for optional tensor]
80   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
81   const Tensor& weight = *weight_maybe_owned;
82   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
83   const Tensor& bias = *bias_maybe_owned;
84 
85   bool mixed_type = is_mixed_type(input, weight, bias);
86   if (mixed_type) {
87     check_mixed_data_type(input, weight, bias);
88   }
89 
90   auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
91   auto M = M_N.first;
92   auto N = M_N.second;
93   auto X = input.expect_contiguous();
94   auto gamma = weight.expect_contiguous();
95   auto beta = bias.expect_contiguous();
96 
97   Tensor Y = at::native::empty_like(
98       *X,
99       std::nullopt /* dtype */,
100       std::nullopt /* layout */,
101       std::nullopt /* device */,
102       std::nullopt /* pin_memory */,
103       at::MemoryFormat::Contiguous);
104   const auto dtype = param_scalar_type(input, mixed_type);
105   Tensor mean = at::empty({M}, X->options().dtype(dtype));
106   Tensor rstd = at::empty({M}, X->options().dtype(dtype));
107 
108   layer_norm_with_mean_rstd_out(Y, mean, rstd, *X, normalized_shape, *gamma, *beta, eps, M, N);
109   return std::make_tuple(std::move(Y), std::move(mean), std::move(rstd));
110 }
111 
layer_norm_backward_cpu(const Tensor & dY,const Tensor & input,IntArrayRef normalized_shape,const Tensor & mean,const Tensor & rstd,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,std::array<bool,3> grad_input_mask)112 std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_cpu(
113     const Tensor& dY,
114     const Tensor& input,
115     IntArrayRef normalized_shape,
116     const Tensor& mean,
117     const Tensor& rstd,
118     const std::optional<Tensor>& weight_opt /* optional */,
119     const std::optional<Tensor>& bias_opt /* optional */,
120     std::array<bool, 3> grad_input_mask) {
121   // See [Note: hacky wrapper removal for optional tensor]
122   c10::MaybeOwned<Tensor> weight_maybe_owned =
123       at::borrow_from_optional_tensor(weight_opt);
124   const Tensor& weight = *weight_maybe_owned;
125   c10::MaybeOwned<Tensor> bias_maybe_owned =
126       at::borrow_from_optional_tensor(bias_opt);
127   const Tensor& bias = *bias_maybe_owned;
128 
129   auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
130   auto M = M_N.first;
131   auto N = M_N.second;
132   auto X = input.expect_contiguous();
133   auto gamma = weight.expect_contiguous();
134   auto beta = bias.expect_contiguous();
135 
136   Tensor dX;
137   Tensor dgamma;
138   Tensor dbeta;
139   if (grad_input_mask[0]) {
140     dX = at::native::empty_like(
141         *X,
142         std::nullopt /* dtype */,
143         std::nullopt /* layout */,
144         std::nullopt /* device */,
145         std::nullopt /* pin_memory */,
146         at::MemoryFormat::Contiguous);
147   }
148   if (grad_input_mask[1]) {
149     dgamma = M > 0 ? at::native::empty_like(
150                          *gamma,
151                          std::nullopt /* dtype */,
152                          std::nullopt /* layout */,
153                          std::nullopt /* device */,
154                          std::nullopt /* pin_memory */,
155                          at::MemoryFormat::Contiguous)
156                    : at::native::zeros_like(
157                          *gamma,
158                          std::nullopt /* dtype */,
159                          std::nullopt /* layout */,
160                          std::nullopt /* device */,
161                          std::nullopt /* pin_memory */,
162                          at::MemoryFormat::Contiguous);
163   }
164   if (grad_input_mask[2]) {
165     dbeta = M > 0 ? at::native::empty_like(
166                         *beta,
167                         std::nullopt /* dtype */,
168                         std::nullopt /* layout */,
169                         std::nullopt /* device */,
170                         std::nullopt /* pin_memory */,
171                         at::MemoryFormat::Contiguous)
172                   : at::native::zeros_like(
173                         *beta,
174                         std::nullopt /* dtype */,
175                         std::nullopt /* layout */,
176                         std::nullopt /* device */,
177                         std::nullopt /* pin_memory */,
178                         at::MemoryFormat::Contiguous);
179   }
180   if (M > 0) {
181     LayerNormBackwardKernel(
182         kCPU, dY, *X, mean, rstd, *gamma, M, N, &dX, &dgamma, &dbeta);
183   }
184   return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta));
185 }
186 
layer_norm_symint(const Tensor & input,c10::SymIntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,double eps,bool)187 Tensor layer_norm_symint(
188     const Tensor& input,
189     c10::SymIntArrayRef normalized_shape, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& bias_opt /* optional */,
190     double eps,
191     bool /* cudnn_enable, deprecated */) {
192   // See [Note: hacky wrapper removal for optional tensor]
193   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
194   const Tensor& weight = *weight_maybe_owned;
195   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
196   const Tensor& bias = *bias_maybe_owned;
197 
198   return std::get<0>(at::native_layer_norm_symint(input, normalized_shape, weight, bias, eps));
199 }
200 
201 DEFINE_DISPATCH(LayerNormKernel);
202 DEFINE_DISPATCH(LayerNormBackwardKernel);
203 
204 // Ported from pytorch/xla repo
math_native_layer_norm(const Tensor & input,IntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,double eps)205 std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
206     const Tensor& input,
207     IntArrayRef normalized_shape, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
208     double eps) {
209   // See [Note: hacky wrapper removal for optional tensor]
210   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
211   const Tensor& weight = *weight_maybe_owned;
212   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
213   const Tensor& bias = *bias_maybe_owned;
214 
215   auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
216   auto M = M_N.first;
217   auto X = input.expect_contiguous();
218   auto gamma = weight.expect_contiguous();
219 
220   auto input_shape = input.sizes();
221   const auto input_ndim = input.dim();
222   const int normalized_ndim = normalized_shape.size();
223   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
224   const int axis = input_ndim - normalized_ndim;
225 
226   // Properly handle zero-size inputs: the view(1, M, -1) call below breaks on this.
227   if (input.numel() == 0) {
228     auto result_type = c10::promoteTypes(input.scalar_type(), kFloat);
229     return std::make_tuple(
230       at::empty_like(input),
231       at::empty_like(input, c10::TensorOptions().dtype(result_type)),
232       at::empty_like(input, c10::TensorOptions().dtype(result_type))
233     );
234   }
235   at::Tensor input_reshaped = input.reshape({1, M, -1});
236   // Unlike Batch Normalization, which applies scalar scale and bias for each
237   // entire channel/plane with the affine option, Layer Normalization applies
238   // per-element scale and bias. E.g. For input {N, C, H, W}, weight for
239   // batchnorm has shape {C} while weight for layernorm has shape {H, W} or {W}.
240   auto outputs = at::native_batch_norm(
241       input_reshaped, /*weight=*/{}, /*bias=*/{}, /*running_mean=*/{},
242       /*running_var=*/{}, /*training=*/true, /*momentum=*/0, eps);
243   at::Tensor out = std::get<0>(outputs);
244   out = out.view(input_shape);
245   if (weight.defined() && bias.defined()) {
246     out = bias.addcmul(out, weight, 1);
247   } else if (weight.defined()) {
248     out = out.mul(weight);
249   } else if (bias.defined()) {
250     out = out.add(bias);
251   }
252   at::Tensor mean = std::get<1>(outputs);
253   at::Tensor rstd = std::get<2>(outputs);
254   std::vector<int64_t> stat_shape;
255   for (const auto idx : c10::irange(axis)) {
256     stat_shape.push_back(input_shape[idx]);
257   }
258   for (const auto idx C10_UNUSED : c10::irange(axis, input.dim())) {
259     stat_shape.push_back(1);
260   }
261   mean = mean.view(stat_shape);
262   rstd = rstd.view(stat_shape);
263   return std::make_tuple(out, mean, rstd);
264 }
265 
rms_norm(const Tensor & input,IntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,std::optional<double> eps)266 Tensor rms_norm(
267     const Tensor& input,
268     IntArrayRef normalized_shape,
269     const std::optional<Tensor>& weight_opt /* optional */,
270     std::optional<double> eps) {
271 
272   // See [Note: hacky wrapper removal for optional tensor]
273   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
274   const Tensor& weight = *weight_maybe_owned;
275   auto bias_opt = std::optional<Tensor>();
276   const Tensor& bias = *at::borrow_from_optional_tensor(bias_opt);
277   (void) _check_layer_norm_inputs(input, normalized_shape, weight, bias);
278 
279   std::vector<int64_t> dims_to_reduce;
280   for (const auto i : c10::irange(normalized_shape.size())) {
281     dims_to_reduce.push_back(input.dim() - i - 1);
282   }
283   IntArrayRef dims_to_reduce_ref = IntArrayRef(dims_to_reduce);
284 
285   auto result = AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
286         at::ScalarType::Half,
287         at::ScalarType::BFloat16,
288         input.scalar_type(),
289         "rms_norm",
290         [&] {
291     scalar_t eps_val;
292     if (!eps.has_value()) {
293       eps_val = std::numeric_limits<at::scalar_value_type<scalar_t>::type>::epsilon();
294     } else {
295       eps_val = eps.value();
296     }
297 
298     auto result = input.mul(at::rsqrt(at::pow(input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val)));
299 
300     if (weight_opt.has_value()) {
301       result = result.mul(weight_opt.value());
302     }
303 
304     return result;
305   });
306 
307   return result;
308 
309 }
310 } // namespace at::native
311