xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchRulesLoss.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/functorch/PlumbingHelper.h>
9 #include <ATen/functorch/BatchedFallback.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11 
12 namespace at::functorch {
13 // Flattens out all dims except the batch dim, and also moves batch dim
14 // (if it exists) to front.
flatten_logical(const Tensor & tensor,std::optional<int64_t> bdim)15 static at::Tensor flatten_logical(const Tensor& tensor, std::optional<int64_t> bdim) {
16   if (bdim.has_value()) {
17     auto result = moveBatchDimToFront(tensor, bdim);
18     if (result.dim() > 1) {
19       return result.flatten(1);
20     } else {
21       return result;
22     }
23   } else {
24     return tensor.flatten();
25   }
26 }
27 
28 // Useful for many loss functions
29 template <typename Func>
30 static std::tuple<at::Tensor, std::optional<int64_t>>
loss_batch_rule_helper(const at::Tensor & self,std::optional<int64_t> self_bdim,const at::Tensor & target,std::optional<int64_t> target_bdim,int64_t reduction,Func loss_fn)31 loss_batch_rule_helper(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
32           std::optional<int64_t> target_bdim, int64_t reduction,
33           Func loss_fn) {
34   auto self_ = flatten_logical(self, self_bdim);
35   auto target_ = flatten_logical(target, target_bdim);
36   auto result = loss_fn(self_, target_, Reduction::None);
37   if (result.dim() == 1) {
38     return std::make_tuple(result, 0);
39   } else if (reduction == Reduction::None) {
40     DimVector end_shape;
41     const auto batched_elem = self_bdim.has_value() ?
42         moveBatchDimToFront(self, self_bdim) : moveBatchDimToFront(target, target_bdim);
43     return std::make_tuple(result.reshape(batched_elem.sizes()), 0);
44   } else if (reduction == Reduction::Sum) {
45     return std::make_tuple(result.sum(-1), 0);
46   } else if (reduction == Reduction::Mean) {
47     return std::make_tuple(result.mean(-1), 0);
48   }
49   TORCH_INTERNAL_ASSERT(false);
50 };
51 
52 static std::tuple<at::Tensor, std::optional<int64_t>>
mse_loss_batch_rule(const at::Tensor & self,std::optional<int64_t> self_bdim,const at::Tensor & target,std::optional<int64_t> target_bdim,int64_t reduction)53 mse_loss_batch_rule(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
54           std::optional<int64_t> target_bdim, int64_t reduction) {
55   return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
56                                 reduction, [](const at::Tensor& self, const at::Tensor& target, int64_t reduction) {
57                                   return at::mse_loss(self, target, reduction);
58                                 });
59 };
60 
61 static std::tuple<at::Tensor, std::optional<int64_t>>
huber_loss_batch_rule(const at::Tensor & self,std::optional<int64_t> self_bdim,const at::Tensor & target,std::optional<int64_t> target_bdim,int64_t reduction,double delta)62 huber_loss_batch_rule(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
63           std::optional<int64_t> target_bdim, int64_t reduction, double delta) {
64   return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
65                                 reduction, [delta](const at::Tensor& self, const at::Tensor& target, int64_t reduction) {
66                                   return at::huber_loss(self, target, reduction, delta);
67                                 });
68 };
69 
70 static std::tuple<at::Tensor, std::optional<int64_t>>
smooth_l1_loss_batch_rule(const at::Tensor & self,std::optional<int64_t> self_bdim,const at::Tensor & target,std::optional<int64_t> target_bdim,int64_t reduction,double beta)71 smooth_l1_loss_batch_rule(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
72           std::optional<int64_t> target_bdim, int64_t reduction, double beta) {
73   return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
74                                 reduction, [beta](const at::Tensor& self, const at::Tensor& target, int64_t reduction) {
75                                   return at::smooth_l1_loss(self, target, reduction, beta);
76                                 });
77 };
78 
apply_loss_reduction(const at::Tensor & unreduced,int64_t reduction)79 static Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
80   if (reduction == at::Reduction::Mean) {
81     return unreduced.mean();
82   } else if (reduction == at::Reduction::Sum) {
83     return unreduced.sum();
84   }
85   return unreduced;
86 }
87 
binary_cross_entropy_plumbing(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight,int64_t reduction)88 static Tensor binary_cross_entropy_plumbing(
89     const Tensor& self, const Tensor& target,
90     const std::optional<Tensor>& weight, int64_t reduction) {
91   auto maybe_layer = maybeCurrentDynamicLayer();
92   vmap_check_escaped(maybe_layer, "binary_cross_entropy_plumbing");
93   int64_t cur_level = maybe_layer->layerId();
94 
95   if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)
96       && !isBatchedAtLevel(weight, cur_level)) {
97     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
98     return at::binary_cross_entropy(self, target, weight, reduction);
99   }
100 
101   auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
102   auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level);
103 
104   Tensor result;
105   if (self_bdim || target_bdim) {
106     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
107     const auto bdim_size = get_bdim_size2(self_value, self_bdim, target_value, target_bdim);
108     auto self_ = moveBatchDimToFront(self_value, self_bdim);
109     auto target_ = moveBatchDimToFront(target_value, target_bdim);
110     self_ = ensure_has_bdim(self_, self_bdim.has_value(), bdim_size);
111     target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size);
112     result = at::binary_cross_entropy(self_, target_, std::nullopt, Reduction::None);
113     result = makeBatched(result, 0, cur_level);
114   } else {
115     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
116     result = at::binary_cross_entropy(self_value, target_value, std::nullopt, Reduction::None);
117   }
118   if (weight.has_value() && weight->defined()) {
119     result = result * weight.value();
120   }
121   return apply_loss_reduction(result, reduction);
122 }
123 
binary_cross_entropy_backward_plumbing(const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction)124 static Tensor binary_cross_entropy_backward_plumbing(
125     const Tensor& grad, const Tensor& input, const Tensor& target,
126     const std::optional<Tensor>& weight_opt, int64_t reduction) {
127   auto maybe_layer = maybeCurrentDynamicLayer();
128   vmap_check_escaped(maybe_layer, "binary_cross_entropy_backward_plumbing");
129   int64_t cur_level = maybe_layer->layerId();
130 
131   if (!areAnyBatchedAtLevel({grad, input, target, weight_opt}, cur_level)) {
132     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
133     return at::binary_cross_entropy_backward(grad, input, target, weight_opt, reduction);
134   }
135 
136   auto [grad_value, grad_bdim] = unwrapTensorAtLevel(
137       reduction == Reduction::None ? grad : grad.expand_as(input), cur_level);
138   auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level);
139   auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level);
140 
141   Tensor grad_input;
142   if (grad_bdim || input_bdim || target_bdim) {
143     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
144     const auto bdim_size = get_bdim_size3(
145         grad_value, grad_bdim, input_value, input_bdim, target_value, target_bdim);
146 
147     auto grad_ = moveBatchDimToFront(grad_value, grad_bdim);
148     auto input_ = moveBatchDimToFront(input_value, input_bdim);
149     auto target_ = moveBatchDimToFront(target_value, target_bdim);
150 
151     grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bdim_size);
152     input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
153     target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size);
154 
155     grad_input = at::binary_cross_entropy_backward(
156         grad_, input_, target_, std::nullopt, Reduction::None);
157     grad_input = makeBatched(grad_input, 0, cur_level);
158   } else {
159     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
160     grad_input = at::binary_cross_entropy_backward(
161         grad_value, input_value, target_value, std::nullopt, Reduction::None);
162   }
163   if (weight_opt.has_value() && weight_opt->defined()) {
164     grad_input = grad_input * weight_opt.value();
165   }
166   if (reduction == Reduction::Mean) {
167     grad_input.div_(input.numel());
168   }
169   return grad_input;
170 }
171 
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)172 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
173   VMAP_SUPPORT(mse_loss, mse_loss_batch_rule);
174   // mse_loss_backward uses a decomposition for its batch rule
175   VMAP_SUPPORT(huber_loss, huber_loss_batch_rule);
176   // huber_loss_backward uses a decomposition for its batch rule
177   VMAP_SUPPORT(smooth_l1_loss, smooth_l1_loss_batch_rule);
178   // smooth_l1_loss_backward uses a decomposition for its batch rule
179   m.impl("binary_cross_entropy", binary_cross_entropy_plumbing);
180   m.impl("binary_cross_entropy_backward", binary_cross_entropy_backward_plumbing);
181 }
182 
183 } // namespace at::functorch
184