1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/functorch/PlumbingHelper.h>
9 #include <ATen/functorch/BatchedFallback.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11
12 namespace at::functorch {
13
is_empty_tensor(const Tensor & tensor)14 static bool is_empty_tensor(const Tensor& tensor) {
15 const auto shape = tensor.sizes();
16 return shape.size() == 1 && shape[0] == 0;
17 }
18
compute_stat_bdim(std::optional<int64_t> input_bdim,const Tensor & stat)19 static std::optional<int64_t> compute_stat_bdim(
20 std::optional<int64_t> input_bdim,
21 const Tensor& stat) {
22 // There's a weird case where mean, rstd can both have shape (0,).
23 // It's possible that this is a bug on the PyTorch side.
24 // When that happens we don't want to return a BatchedTensor.
25 if (input_bdim.has_value() && !is_empty_tensor(stat)) {
26 return 0;
27 }
28 return std::nullopt;
29 }
30
padRight(const Tensor & tensor,std::optional<int64_t> has_bdim,int64_t logical_rank)31 static Tensor padRight(const Tensor& tensor, std::optional<int64_t> has_bdim, int64_t logical_rank) {
32 // NB: Batch dim, if it exists, is assumed to be the first dim
33 auto tensor_logical_rank = rankWithoutBatchDim(tensor, has_bdim);
34 if (tensor_logical_rank >= logical_rank) {
35 return tensor;
36 }
37 VmapDimVector new_sizes(tensor.sizes().begin(), tensor.sizes().end());
38 for (int64_t i = 0; i < logical_rank - tensor_logical_rank; i++) {
39 new_sizes.push_back(1);
40 }
41 return tensor.view(new_sizes);
42 }
43
44 template<typename F, F Func>
45 std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
batch_norm_batch_rule(const Tensor & input,std::optional<int64_t> input_bdim,const std::optional<Tensor> & weight_opt,std::optional<int64_t> weight_bdim,const std::optional<Tensor> & bias_opt,std::optional<int64_t> bias_bdim,const std::optional<Tensor> & running_mean_opt,std::optional<int64_t> running_mean_bdim,const std::optional<Tensor> & running_var_opt,std::optional<int64_t> running_var_bdim,bool training,double momentum,double eps)46 batch_norm_batch_rule(
47 const Tensor& input, std::optional<int64_t> input_bdim,
48 const std::optional<Tensor>& weight_opt, std::optional<int64_t> weight_bdim,
49 const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
50 const std::optional<Tensor>& running_mean_opt, std::optional<int64_t> running_mean_bdim,
51 const std::optional<Tensor>& running_var_opt, std::optional<int64_t> running_var_bdim,
52 bool training, double momentum, double eps) {
53 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
54 const Tensor& weight = *weight_maybe_owned;
55 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
56 const Tensor& bias = *bias_maybe_owned;
57 c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
58 const auto& running_mean = *running_mean_maybe_owned;
59 c10::MaybeOwned<Tensor> running_var_maybe_owned = at::borrow_from_optional_tensor(running_var_opt);
60 const auto& running_var = *running_var_maybe_owned;
61 TORCH_CHECK(!training || (!input_bdim || ((!running_mean.defined() || running_mean_bdim) && (!running_var.defined() || running_var_bdim))),
62 "Batch norm got a batched tensor as input while the running_mean or running_var, which will be updated in place, ",
63 "were not batched.\nIf you are using a module and do not need eval mode, please set `track_running_stats` to be False.",
64 "If you are using a prebuilt module and do not need eval mode, please see the functorch website for resources on ",
65 "how to patch your module to work with vmap");
66 std::optional<int64_t> bdim_size;
67 Tensor result0;
68 Tensor mean;
69 Tensor rstd;
70 if (!input_bdim && !running_mean_bdim && !running_var_bdim) {
71 const auto dummy_weight = at::ones(input.size(1), input.options()); // cudnn and miopen require a weight
72 const auto dummy_bias = at::zeros(input.size(1), input.options()); // without this, get "strides() called on undefined Tensor" on cuda
73 const auto result = Func(input, dummy_weight, dummy_bias, running_mean_opt, running_var_opt, training, momentum, eps);
74 result0 = std::get<0>(result).transpose(0, 1); // [C, B, *]
75 mean = std::get<1>(result);
76 rstd = std::get<2>(result);
77 } else {
78 bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim);
79 auto input_ = moveBatchDimToFront(input, input_bdim);
80 input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size.value());
81 input_ = reshape_dim_into(0, /*channels dim*/1, input_);
82
83 std::optional<Tensor> running_mean_;
84 std::optional<Tensor> running_var_;
85 if (running_mean.defined()) {
86 running_mean_ = moveBatchDimToFront(running_mean, running_mean_bdim);
87 running_mean_ = ensure_has_bdim(*running_mean_, running_mean_bdim.has_value(), bdim_size.value());
88 running_mean_ = reshape_dim_into(0, 0, *running_mean_).contiguous();
89 }
90 if (running_var.defined()) {
91 running_var_ = moveBatchDimToFront(running_var, running_var_bdim);
92 running_var_ = ensure_has_bdim(*running_var_, running_var_bdim.has_value(), bdim_size.value());
93 running_var_ = reshape_dim_into(0, 0, *running_var_).contiguous();
94 }
95
96 const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight
97 const auto dummy_bias = at::zeros(input_.size(1), input_.options()); // without this, get "strides() called on undefined Tensor" on cuda
98 const auto result = Func(input_, dummy_weight, dummy_bias, running_mean_, running_var_, training, momentum, eps);
99 result0 = std::get<0>(result).transpose(0, 1); // [(B0, C), B, *]
100 result0 = reshape_dim_outof(0, bdim_size.value(), result0); // [B0, C, B, *]
101 mean = std::get<1>(result);
102 mean = reshape_dim_outof(0, bdim_size.value(), mean); // [B0, C]
103 rstd = std::get<2>(result);
104 rstd = reshape_dim_outof(0, bdim_size.value(), rstd); // [B0, C]
105 }
106
107 const auto stats_bdim = compute_stat_bdim(bdim_size, mean);
108 if (weight.defined()) {
109 const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
110 auto weight_ = moveBatchDimToFront(weight, weight_bdim);
111 weight_ = padRight(weight_, weight_bdim, input_logical_rank);
112 result0 = result0 * weight_;
113 }
114 if (bias.defined()) {
115 const auto result_logical_rank = rankWithoutBatchDim(
116 result0,
117 bdim_size.has_value() || weight_bdim.has_value() ? std::optional<int64_t>(0) : std::optional<int64_t>(std::nullopt));
118 auto bias_ = moveBatchDimToFront(bias, bias_bdim);
119 bias_ = padRight(bias_, bias_bdim, result_logical_rank);
120 result0 = result0 + bias_;
121 }
122 result0 = result0.transpose(1, 2); // [B0, B, C, *], because some arg must have been batched, the output must be batched
123 return std::make_tuple(result0, 0, mean, stats_bdim, rstd, stats_bdim);
124 }
125
126 template<typename F, F Func>
batch_norm_backward_no_weight_bias_batch_rule(const at::Tensor & grad_out,std::optional<int64_t> grad_out_bdim,const at::Tensor & input,std::optional<int64_t> input_bdim,const std::optional<at::Tensor> & running_mean_opt,std::optional<int64_t> running_mean_bdim,const std::optional<at::Tensor> & running_var_opt,std::optional<int64_t> running_var_bdim,const at::Tensor & mean,std::optional<int64_t> mean_bdim,const at::Tensor & rstd,std::optional<int64_t> rstd_bdim,bool training,double eps)127 std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bias_batch_rule(
128 const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
129 const at::Tensor & input, std::optional<int64_t> input_bdim,
130 const std::optional<at::Tensor> & running_mean_opt, std::optional<int64_t> running_mean_bdim,
131 const std::optional<at::Tensor> & running_var_opt, std::optional<int64_t> running_var_bdim,
132 const at::Tensor & mean, std::optional<int64_t> mean_bdim,
133 const at::Tensor & rstd, std::optional<int64_t> rstd_bdim,
134 bool training, double eps) {
135 c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
136 const Tensor& running_mean = *running_mean_maybe_owned;
137 c10::MaybeOwned<Tensor> running_var_maybe_owned = at::borrow_from_optional_tensor(running_var_opt);
138 const Tensor& running_var = *running_var_maybe_owned;
139
140 if (!grad_out_bdim.has_value() && !input_bdim.has_value() && !running_mean_bdim.has_value() && !running_var_bdim.has_value()) {
141 // for either of these to have bdims, the input, running_mean, or running_var must have had a bdim
142 TORCH_INTERNAL_ASSERT(!mean_bdim);
143 TORCH_INTERNAL_ASSERT(!rstd_bdim);
144 const auto dummy_weight = at::ones(input.size(1), input.options());
145 const auto result = Func(
146 grad_out, input, dummy_weight, running_mean_opt, running_var_opt, mean, rstd, training, eps, {true, false, false});
147 return std::make_tuple(std::get<0>(result), std::nullopt);
148 }
149
150 auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
151 auto input_ = moveBatchDimToFront(input, input_bdim);
152 auto mean_ = moveBatchDimToFront(mean, mean_bdim);
153 auto rstd_ = moveBatchDimToFront(rstd, rstd_bdim);
154
155 // ensure all inputs have bdim.
156 const auto bdim_size = get_bdim_size4(grad_out, grad_out_bdim, input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim);
157 grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), bdim_size);
158 input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
159 mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size);
160 rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size);
161
162 std::optional<Tensor> running_mean_;
163 std::optional<Tensor> running_var_;
164 if (running_mean.defined()) {
165 running_mean_ = moveBatchDimToFront(running_mean, running_mean_bdim);
166 running_mean_ = ensure_has_bdim(*running_mean_, running_mean_bdim.has_value(), bdim_size);
167 running_mean_ = reshape_dim_into(0, 0, *running_mean_).contiguous();
168 }
169 if (running_var.defined()) {
170 running_var_ = moveBatchDimToFront(running_var, running_var_bdim);
171 running_var_ = ensure_has_bdim(*running_var_, running_var_bdim.has_value(), bdim_size);
172 running_var_ = reshape_dim_into(0, 0, *running_var_).contiguous();
173 }
174
175 input_ = reshape_dim_into(0, /*channels dim*/1, input_);
176 TORCH_INTERNAL_ASSERT(mean_.dim() == 2);
177 TORCH_INTERNAL_ASSERT(rstd_.dim() == 2);
178 mean_ = reshape_dim_into(0, 0, mean_);
179 rstd_ = reshape_dim_into(0, 0, rstd_);
180 grad_out_ = grad_out_.transpose(0, 1).flatten(1, 2); // [B0, B, C, *] -> [B, (B0, C), *]
181
182 const auto dummy_weight = at::ones(input_.size(1), input_.options());
183 auto result = at::native_batch_norm_backward(
184 grad_out_.contiguous(),
185 input_.contiguous(),
186 dummy_weight,
187 running_mean_, // contiguous called if there is a tensor given
188 running_var_, // contiguous called if there is a tensor given
189 mean_.contiguous(),
190 rstd_.contiguous(),
191 training, eps, {true, false, false});
192 auto result0 = std::get<0>(result);
193 result0 = reshape_dim_outof(1, bdim_size, result0); // [B, B0, C, *]
194 result0 = result0.transpose(0, 1); // [B0, B, C, *]
195 return std::make_tuple(result0, 0);
196 }
197
198 template<typename F, F Func>
batch_norm_backward_plumbing(const at::Tensor & grad_out,const at::Tensor & input,const std::optional<at::Tensor> & weight_opt,const std::optional<at::Tensor> & running_mean_opt,const std::optional<at::Tensor> & running_var_opt,const std::optional<at::Tensor> & save_mean_opt,const std::optional<at::Tensor> & save_rstd_opt,bool training,double eps,std::array<bool,3> output_mask)199 std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
200 const at::Tensor & grad_out,
201 const at::Tensor & input,
202 const std::optional<at::Tensor> & weight_opt,
203 const std::optional<at::Tensor> & running_mean_opt,
204 const std::optional<at::Tensor> & running_var_opt,
205 const std::optional<at::Tensor> & save_mean_opt,
206 const std::optional<at::Tensor> & save_rstd_opt,
207 bool training,
208 double eps,
209 std::array<bool,3> output_mask) {
210 // See [Note: hacky wrapper removal for optional tensor]
211 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
212 const Tensor& weight = *weight_maybe_owned;
213 c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
214 const Tensor& running_mean = *running_mean_maybe_owned;
215 c10::MaybeOwned<Tensor> running_var_maybe_owned = at::borrow_from_optional_tensor(running_var_opt);
216 const Tensor& running_var = *running_var_maybe_owned;
217 // NB: not sure why these are optional...these are required from the forward
218 const Tensor& save_mean = *save_mean_opt;
219 const Tensor& save_rstd = *save_rstd_opt;
220 TORCH_INTERNAL_ASSERT(save_mean.defined());
221 TORCH_INTERNAL_ASSERT(save_rstd.defined());
222
223 // plumbing
224 auto maybe_layer = maybeCurrentDynamicLayer();
225 vmap_check_escaped(maybe_layer, "batch_norm_backward_plumbing");
226 int64_t cur_level = maybe_layer->layerId();
227
228 auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level);
229 auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level);
230 Tensor mean_value;
231 std::optional<Tensor> weight_value;
232 std::optional<int64_t> weight_bdim;
233 if (weight.defined()) {
234 std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level);
235 }
236 std::optional<Tensor> running_mean_value;
237 std::optional<int64_t> running_mean_bdim;
238 if (running_mean.defined()) {
239 std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean, cur_level);
240 }
241 std::optional<Tensor> running_var_value;
242 std::optional<int64_t> running_var_bdim;
243 if (running_var.defined()) {
244 std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var, cur_level);
245 }
246 auto [save_mean_value, save_mean_bdim] = unwrapTensorAtLevel(save_mean, cur_level);
247 auto [save_rstd_value, save_rstd_bdim] = unwrapTensorAtLevel(save_rstd, cur_level);
248
249 // results
250 Tensor grad_bias;
251 Tensor grad_weight;
252 Tensor grad_input;
253
254 TORCH_INTERNAL_ASSERT(grad_out_value.dim() > 1); // batch_norm can't operate on 1D tensors so the output will be at least 2D
255 if (output_mask[2]) {
256 grad_bias = grad_out.transpose(0, 1).sum(range(1, grad_out.dim()));
257 }
258 if (output_mask[1] && weight_value.has_value()) {
259 // NB: output isn't saved...
260 auto mean = training ? save_mean : running_mean;
261 auto var = training ? save_rstd : (1 / at::sqrt(running_var + eps));
262 const auto normalized_input = (input.transpose(0, 1) - padRight(mean, std::nullopt, input.dim())) * padRight(var, std::nullopt, input.dim());
263 const auto expanded_grad_weight = normalized_input * grad_out.transpose(0, 1);
264 grad_weight = expanded_grad_weight.sum(range(1, grad_out.dim()));
265 }
266 if (output_mask[0]) {
267 const auto grad_normalized_input = weight.defined() ?
268 grad_out.transpose(0, 1) * padRight(weight, std::nullopt, grad_out.dim()) : grad_out.transpose(0, 1); // [B0, C, B, *]
269 auto [grad_normalized_input_value, grad_normalized_input_bdim] =
270 unwrapTensorAtLevel(grad_normalized_input.transpose(0, 1), cur_level); // [B0, B, C, *]
271
272 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
273 const auto results = batch_norm_backward_no_weight_bias_batch_rule<F, Func>(
274 grad_normalized_input_value, grad_normalized_input_bdim,
275 input_value, input_bdim,
276 running_mean_value, running_mean_bdim,
277 running_var_value, running_var_bdim,
278 save_mean_value, save_mean_bdim,
279 save_rstd_value, save_rstd_bdim,
280 training, eps);
281 grad_input = makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
282 }
283 return std::make_tuple(grad_input, grad_weight, grad_bias);
284 }
285
native_group_norm_plumbing(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps)286 static std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
287 const Tensor & input, const std::optional<Tensor> & weight_opt,
288 const std::optional<Tensor> & bias_opt, int64_t N, int64_t C,
289 int64_t HxW, int64_t group, double eps) {
290 // See [Note: hacky wrapper removal for optional tensor]
291 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
292 const Tensor& weight = *weight_maybe_owned;
293 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
294 const Tensor& bias = *bias_maybe_owned;
295
296 auto maybe_layer = maybeCurrentDynamicLayer();
297 vmap_check_escaped(maybe_layer, "native_group_norm_plumbing");
298 int64_t cur_level = maybe_layer->layerId();
299
300 if (!areAnyBatchedAtLevel({input, weight_opt, bias_opt}, cur_level)) {
301 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
302 return at::native_group_norm(input, weight_opt, bias_opt, N, C, HxW, group, eps);
303 }
304
305 auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level);
306
307 Tensor result0;
308 Tensor mean;
309 Tensor rstd;
310 if (input_bdim) {
311 const auto input_ = reshape_dim_into(*input_bdim, 0, input_value);
312 const auto bdim_size = input_value.size(*input_bdim);
313
314 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
315 const auto result = at::native_group_norm(input_, std::nullopt, std::nullopt, N * bdim_size, C, HxW, group, eps);
316 result0 = makeBatched(reshape_dim_outof(0, bdim_size, std::get<0>(result)), 0, cur_level);
317 mean = makeBatched(reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0, cur_level);
318 rstd = makeBatched(reshape_dim_outof(0, bdim_size, std::get<2>(result)), 0, cur_level);
319 } else {
320 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
321 const auto result = at::native_group_norm(input_value, std::nullopt, std::nullopt, N, C, HxW, group, eps);
322 result0 = std::get<0>(result);
323 mean = std::get<1>(result);
324 rstd = std::get<2>(result);
325 }
326
327 if (weight.defined()) {
328 const auto padded_weight = padRight(weight, std::nullopt, result0.dim() - 1);
329 result0 = result0 * padded_weight;
330 }
331
332 if (bias.defined()) {
333 const auto padded_bias = padRight(bias, std::nullopt, result0.dim() - 1);
334 result0 = result0 + padded_bias;
335 }
336
337 return std::make_tuple(result0, mean, rstd);
338 }
339
group_norm_backward_no_weight_bias_batch_rule(const at::Tensor & grad_out,std::optional<int64_t> grad_out_bdim,const at::Tensor & input,std::optional<int64_t> input_bdim,const at::Tensor & mean,std::optional<int64_t> mean_bdim,const at::Tensor & rstd,std::optional<int64_t> rstd_bdim,int64_t N,int64_t C,int64_t HxW,int64_t group)340 static std::tuple<at::Tensor, std::optional<int64_t>> group_norm_backward_no_weight_bias_batch_rule(
341 const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
342 const at::Tensor & input, std::optional<int64_t> input_bdim,
343 const at::Tensor & mean, std::optional<int64_t> mean_bdim,
344 const at::Tensor & rstd, std::optional<int64_t> rstd_bdim,
345 int64_t N, int64_t C, int64_t HxW, int64_t group) {
346 auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
347 auto input_ = moveBatchDimToFront(input, input_bdim);
348 auto mean_ = moveBatchDimToFront(mean, mean_bdim);
349 auto rstd_ = moveBatchDimToFront(rstd, rstd_bdim);
350
351 const auto bdim_size = get_bdim_size2(grad_out, grad_out_bdim, input, input_bdim);
352 grad_out_ = ensure_has_bdim(grad_out, grad_out_bdim.has_value(), bdim_size);
353 input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
354 mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size);
355 rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size);
356
357 grad_out_ = reshape_dim_into(0, 0, grad_out_); // [B0 * N, C, *]
358 input_ = reshape_dim_into(0, 0, input_); // [B0 * N, C, *]
359 mean_ = reshape_dim_into(0, 0, mean_); // [B0 * N, G]
360 rstd_ = reshape_dim_into(0, 0, rstd_); // [B0 * N, G]
361
362 const auto result = native_group_norm_backward(
363 grad_out_.contiguous(),
364 input_.contiguous(),
365 mean_.contiguous(),
366 rstd_.contiguous(),
367 std::nullopt, N * bdim_size, C, HxW, group, {true, false, false});
368 auto result0 = std::get<0>(result);
369 result0 = reshape_dim_outof(0, bdim_size, result0);
370 return std::make_tuple(result0, 0);
371 }
372
native_group_norm_backward_plumbing(const Tensor & grad_out,const Tensor & input,const Tensor & mean,const Tensor & rstd,const std::optional<Tensor> & weight_opt,int64_t N,int64_t C,int64_t HxW,int64_t group,std::array<bool,3> output_mask)373 static std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
374 const Tensor & grad_out, const Tensor & input, const Tensor & mean,
375 const Tensor & rstd, const std::optional<Tensor> & weight_opt,
376 int64_t N, int64_t C, int64_t HxW, int64_t group, std::array<bool,3> output_mask
377 ) {
378 // See [Note: hacky wrapper removal for optional tensor]
379 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
380 const Tensor& weight = *weight_maybe_owned;
381
382 // plumbing
383 auto maybe_layer = maybeCurrentDynamicLayer();
384 vmap_check_escaped(maybe_layer, "native_group_norm_backward_plumbing");
385 int64_t cur_level = maybe_layer->layerId();
386
387 if (!areAnyBatchedAtLevel({grad_out, input, mean, rstd, weight_opt}, cur_level)) {
388 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
389 return at::native_group_norm_backward(grad_out, input, mean, rstd, weight_opt, N, C, HxW, group, output_mask);
390 }
391
392 auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level);
393 Tensor weight_value;
394 std::optional<int64_t> weight_bdim;
395 if (weight.defined()){
396 std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level);
397 }
398 auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level);
399 auto [rstd_value, rstd_bdim] = unwrapTensorAtLevel(rstd, cur_level);
400
401 // results
402 Tensor grad_input;
403 Tensor grad_weight;
404 Tensor grad_bias;
405
406 TORCH_INTERNAL_ASSERT(grad_out.dim() > 1); // group_norm can't operate on 1D tensors so the output will be at least 2D
407 if (output_mask[2]) {
408 grad_bias = grad_out.transpose(0, 1).sum(range(1, grad_out.dim()));
409 }
410
411 if (output_mask[1] && weight.defined()) {
412 const auto reshaped_input = reshape_dim_outof(1, group, input);
413 const auto normalized_input = (reshaped_input - padRight(mean, std::nullopt, reshaped_input.dim())) * padRight(rstd, std::nullopt, reshaped_input.dim());
414 const auto expanded_grad_weight = reshape_dim_into(1, 1, normalized_input) * grad_out;
415 grad_weight = expanded_grad_weight.transpose(0, 1).sum(range(1, expanded_grad_weight.dim()));
416 }
417
418 if (output_mask[0]) {
419 const auto grad_normalized_input = weight.defined() ?
420 grad_out * padRight(weight, std::nullopt, grad_out.dim() - 1) : grad_out;
421 auto [grad_normalized_input_value, grad_normalized_input_bdim] =
422 unwrapTensorAtLevel(grad_normalized_input, cur_level);
423
424 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
425 const auto res = group_norm_backward_no_weight_bias_batch_rule(
426 grad_normalized_input_value, grad_normalized_input_bdim,
427 input_value, input_bdim,
428 mean_value, mean_bdim,
429 rstd_value, rstd_bdim,
430 N, C, HxW, group
431 );
432 grad_input = makeBatched(std::get<0>(res), std::get<1>(res), cur_level);
433 }
434 return std::make_tuple(grad_input, grad_weight, grad_bias);
435 }
436
has_same_shape(const Tensor & tensor,std::optional<int64_t> tensor_bdim,c10::SymIntArrayRef normalized_shape)437 C10_ALWAYS_INLINE bool has_same_shape(
438 const Tensor& tensor, std::optional<int64_t> tensor_bdim,
439 c10::SymIntArrayRef normalized_shape) {
440 if (!tensor.defined()) {
441 return true;
442 }
443 if (rankWithoutBatchDim(tensor, tensor_bdim) != (int64_t) normalized_shape.size()) {
444 return false;
445 }
446 const auto tensor_shape = tensor.sizes();
447 for (const auto i : c10::irange(normalized_shape.size())) {
448 auto j = i;
449 // (0, 1, 2), 1 -> (0, 2, 3)
450 if (tensor_bdim.has_value() && (int64_t)i >= tensor_bdim.value()) {
451 j = j + 1;
452 }
453 if (normalized_shape[i] != tensor_shape[j]) {
454 return false;
455 }
456 }
457 return true;
458 }
459
check_same_shape(const Tensor & tensor,std::optional<int64_t> tensor_bdim,c10::SymIntArrayRef normalized_shape,const std::string & name)460 C10_ALWAYS_INLINE void check_same_shape(
461 const Tensor& tensor, std::optional<int64_t> tensor_bdim,
462 c10::SymIntArrayRef normalized_shape, const std::string& name) {
463 TORCH_CHECK(has_same_shape(tensor, tensor_bdim, normalized_shape),
464 "Expected ", name, " to be of same shape as normalized_shape, but got ",
465 name, " of shape ",
466 tensor.sizes(),
467 " and normalized_shape = ",
468 normalized_shape);
469 }
470
471 // Ugh, hard to deduplicate
_check_layer_norm_inputs(SymIntArrayRef normalized_shape,const Tensor & weight,std::optional<int64_t> weight_bdim,const Tensor & bias,std::optional<int64_t> bias_bdim)472 C10_ALWAYS_INLINE void _check_layer_norm_inputs(
473 SymIntArrayRef normalized_shape,
474 const Tensor& weight, std::optional<int64_t> weight_bdim,
475 const Tensor& bias, std::optional<int64_t> bias_bdim) {
476
477 const auto normalized_ndim = normalized_shape.size();
478 TORCH_CHECK(
479 normalized_ndim >= 1,
480 "Expected normalized_shape to be at least 1-dimensional, i.e., ",
481 "containing at least one element, but got normalized_shape = ",
482 normalized_shape);
483 check_same_shape(weight, weight_bdim, normalized_shape, "weight");
484 check_same_shape(bias, bias_bdim, normalized_shape, "weight");
485 }
486
487 static std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
native_layer_norm_batch_rule(const Tensor & input,std::optional<int64_t> input_bdim,c10::SymIntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,std::optional<int64_t> weight_bdim,const std::optional<Tensor> & bias_opt,std::optional<int64_t> bias_bdim,double eps)488 native_layer_norm_batch_rule(
489 const Tensor& input, std::optional<int64_t> input_bdim,
490 c10::SymIntArrayRef normalized_shape,
491 const std::optional<Tensor>& weight_opt, std::optional<int64_t> weight_bdim,
492 const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
493 double eps) {
494 auto input_ = moveBatchDimToFront(input, input_bdim);
495 if (!weight_bdim && !bias_bdim) {
496 const auto result = at::native_layer_norm_symint(input_, normalized_shape, weight_opt, bias_opt, eps);
497 const auto mean = std::get<1>(result);
498 const auto rstd = std::get<2>(result);
499 const auto stats_bdim = compute_stat_bdim(input_bdim, mean);
500 return std::make_tuple(std::get<0>(result), 0, mean, stats_bdim, rstd, stats_bdim);
501 }
502
503 // See [Note: hacky wrapper removal for optional tensor]
504 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
505 const Tensor& weight = *weight_maybe_owned;
506 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
507 const Tensor& bias = *bias_maybe_owned;
508 _check_layer_norm_inputs(normalized_shape, weight, weight_bdim, bias, bias_bdim);
509
510 const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
511 const auto result = at::native_layer_norm_symint(input_, normalized_shape, std::nullopt, std::nullopt, eps);
512 auto result0 = std::get<0>(result);
513 const auto mean = std::get<1>(result);
514 const auto rstd = std::get<2>(result);
515 const auto stats_bdim = compute_stat_bdim(input_bdim, mean);
516
517 if (weight.defined()) {
518 auto weight_ = moveBatchDimToFront(weight, weight_bdim);
519 weight_ = maybePadToLogicalRank(weight_, /*has_bdim*/weight_bdim, input_logical_rank);
520 result0 = result0 * weight_;
521 }
522 if (bias.defined()) {
523 const auto result_logical_rank = rankWithoutBatchDim(
524 result0,
525 input_bdim.has_value() || weight_bdim.has_value() ? std::optional<int64_t>(0) : std::optional<int64_t>(std::nullopt));
526 auto bias_ = moveBatchDimToFront(bias, bias_bdim);
527 bias_ = maybePadToLogicalRank(bias_, /*has_bdim*/bias_bdim, result_logical_rank);
528 result0 = result0 + bias_;
529 }
530 return std::make_tuple(result0, 0, mean, stats_bdim, rstd, stats_bdim);
531 }
532
native_layer_norm_backward_no_weight_bias_batch_rule(const at::Tensor & grad_out,std::optional<int64_t> grad_out_bdim,const at::Tensor & input,std::optional<int64_t> input_bdim,at::IntArrayRef normalized_shape,const at::Tensor & mean,std::optional<int64_t> mean_bdim,const at::Tensor & rstd,std::optional<int64_t> rstd_bdim)533 static std::tuple<at::Tensor, std::optional<int64_t>> native_layer_norm_backward_no_weight_bias_batch_rule(
534 const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
535 const at::Tensor & input, std::optional<int64_t> input_bdim,
536 at::IntArrayRef normalized_shape,
537 const at::Tensor & mean, std::optional<int64_t> mean_bdim,
538 const at::Tensor & rstd, std::optional<int64_t> rstd_bdim) {
539
540 if (!grad_out_bdim.has_value() && !input_bdim.has_value() &&
541 !mean_bdim.has_value() && !rstd_bdim.has_value()) {
542 const auto result = at::native_layer_norm_backward(
543 grad_out, input, normalized_shape, mean, rstd, std::nullopt, std::nullopt, {true, false, false});
544 return std::make_tuple(std::get<0>(result), std::nullopt);
545 }
546
547 auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
548 auto input_ = moveBatchDimToFront(input, input_bdim);
549 auto mean_ = moveBatchDimToFront(mean, mean_bdim);
550 auto rstd_ = moveBatchDimToFront(rstd, rstd_bdim);
551
552 // ensure grad_out / input have bdim.
553 const auto bdim_size = get_bdim_size2(grad_out, grad_out_bdim, input, input_bdim);
554 grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), bdim_size);
555 input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
556 mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size);
557 rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size);
558
559 auto result = at::native_layer_norm_backward(
560 grad_out_.contiguous(),
561 input_.contiguous(),
562 normalized_shape,
563 mean_.contiguous(),
564 rstd_.contiguous(),
565 std::nullopt, std::nullopt, {true, false, false});
566
567 return std::make_tuple(std::get<0>(result), 0);
568 }
569
native_layer_norm_backward_plumbing(const at::Tensor & grad_out,const at::Tensor & input,at::IntArrayRef normalized_shape,const at::Tensor & mean,const at::Tensor & rstd,const std::optional<at::Tensor> & weight_opt,const std::optional<at::Tensor> & bias_opt,std::array<bool,3> output_mask)570 static std::tuple<at::Tensor,at::Tensor,at::Tensor> native_layer_norm_backward_plumbing(
571 const at::Tensor & grad_out,
572 const at::Tensor & input,
573 at::IntArrayRef normalized_shape,
574 const at::Tensor & mean,
575 const at::Tensor & rstd,
576 const std::optional<at::Tensor> & weight_opt,
577 const std::optional<at::Tensor> & bias_opt,
578 std::array<bool,3> output_mask) {
579 // See [Note: hacky wrapper removal for optional tensor]
580 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
581 const Tensor& weight = *weight_maybe_owned;
582 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
583 const Tensor& bias = *bias_maybe_owned;
584
585 // plumbing
586 auto maybe_layer = maybeCurrentDynamicLayer();
587 vmap_check_escaped(maybe_layer, "native_layer_norm_backward_plumbing");
588 int64_t cur_level = maybe_layer->layerId();
589 if (!areAnyBatchedAtLevel({grad_out, input, mean, rstd, weight_opt, bias_opt}, cur_level)) {
590 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
591 return at::native_layer_norm_backward(grad_out, input, normalized_shape, mean, rstd,
592 weight_opt, bias_opt, output_mask);
593 }
594 auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level);
595 auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level);
596 auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level);
597 auto [rstd_value, rstd_bdim] = unwrapTensorAtLevel(rstd, cur_level);
598 std::optional<Tensor> weight_value;
599 std::optional<int64_t> weight_bdim;
600 if (weight.defined()) {
601 std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level);
602 }
603 std::optional<Tensor> bias_value;
604 std::optional<int64_t> bias_bdim;
605 if (bias.defined()) {
606 std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias, cur_level);
607 }
608
609 // results
610 Tensor grad_bias;
611 Tensor grad_weight;
612 Tensor grad_input;
613
614 if (output_mask[2] && bias_value.has_value()) {
615 const auto num_front_dims_to_reduce = grad_out.dim() - normalized_shape.size();
616 if (num_front_dims_to_reduce == 0) {
617 grad_bias = grad_out;
618 } else {
619 grad_bias = grad_out.sum(range(0, static_cast<int64_t>(num_front_dims_to_reduce)));
620 }
621 }
622 if (output_mask[1] && weight_value.has_value()) {
623 // NB: output isn't saved...
624 const auto normalized_input = (input - mean) * rstd;
625 const auto expanded_grad_weight = normalized_input * grad_out;
626 const auto num_front_dims_to_reduce =
627 expanded_grad_weight.dim() - normalized_shape.size();
628 if (num_front_dims_to_reduce == 0) {
629 grad_weight = expanded_grad_weight;
630 } else {
631 grad_weight = expanded_grad_weight.sum(range(0, static_cast<int64_t>(num_front_dims_to_reduce)));
632 }
633 }
634 if (output_mask[0]) {
635 const auto grad_normalized_input = weight.defined() ?
636 grad_out * weight : grad_out;
637 auto [grad_normalized_input_value, grad_normalized_input_bdim] =
638 unwrapTensorAtLevel(grad_normalized_input, cur_level);
639
640 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
641 const auto results = native_layer_norm_backward_no_weight_bias_batch_rule(
642 grad_normalized_input_value, grad_normalized_input_bdim,
643 input_value, input_bdim,
644 normalized_shape,
645 mean_value, mean_bdim,
646 rstd_value, rstd_bdim);
647 grad_input = makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
648 }
649 return std::make_tuple(grad_input, grad_weight, grad_bias);
650 }
651
652 template <typename F, F Func>
653 struct NativeBatchNormBatchRuleHelper {
applyat::functorch::NativeBatchNormBatchRuleHelper654 static std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>> apply(
655 const Tensor& input, std::optional<int64_t> input_bdim,
656 const std::optional<Tensor>& weight_opt, std::optional<int64_t> weight_bdim,
657 const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
658 const std::optional<Tensor>& running_mean_opt, std::optional<int64_t> running_mean_bdim,
659 const std::optional<Tensor>& running_var_opt, std::optional<int64_t> running_var_bdim,
660 bool training, double momentum, double eps) {
661 return batch_norm_batch_rule<F, Func>(
662 input, input_bdim, weight_opt, weight_bdim, bias_opt, bias_bdim,
663 running_mean_opt, running_mean_bdim, running_var_opt, running_var_bdim, training, momentum, eps);
664 }
665 };
666
667 template <typename F, F Func>
668 struct CudnnBatchNormBatchRuleHelper {
applyat::functorch::CudnnBatchNormBatchRuleHelper669 static std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>> apply(
670 const Tensor& input, std::optional<int64_t> input_bdim,
671 const Tensor& weight_opt, std::optional<int64_t> weight_bdim,
672 const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
673 const std::optional<Tensor>& running_mean_opt, std::optional<int64_t> running_mean_bdim,
674 const std::optional<Tensor>& running_var_opt, std::optional<int64_t> running_var_bdim,
675 bool training, double momentum, double eps) {
676 auto reserve = at::empty({0}, input.options().dtype(kByte)); // in experiments, reserve was never set to anything other than empty by cuda
677 auto res = batch_norm_batch_rule<F, Func>(
678 input, input_bdim, weight_opt, weight_bdim, bias_opt, bias_bdim,
679 running_mean_opt, running_mean_bdim, running_var_opt, running_var_bdim, training, momentum, eps);
680 return std::tuple_cat(res, std::make_tuple(reserve, std::nullopt));
681 }
682 };
683
684 template <typename F, F Func>
685 struct MiopenBatchNormBatchRuleHelper {
applyat::functorch::MiopenBatchNormBatchRuleHelper686 static std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>> apply(
687 const Tensor& input, std::optional<int64_t> input_bdim,
688 const Tensor& weight_opt, std::optional<int64_t> weight_bdim,
689 const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
690 const std::optional<Tensor>& running_mean_opt, std::optional<int64_t> running_mean_bdim,
691 const std::optional<Tensor>& running_var_opt, std::optional<int64_t> running_var_bdim,
692 bool training, double momentum, double eps) {
693 return batch_norm_batch_rule<F, Func>(
694 input, input_bdim, weight_opt, weight_bdim, bias_opt, bias_bdim,
695 running_mean_opt, running_mean_bdim, running_var_opt, running_var_bdim, training, momentum, eps);
696 }
697 };
698
699 #define NATIVE_BATCH_NORM_BATCH_RULE(fn) SINGLE_ARG(\
700 NativeBatchNormBatchRuleHelper<\
701 decltype(&ATEN_FN(fn)),\
702 &ATEN_FN(fn)>::apply)
703
704 #define CUDNN_BATCH_NORM_BATCH_RULE(fn) SINGLE_ARG(\
705 CudnnBatchNormBatchRuleHelper<\
706 decltype(&ATEN_FN(fn)),\
707 &ATEN_FN(fn)>::apply)
708
709 #define MIOPEN_BATCH_NORM_BATCH_RULE(fn) SINGLE_ARG(\
710 MiopenBatchNormBatchRuleHelper<\
711 decltype(&ATEN_FN(fn)),\
712 &ATEN_FN(fn)>::apply)
713
714 template <typename F, F Func>
715 struct NativeBatchNormBackwardBatchRuleHelper {
applyat::functorch::NativeBatchNormBackwardBatchRuleHelper716 static std::tuple<Tensor,Tensor,Tensor> apply(
717 const at::Tensor & grad_out,
718 const at::Tensor & input,
719 const std::optional<at::Tensor> & weight_opt,
720 const std::optional<at::Tensor> & running_mean_opt,
721 const std::optional<at::Tensor> & running_var_opt,
722 const std::optional<at::Tensor> & save_mean_opt,
723 const std::optional<at::Tensor> & save_rstd_opt,
724 bool training,
725 double eps,
726 std::array<bool,3> output_mask) {
727
728 auto maybe_layer = maybeCurrentDynamicLayer();
729 vmap_check_escaped(maybe_layer, "NativeBatchNormBackwardBatchRuleHelper.apply");
730 int64_t cur_level = maybe_layer->layerId();
731
732 if (!areAnyBatchedAtLevel({grad_out, input, weight_opt, running_mean_opt,
733 running_var_opt, save_mean_opt, save_rstd_opt}, cur_level)) {
734 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
735 return at::native_batch_norm_backward(grad_out, input, weight_opt,
736 running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt,
737 training, eps, output_mask);
738 }
739
740 return batch_norm_backward_plumbing<F, Func>(
741 grad_out, input, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, training, eps, output_mask);
742 }
743 };
744
745 template <typename F, F Func>
746 struct CudnnBatchNormBackwardBatchRuleHelper {
applyat::functorch::CudnnBatchNormBackwardBatchRuleHelper747 static std::tuple<Tensor,Tensor,Tensor> apply(
748 const at::Tensor & input,
749 const at::Tensor & grad_out,
750 const at::Tensor & weight,
751 const std::optional<at::Tensor> & running_mean_opt,
752 const std::optional<at::Tensor> & running_var_opt,
753 const std::optional<at::Tensor> & save_mean_opt,
754 const std::optional<at::Tensor> & save_rstd_opt,
755 double eps,
756 const at::Tensor & reserve) {
757
758 auto maybe_layer = maybeCurrentDynamicLayer();
759 vmap_check_escaped(maybe_layer, "CudnnBatchNormBackwardBatchRuleHelper.apply");
760 int64_t cur_level = maybe_layer->layerId();
761
762 if (!areAnyBatchedAtLevel({input, grad_out, weight, running_mean_opt,
763 running_var_opt, save_mean_opt, save_rstd_opt, reserve}, cur_level)) {
764 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
765 return at::cudnn_batch_norm_backward(input, grad_out, weight,
766 running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps, reserve);
767 }
768
769 return batch_norm_backward_plumbing<F, Func>(
770 grad_out, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, true, eps, {true, true, true});
771 }
772 };
773
774 template <typename F, F Func>
775 struct MiopenBatchNormBackwardBatchRuleHelper {
applyat::functorch::MiopenBatchNormBackwardBatchRuleHelper776 static std::tuple<Tensor,Tensor,Tensor> apply(
777 const at::Tensor & input,
778 const at::Tensor & grad_out,
779 const at::Tensor & weight,
780 const std::optional<at::Tensor> & running_mean_opt,
781 const std::optional<at::Tensor> & running_var_opt,
782 const std::optional<at::Tensor> & save_mean_opt,
783 const std::optional<at::Tensor> & save_rstd_opt,
784 double eps) {
785
786 auto maybe_layer = maybeCurrentDynamicLayer();
787 vmap_check_escaped(maybe_layer, "MiopenBatchNormBackwardBatchRuleHelper.apply");
788 int64_t cur_level = maybe_layer->layerId();
789
790 if (!areAnyBatchedAtLevel({input, grad_out, weight, running_mean_opt,
791 running_var_opt, save_mean_opt, save_rstd_opt}, cur_level)) {
792 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
793 return at::miopen_batch_norm_backward(input, grad_out, weight,
794 running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps);
795 }
796
797 return batch_norm_backward_plumbing<F, Func>(
798 grad_out, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, true, eps, {true, true, true});
799 }
800 };
801
802 #define NATIVE_BATCH_NORM_BACKWARD_BATCH_RULE(fn) SINGLE_ARG(\
803 NativeBatchNormBackwardBatchRuleHelper<\
804 decltype(&ATEN_FN(fn)),\
805 &ATEN_FN(fn)>::apply)
806
807 #define CUDNN_BATCH_NORM_BACKWARD_BATCH_RULE(fn) SINGLE_ARG(\
808 CudnnBatchNormBackwardBatchRuleHelper<\
809 decltype(&fn),\
810 &fn>::apply)
811
812 #define MIOPEN_BATCH_NORM_BACKWARD_BATCH_RULE(fn) SINGLE_ARG(\
813 MiopenBatchNormBackwardBatchRuleHelper<\
814 decltype(&fn),\
815 &fn>::apply)
816
cudnn_batch_norm_backward_wrapper(const at::Tensor & grad_out,const at::Tensor & input,const at::Tensor & weight_opt,const std::optional<at::Tensor> & running_mean_opt,const std::optional<at::Tensor> & running_var_opt,const std::optional<at::Tensor> & save_mean_opt,const std::optional<at::Tensor> & save_rstd_opt,bool training,double eps,std::array<bool,3> output_mask)817 static std::tuple<at::Tensor,at::Tensor,at::Tensor> cudnn_batch_norm_backward_wrapper(
818 const at::Tensor & grad_out,
819 const at::Tensor & input,
820 const at::Tensor& weight_opt,
821 const std::optional<at::Tensor> & running_mean_opt,
822 const std::optional<at::Tensor> & running_var_opt,
823 const std::optional<at::Tensor> & save_mean_opt,
824 const std::optional<at::Tensor> & save_rstd_opt,
825 bool training,
826 double eps,
827 std::array<bool,3> output_mask) {
828 TORCH_INTERNAL_ASSERT(!training);
829 auto reserve = at::empty({0}, input.options().dtype(kByte));
830 return at::cudnn_batch_norm_backward(input, grad_out, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps, reserve);
831 }
832
miopen_batch_norm_backward_wrapper(const at::Tensor & grad_out,const at::Tensor & input,const at::Tensor & weight_opt,const std::optional<at::Tensor> & running_mean_opt,const std::optional<at::Tensor> & running_var_opt,const std::optional<at::Tensor> & save_mean_opt,const std::optional<at::Tensor> & save_rstd_opt,bool training,double eps,std::array<bool,3> output_mask)833 static std::tuple<at::Tensor,at::Tensor,at::Tensor> miopen_batch_norm_backward_wrapper(
834 const at::Tensor & grad_out,
835 const at::Tensor & input,
836 const at::Tensor& weight_opt,
837 const std::optional<at::Tensor> & running_mean_opt,
838 const std::optional<at::Tensor> & running_var_opt,
839 const std::optional<at::Tensor> & save_mean_opt,
840 const std::optional<at::Tensor> & save_rstd_opt,
841 bool training,
842 double eps,
843 std::array<bool,3> output_mask) {
844 TORCH_INTERNAL_ASSERT(!training); // this should be ensured by batch_norm_impl
845 return at::miopen_batch_norm_backward(input, grad_out, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps);
846 }
847
848 // NB: This is NOT good. In the ideal world, we do NOT want to convert the new legit op back into native_batch_norm
849 // as native_batch_norm has a problematic schema--it promises it is functional when it is not. However, vmap doesn't
850 // work with dynamo anyway so we gain some buffer room to do wrong things here. The (reasonable) hope is that we will
851 // make native_batch_norm composite implicit within a few weeks and we can fix this before vmap works with dynamo.
_native_batch_norm_legit_batch(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,bool train,double momentum,double eps)852 static std::tuple<at::Tensor,at::Tensor,at::Tensor> _native_batch_norm_legit_batch(
853 const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
854 Tensor& running_mean, Tensor& running_var, bool train, double momentum, double eps) {
855 return at::native_batch_norm(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps);
856 }
857
_native_batch_norm_legit_no_stats_batch(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,bool train,double momentum,double eps)858 static std::tuple<at::Tensor,at::Tensor,at::Tensor> _native_batch_norm_legit_no_stats_batch(
859 const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
860 bool train, double momentum, double eps) {
861 return at::native_batch_norm(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps);
862 }
863
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)864 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
865 VMAP_SUPPORT(native_batch_norm, NATIVE_BATCH_NORM_BATCH_RULE(native_batch_norm));
866 VMAP_SUPPORT(cudnn_batch_norm, CUDNN_BATCH_NORM_BATCH_RULE(cudnn_batch_norm));
867 VMAP_SUPPORT(miopen_batch_norm, MIOPEN_BATCH_NORM_BATCH_RULE(miopen_batch_norm));
868 m.impl("_native_batch_norm_legit", _native_batch_norm_legit_batch);
869 m.impl("_native_batch_norm_legit.no_stats", _native_batch_norm_legit_no_stats_batch);
870 m.impl("native_batch_norm_backward", NATIVE_BATCH_NORM_BACKWARD_BATCH_RULE(native_batch_norm_backward));
871 m.impl("cudnn_batch_norm_backward", CUDNN_BATCH_NORM_BACKWARD_BATCH_RULE(at::functorch::cudnn_batch_norm_backward_wrapper));
872 m.impl("miopen_batch_norm_backward", MIOPEN_BATCH_NORM_BACKWARD_BATCH_RULE(at::functorch::miopen_batch_norm_backward_wrapper));
873 m.impl("native_group_norm", native_group_norm_plumbing);
874 m.impl("native_group_norm_backward", native_group_norm_backward_plumbing);
875 VMAP_SUPPORT(native_layer_norm, native_layer_norm_batch_rule);
876 m.impl("native_layer_norm_backward", native_layer_norm_backward_plumbing);
877 }
878
879 } // namespace at::functorch
880