xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_convolution_backward.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstring>
10 
11 #include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 #include <tuple>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 
19 using Tensor = exec_aten::Tensor;
20 using ScalarType = exec_aten::ScalarType;
21 using IntArrayRef = exec_aten::ArrayRef<int64_t>;
22 using OptIntArrayRef = exec_aten::OptionalArrayRef<int64_t>;
23 
24 namespace {
25 
check_convolution_backward_args(const Tensor & grad_output,const Tensor & input,const Tensor & weight,ET_UNUSED const OptIntArrayRef bias_sizes_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,ET_UNUSED exec_aten::ArrayRef<bool> output_mask,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)26 bool check_convolution_backward_args(
27     const Tensor& grad_output,
28     const Tensor& input,
29     const Tensor& weight,
30     ET_UNUSED const OptIntArrayRef bias_sizes_opt,
31     IntArrayRef stride,
32     IntArrayRef padding,
33     IntArrayRef dilation,
34     bool transposed,
35     IntArrayRef output_padding,
36     int64_t groups,
37     ET_UNUSED exec_aten::ArrayRef<bool> output_mask,
38     Tensor& grad_input,
39     Tensor& grad_weight,
40     Tensor& grad_bias) {
41   ET_LOG_MSG_AND_RETURN_IF_FALSE(
42       transposed == false, "Transposed Convolution Backward not supported yet");
43   ET_LOG_MSG_AND_RETURN_IF_FALSE(
44       weight.dim() == 4, "Only 2D Convolution Backward supported for now");
45 
46   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(weight, input));
47   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_output, input));
48   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_input, input));
49   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_weight, input));
50   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_bias, input));
51 
52   ET_LOG_MSG_AND_RETURN_IF_FALSE(
53       check_convolution_args(
54           input,
55           weight,
56           exec_aten::optional<Tensor>(),
57           stride,
58           padding,
59           dilation,
60           transposed,
61           output_padding,
62           groups,
63           grad_output),
64       "Invalid convolution arguments");
65 
66   size_t output_ndim = 0;
67   exec_aten::SizesType output_sizes[kTensorDimensionLimit];
68   get_convolution_out_target_size(
69       input,
70       weight,
71       stride,
72       padding,
73       dilation,
74       transposed,
75       output_padding,
76       groups,
77       output_sizes,
78       &output_ndim);
79 
80   ET_LOG_AND_RETURN_IF_FALSE(
81       output_size_is_valid({output_sizes, output_ndim}, input.dim() - 2));
82 
83   ET_LOG_MSG_AND_RETURN_IF_FALSE(
84       grad_output.dim() == input.dim(),
85       "grad_output should have same number of dimensions as input");
86 
87   ET_LOG_AND_RETURN_IF_FALSE(
88       tensor_has_expected_size(grad_output, {output_sizes, output_ndim}));
89 
90   return true;
91 }
92 
93 template <typename CTYPE>
conv2d_backward_impl(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups,exec_aten::ArrayRef<bool> output_mask,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)94 void conv2d_backward_impl(
95     const Tensor& grad_output,
96     const Tensor& input,
97     const Tensor& weight,
98     IntArrayRef stride,
99     IntArrayRef padding,
100     IntArrayRef dilation,
101     int64_t groups,
102     exec_aten::ArrayRef<bool> output_mask,
103     Tensor& grad_input,
104     Tensor& grad_weight,
105     Tensor& grad_bias) {
106   auto batch_size = input.size(0);
107   auto in_channels = input.size(1);
108   auto out_channels = weight.size(0);
109   auto in_height = input.size(2);
110   auto in_width = input.size(3);
111   auto out_height = grad_output.size(2);
112   auto out_width = grad_output.size(3);
113   auto kernel_height = weight.size(2);
114   auto kernel_width = weight.size(3);
115 
116   const int64_t stride_h = val_at(stride, 0);
117   const int64_t padding_h = val_at(padding, 0, /*default_value=*/0);
118   const int64_t dilation_h = val_at(dilation, 0);
119   const int64_t stride_w = val_at(stride, 1);
120   const int64_t padding_w = val_at(padding, 1, /*default_value=*/0);
121   const int64_t dilation_w = val_at(dilation, 1);
122 
123   auto in_channels_per_group = in_channels / groups;
124   auto out_channels_per_group = out_channels / groups;
125 
126   const CTYPE* grad_output_data = grad_output.const_data_ptr<CTYPE>();
127   const CTYPE* input_data = input.const_data_ptr<CTYPE>();
128   const CTYPE* weight_data = weight.const_data_ptr<CTYPE>();
129 
130   CTYPE* grad_input_data = nullptr;
131   CTYPE* grad_weight_data = nullptr;
132   CTYPE* grad_bias_data = nullptr;
133 
134   if (output_mask[0]) {
135     grad_input_data = grad_input.mutable_data_ptr<CTYPE>();
136     memset(grad_input_data, 0, grad_input.nbytes());
137   }
138 
139   if (output_mask[1]) {
140     grad_weight_data = grad_weight.mutable_data_ptr<CTYPE>();
141     memset(grad_weight_data, 0, grad_weight.nbytes());
142   }
143 
144   if (output_mask[2]) {
145     grad_bias_data = grad_bias.mutable_data_ptr<CTYPE>();
146     memset(grad_bias_data, 0, grad_bias.nbytes());
147   }
148 
149   // @lint-ignore CLANGTIDY facebook-hte-CArray
150   exec_aten::SizesType out_coord[kTensorDimensionLimit];
151   // @lint-ignore CLANGTIDY facebook-hte-CArray
152   exec_aten::SizesType in_coord[kTensorDimensionLimit];
153   // @lint-ignore CLANGTIDY facebook-hte-CArray
154   exec_aten::SizesType weight_coord[kTensorDimensionLimit];
155 
156   // Compute gradients
157   for (int64_t b = 0; b < batch_size; ++b) { // Loop over each batch
158     in_coord[0] = b;
159     out_coord[0] = b;
160     for (int64_t g = 0; g < groups; ++g) { // Loop over each group
161       for (int64_t h = 0; h < out_height; ++h) { // Loop over each output row
162         out_coord[2] = h;
163         for (int64_t w = 0; w < out_width; ++w) { // Loop over each output col
164           out_coord[3] = w;
165 
166           // Loop over each output channel in the group
167           for (int64_t oc = 0; oc < out_channels_per_group; ++oc) {
168             int64_t oc_global = oc + g * out_channels_per_group;
169             weight_coord[0] = oc_global;
170             out_coord[1] = oc_global;
171 
172             int64_t out_idx = calculate_linear_index(
173                 out_coord, grad_output.strides().data(), 4);
174 
175             // Accumulate the gradient with respect to the bias if required
176             if (output_mask[2]) {
177               grad_bias_data[oc_global] += grad_output_data[out_idx];
178             }
179 
180             // Loop over each input channel in the group
181             for (int64_t ic = 0; ic < in_channels_per_group; ++ic) {
182               int64_t ic_global = ic + g * in_channels_per_group;
183               in_coord[1] = ic_global;
184               weight_coord[1] = ic;
185 
186               // Loop over each element
187               for (int64_t kh = 0; kh < kernel_height; ++kh) {
188                 int64_t in_h = h * stride_h - padding_h + kh * dilation_h;
189                 if (in_h >= 0 && in_h < in_height) {
190                   in_coord[2] = in_h;
191                   weight_coord[2] = kh;
192 
193                   for (int64_t kw = 0; kw < kernel_width; ++kw) {
194                     int64_t in_w = w * stride_w - padding_w + kw * dilation_w;
195                     if (in_w >= 0 && in_w < in_width) {
196                       in_coord[3] = in_w;
197                       weight_coord[3] = kw;
198 
199                       int64_t in_idx = calculate_linear_index(
200                           in_coord, input.strides().data(), 4);
201 
202                       int64_t weight_idx = calculate_linear_index(
203                           weight_coord, weight.strides().data(), 4);
204 
205                       // Gradient with respect to the input if required
206                       if (output_mask[0]) {
207                         grad_input_data[in_idx] +=
208                             grad_output_data[out_idx] * weight_data[weight_idx];
209                       }
210                       // Gradient with respect to the weight if required
211                       if (output_mask[1]) {
212                         grad_weight_data[weight_idx] +=
213                             grad_output_data[out_idx] * input_data[in_idx];
214                       }
215                     }
216                   }
217                 }
218               }
219             }
220           }
221         }
222       }
223     }
224   }
225 }
226 
227 } // namespace
228 
convolution_backward_out(KernelRuntimeContext & ctx,const Tensor & grad_output,const Tensor & input,const Tensor & weight,const OptIntArrayRef bias_sizes_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,exec_aten::ArrayRef<bool> output_mask,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)229 std::tuple<Tensor&, Tensor&, Tensor&> convolution_backward_out(
230     KernelRuntimeContext& ctx,
231     const Tensor& grad_output,
232     const Tensor& input,
233     const Tensor& weight,
234     const OptIntArrayRef bias_sizes_opt,
235     IntArrayRef stride,
236     IntArrayRef padding,
237     IntArrayRef dilation,
238     bool transposed,
239     IntArrayRef output_padding,
240     int64_t groups,
241     exec_aten::ArrayRef<bool> output_mask,
242     Tensor& grad_input,
243     Tensor& grad_weight,
244     Tensor& grad_bias) {
245   (void)ctx;
246 
247   std::tuple<Tensor&, Tensor&, Tensor&> ret_val(
248       grad_input, grad_weight, grad_bias);
249 
250   ET_KERNEL_CHECK(
251       ctx,
252       check_convolution_backward_args(
253           grad_output,
254           input,
255           weight,
256           bias_sizes_opt,
257           stride,
258           padding,
259           dilation,
260           transposed,
261           output_padding,
262           groups,
263           output_mask,
264           grad_input,
265           grad_weight,
266           grad_bias),
267       InvalidArgument,
268       ret_val);
269 
270   ET_KERNEL_CHECK(
271       ctx,
272       resize_tensor(grad_input, input.sizes()) == Error::Ok,
273       InvalidArgument,
274       ret_val);
275 
276   ET_KERNEL_CHECK(
277       ctx,
278       resize_tensor(grad_weight, weight.sizes()) == Error::Ok,
279       InvalidArgument,
280       ret_val);
281 
282   if (bias_sizes_opt.has_value()) {
283     ET_KERNEL_CHECK(
284         ctx,
285         resize_tensor(grad_bias, bias_sizes_opt.value()) == Error::Ok,
286         InvalidArgument,
287         ret_val);
288   }
289 
290   constexpr auto name = "convolution_backward.out";
291 
292   ET_SWITCH_FLOATH_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {
293     conv2d_backward_impl<CTYPE>(
294         grad_output,
295         input,
296         weight,
297         stride,
298         padding,
299         dilation,
300         groups,
301         output_mask,
302         grad_input,
303         grad_weight,
304         grad_bias);
305   });
306 
307   return ret_val;
308 }
309 
310 } // namespace native
311 } // namespace executor
312 } // namespace torch
313