xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/functional/conv.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/options/conv.h>
4 #include <torch/types.h>
5 
6 namespace torch {
7 namespace nn {
8 namespace functional {
9 
10 #ifndef DOXYGEN_SHOULD_SKIP_THIS
11 namespace detail {
12 
padding_unwrap(enumtype::kValid)13 inline std::string padding_unwrap(enumtype::kValid) {
14   return "valid";
15 }
16 
padding_unwrap(enumtype::kSame)17 inline std::string padding_unwrap(enumtype::kSame) {
18   return "same";
19 }
20 
21 template <size_t D>
padding_unwrap(const ExpandingArray<D> & array)22 IntArrayRef padding_unwrap(const ExpandingArray<D>& array) {
23   return array;
24 }
25 
conv1d(const Tensor & input,const Tensor & weight,const Tensor & bias,ExpandingArray<1> stride,const Conv1dFuncOptions::padding_t & padding,ExpandingArray<1> dilation,int64_t groups)26 inline Tensor conv1d(
27     const Tensor& input,
28     const Tensor& weight,
29     const Tensor& bias,
30     ExpandingArray<1> stride,
31     const Conv1dFuncOptions::padding_t& padding,
32     ExpandingArray<1> dilation,
33     int64_t groups) {
34   return std::visit(
35       [&](const auto& pad) {
36         return torch::conv1d(
37             input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
38       },
39       padding);
40 }
41 } // namespace detail
42 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
43 
44 /// See
45 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv1d
46 /// about the exact behavior of this functional.
47 ///
48 /// See the documentation for `torch::nn::functional::Conv1dFuncOptions` class
49 /// to learn what optional arguments are supported for this functional.
50 ///
51 /// Example:
52 /// ```
53 /// namespace F = torch::nn::functional;
54 /// F::conv1d(x, weight, F::Conv1dFuncOptions().stride(1));
55 /// ```
56 inline Tensor conv1d(
57     const Tensor& input,
58     const Tensor& weight,
59     const Conv1dFuncOptions& options = {}) {
60   return detail::conv1d(
61       input,
62       weight,
63       options.bias(),
64       options.stride(),
65       options.padding(),
66       options.dilation(),
67       options.groups());
68 }
69 
70 #ifndef DOXYGEN_SHOULD_SKIP_THIS
71 namespace detail {
conv2d(const Tensor & input,const Tensor & weight,const Tensor & bias,ExpandingArray<2> stride,const Conv2dFuncOptions::padding_t & padding,ExpandingArray<2> dilation,int64_t groups)72 inline Tensor conv2d(
73     const Tensor& input,
74     const Tensor& weight,
75     const Tensor& bias,
76     ExpandingArray<2> stride,
77     const Conv2dFuncOptions::padding_t& padding,
78     ExpandingArray<2> dilation,
79     int64_t groups) {
80   return std::visit(
81       [&](const auto& pad) {
82         return torch::conv2d(
83             input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
84       },
85       padding);
86 }
87 } // namespace detail
88 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
89 
90 /// See
91 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv2d
92 /// about the exact behavior of this functional.
93 ///
94 /// See the documentation for `torch::nn::functional::Conv2dFuncOptions` class
95 /// to learn what optional arguments are supported for this functional.
96 ///
97 /// Example:
98 /// ```
99 /// namespace F = torch::nn::functional;
100 /// F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1));
101 /// ```
102 inline Tensor conv2d(
103     const Tensor& input,
104     const Tensor& weight,
105     const Conv2dFuncOptions& options = {}) {
106   return detail::conv2d(
107       input,
108       weight,
109       options.bias(),
110       options.stride(),
111       options.padding(),
112       options.dilation(),
113       options.groups());
114 }
115 
116 #ifndef DOXYGEN_SHOULD_SKIP_THIS
117 namespace detail {
conv3d(const Tensor & input,const Tensor & weight,const Tensor & bias,ExpandingArray<3> stride,const Conv3dFuncOptions::padding_t & padding,ExpandingArray<3> dilation,int64_t groups)118 inline Tensor conv3d(
119     const Tensor& input,
120     const Tensor& weight,
121     const Tensor& bias,
122     ExpandingArray<3> stride,
123     const Conv3dFuncOptions::padding_t& padding,
124     ExpandingArray<3> dilation,
125     int64_t groups) {
126   return std::visit(
127       [&](const auto& pad) {
128         return torch::conv3d(
129             input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
130       },
131       padding);
132 }
133 } // namespace detail
134 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
135 
136 /// See
137 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv3d
138 /// about the exact behavior of this functional.
139 ///
140 /// See the documentation for `torch::nn::functional::Conv3dFuncOptions` class
141 /// to learn what optional arguments are supported for this functional.
142 ///
143 /// Example:
144 /// ```
145 /// namespace F = torch::nn::functional;
146 /// F::conv3d(x, weight, F::Conv3dFuncOptions().stride(1));
147 /// ```
148 inline Tensor conv3d(
149     const Tensor& input,
150     const Tensor& weight,
151     const Conv3dFuncOptions& options = {}) {
152   return detail::conv3d(
153       input,
154       weight,
155       options.bias(),
156       options.stride(),
157       options.padding(),
158       options.dilation(),
159       options.groups());
160 }
161 
162 // ============================================================================
163 
164 #ifndef DOXYGEN_SHOULD_SKIP_THIS
165 namespace detail {
conv_transpose1d(const Tensor & input,const Tensor & weight,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,int64_t groups,IntArrayRef dilation)166 inline Tensor conv_transpose1d(
167     const Tensor& input,
168     const Tensor& weight,
169     const Tensor& bias,
170     IntArrayRef stride,
171     IntArrayRef padding,
172     IntArrayRef output_padding,
173     int64_t groups,
174     IntArrayRef dilation) {
175   return torch::conv_transpose1d(
176       input, weight, bias, stride, padding, output_padding, groups, dilation);
177 }
178 } // namespace detail
179 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
180 
181 /// See
182 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv_transpose1d
183 /// about the exact behavior of this functional.
184 ///
185 /// See the documentation for
186 /// `torch::nn::functional::ConvTranspose1dFuncOptions` class to learn what
187 /// optional arguments are supported for this functional.
188 ///
189 /// Example:
190 /// ```
191 /// namespace F = torch::nn::functional;
192 /// F::conv_transpose1d(x, weight, F::ConvTranspose1dFuncOptions().stride(1));
193 /// ```
194 inline Tensor conv_transpose1d(
195     const Tensor& input,
196     const Tensor& weight,
197     const ConvTranspose1dFuncOptions& options = {}) {
198   return detail::conv_transpose1d(
199       input,
200       weight,
201       options.bias(),
202       options.stride(),
203       options.padding(),
204       options.output_padding(),
205       options.groups(),
206       options.dilation());
207 }
208 
209 #ifndef DOXYGEN_SHOULD_SKIP_THIS
210 namespace detail {
conv_transpose2d(const Tensor & input,const Tensor & weight,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,int64_t groups,IntArrayRef dilation)211 inline Tensor conv_transpose2d(
212     const Tensor& input,
213     const Tensor& weight,
214     const Tensor& bias,
215     IntArrayRef stride,
216     IntArrayRef padding,
217     IntArrayRef output_padding,
218     int64_t groups,
219     IntArrayRef dilation) {
220   return torch::conv_transpose2d(
221       input, weight, bias, stride, padding, output_padding, groups, dilation);
222 }
223 } // namespace detail
224 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
225 
226 /// See
227 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv_transpose2d
228 /// about the exact behavior of this functional.
229 ///
230 /// See the documentation for
231 /// `torch::nn::functional::ConvTranspose2dFuncOptions` class to learn what
232 /// optional arguments are supported for this functional.
233 ///
234 /// Example:
235 /// ```
236 /// namespace F = torch::nn::functional;
237 /// F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1));
238 /// ```
239 inline Tensor conv_transpose2d(
240     const Tensor& input,
241     const Tensor& weight,
242     const ConvTranspose2dFuncOptions& options = {}) {
243   return detail::conv_transpose2d(
244       input,
245       weight,
246       options.bias(),
247       options.stride(),
248       options.padding(),
249       options.output_padding(),
250       options.groups(),
251       options.dilation());
252 }
253 
254 #ifndef DOXYGEN_SHOULD_SKIP_THIS
255 namespace detail {
conv_transpose3d(const Tensor & input,const Tensor & weight,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,int64_t groups,IntArrayRef dilation)256 inline Tensor conv_transpose3d(
257     const Tensor& input,
258     const Tensor& weight,
259     const Tensor& bias,
260     IntArrayRef stride,
261     IntArrayRef padding,
262     IntArrayRef output_padding,
263     int64_t groups,
264     IntArrayRef dilation) {
265   return torch::conv_transpose3d(
266       input, weight, bias, stride, padding, output_padding, groups, dilation);
267 }
268 } // namespace detail
269 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
270 
271 /// See
272 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv_transpose3d
273 /// about the exact behavior of this functional.
274 ///
275 /// See the documentation for
276 /// `torch::nn::functional::ConvTranspose3dFuncOptions` class to learn what
277 /// optional arguments are supported for this functional.
278 ///
279 /// Example:
280 /// ```
281 /// namespace F = torch::nn::functional;
282 /// F::conv_transpose3d(x, weight, F::ConvTranspose3dFuncOptions().stride(1));
283 /// ```
284 inline Tensor conv_transpose3d(
285     const Tensor& input,
286     const Tensor& weight,
287     const ConvTranspose3dFuncOptions& options = {}) {
288   return detail::conv_transpose3d(
289       input,
290       weight,
291       options.bias(),
292       options.stride(),
293       options.padding(),
294       options.output_padding(),
295       options.groups(),
296       options.dilation());
297 }
298 
299 } // namespace functional
300 } // namespace nn
301 } // namespace torch
302