xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/_functions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/nn/modules/_functions.h>
3 
4 using namespace torch::autograd;
5 
6 namespace torch {
7 namespace nn {
8 namespace functions {
9 
forward(AutogradContext * ctx,const Variable & input,const CrossMapLRN2dOptions & options)10 Variable CrossMapLRN2d::forward(
11     AutogradContext* ctx,
12     const Variable& input,
13     const CrossMapLRN2dOptions& options) {
14   ctx->saved_data["size"] = options.size();
15   ctx->saved_data["alpha"] = options.alpha();
16   ctx->saved_data["beta"] = options.beta();
17   ctx->saved_data["k"] = options.k();
18   ctx->saved_data["scale"] = torch::Tensor();
19 
20   TORCH_CHECK(input.dim() == 4);
21 
22   ctx->saved_data["scale"] = ctx->saved_data["scale"].toTensor().defined()
23       ? ctx->saved_data["scale"]
24       : torch::empty({0}, input.options());
25 
26   torch::Tensor output = torch::empty({0}, input.options());
27 
28   int64_t channels = input.size(1);
29 
30   output.resize_as_(input);
31   ctx->saved_data["scale"].toTensor().resize_as_(input);
32 
33   /// use output storage as temporary buffer
34   auto input_square = output;
35   torch::pow_out(input_square, input, 2);
36 
37   int64_t pre_pad =
38       static_cast<int64_t>((ctx->saved_data["size"].toInt() - 1) / 2 + 1);
39   int64_t pre_pad_crop = pre_pad > channels ? channels : pre_pad;
40 
41   auto scale_first = ctx->saved_data["scale"].toTensor().select(1, 0);
42   scale_first.zero_();
43 
44   /// compute first feature map normalization
45   for (const auto c : c10::irange(pre_pad_crop)) {
46     scale_first.add_(input_square.select(1, c));
47   }
48 
49   /// reuse computations for next feature maps normalization
50   /// by adding the next feature map and removing the previous
51   torch::Tensor scale_previous, scale_current, square_next, square_previous;
52 
53   for (const auto c : c10::irange(1, channels)) {
54     scale_previous = ctx->saved_data["scale"].toTensor().select(1, c - 1);
55     scale_current = ctx->saved_data["scale"].toTensor().select(1, c);
56     scale_current.copy_(scale_previous);
57 
58     if (c < channels - pre_pad + 1) {
59       square_next = input_square.select(1, c + pre_pad - 1);
60       scale_current.add_(square_next, 1);
61     }
62 
63     if (c > pre_pad) {
64       square_previous = input_square.select(1, c - pre_pad);
65       scale_current.add_(square_previous, -1);
66     }
67   }
68 
69   ctx->saved_data["scale"]
70       .toTensor()
71       .mul_(
72           ctx->saved_data["alpha"].toDouble() / ctx->saved_data["size"].toInt())
73       .add_(ctx->saved_data["k"].toInt());
74 
75   torch::pow_out(
76       output,
77       ctx->saved_data["scale"].toTensor(),
78       -ctx->saved_data["beta"].toDouble());
79   output.mul_(input);
80 
81   ctx->save_for_backward({input, output});
82   return output;
83 }
84 
backward(AutogradContext * ctx,variable_list grad_outputs)85 variable_list CrossMapLRN2d::backward(
86     AutogradContext* ctx,
87     variable_list grad_outputs) {
88   auto grad_output = grad_outputs[0];
89   auto input = ctx->get_saved_variables()[0];
90   auto output = ctx->get_saved_variables()[1];
91   auto grad_input = torch::empty({0}, grad_output.options());
92 
93   int64_t batch_size = input.size(0);
94   int64_t channels = input.size(1);
95   int64_t input_height = input.size(2);
96   int64_t input_width = input.size(3);
97 
98   auto padded_ratio = torch::empty(
99       {channels + ctx->saved_data["size"].toInt() - 1,
100        input_height,
101        input_width},
102       input.options());
103   auto accum_ratio = torch::empty({input_height, input_width}, input.options());
104   double cache_ratio_value = 2 * ctx->saved_data["alpha"].toDouble() *
105       ctx->saved_data["beta"].toDouble() / ctx->saved_data["size"].toInt();
106   int64_t inversePrePad = static_cast<int64_t>(
107       ctx->saved_data["size"].toInt() -
108       (ctx->saved_data["size"].toInt() - 1) / 2);
109 
110   grad_input.resize_as_(input);
111   torch::pow_out(
112       grad_input,
113       ctx->saved_data["scale"].toTensor(),
114       -ctx->saved_data["beta"].toDouble())
115       .mul_(grad_output);
116 
117   padded_ratio.zero_();
118   auto padded_ratio_center = padded_ratio.narrow(0, inversePrePad, channels);
119 
120   for (const auto n : c10::irange(batch_size)) {
121     torch::mul_out(padded_ratio_center, grad_output[n], output[n]);
122     padded_ratio_center.div_(ctx->saved_data["scale"].toTensor()[n]);
123     torch::sum_out(
124         accum_ratio,
125         padded_ratio.narrow(0, 0, ctx->saved_data["size"].toInt() - 1),
126         0,
127         /*keepdim=*/false);
128     for (const auto c : c10::irange(channels)) {
129       accum_ratio.add_(padded_ratio[c + ctx->saved_data["size"].toInt() - 1]);
130       grad_input[n][c].addcmul_(input[n][c], accum_ratio, -cache_ratio_value);
131       accum_ratio.add_(padded_ratio[c], -1);
132     }
133   }
134 
135   return variable_list{
136       grad_input, Variable(), Variable(), Variable(), Variable()};
137 }
138 
139 } // namespace functions
140 } // namespace nn
141 } // namespace torch
142