xref: /aosp_15_r20/external/pytorch/test/nn/test_dropout.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2import itertools
3import random
4import unittest
5from itertools import product
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10from torch.testing._internal.common_cuda import TEST_CUDA
11from torch.testing._internal.common_device_type import (
12    expectedFailureXLA,
13    instantiate_device_type_tests,
14)
15from torch.testing._internal.common_nn import freeze_rng_state, NNTestCase
16from torch.testing._internal.common_utils import (
17    instantiate_parametrized_tests,
18    run_tests,
19    set_default_dtype,
20    TEST_PRIVATEUSE1,
21)
22
23
24class TestDropoutNN(NNTestCase):
25    _do_cuda_memory_leak_check = True
26    _do_cuda_non_default_stream = True
27
28    def _test_alpha_dropout(self, cls, input):
29        mean = input.mean()
30        std = input.std()
31
32        for p in [0.2, 0.5, 0.8]:
33            module = cls(p)
34            input_var = input.detach().clone().requires_grad_()
35            output = module(input_var)
36            # output mean should be close to input mean
37            self.assertLess(abs(output.data.mean() - mean), 0.1)
38            # output std should be close to input std
39            self.assertLess(abs(output.data.std() - std), 0.1)
40            output.backward(input)
41
42    def test_AlphaDropout(self):
43        # generate random tensor with zero mean and unit std
44        input = torch.randn(5000)
45        self._test_alpha_dropout(nn.AlphaDropout, input)
46
47    def test_FeatureAlphaDropout(self):
48        b = random.randint(1, 5)
49        w = random.randint(1, 5)
50        h = random.randint(1, 5)
51        d = random.randint(1, 2)
52        num_features = 1000
53        input = torch.randn(num_features, b, d, w, h)
54        self._test_alpha_dropout(nn.FeatureAlphaDropout, input)
55
56        # no batch dims
57        input = torch.randn(50, 20, 64, 64)
58        self._test_alpha_dropout(nn.FeatureAlphaDropout, input)
59
60    @unittest.skipIf(
61        not (TEST_CUDA or TEST_PRIVATEUSE1), "CUDA and PRIVATEUSE1 unavailable"
62    )
63    def test_native_dropout_corner_case(self):
64        if TEST_CUDA:
65            device = "cuda"
66        elif TEST_PRIVATEUSE1:
67            device = torch._C._get_privateuse1_backend_name()
68        for train in [True, False]:
69            for p in [0.0, 1.0]:
70                for current_device in [device, "cpu"]:
71                    x = torch.randn(5).to(device=current_device).requires_grad_()
72                    x_ref = x.detach().requires_grad_()
73                    o = torch.native_dropout(x, p, train)[0]
74                    o_ref = torch.dropout(x_ref, p, train)
75                    o.sum().backward()
76                    o_ref.sum().backward()
77                    assert o.equal(o_ref)
78                    assert x.grad.equal(x_ref.grad)
79
80    def test_invalid_dropout_p(self):
81        v = torch.ones(1)
82        self.assertRaises(ValueError, lambda: nn.Dropout(-0.1))
83        self.assertRaises(ValueError, lambda: nn.Dropout(1.1))
84        self.assertRaises(ValueError, lambda: nn.Dropout1d(-0.1))
85        self.assertRaises(ValueError, lambda: nn.Dropout1d(1.1))
86        self.assertRaises(ValueError, lambda: nn.Dropout2d(-0.1))
87        self.assertRaises(ValueError, lambda: nn.Dropout2d(1.1))
88        self.assertRaises(ValueError, lambda: nn.Dropout3d(-0.1))
89        self.assertRaises(ValueError, lambda: nn.Dropout3d(1.1))
90        self.assertRaises(ValueError, lambda: F.dropout(v, -0.1))
91        self.assertRaises(ValueError, lambda: F.dropout(v, 1.1))
92
93
94class TestDropoutNNDeviceType(NNTestCase):
95    def _test_dropout(self, cls, device, input, memory_format=torch.contiguous_format):
96        p = 0.2
97        input = input.to(device).fill_(1 - p)
98
99        module = cls(p)
100        input_var = input.clone(memory_format=memory_format).requires_grad_()
101        output = module(input_var)
102        self.assertTrue(output.is_contiguous(memory_format=memory_format))
103        self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
104        output.backward(input)
105        self.assertTrue(input_var.grad.is_contiguous(memory_format=memory_format))
106        self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
107
108        module = cls(p, True)
109        input_var = input.clone(memory_format=memory_format).requires_grad_()
110        output = module(input_var + 0)
111        self.assertTrue(output.is_contiguous(memory_format=memory_format))
112        self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
113        output.backward(input)
114        self.assertTrue(input_var.grad.is_contiguous(memory_format=memory_format))
115        self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
116
117        # check eval mode doesn't change anything
118        for inplace in [True, False]:
119            module = cls(p, inplace).eval()
120            self.assertEqual(input, module(input))
121
122        # Check that these don't raise errors
123        module.__repr__()
124        str(module)
125
126    def _test_dropout_discontiguous(
127        self, cls, device, memory_format=torch.contiguous_format
128    ):
129        # In this test, we verify that dropout preserves the layout and data for different memory formats.
130        # We check whether, we get same values for the output of dropout, when the probability
131        # of dropout is 0 or very close to 0.
132        # Reference: https://github.com/pytorch/pytorch/issues/47176
133        close_to_zero_p = 1e-10  # Should be almost zero but not zero, as for p=0 different path is taken
134        for p in [0, close_to_zero_p]:
135            inp = torch.ones(2, 3, 3, 3, device=device)
136            inp_discontiguous = torch.empty(
137                2, 3, 3, 6, device=device, memory_format=memory_format
138            )[..., ::2]
139            inp_discontiguous.copy_(inp)
140            mod = cls(p=p)
141            out = mod(inp_discontiguous)
142            if p != 0:  # Zero will keep strides as is based on input.
143                # When prob == 0, input stride (54, 18, 6, 2) -> output stride (54, 18, 6, 2)
144                # When prob != 0, input stride (54, 18, 6, 2) -> output stride (27, 9, 3, 1)
145                self.assertTrue(out.is_contiguous(memory_format=memory_format))
146            self.assertEqual(inp_discontiguous, out)
147
148    def _test_dropout_stride_mean_preserve(self, cls, device):
149        def invert_perm(p):
150            d = {x: i for i, x in enumerate(p)}
151            return (d[0], d[1], d[2], d[3])
152
153        inp = torch.ones(2, 3, 4, 5, device=device)
154        shifts = [(0, 0), (1, 0), (0, 1), (1, 1)]
155        for perm in itertools.permutations((0, 1, 2, 3), r=4):
156            for shift in shifts:
157                for p in [1e-10, 0.3, 0.5, 0.7]:
158                    mod = cls(p=p)
159                    permuted_inp = (
160                        inp.permute(perm).contiguous().permute(invert_perm(perm))
161                    )
162                    permuted_inp = permuted_inp[shift[0] :, shift[1] :, :, :]
163                    out = mod(permuted_inp)
164
165                    self.assertTrue(out.permute(perm).is_contiguous())
166                    self.assertEqual(inp.mean(), out.mean(), rtol=0.5, atol=0.5)
167                    if p == 1e-10:
168                        self.assertEqual(permuted_inp, out)
169                    else:
170                        self.assertNotEqual(permuted_inp, out)
171
172    def test_Dropout(self, device):
173        input = torch.empty(1000)
174        self._test_dropout(nn.Dropout, device, input)
175
176        self._test_dropout_discontiguous(nn.Dropout, device)
177        self._test_dropout_discontiguous(
178            nn.Dropout, device, memory_format=torch.channels_last
179        )
180
181        self._test_dropout_stride_mean_preserve(nn.Dropout, device)
182
183        if self.device_type == "cuda" or self.device_type == "cpu":
184            input = input.bfloat16()
185            self._test_dropout(nn.Dropout, device, input)
186
187    def _test_dropoutNd_no_batch(self, dropout, input):
188        input_clone = input.clone()
189        with freeze_rng_state():
190            res_no_batch = dropout(input)
191
192        with freeze_rng_state():
193            res_batched = dropout(input_clone.unsqueeze(0)).squeeze(0)
194
195        self.assertEqual(res_no_batch, res_batched)
196
197    def _test_dropoutNd_channel_zero(self, dropout, input):
198        # Verify the number of zeros in a channel is 0 or the number of elements in the channel
199        # for a fully positive input tensor
200        shape = input.shape
201        B = shape[0]
202        C = shape[1]
203        channel_numel = torch.tensor(shape[2:]).prod()
204        result = dropout(input)
205
206        for b, c in product(range(B), range(C)):
207            self.assertTrue(result[b, c].count_nonzero() in (0, channel_numel))
208
209    @expectedFailureXLA  # seems like freeze_rng_state is not honoured by XLA
210    def test_Dropout1d(self, device):
211        with set_default_dtype(torch.double):
212            N, C, L = (
213                random.randint(10, 15),
214                random.randint(10, 15),
215                random.randint(10, 15),
216            )
217            input = torch.empty(N, C, L)
218            self._test_dropout(nn.Dropout1d, device, input)
219
220            with self.assertRaisesRegex(
221                RuntimeError, "Expected 2D or 3D input, but received a 4D input"
222            ):
223                nn.Dropout1d(p=0.5)(torch.rand(1, 2, 2, 2, device=device))
224
225            with self.assertRaisesRegex(
226                RuntimeError, "Expected 2D or 3D input, but received a 1D input"
227            ):
228                nn.Dropout1d(p=0.5)(torch.rand(2, device=device))
229
230            # no batch dims
231            input = torch.rand(50, 2, device=device)
232            self._test_dropoutNd_no_batch(nn.Dropout1d(p=0.5), input)
233            self._test_dropoutNd_no_batch(nn.Dropout1d(p=0.5, inplace=True), input)
234
235            # check that complete channels are dropped
236            input = torch.ones(10, 4, 2, device=device)
237            self._test_dropoutNd_channel_zero(nn.Dropout1d(p=0.5), input)
238            self._test_dropoutNd_channel_zero(nn.Dropout1d(p=0.5, inplace=True), input)
239
240    @expectedFailureXLA  # seems like freeze_rng_state is not honoured by XLA
241    def test_Dropout2d(self, device):
242        b = random.randint(1, 5)
243        w = random.randint(1, 5)
244        h = random.randint(1, 5)
245        num_features = 1000
246        input = torch.empty(num_features, b, w, h)
247        self._test_dropout(nn.Dropout2d, device, input)
248        self._test_dropout(
249            nn.Dropout2d, device, input, memory_format=torch.channels_last
250        )
251
252        self._test_dropout_discontiguous(nn.Dropout2d, device)
253        self._test_dropout_discontiguous(
254            nn.Dropout2d, device, memory_format=torch.channels_last
255        )
256
257        with self.assertWarnsRegex(UserWarning, "Received a 5-D input to dropout2d"):
258            nn.Dropout2d(p=0.5)(torch.rand(1, 2, 2, 2, 2, device=device))
259
260        with self.assertWarnsRegex(UserWarning, "Received a 2-D input to dropout2d"):
261            nn.Dropout2d(p=0.5)(torch.rand(1, 2, device=device))
262
263        # TODO: Uncomment these lines once no-batch-dim inputs are supported.
264        # For now, the historical dropout1d behavior is performed for 3D inputs.
265        # See https://github.com/pytorch/pytorch/issues/77081
266
267        # input = torch.rand(50, 2, 2, device=device)
268        # self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5), input)
269        # self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5, inplace=True), input)
270
271        with self.assertWarnsRegex(
272            UserWarning, "assuming that channel-wise 1D dropout behavior is desired"
273        ):
274            nn.Dropout2d(p=0.5)(torch.rand(1, 2, 2, device=device))
275
276        # check that complete channels are dropped
277        input = torch.ones(10, 4, 2, 2, device=device)
278        self._test_dropoutNd_channel_zero(nn.Dropout2d(p=0.5), input)
279        self._test_dropoutNd_channel_zero(nn.Dropout2d(p=0.5, inplace=True), input)
280
281    @expectedFailureXLA  # seems like freeze_rng_state is not honoured by XLA
282    def test_Dropout3d(self, device):
283        b = random.randint(1, 5)
284        w = random.randint(1, 5)
285        h = random.randint(1, 5)
286        d = random.randint(1, 2)
287        num_features = 1000
288        input = torch.empty(num_features, b, d, w, h)
289        self._test_dropout(nn.Dropout3d, device, input)
290
291        self._test_dropout_discontiguous(nn.Dropout3d, device)
292        self._test_dropout_discontiguous(
293            nn.Dropout3d, device, memory_format=torch.channels_last
294        )
295
296        with self.assertWarnsRegex(UserWarning, "Received a 6-D input to dropout3d"):
297            nn.Dropout3d(p=0.5)(torch.rand(1, 2, 2, 2, 2, 2, device=device))
298
299        with self.assertWarnsRegex(UserWarning, "Received a 3-D input to dropout3d"):
300            nn.Dropout3d(p=0.5)(torch.rand(1, 2, 2, device=device))
301
302        # no batch dims
303        input = torch.rand(50, 2, 2, 2, device=device)
304        self._test_dropoutNd_no_batch(nn.Dropout3d(p=0.5), input)
305        self._test_dropoutNd_no_batch(nn.Dropout3d(p=0.5, inplace=True), input)
306
307        # check that complete channels are dropped
308        input = torch.ones(10, 4, 2, 2, 2, device=device)
309        self._test_dropoutNd_channel_zero(nn.Dropout3d(p=0.5), input)
310        self._test_dropoutNd_channel_zero(nn.Dropout3d(p=0.5, inplace=True), input)
311
312    def test_empty_dropout(self, device):
313        x = torch.tensor([]).to(device)
314        out = torch.nn.functional.dropout(x)
315        self.assertEqual(out.size(), x.size())
316
317
318instantiate_device_type_tests(TestDropoutNNDeviceType, globals())
319instantiate_parametrized_tests(TestDropoutNN)
320
321if __name__ == "__main__":
322    run_tests()
323