1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/layer_norm.h>
3
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/native/cpu/mixed_data_type.h>
8 #include <c10/util/irange.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/empty_like_native.h>
17 #include <ATen/ops/layer_norm_native.h>
18 #include <ATen/ops/native_batch_norm.h>
19 #include <ATen/ops/native_layer_norm.h>
20 #include <ATen/ops/native_layer_norm_backward_native.h>
21 #include <ATen/ops/native_layer_norm_native.h>
22 #include <ATen/ops/pow.h>
23 #include <ATen/ops/rsqrt.h>
24 #include <ATen/ops/rms_norm.h>
25 #include <ATen/ops/zeros_like_native.h>
26 #endif
27
28 #include <array>
29 #include <tuple>
30 #include <vector>
31
32 namespace at::native {
33
layer_norm_with_mean_rstd_out(at::Tensor & out,at::Tensor & mean,at::Tensor & rstd,const at::Tensor & input,IntArrayRef normalized_shape,const Tensor & gamma,const Tensor & beta,double eps,int64_t M,int64_t N)34 static void layer_norm_with_mean_rstd_out(
35 at::Tensor& out,
36 at::Tensor& mean,
37 at::Tensor& rstd,
38 const at::Tensor& input,
39 IntArrayRef normalized_shape,
40 const Tensor& gamma,
41 const Tensor& beta,
42 double eps,
43 int64_t M,
44 int64_t N) {
45 LayerNormKernel(kCPU, input, gamma, beta, M, N, eps, &out, &mean, &rstd);
46 const auto input_shape = input.sizes();
47 const size_t axis = input.dim() - normalized_shape.size();
48
49 DimVector stat_shape;
50 for (const auto idx : c10::irange(axis)) {
51 stat_shape.emplace_back(input_shape[idx]);
52 }
53 for (const auto idx C10_UNUSED : c10::irange(axis, input.dim())) {
54 stat_shape.emplace_back(1);
55 }
56
57 mean = mean.view(stat_shape);
58 rstd = rstd.view(stat_shape);
59 }
60
layer_norm_cpu_out(at::Tensor & out,const at::Tensor & input,const Tensor & gamma,const Tensor & beta,double eps,int64_t M,int64_t N)61 void layer_norm_cpu_out(
62 at::Tensor& out,
63 const at::Tensor& input,
64 const Tensor& gamma,
65 const Tensor& beta,
66 double eps,
67 int64_t M,
68 int64_t N) {
69 if (M <= 0) {
70 return;
71 }
72 LayerNormKernel(kCPU, input, gamma, beta, M, N, eps, &out, /*mean=*/nullptr, /*rstd=*/nullptr);
73 }
74
layer_norm_cpu(const Tensor & input,IntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,double eps)75 std::tuple<Tensor, Tensor, Tensor> layer_norm_cpu(
76 const Tensor& input,
77 IntArrayRef normalized_shape, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& bias_opt /* optional */,
78 double eps) {
79 // See [Note: hacky wrapper removal for optional tensor]
80 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
81 const Tensor& weight = *weight_maybe_owned;
82 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
83 const Tensor& bias = *bias_maybe_owned;
84
85 bool mixed_type = is_mixed_type(input, weight, bias);
86 if (mixed_type) {
87 check_mixed_data_type(input, weight, bias);
88 }
89
90 auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
91 auto M = M_N.first;
92 auto N = M_N.second;
93 auto X = input.expect_contiguous();
94 auto gamma = weight.expect_contiguous();
95 auto beta = bias.expect_contiguous();
96
97 Tensor Y = at::native::empty_like(
98 *X,
99 std::nullopt /* dtype */,
100 std::nullopt /* layout */,
101 std::nullopt /* device */,
102 std::nullopt /* pin_memory */,
103 at::MemoryFormat::Contiguous);
104 const auto dtype = param_scalar_type(input, mixed_type);
105 Tensor mean = at::empty({M}, X->options().dtype(dtype));
106 Tensor rstd = at::empty({M}, X->options().dtype(dtype));
107
108 layer_norm_with_mean_rstd_out(Y, mean, rstd, *X, normalized_shape, *gamma, *beta, eps, M, N);
109 return std::make_tuple(std::move(Y), std::move(mean), std::move(rstd));
110 }
111
layer_norm_backward_cpu(const Tensor & dY,const Tensor & input,IntArrayRef normalized_shape,const Tensor & mean,const Tensor & rstd,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,std::array<bool,3> grad_input_mask)112 std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_cpu(
113 const Tensor& dY,
114 const Tensor& input,
115 IntArrayRef normalized_shape,
116 const Tensor& mean,
117 const Tensor& rstd,
118 const std::optional<Tensor>& weight_opt /* optional */,
119 const std::optional<Tensor>& bias_opt /* optional */,
120 std::array<bool, 3> grad_input_mask) {
121 // See [Note: hacky wrapper removal for optional tensor]
122 c10::MaybeOwned<Tensor> weight_maybe_owned =
123 at::borrow_from_optional_tensor(weight_opt);
124 const Tensor& weight = *weight_maybe_owned;
125 c10::MaybeOwned<Tensor> bias_maybe_owned =
126 at::borrow_from_optional_tensor(bias_opt);
127 const Tensor& bias = *bias_maybe_owned;
128
129 auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
130 auto M = M_N.first;
131 auto N = M_N.second;
132 auto X = input.expect_contiguous();
133 auto gamma = weight.expect_contiguous();
134 auto beta = bias.expect_contiguous();
135
136 Tensor dX;
137 Tensor dgamma;
138 Tensor dbeta;
139 if (grad_input_mask[0]) {
140 dX = at::native::empty_like(
141 *X,
142 std::nullopt /* dtype */,
143 std::nullopt /* layout */,
144 std::nullopt /* device */,
145 std::nullopt /* pin_memory */,
146 at::MemoryFormat::Contiguous);
147 }
148 if (grad_input_mask[1]) {
149 dgamma = M > 0 ? at::native::empty_like(
150 *gamma,
151 std::nullopt /* dtype */,
152 std::nullopt /* layout */,
153 std::nullopt /* device */,
154 std::nullopt /* pin_memory */,
155 at::MemoryFormat::Contiguous)
156 : at::native::zeros_like(
157 *gamma,
158 std::nullopt /* dtype */,
159 std::nullopt /* layout */,
160 std::nullopt /* device */,
161 std::nullopt /* pin_memory */,
162 at::MemoryFormat::Contiguous);
163 }
164 if (grad_input_mask[2]) {
165 dbeta = M > 0 ? at::native::empty_like(
166 *beta,
167 std::nullopt /* dtype */,
168 std::nullopt /* layout */,
169 std::nullopt /* device */,
170 std::nullopt /* pin_memory */,
171 at::MemoryFormat::Contiguous)
172 : at::native::zeros_like(
173 *beta,
174 std::nullopt /* dtype */,
175 std::nullopt /* layout */,
176 std::nullopt /* device */,
177 std::nullopt /* pin_memory */,
178 at::MemoryFormat::Contiguous);
179 }
180 if (M > 0) {
181 LayerNormBackwardKernel(
182 kCPU, dY, *X, mean, rstd, *gamma, M, N, &dX, &dgamma, &dbeta);
183 }
184 return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta));
185 }
186
layer_norm_symint(const Tensor & input,c10::SymIntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,double eps,bool)187 Tensor layer_norm_symint(
188 const Tensor& input,
189 c10::SymIntArrayRef normalized_shape, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& bias_opt /* optional */,
190 double eps,
191 bool /* cudnn_enable, deprecated */) {
192 // See [Note: hacky wrapper removal for optional tensor]
193 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
194 const Tensor& weight = *weight_maybe_owned;
195 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
196 const Tensor& bias = *bias_maybe_owned;
197
198 return std::get<0>(at::native_layer_norm_symint(input, normalized_shape, weight, bias, eps));
199 }
200
201 DEFINE_DISPATCH(LayerNormKernel);
202 DEFINE_DISPATCH(LayerNormBackwardKernel);
203
204 // Ported from pytorch/xla repo
math_native_layer_norm(const Tensor & input,IntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,double eps)205 std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
206 const Tensor& input,
207 IntArrayRef normalized_shape, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
208 double eps) {
209 // See [Note: hacky wrapper removal for optional tensor]
210 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
211 const Tensor& weight = *weight_maybe_owned;
212 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
213 const Tensor& bias = *bias_maybe_owned;
214
215 auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
216 auto M = M_N.first;
217 auto X = input.expect_contiguous();
218 auto gamma = weight.expect_contiguous();
219
220 auto input_shape = input.sizes();
221 const auto input_ndim = input.dim();
222 const int normalized_ndim = normalized_shape.size();
223 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
224 const int axis = input_ndim - normalized_ndim;
225
226 // Properly handle zero-size inputs: the view(1, M, -1) call below breaks on this.
227 if (input.numel() == 0) {
228 auto result_type = c10::promoteTypes(input.scalar_type(), kFloat);
229 return std::make_tuple(
230 at::empty_like(input),
231 at::empty_like(input, c10::TensorOptions().dtype(result_type)),
232 at::empty_like(input, c10::TensorOptions().dtype(result_type))
233 );
234 }
235 at::Tensor input_reshaped = input.reshape({1, M, -1});
236 // Unlike Batch Normalization, which applies scalar scale and bias for each
237 // entire channel/plane with the affine option, Layer Normalization applies
238 // per-element scale and bias. E.g. For input {N, C, H, W}, weight for
239 // batchnorm has shape {C} while weight for layernorm has shape {H, W} or {W}.
240 auto outputs = at::native_batch_norm(
241 input_reshaped, /*weight=*/{}, /*bias=*/{}, /*running_mean=*/{},
242 /*running_var=*/{}, /*training=*/true, /*momentum=*/0, eps);
243 at::Tensor out = std::get<0>(outputs);
244 out = out.view(input_shape);
245 if (weight.defined() && bias.defined()) {
246 out = bias.addcmul(out, weight, 1);
247 } else if (weight.defined()) {
248 out = out.mul(weight);
249 } else if (bias.defined()) {
250 out = out.add(bias);
251 }
252 at::Tensor mean = std::get<1>(outputs);
253 at::Tensor rstd = std::get<2>(outputs);
254 std::vector<int64_t> stat_shape;
255 for (const auto idx : c10::irange(axis)) {
256 stat_shape.push_back(input_shape[idx]);
257 }
258 for (const auto idx C10_UNUSED : c10::irange(axis, input.dim())) {
259 stat_shape.push_back(1);
260 }
261 mean = mean.view(stat_shape);
262 rstd = rstd.view(stat_shape);
263 return std::make_tuple(out, mean, rstd);
264 }
265
rms_norm(const Tensor & input,IntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,std::optional<double> eps)266 Tensor rms_norm(
267 const Tensor& input,
268 IntArrayRef normalized_shape,
269 const std::optional<Tensor>& weight_opt /* optional */,
270 std::optional<double> eps) {
271
272 // See [Note: hacky wrapper removal for optional tensor]
273 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
274 const Tensor& weight = *weight_maybe_owned;
275 auto bias_opt = std::optional<Tensor>();
276 const Tensor& bias = *at::borrow_from_optional_tensor(bias_opt);
277 (void) _check_layer_norm_inputs(input, normalized_shape, weight, bias);
278
279 std::vector<int64_t> dims_to_reduce;
280 for (const auto i : c10::irange(normalized_shape.size())) {
281 dims_to_reduce.push_back(input.dim() - i - 1);
282 }
283 IntArrayRef dims_to_reduce_ref = IntArrayRef(dims_to_reduce);
284
285 auto result = AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
286 at::ScalarType::Half,
287 at::ScalarType::BFloat16,
288 input.scalar_type(),
289 "rms_norm",
290 [&] {
291 scalar_t eps_val;
292 if (!eps.has_value()) {
293 eps_val = std::numeric_limits<at::scalar_value_type<scalar_t>::type>::epsilon();
294 } else {
295 eps_val = eps.value();
296 }
297
298 auto result = input.mul(at::rsqrt(at::pow(input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val)));
299
300 if (weight_opt.has_value()) {
301 result = result.mul(weight_opt.value());
302 }
303
304 return result;
305 });
306
307 return result;
308
309 }
310 } // namespace at::native
311