xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchRulesNorm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/functorch/PlumbingHelper.h>
9 #include <ATen/functorch/BatchedFallback.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11 
12 namespace at::functorch {
13 
is_empty_tensor(const Tensor & tensor)14 static bool is_empty_tensor(const Tensor& tensor) {
15   const auto shape = tensor.sizes();
16   return shape.size() == 1 && shape[0] == 0;
17 }
18 
compute_stat_bdim(std::optional<int64_t> input_bdim,const Tensor & stat)19 static std::optional<int64_t> compute_stat_bdim(
20     std::optional<int64_t> input_bdim,
21     const Tensor& stat) {
22   // There's a weird case where mean, rstd can both have shape (0,).
23   // It's possible that this is a bug on the PyTorch side.
24   // When that happens we don't want to return a BatchedTensor.
25   if (input_bdim.has_value() && !is_empty_tensor(stat)) {
26     return 0;
27   }
28   return std::nullopt;
29 }
30 
padRight(const Tensor & tensor,std::optional<int64_t> has_bdim,int64_t logical_rank)31 static Tensor padRight(const Tensor& tensor, std::optional<int64_t> has_bdim, int64_t logical_rank) {
32   // NB: Batch dim, if it exists, is assumed to be the first dim
33   auto tensor_logical_rank = rankWithoutBatchDim(tensor, has_bdim);
34   if (tensor_logical_rank >= logical_rank) {
35     return tensor;
36   }
37   VmapDimVector new_sizes(tensor.sizes().begin(), tensor.sizes().end());
38   for (int64_t i = 0; i < logical_rank - tensor_logical_rank; i++) {
39     new_sizes.push_back(1);
40   }
41   return tensor.view(new_sizes);
42 }
43 
44 template<typename F, F Func>
45 std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
batch_norm_batch_rule(const Tensor & input,std::optional<int64_t> input_bdim,const std::optional<Tensor> & weight_opt,std::optional<int64_t> weight_bdim,const std::optional<Tensor> & bias_opt,std::optional<int64_t> bias_bdim,const std::optional<Tensor> & running_mean_opt,std::optional<int64_t> running_mean_bdim,const std::optional<Tensor> & running_var_opt,std::optional<int64_t> running_var_bdim,bool training,double momentum,double eps)46 batch_norm_batch_rule(
47     const Tensor& input, std::optional<int64_t> input_bdim,
48     const std::optional<Tensor>& weight_opt, std::optional<int64_t> weight_bdim,
49     const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
50     const std::optional<Tensor>& running_mean_opt, std::optional<int64_t> running_mean_bdim,
51     const std::optional<Tensor>& running_var_opt, std::optional<int64_t> running_var_bdim,
52     bool training, double momentum, double eps) {
53   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
54   const Tensor& weight = *weight_maybe_owned;
55   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
56   const Tensor& bias = *bias_maybe_owned;
57   c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
58   const auto& running_mean = *running_mean_maybe_owned;
59   c10::MaybeOwned<Tensor> running_var_maybe_owned = at::borrow_from_optional_tensor(running_var_opt);
60   const auto& running_var = *running_var_maybe_owned;
61   TORCH_CHECK(!training || (!input_bdim || ((!running_mean.defined() || running_mean_bdim) && (!running_var.defined() || running_var_bdim))),
62       "Batch norm got a batched tensor as input while the running_mean or running_var, which will be updated in place, ",
63       "were not batched.\nIf you are using a module and do not need eval mode, please set `track_running_stats` to be False.",
64       "If you are using a prebuilt module and do not need eval mode, please see the functorch website for resources on ",
65       "how to patch your module to work with vmap");
66   std::optional<int64_t> bdim_size;
67   Tensor result0;
68   Tensor mean;
69   Tensor rstd;
70   if (!input_bdim && !running_mean_bdim && !running_var_bdim) {
71     const auto dummy_weight = at::ones(input.size(1), input.options());  // cudnn and miopen require a weight
72     const auto dummy_bias = at::zeros(input.size(1), input.options());   // without this, get "strides() called on undefined Tensor" on cuda
73     const auto result = Func(input, dummy_weight, dummy_bias, running_mean_opt, running_var_opt, training, momentum, eps);
74     result0 = std::get<0>(result).transpose(0, 1);          // [C, B, *]
75     mean = std::get<1>(result);
76     rstd = std::get<2>(result);
77   } else {
78     bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim);
79     auto input_ = moveBatchDimToFront(input, input_bdim);
80     input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size.value());
81     input_ = reshape_dim_into(0, /*channels dim*/1, input_);
82 
83     std::optional<Tensor> running_mean_;
84     std::optional<Tensor> running_var_;
85     if (running_mean.defined()) {
86       running_mean_ = moveBatchDimToFront(running_mean, running_mean_bdim);
87       running_mean_ = ensure_has_bdim(*running_mean_, running_mean_bdim.has_value(), bdim_size.value());
88       running_mean_ = reshape_dim_into(0, 0, *running_mean_).contiguous();
89     }
90     if (running_var.defined()) {
91       running_var_ = moveBatchDimToFront(running_var, running_var_bdim);
92       running_var_ = ensure_has_bdim(*running_var_, running_var_bdim.has_value(), bdim_size.value());
93       running_var_ = reshape_dim_into(0, 0, *running_var_).contiguous();
94     }
95 
96     const auto dummy_weight = at::ones(input_.size(1), input_.options());  // cudnn and miopen require a weight
97     const auto dummy_bias = at::zeros(input_.size(1), input_.options());   // without this, get "strides() called on undefined Tensor" on cuda
98     const auto result = Func(input_, dummy_weight, dummy_bias, running_mean_, running_var_, training, momentum, eps);
99     result0 = std::get<0>(result).transpose(0, 1);                // [(B0, C), B, *]
100     result0 = reshape_dim_outof(0, bdim_size.value(), result0);   // [B0, C, B, *]
101     mean = std::get<1>(result);
102     mean = reshape_dim_outof(0, bdim_size.value(), mean);         // [B0, C]
103     rstd = std::get<2>(result);
104     rstd = reshape_dim_outof(0, bdim_size.value(), rstd);         // [B0, C]
105   }
106 
107   const auto stats_bdim = compute_stat_bdim(bdim_size, mean);
108   if (weight.defined()) {
109     const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
110     auto weight_ = moveBatchDimToFront(weight, weight_bdim);
111     weight_ = padRight(weight_, weight_bdim, input_logical_rank);
112     result0 = result0 * weight_;
113   }
114   if (bias.defined()) {
115     const auto result_logical_rank = rankWithoutBatchDim(
116         result0,
117         bdim_size.has_value() || weight_bdim.has_value() ? std::optional<int64_t>(0) : std::optional<int64_t>(std::nullopt));
118     auto bias_ = moveBatchDimToFront(bias, bias_bdim);
119     bias_ = padRight(bias_, bias_bdim, result_logical_rank);
120     result0 = result0 + bias_;
121   }
122   result0 = result0.transpose(1, 2);  // [B0, B, C, *], because some arg must have been batched, the output must be batched
123   return std::make_tuple(result0, 0, mean, stats_bdim, rstd, stats_bdim);
124 }
125 
126 template<typename F, F Func>
batch_norm_backward_no_weight_bias_batch_rule(const at::Tensor & grad_out,std::optional<int64_t> grad_out_bdim,const at::Tensor & input,std::optional<int64_t> input_bdim,const std::optional<at::Tensor> & running_mean_opt,std::optional<int64_t> running_mean_bdim,const std::optional<at::Tensor> & running_var_opt,std::optional<int64_t> running_var_bdim,const at::Tensor & mean,std::optional<int64_t> mean_bdim,const at::Tensor & rstd,std::optional<int64_t> rstd_bdim,bool training,double eps)127 std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bias_batch_rule(
128     const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
129     const at::Tensor & input, std::optional<int64_t> input_bdim,
130     const std::optional<at::Tensor> & running_mean_opt, std::optional<int64_t> running_mean_bdim,
131     const std::optional<at::Tensor> & running_var_opt, std::optional<int64_t> running_var_bdim,
132     const at::Tensor & mean, std::optional<int64_t> mean_bdim,
133     const at::Tensor & rstd, std::optional<int64_t> rstd_bdim,
134     bool training, double eps) {
135   c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
136   const Tensor& running_mean = *running_mean_maybe_owned;
137   c10::MaybeOwned<Tensor> running_var_maybe_owned = at::borrow_from_optional_tensor(running_var_opt);
138   const Tensor& running_var = *running_var_maybe_owned;
139 
140   if (!grad_out_bdim.has_value() && !input_bdim.has_value() && !running_mean_bdim.has_value() && !running_var_bdim.has_value()) {
141     // for either of these to have bdims, the input, running_mean, or running_var must have had a bdim
142     TORCH_INTERNAL_ASSERT(!mean_bdim);
143     TORCH_INTERNAL_ASSERT(!rstd_bdim);
144     const auto dummy_weight = at::ones(input.size(1), input.options());
145     const auto result = Func(
146         grad_out, input, dummy_weight, running_mean_opt, running_var_opt, mean, rstd, training, eps, {true, false, false});
147     return std::make_tuple(std::get<0>(result), std::nullopt);
148   }
149 
150   auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
151   auto input_ = moveBatchDimToFront(input, input_bdim);
152   auto mean_ = moveBatchDimToFront(mean, mean_bdim);
153   auto rstd_ = moveBatchDimToFront(rstd, rstd_bdim);
154 
155   // ensure all inputs have bdim.
156   const auto bdim_size = get_bdim_size4(grad_out, grad_out_bdim, input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim);
157   grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), bdim_size);
158   input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
159   mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size);
160   rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size);
161 
162   std::optional<Tensor> running_mean_;
163   std::optional<Tensor> running_var_;
164   if (running_mean.defined()) {
165     running_mean_ = moveBatchDimToFront(running_mean, running_mean_bdim);
166     running_mean_ = ensure_has_bdim(*running_mean_, running_mean_bdim.has_value(), bdim_size);
167     running_mean_ = reshape_dim_into(0, 0, *running_mean_).contiguous();
168   }
169   if (running_var.defined()) {
170     running_var_ = moveBatchDimToFront(running_var, running_var_bdim);
171     running_var_ = ensure_has_bdim(*running_var_, running_var_bdim.has_value(), bdim_size);
172     running_var_ = reshape_dim_into(0, 0, *running_var_).contiguous();
173   }
174 
175   input_ = reshape_dim_into(0, /*channels dim*/1, input_);
176   TORCH_INTERNAL_ASSERT(mean_.dim() == 2);
177   TORCH_INTERNAL_ASSERT(rstd_.dim() == 2);
178   mean_ = reshape_dim_into(0, 0, mean_);
179   rstd_ = reshape_dim_into(0, 0, rstd_);
180   grad_out_ = grad_out_.transpose(0, 1).flatten(1, 2); // [B0, B, C, *] -> [B, (B0, C), *]
181 
182   const auto dummy_weight = at::ones(input_.size(1), input_.options());
183   auto result = at::native_batch_norm_backward(
184       grad_out_.contiguous(),
185       input_.contiguous(),
186       dummy_weight,
187       running_mean_,  // contiguous called if there is a tensor given
188       running_var_,   // contiguous called if there is a tensor given
189       mean_.contiguous(),
190       rstd_.contiguous(),
191       training, eps, {true, false, false});
192   auto result0 = std::get<0>(result);
193   result0 = reshape_dim_outof(1, bdim_size, result0); // [B, B0, C, *]
194   result0 = result0.transpose(0, 1); // [B0, B, C, *]
195   return std::make_tuple(result0, 0);
196 }
197 
198 template<typename F, F Func>
batch_norm_backward_plumbing(const at::Tensor & grad_out,const at::Tensor & input,const std::optional<at::Tensor> & weight_opt,const std::optional<at::Tensor> & running_mean_opt,const std::optional<at::Tensor> & running_var_opt,const std::optional<at::Tensor> & save_mean_opt,const std::optional<at::Tensor> & save_rstd_opt,bool training,double eps,std::array<bool,3> output_mask)199 std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
200     const at::Tensor & grad_out,
201     const at::Tensor & input,
202     const std::optional<at::Tensor> & weight_opt,
203     const std::optional<at::Tensor> & running_mean_opt,
204     const std::optional<at::Tensor> & running_var_opt,
205     const std::optional<at::Tensor> & save_mean_opt,
206     const std::optional<at::Tensor> & save_rstd_opt,
207     bool training,
208     double eps,
209     std::array<bool,3> output_mask) {
210   // See [Note: hacky wrapper removal for optional tensor]
211   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
212   const Tensor& weight = *weight_maybe_owned;
213   c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
214   const Tensor& running_mean = *running_mean_maybe_owned;
215   c10::MaybeOwned<Tensor> running_var_maybe_owned = at::borrow_from_optional_tensor(running_var_opt);
216   const Tensor& running_var = *running_var_maybe_owned;
217   // NB: not sure why these are optional...these are required from the forward
218   const Tensor& save_mean = *save_mean_opt;
219   const Tensor& save_rstd = *save_rstd_opt;
220   TORCH_INTERNAL_ASSERT(save_mean.defined());
221   TORCH_INTERNAL_ASSERT(save_rstd.defined());
222 
223   // plumbing
224   auto maybe_layer = maybeCurrentDynamicLayer();
225   vmap_check_escaped(maybe_layer, "batch_norm_backward_plumbing");
226   int64_t cur_level = maybe_layer->layerId();
227 
228   auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level);
229   auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level);
230   Tensor mean_value;
231   std::optional<Tensor> weight_value;
232   std::optional<int64_t> weight_bdim;
233   if (weight.defined()) {
234     std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level);
235   }
236   std::optional<Tensor> running_mean_value;
237   std::optional<int64_t> running_mean_bdim;
238   if (running_mean.defined()) {
239     std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean, cur_level);
240   }
241   std::optional<Tensor> running_var_value;
242   std::optional<int64_t> running_var_bdim;
243   if (running_var.defined()) {
244     std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var, cur_level);
245   }
246   auto [save_mean_value, save_mean_bdim] = unwrapTensorAtLevel(save_mean, cur_level);
247   auto [save_rstd_value, save_rstd_bdim] = unwrapTensorAtLevel(save_rstd, cur_level);
248 
249   // results
250   Tensor grad_bias;
251   Tensor grad_weight;
252   Tensor grad_input;
253 
254   TORCH_INTERNAL_ASSERT(grad_out_value.dim() > 1);  // batch_norm can't operate on 1D tensors so the output will be at least 2D
255   if (output_mask[2]) {
256     grad_bias = grad_out.transpose(0, 1).sum(range(1, grad_out.dim()));
257   }
258   if (output_mask[1] && weight_value.has_value()) {
259     // NB: output isn't saved...
260     auto mean = training ? save_mean : running_mean;
261     auto var = training ? save_rstd : (1 / at::sqrt(running_var + eps));
262     const auto normalized_input = (input.transpose(0, 1) - padRight(mean, std::nullopt, input.dim())) * padRight(var, std::nullopt, input.dim());
263     const auto expanded_grad_weight = normalized_input * grad_out.transpose(0, 1);
264     grad_weight = expanded_grad_weight.sum(range(1, grad_out.dim()));
265   }
266   if (output_mask[0]) {
267     const auto grad_normalized_input = weight.defined() ?
268       grad_out.transpose(0, 1) * padRight(weight, std::nullopt, grad_out.dim()) : grad_out.transpose(0, 1);           // [B0, C, B, *]
269     auto [grad_normalized_input_value, grad_normalized_input_bdim] =
270         unwrapTensorAtLevel(grad_normalized_input.transpose(0, 1), cur_level);       // [B0, B, C, *]
271 
272     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
273     const auto results = batch_norm_backward_no_weight_bias_batch_rule<F, Func>(
274         grad_normalized_input_value, grad_normalized_input_bdim,
275         input_value, input_bdim,
276         running_mean_value, running_mean_bdim,
277         running_var_value, running_var_bdim,
278         save_mean_value, save_mean_bdim,
279         save_rstd_value, save_rstd_bdim,
280         training, eps);
281     grad_input = makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
282   }
283   return std::make_tuple(grad_input, grad_weight, grad_bias);
284 }
285 
native_group_norm_plumbing(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps)286 static std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
287     const Tensor & input, const std::optional<Tensor> & weight_opt,
288     const std::optional<Tensor> & bias_opt, int64_t N, int64_t C,
289     int64_t HxW, int64_t group, double eps) {
290   // See [Note: hacky wrapper removal for optional tensor]
291   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
292   const Tensor& weight = *weight_maybe_owned;
293   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
294   const Tensor& bias = *bias_maybe_owned;
295 
296   auto maybe_layer = maybeCurrentDynamicLayer();
297   vmap_check_escaped(maybe_layer, "native_group_norm_plumbing");
298   int64_t cur_level = maybe_layer->layerId();
299 
300   if (!areAnyBatchedAtLevel({input, weight_opt, bias_opt}, cur_level)) {
301     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
302     return at::native_group_norm(input, weight_opt, bias_opt, N, C, HxW, group, eps);
303   }
304 
305   auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level);
306 
307   Tensor result0;
308   Tensor mean;
309   Tensor rstd;
310   if (input_bdim) {
311     const auto input_ = reshape_dim_into(*input_bdim, 0, input_value);
312     const auto bdim_size = input_value.size(*input_bdim);
313 
314     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
315     const auto result = at::native_group_norm(input_, std::nullopt, std::nullopt, N * bdim_size, C, HxW, group, eps);
316     result0 = makeBatched(reshape_dim_outof(0, bdim_size, std::get<0>(result)), 0, cur_level);
317     mean = makeBatched(reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0, cur_level);
318     rstd = makeBatched(reshape_dim_outof(0, bdim_size, std::get<2>(result)), 0, cur_level);
319   } else {
320     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
321     const auto result = at::native_group_norm(input_value, std::nullopt, std::nullopt, N, C, HxW, group, eps);
322     result0 = std::get<0>(result);
323     mean = std::get<1>(result);
324     rstd = std::get<2>(result);
325   }
326 
327   if (weight.defined()) {
328     const auto padded_weight = padRight(weight, std::nullopt, result0.dim() - 1);
329     result0 = result0 * padded_weight;
330   }
331 
332   if (bias.defined()) {
333     const auto padded_bias = padRight(bias, std::nullopt, result0.dim() - 1);
334     result0 = result0 + padded_bias;
335   }
336 
337   return std::make_tuple(result0, mean, rstd);
338 }
339 
group_norm_backward_no_weight_bias_batch_rule(const at::Tensor & grad_out,std::optional<int64_t> grad_out_bdim,const at::Tensor & input,std::optional<int64_t> input_bdim,const at::Tensor & mean,std::optional<int64_t> mean_bdim,const at::Tensor & rstd,std::optional<int64_t> rstd_bdim,int64_t N,int64_t C,int64_t HxW,int64_t group)340 static std::tuple<at::Tensor, std::optional<int64_t>> group_norm_backward_no_weight_bias_batch_rule(
341     const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
342     const at::Tensor & input, std::optional<int64_t> input_bdim,
343     const at::Tensor & mean, std::optional<int64_t> mean_bdim,
344     const at::Tensor & rstd, std::optional<int64_t> rstd_bdim,
345     int64_t N, int64_t C, int64_t HxW, int64_t group) {
346   auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
347   auto input_ = moveBatchDimToFront(input, input_bdim);
348   auto mean_ = moveBatchDimToFront(mean, mean_bdim);
349   auto rstd_ = moveBatchDimToFront(rstd, rstd_bdim);
350 
351   const auto bdim_size = get_bdim_size2(grad_out, grad_out_bdim, input, input_bdim);
352   grad_out_ = ensure_has_bdim(grad_out, grad_out_bdim.has_value(), bdim_size);
353   input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
354   mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size);
355   rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size);
356 
357   grad_out_ = reshape_dim_into(0, 0, grad_out_); // [B0 * N, C, *]
358   input_ = reshape_dim_into(0, 0, input_);       // [B0 * N, C, *]
359   mean_ = reshape_dim_into(0, 0, mean_);         // [B0 * N, G]
360   rstd_ = reshape_dim_into(0, 0, rstd_);         // [B0 * N, G]
361 
362   const auto result = native_group_norm_backward(
363       grad_out_.contiguous(),
364       input_.contiguous(),
365       mean_.contiguous(),
366       rstd_.contiguous(),
367       std::nullopt, N * bdim_size, C, HxW, group, {true, false, false});
368   auto result0 = std::get<0>(result);
369   result0 = reshape_dim_outof(0, bdim_size, result0);
370   return std::make_tuple(result0, 0);
371 }
372 
native_group_norm_backward_plumbing(const Tensor & grad_out,const Tensor & input,const Tensor & mean,const Tensor & rstd,const std::optional<Tensor> & weight_opt,int64_t N,int64_t C,int64_t HxW,int64_t group,std::array<bool,3> output_mask)373 static std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
374   const Tensor & grad_out, const Tensor & input, const Tensor & mean,
375   const Tensor & rstd, const std::optional<Tensor> & weight_opt,
376   int64_t N, int64_t C, int64_t HxW, int64_t group, std::array<bool,3> output_mask
377 ) {
378   // See [Note: hacky wrapper removal for optional tensor]
379   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
380   const Tensor& weight = *weight_maybe_owned;
381 
382   // plumbing
383   auto maybe_layer = maybeCurrentDynamicLayer();
384   vmap_check_escaped(maybe_layer, "native_group_norm_backward_plumbing");
385   int64_t cur_level = maybe_layer->layerId();
386 
387   if (!areAnyBatchedAtLevel({grad_out, input, mean, rstd, weight_opt}, cur_level)) {
388     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
389     return at::native_group_norm_backward(grad_out, input, mean, rstd, weight_opt, N, C, HxW, group, output_mask);
390   }
391 
392   auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level);
393   Tensor weight_value;
394   std::optional<int64_t> weight_bdim;
395   if (weight.defined()){
396     std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level);
397   }
398   auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level);
399   auto [rstd_value, rstd_bdim] = unwrapTensorAtLevel(rstd, cur_level);
400 
401   // results
402   Tensor grad_input;
403   Tensor grad_weight;
404   Tensor grad_bias;
405 
406   TORCH_INTERNAL_ASSERT(grad_out.dim() > 1);  // group_norm can't operate on 1D tensors so the output will be at least 2D
407   if (output_mask[2]) {
408     grad_bias = grad_out.transpose(0, 1).sum(range(1, grad_out.dim()));
409   }
410 
411   if (output_mask[1] && weight.defined()) {
412     const auto reshaped_input = reshape_dim_outof(1, group, input);
413     const auto normalized_input = (reshaped_input - padRight(mean, std::nullopt, reshaped_input.dim())) * padRight(rstd, std::nullopt, reshaped_input.dim());
414     const auto expanded_grad_weight = reshape_dim_into(1, 1, normalized_input) * grad_out;
415     grad_weight = expanded_grad_weight.transpose(0, 1).sum(range(1, expanded_grad_weight.dim()));
416   }
417 
418   if (output_mask[0]) {
419     const auto grad_normalized_input = weight.defined() ?
420       grad_out * padRight(weight, std::nullopt, grad_out.dim() - 1) : grad_out;
421     auto [grad_normalized_input_value, grad_normalized_input_bdim] =
422         unwrapTensorAtLevel(grad_normalized_input, cur_level);
423 
424     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
425     const auto res = group_norm_backward_no_weight_bias_batch_rule(
426         grad_normalized_input_value, grad_normalized_input_bdim,
427         input_value, input_bdim,
428         mean_value, mean_bdim,
429         rstd_value, rstd_bdim,
430         N, C, HxW, group
431     );
432     grad_input = makeBatched(std::get<0>(res), std::get<1>(res), cur_level);
433   }
434   return std::make_tuple(grad_input, grad_weight, grad_bias);
435 }
436 
has_same_shape(const Tensor & tensor,std::optional<int64_t> tensor_bdim,c10::SymIntArrayRef normalized_shape)437 C10_ALWAYS_INLINE bool has_same_shape(
438     const Tensor& tensor, std::optional<int64_t> tensor_bdim,
439     c10::SymIntArrayRef normalized_shape) {
440   if (!tensor.defined()) {
441     return true;
442   }
443   if (rankWithoutBatchDim(tensor, tensor_bdim) != (int64_t) normalized_shape.size()) {
444     return false;
445   }
446   const auto tensor_shape = tensor.sizes();
447   for (const auto i : c10::irange(normalized_shape.size())) {
448     auto j = i;
449     // (0, 1, 2), 1 -> (0, 2, 3)
450     if (tensor_bdim.has_value() && (int64_t)i >= tensor_bdim.value()) {
451       j = j + 1;
452     }
453     if (normalized_shape[i] != tensor_shape[j]) {
454       return false;
455     }
456   }
457   return true;
458 }
459 
check_same_shape(const Tensor & tensor,std::optional<int64_t> tensor_bdim,c10::SymIntArrayRef normalized_shape,const std::string & name)460 C10_ALWAYS_INLINE void check_same_shape(
461     const Tensor& tensor, std::optional<int64_t> tensor_bdim,
462     c10::SymIntArrayRef normalized_shape, const std::string& name) {
463   TORCH_CHECK(has_same_shape(tensor, tensor_bdim, normalized_shape),
464       "Expected ", name, " to be of same shape as normalized_shape, but got ",
465       name, " of shape ",
466       tensor.sizes(),
467       " and normalized_shape = ",
468       normalized_shape);
469 }
470 
471 // Ugh, hard to deduplicate
_check_layer_norm_inputs(SymIntArrayRef normalized_shape,const Tensor & weight,std::optional<int64_t> weight_bdim,const Tensor & bias,std::optional<int64_t> bias_bdim)472 C10_ALWAYS_INLINE void _check_layer_norm_inputs(
473     SymIntArrayRef normalized_shape,
474     const Tensor& weight, std::optional<int64_t> weight_bdim,
475     const Tensor& bias, std::optional<int64_t> bias_bdim) {
476 
477   const auto normalized_ndim = normalized_shape.size();
478   TORCH_CHECK(
479       normalized_ndim >= 1,
480       "Expected normalized_shape to be at least 1-dimensional, i.e., ",
481       "containing at least one element, but got normalized_shape = ",
482       normalized_shape);
483   check_same_shape(weight, weight_bdim, normalized_shape, "weight");
484   check_same_shape(bias, bias_bdim, normalized_shape, "weight");
485 }
486 
487 static std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
native_layer_norm_batch_rule(const Tensor & input,std::optional<int64_t> input_bdim,c10::SymIntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,std::optional<int64_t> weight_bdim,const std::optional<Tensor> & bias_opt,std::optional<int64_t> bias_bdim,double eps)488 native_layer_norm_batch_rule(
489     const Tensor& input, std::optional<int64_t> input_bdim,
490     c10::SymIntArrayRef normalized_shape,
491     const std::optional<Tensor>& weight_opt, std::optional<int64_t> weight_bdim,
492     const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
493     double eps) {
494   auto input_ = moveBatchDimToFront(input, input_bdim);
495   if (!weight_bdim && !bias_bdim) {
496     const auto result = at::native_layer_norm_symint(input_, normalized_shape, weight_opt, bias_opt, eps);
497     const auto mean = std::get<1>(result);
498     const auto rstd = std::get<2>(result);
499     const auto stats_bdim = compute_stat_bdim(input_bdim, mean);
500     return std::make_tuple(std::get<0>(result), 0, mean, stats_bdim, rstd, stats_bdim);
501   }
502 
503   // See [Note: hacky wrapper removal for optional tensor]
504   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
505   const Tensor& weight = *weight_maybe_owned;
506   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
507   const Tensor& bias = *bias_maybe_owned;
508   _check_layer_norm_inputs(normalized_shape, weight, weight_bdim, bias, bias_bdim);
509 
510   const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
511   const auto result = at::native_layer_norm_symint(input_, normalized_shape, std::nullopt, std::nullopt, eps);
512   auto result0 = std::get<0>(result);
513   const auto mean = std::get<1>(result);
514   const auto rstd = std::get<2>(result);
515   const auto stats_bdim = compute_stat_bdim(input_bdim, mean);
516 
517   if (weight.defined()) {
518     auto weight_ = moveBatchDimToFront(weight, weight_bdim);
519     weight_ = maybePadToLogicalRank(weight_, /*has_bdim*/weight_bdim, input_logical_rank);
520     result0 = result0 * weight_;
521   }
522   if (bias.defined()) {
523     const auto result_logical_rank = rankWithoutBatchDim(
524         result0,
525         input_bdim.has_value() || weight_bdim.has_value() ? std::optional<int64_t>(0) : std::optional<int64_t>(std::nullopt));
526     auto bias_ = moveBatchDimToFront(bias, bias_bdim);
527     bias_ = maybePadToLogicalRank(bias_, /*has_bdim*/bias_bdim, result_logical_rank);
528     result0 = result0 + bias_;
529   }
530   return std::make_tuple(result0, 0, mean, stats_bdim, rstd, stats_bdim);
531 }
532 
native_layer_norm_backward_no_weight_bias_batch_rule(const at::Tensor & grad_out,std::optional<int64_t> grad_out_bdim,const at::Tensor & input,std::optional<int64_t> input_bdim,at::IntArrayRef normalized_shape,const at::Tensor & mean,std::optional<int64_t> mean_bdim,const at::Tensor & rstd,std::optional<int64_t> rstd_bdim)533 static std::tuple<at::Tensor, std::optional<int64_t>> native_layer_norm_backward_no_weight_bias_batch_rule(
534     const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
535     const at::Tensor & input, std::optional<int64_t> input_bdim,
536     at::IntArrayRef normalized_shape,
537     const at::Tensor & mean, std::optional<int64_t> mean_bdim,
538     const at::Tensor & rstd, std::optional<int64_t> rstd_bdim) {
539 
540   if (!grad_out_bdim.has_value() && !input_bdim.has_value() &&
541       !mean_bdim.has_value() && !rstd_bdim.has_value()) {
542     const auto result = at::native_layer_norm_backward(
543         grad_out, input, normalized_shape, mean, rstd, std::nullopt, std::nullopt, {true, false, false});
544     return std::make_tuple(std::get<0>(result), std::nullopt);
545   }
546 
547   auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
548   auto input_ = moveBatchDimToFront(input, input_bdim);
549   auto mean_ = moveBatchDimToFront(mean, mean_bdim);
550   auto rstd_ = moveBatchDimToFront(rstd, rstd_bdim);
551 
552   // ensure grad_out / input have bdim.
553   const auto bdim_size = get_bdim_size2(grad_out, grad_out_bdim, input, input_bdim);
554   grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), bdim_size);
555   input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
556   mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size);
557   rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size);
558 
559   auto result = at::native_layer_norm_backward(
560       grad_out_.contiguous(),
561       input_.contiguous(),
562       normalized_shape,
563       mean_.contiguous(),
564       rstd_.contiguous(),
565       std::nullopt, std::nullopt, {true, false, false});
566 
567   return std::make_tuple(std::get<0>(result), 0);
568 }
569 
native_layer_norm_backward_plumbing(const at::Tensor & grad_out,const at::Tensor & input,at::IntArrayRef normalized_shape,const at::Tensor & mean,const at::Tensor & rstd,const std::optional<at::Tensor> & weight_opt,const std::optional<at::Tensor> & bias_opt,std::array<bool,3> output_mask)570 static std::tuple<at::Tensor,at::Tensor,at::Tensor> native_layer_norm_backward_plumbing(
571     const at::Tensor & grad_out,
572     const at::Tensor & input,
573     at::IntArrayRef normalized_shape,
574     const at::Tensor & mean,
575     const at::Tensor & rstd,
576     const std::optional<at::Tensor> & weight_opt,
577     const std::optional<at::Tensor> & bias_opt,
578     std::array<bool,3> output_mask) {
579   // See [Note: hacky wrapper removal for optional tensor]
580   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
581   const Tensor& weight = *weight_maybe_owned;
582   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
583   const Tensor& bias = *bias_maybe_owned;
584 
585   // plumbing
586   auto maybe_layer = maybeCurrentDynamicLayer();
587   vmap_check_escaped(maybe_layer, "native_layer_norm_backward_plumbing");
588   int64_t cur_level = maybe_layer->layerId();
589   if (!areAnyBatchedAtLevel({grad_out, input, mean, rstd, weight_opt, bias_opt}, cur_level)) {
590     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
591     return at::native_layer_norm_backward(grad_out, input, normalized_shape, mean, rstd,
592         weight_opt, bias_opt, output_mask);
593   }
594   auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level);
595   auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level);
596   auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level);
597   auto [rstd_value, rstd_bdim] = unwrapTensorAtLevel(rstd, cur_level);
598   std::optional<Tensor> weight_value;
599   std::optional<int64_t> weight_bdim;
600   if (weight.defined()) {
601     std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level);
602   }
603   std::optional<Tensor> bias_value;
604   std::optional<int64_t> bias_bdim;
605   if (bias.defined()) {
606     std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias, cur_level);
607   }
608 
609   // results
610   Tensor grad_bias;
611   Tensor grad_weight;
612   Tensor grad_input;
613 
614   if (output_mask[2] && bias_value.has_value()) {
615     const auto num_front_dims_to_reduce = grad_out.dim() - normalized_shape.size();
616     if (num_front_dims_to_reduce == 0) {
617       grad_bias = grad_out;
618     } else {
619       grad_bias = grad_out.sum(range(0, static_cast<int64_t>(num_front_dims_to_reduce)));
620     }
621   }
622   if (output_mask[1] && weight_value.has_value()) {
623     // NB: output isn't saved...
624     const auto normalized_input = (input - mean) * rstd;
625     const auto expanded_grad_weight = normalized_input * grad_out;
626     const auto num_front_dims_to_reduce =
627         expanded_grad_weight.dim() - normalized_shape.size();
628     if (num_front_dims_to_reduce == 0) {
629       grad_weight = expanded_grad_weight;
630     } else {
631       grad_weight = expanded_grad_weight.sum(range(0, static_cast<int64_t>(num_front_dims_to_reduce)));
632     }
633   }
634   if (output_mask[0]) {
635     const auto grad_normalized_input = weight.defined() ?
636       grad_out * weight : grad_out;
637     auto [grad_normalized_input_value, grad_normalized_input_bdim] =
638         unwrapTensorAtLevel(grad_normalized_input, cur_level);
639 
640     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
641     const auto results = native_layer_norm_backward_no_weight_bias_batch_rule(
642         grad_normalized_input_value, grad_normalized_input_bdim,
643         input_value, input_bdim,
644         normalized_shape,
645         mean_value, mean_bdim,
646         rstd_value, rstd_bdim);
647     grad_input = makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
648   }
649   return std::make_tuple(grad_input, grad_weight, grad_bias);
650 }
651 
652 template <typename F, F Func>
653 struct NativeBatchNormBatchRuleHelper {
applyat::functorch::NativeBatchNormBatchRuleHelper654   static std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>> apply(
655     const Tensor& input, std::optional<int64_t> input_bdim,
656     const std::optional<Tensor>& weight_opt, std::optional<int64_t> weight_bdim,
657     const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
658     const std::optional<Tensor>& running_mean_opt, std::optional<int64_t> running_mean_bdim,
659     const std::optional<Tensor>& running_var_opt, std::optional<int64_t> running_var_bdim,
660     bool training, double momentum, double eps) {
661     return batch_norm_batch_rule<F, Func>(
662         input, input_bdim, weight_opt, weight_bdim, bias_opt, bias_bdim,
663         running_mean_opt, running_mean_bdim, running_var_opt, running_var_bdim, training, momentum, eps);
664   }
665 };
666 
667 template <typename F, F Func>
668 struct CudnnBatchNormBatchRuleHelper {
applyat::functorch::CudnnBatchNormBatchRuleHelper669   static std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>> apply(
670     const Tensor& input, std::optional<int64_t> input_bdim,
671     const Tensor& weight_opt, std::optional<int64_t> weight_bdim,
672     const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
673     const std::optional<Tensor>& running_mean_opt, std::optional<int64_t> running_mean_bdim,
674     const std::optional<Tensor>& running_var_opt, std::optional<int64_t> running_var_bdim,
675     bool training, double momentum, double eps) {
676     auto reserve = at::empty({0}, input.options().dtype(kByte));  // in experiments, reserve was never set to anything other than empty by cuda
677     auto res = batch_norm_batch_rule<F, Func>(
678         input, input_bdim, weight_opt, weight_bdim, bias_opt, bias_bdim,
679         running_mean_opt, running_mean_bdim, running_var_opt, running_var_bdim, training, momentum, eps);
680     return std::tuple_cat(res, std::make_tuple(reserve, std::nullopt));
681   }
682 };
683 
684 template <typename F, F Func>
685 struct MiopenBatchNormBatchRuleHelper {
applyat::functorch::MiopenBatchNormBatchRuleHelper686   static std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>> apply(
687     const Tensor& input, std::optional<int64_t> input_bdim,
688     const Tensor& weight_opt, std::optional<int64_t> weight_bdim,
689     const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
690     const std::optional<Tensor>& running_mean_opt, std::optional<int64_t> running_mean_bdim,
691     const std::optional<Tensor>& running_var_opt, std::optional<int64_t> running_var_bdim,
692     bool training, double momentum, double eps) {
693     return batch_norm_batch_rule<F, Func>(
694         input, input_bdim, weight_opt, weight_bdim, bias_opt, bias_bdim,
695         running_mean_opt, running_mean_bdim, running_var_opt, running_var_bdim, training, momentum, eps);
696   }
697 };
698 
699 #define NATIVE_BATCH_NORM_BATCH_RULE(fn) SINGLE_ARG(\
700     NativeBatchNormBatchRuleHelper<\
701       decltype(&ATEN_FN(fn)),\
702       &ATEN_FN(fn)>::apply)
703 
704 #define CUDNN_BATCH_NORM_BATCH_RULE(fn) SINGLE_ARG(\
705    CudnnBatchNormBatchRuleHelper<\
706       decltype(&ATEN_FN(fn)),\
707       &ATEN_FN(fn)>::apply)
708 
709 #define MIOPEN_BATCH_NORM_BATCH_RULE(fn) SINGLE_ARG(\
710     MiopenBatchNormBatchRuleHelper<\
711       decltype(&ATEN_FN(fn)),\
712       &ATEN_FN(fn)>::apply)
713 
714 template <typename F, F Func>
715 struct NativeBatchNormBackwardBatchRuleHelper {
applyat::functorch::NativeBatchNormBackwardBatchRuleHelper716   static std::tuple<Tensor,Tensor,Tensor> apply(
717     const at::Tensor & grad_out,
718     const at::Tensor & input,
719     const std::optional<at::Tensor> & weight_opt,
720     const std::optional<at::Tensor> & running_mean_opt,
721     const std::optional<at::Tensor> & running_var_opt,
722     const std::optional<at::Tensor> & save_mean_opt,
723     const std::optional<at::Tensor> & save_rstd_opt,
724     bool training,
725     double eps,
726     std::array<bool,3> output_mask) {
727 
728     auto maybe_layer = maybeCurrentDynamicLayer();
729     vmap_check_escaped(maybe_layer, "NativeBatchNormBackwardBatchRuleHelper.apply");
730     int64_t cur_level = maybe_layer->layerId();
731 
732     if (!areAnyBatchedAtLevel({grad_out, input, weight_opt, running_mean_opt,
733           running_var_opt, save_mean_opt, save_rstd_opt}, cur_level)) {
734       c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
735       return at::native_batch_norm_backward(grad_out, input, weight_opt,
736           running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt,
737           training, eps, output_mask);
738     }
739 
740     return batch_norm_backward_plumbing<F, Func>(
741         grad_out, input, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, training, eps, output_mask);
742   }
743 };
744 
745 template <typename F, F Func>
746 struct CudnnBatchNormBackwardBatchRuleHelper {
applyat::functorch::CudnnBatchNormBackwardBatchRuleHelper747   static std::tuple<Tensor,Tensor,Tensor> apply(
748     const at::Tensor & input,
749     const at::Tensor & grad_out,
750     const at::Tensor & weight,
751     const std::optional<at::Tensor> & running_mean_opt,
752     const std::optional<at::Tensor> & running_var_opt,
753     const std::optional<at::Tensor> & save_mean_opt,
754     const std::optional<at::Tensor> & save_rstd_opt,
755     double eps,
756     const at::Tensor & reserve) {
757 
758     auto maybe_layer = maybeCurrentDynamicLayer();
759     vmap_check_escaped(maybe_layer, "CudnnBatchNormBackwardBatchRuleHelper.apply");
760     int64_t cur_level = maybe_layer->layerId();
761 
762     if (!areAnyBatchedAtLevel({input, grad_out, weight, running_mean_opt,
763           running_var_opt, save_mean_opt, save_rstd_opt, reserve}, cur_level)) {
764       c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
765       return at::cudnn_batch_norm_backward(input, grad_out, weight,
766           running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps, reserve);
767     }
768 
769     return batch_norm_backward_plumbing<F, Func>(
770         grad_out, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, true, eps, {true, true, true});
771   }
772 };
773 
774 template <typename F, F Func>
775 struct MiopenBatchNormBackwardBatchRuleHelper {
applyat::functorch::MiopenBatchNormBackwardBatchRuleHelper776   static std::tuple<Tensor,Tensor,Tensor> apply(
777     const at::Tensor & input,
778     const at::Tensor & grad_out,
779     const at::Tensor & weight,
780     const std::optional<at::Tensor> & running_mean_opt,
781     const std::optional<at::Tensor> & running_var_opt,
782     const std::optional<at::Tensor> & save_mean_opt,
783     const std::optional<at::Tensor> & save_rstd_opt,
784     double eps) {
785 
786     auto maybe_layer = maybeCurrentDynamicLayer();
787     vmap_check_escaped(maybe_layer, "MiopenBatchNormBackwardBatchRuleHelper.apply");
788     int64_t cur_level = maybe_layer->layerId();
789 
790     if (!areAnyBatchedAtLevel({input, grad_out, weight, running_mean_opt,
791           running_var_opt, save_mean_opt, save_rstd_opt}, cur_level)) {
792       c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
793       return at::miopen_batch_norm_backward(input, grad_out, weight,
794           running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps);
795     }
796 
797     return batch_norm_backward_plumbing<F, Func>(
798         grad_out, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, true, eps, {true, true, true});
799   }
800 };
801 
802 #define NATIVE_BATCH_NORM_BACKWARD_BATCH_RULE(fn) SINGLE_ARG(\
803     NativeBatchNormBackwardBatchRuleHelper<\
804       decltype(&ATEN_FN(fn)),\
805       &ATEN_FN(fn)>::apply)
806 
807 #define CUDNN_BATCH_NORM_BACKWARD_BATCH_RULE(fn) SINGLE_ARG(\
808    CudnnBatchNormBackwardBatchRuleHelper<\
809       decltype(&fn),\
810       &fn>::apply)
811 
812 #define MIOPEN_BATCH_NORM_BACKWARD_BATCH_RULE(fn) SINGLE_ARG(\
813     MiopenBatchNormBackwardBatchRuleHelper<\
814       decltype(&fn),\
815       &fn>::apply)
816 
cudnn_batch_norm_backward_wrapper(const at::Tensor & grad_out,const at::Tensor & input,const at::Tensor & weight_opt,const std::optional<at::Tensor> & running_mean_opt,const std::optional<at::Tensor> & running_var_opt,const std::optional<at::Tensor> & save_mean_opt,const std::optional<at::Tensor> & save_rstd_opt,bool training,double eps,std::array<bool,3> output_mask)817 static std::tuple<at::Tensor,at::Tensor,at::Tensor> cudnn_batch_norm_backward_wrapper(
818     const at::Tensor & grad_out,
819     const at::Tensor & input,
820     const at::Tensor& weight_opt,
821     const std::optional<at::Tensor> & running_mean_opt,
822     const std::optional<at::Tensor> & running_var_opt,
823     const std::optional<at::Tensor> & save_mean_opt,
824     const std::optional<at::Tensor> & save_rstd_opt,
825     bool training,
826     double eps,
827     std::array<bool,3> output_mask) {
828     TORCH_INTERNAL_ASSERT(!training);
829     auto reserve = at::empty({0}, input.options().dtype(kByte));
830     return at::cudnn_batch_norm_backward(input, grad_out, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps, reserve);
831   }
832 
miopen_batch_norm_backward_wrapper(const at::Tensor & grad_out,const at::Tensor & input,const at::Tensor & weight_opt,const std::optional<at::Tensor> & running_mean_opt,const std::optional<at::Tensor> & running_var_opt,const std::optional<at::Tensor> & save_mean_opt,const std::optional<at::Tensor> & save_rstd_opt,bool training,double eps,std::array<bool,3> output_mask)833 static std::tuple<at::Tensor,at::Tensor,at::Tensor> miopen_batch_norm_backward_wrapper(
834     const at::Tensor & grad_out,
835     const at::Tensor & input,
836     const at::Tensor& weight_opt,
837     const std::optional<at::Tensor> & running_mean_opt,
838     const std::optional<at::Tensor> & running_var_opt,
839     const std::optional<at::Tensor> & save_mean_opt,
840     const std::optional<at::Tensor> & save_rstd_opt,
841     bool training,
842     double eps,
843     std::array<bool,3> output_mask) {
844     TORCH_INTERNAL_ASSERT(!training); // this should be ensured by batch_norm_impl
845     return at::miopen_batch_norm_backward(input, grad_out, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps);
846   }
847 
848 // NB: This is NOT good. In the ideal world, we do NOT want to convert the new legit op back into native_batch_norm
849 // as native_batch_norm has a problematic schema--it promises it is functional when it is not. However, vmap doesn't
850 // work with dynamo anyway so we gain some buffer room to do wrong things here. The (reasonable) hope is that we will
851 // make native_batch_norm composite implicit within a few weeks and we can fix this before vmap works with dynamo.
_native_batch_norm_legit_batch(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,bool train,double momentum,double eps)852 static std::tuple<at::Tensor,at::Tensor,at::Tensor> _native_batch_norm_legit_batch(
853   const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
854   Tensor& running_mean, Tensor& running_var, bool train, double momentum, double eps) {
855     return at::native_batch_norm(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps);
856 }
857 
_native_batch_norm_legit_no_stats_batch(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,bool train,double momentum,double eps)858 static std::tuple<at::Tensor,at::Tensor,at::Tensor> _native_batch_norm_legit_no_stats_batch(
859   const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
860   bool train, double momentum, double eps) {
861     return at::native_batch_norm(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps);
862 }
863 
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)864 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
865   VMAP_SUPPORT(native_batch_norm, NATIVE_BATCH_NORM_BATCH_RULE(native_batch_norm));
866   VMAP_SUPPORT(cudnn_batch_norm, CUDNN_BATCH_NORM_BATCH_RULE(cudnn_batch_norm));
867   VMAP_SUPPORT(miopen_batch_norm, MIOPEN_BATCH_NORM_BATCH_RULE(miopen_batch_norm));
868   m.impl("_native_batch_norm_legit", _native_batch_norm_legit_batch);
869   m.impl("_native_batch_norm_legit.no_stats", _native_batch_norm_legit_no_stats_batch);
870   m.impl("native_batch_norm_backward", NATIVE_BATCH_NORM_BACKWARD_BATCH_RULE(native_batch_norm_backward));
871   m.impl("cudnn_batch_norm_backward", CUDNN_BATCH_NORM_BACKWARD_BATCH_RULE(at::functorch::cudnn_batch_norm_backward_wrapper));
872   m.impl("miopen_batch_norm_backward", MIOPEN_BATCH_NORM_BACKWARD_BATCH_RULE(at::functorch::miopen_batch_norm_backward_wrapper));
873   m.impl("native_group_norm", native_group_norm_plumbing);
874   m.impl("native_group_norm_backward", native_group_norm_backward_plumbing);
875   VMAP_SUPPORT(native_layer_norm, native_layer_norm_batch_rule);
876   m.impl("native_layer_norm_backward", native_layer_norm_backward_plumbing);
877 }
878 
879 } // namespace at::functorch
880