1 #pragma once
2
3 #include <torch/nn/options/conv.h>
4 #include <torch/types.h>
5
6 namespace torch {
7 namespace nn {
8 namespace functional {
9
10 #ifndef DOXYGEN_SHOULD_SKIP_THIS
11 namespace detail {
12
padding_unwrap(enumtype::kValid)13 inline std::string padding_unwrap(enumtype::kValid) {
14 return "valid";
15 }
16
padding_unwrap(enumtype::kSame)17 inline std::string padding_unwrap(enumtype::kSame) {
18 return "same";
19 }
20
21 template <size_t D>
padding_unwrap(const ExpandingArray<D> & array)22 IntArrayRef padding_unwrap(const ExpandingArray<D>& array) {
23 return array;
24 }
25
conv1d(const Tensor & input,const Tensor & weight,const Tensor & bias,ExpandingArray<1> stride,const Conv1dFuncOptions::padding_t & padding,ExpandingArray<1> dilation,int64_t groups)26 inline Tensor conv1d(
27 const Tensor& input,
28 const Tensor& weight,
29 const Tensor& bias,
30 ExpandingArray<1> stride,
31 const Conv1dFuncOptions::padding_t& padding,
32 ExpandingArray<1> dilation,
33 int64_t groups) {
34 return std::visit(
35 [&](const auto& pad) {
36 return torch::conv1d(
37 input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
38 },
39 padding);
40 }
41 } // namespace detail
42 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
43
44 /// See
45 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv1d
46 /// about the exact behavior of this functional.
47 ///
48 /// See the documentation for `torch::nn::functional::Conv1dFuncOptions` class
49 /// to learn what optional arguments are supported for this functional.
50 ///
51 /// Example:
52 /// ```
53 /// namespace F = torch::nn::functional;
54 /// F::conv1d(x, weight, F::Conv1dFuncOptions().stride(1));
55 /// ```
56 inline Tensor conv1d(
57 const Tensor& input,
58 const Tensor& weight,
59 const Conv1dFuncOptions& options = {}) {
60 return detail::conv1d(
61 input,
62 weight,
63 options.bias(),
64 options.stride(),
65 options.padding(),
66 options.dilation(),
67 options.groups());
68 }
69
70 #ifndef DOXYGEN_SHOULD_SKIP_THIS
71 namespace detail {
conv2d(const Tensor & input,const Tensor & weight,const Tensor & bias,ExpandingArray<2> stride,const Conv2dFuncOptions::padding_t & padding,ExpandingArray<2> dilation,int64_t groups)72 inline Tensor conv2d(
73 const Tensor& input,
74 const Tensor& weight,
75 const Tensor& bias,
76 ExpandingArray<2> stride,
77 const Conv2dFuncOptions::padding_t& padding,
78 ExpandingArray<2> dilation,
79 int64_t groups) {
80 return std::visit(
81 [&](const auto& pad) {
82 return torch::conv2d(
83 input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
84 },
85 padding);
86 }
87 } // namespace detail
88 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
89
90 /// See
91 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv2d
92 /// about the exact behavior of this functional.
93 ///
94 /// See the documentation for `torch::nn::functional::Conv2dFuncOptions` class
95 /// to learn what optional arguments are supported for this functional.
96 ///
97 /// Example:
98 /// ```
99 /// namespace F = torch::nn::functional;
100 /// F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1));
101 /// ```
102 inline Tensor conv2d(
103 const Tensor& input,
104 const Tensor& weight,
105 const Conv2dFuncOptions& options = {}) {
106 return detail::conv2d(
107 input,
108 weight,
109 options.bias(),
110 options.stride(),
111 options.padding(),
112 options.dilation(),
113 options.groups());
114 }
115
116 #ifndef DOXYGEN_SHOULD_SKIP_THIS
117 namespace detail {
conv3d(const Tensor & input,const Tensor & weight,const Tensor & bias,ExpandingArray<3> stride,const Conv3dFuncOptions::padding_t & padding,ExpandingArray<3> dilation,int64_t groups)118 inline Tensor conv3d(
119 const Tensor& input,
120 const Tensor& weight,
121 const Tensor& bias,
122 ExpandingArray<3> stride,
123 const Conv3dFuncOptions::padding_t& padding,
124 ExpandingArray<3> dilation,
125 int64_t groups) {
126 return std::visit(
127 [&](const auto& pad) {
128 return torch::conv3d(
129 input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
130 },
131 padding);
132 }
133 } // namespace detail
134 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
135
136 /// See
137 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv3d
138 /// about the exact behavior of this functional.
139 ///
140 /// See the documentation for `torch::nn::functional::Conv3dFuncOptions` class
141 /// to learn what optional arguments are supported for this functional.
142 ///
143 /// Example:
144 /// ```
145 /// namespace F = torch::nn::functional;
146 /// F::conv3d(x, weight, F::Conv3dFuncOptions().stride(1));
147 /// ```
148 inline Tensor conv3d(
149 const Tensor& input,
150 const Tensor& weight,
151 const Conv3dFuncOptions& options = {}) {
152 return detail::conv3d(
153 input,
154 weight,
155 options.bias(),
156 options.stride(),
157 options.padding(),
158 options.dilation(),
159 options.groups());
160 }
161
162 // ============================================================================
163
164 #ifndef DOXYGEN_SHOULD_SKIP_THIS
165 namespace detail {
conv_transpose1d(const Tensor & input,const Tensor & weight,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,int64_t groups,IntArrayRef dilation)166 inline Tensor conv_transpose1d(
167 const Tensor& input,
168 const Tensor& weight,
169 const Tensor& bias,
170 IntArrayRef stride,
171 IntArrayRef padding,
172 IntArrayRef output_padding,
173 int64_t groups,
174 IntArrayRef dilation) {
175 return torch::conv_transpose1d(
176 input, weight, bias, stride, padding, output_padding, groups, dilation);
177 }
178 } // namespace detail
179 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
180
181 /// See
182 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv_transpose1d
183 /// about the exact behavior of this functional.
184 ///
185 /// See the documentation for
186 /// `torch::nn::functional::ConvTranspose1dFuncOptions` class to learn what
187 /// optional arguments are supported for this functional.
188 ///
189 /// Example:
190 /// ```
191 /// namespace F = torch::nn::functional;
192 /// F::conv_transpose1d(x, weight, F::ConvTranspose1dFuncOptions().stride(1));
193 /// ```
194 inline Tensor conv_transpose1d(
195 const Tensor& input,
196 const Tensor& weight,
197 const ConvTranspose1dFuncOptions& options = {}) {
198 return detail::conv_transpose1d(
199 input,
200 weight,
201 options.bias(),
202 options.stride(),
203 options.padding(),
204 options.output_padding(),
205 options.groups(),
206 options.dilation());
207 }
208
209 #ifndef DOXYGEN_SHOULD_SKIP_THIS
210 namespace detail {
conv_transpose2d(const Tensor & input,const Tensor & weight,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,int64_t groups,IntArrayRef dilation)211 inline Tensor conv_transpose2d(
212 const Tensor& input,
213 const Tensor& weight,
214 const Tensor& bias,
215 IntArrayRef stride,
216 IntArrayRef padding,
217 IntArrayRef output_padding,
218 int64_t groups,
219 IntArrayRef dilation) {
220 return torch::conv_transpose2d(
221 input, weight, bias, stride, padding, output_padding, groups, dilation);
222 }
223 } // namespace detail
224 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
225
226 /// See
227 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv_transpose2d
228 /// about the exact behavior of this functional.
229 ///
230 /// See the documentation for
231 /// `torch::nn::functional::ConvTranspose2dFuncOptions` class to learn what
232 /// optional arguments are supported for this functional.
233 ///
234 /// Example:
235 /// ```
236 /// namespace F = torch::nn::functional;
237 /// F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1));
238 /// ```
239 inline Tensor conv_transpose2d(
240 const Tensor& input,
241 const Tensor& weight,
242 const ConvTranspose2dFuncOptions& options = {}) {
243 return detail::conv_transpose2d(
244 input,
245 weight,
246 options.bias(),
247 options.stride(),
248 options.padding(),
249 options.output_padding(),
250 options.groups(),
251 options.dilation());
252 }
253
254 #ifndef DOXYGEN_SHOULD_SKIP_THIS
255 namespace detail {
conv_transpose3d(const Tensor & input,const Tensor & weight,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,int64_t groups,IntArrayRef dilation)256 inline Tensor conv_transpose3d(
257 const Tensor& input,
258 const Tensor& weight,
259 const Tensor& bias,
260 IntArrayRef stride,
261 IntArrayRef padding,
262 IntArrayRef output_padding,
263 int64_t groups,
264 IntArrayRef dilation) {
265 return torch::conv_transpose3d(
266 input, weight, bias, stride, padding, output_padding, groups, dilation);
267 }
268 } // namespace detail
269 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
270
271 /// See
272 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv_transpose3d
273 /// about the exact behavior of this functional.
274 ///
275 /// See the documentation for
276 /// `torch::nn::functional::ConvTranspose3dFuncOptions` class to learn what
277 /// optional arguments are supported for this functional.
278 ///
279 /// Example:
280 /// ```
281 /// namespace F = torch::nn::functional;
282 /// F::conv_transpose3d(x, weight, F::ConvTranspose3dFuncOptions().stride(1));
283 /// ```
284 inline Tensor conv_transpose3d(
285 const Tensor& input,
286 const Tensor& weight,
287 const ConvTranspose3dFuncOptions& options = {}) {
288 return detail::conv_transpose3d(
289 input,
290 weight,
291 options.bias(),
292 options.stride(),
293 options.padding(),
294 options.output_padding(),
295 options.groups(),
296 options.dilation());
297 }
298
299 } // namespace functional
300 } // namespace nn
301 } // namespace torch
302