1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/functorch/PlumbingHelper.h>
9 #include <ATen/functorch/BatchedFallback.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11
12 namespace at::functorch {
13 // Flattens out all dims except the batch dim, and also moves batch dim
14 // (if it exists) to front.
flatten_logical(const Tensor & tensor,std::optional<int64_t> bdim)15 static at::Tensor flatten_logical(const Tensor& tensor, std::optional<int64_t> bdim) {
16 if (bdim.has_value()) {
17 auto result = moveBatchDimToFront(tensor, bdim);
18 if (result.dim() > 1) {
19 return result.flatten(1);
20 } else {
21 return result;
22 }
23 } else {
24 return tensor.flatten();
25 }
26 }
27
28 // Useful for many loss functions
29 template <typename Func>
30 static std::tuple<at::Tensor, std::optional<int64_t>>
loss_batch_rule_helper(const at::Tensor & self,std::optional<int64_t> self_bdim,const at::Tensor & target,std::optional<int64_t> target_bdim,int64_t reduction,Func loss_fn)31 loss_batch_rule_helper(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
32 std::optional<int64_t> target_bdim, int64_t reduction,
33 Func loss_fn) {
34 auto self_ = flatten_logical(self, self_bdim);
35 auto target_ = flatten_logical(target, target_bdim);
36 auto result = loss_fn(self_, target_, Reduction::None);
37 if (result.dim() == 1) {
38 return std::make_tuple(result, 0);
39 } else if (reduction == Reduction::None) {
40 DimVector end_shape;
41 const auto batched_elem = self_bdim.has_value() ?
42 moveBatchDimToFront(self, self_bdim) : moveBatchDimToFront(target, target_bdim);
43 return std::make_tuple(result.reshape(batched_elem.sizes()), 0);
44 } else if (reduction == Reduction::Sum) {
45 return std::make_tuple(result.sum(-1), 0);
46 } else if (reduction == Reduction::Mean) {
47 return std::make_tuple(result.mean(-1), 0);
48 }
49 TORCH_INTERNAL_ASSERT(false);
50 };
51
52 static std::tuple<at::Tensor, std::optional<int64_t>>
mse_loss_batch_rule(const at::Tensor & self,std::optional<int64_t> self_bdim,const at::Tensor & target,std::optional<int64_t> target_bdim,int64_t reduction)53 mse_loss_batch_rule(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
54 std::optional<int64_t> target_bdim, int64_t reduction) {
55 return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
56 reduction, [](const at::Tensor& self, const at::Tensor& target, int64_t reduction) {
57 return at::mse_loss(self, target, reduction);
58 });
59 };
60
61 static std::tuple<at::Tensor, std::optional<int64_t>>
huber_loss_batch_rule(const at::Tensor & self,std::optional<int64_t> self_bdim,const at::Tensor & target,std::optional<int64_t> target_bdim,int64_t reduction,double delta)62 huber_loss_batch_rule(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
63 std::optional<int64_t> target_bdim, int64_t reduction, double delta) {
64 return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
65 reduction, [delta](const at::Tensor& self, const at::Tensor& target, int64_t reduction) {
66 return at::huber_loss(self, target, reduction, delta);
67 });
68 };
69
70 static std::tuple<at::Tensor, std::optional<int64_t>>
smooth_l1_loss_batch_rule(const at::Tensor & self,std::optional<int64_t> self_bdim,const at::Tensor & target,std::optional<int64_t> target_bdim,int64_t reduction,double beta)71 smooth_l1_loss_batch_rule(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
72 std::optional<int64_t> target_bdim, int64_t reduction, double beta) {
73 return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
74 reduction, [beta](const at::Tensor& self, const at::Tensor& target, int64_t reduction) {
75 return at::smooth_l1_loss(self, target, reduction, beta);
76 });
77 };
78
apply_loss_reduction(const at::Tensor & unreduced,int64_t reduction)79 static Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
80 if (reduction == at::Reduction::Mean) {
81 return unreduced.mean();
82 } else if (reduction == at::Reduction::Sum) {
83 return unreduced.sum();
84 }
85 return unreduced;
86 }
87
binary_cross_entropy_plumbing(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight,int64_t reduction)88 static Tensor binary_cross_entropy_plumbing(
89 const Tensor& self, const Tensor& target,
90 const std::optional<Tensor>& weight, int64_t reduction) {
91 auto maybe_layer = maybeCurrentDynamicLayer();
92 vmap_check_escaped(maybe_layer, "binary_cross_entropy_plumbing");
93 int64_t cur_level = maybe_layer->layerId();
94
95 if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)
96 && !isBatchedAtLevel(weight, cur_level)) {
97 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
98 return at::binary_cross_entropy(self, target, weight, reduction);
99 }
100
101 auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
102 auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level);
103
104 Tensor result;
105 if (self_bdim || target_bdim) {
106 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
107 const auto bdim_size = get_bdim_size2(self_value, self_bdim, target_value, target_bdim);
108 auto self_ = moveBatchDimToFront(self_value, self_bdim);
109 auto target_ = moveBatchDimToFront(target_value, target_bdim);
110 self_ = ensure_has_bdim(self_, self_bdim.has_value(), bdim_size);
111 target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size);
112 result = at::binary_cross_entropy(self_, target_, std::nullopt, Reduction::None);
113 result = makeBatched(result, 0, cur_level);
114 } else {
115 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
116 result = at::binary_cross_entropy(self_value, target_value, std::nullopt, Reduction::None);
117 }
118 if (weight.has_value() && weight->defined()) {
119 result = result * weight.value();
120 }
121 return apply_loss_reduction(result, reduction);
122 }
123
binary_cross_entropy_backward_plumbing(const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction)124 static Tensor binary_cross_entropy_backward_plumbing(
125 const Tensor& grad, const Tensor& input, const Tensor& target,
126 const std::optional<Tensor>& weight_opt, int64_t reduction) {
127 auto maybe_layer = maybeCurrentDynamicLayer();
128 vmap_check_escaped(maybe_layer, "binary_cross_entropy_backward_plumbing");
129 int64_t cur_level = maybe_layer->layerId();
130
131 if (!areAnyBatchedAtLevel({grad, input, target, weight_opt}, cur_level)) {
132 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
133 return at::binary_cross_entropy_backward(grad, input, target, weight_opt, reduction);
134 }
135
136 auto [grad_value, grad_bdim] = unwrapTensorAtLevel(
137 reduction == Reduction::None ? grad : grad.expand_as(input), cur_level);
138 auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level);
139 auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level);
140
141 Tensor grad_input;
142 if (grad_bdim || input_bdim || target_bdim) {
143 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
144 const auto bdim_size = get_bdim_size3(
145 grad_value, grad_bdim, input_value, input_bdim, target_value, target_bdim);
146
147 auto grad_ = moveBatchDimToFront(grad_value, grad_bdim);
148 auto input_ = moveBatchDimToFront(input_value, input_bdim);
149 auto target_ = moveBatchDimToFront(target_value, target_bdim);
150
151 grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bdim_size);
152 input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
153 target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size);
154
155 grad_input = at::binary_cross_entropy_backward(
156 grad_, input_, target_, std::nullopt, Reduction::None);
157 grad_input = makeBatched(grad_input, 0, cur_level);
158 } else {
159 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
160 grad_input = at::binary_cross_entropy_backward(
161 grad_value, input_value, target_value, std::nullopt, Reduction::None);
162 }
163 if (weight_opt.has_value() && weight_opt->defined()) {
164 grad_input = grad_input * weight_opt.value();
165 }
166 if (reduction == Reduction::Mean) {
167 grad_input.div_(input.numel());
168 }
169 return grad_input;
170 }
171
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)172 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
173 VMAP_SUPPORT(mse_loss, mse_loss_batch_rule);
174 // mse_loss_backward uses a decomposition for its batch rule
175 VMAP_SUPPORT(huber_loss, huber_loss_batch_rule);
176 // huber_loss_backward uses a decomposition for its batch rule
177 VMAP_SUPPORT(smooth_l1_loss, smooth_l1_loss_batch_rule);
178 // smooth_l1_loss_backward uses a decomposition for its batch rule
179 m.impl("binary_cross_entropy", binary_cross_entropy_plumbing);
180 m.impl("binary_cross_entropy_backward", binary_cross_entropy_backward_plumbing);
181 }
182
183 } // namespace at::functorch
184