1 #pragma once
2
3 #include <torch/nn/options/dropout.h>
4
5 #include <utility>
6
7 namespace torch {
8 namespace nn {
9 namespace functional {
10
11 #ifndef DOXYGEN_SHOULD_SKIP_THIS
12 namespace detail {
13
dropout(Tensor input,double p,bool training,bool inplace)14 inline Tensor dropout(Tensor input, double p, bool training, bool inplace) {
15 TORCH_CHECK(
16 p >= 0. && p <= 1.,
17 "dropout probability has to be between 0 and 1, but got ",
18 p);
19 if (inplace) {
20 return torch::dropout_(input, p, training);
21 } else {
22 return torch::dropout(input, p, training);
23 }
24 }
25
26 } // namespace detail
27 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
28
29 /// See
30 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.dropout
31 /// about the exact behavior of this functional.
32 ///
33 /// See the documentation for `torch::nn::functional::DropoutFuncOptions` class
34 /// to learn what optional arguments are supported for this functional.
35 ///
36 /// Example:
37 /// ```
38 /// namespace F = torch::nn::functional;
39 /// F::dropout(input, F::DropoutFuncOptions().p(0.5));
40 /// ```
41 inline Tensor dropout(Tensor input, const DropoutFuncOptions& options = {}) {
42 return detail::dropout(
43 std::move(input), options.p(), options.training(), options.inplace());
44 }
45
46 // ============================================================================
47
48 #ifndef DOXYGEN_SHOULD_SKIP_THIS
49 namespace detail {
50
51 template <int64_t unbatched_dim, int64_t batched_dim>
_dropoutNd_helper(Tensor input,double p,bool training,bool inplace,const char * fn_name)52 inline Tensor _dropoutNd_helper(
53 Tensor input,
54 double p,
55 bool training,
56 bool inplace,
57 const char* fn_name) {
58 TORCH_CHECK(
59 p >= 0. && p <= 1.,
60 "dropout probability has to be between 0 and 1, but got ",
61 p);
62
63 auto inp_dim = input.dim();
64 auto is_batched = inp_dim == batched_dim;
65 if (!is_batched) {
66 if (inplace) {
67 input = input.unsqueeze_(0);
68 } else {
69 input = input.unsqueeze(0);
70 }
71 }
72
73 Tensor result;
74 if (inplace) {
75 result = torch::feature_dropout_(input, p, training);
76 } else {
77 result = torch::feature_dropout(input, p, training);
78 }
79
80 if (!is_batched) {
81 if (inplace) {
82 result = result.squeeze_(0);
83 } else {
84 result = result.squeeze(0);
85 }
86 }
87 return result;
88 }
89
dropout2d(Tensor input,double p,bool training,bool inplace)90 inline Tensor dropout2d(Tensor input, double p, bool training, bool inplace) {
91 return _dropoutNd_helper<3, 4>(
92 std::move(input), p, training, inplace, "dropout2d");
93 }
94
95 } // namespace detail
96 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
97
98 /// See
99 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.dropout2d
100 /// about the exact behavior of this functional.
101 ///
102 /// See the documentation for `torch::nn::functional::Dropout2dFuncOptions`
103 /// class to learn what optional arguments are supported for this functional.
104 ///
105 /// Example:
106 /// ```
107 /// namespace F = torch::nn::functional;
108 /// F::dropout2d(input, F::Dropout2dFuncOptions().p(0.5));
109 /// ```
110 inline Tensor dropout2d(
111 Tensor input,
112 const Dropout2dFuncOptions& options = {}) {
113 return detail::dropout2d(
114 std::move(input), options.p(), options.training(), options.inplace());
115 }
116
117 // ============================================================================
118
119 #ifndef DOXYGEN_SHOULD_SKIP_THIS
120 namespace detail {
121
dropout3d(Tensor input,double p,bool training,bool inplace)122 inline Tensor dropout3d(Tensor input, double p, bool training, bool inplace) {
123 return _dropoutNd_helper<4, 5>(
124 std::move(input), p, training, inplace, "dropout3d");
125 }
126
127 } // namespace detail
128 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
129
130 /// See
131 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.dropout3d
132 /// about the exact behavior of this functional.
133 ///
134 /// See the documentation for `torch::nn::functional::Dropout3dFuncOptions`
135 /// class to learn what optional arguments are supported for this functional.
136 ///
137 /// Example:
138 /// ```
139 /// namespace F = torch::nn::functional;
140 /// F::dropout3d(input, F::Dropout3dFuncOptions().p(0.5));
141 /// ```
142 inline Tensor dropout3d(
143 Tensor input,
144 const Dropout3dFuncOptions& options = {}) {
145 return detail::dropout3d(
146 std::move(input), options.p(), options.training(), options.inplace());
147 }
148
149 // ============================================================================
150
151 #ifndef DOXYGEN_SHOULD_SKIP_THIS
152 namespace detail {
153
alpha_dropout(Tensor input,double p,bool training,bool inplace)154 inline Tensor alpha_dropout(
155 Tensor input,
156 double p,
157 bool training,
158 bool inplace) {
159 if (p < 0. || p > 1.) {
160 TORCH_CHECK(
161 false, "dropout probability has to be between 0 and 1, but got ", p);
162 }
163 return inplace ? torch::alpha_dropout_(input, p, training)
164 : torch::alpha_dropout(input, p, training);
165 }
166
167 } // namespace detail
168 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
169
170 /// See
171 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.alpha_dropout
172 /// about the exact behavior of this functional.
173 ///
174 /// See the documentation for `torch::nn::functional::AlphaDropoutFuncOptions`
175 /// class to learn what optional arguments are supported for this functional.
176 ///
177 /// Example:
178 /// ```
179 /// namespace F = torch::nn::functional;
180 /// F::alpha_dropout(input,
181 /// F::AlphaDropoutFuncOptions().p(0.5).training(false));
182 /// ```
183 inline Tensor alpha_dropout(
184 Tensor input,
185 const AlphaDropoutFuncOptions& options = {}) {
186 return detail::alpha_dropout(
187 std::move(input), options.p(), options.training(), options.inplace());
188 }
189
190 // ============================================================================
191
192 #ifndef DOXYGEN_SHOULD_SKIP_THIS
193 namespace detail {
194
feature_alpha_dropout(Tensor input,double p,bool training,bool inplace)195 inline Tensor feature_alpha_dropout(
196 Tensor input,
197 double p,
198 bool training,
199 bool inplace) {
200 if (p < 0. || p > 1.) {
201 TORCH_CHECK(
202 false, "dropout probability has to be between 0 and 1, but got ", p);
203 }
204 return inplace ? torch::feature_alpha_dropout_(input, p, training)
205 : torch::feature_alpha_dropout(input, p, training);
206 }
207
208 } // namespace detail
209 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
210
211 /// See
212 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.feature_alpha_dropout
213 /// about the exact behavior of this functional.
214 ///
215 /// See the documentation for
216 /// `torch::nn::functional::FeatureAlphaDropoutFuncOptions` class to learn what
217 /// optional arguments are supported for this functional.
218 ///
219 /// Example:
220 /// ```
221 /// namespace F = torch::nn::functional;
222 /// F::feature_alpha_dropout(input,
223 /// F::FeatureAlphaDropoutFuncOptions().p(0.5).training(false));
224 /// ```
225 inline Tensor feature_alpha_dropout(
226 Tensor input,
227 const FeatureAlphaDropoutFuncOptions& options = {}) {
228 return detail::feature_alpha_dropout(
229 std::move(input), options.p(), options.training(), options.inplace());
230 }
231
232 } // namespace functional
233 } // namespace nn
234 } // namespace torch
235