1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cstring>
10
11 #include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 #include <tuple>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
19 using Tensor = exec_aten::Tensor;
20 using ScalarType = exec_aten::ScalarType;
21 using IntArrayRef = exec_aten::ArrayRef<int64_t>;
22 using OptIntArrayRef = exec_aten::OptionalArrayRef<int64_t>;
23
24 namespace {
25
check_convolution_backward_args(const Tensor & grad_output,const Tensor & input,const Tensor & weight,ET_UNUSED const OptIntArrayRef bias_sizes_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,ET_UNUSED exec_aten::ArrayRef<bool> output_mask,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)26 bool check_convolution_backward_args(
27 const Tensor& grad_output,
28 const Tensor& input,
29 const Tensor& weight,
30 ET_UNUSED const OptIntArrayRef bias_sizes_opt,
31 IntArrayRef stride,
32 IntArrayRef padding,
33 IntArrayRef dilation,
34 bool transposed,
35 IntArrayRef output_padding,
36 int64_t groups,
37 ET_UNUSED exec_aten::ArrayRef<bool> output_mask,
38 Tensor& grad_input,
39 Tensor& grad_weight,
40 Tensor& grad_bias) {
41 ET_LOG_MSG_AND_RETURN_IF_FALSE(
42 transposed == false, "Transposed Convolution Backward not supported yet");
43 ET_LOG_MSG_AND_RETURN_IF_FALSE(
44 weight.dim() == 4, "Only 2D Convolution Backward supported for now");
45
46 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(weight, input));
47 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_output, input));
48 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_input, input));
49 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_weight, input));
50 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_bias, input));
51
52 ET_LOG_MSG_AND_RETURN_IF_FALSE(
53 check_convolution_args(
54 input,
55 weight,
56 exec_aten::optional<Tensor>(),
57 stride,
58 padding,
59 dilation,
60 transposed,
61 output_padding,
62 groups,
63 grad_output),
64 "Invalid convolution arguments");
65
66 size_t output_ndim = 0;
67 exec_aten::SizesType output_sizes[kTensorDimensionLimit];
68 get_convolution_out_target_size(
69 input,
70 weight,
71 stride,
72 padding,
73 dilation,
74 transposed,
75 output_padding,
76 groups,
77 output_sizes,
78 &output_ndim);
79
80 ET_LOG_AND_RETURN_IF_FALSE(
81 output_size_is_valid({output_sizes, output_ndim}, input.dim() - 2));
82
83 ET_LOG_MSG_AND_RETURN_IF_FALSE(
84 grad_output.dim() == input.dim(),
85 "grad_output should have same number of dimensions as input");
86
87 ET_LOG_AND_RETURN_IF_FALSE(
88 tensor_has_expected_size(grad_output, {output_sizes, output_ndim}));
89
90 return true;
91 }
92
93 template <typename CTYPE>
conv2d_backward_impl(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups,exec_aten::ArrayRef<bool> output_mask,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)94 void conv2d_backward_impl(
95 const Tensor& grad_output,
96 const Tensor& input,
97 const Tensor& weight,
98 IntArrayRef stride,
99 IntArrayRef padding,
100 IntArrayRef dilation,
101 int64_t groups,
102 exec_aten::ArrayRef<bool> output_mask,
103 Tensor& grad_input,
104 Tensor& grad_weight,
105 Tensor& grad_bias) {
106 auto batch_size = input.size(0);
107 auto in_channels = input.size(1);
108 auto out_channels = weight.size(0);
109 auto in_height = input.size(2);
110 auto in_width = input.size(3);
111 auto out_height = grad_output.size(2);
112 auto out_width = grad_output.size(3);
113 auto kernel_height = weight.size(2);
114 auto kernel_width = weight.size(3);
115
116 const int64_t stride_h = val_at(stride, 0);
117 const int64_t padding_h = val_at(padding, 0, /*default_value=*/0);
118 const int64_t dilation_h = val_at(dilation, 0);
119 const int64_t stride_w = val_at(stride, 1);
120 const int64_t padding_w = val_at(padding, 1, /*default_value=*/0);
121 const int64_t dilation_w = val_at(dilation, 1);
122
123 auto in_channels_per_group = in_channels / groups;
124 auto out_channels_per_group = out_channels / groups;
125
126 const CTYPE* grad_output_data = grad_output.const_data_ptr<CTYPE>();
127 const CTYPE* input_data = input.const_data_ptr<CTYPE>();
128 const CTYPE* weight_data = weight.const_data_ptr<CTYPE>();
129
130 CTYPE* grad_input_data = nullptr;
131 CTYPE* grad_weight_data = nullptr;
132 CTYPE* grad_bias_data = nullptr;
133
134 if (output_mask[0]) {
135 grad_input_data = grad_input.mutable_data_ptr<CTYPE>();
136 memset(grad_input_data, 0, grad_input.nbytes());
137 }
138
139 if (output_mask[1]) {
140 grad_weight_data = grad_weight.mutable_data_ptr<CTYPE>();
141 memset(grad_weight_data, 0, grad_weight.nbytes());
142 }
143
144 if (output_mask[2]) {
145 grad_bias_data = grad_bias.mutable_data_ptr<CTYPE>();
146 memset(grad_bias_data, 0, grad_bias.nbytes());
147 }
148
149 // @lint-ignore CLANGTIDY facebook-hte-CArray
150 exec_aten::SizesType out_coord[kTensorDimensionLimit];
151 // @lint-ignore CLANGTIDY facebook-hte-CArray
152 exec_aten::SizesType in_coord[kTensorDimensionLimit];
153 // @lint-ignore CLANGTIDY facebook-hte-CArray
154 exec_aten::SizesType weight_coord[kTensorDimensionLimit];
155
156 // Compute gradients
157 for (int64_t b = 0; b < batch_size; ++b) { // Loop over each batch
158 in_coord[0] = b;
159 out_coord[0] = b;
160 for (int64_t g = 0; g < groups; ++g) { // Loop over each group
161 for (int64_t h = 0; h < out_height; ++h) { // Loop over each output row
162 out_coord[2] = h;
163 for (int64_t w = 0; w < out_width; ++w) { // Loop over each output col
164 out_coord[3] = w;
165
166 // Loop over each output channel in the group
167 for (int64_t oc = 0; oc < out_channels_per_group; ++oc) {
168 int64_t oc_global = oc + g * out_channels_per_group;
169 weight_coord[0] = oc_global;
170 out_coord[1] = oc_global;
171
172 int64_t out_idx = calculate_linear_index(
173 out_coord, grad_output.strides().data(), 4);
174
175 // Accumulate the gradient with respect to the bias if required
176 if (output_mask[2]) {
177 grad_bias_data[oc_global] += grad_output_data[out_idx];
178 }
179
180 // Loop over each input channel in the group
181 for (int64_t ic = 0; ic < in_channels_per_group; ++ic) {
182 int64_t ic_global = ic + g * in_channels_per_group;
183 in_coord[1] = ic_global;
184 weight_coord[1] = ic;
185
186 // Loop over each element
187 for (int64_t kh = 0; kh < kernel_height; ++kh) {
188 int64_t in_h = h * stride_h - padding_h + kh * dilation_h;
189 if (in_h >= 0 && in_h < in_height) {
190 in_coord[2] = in_h;
191 weight_coord[2] = kh;
192
193 for (int64_t kw = 0; kw < kernel_width; ++kw) {
194 int64_t in_w = w * stride_w - padding_w + kw * dilation_w;
195 if (in_w >= 0 && in_w < in_width) {
196 in_coord[3] = in_w;
197 weight_coord[3] = kw;
198
199 int64_t in_idx = calculate_linear_index(
200 in_coord, input.strides().data(), 4);
201
202 int64_t weight_idx = calculate_linear_index(
203 weight_coord, weight.strides().data(), 4);
204
205 // Gradient with respect to the input if required
206 if (output_mask[0]) {
207 grad_input_data[in_idx] +=
208 grad_output_data[out_idx] * weight_data[weight_idx];
209 }
210 // Gradient with respect to the weight if required
211 if (output_mask[1]) {
212 grad_weight_data[weight_idx] +=
213 grad_output_data[out_idx] * input_data[in_idx];
214 }
215 }
216 }
217 }
218 }
219 }
220 }
221 }
222 }
223 }
224 }
225 }
226
227 } // namespace
228
convolution_backward_out(KernelRuntimeContext & ctx,const Tensor & grad_output,const Tensor & input,const Tensor & weight,const OptIntArrayRef bias_sizes_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,exec_aten::ArrayRef<bool> output_mask,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)229 std::tuple<Tensor&, Tensor&, Tensor&> convolution_backward_out(
230 KernelRuntimeContext& ctx,
231 const Tensor& grad_output,
232 const Tensor& input,
233 const Tensor& weight,
234 const OptIntArrayRef bias_sizes_opt,
235 IntArrayRef stride,
236 IntArrayRef padding,
237 IntArrayRef dilation,
238 bool transposed,
239 IntArrayRef output_padding,
240 int64_t groups,
241 exec_aten::ArrayRef<bool> output_mask,
242 Tensor& grad_input,
243 Tensor& grad_weight,
244 Tensor& grad_bias) {
245 (void)ctx;
246
247 std::tuple<Tensor&, Tensor&, Tensor&> ret_val(
248 grad_input, grad_weight, grad_bias);
249
250 ET_KERNEL_CHECK(
251 ctx,
252 check_convolution_backward_args(
253 grad_output,
254 input,
255 weight,
256 bias_sizes_opt,
257 stride,
258 padding,
259 dilation,
260 transposed,
261 output_padding,
262 groups,
263 output_mask,
264 grad_input,
265 grad_weight,
266 grad_bias),
267 InvalidArgument,
268 ret_val);
269
270 ET_KERNEL_CHECK(
271 ctx,
272 resize_tensor(grad_input, input.sizes()) == Error::Ok,
273 InvalidArgument,
274 ret_val);
275
276 ET_KERNEL_CHECK(
277 ctx,
278 resize_tensor(grad_weight, weight.sizes()) == Error::Ok,
279 InvalidArgument,
280 ret_val);
281
282 if (bias_sizes_opt.has_value()) {
283 ET_KERNEL_CHECK(
284 ctx,
285 resize_tensor(grad_bias, bias_sizes_opt.value()) == Error::Ok,
286 InvalidArgument,
287 ret_val);
288 }
289
290 constexpr auto name = "convolution_backward.out";
291
292 ET_SWITCH_FLOATH_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {
293 conv2d_backward_impl<CTYPE>(
294 grad_output,
295 input,
296 weight,
297 stride,
298 padding,
299 dilation,
300 groups,
301 output_mask,
302 grad_input,
303 grad_weight,
304 grad_bias);
305 });
306
307 return ret_val;
308 }
309
310 } // namespace native
311 } // namespace executor
312 } // namespace torch
313