xref: /aosp_15_r20/external/pytorch/test/nn/test_dropout.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"]
2*da0073e9SAndroid Build Coastguard Workerimport itertools
3*da0073e9SAndroid Build Coastguard Workerimport random
4*da0073e9SAndroid Build Coastguard Workerimport unittest
5*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
9*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
12*da0073e9SAndroid Build Coastguard Worker    expectedFailureXLA,
13*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
14*da0073e9SAndroid Build Coastguard Worker)
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import freeze_rng_state, NNTestCase
16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
17*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
18*da0073e9SAndroid Build Coastguard Worker    run_tests,
19*da0073e9SAndroid Build Coastguard Worker    set_default_dtype,
20*da0073e9SAndroid Build Coastguard Worker    TEST_PRIVATEUSE1,
21*da0073e9SAndroid Build Coastguard Worker)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerclass TestDropoutNN(NNTestCase):
25*da0073e9SAndroid Build Coastguard Worker    _do_cuda_memory_leak_check = True
26*da0073e9SAndroid Build Coastguard Worker    _do_cuda_non_default_stream = True
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    def _test_alpha_dropout(self, cls, input):
29*da0073e9SAndroid Build Coastguard Worker        mean = input.mean()
30*da0073e9SAndroid Build Coastguard Worker        std = input.std()
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker        for p in [0.2, 0.5, 0.8]:
33*da0073e9SAndroid Build Coastguard Worker            module = cls(p)
34*da0073e9SAndroid Build Coastguard Worker            input_var = input.detach().clone().requires_grad_()
35*da0073e9SAndroid Build Coastguard Worker            output = module(input_var)
36*da0073e9SAndroid Build Coastguard Worker            # output mean should be close to input mean
37*da0073e9SAndroid Build Coastguard Worker            self.assertLess(abs(output.data.mean() - mean), 0.1)
38*da0073e9SAndroid Build Coastguard Worker            # output std should be close to input std
39*da0073e9SAndroid Build Coastguard Worker            self.assertLess(abs(output.data.std() - std), 0.1)
40*da0073e9SAndroid Build Coastguard Worker            output.backward(input)
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    def test_AlphaDropout(self):
43*da0073e9SAndroid Build Coastguard Worker        # generate random tensor with zero mean and unit std
44*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(5000)
45*da0073e9SAndroid Build Coastguard Worker        self._test_alpha_dropout(nn.AlphaDropout, input)
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    def test_FeatureAlphaDropout(self):
48*da0073e9SAndroid Build Coastguard Worker        b = random.randint(1, 5)
49*da0073e9SAndroid Build Coastguard Worker        w = random.randint(1, 5)
50*da0073e9SAndroid Build Coastguard Worker        h = random.randint(1, 5)
51*da0073e9SAndroid Build Coastguard Worker        d = random.randint(1, 2)
52*da0073e9SAndroid Build Coastguard Worker        num_features = 1000
53*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(num_features, b, d, w, h)
54*da0073e9SAndroid Build Coastguard Worker        self._test_alpha_dropout(nn.FeatureAlphaDropout, input)
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker        # no batch dims
57*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(50, 20, 64, 64)
58*da0073e9SAndroid Build Coastguard Worker        self._test_alpha_dropout(nn.FeatureAlphaDropout, input)
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
61*da0073e9SAndroid Build Coastguard Worker        not (TEST_CUDA or TEST_PRIVATEUSE1), "CUDA and PRIVATEUSE1 unavailable"
62*da0073e9SAndroid Build Coastguard Worker    )
63*da0073e9SAndroid Build Coastguard Worker    def test_native_dropout_corner_case(self):
64*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
65*da0073e9SAndroid Build Coastguard Worker            device = "cuda"
66*da0073e9SAndroid Build Coastguard Worker        elif TEST_PRIVATEUSE1:
67*da0073e9SAndroid Build Coastguard Worker            device = torch._C._get_privateuse1_backend_name()
68*da0073e9SAndroid Build Coastguard Worker        for train in [True, False]:
69*da0073e9SAndroid Build Coastguard Worker            for p in [0.0, 1.0]:
70*da0073e9SAndroid Build Coastguard Worker                for current_device in [device, "cpu"]:
71*da0073e9SAndroid Build Coastguard Worker                    x = torch.randn(5).to(device=current_device).requires_grad_()
72*da0073e9SAndroid Build Coastguard Worker                    x_ref = x.detach().requires_grad_()
73*da0073e9SAndroid Build Coastguard Worker                    o = torch.native_dropout(x, p, train)[0]
74*da0073e9SAndroid Build Coastguard Worker                    o_ref = torch.dropout(x_ref, p, train)
75*da0073e9SAndroid Build Coastguard Worker                    o.sum().backward()
76*da0073e9SAndroid Build Coastguard Worker                    o_ref.sum().backward()
77*da0073e9SAndroid Build Coastguard Worker                    assert o.equal(o_ref)
78*da0073e9SAndroid Build Coastguard Worker                    assert x.grad.equal(x_ref.grad)
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    def test_invalid_dropout_p(self):
81*da0073e9SAndroid Build Coastguard Worker        v = torch.ones(1)
82*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: nn.Dropout(-0.1))
83*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: nn.Dropout(1.1))
84*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: nn.Dropout1d(-0.1))
85*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: nn.Dropout1d(1.1))
86*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: nn.Dropout2d(-0.1))
87*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: nn.Dropout2d(1.1))
88*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: nn.Dropout3d(-0.1))
89*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: nn.Dropout3d(1.1))
90*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: F.dropout(v, -0.1))
91*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: F.dropout(v, 1.1))
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Workerclass TestDropoutNNDeviceType(NNTestCase):
95*da0073e9SAndroid Build Coastguard Worker    def _test_dropout(self, cls, device, input, memory_format=torch.contiguous_format):
96*da0073e9SAndroid Build Coastguard Worker        p = 0.2
97*da0073e9SAndroid Build Coastguard Worker        input = input.to(device).fill_(1 - p)
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker        module = cls(p)
100*da0073e9SAndroid Build Coastguard Worker        input_var = input.clone(memory_format=memory_format).requires_grad_()
101*da0073e9SAndroid Build Coastguard Worker        output = module(input_var)
102*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(output.is_contiguous(memory_format=memory_format))
103*da0073e9SAndroid Build Coastguard Worker        self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
104*da0073e9SAndroid Build Coastguard Worker        output.backward(input)
105*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(input_var.grad.is_contiguous(memory_format=memory_format))
106*da0073e9SAndroid Build Coastguard Worker        self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker        module = cls(p, True)
109*da0073e9SAndroid Build Coastguard Worker        input_var = input.clone(memory_format=memory_format).requires_grad_()
110*da0073e9SAndroid Build Coastguard Worker        output = module(input_var + 0)
111*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(output.is_contiguous(memory_format=memory_format))
112*da0073e9SAndroid Build Coastguard Worker        self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
113*da0073e9SAndroid Build Coastguard Worker        output.backward(input)
114*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(input_var.grad.is_contiguous(memory_format=memory_format))
115*da0073e9SAndroid Build Coastguard Worker        self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker        # check eval mode doesn't change anything
118*da0073e9SAndroid Build Coastguard Worker        for inplace in [True, False]:
119*da0073e9SAndroid Build Coastguard Worker            module = cls(p, inplace).eval()
120*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input, module(input))
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker        # Check that these don't raise errors
123*da0073e9SAndroid Build Coastguard Worker        module.__repr__()
124*da0073e9SAndroid Build Coastguard Worker        str(module)
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker    def _test_dropout_discontiguous(
127*da0073e9SAndroid Build Coastguard Worker        self, cls, device, memory_format=torch.contiguous_format
128*da0073e9SAndroid Build Coastguard Worker    ):
129*da0073e9SAndroid Build Coastguard Worker        # In this test, we verify that dropout preserves the layout and data for different memory formats.
130*da0073e9SAndroid Build Coastguard Worker        # We check whether, we get same values for the output of dropout, when the probability
131*da0073e9SAndroid Build Coastguard Worker        # of dropout is 0 or very close to 0.
132*da0073e9SAndroid Build Coastguard Worker        # Reference: https://github.com/pytorch/pytorch/issues/47176
133*da0073e9SAndroid Build Coastguard Worker        close_to_zero_p = 1e-10  # Should be almost zero but not zero, as for p=0 different path is taken
134*da0073e9SAndroid Build Coastguard Worker        for p in [0, close_to_zero_p]:
135*da0073e9SAndroid Build Coastguard Worker            inp = torch.ones(2, 3, 3, 3, device=device)
136*da0073e9SAndroid Build Coastguard Worker            inp_discontiguous = torch.empty(
137*da0073e9SAndroid Build Coastguard Worker                2, 3, 3, 6, device=device, memory_format=memory_format
138*da0073e9SAndroid Build Coastguard Worker            )[..., ::2]
139*da0073e9SAndroid Build Coastguard Worker            inp_discontiguous.copy_(inp)
140*da0073e9SAndroid Build Coastguard Worker            mod = cls(p=p)
141*da0073e9SAndroid Build Coastguard Worker            out = mod(inp_discontiguous)
142*da0073e9SAndroid Build Coastguard Worker            if p != 0:  # Zero will keep strides as is based on input.
143*da0073e9SAndroid Build Coastguard Worker                # When prob == 0, input stride (54, 18, 6, 2) -> output stride (54, 18, 6, 2)
144*da0073e9SAndroid Build Coastguard Worker                # When prob != 0, input stride (54, 18, 6, 2) -> output stride (27, 9, 3, 1)
145*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(out.is_contiguous(memory_format=memory_format))
146*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(inp_discontiguous, out)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker    def _test_dropout_stride_mean_preserve(self, cls, device):
149*da0073e9SAndroid Build Coastguard Worker        def invert_perm(p):
150*da0073e9SAndroid Build Coastguard Worker            d = {x: i for i, x in enumerate(p)}
151*da0073e9SAndroid Build Coastguard Worker            return (d[0], d[1], d[2], d[3])
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(2, 3, 4, 5, device=device)
154*da0073e9SAndroid Build Coastguard Worker        shifts = [(0, 0), (1, 0), (0, 1), (1, 1)]
155*da0073e9SAndroid Build Coastguard Worker        for perm in itertools.permutations((0, 1, 2, 3), r=4):
156*da0073e9SAndroid Build Coastguard Worker            for shift in shifts:
157*da0073e9SAndroid Build Coastguard Worker                for p in [1e-10, 0.3, 0.5, 0.7]:
158*da0073e9SAndroid Build Coastguard Worker                    mod = cls(p=p)
159*da0073e9SAndroid Build Coastguard Worker                    permuted_inp = (
160*da0073e9SAndroid Build Coastguard Worker                        inp.permute(perm).contiguous().permute(invert_perm(perm))
161*da0073e9SAndroid Build Coastguard Worker                    )
162*da0073e9SAndroid Build Coastguard Worker                    permuted_inp = permuted_inp[shift[0] :, shift[1] :, :, :]
163*da0073e9SAndroid Build Coastguard Worker                    out = mod(permuted_inp)
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(out.permute(perm).is_contiguous())
166*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(inp.mean(), out.mean(), rtol=0.5, atol=0.5)
167*da0073e9SAndroid Build Coastguard Worker                    if p == 1e-10:
168*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(permuted_inp, out)
169*da0073e9SAndroid Build Coastguard Worker                    else:
170*da0073e9SAndroid Build Coastguard Worker                        self.assertNotEqual(permuted_inp, out)
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    def test_Dropout(self, device):
173*da0073e9SAndroid Build Coastguard Worker        input = torch.empty(1000)
174*da0073e9SAndroid Build Coastguard Worker        self._test_dropout(nn.Dropout, device, input)
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker        self._test_dropout_discontiguous(nn.Dropout, device)
177*da0073e9SAndroid Build Coastguard Worker        self._test_dropout_discontiguous(
178*da0073e9SAndroid Build Coastguard Worker            nn.Dropout, device, memory_format=torch.channels_last
179*da0073e9SAndroid Build Coastguard Worker        )
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker        self._test_dropout_stride_mean_preserve(nn.Dropout, device)
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        if self.device_type == "cuda" or self.device_type == "cpu":
184*da0073e9SAndroid Build Coastguard Worker            input = input.bfloat16()
185*da0073e9SAndroid Build Coastguard Worker            self._test_dropout(nn.Dropout, device, input)
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker    def _test_dropoutNd_no_batch(self, dropout, input):
188*da0073e9SAndroid Build Coastguard Worker        input_clone = input.clone()
189*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
190*da0073e9SAndroid Build Coastguard Worker            res_no_batch = dropout(input)
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
193*da0073e9SAndroid Build Coastguard Worker            res_batched = dropout(input_clone.unsqueeze(0)).squeeze(0)
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_no_batch, res_batched)
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker    def _test_dropoutNd_channel_zero(self, dropout, input):
198*da0073e9SAndroid Build Coastguard Worker        # Verify the number of zeros in a channel is 0 or the number of elements in the channel
199*da0073e9SAndroid Build Coastguard Worker        # for a fully positive input tensor
200*da0073e9SAndroid Build Coastguard Worker        shape = input.shape
201*da0073e9SAndroid Build Coastguard Worker        B = shape[0]
202*da0073e9SAndroid Build Coastguard Worker        C = shape[1]
203*da0073e9SAndroid Build Coastguard Worker        channel_numel = torch.tensor(shape[2:]).prod()
204*da0073e9SAndroid Build Coastguard Worker        result = dropout(input)
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker        for b, c in product(range(B), range(C)):
207*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(result[b, c].count_nonzero() in (0, channel_numel))
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker    @expectedFailureXLA  # seems like freeze_rng_state is not honoured by XLA
210*da0073e9SAndroid Build Coastguard Worker    def test_Dropout1d(self, device):
211*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
212*da0073e9SAndroid Build Coastguard Worker            N, C, L = (
213*da0073e9SAndroid Build Coastguard Worker                random.randint(10, 15),
214*da0073e9SAndroid Build Coastguard Worker                random.randint(10, 15),
215*da0073e9SAndroid Build Coastguard Worker                random.randint(10, 15),
216*da0073e9SAndroid Build Coastguard Worker            )
217*da0073e9SAndroid Build Coastguard Worker            input = torch.empty(N, C, L)
218*da0073e9SAndroid Build Coastguard Worker            self._test_dropout(nn.Dropout1d, device, input)
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
221*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "Expected 2D or 3D input, but received a 4D input"
222*da0073e9SAndroid Build Coastguard Worker            ):
223*da0073e9SAndroid Build Coastguard Worker                nn.Dropout1d(p=0.5)(torch.rand(1, 2, 2, 2, device=device))
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
226*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "Expected 2D or 3D input, but received a 1D input"
227*da0073e9SAndroid Build Coastguard Worker            ):
228*da0073e9SAndroid Build Coastguard Worker                nn.Dropout1d(p=0.5)(torch.rand(2, device=device))
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker            # no batch dims
231*da0073e9SAndroid Build Coastguard Worker            input = torch.rand(50, 2, device=device)
232*da0073e9SAndroid Build Coastguard Worker            self._test_dropoutNd_no_batch(nn.Dropout1d(p=0.5), input)
233*da0073e9SAndroid Build Coastguard Worker            self._test_dropoutNd_no_batch(nn.Dropout1d(p=0.5, inplace=True), input)
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker            # check that complete channels are dropped
236*da0073e9SAndroid Build Coastguard Worker            input = torch.ones(10, 4, 2, device=device)
237*da0073e9SAndroid Build Coastguard Worker            self._test_dropoutNd_channel_zero(nn.Dropout1d(p=0.5), input)
238*da0073e9SAndroid Build Coastguard Worker            self._test_dropoutNd_channel_zero(nn.Dropout1d(p=0.5, inplace=True), input)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker    @expectedFailureXLA  # seems like freeze_rng_state is not honoured by XLA
241*da0073e9SAndroid Build Coastguard Worker    def test_Dropout2d(self, device):
242*da0073e9SAndroid Build Coastguard Worker        b = random.randint(1, 5)
243*da0073e9SAndroid Build Coastguard Worker        w = random.randint(1, 5)
244*da0073e9SAndroid Build Coastguard Worker        h = random.randint(1, 5)
245*da0073e9SAndroid Build Coastguard Worker        num_features = 1000
246*da0073e9SAndroid Build Coastguard Worker        input = torch.empty(num_features, b, w, h)
247*da0073e9SAndroid Build Coastguard Worker        self._test_dropout(nn.Dropout2d, device, input)
248*da0073e9SAndroid Build Coastguard Worker        self._test_dropout(
249*da0073e9SAndroid Build Coastguard Worker            nn.Dropout2d, device, input, memory_format=torch.channels_last
250*da0073e9SAndroid Build Coastguard Worker        )
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker        self._test_dropout_discontiguous(nn.Dropout2d, device)
253*da0073e9SAndroid Build Coastguard Worker        self._test_dropout_discontiguous(
254*da0073e9SAndroid Build Coastguard Worker            nn.Dropout2d, device, memory_format=torch.channels_last
255*da0073e9SAndroid Build Coastguard Worker        )
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, "Received a 5-D input to dropout2d"):
258*da0073e9SAndroid Build Coastguard Worker            nn.Dropout2d(p=0.5)(torch.rand(1, 2, 2, 2, 2, device=device))
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, "Received a 2-D input to dropout2d"):
261*da0073e9SAndroid Build Coastguard Worker            nn.Dropout2d(p=0.5)(torch.rand(1, 2, device=device))
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker        # TODO: Uncomment these lines once no-batch-dim inputs are supported.
264*da0073e9SAndroid Build Coastguard Worker        # For now, the historical dropout1d behavior is performed for 3D inputs.
265*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/77081
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker        # input = torch.rand(50, 2, 2, device=device)
268*da0073e9SAndroid Build Coastguard Worker        # self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5), input)
269*da0073e9SAndroid Build Coastguard Worker        # self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5, inplace=True), input)
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
272*da0073e9SAndroid Build Coastguard Worker            UserWarning, "assuming that channel-wise 1D dropout behavior is desired"
273*da0073e9SAndroid Build Coastguard Worker        ):
274*da0073e9SAndroid Build Coastguard Worker            nn.Dropout2d(p=0.5)(torch.rand(1, 2, 2, device=device))
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker        # check that complete channels are dropped
277*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(10, 4, 2, 2, device=device)
278*da0073e9SAndroid Build Coastguard Worker        self._test_dropoutNd_channel_zero(nn.Dropout2d(p=0.5), input)
279*da0073e9SAndroid Build Coastguard Worker        self._test_dropoutNd_channel_zero(nn.Dropout2d(p=0.5, inplace=True), input)
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker    @expectedFailureXLA  # seems like freeze_rng_state is not honoured by XLA
282*da0073e9SAndroid Build Coastguard Worker    def test_Dropout3d(self, device):
283*da0073e9SAndroid Build Coastguard Worker        b = random.randint(1, 5)
284*da0073e9SAndroid Build Coastguard Worker        w = random.randint(1, 5)
285*da0073e9SAndroid Build Coastguard Worker        h = random.randint(1, 5)
286*da0073e9SAndroid Build Coastguard Worker        d = random.randint(1, 2)
287*da0073e9SAndroid Build Coastguard Worker        num_features = 1000
288*da0073e9SAndroid Build Coastguard Worker        input = torch.empty(num_features, b, d, w, h)
289*da0073e9SAndroid Build Coastguard Worker        self._test_dropout(nn.Dropout3d, device, input)
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker        self._test_dropout_discontiguous(nn.Dropout3d, device)
292*da0073e9SAndroid Build Coastguard Worker        self._test_dropout_discontiguous(
293*da0073e9SAndroid Build Coastguard Worker            nn.Dropout3d, device, memory_format=torch.channels_last
294*da0073e9SAndroid Build Coastguard Worker        )
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, "Received a 6-D input to dropout3d"):
297*da0073e9SAndroid Build Coastguard Worker            nn.Dropout3d(p=0.5)(torch.rand(1, 2, 2, 2, 2, 2, device=device))
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, "Received a 3-D input to dropout3d"):
300*da0073e9SAndroid Build Coastguard Worker            nn.Dropout3d(p=0.5)(torch.rand(1, 2, 2, device=device))
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker        # no batch dims
303*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(50, 2, 2, 2, device=device)
304*da0073e9SAndroid Build Coastguard Worker        self._test_dropoutNd_no_batch(nn.Dropout3d(p=0.5), input)
305*da0073e9SAndroid Build Coastguard Worker        self._test_dropoutNd_no_batch(nn.Dropout3d(p=0.5, inplace=True), input)
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker        # check that complete channels are dropped
308*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(10, 4, 2, 2, 2, device=device)
309*da0073e9SAndroid Build Coastguard Worker        self._test_dropoutNd_channel_zero(nn.Dropout3d(p=0.5), input)
310*da0073e9SAndroid Build Coastguard Worker        self._test_dropoutNd_channel_zero(nn.Dropout3d(p=0.5, inplace=True), input)
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker    def test_empty_dropout(self, device):
313*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([]).to(device)
314*da0073e9SAndroid Build Coastguard Worker        out = torch.nn.functional.dropout(x)
315*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.size(), x.size())
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestDropoutNNDeviceType, globals())
319*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestDropoutNN)
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
322*da0073e9SAndroid Build Coastguard Worker    run_tests()
323