1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"] 2*da0073e9SAndroid Build Coastguard Workerimport itertools 3*da0073e9SAndroid Build Coastguard Workerimport random 4*da0073e9SAndroid Build Coastguard Workerimport unittest 5*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 9*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 12*da0073e9SAndroid Build Coastguard Worker expectedFailureXLA, 13*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 14*da0073e9SAndroid Build Coastguard Worker) 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import freeze_rng_state, NNTestCase 16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 17*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 18*da0073e9SAndroid Build Coastguard Worker run_tests, 19*da0073e9SAndroid Build Coastguard Worker set_default_dtype, 20*da0073e9SAndroid Build Coastguard Worker TEST_PRIVATEUSE1, 21*da0073e9SAndroid Build Coastguard Worker) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerclass TestDropoutNN(NNTestCase): 25*da0073e9SAndroid Build Coastguard Worker _do_cuda_memory_leak_check = True 26*da0073e9SAndroid Build Coastguard Worker _do_cuda_non_default_stream = True 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker def _test_alpha_dropout(self, cls, input): 29*da0073e9SAndroid Build Coastguard Worker mean = input.mean() 30*da0073e9SAndroid Build Coastguard Worker std = input.std() 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker for p in [0.2, 0.5, 0.8]: 33*da0073e9SAndroid Build Coastguard Worker module = cls(p) 34*da0073e9SAndroid Build Coastguard Worker input_var = input.detach().clone().requires_grad_() 35*da0073e9SAndroid Build Coastguard Worker output = module(input_var) 36*da0073e9SAndroid Build Coastguard Worker # output mean should be close to input mean 37*da0073e9SAndroid Build Coastguard Worker self.assertLess(abs(output.data.mean() - mean), 0.1) 38*da0073e9SAndroid Build Coastguard Worker # output std should be close to input std 39*da0073e9SAndroid Build Coastguard Worker self.assertLess(abs(output.data.std() - std), 0.1) 40*da0073e9SAndroid Build Coastguard Worker output.backward(input) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker def test_AlphaDropout(self): 43*da0073e9SAndroid Build Coastguard Worker # generate random tensor with zero mean and unit std 44*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5000) 45*da0073e9SAndroid Build Coastguard Worker self._test_alpha_dropout(nn.AlphaDropout, input) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker def test_FeatureAlphaDropout(self): 48*da0073e9SAndroid Build Coastguard Worker b = random.randint(1, 5) 49*da0073e9SAndroid Build Coastguard Worker w = random.randint(1, 5) 50*da0073e9SAndroid Build Coastguard Worker h = random.randint(1, 5) 51*da0073e9SAndroid Build Coastguard Worker d = random.randint(1, 2) 52*da0073e9SAndroid Build Coastguard Worker num_features = 1000 53*da0073e9SAndroid Build Coastguard Worker input = torch.randn(num_features, b, d, w, h) 54*da0073e9SAndroid Build Coastguard Worker self._test_alpha_dropout(nn.FeatureAlphaDropout, input) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker # no batch dims 57*da0073e9SAndroid Build Coastguard Worker input = torch.randn(50, 20, 64, 64) 58*da0073e9SAndroid Build Coastguard Worker self._test_alpha_dropout(nn.FeatureAlphaDropout, input) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 61*da0073e9SAndroid Build Coastguard Worker not (TEST_CUDA or TEST_PRIVATEUSE1), "CUDA and PRIVATEUSE1 unavailable" 62*da0073e9SAndroid Build Coastguard Worker ) 63*da0073e9SAndroid Build Coastguard Worker def test_native_dropout_corner_case(self): 64*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 65*da0073e9SAndroid Build Coastguard Worker device = "cuda" 66*da0073e9SAndroid Build Coastguard Worker elif TEST_PRIVATEUSE1: 67*da0073e9SAndroid Build Coastguard Worker device = torch._C._get_privateuse1_backend_name() 68*da0073e9SAndroid Build Coastguard Worker for train in [True, False]: 69*da0073e9SAndroid Build Coastguard Worker for p in [0.0, 1.0]: 70*da0073e9SAndroid Build Coastguard Worker for current_device in [device, "cpu"]: 71*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5).to(device=current_device).requires_grad_() 72*da0073e9SAndroid Build Coastguard Worker x_ref = x.detach().requires_grad_() 73*da0073e9SAndroid Build Coastguard Worker o = torch.native_dropout(x, p, train)[0] 74*da0073e9SAndroid Build Coastguard Worker o_ref = torch.dropout(x_ref, p, train) 75*da0073e9SAndroid Build Coastguard Worker o.sum().backward() 76*da0073e9SAndroid Build Coastguard Worker o_ref.sum().backward() 77*da0073e9SAndroid Build Coastguard Worker assert o.equal(o_ref) 78*da0073e9SAndroid Build Coastguard Worker assert x.grad.equal(x_ref.grad) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker def test_invalid_dropout_p(self): 81*da0073e9SAndroid Build Coastguard Worker v = torch.ones(1) 82*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: nn.Dropout(-0.1)) 83*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: nn.Dropout(1.1)) 84*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: nn.Dropout1d(-0.1)) 85*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: nn.Dropout1d(1.1)) 86*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: nn.Dropout2d(-0.1)) 87*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: nn.Dropout2d(1.1)) 88*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: nn.Dropout3d(-0.1)) 89*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: nn.Dropout3d(1.1)) 90*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: F.dropout(v, -0.1)) 91*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: F.dropout(v, 1.1)) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Workerclass TestDropoutNNDeviceType(NNTestCase): 95*da0073e9SAndroid Build Coastguard Worker def _test_dropout(self, cls, device, input, memory_format=torch.contiguous_format): 96*da0073e9SAndroid Build Coastguard Worker p = 0.2 97*da0073e9SAndroid Build Coastguard Worker input = input.to(device).fill_(1 - p) 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker module = cls(p) 100*da0073e9SAndroid Build Coastguard Worker input_var = input.clone(memory_format=memory_format).requires_grad_() 101*da0073e9SAndroid Build Coastguard Worker output = module(input_var) 102*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.is_contiguous(memory_format=memory_format)) 103*da0073e9SAndroid Build Coastguard Worker self.assertLess(abs(output.data.mean() - (1 - p)), 0.05) 104*da0073e9SAndroid Build Coastguard Worker output.backward(input) 105*da0073e9SAndroid Build Coastguard Worker self.assertTrue(input_var.grad.is_contiguous(memory_format=memory_format)) 106*da0073e9SAndroid Build Coastguard Worker self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05) 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker module = cls(p, True) 109*da0073e9SAndroid Build Coastguard Worker input_var = input.clone(memory_format=memory_format).requires_grad_() 110*da0073e9SAndroid Build Coastguard Worker output = module(input_var + 0) 111*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.is_contiguous(memory_format=memory_format)) 112*da0073e9SAndroid Build Coastguard Worker self.assertLess(abs(output.data.mean() - (1 - p)), 0.05) 113*da0073e9SAndroid Build Coastguard Worker output.backward(input) 114*da0073e9SAndroid Build Coastguard Worker self.assertTrue(input_var.grad.is_contiguous(memory_format=memory_format)) 115*da0073e9SAndroid Build Coastguard Worker self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05) 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker # check eval mode doesn't change anything 118*da0073e9SAndroid Build Coastguard Worker for inplace in [True, False]: 119*da0073e9SAndroid Build Coastguard Worker module = cls(p, inplace).eval() 120*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, module(input)) 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker # Check that these don't raise errors 123*da0073e9SAndroid Build Coastguard Worker module.__repr__() 124*da0073e9SAndroid Build Coastguard Worker str(module) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker def _test_dropout_discontiguous( 127*da0073e9SAndroid Build Coastguard Worker self, cls, device, memory_format=torch.contiguous_format 128*da0073e9SAndroid Build Coastguard Worker ): 129*da0073e9SAndroid Build Coastguard Worker # In this test, we verify that dropout preserves the layout and data for different memory formats. 130*da0073e9SAndroid Build Coastguard Worker # We check whether, we get same values for the output of dropout, when the probability 131*da0073e9SAndroid Build Coastguard Worker # of dropout is 0 or very close to 0. 132*da0073e9SAndroid Build Coastguard Worker # Reference: https://github.com/pytorch/pytorch/issues/47176 133*da0073e9SAndroid Build Coastguard Worker close_to_zero_p = 1e-10 # Should be almost zero but not zero, as for p=0 different path is taken 134*da0073e9SAndroid Build Coastguard Worker for p in [0, close_to_zero_p]: 135*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(2, 3, 3, 3, device=device) 136*da0073e9SAndroid Build Coastguard Worker inp_discontiguous = torch.empty( 137*da0073e9SAndroid Build Coastguard Worker 2, 3, 3, 6, device=device, memory_format=memory_format 138*da0073e9SAndroid Build Coastguard Worker )[..., ::2] 139*da0073e9SAndroid Build Coastguard Worker inp_discontiguous.copy_(inp) 140*da0073e9SAndroid Build Coastguard Worker mod = cls(p=p) 141*da0073e9SAndroid Build Coastguard Worker out = mod(inp_discontiguous) 142*da0073e9SAndroid Build Coastguard Worker if p != 0: # Zero will keep strides as is based on input. 143*da0073e9SAndroid Build Coastguard Worker # When prob == 0, input stride (54, 18, 6, 2) -> output stride (54, 18, 6, 2) 144*da0073e9SAndroid Build Coastguard Worker # When prob != 0, input stride (54, 18, 6, 2) -> output stride (27, 9, 3, 1) 145*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=memory_format)) 146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp_discontiguous, out) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker def _test_dropout_stride_mean_preserve(self, cls, device): 149*da0073e9SAndroid Build Coastguard Worker def invert_perm(p): 150*da0073e9SAndroid Build Coastguard Worker d = {x: i for i, x in enumerate(p)} 151*da0073e9SAndroid Build Coastguard Worker return (d[0], d[1], d[2], d[3]) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(2, 3, 4, 5, device=device) 154*da0073e9SAndroid Build Coastguard Worker shifts = [(0, 0), (1, 0), (0, 1), (1, 1)] 155*da0073e9SAndroid Build Coastguard Worker for perm in itertools.permutations((0, 1, 2, 3), r=4): 156*da0073e9SAndroid Build Coastguard Worker for shift in shifts: 157*da0073e9SAndroid Build Coastguard Worker for p in [1e-10, 0.3, 0.5, 0.7]: 158*da0073e9SAndroid Build Coastguard Worker mod = cls(p=p) 159*da0073e9SAndroid Build Coastguard Worker permuted_inp = ( 160*da0073e9SAndroid Build Coastguard Worker inp.permute(perm).contiguous().permute(invert_perm(perm)) 161*da0073e9SAndroid Build Coastguard Worker ) 162*da0073e9SAndroid Build Coastguard Worker permuted_inp = permuted_inp[shift[0] :, shift[1] :, :, :] 163*da0073e9SAndroid Build Coastguard Worker out = mod(permuted_inp) 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.permute(perm).is_contiguous()) 166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp.mean(), out.mean(), rtol=0.5, atol=0.5) 167*da0073e9SAndroid Build Coastguard Worker if p == 1e-10: 168*da0073e9SAndroid Build Coastguard Worker self.assertEqual(permuted_inp, out) 169*da0073e9SAndroid Build Coastguard Worker else: 170*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(permuted_inp, out) 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker def test_Dropout(self, device): 173*da0073e9SAndroid Build Coastguard Worker input = torch.empty(1000) 174*da0073e9SAndroid Build Coastguard Worker self._test_dropout(nn.Dropout, device, input) 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker self._test_dropout_discontiguous(nn.Dropout, device) 177*da0073e9SAndroid Build Coastguard Worker self._test_dropout_discontiguous( 178*da0073e9SAndroid Build Coastguard Worker nn.Dropout, device, memory_format=torch.channels_last 179*da0073e9SAndroid Build Coastguard Worker ) 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker self._test_dropout_stride_mean_preserve(nn.Dropout, device) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker if self.device_type == "cuda" or self.device_type == "cpu": 184*da0073e9SAndroid Build Coastguard Worker input = input.bfloat16() 185*da0073e9SAndroid Build Coastguard Worker self._test_dropout(nn.Dropout, device, input) 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker def _test_dropoutNd_no_batch(self, dropout, input): 188*da0073e9SAndroid Build Coastguard Worker input_clone = input.clone() 189*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 190*da0073e9SAndroid Build Coastguard Worker res_no_batch = dropout(input) 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 193*da0073e9SAndroid Build Coastguard Worker res_batched = dropout(input_clone.unsqueeze(0)).squeeze(0) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_no_batch, res_batched) 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker def _test_dropoutNd_channel_zero(self, dropout, input): 198*da0073e9SAndroid Build Coastguard Worker # Verify the number of zeros in a channel is 0 or the number of elements in the channel 199*da0073e9SAndroid Build Coastguard Worker # for a fully positive input tensor 200*da0073e9SAndroid Build Coastguard Worker shape = input.shape 201*da0073e9SAndroid Build Coastguard Worker B = shape[0] 202*da0073e9SAndroid Build Coastguard Worker C = shape[1] 203*da0073e9SAndroid Build Coastguard Worker channel_numel = torch.tensor(shape[2:]).prod() 204*da0073e9SAndroid Build Coastguard Worker result = dropout(input) 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker for b, c in product(range(B), range(C)): 207*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result[b, c].count_nonzero() in (0, channel_numel)) 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker @expectedFailureXLA # seems like freeze_rng_state is not honoured by XLA 210*da0073e9SAndroid Build Coastguard Worker def test_Dropout1d(self, device): 211*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 212*da0073e9SAndroid Build Coastguard Worker N, C, L = ( 213*da0073e9SAndroid Build Coastguard Worker random.randint(10, 15), 214*da0073e9SAndroid Build Coastguard Worker random.randint(10, 15), 215*da0073e9SAndroid Build Coastguard Worker random.randint(10, 15), 216*da0073e9SAndroid Build Coastguard Worker ) 217*da0073e9SAndroid Build Coastguard Worker input = torch.empty(N, C, L) 218*da0073e9SAndroid Build Coastguard Worker self._test_dropout(nn.Dropout1d, device, input) 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 221*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected 2D or 3D input, but received a 4D input" 222*da0073e9SAndroid Build Coastguard Worker ): 223*da0073e9SAndroid Build Coastguard Worker nn.Dropout1d(p=0.5)(torch.rand(1, 2, 2, 2, device=device)) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 226*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected 2D or 3D input, but received a 1D input" 227*da0073e9SAndroid Build Coastguard Worker ): 228*da0073e9SAndroid Build Coastguard Worker nn.Dropout1d(p=0.5)(torch.rand(2, device=device)) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker # no batch dims 231*da0073e9SAndroid Build Coastguard Worker input = torch.rand(50, 2, device=device) 232*da0073e9SAndroid Build Coastguard Worker self._test_dropoutNd_no_batch(nn.Dropout1d(p=0.5), input) 233*da0073e9SAndroid Build Coastguard Worker self._test_dropoutNd_no_batch(nn.Dropout1d(p=0.5, inplace=True), input) 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker # check that complete channels are dropped 236*da0073e9SAndroid Build Coastguard Worker input = torch.ones(10, 4, 2, device=device) 237*da0073e9SAndroid Build Coastguard Worker self._test_dropoutNd_channel_zero(nn.Dropout1d(p=0.5), input) 238*da0073e9SAndroid Build Coastguard Worker self._test_dropoutNd_channel_zero(nn.Dropout1d(p=0.5, inplace=True), input) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker @expectedFailureXLA # seems like freeze_rng_state is not honoured by XLA 241*da0073e9SAndroid Build Coastguard Worker def test_Dropout2d(self, device): 242*da0073e9SAndroid Build Coastguard Worker b = random.randint(1, 5) 243*da0073e9SAndroid Build Coastguard Worker w = random.randint(1, 5) 244*da0073e9SAndroid Build Coastguard Worker h = random.randint(1, 5) 245*da0073e9SAndroid Build Coastguard Worker num_features = 1000 246*da0073e9SAndroid Build Coastguard Worker input = torch.empty(num_features, b, w, h) 247*da0073e9SAndroid Build Coastguard Worker self._test_dropout(nn.Dropout2d, device, input) 248*da0073e9SAndroid Build Coastguard Worker self._test_dropout( 249*da0073e9SAndroid Build Coastguard Worker nn.Dropout2d, device, input, memory_format=torch.channels_last 250*da0073e9SAndroid Build Coastguard Worker ) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker self._test_dropout_discontiguous(nn.Dropout2d, device) 253*da0073e9SAndroid Build Coastguard Worker self._test_dropout_discontiguous( 254*da0073e9SAndroid Build Coastguard Worker nn.Dropout2d, device, memory_format=torch.channels_last 255*da0073e9SAndroid Build Coastguard Worker ) 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "Received a 5-D input to dropout2d"): 258*da0073e9SAndroid Build Coastguard Worker nn.Dropout2d(p=0.5)(torch.rand(1, 2, 2, 2, 2, device=device)) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "Received a 2-D input to dropout2d"): 261*da0073e9SAndroid Build Coastguard Worker nn.Dropout2d(p=0.5)(torch.rand(1, 2, device=device)) 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker # TODO: Uncomment these lines once no-batch-dim inputs are supported. 264*da0073e9SAndroid Build Coastguard Worker # For now, the historical dropout1d behavior is performed for 3D inputs. 265*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/77081 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker # input = torch.rand(50, 2, 2, device=device) 268*da0073e9SAndroid Build Coastguard Worker # self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5), input) 269*da0073e9SAndroid Build Coastguard Worker # self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5, inplace=True), input) 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 272*da0073e9SAndroid Build Coastguard Worker UserWarning, "assuming that channel-wise 1D dropout behavior is desired" 273*da0073e9SAndroid Build Coastguard Worker ): 274*da0073e9SAndroid Build Coastguard Worker nn.Dropout2d(p=0.5)(torch.rand(1, 2, 2, device=device)) 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker # check that complete channels are dropped 277*da0073e9SAndroid Build Coastguard Worker input = torch.ones(10, 4, 2, 2, device=device) 278*da0073e9SAndroid Build Coastguard Worker self._test_dropoutNd_channel_zero(nn.Dropout2d(p=0.5), input) 279*da0073e9SAndroid Build Coastguard Worker self._test_dropoutNd_channel_zero(nn.Dropout2d(p=0.5, inplace=True), input) 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker @expectedFailureXLA # seems like freeze_rng_state is not honoured by XLA 282*da0073e9SAndroid Build Coastguard Worker def test_Dropout3d(self, device): 283*da0073e9SAndroid Build Coastguard Worker b = random.randint(1, 5) 284*da0073e9SAndroid Build Coastguard Worker w = random.randint(1, 5) 285*da0073e9SAndroid Build Coastguard Worker h = random.randint(1, 5) 286*da0073e9SAndroid Build Coastguard Worker d = random.randint(1, 2) 287*da0073e9SAndroid Build Coastguard Worker num_features = 1000 288*da0073e9SAndroid Build Coastguard Worker input = torch.empty(num_features, b, d, w, h) 289*da0073e9SAndroid Build Coastguard Worker self._test_dropout(nn.Dropout3d, device, input) 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker self._test_dropout_discontiguous(nn.Dropout3d, device) 292*da0073e9SAndroid Build Coastguard Worker self._test_dropout_discontiguous( 293*da0073e9SAndroid Build Coastguard Worker nn.Dropout3d, device, memory_format=torch.channels_last 294*da0073e9SAndroid Build Coastguard Worker ) 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "Received a 6-D input to dropout3d"): 297*da0073e9SAndroid Build Coastguard Worker nn.Dropout3d(p=0.5)(torch.rand(1, 2, 2, 2, 2, 2, device=device)) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "Received a 3-D input to dropout3d"): 300*da0073e9SAndroid Build Coastguard Worker nn.Dropout3d(p=0.5)(torch.rand(1, 2, 2, device=device)) 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker # no batch dims 303*da0073e9SAndroid Build Coastguard Worker input = torch.rand(50, 2, 2, 2, device=device) 304*da0073e9SAndroid Build Coastguard Worker self._test_dropoutNd_no_batch(nn.Dropout3d(p=0.5), input) 305*da0073e9SAndroid Build Coastguard Worker self._test_dropoutNd_no_batch(nn.Dropout3d(p=0.5, inplace=True), input) 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker # check that complete channels are dropped 308*da0073e9SAndroid Build Coastguard Worker input = torch.ones(10, 4, 2, 2, 2, device=device) 309*da0073e9SAndroid Build Coastguard Worker self._test_dropoutNd_channel_zero(nn.Dropout3d(p=0.5), input) 310*da0073e9SAndroid Build Coastguard Worker self._test_dropoutNd_channel_zero(nn.Dropout3d(p=0.5, inplace=True), input) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker def test_empty_dropout(self, device): 313*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([]).to(device) 314*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.dropout(x) 315*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.size(), x.size()) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestDropoutNNDeviceType, globals()) 319*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestDropoutNN) 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 322*da0073e9SAndroid Build Coastguard Worker run_tests() 323