xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/conv.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/irange.h>
4 #include <c10/util/overloaded.h>
5 
6 #include <torch/expanding_array.h>
7 #include <torch/nn/cloneable.h>
8 #include <torch/nn/init.h>
9 #include <torch/nn/modules/common.h>
10 #include <torch/nn/modules/utils.h>
11 #include <torch/nn/options/conv.h>
12 #include <torch/nn/pimpl.h>
13 #include <torch/types.h>
14 
15 #include <torch/csrc/Export.h>
16 
17 #include <cstddef>
18 #include <vector>
19 
20 namespace torch {
21 namespace nn {
22 
23 /// Base class for all (dimension-specialized) convolution modules.
24 template <size_t D, typename Derived>
25 class ConvNdImpl : public torch::nn::Cloneable<Derived> {
26  public:
ConvNdImpl(detail::ConvNdOptions<D> options_)27   explicit ConvNdImpl(detail::ConvNdOptions<D> options_)
28       : options(std::move(options_)) {
29     // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
30     reset();
31   }
32 
reset()33   void reset() override {
34     TORCH_CHECK(
35         options.in_channels() > 0 && options.groups() > 0 &&
36             options.out_channels() > 0,
37         "in_channels, groups and out_channels must be a positive integer.");
38     TORCH_CHECK(
39         options.in_channels() % options.groups() == 0,
40         "in_channels must be divisible by groups");
41     TORCH_CHECK(
42         options.out_channels() % options.groups() == 0,
43         "out_channels must be divisible by groups");
44 
45     std::visit(
46         c10::overloaded(
47             [&](enumtype::kValid) {
48               _reversed_padding_repeated_twice.resize(2 * D);
49               std::fill_n(_reversed_padding_repeated_twice.begin(), 2 * D, 0);
50             },
51             [&](enumtype::kSame) {
52               for (const auto i : c10::irange(D)) {
53                 const auto stride = (*options.stride())[i];
54                 TORCH_CHECK(
55                     stride == 1,
56                     "padding='same' is not supported for strided convolutions");
57               }
58 
59               _reversed_padding_repeated_twice.resize(2 * D);
60               for (const auto i : c10::irange(D)) {
61                 const auto dilation = (*options.dilation())[i];
62                 const auto kernel_size = (*options.kernel_size())[i];
63                 const auto total_padding = dilation * (kernel_size - 1);
64                 auto left_pad = total_padding / 2;
65                 auto right_pad = total_padding - left_pad;
66                 _reversed_padding_repeated_twice[2 * i] = left_pad;
67                 _reversed_padding_repeated_twice[2 * i + 1] = right_pad;
68               }
69             },
70             [&](const ExpandingArray<D>& pad) {
71               _reversed_padding_repeated_twice =
72                   torch::nn::modules::utils::_reverse_repeat_vector(pad, 2);
73             }),
74         options.padding());
75 
76     if (options.transposed()) {
77       std::vector<int64_t> weight_sizes = {
78           options.in_channels(), options.out_channels() / options.groups()};
79       weight_sizes.insert(
80           weight_sizes.end(),
81           (*options.kernel_size()).begin(),
82           (*options.kernel_size()).end());
83       weight = this->register_parameter("weight", torch::empty(weight_sizes));
84     } else {
85       std::vector<int64_t> weight_sizes = {
86           options.out_channels(), options.in_channels() / options.groups()};
87       weight_sizes.insert(
88           weight_sizes.end(),
89           (*options.kernel_size()).begin(),
90           (*options.kernel_size()).end());
91       weight = this->register_parameter("weight", torch::empty(weight_sizes));
92     }
93 
94     if (options.bias()) {
95       bias = this->register_parameter(
96           "bias", torch::empty({options.out_channels()}));
97     } else {
98       this->register_parameter("bias", Tensor(), /*requires_grad=*/false);
99     }
100 
101     reset_parameters();
102   }
103 
reset_parameters()104   void reset_parameters() {
105     init::kaiming_uniform_(
106         weight,
107         /*a=*/std::sqrt(5)); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
108 
109     if (bias.defined()) {
110       auto [fan_in, fan_out] = init::_calculate_fan_in_and_fan_out(weight);
111       auto bound = 1 / std::sqrt(fan_in);
112       init::uniform_(bias, -bound, bound);
113     }
114   }
115 
116   /// Pretty prints the `Conv{1,2,3}d` module into the given `stream`.
pretty_print(std::ostream & stream)117   void pretty_print(std::ostream& stream) const override {
118     stream << "torch::nn::Conv" << D << "d"
119            << "(" << options.in_channels() << ", " << options.out_channels()
120            << ", kernel_size=" << options.kernel_size()
121            << ", stride=" << options.stride();
122     std::visit(
123         c10::overloaded(
124             [&](enumtype::kValid) { stream << ", padding='valid'"; },
125             [&](enumtype::kSame) { stream << ", padding='same'"; },
126             [&](const ExpandingArray<D>& pad) {
127               if (*pad != *ExpandingArray<D>(0)) {
128                 stream << ", padding=" << pad;
129               }
130             }),
131         options.padding());
132     if (*options.dilation() != *ExpandingArray<D>(1)) {
133       stream << ", dilation=" << options.dilation();
134     }
135     if (*options.output_padding() != *ExpandingArray<D>(0)) {
136       stream << ", output_padding=" << options.output_padding();
137     }
138     if (options.groups() != 1) {
139       stream << ", groups=" << options.groups();
140     }
141     if (!options.bias()) {
142       stream << ", bias=" << std::boolalpha << false;
143     }
144     if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
145       stream << ", padding_mode="
146              << enumtype::get_enum_name(options.padding_mode());
147     }
148     stream << ")";
149   }
150 
151   /// The options with which this `Module` was constructed.
152   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
153   detail::ConvNdOptions<D> options;
154 
155   /// The learned kernel (or "weight").
156   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
157   Tensor weight;
158 
159   /// The learned bias. Only defined if the `bias` option was true.
160   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
161   Tensor bias;
162 
163  protected:
164   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
165   std::vector<int64_t> _reversed_padding_repeated_twice;
166 };
167 
168 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
169 
170 /// Applies convolution over a 1-D input.
171 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Conv1d to learn about
172 /// the exact behavior of this module.
173 ///
174 /// See the documentation for `torch::nn::Conv1dOptions` class to learn what
175 /// constructor arguments are supported for this module.
176 ///
177 /// Example:
178 /// ```
179 /// Conv1d model(Conv1dOptions(3, 2, 3).stride(1).bias(false));
180 /// ```
181 class TORCH_API Conv1dImpl : public ConvNdImpl<1, Conv1dImpl> {
182  public:
Conv1dImpl(int64_t input_channels,int64_t output_channels,ExpandingArray<1> kernel_size)183   Conv1dImpl(
184       int64_t input_channels,
185       int64_t output_channels,
186       ExpandingArray<1> kernel_size)
187       : Conv1dImpl(
188             Conv1dOptions(input_channels, output_channels, kernel_size)) {}
189   explicit Conv1dImpl(Conv1dOptions options_);
190   Tensor forward(const Tensor& input);
191 };
192 
193 /// A `ModuleHolder` subclass for `Conv1dImpl`.
194 /// See the documentation for `Conv1dImpl` class to learn what methods it
195 /// provides, and examples of how to use `Conv1d` with
196 /// `torch::nn::Conv1dOptions`. See the documentation for `ModuleHolder` to
197 /// learn about PyTorch's module storage semantics.
198 TORCH_MODULE(Conv1d);
199 
200 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
201 
202 /// Applies convolution over a 2-D input.
203 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Conv2d to learn about
204 /// the exact behavior of this module.
205 ///
206 /// See the documentation for `torch::nn::Conv2dOptions` class to learn what
207 /// constructor arguments are supported for this module.
208 ///
209 /// Example:
210 /// ```
211 /// Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false));
212 /// ```
213 class TORCH_API Conv2dImpl : public ConvNdImpl<2, Conv2dImpl> {
214  public:
Conv2dImpl(int64_t input_channels,int64_t output_channels,ExpandingArray<2> kernel_size)215   Conv2dImpl(
216       int64_t input_channels,
217       int64_t output_channels,
218       ExpandingArray<2> kernel_size)
219       : Conv2dImpl(
220             Conv2dOptions(input_channels, output_channels, kernel_size)) {}
221   explicit Conv2dImpl(Conv2dOptions options_);
222   Tensor forward(const Tensor& input);
223 
224  protected:
225   Tensor _conv_forward(const Tensor& input, const Tensor& weight);
226 };
227 
228 /// A `ModuleHolder` subclass for `Conv2dImpl`.
229 /// See the documentation for `Conv2dImpl` class to learn what methods it
230 /// provides, and examples of how to use `Conv2d` with
231 /// `torch::nn::Conv2dOptions`. See the documentation for `ModuleHolder` to
232 /// learn about PyTorch's module storage semantics.
233 TORCH_MODULE(Conv2d);
234 
235 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
236 
237 /// Applies convolution over a 3-D input.
238 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Conv3d to learn about
239 /// the exact behavior of this module.
240 ///
241 /// See the documentation for `torch::nn::Conv3dOptions` class to learn what
242 /// constructor arguments are supported for this module.
243 ///
244 /// Example:
245 /// ```
246 /// Conv3d model(Conv3dOptions(3, 2, 3).stride(1).bias(false));
247 /// ```
248 class TORCH_API Conv3dImpl : public ConvNdImpl<3, Conv3dImpl> {
249  public:
Conv3dImpl(int64_t input_channels,int64_t output_channels,ExpandingArray<3> kernel_size)250   Conv3dImpl(
251       int64_t input_channels,
252       int64_t output_channels,
253       ExpandingArray<3> kernel_size)
254       : Conv3dImpl(
255             Conv3dOptions(input_channels, output_channels, kernel_size)) {}
256   explicit Conv3dImpl(Conv3dOptions options_);
257   Tensor forward(const Tensor& input);
258 };
259 
260 /// A `ModuleHolder` subclass for `Conv3dImpl`.
261 /// See the documentation for `Conv3dImpl` class to learn what methods it
262 /// provides, and examples of how to use `Conv3d` with
263 /// `torch::nn::Conv3dOptions`. See the documentation for `ModuleHolder` to
264 /// learn about PyTorch's module storage semantics.
265 TORCH_MODULE(Conv3d);
266 
267 // ~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
268 
269 /// Base class for all (dimension-specialized) convolution transpose modules.
270 template <size_t D, typename Derived>
271 class ConvTransposeNdImpl : public ConvNdImpl<D, Derived> {
272  public:
273   using torch::nn::ConvNdImpl<D, Derived>::ConvNdImpl;
ConvTransposeNdImpl(detail::ConvNdOptions<D> options_)274   explicit ConvTransposeNdImpl(detail::ConvNdOptions<D> options_)
275       : ConvNdImpl<D, Derived>(options_) {
276     TORCH_INTERNAL_ASSERT(
277         std::holds_alternative<ExpandingArray<D>>(this->options.padding()),
278         "ConvTranspose padding cannot be a string");
279   }
280 
281   /// Pretty prints the `ConvTranspose{1,2,3}d` module into the given `stream`.
pretty_print(std::ostream & stream)282   void pretty_print(std::ostream& stream) const override {
283     stream << "torch::nn::ConvTranspose" << D << "d"
284            << "(" << this->options.in_channels() << ", "
285            << this->options.out_channels()
286            << ", kernel_size=" << this->options.kernel_size()
287            << ", stride=" << this->options.stride();
288     const auto& pad = padding();
289     if (*pad != *ExpandingArray<D>(0)) {
290       stream << ", padding=" << pad;
291     }
292     if (*this->options.dilation() != *ExpandingArray<D>(1)) {
293       stream << ", dilation=" << this->options.dilation();
294     }
295     if (*this->options.output_padding() != *ExpandingArray<D>(0)) {
296       stream << ", output_padding=" << this->options.output_padding();
297     }
298     if (this->options.groups() != 1) {
299       stream << ", groups=" << this->options.groups();
300     }
301     if (!this->options.bias()) {
302       stream << ", bias=" << std::boolalpha << false;
303     }
304     if (!std::get_if<enumtype::kZeros>(&this->options.padding_mode())) {
305       stream << ", padding_mode="
306              << enumtype::get_enum_name(this->options.padding_mode());
307     }
308     stream << ")";
309   }
310 
311  protected:
padding()312   const ExpandingArray<D>& padding() const {
313     return std::get<ExpandingArray<D>>(this->options.padding());
314   }
315 
316   std::vector<int64_t> _output_padding(
317       const Tensor& input,
318       const std::optional<at::IntArrayRef>& output_size,
319       const ExpandingArray<D>& stride,
320       const ExpandingArray<D>& padding,
321       const ExpandingArray<D>& kernel_size);
322 };
323 
324 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose1d
325 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
326 
327 /// Applies the ConvTranspose1d function.
328 /// See https://pytorch.org/docs/main/nn.html#torch.nn.ConvTranspose1d to
329 /// learn about the exact behavior of this module.
330 ///
331 /// See the documentation for `torch::nn::ConvTranspose1dOptions` class to learn
332 /// what constructor arguments are supported for this module.
333 ///
334 /// Example:
335 /// ```
336 /// ConvTranspose1d model(ConvTranspose1dOptions(3, 2,
337 /// 3).stride(1).bias(false));
338 /// ```
339 class TORCH_API ConvTranspose1dImpl
340     : public ConvTransposeNdImpl<1, ConvTranspose1dImpl> {
341  public:
ConvTranspose1dImpl(int64_t input_channels,int64_t output_channels,ExpandingArray<1> kernel_size)342   ConvTranspose1dImpl(
343       int64_t input_channels,
344       int64_t output_channels,
345       ExpandingArray<1> kernel_size)
346       : ConvTranspose1dImpl(ConvTranspose1dOptions(
347             input_channels,
348             output_channels,
349             kernel_size)) {}
350   explicit ConvTranspose1dImpl(ConvTranspose1dOptions options_);
351   Tensor forward(
352       const Tensor& input,
353       const std::optional<at::IntArrayRef>& output_size = std::nullopt);
354 
355  protected:
356   FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional<at::IntArrayRef>())})
357 };
358 
359 /// A `ModuleHolder` subclass for `ConvTranspose1dImpl`.
360 /// See the documentation for `ConvTranspose1dImpl` class to learn what methods
361 /// it provides, and examples of how to use `ConvTranspose1d` with
362 /// `torch::nn::ConvTranspose1dOptions`. See the documentation for
363 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
364 TORCH_MODULE(ConvTranspose1d);
365 
366 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose2d
367 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
368 
369 /// Applies the ConvTranspose2d function.
370 /// See https://pytorch.org/docs/main/nn.html#torch.nn.ConvTranspose2d to
371 /// learn about the exact behavior of this module.
372 ///
373 /// See the documentation for `torch::nn::ConvTranspose2dOptions` class to learn
374 /// what constructor arguments are supported for this module.
375 ///
376 /// Example:
377 /// ```
378 /// ConvTranspose2d model(ConvTranspose2dOptions(3, 2,
379 /// 3).stride(1).bias(false));
380 /// ```
381 class TORCH_API ConvTranspose2dImpl
382     : public ConvTransposeNdImpl<2, ConvTranspose2dImpl> {
383  public:
ConvTranspose2dImpl(int64_t input_channels,int64_t output_channels,ExpandingArray<2> kernel_size)384   ConvTranspose2dImpl(
385       int64_t input_channels,
386       int64_t output_channels,
387       ExpandingArray<2> kernel_size)
388       : ConvTranspose2dImpl(ConvTranspose2dOptions(
389             input_channels,
390             output_channels,
391             kernel_size)) {}
392   explicit ConvTranspose2dImpl(ConvTranspose2dOptions options_);
393   Tensor forward(
394       const Tensor& input,
395       const std::optional<at::IntArrayRef>& output_size = std::nullopt);
396 
397  protected:
398   FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional<at::IntArrayRef>())})
399 };
400 
401 /// A `ModuleHolder` subclass for `ConvTranspose2dImpl`.
402 /// See the documentation for `ConvTranspose2dImpl` class to learn what methods
403 /// it provides, and examples of how to use `ConvTranspose2d` with
404 /// `torch::nn::ConvTranspose2dOptions`. See the documentation for
405 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
406 TORCH_MODULE(ConvTranspose2d);
407 
408 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose3d
409 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
410 
411 /// Applies the ConvTranspose3d function.
412 /// See https://pytorch.org/docs/main/nn.html#torch.nn.ConvTranspose3d to
413 /// learn about the exact behavior of this module.
414 ///
415 /// See the documentation for `torch::nn::ConvTranspose3dOptions` class to learn
416 /// what constructor arguments are supported for this module.
417 ///
418 /// Example:
419 /// ```
420 /// ConvTranspose3d model(ConvTranspose3dOptions(2, 2,
421 /// 2).stride(1).bias(false));
422 /// ```
423 class TORCH_API ConvTranspose3dImpl
424     : public ConvTransposeNdImpl<3, ConvTranspose3dImpl> {
425  public:
ConvTranspose3dImpl(int64_t input_channels,int64_t output_channels,ExpandingArray<3> kernel_size)426   ConvTranspose3dImpl(
427       int64_t input_channels,
428       int64_t output_channels,
429       ExpandingArray<3> kernel_size)
430       : ConvTranspose3dImpl(ConvTranspose3dOptions(
431             input_channels,
432             output_channels,
433             kernel_size)) {}
434   explicit ConvTranspose3dImpl(ConvTranspose3dOptions options_);
435   Tensor forward(
436       const Tensor& input,
437       const std::optional<at::IntArrayRef>& output_size = std::nullopt);
438 
439  protected:
440   FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional<at::IntArrayRef>())})
441 };
442 
443 /// A `ModuleHolder` subclass for `ConvTranspose3dImpl`.
444 /// See the documentation for `ConvTranspose3dImpl` class to learn what methods
445 /// it provides, and examples of how to use `ConvTranspose3d` with
446 /// `torch::nn::ConvTranspose3dOptions`. See the documentation for
447 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
448 TORCH_MODULE(ConvTranspose3d);
449 
450 } // namespace nn
451 } // namespace torch
452