xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/functional/dropout.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/options/dropout.h>
4 
5 #include <utility>
6 
7 namespace torch {
8 namespace nn {
9 namespace functional {
10 
11 #ifndef DOXYGEN_SHOULD_SKIP_THIS
12 namespace detail {
13 
dropout(Tensor input,double p,bool training,bool inplace)14 inline Tensor dropout(Tensor input, double p, bool training, bool inplace) {
15   TORCH_CHECK(
16       p >= 0. && p <= 1.,
17       "dropout probability has to be between 0 and 1, but got ",
18       p);
19   if (inplace) {
20     return torch::dropout_(input, p, training);
21   } else {
22     return torch::dropout(input, p, training);
23   }
24 }
25 
26 } // namespace detail
27 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
28 
29 /// See
30 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.dropout
31 /// about the exact behavior of this functional.
32 ///
33 /// See the documentation for `torch::nn::functional::DropoutFuncOptions` class
34 /// to learn what optional arguments are supported for this functional.
35 ///
36 /// Example:
37 /// ```
38 /// namespace F = torch::nn::functional;
39 /// F::dropout(input, F::DropoutFuncOptions().p(0.5));
40 /// ```
41 inline Tensor dropout(Tensor input, const DropoutFuncOptions& options = {}) {
42   return detail::dropout(
43       std::move(input), options.p(), options.training(), options.inplace());
44 }
45 
46 // ============================================================================
47 
48 #ifndef DOXYGEN_SHOULD_SKIP_THIS
49 namespace detail {
50 
51 template <int64_t unbatched_dim, int64_t batched_dim>
_dropoutNd_helper(Tensor input,double p,bool training,bool inplace,const char * fn_name)52 inline Tensor _dropoutNd_helper(
53     Tensor input,
54     double p,
55     bool training,
56     bool inplace,
57     const char* fn_name) {
58   TORCH_CHECK(
59       p >= 0. && p <= 1.,
60       "dropout probability has to be between 0 and 1, but got ",
61       p);
62 
63   auto inp_dim = input.dim();
64   auto is_batched = inp_dim == batched_dim;
65   if (!is_batched) {
66     if (inplace) {
67       input = input.unsqueeze_(0);
68     } else {
69       input = input.unsqueeze(0);
70     }
71   }
72 
73   Tensor result;
74   if (inplace) {
75     result = torch::feature_dropout_(input, p, training);
76   } else {
77     result = torch::feature_dropout(input, p, training);
78   }
79 
80   if (!is_batched) {
81     if (inplace) {
82       result = result.squeeze_(0);
83     } else {
84       result = result.squeeze(0);
85     }
86   }
87   return result;
88 }
89 
dropout2d(Tensor input,double p,bool training,bool inplace)90 inline Tensor dropout2d(Tensor input, double p, bool training, bool inplace) {
91   return _dropoutNd_helper<3, 4>(
92       std::move(input), p, training, inplace, "dropout2d");
93 }
94 
95 } // namespace detail
96 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
97 
98 /// See
99 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.dropout2d
100 /// about the exact behavior of this functional.
101 ///
102 /// See the documentation for `torch::nn::functional::Dropout2dFuncOptions`
103 /// class to learn what optional arguments are supported for this functional.
104 ///
105 /// Example:
106 /// ```
107 /// namespace F = torch::nn::functional;
108 /// F::dropout2d(input, F::Dropout2dFuncOptions().p(0.5));
109 /// ```
110 inline Tensor dropout2d(
111     Tensor input,
112     const Dropout2dFuncOptions& options = {}) {
113   return detail::dropout2d(
114       std::move(input), options.p(), options.training(), options.inplace());
115 }
116 
117 // ============================================================================
118 
119 #ifndef DOXYGEN_SHOULD_SKIP_THIS
120 namespace detail {
121 
dropout3d(Tensor input,double p,bool training,bool inplace)122 inline Tensor dropout3d(Tensor input, double p, bool training, bool inplace) {
123   return _dropoutNd_helper<4, 5>(
124       std::move(input), p, training, inplace, "dropout3d");
125 }
126 
127 } // namespace detail
128 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
129 
130 /// See
131 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.dropout3d
132 /// about the exact behavior of this functional.
133 ///
134 /// See the documentation for `torch::nn::functional::Dropout3dFuncOptions`
135 /// class to learn what optional arguments are supported for this functional.
136 ///
137 /// Example:
138 /// ```
139 /// namespace F = torch::nn::functional;
140 /// F::dropout3d(input, F::Dropout3dFuncOptions().p(0.5));
141 /// ```
142 inline Tensor dropout3d(
143     Tensor input,
144     const Dropout3dFuncOptions& options = {}) {
145   return detail::dropout3d(
146       std::move(input), options.p(), options.training(), options.inplace());
147 }
148 
149 // ============================================================================
150 
151 #ifndef DOXYGEN_SHOULD_SKIP_THIS
152 namespace detail {
153 
alpha_dropout(Tensor input,double p,bool training,bool inplace)154 inline Tensor alpha_dropout(
155     Tensor input,
156     double p,
157     bool training,
158     bool inplace) {
159   if (p < 0. || p > 1.) {
160     TORCH_CHECK(
161         false, "dropout probability has to be between 0 and 1, but got ", p);
162   }
163   return inplace ? torch::alpha_dropout_(input, p, training)
164                  : torch::alpha_dropout(input, p, training);
165 }
166 
167 } // namespace detail
168 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
169 
170 /// See
171 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.alpha_dropout
172 /// about the exact behavior of this functional.
173 ///
174 /// See the documentation for `torch::nn::functional::AlphaDropoutFuncOptions`
175 /// class to learn what optional arguments are supported for this functional.
176 ///
177 /// Example:
178 /// ```
179 /// namespace F = torch::nn::functional;
180 /// F::alpha_dropout(input,
181 /// F::AlphaDropoutFuncOptions().p(0.5).training(false));
182 /// ```
183 inline Tensor alpha_dropout(
184     Tensor input,
185     const AlphaDropoutFuncOptions& options = {}) {
186   return detail::alpha_dropout(
187       std::move(input), options.p(), options.training(), options.inplace());
188 }
189 
190 // ============================================================================
191 
192 #ifndef DOXYGEN_SHOULD_SKIP_THIS
193 namespace detail {
194 
feature_alpha_dropout(Tensor input,double p,bool training,bool inplace)195 inline Tensor feature_alpha_dropout(
196     Tensor input,
197     double p,
198     bool training,
199     bool inplace) {
200   if (p < 0. || p > 1.) {
201     TORCH_CHECK(
202         false, "dropout probability has to be between 0 and 1, but got ", p);
203   }
204   return inplace ? torch::feature_alpha_dropout_(input, p, training)
205                  : torch::feature_alpha_dropout(input, p, training);
206 }
207 
208 } // namespace detail
209 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
210 
211 /// See
212 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.feature_alpha_dropout
213 /// about the exact behavior of this functional.
214 ///
215 /// See the documentation for
216 /// `torch::nn::functional::FeatureAlphaDropoutFuncOptions` class to learn what
217 /// optional arguments are supported for this functional.
218 ///
219 /// Example:
220 /// ```
221 /// namespace F = torch::nn::functional;
222 /// F::feature_alpha_dropout(input,
223 /// F::FeatureAlphaDropoutFuncOptions().p(0.5).training(false));
224 /// ```
225 inline Tensor feature_alpha_dropout(
226     Tensor input,
227     const FeatureAlphaDropoutFuncOptions& options = {}) {
228   return detail::feature_alpha_dropout(
229       std::move(input), options.p(), options.training(), options.inplace());
230 }
231 
232 } // namespace functional
233 } // namespace nn
234 } // namespace torch
235