xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchRulesActivation.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/functorch/PlumbingHelper.h>
9 #include <ATen/Operators.h>
10 
11 // NB: most activation functions fit pointwise unary or binary rules.
12 // These are only the ones that have special batch rules to help with organization
13 namespace at::functorch {
14 static std::tuple<Tensor, std::optional<int64_t>>
glu_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim)15 glu_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim, int64_t dim) {
16   // repeated error message from glu because 0D -> 1D when batched
17   // this can't pass anyway because a 0-dimensional tensor has "size" 1, which
18   // can't be evenly halved, but give a nicer error message here.
19   TORCH_CHECK(self.dim() > 1, "glu does not support 0-dimensional tensors");
20 
21   const auto rank = rankWithoutBatchDim(self, self_bdim);
22   const auto dim_ = maybe_wrap_dim(dim, rank) + 1;
23 
24   const auto self_ = moveBatchDimToFront(self, self_bdim);
25 
26   const auto res = at::glu(self_, dim_);
27   return std::make_tuple(res, 0);
28 }
29 
glu_backward_batch_rule(const Tensor & grad_output,std::optional<int64_t> grad_output_bdim,const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim)30 static std::tuple<Tensor, std::optional<int64_t>> glu_backward_batch_rule(
31     const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
32     const Tensor& self, std::optional<int64_t> self_bdim, int64_t dim) {
33   if (self_bdim) {
34     // repeated error message from glu because 0D -> 1D when batched
35     // this can't pass anyway because a 0-dimensional tensor has "size" 1, which
36     // can't be evenly halved, but give a nicer error message here.
37     TORCH_CHECK(self.dim() > 1, "glu does not support 0-dimensional tensors");
38   }
39 
40   const auto rank = rankWithoutBatchDim(self, self_bdim);
41   const auto dim_ = maybe_wrap_dim(dim, rank) + 1;
42 
43   const auto batch_size = get_bdim_size2(grad_output, grad_output_bdim, self, self_bdim);
44   const auto grad_output_ = ensure_has_bdim(moveBatchDimToFront(grad_output, grad_output_bdim), grad_output_bdim.has_value(), batch_size);
45   const auto self_ = ensure_has_bdim(moveBatchDimToFront(self, self_bdim), self_bdim.has_value(), batch_size);
46 
47   const auto res = at::glu_backward(grad_output_, self_, dim_);
48   return std::make_tuple(res, 0);
49 }
50 
51 
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)52 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
53   VMAP_SUPPORT(glu_backward, glu_backward_batch_rule);
54   VMAP_SUPPORT(glu, glu_batch_rule);
55 }
56 } // namespace at::functorch
57