1# Owner(s): ["module: nn"] 2import itertools 3import random 4import unittest 5from itertools import product 6 7import torch 8import torch.nn as nn 9import torch.nn.functional as F 10from torch.testing._internal.common_cuda import TEST_CUDA 11from torch.testing._internal.common_device_type import ( 12 expectedFailureXLA, 13 instantiate_device_type_tests, 14) 15from torch.testing._internal.common_nn import freeze_rng_state, NNTestCase 16from torch.testing._internal.common_utils import ( 17 instantiate_parametrized_tests, 18 run_tests, 19 set_default_dtype, 20 TEST_PRIVATEUSE1, 21) 22 23 24class TestDropoutNN(NNTestCase): 25 _do_cuda_memory_leak_check = True 26 _do_cuda_non_default_stream = True 27 28 def _test_alpha_dropout(self, cls, input): 29 mean = input.mean() 30 std = input.std() 31 32 for p in [0.2, 0.5, 0.8]: 33 module = cls(p) 34 input_var = input.detach().clone().requires_grad_() 35 output = module(input_var) 36 # output mean should be close to input mean 37 self.assertLess(abs(output.data.mean() - mean), 0.1) 38 # output std should be close to input std 39 self.assertLess(abs(output.data.std() - std), 0.1) 40 output.backward(input) 41 42 def test_AlphaDropout(self): 43 # generate random tensor with zero mean and unit std 44 input = torch.randn(5000) 45 self._test_alpha_dropout(nn.AlphaDropout, input) 46 47 def test_FeatureAlphaDropout(self): 48 b = random.randint(1, 5) 49 w = random.randint(1, 5) 50 h = random.randint(1, 5) 51 d = random.randint(1, 2) 52 num_features = 1000 53 input = torch.randn(num_features, b, d, w, h) 54 self._test_alpha_dropout(nn.FeatureAlphaDropout, input) 55 56 # no batch dims 57 input = torch.randn(50, 20, 64, 64) 58 self._test_alpha_dropout(nn.FeatureAlphaDropout, input) 59 60 @unittest.skipIf( 61 not (TEST_CUDA or TEST_PRIVATEUSE1), "CUDA and PRIVATEUSE1 unavailable" 62 ) 63 def test_native_dropout_corner_case(self): 64 if TEST_CUDA: 65 device = "cuda" 66 elif TEST_PRIVATEUSE1: 67 device = torch._C._get_privateuse1_backend_name() 68 for train in [True, False]: 69 for p in [0.0, 1.0]: 70 for current_device in [device, "cpu"]: 71 x = torch.randn(5).to(device=current_device).requires_grad_() 72 x_ref = x.detach().requires_grad_() 73 o = torch.native_dropout(x, p, train)[0] 74 o_ref = torch.dropout(x_ref, p, train) 75 o.sum().backward() 76 o_ref.sum().backward() 77 assert o.equal(o_ref) 78 assert x.grad.equal(x_ref.grad) 79 80 def test_invalid_dropout_p(self): 81 v = torch.ones(1) 82 self.assertRaises(ValueError, lambda: nn.Dropout(-0.1)) 83 self.assertRaises(ValueError, lambda: nn.Dropout(1.1)) 84 self.assertRaises(ValueError, lambda: nn.Dropout1d(-0.1)) 85 self.assertRaises(ValueError, lambda: nn.Dropout1d(1.1)) 86 self.assertRaises(ValueError, lambda: nn.Dropout2d(-0.1)) 87 self.assertRaises(ValueError, lambda: nn.Dropout2d(1.1)) 88 self.assertRaises(ValueError, lambda: nn.Dropout3d(-0.1)) 89 self.assertRaises(ValueError, lambda: nn.Dropout3d(1.1)) 90 self.assertRaises(ValueError, lambda: F.dropout(v, -0.1)) 91 self.assertRaises(ValueError, lambda: F.dropout(v, 1.1)) 92 93 94class TestDropoutNNDeviceType(NNTestCase): 95 def _test_dropout(self, cls, device, input, memory_format=torch.contiguous_format): 96 p = 0.2 97 input = input.to(device).fill_(1 - p) 98 99 module = cls(p) 100 input_var = input.clone(memory_format=memory_format).requires_grad_() 101 output = module(input_var) 102 self.assertTrue(output.is_contiguous(memory_format=memory_format)) 103 self.assertLess(abs(output.data.mean() - (1 - p)), 0.05) 104 output.backward(input) 105 self.assertTrue(input_var.grad.is_contiguous(memory_format=memory_format)) 106 self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05) 107 108 module = cls(p, True) 109 input_var = input.clone(memory_format=memory_format).requires_grad_() 110 output = module(input_var + 0) 111 self.assertTrue(output.is_contiguous(memory_format=memory_format)) 112 self.assertLess(abs(output.data.mean() - (1 - p)), 0.05) 113 output.backward(input) 114 self.assertTrue(input_var.grad.is_contiguous(memory_format=memory_format)) 115 self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05) 116 117 # check eval mode doesn't change anything 118 for inplace in [True, False]: 119 module = cls(p, inplace).eval() 120 self.assertEqual(input, module(input)) 121 122 # Check that these don't raise errors 123 module.__repr__() 124 str(module) 125 126 def _test_dropout_discontiguous( 127 self, cls, device, memory_format=torch.contiguous_format 128 ): 129 # In this test, we verify that dropout preserves the layout and data for different memory formats. 130 # We check whether, we get same values for the output of dropout, when the probability 131 # of dropout is 0 or very close to 0. 132 # Reference: https://github.com/pytorch/pytorch/issues/47176 133 close_to_zero_p = 1e-10 # Should be almost zero but not zero, as for p=0 different path is taken 134 for p in [0, close_to_zero_p]: 135 inp = torch.ones(2, 3, 3, 3, device=device) 136 inp_discontiguous = torch.empty( 137 2, 3, 3, 6, device=device, memory_format=memory_format 138 )[..., ::2] 139 inp_discontiguous.copy_(inp) 140 mod = cls(p=p) 141 out = mod(inp_discontiguous) 142 if p != 0: # Zero will keep strides as is based on input. 143 # When prob == 0, input stride (54, 18, 6, 2) -> output stride (54, 18, 6, 2) 144 # When prob != 0, input stride (54, 18, 6, 2) -> output stride (27, 9, 3, 1) 145 self.assertTrue(out.is_contiguous(memory_format=memory_format)) 146 self.assertEqual(inp_discontiguous, out) 147 148 def _test_dropout_stride_mean_preserve(self, cls, device): 149 def invert_perm(p): 150 d = {x: i for i, x in enumerate(p)} 151 return (d[0], d[1], d[2], d[3]) 152 153 inp = torch.ones(2, 3, 4, 5, device=device) 154 shifts = [(0, 0), (1, 0), (0, 1), (1, 1)] 155 for perm in itertools.permutations((0, 1, 2, 3), r=4): 156 for shift in shifts: 157 for p in [1e-10, 0.3, 0.5, 0.7]: 158 mod = cls(p=p) 159 permuted_inp = ( 160 inp.permute(perm).contiguous().permute(invert_perm(perm)) 161 ) 162 permuted_inp = permuted_inp[shift[0] :, shift[1] :, :, :] 163 out = mod(permuted_inp) 164 165 self.assertTrue(out.permute(perm).is_contiguous()) 166 self.assertEqual(inp.mean(), out.mean(), rtol=0.5, atol=0.5) 167 if p == 1e-10: 168 self.assertEqual(permuted_inp, out) 169 else: 170 self.assertNotEqual(permuted_inp, out) 171 172 def test_Dropout(self, device): 173 input = torch.empty(1000) 174 self._test_dropout(nn.Dropout, device, input) 175 176 self._test_dropout_discontiguous(nn.Dropout, device) 177 self._test_dropout_discontiguous( 178 nn.Dropout, device, memory_format=torch.channels_last 179 ) 180 181 self._test_dropout_stride_mean_preserve(nn.Dropout, device) 182 183 if self.device_type == "cuda" or self.device_type == "cpu": 184 input = input.bfloat16() 185 self._test_dropout(nn.Dropout, device, input) 186 187 def _test_dropoutNd_no_batch(self, dropout, input): 188 input_clone = input.clone() 189 with freeze_rng_state(): 190 res_no_batch = dropout(input) 191 192 with freeze_rng_state(): 193 res_batched = dropout(input_clone.unsqueeze(0)).squeeze(0) 194 195 self.assertEqual(res_no_batch, res_batched) 196 197 def _test_dropoutNd_channel_zero(self, dropout, input): 198 # Verify the number of zeros in a channel is 0 or the number of elements in the channel 199 # for a fully positive input tensor 200 shape = input.shape 201 B = shape[0] 202 C = shape[1] 203 channel_numel = torch.tensor(shape[2:]).prod() 204 result = dropout(input) 205 206 for b, c in product(range(B), range(C)): 207 self.assertTrue(result[b, c].count_nonzero() in (0, channel_numel)) 208 209 @expectedFailureXLA # seems like freeze_rng_state is not honoured by XLA 210 def test_Dropout1d(self, device): 211 with set_default_dtype(torch.double): 212 N, C, L = ( 213 random.randint(10, 15), 214 random.randint(10, 15), 215 random.randint(10, 15), 216 ) 217 input = torch.empty(N, C, L) 218 self._test_dropout(nn.Dropout1d, device, input) 219 220 with self.assertRaisesRegex( 221 RuntimeError, "Expected 2D or 3D input, but received a 4D input" 222 ): 223 nn.Dropout1d(p=0.5)(torch.rand(1, 2, 2, 2, device=device)) 224 225 with self.assertRaisesRegex( 226 RuntimeError, "Expected 2D or 3D input, but received a 1D input" 227 ): 228 nn.Dropout1d(p=0.5)(torch.rand(2, device=device)) 229 230 # no batch dims 231 input = torch.rand(50, 2, device=device) 232 self._test_dropoutNd_no_batch(nn.Dropout1d(p=0.5), input) 233 self._test_dropoutNd_no_batch(nn.Dropout1d(p=0.5, inplace=True), input) 234 235 # check that complete channels are dropped 236 input = torch.ones(10, 4, 2, device=device) 237 self._test_dropoutNd_channel_zero(nn.Dropout1d(p=0.5), input) 238 self._test_dropoutNd_channel_zero(nn.Dropout1d(p=0.5, inplace=True), input) 239 240 @expectedFailureXLA # seems like freeze_rng_state is not honoured by XLA 241 def test_Dropout2d(self, device): 242 b = random.randint(1, 5) 243 w = random.randint(1, 5) 244 h = random.randint(1, 5) 245 num_features = 1000 246 input = torch.empty(num_features, b, w, h) 247 self._test_dropout(nn.Dropout2d, device, input) 248 self._test_dropout( 249 nn.Dropout2d, device, input, memory_format=torch.channels_last 250 ) 251 252 self._test_dropout_discontiguous(nn.Dropout2d, device) 253 self._test_dropout_discontiguous( 254 nn.Dropout2d, device, memory_format=torch.channels_last 255 ) 256 257 with self.assertWarnsRegex(UserWarning, "Received a 5-D input to dropout2d"): 258 nn.Dropout2d(p=0.5)(torch.rand(1, 2, 2, 2, 2, device=device)) 259 260 with self.assertWarnsRegex(UserWarning, "Received a 2-D input to dropout2d"): 261 nn.Dropout2d(p=0.5)(torch.rand(1, 2, device=device)) 262 263 # TODO: Uncomment these lines once no-batch-dim inputs are supported. 264 # For now, the historical dropout1d behavior is performed for 3D inputs. 265 # See https://github.com/pytorch/pytorch/issues/77081 266 267 # input = torch.rand(50, 2, 2, device=device) 268 # self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5), input) 269 # self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5, inplace=True), input) 270 271 with self.assertWarnsRegex( 272 UserWarning, "assuming that channel-wise 1D dropout behavior is desired" 273 ): 274 nn.Dropout2d(p=0.5)(torch.rand(1, 2, 2, device=device)) 275 276 # check that complete channels are dropped 277 input = torch.ones(10, 4, 2, 2, device=device) 278 self._test_dropoutNd_channel_zero(nn.Dropout2d(p=0.5), input) 279 self._test_dropoutNd_channel_zero(nn.Dropout2d(p=0.5, inplace=True), input) 280 281 @expectedFailureXLA # seems like freeze_rng_state is not honoured by XLA 282 def test_Dropout3d(self, device): 283 b = random.randint(1, 5) 284 w = random.randint(1, 5) 285 h = random.randint(1, 5) 286 d = random.randint(1, 2) 287 num_features = 1000 288 input = torch.empty(num_features, b, d, w, h) 289 self._test_dropout(nn.Dropout3d, device, input) 290 291 self._test_dropout_discontiguous(nn.Dropout3d, device) 292 self._test_dropout_discontiguous( 293 nn.Dropout3d, device, memory_format=torch.channels_last 294 ) 295 296 with self.assertWarnsRegex(UserWarning, "Received a 6-D input to dropout3d"): 297 nn.Dropout3d(p=0.5)(torch.rand(1, 2, 2, 2, 2, 2, device=device)) 298 299 with self.assertWarnsRegex(UserWarning, "Received a 3-D input to dropout3d"): 300 nn.Dropout3d(p=0.5)(torch.rand(1, 2, 2, device=device)) 301 302 # no batch dims 303 input = torch.rand(50, 2, 2, 2, device=device) 304 self._test_dropoutNd_no_batch(nn.Dropout3d(p=0.5), input) 305 self._test_dropoutNd_no_batch(nn.Dropout3d(p=0.5, inplace=True), input) 306 307 # check that complete channels are dropped 308 input = torch.ones(10, 4, 2, 2, 2, device=device) 309 self._test_dropoutNd_channel_zero(nn.Dropout3d(p=0.5), input) 310 self._test_dropoutNd_channel_zero(nn.Dropout3d(p=0.5, inplace=True), input) 311 312 def test_empty_dropout(self, device): 313 x = torch.tensor([]).to(device) 314 out = torch.nn.functional.dropout(x) 315 self.assertEqual(out.size(), x.size()) 316 317 318instantiate_device_type_tests(TestDropoutNNDeviceType, globals()) 319instantiate_parametrized_tests(TestDropoutNN) 320 321if __name__ == "__main__": 322 run_tests() 323