1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dataloader"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport ctypes 4*da0073e9SAndroid Build Coastguard Workerimport errno 5*da0073e9SAndroid Build Coastguard Workerimport faulthandler 6*da0073e9SAndroid Build Coastguard Workerimport functools 7*da0073e9SAndroid Build Coastguard Workerimport gc 8*da0073e9SAndroid Build Coastguard Workerimport itertools 9*da0073e9SAndroid Build Coastguard Workerimport math 10*da0073e9SAndroid Build Coastguard Workerimport operator 11*da0073e9SAndroid Build Coastguard Workerimport os 12*da0073e9SAndroid Build Coastguard Workerimport signal 13*da0073e9SAndroid Build Coastguard Workerimport sys 14*da0073e9SAndroid Build Coastguard Workerimport tempfile 15*da0073e9SAndroid Build Coastguard Workerimport time 16*da0073e9SAndroid Build Coastguard Workerimport unittest 17*da0073e9SAndroid Build Coastguard Workerimport warnings 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerimport torch 20*da0073e9SAndroid Build Coastguard Workerimport torch.utils.data.datapipes as dp 21*da0073e9SAndroid Build Coastguard Workerfrom torch import multiprocessing as mp 22*da0073e9SAndroid Build Coastguard Workerfrom torch._utils import ExceptionWrapper 23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import instantiate_device_type_tests 24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 25*da0073e9SAndroid Build Coastguard Worker IS_CI, 26*da0073e9SAndroid Build Coastguard Worker IS_JETSON, 27*da0073e9SAndroid Build Coastguard Worker IS_MACOS, 28*da0073e9SAndroid Build Coastguard Worker IS_SANDCASTLE, 29*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 30*da0073e9SAndroid Build Coastguard Worker load_tests, 31*da0073e9SAndroid Build Coastguard Worker NO_MULTIPROCESSING_SPAWN, 32*da0073e9SAndroid Build Coastguard Worker parametrize, 33*da0073e9SAndroid Build Coastguard Worker run_tests, 34*da0073e9SAndroid Build Coastguard Worker skipIfNoDill, 35*da0073e9SAndroid Build Coastguard Worker skipIfRocm, 36*da0073e9SAndroid Build Coastguard Worker slowTest, 37*da0073e9SAndroid Build Coastguard Worker TEST_CUDA, 38*da0073e9SAndroid Build Coastguard Worker TEST_NUMPY, 39*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ASAN, 40*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 41*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 42*da0073e9SAndroid Build Coastguard Worker TestCase, 43*da0073e9SAndroid Build Coastguard Worker) 44*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data import ( 45*da0073e9SAndroid Build Coastguard Worker _utils, 46*da0073e9SAndroid Build Coastguard Worker ChainDataset, 47*da0073e9SAndroid Build Coastguard Worker ConcatDataset, 48*da0073e9SAndroid Build Coastguard Worker DataLoader, 49*da0073e9SAndroid Build Coastguard Worker Dataset, 50*da0073e9SAndroid Build Coastguard Worker IterableDataset, 51*da0073e9SAndroid Build Coastguard Worker IterDataPipe, 52*da0073e9SAndroid Build Coastguard Worker StackDataset, 53*da0073e9SAndroid Build Coastguard Worker Subset, 54*da0073e9SAndroid Build Coastguard Worker TensorDataset, 55*da0073e9SAndroid Build Coastguard Worker) 56*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL 57*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.datapipes.iter import IterableWrapper 58*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.dataset import random_split 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Workertry: 62*da0073e9SAndroid Build Coastguard Worker import psutil 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker HAS_PSUTIL = True 65*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError: 66*da0073e9SAndroid Build Coastguard Worker HAS_PSUTIL = False 67*da0073e9SAndroid Build Coastguard Worker psutil = None 68*da0073e9SAndroid Build Coastguard Worker err_msg = ( 69*da0073e9SAndroid Build Coastguard Worker "psutil not found. Some critical data loader tests relying on it " 70*da0073e9SAndroid Build Coastguard Worker "(e.g., TestDataLoader.test_proper_exit) will not run." 71*da0073e9SAndroid Build Coastguard Worker ) 72*da0073e9SAndroid Build Coastguard Worker if IS_CI: 73*da0073e9SAndroid Build Coastguard Worker raise ModuleNotFoundError(err_msg) from None 74*da0073e9SAndroid Build Coastguard Worker else: 75*da0073e9SAndroid Build Coastguard Worker warnings.warn(err_msg) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Workertry: 79*da0073e9SAndroid Build Coastguard Worker import numpy as np 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker HAS_NUMPY = True 82*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError: 83*da0073e9SAndroid Build Coastguard Worker HAS_NUMPY = False 84*da0073e9SAndroid Build Coastguard Worker np = None 85*da0073e9SAndroid Build Coastguard WorkerskipIfNoNumpy = unittest.skipIf(not HAS_NUMPY, "no NumPy") 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for 88*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings 89*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard WorkerTEST_CUDA_IPC = ( 92*da0073e9SAndroid Build Coastguard Worker torch.cuda.is_available() 93*da0073e9SAndroid Build Coastguard Worker and sys.platform != "darwin" 94*da0073e9SAndroid Build Coastguard Worker and sys.platform != "win32" 95*da0073e9SAndroid Build Coastguard Worker and not IS_JETSON 96*da0073e9SAndroid Build Coastguard Worker and not TEST_WITH_ROCM 97*da0073e9SAndroid Build Coastguard Worker) # https://github.com/pytorch/pytorch/issues/90940 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard WorkerTEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Workerif not NO_MULTIPROCESSING_SPAWN: 102*da0073e9SAndroid Build Coastguard Worker # We want to use `spawn` if able because some of our tests check that the 103*da0073e9SAndroid Build Coastguard Worker # data loader terminiates gracefully. To prevent hanging in the testing 104*da0073e9SAndroid Build Coastguard Worker # process, such data loaders are run in a separate subprocess. 105*da0073e9SAndroid Build Coastguard Worker # 106*da0073e9SAndroid Build Coastguard Worker # We also want to test the `pin_memory=True` configuration, thus `spawn` is 107*da0073e9SAndroid Build Coastguard Worker # required to launch such processes and they initialize the CUDA context. 108*da0073e9SAndroid Build Coastguard Worker # 109*da0073e9SAndroid Build Coastguard Worker # Mixing different start method is a recipe for disaster (e.g., using a fork 110*da0073e9SAndroid Build Coastguard Worker # `mp.Event` with a spawn `mp.Process` segfaults). So we set this globally 111*da0073e9SAndroid Build Coastguard Worker # to avoid bugs. 112*da0073e9SAndroid Build Coastguard Worker # 113*da0073e9SAndroid Build Coastguard Worker # Get a multiprocessing context because some test / third party library will 114*da0073e9SAndroid Build Coastguard Worker # set start_method when imported, and setting again triggers `RuntimeError`. 115*da0073e9SAndroid Build Coastguard Worker mp = mp.get_context(method="spawn") 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker# 60s of timeout? 119*da0073e9SAndroid Build Coastguard Worker# Yes, in environments where physical CPU resources are shared, e.g., CI, the 120*da0073e9SAndroid Build Coastguard Worker# time for a inter-process communication can be highly varying. With 15~17s of 121*da0073e9SAndroid Build Coastguard Worker# timeout, we have observed flakiness in some CI builds (see 122*da0073e9SAndroid Build Coastguard Worker# pytorch/pytorch#14501, pytorch/pytorch#16608). We follow the CPython 123*da0073e9SAndroid Build Coastguard Worker# multiprocessing setup and set the timeout to 60s here: 124*da0073e9SAndroid Build Coastguard Worker# 125*da0073e9SAndroid Build Coastguard Worker# https://github.com/python/cpython/blob/e8113f51a8bdf33188ee30a1c038a298329e7bfa/Lib/test/_test_multiprocessing.py#L73 126*da0073e9SAndroid Build Coastguard WorkerJOIN_TIMEOUT = 60.0 # seconds 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Workersupported_multiprocessing_contexts = [None] + list( 130*da0073e9SAndroid Build Coastguard Worker torch.multiprocessing.get_all_start_methods() 131*da0073e9SAndroid Build Coastguard Worker) 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker# collate_fn that returns the batch cloned; defined globally here for pickle purposes. 135*da0073e9SAndroid Build Coastguard Workerdef _clone_collate(b): 136*da0073e9SAndroid Build Coastguard Worker return [x.clone() for x in b] 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 140*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 141*da0073e9SAndroid Build Coastguard Worker "Fails with TSAN with the following error: starting new threads after multi-threaded " 142*da0073e9SAndroid Build Coastguard Worker "fork is not supported. Dying (set die_after_fork=0 to override)", 143*da0073e9SAndroid Build Coastguard Worker) 144*da0073e9SAndroid Build Coastguard Workerclass TestDatasetRandomSplit(TestCase): 145*da0073e9SAndroid Build Coastguard Worker def test_lengths_must_equal_dataset_size(self): 146*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 147*da0073e9SAndroid Build Coastguard Worker random_split([1, 2, 3, 4], [1, 2]) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker def test_splits_have_correct_size(self): 150*da0073e9SAndroid Build Coastguard Worker splits = random_split([1, 2, 3, 4, 5, 6], [2, 4]) 151*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(splits), 2) 152*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(splits[0]), 2) 153*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(splits[1]), 4) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker splits = random_split([1, 2, 3, 4, 5, 6], [0.5, 0.5]) 156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(splits), 2) 157*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(splits[0]), 3) 158*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(splits[1]), 3) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker # Odd size splits 161*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 162*da0073e9SAndroid Build Coastguard Worker len( 163*da0073e9SAndroid Build Coastguard Worker random_split( 164*da0073e9SAndroid Build Coastguard Worker range(3), [0.5, 0.5], generator=torch.Generator().manual_seed(1) 165*da0073e9SAndroid Build Coastguard Worker ) 166*da0073e9SAndroid Build Coastguard Worker ), 167*da0073e9SAndroid Build Coastguard Worker 2, 168*da0073e9SAndroid Build Coastguard Worker ) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker # Odd sized round-robin splits 171*da0073e9SAndroid Build Coastguard Worker splits = random_split( 172*da0073e9SAndroid Build Coastguard Worker range(106), [0.1, 0.2, 0.3, 0.4], generator=torch.Generator().manual_seed(1) 173*da0073e9SAndroid Build Coastguard Worker ) 174*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(splits[0]), 11) 175*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(splits[1]), 22) 176*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(splits[2]), 31) 177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(splits[3]), 42) 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker def test_splits_are_mutually_exclusive(self): 180*da0073e9SAndroid Build Coastguard Worker data = [5, 2, 3, 4, 1, 6] 181*da0073e9SAndroid Build Coastguard Worker splits = random_split(data, [2, 4]) 182*da0073e9SAndroid Build Coastguard Worker all_values = [] 183*da0073e9SAndroid Build Coastguard Worker all_values.extend(list(splits[0])) 184*da0073e9SAndroid Build Coastguard Worker all_values.extend(list(splits[1])) 185*da0073e9SAndroid Build Coastguard Worker data.sort() 186*da0073e9SAndroid Build Coastguard Worker all_values.sort() 187*da0073e9SAndroid Build Coastguard Worker self.assertListEqual(data, all_values) 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker splits = random_split(data, [0.33, 0.67]) 190*da0073e9SAndroid Build Coastguard Worker all_values = [] 191*da0073e9SAndroid Build Coastguard Worker all_values.extend(list(splits[0])) 192*da0073e9SAndroid Build Coastguard Worker all_values.extend(list(splits[1])) 193*da0073e9SAndroid Build Coastguard Worker data.sort() 194*da0073e9SAndroid Build Coastguard Worker all_values.sort() 195*da0073e9SAndroid Build Coastguard Worker self.assertListEqual(data, all_values) 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker data = [1, 2, 3, 4] 198*da0073e9SAndroid Build Coastguard Worker splits = random_split(data, [0.25, 0.75]) 199*da0073e9SAndroid Build Coastguard Worker all_values = [] 200*da0073e9SAndroid Build Coastguard Worker all_values.extend(list(splits[0])) 201*da0073e9SAndroid Build Coastguard Worker all_values.extend(list(splits[1])) 202*da0073e9SAndroid Build Coastguard Worker data.sort() 203*da0073e9SAndroid Build Coastguard Worker all_values.sort() 204*da0073e9SAndroid Build Coastguard Worker self.assertListEqual(data, all_values) 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker def test_splits_indexing_type(self): 207*da0073e9SAndroid Build Coastguard Worker r"""Indices generated by random_split 208*da0073e9SAndroid Build Coastguard Worker should be of integer type 209*da0073e9SAndroid Build Coastguard Worker """ 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker class CustomDataset: 212*da0073e9SAndroid Build Coastguard Worker def __init__(self, test_object, custom_list): 213*da0073e9SAndroid Build Coastguard Worker self.data = custom_list 214*da0073e9SAndroid Build Coastguard Worker self.test_object = test_object 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, key): 217*da0073e9SAndroid Build Coastguard Worker self.test_object.assertEqual(type(key), int) 218*da0073e9SAndroid Build Coastguard Worker return self.data[key] 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker def __len__(self): 221*da0073e9SAndroid Build Coastguard Worker return len(self.data) 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker x = [1, 2, 3, 4, 5] 224*da0073e9SAndroid Build Coastguard Worker dataset = CustomDataset(self, x) 225*da0073e9SAndroid Build Coastguard Worker dataset = random_split(dataset, [5])[0] 226*da0073e9SAndroid Build Coastguard Worker data_loader = DataLoader(dataset) 227*da0073e9SAndroid Build Coastguard Worker for batch in data_loader: 228*da0073e9SAndroid Build Coastguard Worker pass 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker # fractional splitting 231*da0073e9SAndroid Build Coastguard Worker dataset = CustomDataset(self, x) 232*da0073e9SAndroid Build Coastguard Worker dataset = random_split(dataset, [1.0])[0] 233*da0073e9SAndroid Build Coastguard Worker data_loader = DataLoader(dataset) 234*da0073e9SAndroid Build Coastguard Worker for batch in data_loader: 235*da0073e9SAndroid Build Coastguard Worker pass 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker def test_splits_reproducibility(self): 238*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 239*da0073e9SAndroid Build Coastguard Worker [ 240*da0073e9SAndroid Build Coastguard Worker list(x) 241*da0073e9SAndroid Build Coastguard Worker for x in random_split( 242*da0073e9SAndroid Build Coastguard Worker range(10), [3, 7], generator=torch.Generator().manual_seed(1) 243*da0073e9SAndroid Build Coastguard Worker ) 244*da0073e9SAndroid Build Coastguard Worker ], 245*da0073e9SAndroid Build Coastguard Worker [[5, 6, 1], [2, 0, 8, 9, 3, 7, 4]], 246*da0073e9SAndroid Build Coastguard Worker ) 247*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 248*da0073e9SAndroid Build Coastguard Worker random_split( 249*da0073e9SAndroid Build Coastguard Worker range(100), [60, 40], generator=torch.Generator().manual_seed(42) 250*da0073e9SAndroid Build Coastguard Worker ), 251*da0073e9SAndroid Build Coastguard Worker random_split( 252*da0073e9SAndroid Build Coastguard Worker range(100), [60, 40], generator=torch.Generator().manual_seed(42) 253*da0073e9SAndroid Build Coastguard Worker ), 254*da0073e9SAndroid Build Coastguard Worker ) 255*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 256*da0073e9SAndroid Build Coastguard Worker random_split( 257*da0073e9SAndroid Build Coastguard Worker range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42) 258*da0073e9SAndroid Build Coastguard Worker ), 259*da0073e9SAndroid Build Coastguard Worker random_split( 260*da0073e9SAndroid Build Coastguard Worker range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42) 261*da0073e9SAndroid Build Coastguard Worker ), 262*da0073e9SAndroid Build Coastguard Worker ) 263*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 264*da0073e9SAndroid Build Coastguard Worker random_split( 265*da0073e9SAndroid Build Coastguard Worker range(100), 266*da0073e9SAndroid Build Coastguard Worker [0.33, 0.33, 0.34], 267*da0073e9SAndroid Build Coastguard Worker generator=torch.Generator().manual_seed(42), 268*da0073e9SAndroid Build Coastguard Worker ), 269*da0073e9SAndroid Build Coastguard Worker random_split( 270*da0073e9SAndroid Build Coastguard Worker range(100), 271*da0073e9SAndroid Build Coastguard Worker [0.33, 0.33, 0.34], 272*da0073e9SAndroid Build Coastguard Worker generator=torch.Generator().manual_seed(42), 273*da0073e9SAndroid Build Coastguard Worker ), 274*da0073e9SAndroid Build Coastguard Worker ) 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker def test_incomplete_fractional_splits(self): 277*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 278*da0073e9SAndroid Build Coastguard Worker # should raise since the sum of fractions is not 1 279*da0073e9SAndroid Build Coastguard Worker random_split([1, 2, 3, 4], [0.1]) 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 282*da0073e9SAndroid Build Coastguard Worker # should raise since fraction > 1 283*da0073e9SAndroid Build Coastguard Worker random_split([1, 2, 3, 4], [1.1]) 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker def test_splits_generator(self): 286*da0073e9SAndroid Build Coastguard Worker # A random_split without a specific generator should affect the default one 287*da0073e9SAndroid Build Coastguard Worker state = torch.get_rng_state() 288*da0073e9SAndroid Build Coastguard Worker a = torch.rand(10) 289*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(state) 290*da0073e9SAndroid Build Coastguard Worker random_split(range(10), [5, 5]) 291*da0073e9SAndroid Build Coastguard Worker b = torch.rand(10) 292*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(a, b) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker # A random_split with a specific generator should not affect the default one 295*da0073e9SAndroid Build Coastguard Worker state = torch.get_rng_state() 296*da0073e9SAndroid Build Coastguard Worker a = torch.rand(10) 297*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(state) 298*da0073e9SAndroid Build Coastguard Worker random_split(range(10), [5, 5], generator=torch.Generator().manual_seed(42)) 299*da0073e9SAndroid Build Coastguard Worker b = torch.rand(10) 300*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker def test_slicing_of_subset_of_dataset(self): 303*da0073e9SAndroid Build Coastguard Worker # Testing slicing a subset initialized with a dataset 304*da0073e9SAndroid Build Coastguard Worker dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5])) 305*da0073e9SAndroid Build Coastguard Worker subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4]) 306*da0073e9SAndroid Build Coastguard Worker self.assertEqual(subset_of_dataset[:], dataset[:]) 307*da0073e9SAndroid Build Coastguard Worker self.assertEqual(subset_of_dataset[1:2], dataset[1:2]) 308*da0073e9SAndroid Build Coastguard Worker self.assertEqual(subset_of_dataset[0:-1:2], dataset[0:-1:2]) 309*da0073e9SAndroid Build Coastguard Worker # Testing slicing of subset from random split 310*da0073e9SAndroid Build Coastguard Worker subset1, subset2 = random_split(dataset, [3, 2]) 311*da0073e9SAndroid Build Coastguard Worker self.assertEqual(subset1[:], dataset[subset1.indices[:]]) 312*da0073e9SAndroid Build Coastguard Worker self.assertEqual(subset1[0:2], dataset[subset1.indices[0:2]]) 313*da0073e9SAndroid Build Coastguard Worker self.assertEqual(subset1[0:-1:2], dataset[subset1.indices[0:-1:2]]) 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker def test_slicing_of_subset_of_subset(self): 316*da0073e9SAndroid Build Coastguard Worker # Testing slicing a subset initialized with a subset 317*da0073e9SAndroid Build Coastguard Worker dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5])) 318*da0073e9SAndroid Build Coastguard Worker subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4]) 319*da0073e9SAndroid Build Coastguard Worker subset_of_subset = Subset(subset_of_dataset, [0, 1, 2, 3, 4]) 320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(subset_of_subset[:], dataset[:]) 321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(subset_of_subset[0:2], dataset[0:2]) 322*da0073e9SAndroid Build Coastguard Worker self.assertEqual(subset_of_subset[0:-1:2], dataset[0:-1:2]) 323*da0073e9SAndroid Build Coastguard Worker # Testing slicing of subset of subset from random split 324*da0073e9SAndroid Build Coastguard Worker subset1, subset2 = random_split(dataset, [4, 1]) 325*da0073e9SAndroid Build Coastguard Worker subset_of_subset1, subset_of_subset2 = random_split(subset1, [3, 1]) 326*da0073e9SAndroid Build Coastguard Worker idx = [subset1.indices[i] for i in subset_of_subset1.indices] 327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(subset_of_subset1[:], dataset[idx.copy()]) 328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(subset_of_subset1[0:2], dataset[idx[0:2]]) 329*da0073e9SAndroid Build Coastguard Worker self.assertEqual(subset_of_subset1[0:-1:2], dataset[idx[0:-1:2]]) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Workerclass CUDACountingDataset(Dataset): 333*da0073e9SAndroid Build Coastguard Worker def __init__(self, n): 334*da0073e9SAndroid Build Coastguard Worker super().__init__() 335*da0073e9SAndroid Build Coastguard Worker self.n = n 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, i): 338*da0073e9SAndroid Build Coastguard Worker return torch.as_tensor(i, device="cuda") 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker def __len__(self): 341*da0073e9SAndroid Build Coastguard Worker return self.n 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Workerclass CountingDataset(Dataset): 345*da0073e9SAndroid Build Coastguard Worker def __init__(self, n): 346*da0073e9SAndroid Build Coastguard Worker super().__init__() 347*da0073e9SAndroid Build Coastguard Worker self.n = n 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, i): 350*da0073e9SAndroid Build Coastguard Worker return i 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker def __len__(self): 353*da0073e9SAndroid Build Coastguard Worker return self.n 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Workerclass CountingIterableDataset(IterableDataset): 357*da0073e9SAndroid Build Coastguard Worker def __init__(self, n): 358*da0073e9SAndroid Build Coastguard Worker super().__init__() 359*da0073e9SAndroid Build Coastguard Worker self.n = n 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 362*da0073e9SAndroid Build Coastguard Worker return iter(range(self.n)) 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker def __len__(self): 365*da0073e9SAndroid Build Coastguard Worker return self.n 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 369*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 370*da0073e9SAndroid Build Coastguard Worker "Fails with TSAN with the following error: starting new threads after multi-threaded " 371*da0073e9SAndroid Build Coastguard Worker "fork is not supported. Dying (set die_after_fork=0 to override)", 372*da0073e9SAndroid Build Coastguard Worker) 373*da0073e9SAndroid Build Coastguard Workerclass TestTensorDataset(TestCase): 374*da0073e9SAndroid Build Coastguard Worker def test_len(self): 375*da0073e9SAndroid Build Coastguard Worker source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15)) 376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(source), 15) 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker def test_getitem(self): 379*da0073e9SAndroid Build Coastguard Worker t = torch.randn(15, 10, 2, 3, 4, 5) 380*da0073e9SAndroid Build Coastguard Worker l = torch.randn(15, 10) 381*da0073e9SAndroid Build Coastguard Worker source = TensorDataset(t, l) 382*da0073e9SAndroid Build Coastguard Worker for i in range(15): 383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[i], source[i][0]) 384*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l[i], source[i][1]) 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker def test_getitem_1d(self): 387*da0073e9SAndroid Build Coastguard Worker t = torch.randn(15) 388*da0073e9SAndroid Build Coastguard Worker l = torch.randn(15) 389*da0073e9SAndroid Build Coastguard Worker source = TensorDataset(t, l) 390*da0073e9SAndroid Build Coastguard Worker for i in range(15): 391*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[i], source[i][0]) 392*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l[i], source[i][1]) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker def test_single_tensor(self): 395*da0073e9SAndroid Build Coastguard Worker t = torch.randn(5, 10) 396*da0073e9SAndroid Build Coastguard Worker source = TensorDataset(t) 397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(source), 5) 398*da0073e9SAndroid Build Coastguard Worker for i in range(5): 399*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[i], source[i][0]) 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker def test_many_tensors(self): 402*da0073e9SAndroid Build Coastguard Worker t0 = torch.randn(5, 10, 2, 3, 4, 5) 403*da0073e9SAndroid Build Coastguard Worker t1 = torch.randn(5, 10) 404*da0073e9SAndroid Build Coastguard Worker t2 = torch.randn(5, 10, 2, 5) 405*da0073e9SAndroid Build Coastguard Worker t3 = torch.randn(5, 10, 3, 7) 406*da0073e9SAndroid Build Coastguard Worker source = TensorDataset(t0, t1, t2, t3) 407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(source), 5) 408*da0073e9SAndroid Build Coastguard Worker for i in range(5): 409*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t0[i], source[i][0]) 410*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1[i], source[i][1]) 411*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t2[i], source[i][2]) 412*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t3[i], source[i][3]) 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 416*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 417*da0073e9SAndroid Build Coastguard Worker "Fails with TSAN with the following error: starting new threads after multi-threaded " 418*da0073e9SAndroid Build Coastguard Worker "fork is not supported. Dying (set die_after_fork=0 to override)", 419*da0073e9SAndroid Build Coastguard Worker) 420*da0073e9SAndroid Build Coastguard Workerclass TestStackDataset(TestCase): 421*da0073e9SAndroid Build Coastguard Worker def test_empty(self): 422*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 423*da0073e9SAndroid Build Coastguard Worker ValueError, "At least one dataset should be passed" 424*da0073e9SAndroid Build Coastguard Worker ): 425*da0073e9SAndroid Build Coastguard Worker StackDataset() 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker def test_mixed(self): 428*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Supported either"): 429*da0073e9SAndroid Build Coastguard Worker StackDataset( 430*da0073e9SAndroid Build Coastguard Worker TensorDataset(torch.randn(15, 10)), a=TensorDataset(torch.randn(10, 15)) 431*da0073e9SAndroid Build Coastguard Worker ) 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker def test_size_mismatch(self): 434*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"): 435*da0073e9SAndroid Build Coastguard Worker StackDataset( 436*da0073e9SAndroid Build Coastguard Worker TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(10, 15)) 437*da0073e9SAndroid Build Coastguard Worker ) 438*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"): 439*da0073e9SAndroid Build Coastguard Worker StackDataset( 440*da0073e9SAndroid Build Coastguard Worker a=TensorDataset(torch.randn(15, 10)), 441*da0073e9SAndroid Build Coastguard Worker b=TensorDataset(torch.randn(10, 15)), 442*da0073e9SAndroid Build Coastguard Worker ) 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker def test_len(self): 445*da0073e9SAndroid Build Coastguard Worker source = StackDataset( 446*da0073e9SAndroid Build Coastguard Worker TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(15)) 447*da0073e9SAndroid Build Coastguard Worker ) 448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(source), 15) 449*da0073e9SAndroid Build Coastguard Worker source = StackDataset(TensorDataset(torch.randn(15, 10))) 450*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(source), 15) 451*da0073e9SAndroid Build Coastguard Worker source = StackDataset( 452*da0073e9SAndroid Build Coastguard Worker a=TensorDataset(torch.randn(15, 10)), b=TensorDataset(torch.randn(15)) 453*da0073e9SAndroid Build Coastguard Worker ) 454*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(source), 15) 455*da0073e9SAndroid Build Coastguard Worker source = StackDataset(a=TensorDataset(torch.randn(15, 10))) 456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(source), 15) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker def test_single(self): 459*da0073e9SAndroid Build Coastguard Worker t = TensorDataset(torch.randn(15, 10)) 460*da0073e9SAndroid Build Coastguard Worker source = StackDataset(t) 461*da0073e9SAndroid Build Coastguard Worker for i in range(15): 462*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[i], source[i][0]) 463*da0073e9SAndroid Build Coastguard Worker source = StackDataset(a=t) 464*da0073e9SAndroid Build Coastguard Worker for i in range(15): 465*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[i], source[i]["a"]) 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker def test_getitem(self): 468*da0073e9SAndroid Build Coastguard Worker t = TensorDataset(torch.randn(15, 10)) 469*da0073e9SAndroid Build Coastguard Worker l = TensorDataset(torch.randn(15, 5, 4)) 470*da0073e9SAndroid Build Coastguard Worker source = StackDataset(t, l) 471*da0073e9SAndroid Build Coastguard Worker for i in range(15): 472*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[i], source[i][0]) 473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l[i], source[i][1]) 474*da0073e9SAndroid Build Coastguard Worker source = StackDataset(a=t, b=l) 475*da0073e9SAndroid Build Coastguard Worker for i in range(15): 476*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[i], source[i]["a"]) 477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l[i], source[i]["b"]) 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker def test_getitems(self): 480*da0073e9SAndroid Build Coastguard Worker class GetItemsDataset(Dataset): 481*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 482*da0073e9SAndroid Build Coastguard Worker self.data = torch.randn(4) 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, item): 485*da0073e9SAndroid Build Coastguard Worker return self.data[item] 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker def __getitems__(self, items): 488*da0073e9SAndroid Build Coastguard Worker return self.data[items] 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker def __len__(self): 491*da0073e9SAndroid Build Coastguard Worker return 4 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker t = GetItemsDataset() 494*da0073e9SAndroid Build Coastguard Worker l = [1, 2, 3, 4] 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker source = StackDataset(t, l) 497*da0073e9SAndroid Build Coastguard Worker batch = source.__getitems__([0, 1, 2, 3]) 498*da0073e9SAndroid Build Coastguard Worker for i in range(4): 499*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[i], batch[i][0]) 500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l[i], batch[i][1]) 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker source = StackDataset(t=t, l=l) 503*da0073e9SAndroid Build Coastguard Worker batch = source.__getitems__([0, 1, 2, 3]) 504*da0073e9SAndroid Build Coastguard Worker for i in range(4): 505*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[i], batch[i]["t"]) 506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l[i], batch[i]["l"]) 507*da0073e9SAndroid Build Coastguard Worker 508*da0073e9SAndroid Build Coastguard Worker def test_getitems_raises_index_error(self): 509*da0073e9SAndroid Build Coastguard Worker class GetItemsDataset(Dataset): 510*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 511*da0073e9SAndroid Build Coastguard Worker self.data = torch.randn(4) 512*da0073e9SAndroid Build Coastguard Worker 513*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, item): 514*da0073e9SAndroid Build Coastguard Worker return self.data[item] 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker def __getitems__(self, items): 517*da0073e9SAndroid Build Coastguard Worker return self.data[items] 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker def __len__(self): 520*da0073e9SAndroid Build Coastguard Worker return 4 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker t = GetItemsDataset() 523*da0073e9SAndroid Build Coastguard Worker l = [1, 2, 3, 4] 524*da0073e9SAndroid Build Coastguard Worker 525*da0073e9SAndroid Build Coastguard Worker source = StackDataset(t, l) 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 528*da0073e9SAndroid Build Coastguard Worker source.__getitems__([0, 4]) 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker def test_getitems_value_error(self): 531*da0073e9SAndroid Build Coastguard Worker class GetItemsDataset(Dataset): 532*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 533*da0073e9SAndroid Build Coastguard Worker self.data = torch.randn(4) 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, item): 536*da0073e9SAndroid Build Coastguard Worker return self.data[item] 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker def __getitems__(self, items): 539*da0073e9SAndroid Build Coastguard Worker return self.data[items][:-1] # return less 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker def __len__(self): 542*da0073e9SAndroid Build Coastguard Worker return 4 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker t = GetItemsDataset() 545*da0073e9SAndroid Build Coastguard Worker l = [1, 2, 3, 4] 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker source = StackDataset(t, l) 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 550*da0073e9SAndroid Build Coastguard Worker ValueError, "Nested dataset's output size mismatch. Expected 4, got 3" 551*da0073e9SAndroid Build Coastguard Worker ): 552*da0073e9SAndroid Build Coastguard Worker source.__getitems__([0, 1, 2, 3]) 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 556*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 557*da0073e9SAndroid Build Coastguard Worker "Fails with TSAN with the following error: starting new threads after multi-threaded " 558*da0073e9SAndroid Build Coastguard Worker "fork is not supported. Dying (set die_after_fork=0 to override)", 559*da0073e9SAndroid Build Coastguard Worker) 560*da0073e9SAndroid Build Coastguard Workerclass TestConcatDataset(TestCase): 561*da0073e9SAndroid Build Coastguard Worker def test_concat_two_singletons(self): 562*da0073e9SAndroid Build Coastguard Worker result = ConcatDataset([[0], [1]]) 563*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, len(result)) 564*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, result[0]) 565*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, result[1]) 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker def test_concat_two_non_singletons(self): 568*da0073e9SAndroid Build Coastguard Worker result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) 569*da0073e9SAndroid Build Coastguard Worker self.assertEqual(10, len(result)) 570*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, result[0]) 571*da0073e9SAndroid Build Coastguard Worker self.assertEqual(5, result[5]) 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker def test_concat_two_non_singletons_with_empty(self): 574*da0073e9SAndroid Build Coastguard Worker # Adding an empty dataset somewhere is correctly handled 575*da0073e9SAndroid Build Coastguard Worker result = ConcatDataset([[0, 1, 2, 3, 4], [], [5, 6, 7, 8, 9]]) 576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(10, len(result)) 577*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, result[0]) 578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(5, result[5]) 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker def test_concat_raises_index_error(self): 581*da0073e9SAndroid Build Coastguard Worker result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) 582*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 583*da0073e9SAndroid Build Coastguard Worker # this one goes to 11 584*da0073e9SAndroid Build Coastguard Worker result[11] 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker def test_add_dataset(self): 587*da0073e9SAndroid Build Coastguard Worker d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7)) 588*da0073e9SAndroid Build Coastguard Worker d2 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7)) 589*da0073e9SAndroid Build Coastguard Worker d3 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7)) 590*da0073e9SAndroid Build Coastguard Worker result = d1 + d2 + d3 591*da0073e9SAndroid Build Coastguard Worker self.assertEqual(21, len(result)) 592*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, (d1[0][0] - result[0][0]).abs().sum()) 593*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, (d2[0][0] - result[7][0]).abs().sum()) 594*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum()) 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker def test_iterable_dataset_err(self): 597*da0073e9SAndroid Build Coastguard Worker d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7)) 598*da0073e9SAndroid Build Coastguard Worker it1 = CountingIterableDataset(5) 599*da0073e9SAndroid Build Coastguard Worker it2 = CountingIterableDataset(10) 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"): 602*da0073e9SAndroid Build Coastguard Worker ConcatDataset([d1, it2, it1]) 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"): 605*da0073e9SAndroid Build Coastguard Worker ConcatDataset([it2]) 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"): 608*da0073e9SAndroid Build Coastguard Worker ConcatDataset([it1, d1]) 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker 611*da0073e9SAndroid Build Coastguard Worker# takes in dummy var so this can also be used as a `worker_init_fn` 612*da0073e9SAndroid Build Coastguard Workerdef set_faulthander_if_available(_=None): 613*da0073e9SAndroid Build Coastguard Worker faulthandler.enable(sys.__stderr__) 614*da0073e9SAndroid Build Coastguard Worker if not IS_WINDOWS: 615*da0073e9SAndroid Build Coastguard Worker # windows does not have faulthandler.register 616*da0073e9SAndroid Build Coastguard Worker # chain=False prevents the default behavior of killing the process 617*da0073e9SAndroid Build Coastguard Worker faulthandler.register(signal.SIGUSR1, file=sys.__stderr__, chain=False) 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Workerset_faulthander_if_available() 621*da0073e9SAndroid Build Coastguard Worker 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker# Process `pid` must have called `set_faulthander_if_available` 624*da0073e9SAndroid Build Coastguard Workerdef print_traces_of_all_threads(pid): 625*da0073e9SAndroid Build Coastguard Worker if not IS_WINDOWS: 626*da0073e9SAndroid Build Coastguard Worker # use the custom signal if available 627*da0073e9SAndroid Build Coastguard Worker os.kill(pid, signal.SIGUSR1) 628*da0073e9SAndroid Build Coastguard Worker else: 629*da0073e9SAndroid Build Coastguard Worker # otherwise we can still use the handler given by faulthandler.enable() 630*da0073e9SAndroid Build Coastguard Worker # at the cost of killing the process. 631*da0073e9SAndroid Build Coastguard Worker os.kill(pid, signal.SIGSEGV) 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker # wait in parent process to give subprocess some time to print 634*da0073e9SAndroid Build Coastguard Worker time.sleep(5) 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker# The following `ErrorTrackingProcess` stores the first encountered exception in 638*da0073e9SAndroid Build Coastguard Worker# its `.exception` attribute. 639*da0073e9SAndroid Build Coastguard Worker# Inspired by https://stackoverflow.com/a/33599967 640*da0073e9SAndroid Build Coastguard Workerclass ErrorTrackingProcess(mp.Process): 641*da0073e9SAndroid Build Coastguard Worker # Why no *args? 642*da0073e9SAndroid Build Coastguard Worker # py2 doesn't support def fn(x, *args, key=val, **kwargs) 643*da0073e9SAndroid Build Coastguard Worker # Setting disable_stderr=True may generate a lot of unrelated error outputs 644*da0073e9SAndroid Build Coastguard Worker # but could be helpful for debugging. 645*da0073e9SAndroid Build Coastguard Worker def __init__(self, disable_stderr=True, **kwargs): 646*da0073e9SAndroid Build Coastguard Worker super().__init__(**kwargs) 647*da0073e9SAndroid Build Coastguard Worker self._pconn, self._cconn = mp.Pipe() 648*da0073e9SAndroid Build Coastguard Worker self._exception = None 649*da0073e9SAndroid Build Coastguard Worker self.disable_stderr = disable_stderr 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker def run(self): 652*da0073e9SAndroid Build Coastguard Worker set_faulthander_if_available() 653*da0073e9SAndroid Build Coastguard Worker if self.disable_stderr: 654*da0073e9SAndroid Build Coastguard Worker # Disable polluting stderr with errors that are supposed to happen. 655*da0073e9SAndroid Build Coastguard Worker with open(os.devnull, "w") as devnull: 656*da0073e9SAndroid Build Coastguard Worker os.dup2(devnull.fileno(), sys.stderr.fileno()) 657*da0073e9SAndroid Build Coastguard Worker try: 658*da0073e9SAndroid Build Coastguard Worker super().run() 659*da0073e9SAndroid Build Coastguard Worker self._cconn.send(None) 660*da0073e9SAndroid Build Coastguard Worker except Exception: 661*da0073e9SAndroid Build Coastguard Worker self._cconn.send(ExceptionWrapper(sys.exc_info())) 662*da0073e9SAndroid Build Coastguard Worker raise 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker def print_traces_of_all_threads(self): 665*da0073e9SAndroid Build Coastguard Worker assert ( 666*da0073e9SAndroid Build Coastguard Worker self.is_alive() 667*da0073e9SAndroid Build Coastguard Worker ), "can only use print_traces_of_all_threads if the process is alive" 668*da0073e9SAndroid Build Coastguard Worker assert ( 669*da0073e9SAndroid Build Coastguard Worker not self.disable_stderr 670*da0073e9SAndroid Build Coastguard Worker ), "do not disable stderr if you use print_traces_of_all_threads" 671*da0073e9SAndroid Build Coastguard Worker # On platforms without `SIGUSR1`, `set_faulthander_if_available` sets 672*da0073e9SAndroid Build Coastguard Worker # `faulthandler.enable()`, and `print_traces_of_all_threads` may kill 673*da0073e9SAndroid Build Coastguard Worker # the process. So let's poll the exception first 674*da0073e9SAndroid Build Coastguard Worker _ = self.exception 675*da0073e9SAndroid Build Coastguard Worker print_traces_of_all_threads(self.pid) 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker @property 678*da0073e9SAndroid Build Coastguard Worker def exception(self): 679*da0073e9SAndroid Build Coastguard Worker if self._pconn.poll(): 680*da0073e9SAndroid Build Coastguard Worker self._exception = self._pconn.recv() 681*da0073e9SAndroid Build Coastguard Worker if self._exception is None: 682*da0073e9SAndroid Build Coastguard Worker return None 683*da0073e9SAndroid Build Coastguard Worker else: 684*da0073e9SAndroid Build Coastguard Worker return self._exception.exc_type(self._exception.exc_msg) 685*da0073e9SAndroid Build Coastguard Worker 686*da0073e9SAndroid Build Coastguard Worker # ESRCH means that os.kill can't finds alive proc 687*da0073e9SAndroid Build Coastguard Worker def send_signal(self, signum, ignore_ESRCH=False): 688*da0073e9SAndroid Build Coastguard Worker try: 689*da0073e9SAndroid Build Coastguard Worker os.kill(self.pid, signum) 690*da0073e9SAndroid Build Coastguard Worker except OSError as e: 691*da0073e9SAndroid Build Coastguard Worker if not ignore_ESRCH or e.errno != errno.ESRCH: 692*da0073e9SAndroid Build Coastguard Worker raise 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Workerclass ErrorDataset(Dataset): 696*da0073e9SAndroid Build Coastguard Worker def __init__(self, size): 697*da0073e9SAndroid Build Coastguard Worker self.size = size 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Worker def __len__(self): 700*da0073e9SAndroid Build Coastguard Worker return self.size 701*da0073e9SAndroid Build Coastguard Worker 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Workerclass SegfaultDataset(Dataset): 704*da0073e9SAndroid Build Coastguard Worker def __init__(self, size): 705*da0073e9SAndroid Build Coastguard Worker self.size = size 706*da0073e9SAndroid Build Coastguard Worker 707*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, idx): 708*da0073e9SAndroid Build Coastguard Worker return ctypes.string_at(0) 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker def __len__(self): 711*da0073e9SAndroid Build Coastguard Worker return self.size 712*da0073e9SAndroid Build Coastguard Worker 713*da0073e9SAndroid Build Coastguard Worker 714*da0073e9SAndroid Build Coastguard Workerclass SleepDataset(Dataset): 715*da0073e9SAndroid Build Coastguard Worker def __init__(self, size, sleep_sec): 716*da0073e9SAndroid Build Coastguard Worker self.size = size 717*da0073e9SAndroid Build Coastguard Worker self.sleep_sec = sleep_sec 718*da0073e9SAndroid Build Coastguard Worker self.sleeped = False 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, idx): 721*da0073e9SAndroid Build Coastguard Worker if not self.sleeped: 722*da0073e9SAndroid Build Coastguard Worker time.sleep(self.sleep_sec) 723*da0073e9SAndroid Build Coastguard Worker self.sleeped = True 724*da0073e9SAndroid Build Coastguard Worker return idx 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Worker def __len__(self): 727*da0073e9SAndroid Build Coastguard Worker return self.size 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Workerclass SeedDataset(Dataset): 731*da0073e9SAndroid Build Coastguard Worker def __init__(self, size): 732*da0073e9SAndroid Build Coastguard Worker self.size = size 733*da0073e9SAndroid Build Coastguard Worker 734*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, idx): 735*da0073e9SAndroid Build Coastguard Worker return torch.initial_seed() 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker def __len__(self): 738*da0073e9SAndroid Build Coastguard Worker return self.size 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Workerclass WorkerSpecificIterableDataset(IterableDataset): 742*da0073e9SAndroid Build Coastguard Worker def __init__(self, sizes_for_all_workers): 743*da0073e9SAndroid Build Coastguard Worker self.sizes_for_all_workers = sizes_for_all_workers 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 746*da0073e9SAndroid Build Coastguard Worker worker_info = torch.utils.data.get_worker_info() 747*da0073e9SAndroid Build Coastguard Worker assert worker_info is not None 748*da0073e9SAndroid Build Coastguard Worker return iter(range(self.sizes_for_all_workers[worker_info.id])) 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker def __len__(self): 751*da0073e9SAndroid Build Coastguard Worker return sum(self.sizes_for_all_workers) 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker 754*da0073e9SAndroid Build Coastguard Worker# Inspired by https://stackoverflow.com/a/26703365 755*da0073e9SAndroid Build Coastguard Worker# If all workers will call `sync_once`, they will be blocked until all workers 756*da0073e9SAndroid Build Coastguard Worker# reach the call (i.e., acting like a barrier). 757*da0073e9SAndroid Build Coastguard Worker# This can be used to ensure that each worker at least processes one data. 758*da0073e9SAndroid Build Coastguard Workerclass SynchronizedDataset(Dataset): 759*da0073e9SAndroid Build Coastguard Worker def __init__(self, size, batch_size, num_workers): 760*da0073e9SAndroid Build Coastguard Worker assert size >= num_workers * batch_size 761*da0073e9SAndroid Build Coastguard Worker self.count = mp.Value("i", 0, lock=True) 762*da0073e9SAndroid Build Coastguard Worker self.barrier = mp.Semaphore(0) 763*da0073e9SAndroid Build Coastguard Worker self.num_workers = num_workers 764*da0073e9SAndroid Build Coastguard Worker self.size = size 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker def sync_once(self): 767*da0073e9SAndroid Build Coastguard Worker with self.count.get_lock(): 768*da0073e9SAndroid Build Coastguard Worker self.count.value += 1 769*da0073e9SAndroid Build Coastguard Worker if self.count.value == self.num_workers: 770*da0073e9SAndroid Build Coastguard Worker self.barrier.release() 771*da0073e9SAndroid Build Coastguard Worker self.barrier.acquire() 772*da0073e9SAndroid Build Coastguard Worker self.barrier.release() 773*da0073e9SAndroid Build Coastguard Worker 774*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, idx): 775*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker def __len__(self): 778*da0073e9SAndroid Build Coastguard Worker return self.size 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Workerclass EmptyTensorDataset(torch.utils.data.Dataset): 782*da0073e9SAndroid Build Coastguard Worker def __init__(self, len): 783*da0073e9SAndroid Build Coastguard Worker self.len = len 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker def __len__(self): 786*da0073e9SAndroid Build Coastguard Worker return self.len 787*da0073e9SAndroid Build Coastguard Worker 788*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, any): 789*da0073e9SAndroid Build Coastguard Worker return torch.empty(0) 790*da0073e9SAndroid Build Coastguard Worker 791*da0073e9SAndroid Build Coastguard Worker 792*da0073e9SAndroid Build Coastguard Workerclass SynchronizedSeedDataset(SynchronizedDataset): 793*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, idx): 794*da0073e9SAndroid Build Coastguard Worker self.sync_once() 795*da0073e9SAndroid Build Coastguard Worker return torch.initial_seed() 796*da0073e9SAndroid Build Coastguard Worker 797*da0073e9SAndroid Build Coastguard Worker 798*da0073e9SAndroid Build Coastguard Workerdef _test_timeout(persistent_workers): 799*da0073e9SAndroid Build Coastguard Worker dataset = SleepDataset(10, 3) 800*da0073e9SAndroid Build Coastguard Worker dataloader = DataLoader( 801*da0073e9SAndroid Build Coastguard Worker dataset, 802*da0073e9SAndroid Build Coastguard Worker batch_size=2, 803*da0073e9SAndroid Build Coastguard Worker num_workers=2, 804*da0073e9SAndroid Build Coastguard Worker timeout=1, 805*da0073e9SAndroid Build Coastguard Worker persistent_workers=persistent_workers, 806*da0073e9SAndroid Build Coastguard Worker ) 807*da0073e9SAndroid Build Coastguard Worker _ = next(iter(dataloader)) 808*da0073e9SAndroid Build Coastguard Worker 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Workerdef _test_timeout_pin_memory(persistent_workers): 811*da0073e9SAndroid Build Coastguard Worker dataset = SleepDataset(10, 3) 812*da0073e9SAndroid Build Coastguard Worker dataloader = DataLoader( 813*da0073e9SAndroid Build Coastguard Worker dataset, 814*da0073e9SAndroid Build Coastguard Worker batch_size=2, 815*da0073e9SAndroid Build Coastguard Worker num_workers=2, 816*da0073e9SAndroid Build Coastguard Worker timeout=1, 817*da0073e9SAndroid Build Coastguard Worker pin_memory=True, 818*da0073e9SAndroid Build Coastguard Worker persistent_workers=persistent_workers, 819*da0073e9SAndroid Build Coastguard Worker ) 820*da0073e9SAndroid Build Coastguard Worker _ = next(iter(dataloader)) 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker 823*da0073e9SAndroid Build Coastguard Workerdef _test_large_sampler_indices(persistent_workers): 824*da0073e9SAndroid Build Coastguard Worker # See 825*da0073e9SAndroid Build Coastguard Worker # test_large_sampler_indices 826*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/48666 827*da0073e9SAndroid Build Coastguard Worker 828*da0073e9SAndroid Build Coastguard Worker dataloader = torch.utils.data.DataLoader( 829*da0073e9SAndroid Build Coastguard Worker EmptyTensorDataset(10000000), 830*da0073e9SAndroid Build Coastguard Worker batch_size=40960, 831*da0073e9SAndroid Build Coastguard Worker persistent_workers=persistent_workers, 832*da0073e9SAndroid Build Coastguard Worker num_workers=1, 833*da0073e9SAndroid Build Coastguard Worker ) 834*da0073e9SAndroid Build Coastguard Worker 835*da0073e9SAndroid Build Coastguard Worker it = iter(dataloader) 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker for x in it: 838*da0073e9SAndroid Build Coastguard Worker assert x.numel() == 0 839*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("My Error") 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker 842*da0073e9SAndroid Build Coastguard Workerdef disable_stderr(worker_id): 843*da0073e9SAndroid Build Coastguard Worker r""" 844*da0073e9SAndroid Build Coastguard Worker Avoids printing "ERROR: Unexpected segmentation fault encountered in worker." 845*da0073e9SAndroid Build Coastguard Worker from workers. Since worker signal handler prints with low-level write(), 846*da0073e9SAndroid Build Coastguard Worker this has to be done on OS level via dup. 847*da0073e9SAndroid Build Coastguard Worker 848*da0073e9SAndroid Build Coastguard Worker This is used as worker_init_fn for test_segfault. 849*da0073e9SAndroid Build Coastguard Worker """ 850*da0073e9SAndroid Build Coastguard Worker sys.stderr.flush() # flush library buffers that dup2 knows nothing about 851*da0073e9SAndroid Build Coastguard Worker # Can't use a with-block because otherwise the fd will be closed when this 852*da0073e9SAndroid Build Coastguard Worker # function ends. 853*da0073e9SAndroid Build Coastguard Worker with open(os.devnull, "w") as devnull: 854*da0073e9SAndroid Build Coastguard Worker os.dup2(devnull.fileno(), sys.stderr.fileno()) 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker 857*da0073e9SAndroid Build Coastguard Workerdef _test_segfault(): 858*da0073e9SAndroid Build Coastguard Worker dataset = SegfaultDataset(10) 859*da0073e9SAndroid Build Coastguard Worker dataloader = DataLoader( 860*da0073e9SAndroid Build Coastguard Worker dataset, batch_size=2, num_workers=2, worker_init_fn=disable_stderr 861*da0073e9SAndroid Build Coastguard Worker ) 862*da0073e9SAndroid Build Coastguard Worker _ = next(iter(dataloader)) 863*da0073e9SAndroid Build Coastguard Worker 864*da0073e9SAndroid Build Coastguard Worker 865*da0073e9SAndroid Build Coastguard Workerdef _test_no_segfault(): 866*da0073e9SAndroid Build Coastguard Worker dataset = [1, 2, 3] 867*da0073e9SAndroid Build Coastguard Worker num_threads = torch.get_num_threads() 868*da0073e9SAndroid Build Coastguard Worker if num_threads < 4: 869*da0073e9SAndroid Build Coastguard Worker torch.set_num_threads(4) 870*da0073e9SAndroid Build Coastguard Worker else: 871*da0073e9SAndroid Build Coastguard Worker torch.set_num_threads(num_threads) 872*da0073e9SAndroid Build Coastguard Worker mp_ctx = torch.multiprocessing.get_context(method="fork") 873*da0073e9SAndroid Build Coastguard Worker dataloader = DataLoader( 874*da0073e9SAndroid Build Coastguard Worker dataset, 875*da0073e9SAndroid Build Coastguard Worker num_workers=1, 876*da0073e9SAndroid Build Coastguard Worker worker_init_fn=disable_stderr, 877*da0073e9SAndroid Build Coastguard Worker multiprocessing_context=mp_ctx, 878*da0073e9SAndroid Build Coastguard Worker ) 879*da0073e9SAndroid Build Coastguard Worker _ = next(iter(dataloader)) 880*da0073e9SAndroid Build Coastguard Worker 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Workerclass TestProperExitDataset(Dataset): 883*da0073e9SAndroid Build Coastguard Worker def __init__(self, size, error_event): 884*da0073e9SAndroid Build Coastguard Worker self.size = size 885*da0073e9SAndroid Build Coastguard Worker self.error_event = error_event 886*da0073e9SAndroid Build Coastguard Worker 887*da0073e9SAndroid Build Coastguard Worker def __len__(self): 888*da0073e9SAndroid Build Coastguard Worker return self.size 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, idx): 891*da0073e9SAndroid Build Coastguard Worker worker_info = torch.utils.data.get_worker_info() 892*da0073e9SAndroid Build Coastguard Worker if ( 893*da0073e9SAndroid Build Coastguard Worker self.error_event is not None 894*da0073e9SAndroid Build Coastguard Worker and self.error_event.is_set() 895*da0073e9SAndroid Build Coastguard Worker and worker_info.id == worker_info.num_workers - 1 896*da0073e9SAndroid Build Coastguard Worker ): 897*da0073e9SAndroid Build Coastguard Worker # only error in the last worker 898*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Worker error") 899*da0073e9SAndroid Build Coastguard Worker return torch.tensor([idx]) 900*da0073e9SAndroid Build Coastguard Worker 901*da0073e9SAndroid Build Coastguard Worker 902*da0073e9SAndroid Build Coastguard Workerclass TestProperExitIterableDataset(IterableDataset): 903*da0073e9SAndroid Build Coastguard Worker def __init__(self, size, error_event): 904*da0073e9SAndroid Build Coastguard Worker self.error_event = error_event 905*da0073e9SAndroid Build Coastguard Worker self.size = size 906*da0073e9SAndroid Build Coastguard Worker self.remaining = size 907*da0073e9SAndroid Build Coastguard Worker 908*da0073e9SAndroid Build Coastguard Worker def __len__(self): 909*da0073e9SAndroid Build Coastguard Worker return self.size 910*da0073e9SAndroid Build Coastguard Worker 911*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 912*da0073e9SAndroid Build Coastguard Worker return self 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Worker def __next__(self): 915*da0073e9SAndroid Build Coastguard Worker worker_info = torch.utils.data.get_worker_info() 916*da0073e9SAndroid Build Coastguard Worker if ( 917*da0073e9SAndroid Build Coastguard Worker self.error_event is not None 918*da0073e9SAndroid Build Coastguard Worker and self.error_event.is_set() 919*da0073e9SAndroid Build Coastguard Worker and worker_info.id == worker_info.num_workers - 1 920*da0073e9SAndroid Build Coastguard Worker ): 921*da0073e9SAndroid Build Coastguard Worker # only error in the last worker 922*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Worker error") 923*da0073e9SAndroid Build Coastguard Worker self.remaining -= 1 924*da0073e9SAndroid Build Coastguard Worker if self.remaining < 0: 925*da0073e9SAndroid Build Coastguard Worker raise StopIteration 926*da0073e9SAndroid Build Coastguard Worker return torch.tensor(-1000) 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker 929*da0073e9SAndroid Build Coastguard Worker# See TestDataLoader.test_proper_exit for usage 930*da0073e9SAndroid Build Coastguard Workerdef _test_proper_exit( 931*da0073e9SAndroid Build Coastguard Worker is_iterable_dataset, 932*da0073e9SAndroid Build Coastguard Worker use_workers, 933*da0073e9SAndroid Build Coastguard Worker pin_memory, 934*da0073e9SAndroid Build Coastguard Worker exit_method, 935*da0073e9SAndroid Build Coastguard Worker hold_iter_reference, 936*da0073e9SAndroid Build Coastguard Worker loader_setup_event, 937*da0073e9SAndroid Build Coastguard Worker tester_setup_event, 938*da0073e9SAndroid Build Coastguard Worker persistent_workers, 939*da0073e9SAndroid Build Coastguard Worker): 940*da0073e9SAndroid Build Coastguard Worker num_workers = 2 if use_workers else 0 941*da0073e9SAndroid Build Coastguard Worker 942*da0073e9SAndroid Build Coastguard Worker if exit_method == "worker_error" or exit_method == "worker_kill": 943*da0073e9SAndroid Build Coastguard Worker assert use_workers is True 944*da0073e9SAndroid Build Coastguard Worker 945*da0073e9SAndroid Build Coastguard Worker if exit_method == "worker_error": 946*da0073e9SAndroid Build Coastguard Worker worker_error_event = mp.Event() 947*da0073e9SAndroid Build Coastguard Worker else: 948*da0073e9SAndroid Build Coastguard Worker worker_error_event = None 949*da0073e9SAndroid Build Coastguard Worker 950*da0073e9SAndroid Build Coastguard Worker if is_iterable_dataset: 951*da0073e9SAndroid Build Coastguard Worker ds = TestProperExitIterableDataset(7, worker_error_event) 952*da0073e9SAndroid Build Coastguard Worker else: 953*da0073e9SAndroid Build Coastguard Worker ds = TestProperExitDataset(12, worker_error_event) 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker loader = DataLoader( 956*da0073e9SAndroid Build Coastguard Worker ds, 957*da0073e9SAndroid Build Coastguard Worker batch_size=1, 958*da0073e9SAndroid Build Coastguard Worker shuffle=False, 959*da0073e9SAndroid Build Coastguard Worker num_workers=num_workers, 960*da0073e9SAndroid Build Coastguard Worker pin_memory=pin_memory, 961*da0073e9SAndroid Build Coastguard Worker worker_init_fn=set_faulthander_if_available, 962*da0073e9SAndroid Build Coastguard Worker persistent_workers=persistent_workers, 963*da0073e9SAndroid Build Coastguard Worker ) 964*da0073e9SAndroid Build Coastguard Worker 965*da0073e9SAndroid Build Coastguard Worker error_it = 2 966*da0073e9SAndroid Build Coastguard Worker 967*da0073e9SAndroid Build Coastguard Worker if use_workers: 968*da0073e9SAndroid Build Coastguard Worker # 2 is the magical per-worker prefetch number... 969*da0073e9SAndroid Build Coastguard Worker # FIXME: change this after the number becomes configurable. 970*da0073e9SAndroid Build Coastguard Worker if is_iterable_dataset: 971*da0073e9SAndroid Build Coastguard Worker assert len(ds) * num_workers > (error_it + 2 + 1) 972*da0073e9SAndroid Build Coastguard Worker else: 973*da0073e9SAndroid Build Coastguard Worker assert len(loader) > (error_it + 2 + 1) * num_workers 974*da0073e9SAndroid Build Coastguard Worker else: 975*da0073e9SAndroid Build Coastguard Worker if is_iterable_dataset: 976*da0073e9SAndroid Build Coastguard Worker assert len(ds) > error_it + 1 977*da0073e9SAndroid Build Coastguard Worker else: 978*da0073e9SAndroid Build Coastguard Worker assert len(loader) > error_it + 1 979*da0073e9SAndroid Build Coastguard Worker 980*da0073e9SAndroid Build Coastguard Worker it = iter(loader) 981*da0073e9SAndroid Build Coastguard Worker if use_workers: 982*da0073e9SAndroid Build Coastguard Worker workers = it._workers 983*da0073e9SAndroid Build Coastguard Worker 984*da0073e9SAndroid Build Coastguard Worker def kill_pid(pid): 985*da0073e9SAndroid Build Coastguard Worker psutil_p = psutil.Process(pid) 986*da0073e9SAndroid Build Coastguard Worker psutil_p.kill() 987*da0073e9SAndroid Build Coastguard Worker psutil_p.wait(JOIN_TIMEOUT) 988*da0073e9SAndroid Build Coastguard Worker assert not psutil_p.is_running() 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker for i, _ in enumerate(it): 991*da0073e9SAndroid Build Coastguard Worker if i == 0: 992*da0073e9SAndroid Build Coastguard Worker if not hold_iter_reference: 993*da0073e9SAndroid Build Coastguard Worker del it 994*da0073e9SAndroid Build Coastguard Worker del loader 995*da0073e9SAndroid Build Coastguard Worker loader_setup_event.set() 996*da0073e9SAndroid Build Coastguard Worker tester_setup_event.wait() 997*da0073e9SAndroid Build Coastguard Worker # ensure that the workers are still alive 998*da0073e9SAndroid Build Coastguard Worker if use_workers: 999*da0073e9SAndroid Build Coastguard Worker for w in workers: 1000*da0073e9SAndroid Build Coastguard Worker assert w.is_alive() 1001*da0073e9SAndroid Build Coastguard Worker if worker_error_event is not None: 1002*da0073e9SAndroid Build Coastguard Worker worker_error_event.set() 1003*da0073e9SAndroid Build Coastguard Worker 1004*da0073e9SAndroid Build Coastguard Worker if i == error_it: 1005*da0073e9SAndroid Build Coastguard Worker if exit_method == "loader_error": 1006*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Loader error") 1007*da0073e9SAndroid Build Coastguard Worker elif exit_method == "loader_kill": 1008*da0073e9SAndroid Build Coastguard Worker kill_pid(os.getpid()) 1009*da0073e9SAndroid Build Coastguard Worker elif exit_method == "worker_kill": 1010*da0073e9SAndroid Build Coastguard Worker kill_pid(workers[-1].pid) # kill last worker 1011*da0073e9SAndroid Build Coastguard Worker 1012*da0073e9SAndroid Build Coastguard Worker if not hold_iter_reference: 1013*da0073e9SAndroid Build Coastguard Worker # Tries to trigger the __del__ clean-up rather than the automatic 1014*da0073e9SAndroid Build Coastguard Worker # exiting of daemonic children. Technically it should be automatically 1015*da0073e9SAndroid Build Coastguard Worker # triggered, but I don't want to rely on the implementation detail of 1016*da0073e9SAndroid Build Coastguard Worker # Python gc. 1017*da0073e9SAndroid Build Coastguard Worker gc.collect() 1018*da0073e9SAndroid Build Coastguard Worker 1019*da0073e9SAndroid Build Coastguard Worker 1020*da0073e9SAndroid Build Coastguard Workerclass TestWorkerInfoDataset(SynchronizedDataset): 1021*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, idx): 1022*da0073e9SAndroid Build Coastguard Worker self.sync_once() 1023*da0073e9SAndroid Build Coastguard Worker return torch.tensor(self.value) 1024*da0073e9SAndroid Build Coastguard Worker 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker# Should be used as worker_init_fn with TestWorkerInfoDataset. 1027*da0073e9SAndroid Build Coastguard Worker# See _test_get_worker_info below for usage. 1028*da0073e9SAndroid Build Coastguard Workerdef _test_worker_info_init_fn(worker_id): 1029*da0073e9SAndroid Build Coastguard Worker worker_info = torch.utils.data.get_worker_info() 1030*da0073e9SAndroid Build Coastguard Worker assert ( 1031*da0073e9SAndroid Build Coastguard Worker worker_id == worker_info.id 1032*da0073e9SAndroid Build Coastguard Worker ), "worker_init_fn and worker_info should have consistent id" 1033*da0073e9SAndroid Build Coastguard Worker assert ( 1034*da0073e9SAndroid Build Coastguard Worker worker_id < worker_info.num_workers 1035*da0073e9SAndroid Build Coastguard Worker ), "worker_init_fn and worker_info should have valid id" 1036*da0073e9SAndroid Build Coastguard Worker assert ( 1037*da0073e9SAndroid Build Coastguard Worker worker_info.seed == torch.initial_seed() 1038*da0073e9SAndroid Build Coastguard Worker ), "worker_init_fn and worker_info should have consistent seed" 1039*da0073e9SAndroid Build Coastguard Worker dataset = worker_info.dataset 1040*da0073e9SAndroid Build Coastguard Worker assert isinstance( 1041*da0073e9SAndroid Build Coastguard Worker dataset, TestWorkerInfoDataset 1042*da0073e9SAndroid Build Coastguard Worker ), "worker_info should have correct dataset copy" 1043*da0073e9SAndroid Build Coastguard Worker assert not hasattr(dataset, "value"), "worker_info should have correct dataset copy" 1044*da0073e9SAndroid Build Coastguard Worker # test that WorkerInfo attributes are read-only 1045*da0073e9SAndroid Build Coastguard Worker try: 1046*da0073e9SAndroid Build Coastguard Worker worker_info.id = 3999 1047*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 1048*da0073e9SAndroid Build Coastguard Worker assert str(e) == "Cannot assign attributes to WorkerInfo objects" 1049*da0073e9SAndroid Build Coastguard Worker try: 1050*da0073e9SAndroid Build Coastguard Worker worker_info.a = 3 1051*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 1052*da0073e9SAndroid Build Coastguard Worker assert str(e) == "Cannot assign attributes to WorkerInfo objects" 1053*da0073e9SAndroid Build Coastguard Worker for k in ["id", "num_workers", "seed", "dataset"]: 1054*da0073e9SAndroid Build Coastguard Worker assert f"{k}=" in repr(worker_info) 1055*da0073e9SAndroid Build Coastguard Worker dataset.value = [worker_id, os.getpid()] 1056*da0073e9SAndroid Build Coastguard Worker 1057*da0073e9SAndroid Build Coastguard Worker 1058*da0073e9SAndroid Build Coastguard Workerdef _test_get_worker_info(): 1059*da0073e9SAndroid Build Coastguard Worker # get_worker_info returns None in main proc 1060*da0073e9SAndroid Build Coastguard Worker assert torch.utils.data.get_worker_info() is None 1061*da0073e9SAndroid Build Coastguard Worker num_workers = 2 1062*da0073e9SAndroid Build Coastguard Worker batch_size = 2 1063*da0073e9SAndroid Build Coastguard Worker dataset = TestWorkerInfoDataset(6, batch_size, num_workers) 1064*da0073e9SAndroid Build Coastguard Worker dataloader = DataLoader( 1065*da0073e9SAndroid Build Coastguard Worker dataset, 1066*da0073e9SAndroid Build Coastguard Worker batch_size=batch_size, 1067*da0073e9SAndroid Build Coastguard Worker num_workers=num_workers, 1068*da0073e9SAndroid Build Coastguard Worker worker_init_fn=_test_worker_info_init_fn, 1069*da0073e9SAndroid Build Coastguard Worker ) 1070*da0073e9SAndroid Build Coastguard Worker it = iter(dataloader) 1071*da0073e9SAndroid Build Coastguard Worker data = [] 1072*da0073e9SAndroid Build Coastguard Worker for d in it: 1073*da0073e9SAndroid Build Coastguard Worker data.append(d) # noqa: PERF402 1074*da0073e9SAndroid Build Coastguard Worker worker_pids = [w.pid for w in it._workers] 1075*da0073e9SAndroid Build Coastguard Worker data = torch.cat(data, 0) 1076*da0073e9SAndroid Build Coastguard Worker for d in data: 1077*da0073e9SAndroid Build Coastguard Worker # each `d` is a [worker_id, worker_pid] pair, which is set in 1078*da0073e9SAndroid Build Coastguard Worker # _test_worker_info_init_fn 1079*da0073e9SAndroid Build Coastguard Worker assert d[1] == worker_pids[d[0]] 1080*da0073e9SAndroid Build Coastguard Worker # get_worker_info returns None in main proc after data loading 1081*da0073e9SAndroid Build Coastguard Worker assert torch.utils.data.get_worker_info() is None 1082*da0073e9SAndroid Build Coastguard Worker # main proc dataset was never assigned this attribute 1083*da0073e9SAndroid Build Coastguard Worker assert not hasattr(dataset, "value") 1084*da0073e9SAndroid Build Coastguard Worker try: 1085*da0073e9SAndroid Build Coastguard Worker _ = dataset[0] 1086*da0073e9SAndroid Build Coastguard Worker except AttributeError: 1087*da0073e9SAndroid Build Coastguard Worker return 1088*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Expected AttributeError") 1089*da0073e9SAndroid Build Coastguard Worker 1090*da0073e9SAndroid Build Coastguard Worker 1091*da0073e9SAndroid Build Coastguard Worker# test custom init function 1092*da0073e9SAndroid Build Coastguard Workerdef init_fn(worker_id): 1093*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(12345) 1094*da0073e9SAndroid Build Coastguard Worker 1095*da0073e9SAndroid Build Coastguard Worker 1096*da0073e9SAndroid Build Coastguard Worker# used with test_error_in_init 1097*da0073e9SAndroid Build Coastguard Workerclass ErrorIterableDataset(IterableDataset): 1098*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 1099*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Error in __iter__") 1100*da0073e9SAndroid Build Coastguard Worker 1101*da0073e9SAndroid Build Coastguard Worker 1102*da0073e9SAndroid Build Coastguard Worker# used with test_error_in_init 1103*da0073e9SAndroid Build Coastguard Workerdef error_worker_init_fn(_): 1104*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Error in worker_init_fn") 1105*da0073e9SAndroid Build Coastguard Worker 1106*da0073e9SAndroid Build Coastguard Worker 1107*da0073e9SAndroid Build Coastguard Workerclass BulkLoadingDataset(Dataset): 1108*da0073e9SAndroid Build Coastguard Worker def __init__(self, length): 1109*da0073e9SAndroid Build Coastguard Worker self.length = length 1110*da0073e9SAndroid Build Coastguard Worker 1111*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, indices): 1112*da0073e9SAndroid Build Coastguard Worker assert isinstance(indices, (list, tuple)) 1113*da0073e9SAndroid Build Coastguard Worker return torch.as_tensor(indices) 1114*da0073e9SAndroid Build Coastguard Worker 1115*da0073e9SAndroid Build Coastguard Worker def __len__(self): 1116*da0073e9SAndroid Build Coastguard Worker return self.length 1117*da0073e9SAndroid Build Coastguard Worker 1118*da0073e9SAndroid Build Coastguard Worker 1119*da0073e9SAndroid Build Coastguard Workerclass BulkLoadingSampler(torch.utils.data.Sampler): 1120*da0073e9SAndroid Build Coastguard Worker def __init__(self, dataset, batch_size): 1121*da0073e9SAndroid Build Coastguard Worker self.dataset = dataset 1122*da0073e9SAndroid Build Coastguard Worker self.batch_size = batch_size 1123*da0073e9SAndroid Build Coastguard Worker 1124*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 1125*da0073e9SAndroid Build Coastguard Worker for x in torch.randperm(len(self.dataset)).split(self.batch_size): 1126*da0073e9SAndroid Build Coastguard Worker yield x.tolist() 1127*da0073e9SAndroid Build Coastguard Worker 1128*da0073e9SAndroid Build Coastguard Worker def __len__(self): 1129*da0073e9SAndroid Build Coastguard Worker return int(math.ceil(len(self.dataset) / float(self.batch_size))) 1130*da0073e9SAndroid Build Coastguard Worker 1131*da0073e9SAndroid Build Coastguard Worker 1132*da0073e9SAndroid Build Coastguard Workerclass TestMultiEpochDataset(IterableDataset): 1133*da0073e9SAndroid Build Coastguard Worker def __init__(self, length): 1134*da0073e9SAndroid Build Coastguard Worker self.length = length 1135*da0073e9SAndroid Build Coastguard Worker 1136*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 1137*da0073e9SAndroid Build Coastguard Worker worker_info = torch.utils.data.get_worker_info() 1138*da0073e9SAndroid Build Coastguard Worker assert worker_info is not None 1139*da0073e9SAndroid Build Coastguard Worker worker_id = worker_info.id 1140*da0073e9SAndroid Build Coastguard Worker for idx in range(self.length // worker_info.num_workers): 1141*da0073e9SAndroid Build Coastguard Worker yield worker_id 1142*da0073e9SAndroid Build Coastguard Worker 1143*da0073e9SAndroid Build Coastguard Worker def __len__(self): 1144*da0073e9SAndroid Build Coastguard Worker return self.length 1145*da0073e9SAndroid Build Coastguard Worker 1146*da0073e9SAndroid Build Coastguard Worker 1147*da0073e9SAndroid Build Coastguard Workerclass CustomList(list): 1148*da0073e9SAndroid Build Coastguard Worker pass 1149*da0073e9SAndroid Build Coastguard Worker 1150*da0073e9SAndroid Build Coastguard Worker 1151*da0073e9SAndroid Build Coastguard Workerclass CustomDict(dict): 1152*da0073e9SAndroid Build Coastguard Worker pass 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Worker 1155*da0073e9SAndroid Build Coastguard Workerdef row_processor(row): 1156*da0073e9SAndroid Build Coastguard Worker return np.add(row, 1) 1157*da0073e9SAndroid Build Coastguard Worker 1158*da0073e9SAndroid Build Coastguard Worker 1159*da0073e9SAndroid Build Coastguard Workerdef filter_len(row): 1160*da0073e9SAndroid Build Coastguard Worker return len(row) == 4 1161*da0073e9SAndroid Build Coastguard Worker 1162*da0073e9SAndroid Build Coastguard Worker 1163*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 1164*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 1165*da0073e9SAndroid Build Coastguard Worker "Fails with TSAN with the following error: starting new threads after multi-threaded " 1166*da0073e9SAndroid Build Coastguard Worker "fork is not supported. Dying (set die_after_fork=0 to override)", 1167*da0073e9SAndroid Build Coastguard Worker) 1168*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 1169*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ASAN, 1170*da0073e9SAndroid Build Coastguard Worker "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223", 1171*da0073e9SAndroid Build Coastguard Worker) 1172*da0073e9SAndroid Build Coastguard Workerclass TestDataLoader(TestCase): 1173*da0073e9SAndroid Build Coastguard Worker def setUp(self): 1174*da0073e9SAndroid Build Coastguard Worker super().setUp() 1175*da0073e9SAndroid Build Coastguard Worker self.data = torch.randn(100, 2, 3, 5) 1176*da0073e9SAndroid Build Coastguard Worker self.labels = torch.randperm(50).repeat(2) 1177*da0073e9SAndroid Build Coastguard Worker self.dataset = TensorDataset(self.data, self.labels) 1178*da0073e9SAndroid Build Coastguard Worker self.persistent_workers = False 1179*da0073e9SAndroid Build Coastguard Worker 1180*da0073e9SAndroid Build Coastguard Worker def _get_data_loader(self, dataset, **kwargs): 1181*da0073e9SAndroid Build Coastguard Worker persistent_workers = kwargs.get("persistent_workers", self.persistent_workers) 1182*da0073e9SAndroid Build Coastguard Worker if persistent_workers and kwargs.get("num_workers", 0) == 0: 1183*da0073e9SAndroid Build Coastguard Worker persistent_workers = False 1184*da0073e9SAndroid Build Coastguard Worker kwargs["persistent_workers"] = persistent_workers 1185*da0073e9SAndroid Build Coastguard Worker return DataLoader(dataset, **kwargs) 1186*da0073e9SAndroid Build Coastguard Worker 1187*da0073e9SAndroid Build Coastguard Worker def _test_sequential(self, loader): 1188*da0073e9SAndroid Build Coastguard Worker batch_size = loader.batch_size 1189*da0073e9SAndroid Build Coastguard Worker if batch_size is None: 1190*da0073e9SAndroid Build Coastguard Worker for idx, (sample, target) in enumerate(loader): 1191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sample, self.data[idx]) 1192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(target, self.labels[idx]) 1193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx, len(self.dataset) - 1) 1194*da0073e9SAndroid Build Coastguard Worker else: 1195*da0073e9SAndroid Build Coastguard Worker for i, (sample, target) in enumerate(loader): 1196*da0073e9SAndroid Build Coastguard Worker idx = i * batch_size 1197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sample, self.data[idx : idx + batch_size]) 1198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(target, self.labels[idx : idx + batch_size]) 1199*da0073e9SAndroid Build Coastguard Worker self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size)) 1200*da0073e9SAndroid Build Coastguard Worker 1201*da0073e9SAndroid Build Coastguard Worker def _test_shuffle(self, loader): 1202*da0073e9SAndroid Build Coastguard Worker found_data = dict.fromkeys(range(self.data.size(0)), 0) 1203*da0073e9SAndroid Build Coastguard Worker found_labels = dict.fromkeys(range(self.labels.size(0)), 0) 1204*da0073e9SAndroid Build Coastguard Worker batch_size = loader.batch_size 1205*da0073e9SAndroid Build Coastguard Worker if batch_size is None: 1206*da0073e9SAndroid Build Coastguard Worker for i, (batch_samples, batch_targets) in enumerate(loader): 1207*da0073e9SAndroid Build Coastguard Worker sample, target = (batch_samples, batch_targets) 1208*da0073e9SAndroid Build Coastguard Worker for data_point_idx, data_point in enumerate(self.data): 1209*da0073e9SAndroid Build Coastguard Worker if data_point.eq(sample).all(): 1210*da0073e9SAndroid Build Coastguard Worker self.assertFalse(found_data[data_point_idx]) 1211*da0073e9SAndroid Build Coastguard Worker found_data[data_point_idx] += 1 1212*da0073e9SAndroid Build Coastguard Worker break 1213*da0073e9SAndroid Build Coastguard Worker self.assertEqual(target, self.labels[data_point_idx]) 1214*da0073e9SAndroid Build Coastguard Worker found_labels[data_point_idx] += 1 1215*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sum(found_data.values()), (i + 1)) 1216*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sum(found_labels.values()), (i + 1)) 1217*da0073e9SAndroid Build Coastguard Worker self.assertEqual(i, (len(self.dataset) - 1)) 1218*da0073e9SAndroid Build Coastguard Worker else: 1219*da0073e9SAndroid Build Coastguard Worker for i, (batch_samples, batch_targets) in enumerate(loader): 1220*da0073e9SAndroid Build Coastguard Worker for sample, target in zip(batch_samples, batch_targets): 1221*da0073e9SAndroid Build Coastguard Worker for data_point_idx, data_point in enumerate(self.data): 1222*da0073e9SAndroid Build Coastguard Worker if data_point.eq(sample).all(): 1223*da0073e9SAndroid Build Coastguard Worker self.assertFalse(found_data[data_point_idx]) 1224*da0073e9SAndroid Build Coastguard Worker found_data[data_point_idx] += 1 1225*da0073e9SAndroid Build Coastguard Worker break 1226*da0073e9SAndroid Build Coastguard Worker self.assertEqual(target, self.labels[data_point_idx]) 1227*da0073e9SAndroid Build Coastguard Worker found_labels[data_point_idx] += 1 1228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sum(found_data.values()), (i + 1) * batch_size) 1229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size) 1230*da0073e9SAndroid Build Coastguard Worker self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size)) 1231*da0073e9SAndroid Build Coastguard Worker 1232*da0073e9SAndroid Build Coastguard Worker def _test_error(self, loader): 1233*da0073e9SAndroid Build Coastguard Worker it = iter(loader) 1234*da0073e9SAndroid Build Coastguard Worker errors = 0 1235*da0073e9SAndroid Build Coastguard Worker while True: 1236*da0073e9SAndroid Build Coastguard Worker try: 1237*da0073e9SAndroid Build Coastguard Worker next(it) 1238*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 1239*da0073e9SAndroid Build Coastguard Worker errors += 1 1240*da0073e9SAndroid Build Coastguard Worker except StopIteration: 1241*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1242*da0073e9SAndroid Build Coastguard Worker errors, math.ceil(float(len(loader.dataset)) / loader.batch_size) 1243*da0073e9SAndroid Build Coastguard Worker ) 1244*da0073e9SAndroid Build Coastguard Worker return 1245*da0073e9SAndroid Build Coastguard Worker 1246*da0073e9SAndroid Build Coastguard Worker def test_error_in_init(self): 1247*da0073e9SAndroid Build Coastguard Worker for num_workers in [0, 2]: 1248*da0073e9SAndroid Build Coastguard Worker loader = self._get_data_loader( 1249*da0073e9SAndroid Build Coastguard Worker ErrorIterableDataset(), num_workers=num_workers 1250*da0073e9SAndroid Build Coastguard Worker ) 1251*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Error in __iter__"): 1252*da0073e9SAndroid Build Coastguard Worker list(iter(loader)) 1253*da0073e9SAndroid Build Coastguard Worker 1254*da0073e9SAndroid Build Coastguard Worker loader = self._get_data_loader( 1255*da0073e9SAndroid Build Coastguard Worker self.dataset, num_workers=2, worker_init_fn=error_worker_init_fn 1256*da0073e9SAndroid Build Coastguard Worker ) 1257*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Error in worker_init_fn"): 1258*da0073e9SAndroid Build Coastguard Worker list(iter(loader)) 1259*da0073e9SAndroid Build Coastguard Worker 1260*da0073e9SAndroid Build Coastguard Worker def test_typing(self): 1261*da0073e9SAndroid Build Coastguard Worker from typing import List 1262*da0073e9SAndroid Build Coastguard Worker 1263*da0073e9SAndroid Build Coastguard Worker # Make sure there is no TypeError 1264*da0073e9SAndroid Build Coastguard Worker 1265*da0073e9SAndroid Build Coastguard Worker class SomeDatasetClass(Dataset[List[torch.Tensor]]): 1266*da0073e9SAndroid Build Coastguard Worker pass 1267*da0073e9SAndroid Build Coastguard Worker 1268*da0073e9SAndroid Build Coastguard Worker def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]: 1269*da0073e9SAndroid Build Coastguard Worker pass 1270*da0073e9SAndroid Build Coastguard Worker 1271*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI") 1272*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows") 1273*da0073e9SAndroid Build Coastguard Worker def test_fd_limit_exceeded(self): 1274*da0073e9SAndroid Build Coastguard Worker # See NOTE [ DataLoader on Linux and open files limit ] 1275*da0073e9SAndroid Build Coastguard Worker import subprocess 1276*da0073e9SAndroid Build Coastguard Worker 1277*da0073e9SAndroid Build Coastguard Worker subprocess.check_output( 1278*da0073e9SAndroid Build Coastguard Worker [ 1279*da0073e9SAndroid Build Coastguard Worker sys.executable, 1280*da0073e9SAndroid Build Coastguard Worker "-c", 1281*da0073e9SAndroid Build Coastguard Worker """\ 1282*da0073e9SAndroid Build Coastguard Workerimport torch 1283*da0073e9SAndroid Build Coastguard Workerimport resource 1284*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data import DataLoader, IterableDataset 1285*da0073e9SAndroid Build Coastguard Worker 1286*da0073e9SAndroid Build Coastguard Workerclass RandomDataset(IterableDataset): 1287*da0073e9SAndroid Build Coastguard Worker def __init__(self, len, size): 1288*da0073e9SAndroid Build Coastguard Worker super(RandomDataset).__init__() 1289*da0073e9SAndroid Build Coastguard Worker self.len = len 1290*da0073e9SAndroid Build Coastguard Worker self.size = size 1291*da0073e9SAndroid Build Coastguard Worker 1292*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 1293*da0073e9SAndroid Build Coastguard Worker return self 1294*da0073e9SAndroid Build Coastguard Worker 1295*da0073e9SAndroid Build Coastguard Worker def __next__(self): 1296*da0073e9SAndroid Build Coastguard Worker if self.len <= 0: 1297*da0073e9SAndroid Build Coastguard Worker raise StopIteration 1298*da0073e9SAndroid Build Coastguard Worker self.len -= 1 1299*da0073e9SAndroid Build Coastguard Worker return torch.randn(self.size) 1300*da0073e9SAndroid Build Coastguard Worker 1301*da0073e9SAndroid Build Coastguard Workertry: 1302*da0073e9SAndroid Build Coastguard Worker keep_fds_alive = [] 1303*da0073e9SAndroid Build Coastguard Worker resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100)) 1304*da0073e9SAndroid Build Coastguard Worker for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork", 1305*da0073e9SAndroid Build Coastguard Worker num_workers=1): 1306*da0073e9SAndroid Build Coastguard Worker random_t.max(dim=0) 1307*da0073e9SAndroid Build Coastguard Worker keep_fds_alive.append(random_t) 1308*da0073e9SAndroid Build Coastguard Workerexcept RuntimeError as e: 1309*da0073e9SAndroid Build Coastguard Worker assert "ulimit -n" in str(e) 1310*da0073e9SAndroid Build Coastguard Worker assert "set_sharing_strategy" in str(e) 1311*da0073e9SAndroid Build Coastguard Worker""", 1312*da0073e9SAndroid Build Coastguard Worker ] 1313*da0073e9SAndroid Build Coastguard Worker ) 1314*da0073e9SAndroid Build Coastguard Worker 1315*da0073e9SAndroid Build Coastguard Worker def test_invalid_assign_after_init(self): 1316*da0073e9SAndroid Build Coastguard Worker dl = self._get_data_loader(self.dataset) 1317*da0073e9SAndroid Build Coastguard Worker for attr in ("batch_size", "sampler", "batch_sampler", "drop_last", "dataset"): 1318*da0073e9SAndroid Build Coastguard Worker 1319*da0073e9SAndroid Build Coastguard Worker def fn(): 1320*da0073e9SAndroid Build Coastguard Worker setattr(dl, attr, {}) 1321*da0073e9SAndroid Build Coastguard Worker 1322*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, fn) 1323*da0073e9SAndroid Build Coastguard Worker 1324*da0073e9SAndroid Build Coastguard Worker def test_sequential_nonbatch(self): 1325*da0073e9SAndroid Build Coastguard Worker self._test_sequential(self._get_data_loader(self.dataset, batch_size=None)) 1326*da0073e9SAndroid Build Coastguard Worker 1327*da0073e9SAndroid Build Coastguard Worker def test_sequential_batch(self): 1328*da0073e9SAndroid Build Coastguard Worker self._test_sequential(self._get_data_loader(self.dataset)) 1329*da0073e9SAndroid Build Coastguard Worker self._test_sequential(self._get_data_loader(self.dataset, batch_size=2)) 1330*da0073e9SAndroid Build Coastguard Worker 1331*da0073e9SAndroid Build Coastguard Worker def test_bulk_loading_nobatch(self): 1332*da0073e9SAndroid Build Coastguard Worker n = 35 1333*da0073e9SAndroid Build Coastguard Worker bs = 4 1334*da0073e9SAndroid Build Coastguard Worker ds = BulkLoadingDataset(n) 1335*da0073e9SAndroid Build Coastguard Worker sampler = BulkLoadingSampler(ds, batch_size=4) 1336*da0073e9SAndroid Build Coastguard Worker 1337*da0073e9SAndroid Build Coastguard Worker for num_workers in [0, 4]: 1338*da0073e9SAndroid Build Coastguard Worker dl = self._get_data_loader( 1339*da0073e9SAndroid Build Coastguard Worker ds, 1340*da0073e9SAndroid Build Coastguard Worker num_workers=num_workers, 1341*da0073e9SAndroid Build Coastguard Worker batch_size=None, 1342*da0073e9SAndroid Build Coastguard Worker sampler=sampler, 1343*da0073e9SAndroid Build Coastguard Worker pin_memory=TEST_CUDA, 1344*da0073e9SAndroid Build Coastguard Worker ) 1345*da0073e9SAndroid Build Coastguard Worker self.assertFalse(dl._auto_collation) 1346*da0073e9SAndroid Build Coastguard Worker samples = list(dl) 1347*da0073e9SAndroid Build Coastguard Worker self.assertEqual(samples[0].is_pinned(), TEST_CUDA) 1348*da0073e9SAndroid Build Coastguard Worker self.assertEqual(set(torch.cat(samples, 0).tolist()), set(range(n))) 1349*da0073e9SAndroid Build Coastguard Worker 1350*da0073e9SAndroid Build Coastguard Worker def test_growing_dataset(self): 1351*da0073e9SAndroid Build Coastguard Worker dataset = [torch.ones(4) for _ in range(4)] 1352*da0073e9SAndroid Build Coastguard Worker dataloader_seq = self._get_data_loader(dataset, shuffle=False) 1353*da0073e9SAndroid Build Coastguard Worker dataloader_shuffle = self._get_data_loader(dataset, shuffle=True) 1354*da0073e9SAndroid Build Coastguard Worker dataset.append(torch.ones(4)) 1355*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(dataloader_seq), 5) 1356*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(dataloader_shuffle), 5) 1357*da0073e9SAndroid Build Coastguard Worker 1358*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 1359*da0073e9SAndroid Build Coastguard Worker def test_sequential_pin_memory(self): 1360*da0073e9SAndroid Build Coastguard Worker loader = self._get_data_loader(self.dataset, batch_size=2, pin_memory=True) 1361*da0073e9SAndroid Build Coastguard Worker for input, target in loader: 1362*da0073e9SAndroid Build Coastguard Worker self.assertTrue(input.is_pinned()) 1363*da0073e9SAndroid Build Coastguard Worker self.assertTrue(target.is_pinned()) 1364*da0073e9SAndroid Build Coastguard Worker 1365*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 1366*da0073e9SAndroid Build Coastguard Worker def test_multiple_dataloaders(self): 1367*da0073e9SAndroid Build Coastguard Worker for multiprocessing_context in supported_multiprocessing_contexts: 1368*da0073e9SAndroid Build Coastguard Worker loader1_it = iter(self._get_data_loader(self.dataset, num_workers=1)) 1369*da0073e9SAndroid Build Coastguard Worker loader2_it = iter( 1370*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1371*da0073e9SAndroid Build Coastguard Worker self.dataset, 1372*da0073e9SAndroid Build Coastguard Worker num_workers=2, 1373*da0073e9SAndroid Build Coastguard Worker multiprocessing_context=multiprocessing_context, 1374*da0073e9SAndroid Build Coastguard Worker ) 1375*da0073e9SAndroid Build Coastguard Worker ) 1376*da0073e9SAndroid Build Coastguard Worker next(loader1_it) 1377*da0073e9SAndroid Build Coastguard Worker next(loader1_it) 1378*da0073e9SAndroid Build Coastguard Worker next(loader2_it) 1379*da0073e9SAndroid Build Coastguard Worker next(loader2_it) 1380*da0073e9SAndroid Build Coastguard Worker next(loader1_it) 1381*da0073e9SAndroid Build Coastguard Worker next(loader2_it) 1382*da0073e9SAndroid Build Coastguard Worker del loader1_it 1383*da0073e9SAndroid Build Coastguard Worker del loader2_it 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(True, "This test is disabled in pytorch/pytorch") 1386*da0073e9SAndroid Build Coastguard Worker def test_segfault(self): 1387*da0073e9SAndroid Build Coastguard Worker p = ErrorTrackingProcess(target=_test_segfault) 1388*da0073e9SAndroid Build Coastguard Worker p.start() 1389*da0073e9SAndroid Build Coastguard Worker p.join(JOIN_TIMEOUT) 1390*da0073e9SAndroid Build Coastguard Worker try: 1391*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p.is_alive()) 1392*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(p.exitcode, 0) 1393*da0073e9SAndroid Build Coastguard Worker if IS_WINDOWS: 1394*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(p.exception, OSError) 1395*da0073e9SAndroid Build Coastguard Worker self.assertRegex(str(p.exception), r"access violation reading ") 1396*da0073e9SAndroid Build Coastguard Worker else: 1397*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(p.exception, RuntimeError) 1398*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 1399*da0073e9SAndroid Build Coastguard Worker str(p.exception), 1400*da0073e9SAndroid Build Coastguard Worker r"DataLoader worker \(pid \d+\) is killed by signal: ", 1401*da0073e9SAndroid Build Coastguard Worker ) 1402*da0073e9SAndroid Build Coastguard Worker finally: 1403*da0073e9SAndroid Build Coastguard Worker p.terminate() 1404*da0073e9SAndroid Build Coastguard Worker 1405*da0073e9SAndroid Build Coastguard Worker # Tests if the child process forked by the DataLoader segfaults due to having more than 3 threads 1406*da0073e9SAndroid Build Coastguard Worker # in the parent process after at least one set_num_threads invocation in the parent process. 1407*da0073e9SAndroid Build Coastguard Worker # After forking, set_num_threads(1) in the child process entails handling some inherited data-structures 1408*da0073e9SAndroid Build Coastguard Worker # of the Caffe2 thread-pool of the parent process, culminating in a segfault. 1409*da0073e9SAndroid Build Coastguard Worker # Reference: https://github.com/pytorch/pytorch/issues/54752 1410*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Needs fork") 1411*da0073e9SAndroid Build Coastguard Worker def test_no_segfault(self): 1412*da0073e9SAndroid Build Coastguard Worker p = ErrorTrackingProcess(target=_test_no_segfault) 1413*da0073e9SAndroid Build Coastguard Worker p.start() 1414*da0073e9SAndroid Build Coastguard Worker p.join(JOIN_TIMEOUT) 1415*da0073e9SAndroid Build Coastguard Worker try: 1416*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p.is_alive()) 1417*da0073e9SAndroid Build Coastguard Worker if p.exception: 1418*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(p.exception, RuntimeError) 1419*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 1420*da0073e9SAndroid Build Coastguard Worker str(p.exception), 1421*da0073e9SAndroid Build Coastguard Worker r"DataLoader worker \(pid \d+\) is killed by signal: ", 1422*da0073e9SAndroid Build Coastguard Worker ) 1423*da0073e9SAndroid Build Coastguard Worker self.fail("Segfault occurred in worker process after fork") 1424*da0073e9SAndroid Build Coastguard Worker finally: 1425*da0073e9SAndroid Build Coastguard Worker p.terminate() 1426*da0073e9SAndroid Build Coastguard Worker 1427*da0073e9SAndroid Build Coastguard Worker def test_timeout(self): 1428*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA and not NO_MULTIPROCESSING_SPAWN: 1429*da0073e9SAndroid Build Coastguard Worker # This test runs in a subprocess, which can only initialize CUDA with spawn. 1430*da0073e9SAndroid Build Coastguard Worker # _test_timeout_pin_memory with pin_memory=True initializes CUDA when the iterator is 1431*da0073e9SAndroid Build Coastguard Worker # constructed. 1432*da0073e9SAndroid Build Coastguard Worker targets = (_test_timeout, _test_timeout_pin_memory) 1433*da0073e9SAndroid Build Coastguard Worker else: 1434*da0073e9SAndroid Build Coastguard Worker targets = (_test_timeout,) 1435*da0073e9SAndroid Build Coastguard Worker for target in targets: 1436*da0073e9SAndroid Build Coastguard Worker p = ErrorTrackingProcess(target=target, args=(self.persistent_workers,)) 1437*da0073e9SAndroid Build Coastguard Worker p.start() 1438*da0073e9SAndroid Build Coastguard Worker p.join(JOIN_TIMEOUT) 1439*da0073e9SAndroid Build Coastguard Worker try: 1440*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p.is_alive()) 1441*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(p.exitcode, 0) 1442*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(p.exception, RuntimeError) 1443*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 1444*da0073e9SAndroid Build Coastguard Worker str(p.exception), r"DataLoader timed out after \d+ seconds" 1445*da0073e9SAndroid Build Coastguard Worker ) 1446*da0073e9SAndroid Build Coastguard Worker finally: 1447*da0073e9SAndroid Build Coastguard Worker p.terminate() 1448*da0073e9SAndroid Build Coastguard Worker 1449*da0073e9SAndroid Build Coastguard Worker def test_large_sampler_indices(self): 1450*da0073e9SAndroid Build Coastguard Worker # Test that the data loader cleanly exit when the process errors 1451*da0073e9SAndroid Build Coastguard Worker # 1. having an reference to the iterator 1452*da0073e9SAndroid Build Coastguard Worker # 2. using a sampler that yields big elements s.t. _index_queues putters block 1453*da0073e9SAndroid Build Coastguard Worker # 1454*da0073e9SAndroid Build Coastguard Worker # More context: https://github.com/pytorch/pytorch/issues/48666 1455*da0073e9SAndroid Build Coastguard Worker 1456*da0073e9SAndroid Build Coastguard Worker p = ErrorTrackingProcess( 1457*da0073e9SAndroid Build Coastguard Worker target=_test_large_sampler_indices, args=(self.persistent_workers,) 1458*da0073e9SAndroid Build Coastguard Worker ) 1459*da0073e9SAndroid Build Coastguard Worker p.start() 1460*da0073e9SAndroid Build Coastguard Worker p.join(JOIN_TIMEOUT) 1461*da0073e9SAndroid Build Coastguard Worker try: 1462*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p.is_alive()) 1463*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(p.exitcode, 0) 1464*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(p.exception, RuntimeError) 1465*da0073e9SAndroid Build Coastguard Worker self.assertRegex(str(p.exception), r"My Error") 1466*da0073e9SAndroid Build Coastguard Worker finally: 1467*da0073e9SAndroid Build Coastguard Worker p.terminate() 1468*da0073e9SAndroid Build Coastguard Worker 1469*da0073e9SAndroid Build Coastguard Worker def test_invalid_ctor_args_combinations(self): 1470*da0073e9SAndroid Build Coastguard Worker # general 1471*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1472*da0073e9SAndroid Build Coastguard Worker ValueError, "num_workers option should be non-negative" 1473*da0073e9SAndroid Build Coastguard Worker ): 1474*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(self.dataset, num_workers=-1) 1475*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1476*da0073e9SAndroid Build Coastguard Worker ValueError, "timeout option should be non-negative" 1477*da0073e9SAndroid Build Coastguard Worker ): 1478*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(self.dataset, timeout=-1) 1479*da0073e9SAndroid Build Coastguard Worker 1480*da0073e9SAndroid Build Coastguard Worker # disable auto-batching 1481*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1482*da0073e9SAndroid Build Coastguard Worker ValueError, 1483*da0073e9SAndroid Build Coastguard Worker "batch_size=None option disables auto-batching and is mutually exclusive", 1484*da0073e9SAndroid Build Coastguard Worker ): 1485*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(self.dataset, batch_size=None, drop_last=True) 1486*da0073e9SAndroid Build Coastguard Worker 1487*da0073e9SAndroid Build Coastguard Worker valid_ctx = list(torch.multiprocessing.get_all_start_methods())[-1] 1488*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1489*da0073e9SAndroid Build Coastguard Worker ValueError, r"multi-process loading \(num_workers > 0\), but got" 1490*da0073e9SAndroid Build Coastguard Worker ): 1491*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1492*da0073e9SAndroid Build Coastguard Worker self.dataset, num_workers=0, multiprocessing_context=valid_ctx 1493*da0073e9SAndroid Build Coastguard Worker ) 1494*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1495*da0073e9SAndroid Build Coastguard Worker ValueError, "should specify a valid start method in" 1496*da0073e9SAndroid Build Coastguard Worker ): 1497*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1498*da0073e9SAndroid Build Coastguard Worker self.dataset, num_workers=1, multiprocessing_context="bad" 1499*da0073e9SAndroid Build Coastguard Worker ) 1500*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1501*da0073e9SAndroid Build Coastguard Worker TypeError, "multiprocessing_context option should be a valid context " 1502*da0073e9SAndroid Build Coastguard Worker ): 1503*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1504*da0073e9SAndroid Build Coastguard Worker self.dataset, num_workers=1, multiprocessing_context=object() 1505*da0073e9SAndroid Build Coastguard Worker ) 1506*da0073e9SAndroid Build Coastguard Worker 1507*da0073e9SAndroid Build Coastguard Worker # map-style 1508*da0073e9SAndroid Build Coastguard Worker sampler = torch.utils.data.SequentialSampler(self.dataset) 1509*da0073e9SAndroid Build Coastguard Worker batch_sampler = torch.utils.data.BatchSampler(sampler, 3, False) 1510*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1511*da0073e9SAndroid Build Coastguard Worker ValueError, "sampler option is mutually exclusive with shuffle" 1512*da0073e9SAndroid Build Coastguard Worker ): 1513*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1514*da0073e9SAndroid Build Coastguard Worker self.dataset, batch_size=11, sampler=sampler, shuffle=True 1515*da0073e9SAndroid Build Coastguard Worker ) 1516*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1517*da0073e9SAndroid Build Coastguard Worker ValueError, "sampler option is mutually exclusive with shuffle" 1518*da0073e9SAndroid Build Coastguard Worker ): 1519*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1520*da0073e9SAndroid Build Coastguard Worker self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=True 1521*da0073e9SAndroid Build Coastguard Worker ) 1522*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1523*da0073e9SAndroid Build Coastguard Worker ValueError, "sampler option is mutually exclusive with shuffle" 1524*da0073e9SAndroid Build Coastguard Worker ): 1525*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1526*da0073e9SAndroid Build Coastguard Worker self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=3 1527*da0073e9SAndroid Build Coastguard Worker ) 1528*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1529*da0073e9SAndroid Build Coastguard Worker ValueError, "batch_sampler option is mutually exclusive with" 1530*da0073e9SAndroid Build Coastguard Worker ): 1531*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1532*da0073e9SAndroid Build Coastguard Worker self.dataset, batch_size=11, batch_sampler=batch_sampler 1533*da0073e9SAndroid Build Coastguard Worker ) 1534*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1535*da0073e9SAndroid Build Coastguard Worker ValueError, "batch_sampler option is mutually exclusive with" 1536*da0073e9SAndroid Build Coastguard Worker ): 1537*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1538*da0073e9SAndroid Build Coastguard Worker self.dataset, shuffle=True, batch_sampler=batch_sampler 1539*da0073e9SAndroid Build Coastguard Worker ) 1540*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1541*da0073e9SAndroid Build Coastguard Worker ValueError, "batch_sampler option is mutually exclusive with" 1542*da0073e9SAndroid Build Coastguard Worker ): 1543*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1544*da0073e9SAndroid Build Coastguard Worker self.dataset, drop_last=True, batch_sampler=batch_sampler 1545*da0073e9SAndroid Build Coastguard Worker ) 1546*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1547*da0073e9SAndroid Build Coastguard Worker ValueError, "batch_sampler option is mutually exclusive with" 1548*da0073e9SAndroid Build Coastguard Worker ): 1549*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1550*da0073e9SAndroid Build Coastguard Worker self.dataset, drop_last=3, batch_sampler=batch_sampler 1551*da0073e9SAndroid Build Coastguard Worker ) 1552*da0073e9SAndroid Build Coastguard Worker 1553*da0073e9SAndroid Build Coastguard Worker # iterable-style 1554*da0073e9SAndroid Build Coastguard Worker dataset = CountingIterableDataset(20) 1555*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1556*da0073e9SAndroid Build Coastguard Worker ValueError, "DataLoader with IterableDataset: expected unspecified shuffle" 1557*da0073e9SAndroid Build Coastguard Worker ): 1558*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(dataset, shuffle=True) 1559*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1560*da0073e9SAndroid Build Coastguard Worker ValueError, "DataLoader with IterableDataset: expected unspecified shuffle" 1561*da0073e9SAndroid Build Coastguard Worker ): 1562*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(dataset, shuffle=3) 1563*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1564*da0073e9SAndroid Build Coastguard Worker ValueError, "DataLoader with IterableDataset: expected unspecified sampler" 1565*da0073e9SAndroid Build Coastguard Worker ): 1566*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1567*da0073e9SAndroid Build Coastguard Worker dataset, sampler=torch.utils.data.SequentialSampler(dataset) 1568*da0073e9SAndroid Build Coastguard Worker ) 1569*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1570*da0073e9SAndroid Build Coastguard Worker ValueError, "DataLoader with IterableDataset: expected unspecified sampler" 1571*da0073e9SAndroid Build Coastguard Worker ): 1572*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(dataset, sampler=3) 1573*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1574*da0073e9SAndroid Build Coastguard Worker ValueError, 1575*da0073e9SAndroid Build Coastguard Worker "DataLoader with IterableDataset: expected unspecified batch_sampler", 1576*da0073e9SAndroid Build Coastguard Worker ): 1577*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1578*da0073e9SAndroid Build Coastguard Worker dataset, 1579*da0073e9SAndroid Build Coastguard Worker batch_sampler=torch.utils.data.BatchSampler( 1580*da0073e9SAndroid Build Coastguard Worker torch.utils.data.SequentialSampler(dataset), 3, False 1581*da0073e9SAndroid Build Coastguard Worker ), 1582*da0073e9SAndroid Build Coastguard Worker ) 1583*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1584*da0073e9SAndroid Build Coastguard Worker ValueError, 1585*da0073e9SAndroid Build Coastguard Worker "DataLoader with IterableDataset: expected unspecified batch_sampler", 1586*da0073e9SAndroid Build Coastguard Worker ): 1587*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(dataset, batch_sampler=3) 1588*da0073e9SAndroid Build Coastguard Worker 1589*da0073e9SAndroid Build Coastguard Worker def test_builtin_collection_conversion(self): 1590*da0073e9SAndroid Build Coastguard Worker for coll_ty in (list, tuple): 1591*da0073e9SAndroid Build Coastguard Worker for num_workers in (0, 1): 1592*da0073e9SAndroid Build Coastguard Worker # map-style dataset 1593*da0073e9SAndroid Build Coastguard Worker dataset = CountingDataset(20) 1594*da0073e9SAndroid Build Coastguard Worker # no auto-batching 1595*da0073e9SAndroid Build Coastguard Worker fetched = coll_ty( 1596*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1597*da0073e9SAndroid Build Coastguard Worker dataset, batch_size=None, num_workers=num_workers 1598*da0073e9SAndroid Build Coastguard Worker ) 1599*da0073e9SAndroid Build Coastguard Worker ) 1600*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fetched, coll_ty(range(20))) 1601*da0073e9SAndroid Build Coastguard Worker # auto-batching 1602*da0073e9SAndroid Build Coastguard Worker fetched = coll_ty( 1603*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1604*da0073e9SAndroid Build Coastguard Worker dataset, batch_size=2, num_workers=num_workers 1605*da0073e9SAndroid Build Coastguard Worker ) 1606*da0073e9SAndroid Build Coastguard Worker ) 1607*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1608*da0073e9SAndroid Build Coastguard Worker fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2)) 1609*da0073e9SAndroid Build Coastguard Worker ) 1610*da0073e9SAndroid Build Coastguard Worker 1611*da0073e9SAndroid Build Coastguard Worker # iterable-style dataset 1612*da0073e9SAndroid Build Coastguard Worker dataset = CountingIterableDataset(20) 1613*da0073e9SAndroid Build Coastguard Worker # no auto-batching 1614*da0073e9SAndroid Build Coastguard Worker fetched = coll_ty( 1615*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1616*da0073e9SAndroid Build Coastguard Worker dataset, batch_size=None, num_workers=num_workers 1617*da0073e9SAndroid Build Coastguard Worker ) 1618*da0073e9SAndroid Build Coastguard Worker ) 1619*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fetched, coll_ty(range(20))) 1620*da0073e9SAndroid Build Coastguard Worker # auto-batching 1621*da0073e9SAndroid Build Coastguard Worker # this IterableDataset isn't configured for each worker, so for 1622*da0073e9SAndroid Build Coastguard Worker # the equality test below to be valid, we cannot have more than 1 workers. 1623*da0073e9SAndroid Build Coastguard Worker assert num_workers in [0, 1], "invalid test" 1624*da0073e9SAndroid Build Coastguard Worker fetched = coll_ty( 1625*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1626*da0073e9SAndroid Build Coastguard Worker dataset, batch_size=2, num_workers=num_workers 1627*da0073e9SAndroid Build Coastguard Worker ) 1628*da0073e9SAndroid Build Coastguard Worker ) 1629*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1630*da0073e9SAndroid Build Coastguard Worker fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2)) 1631*da0073e9SAndroid Build Coastguard Worker ) 1632*da0073e9SAndroid Build Coastguard Worker 1633*da0073e9SAndroid Build Coastguard Worker def test_iterable_style_dataset(self): 1634*da0073e9SAndroid Build Coastguard Worker # [no auto-batching] single process loading 1635*da0073e9SAndroid Build Coastguard Worker dataset = CountingIterableDataset(20) 1636*da0073e9SAndroid Build Coastguard Worker dataloader = self._get_data_loader(dataset, batch_size=None) 1637*da0073e9SAndroid Build Coastguard Worker fetched = list(dataloader) 1638*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(fetched), 20) 1639*da0073e9SAndroid Build Coastguard Worker for i, d in enumerate(fetched): 1640*da0073e9SAndroid Build Coastguard Worker # non-batched should not convert ints into tensors 1641*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(d, int) 1642*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d, i) 1643*da0073e9SAndroid Build Coastguard Worker # DataLoader should match len of the iterable-style dataset (if implemented) 1644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(dataloader), len(dataset)) 1645*da0073e9SAndroid Build Coastguard Worker 1646*da0073e9SAndroid Build Coastguard Worker # [no auto-batching] multiprocessing loading 1647*da0073e9SAndroid Build Coastguard Worker num_workers = 3 1648*da0073e9SAndroid Build Coastguard Worker sizes_for_all_workers = [0, 4, 20] 1649*da0073e9SAndroid Build Coastguard Worker expected = sorted( 1650*da0073e9SAndroid Build Coastguard Worker functools.reduce( 1651*da0073e9SAndroid Build Coastguard Worker operator.iadd, (list(range(s)) for s in sizes_for_all_workers), [] 1652*da0073e9SAndroid Build Coastguard Worker ) 1653*da0073e9SAndroid Build Coastguard Worker ) 1654*da0073e9SAndroid Build Coastguard Worker assert len(sizes_for_all_workers) == num_workers, "invalid test case" 1655*da0073e9SAndroid Build Coastguard Worker for prefetch_factor in [2, 3, 4]: 1656*da0073e9SAndroid Build Coastguard Worker dataset = WorkerSpecificIterableDataset(sizes_for_all_workers) 1657*da0073e9SAndroid Build Coastguard Worker dataloader = self._get_data_loader( 1658*da0073e9SAndroid Build Coastguard Worker dataset, 1659*da0073e9SAndroid Build Coastguard Worker num_workers=num_workers, 1660*da0073e9SAndroid Build Coastguard Worker batch_size=None, 1661*da0073e9SAndroid Build Coastguard Worker worker_init_fn=set_faulthander_if_available, 1662*da0073e9SAndroid Build Coastguard Worker prefetch_factor=prefetch_factor, 1663*da0073e9SAndroid Build Coastguard Worker ) 1664*da0073e9SAndroid Build Coastguard Worker dataloader_iter = iter(dataloader) 1665*da0073e9SAndroid Build Coastguard Worker fetched = sorted(dataloader_iter) 1666*da0073e9SAndroid Build Coastguard Worker for a, b in zip(fetched, expected): 1667*da0073e9SAndroid Build Coastguard Worker # non-batched should not convert ints into tensors 1668*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(a, int) 1669*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 1670*da0073e9SAndroid Build Coastguard Worker # DataLoader should match len of the iterable-style dataset (if implemented) 1671*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(dataloader), len(dataset)) 1672*da0073e9SAndroid Build Coastguard Worker # When loading more than len(dataset) data, after accessing len(dataloader), 1673*da0073e9SAndroid Build Coastguard Worker # we should get a warning. See NOTE [ IterableDataset and __len__ ]. 1674*da0073e9SAndroid Build Coastguard Worker dataset = CountingIterableDataset(20) 1675*da0073e9SAndroid Build Coastguard Worker dataloader = self._get_data_loader( 1676*da0073e9SAndroid Build Coastguard Worker dataset, 1677*da0073e9SAndroid Build Coastguard Worker num_workers=num_workers, 1678*da0073e9SAndroid Build Coastguard Worker worker_init_fn=set_faulthander_if_available, 1679*da0073e9SAndroid Build Coastguard Worker prefetch_factor=prefetch_factor, 1680*da0073e9SAndroid Build Coastguard Worker ) 1681*da0073e9SAndroid Build Coastguard Worker it = iter(dataloader) 1682*da0073e9SAndroid Build Coastguard Worker for _ in range(40): 1683*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn( 1684*da0073e9SAndroid Build Coastguard Worker lambda: next(it), "Should not warn before accessing len(dataloader)" 1685*da0073e9SAndroid Build Coastguard Worker ) 1686*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(dataloader), len(dataset)) 1687*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(dataloader), 20) 1688*da0073e9SAndroid Build Coastguard Worker it = iter(dataloader) 1689*da0073e9SAndroid Build Coastguard Worker for _ in range(20): 1690*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn( 1691*da0073e9SAndroid Build Coastguard Worker lambda: next(it), "Should not warn before exceeding length" 1692*da0073e9SAndroid Build Coastguard Worker ) 1693*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 1694*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 1695*da0073e9SAndroid Build Coastguard Worker UserWarning, 1696*da0073e9SAndroid Build Coastguard Worker r"but [0-9]+ samples have been fetched\. For multiprocessing data-loading, this", 1697*da0073e9SAndroid Build Coastguard Worker msg="Should always warn after exceeding length", 1698*da0073e9SAndroid Build Coastguard Worker ): 1699*da0073e9SAndroid Build Coastguard Worker next(it) 1700*da0073e9SAndroid Build Coastguard Worker # [no auto-batching] test that workers exit gracefully 1701*da0073e9SAndroid Build Coastguard Worker workers = dataloader_iter._workers 1702*da0073e9SAndroid Build Coastguard Worker del dataloader_iter 1703*da0073e9SAndroid Build Coastguard Worker del dataloader 1704*da0073e9SAndroid Build Coastguard Worker try: 1705*da0073e9SAndroid Build Coastguard Worker for w in workers: 1706*da0073e9SAndroid Build Coastguard Worker w.join(JOIN_TIMEOUT) 1707*da0073e9SAndroid Build Coastguard Worker self.assertFalse(w.is_alive()) 1708*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w.exitcode, 0) 1709*da0073e9SAndroid Build Coastguard Worker finally: 1710*da0073e9SAndroid Build Coastguard Worker for w in workers: 1711*da0073e9SAndroid Build Coastguard Worker w.terminate() 1712*da0073e9SAndroid Build Coastguard Worker 1713*da0073e9SAndroid Build Coastguard Worker # [auto-batching] single process loading 1714*da0073e9SAndroid Build Coastguard Worker dataset = CountingIterableDataset(20) 1715*da0073e9SAndroid Build Coastguard Worker fetched = list(self._get_data_loader(dataset, batch_size=7)) 1716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(fetched), 3) 1717*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fetched[0].tolist(), list(range(7))) 1718*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fetched[1].tolist(), list(range(7, 14))) 1719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fetched[2].tolist(), list(range(14, 20))) 1720*da0073e9SAndroid Build Coastguard Worker 1721*da0073e9SAndroid Build Coastguard Worker # [auto-batching] multiprocessing loading 1722*da0073e9SAndroid Build Coastguard Worker num_workers = 3 1723*da0073e9SAndroid Build Coastguard Worker sizes_for_all_workers = [0, 4, 20] 1724*da0073e9SAndroid Build Coastguard Worker expected = sorted( 1725*da0073e9SAndroid Build Coastguard Worker functools.reduce( 1726*da0073e9SAndroid Build Coastguard Worker operator.iadd, (list(range(s)) for s in sizes_for_all_workers), [] 1727*da0073e9SAndroid Build Coastguard Worker ) 1728*da0073e9SAndroid Build Coastguard Worker ) 1729*da0073e9SAndroid Build Coastguard Worker assert len(sizes_for_all_workers) == num_workers, "invalid test case" 1730*da0073e9SAndroid Build Coastguard Worker for prefetch_factor in [2, 3, 4]: 1731*da0073e9SAndroid Build Coastguard Worker dataset = WorkerSpecificIterableDataset(sizes_for_all_workers) 1732*da0073e9SAndroid Build Coastguard Worker # worker 0 should return 0 batches 1733*da0073e9SAndroid Build Coastguard Worker # worker 1 should return 1 batches 1734*da0073e9SAndroid Build Coastguard Worker # worker 2 should return 3 batches 1735*da0073e9SAndroid Build Coastguard Worker dataloader = self._get_data_loader( 1736*da0073e9SAndroid Build Coastguard Worker dataset, 1737*da0073e9SAndroid Build Coastguard Worker num_workers=num_workers, 1738*da0073e9SAndroid Build Coastguard Worker batch_size=7, 1739*da0073e9SAndroid Build Coastguard Worker prefetch_factor=prefetch_factor, 1740*da0073e9SAndroid Build Coastguard Worker ) 1741*da0073e9SAndroid Build Coastguard Worker dataloader_iter = iter(dataloader) 1742*da0073e9SAndroid Build Coastguard Worker fetched = list(dataloader_iter) 1743*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(fetched), 4) 1744*da0073e9SAndroid Build Coastguard Worker fetched = {tuple(t.tolist()) for t in fetched} 1745*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1746*da0073e9SAndroid Build Coastguard Worker fetched, 1747*da0073e9SAndroid Build Coastguard Worker { 1748*da0073e9SAndroid Build Coastguard Worker tuple(range(4)), 1749*da0073e9SAndroid Build Coastguard Worker tuple(range(7)), 1750*da0073e9SAndroid Build Coastguard Worker tuple(range(7, 14)), 1751*da0073e9SAndroid Build Coastguard Worker tuple(range(14, 20)), 1752*da0073e9SAndroid Build Coastguard Worker }, 1753*da0073e9SAndroid Build Coastguard Worker ) 1754*da0073e9SAndroid Build Coastguard Worker 1755*da0073e9SAndroid Build Coastguard Worker # [auto-batching] test that workers exit gracefully 1756*da0073e9SAndroid Build Coastguard Worker workers = dataloader_iter._workers 1757*da0073e9SAndroid Build Coastguard Worker del dataloader_iter 1758*da0073e9SAndroid Build Coastguard Worker del dataloader 1759*da0073e9SAndroid Build Coastguard Worker try: 1760*da0073e9SAndroid Build Coastguard Worker for w in workers: 1761*da0073e9SAndroid Build Coastguard Worker w.join(JOIN_TIMEOUT) 1762*da0073e9SAndroid Build Coastguard Worker self.assertFalse(w.is_alive()) 1763*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w.exitcode, 0) 1764*da0073e9SAndroid Build Coastguard Worker finally: 1765*da0073e9SAndroid Build Coastguard Worker for w in workers: 1766*da0073e9SAndroid Build Coastguard Worker w.terminate() 1767*da0073e9SAndroid Build Coastguard Worker # [auto-batching & drop_last] single process loading 1768*da0073e9SAndroid Build Coastguard Worker dataset = CountingIterableDataset(20) 1769*da0073e9SAndroid Build Coastguard Worker fetched = list(self._get_data_loader(dataset, batch_size=7, drop_last=True)) 1770*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(fetched), 2) 1771*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fetched[0].tolist(), list(range(7))) 1772*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fetched[1].tolist(), list(range(7, 14))) 1773*da0073e9SAndroid Build Coastguard Worker 1774*da0073e9SAndroid Build Coastguard Worker # [auto-batching & drop_last] multiprocessing loading 1775*da0073e9SAndroid Build Coastguard Worker num_workers = 3 1776*da0073e9SAndroid Build Coastguard Worker sizes_for_all_workers = [0, 4, 20] 1777*da0073e9SAndroid Build Coastguard Worker expected = sorted( 1778*da0073e9SAndroid Build Coastguard Worker functools.reduce( 1779*da0073e9SAndroid Build Coastguard Worker operator.iadd, (list(range(s)) for s in sizes_for_all_workers), [] 1780*da0073e9SAndroid Build Coastguard Worker ) 1781*da0073e9SAndroid Build Coastguard Worker ) 1782*da0073e9SAndroid Build Coastguard Worker assert len(sizes_for_all_workers) == num_workers, "invalid test case" 1783*da0073e9SAndroid Build Coastguard Worker for prefetch_factor in [2, 3, 4]: 1784*da0073e9SAndroid Build Coastguard Worker dataset = WorkerSpecificIterableDataset(sizes_for_all_workers) 1785*da0073e9SAndroid Build Coastguard Worker # worker 0 should return 0 batches 1786*da0073e9SAndroid Build Coastguard Worker # worker 1 should return 1 batches 1787*da0073e9SAndroid Build Coastguard Worker # worker 2 should return 3 batches 1788*da0073e9SAndroid Build Coastguard Worker dataloader = self._get_data_loader( 1789*da0073e9SAndroid Build Coastguard Worker dataset, 1790*da0073e9SAndroid Build Coastguard Worker num_workers=num_workers, 1791*da0073e9SAndroid Build Coastguard Worker batch_size=7, 1792*da0073e9SAndroid Build Coastguard Worker drop_last=True, 1793*da0073e9SAndroid Build Coastguard Worker worker_init_fn=set_faulthander_if_available, 1794*da0073e9SAndroid Build Coastguard Worker prefetch_factor=prefetch_factor, 1795*da0073e9SAndroid Build Coastguard Worker ) 1796*da0073e9SAndroid Build Coastguard Worker dataloader_iter = iter(dataloader) 1797*da0073e9SAndroid Build Coastguard Worker fetched = list(dataloader_iter) 1798*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(fetched), 2) 1799*da0073e9SAndroid Build Coastguard Worker fetched = {tuple(t.tolist()) for t in fetched} 1800*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fetched, {tuple(range(7)), tuple(range(7, 14))}) 1801*da0073e9SAndroid Build Coastguard Worker 1802*da0073e9SAndroid Build Coastguard Worker # [auto-batching & drop_last] test that workers exit gracefully 1803*da0073e9SAndroid Build Coastguard Worker workers = dataloader_iter._workers 1804*da0073e9SAndroid Build Coastguard Worker del dataloader_iter 1805*da0073e9SAndroid Build Coastguard Worker del dataloader 1806*da0073e9SAndroid Build Coastguard Worker try: 1807*da0073e9SAndroid Build Coastguard Worker for w in workers: 1808*da0073e9SAndroid Build Coastguard Worker w.join(JOIN_TIMEOUT) 1809*da0073e9SAndroid Build Coastguard Worker self.assertFalse(w.is_alive()) 1810*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w.exitcode, 0) 1811*da0073e9SAndroid Build Coastguard Worker finally: 1812*da0073e9SAndroid Build Coastguard Worker for w in workers: 1813*da0073e9SAndroid Build Coastguard Worker w.terminate() 1814*da0073e9SAndroid Build Coastguard Worker 1815*da0073e9SAndroid Build Coastguard Worker def test_chain_iterable_style_dataset(self): 1816*da0073e9SAndroid Build Coastguard Worker # chaining (concatenation) 1817*da0073e9SAndroid Build Coastguard Worker dataset1 = CountingIterableDataset(20) 1818*da0073e9SAndroid Build Coastguard Worker dataset2 = CountingIterableDataset(15) 1819*da0073e9SAndroid Build Coastguard Worker expected = list(range(20)) + list(range(15)) 1820*da0073e9SAndroid Build Coastguard Worker for num_workers in [0, 1]: 1821*da0073e9SAndroid Build Coastguard Worker for chained_dataset in [ 1822*da0073e9SAndroid Build Coastguard Worker dataset1 + dataset2, 1823*da0073e9SAndroid Build Coastguard Worker ChainDataset([dataset1, dataset2]), 1824*da0073e9SAndroid Build Coastguard Worker ]: 1825*da0073e9SAndroid Build Coastguard Worker fetched = list( 1826*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(chained_dataset, num_workers=num_workers) 1827*da0073e9SAndroid Build Coastguard Worker ) 1828*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(fetched), len(expected)) 1829*da0073e9SAndroid Build Coastguard Worker for e, d in zip(expected, fetched): 1830*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(d, torch.Tensor) 1831*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e, d) 1832*da0073e9SAndroid Build Coastguard Worker 1833*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1834*da0073e9SAndroid Build Coastguard Worker AssertionError, "ChainDataset only supports IterableDataset" 1835*da0073e9SAndroid Build Coastguard Worker ): 1836*da0073e9SAndroid Build Coastguard Worker list(iter(dataset1 + self.dataset)) 1837*da0073e9SAndroid Build Coastguard Worker 1838*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1839*da0073e9SAndroid Build Coastguard Worker AssertionError, "ChainDataset only supports IterableDataset" 1840*da0073e9SAndroid Build Coastguard Worker ): 1841*da0073e9SAndroid Build Coastguard Worker list(iter(ChainDataset([dataset1, self.dataset]))) 1842*da0073e9SAndroid Build Coastguard Worker 1843*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_MACOS, "Not working on macos") 1844*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 1845*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # https://github.com/pytorch/pytorch/issues/90940 1846*da0073e9SAndroid Build Coastguard Worker def test_multiprocessing_contexts(self): 1847*da0073e9SAndroid Build Coastguard Worker reference = [ 1848*da0073e9SAndroid Build Coastguard Worker torch.arange(3), 1849*da0073e9SAndroid Build Coastguard Worker torch.arange(3, 6), 1850*da0073e9SAndroid Build Coastguard Worker torch.arange(6, 9), 1851*da0073e9SAndroid Build Coastguard Worker torch.arange(9, 11), 1852*da0073e9SAndroid Build Coastguard Worker ] 1853*da0073e9SAndroid Build Coastguard Worker counting_ds_n = 11 1854*da0073e9SAndroid Build Coastguard Worker dl_common_args = dict(num_workers=3, batch_size=3, pin_memory=(not TEST_CUDA)) 1855*da0073e9SAndroid Build Coastguard Worker for ctx in supported_multiprocessing_contexts: 1856*da0073e9SAndroid Build Coastguard Worker # windows and jetson devices don't support sharing cuda tensor; ROCm does not yet fully support IPC 1857*da0073e9SAndroid Build Coastguard Worker if ( 1858*da0073e9SAndroid Build Coastguard Worker ctx in ["spawn", "forkserver"] 1859*da0073e9SAndroid Build Coastguard Worker and TEST_CUDA 1860*da0073e9SAndroid Build Coastguard Worker and not IS_WINDOWS 1861*da0073e9SAndroid Build Coastguard Worker and not IS_JETSON 1862*da0073e9SAndroid Build Coastguard Worker ): 1863*da0073e9SAndroid Build Coastguard Worker ds_cls = CUDACountingDataset 1864*da0073e9SAndroid Build Coastguard Worker else: 1865*da0073e9SAndroid Build Coastguard Worker ds_cls = CountingDataset 1866*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1867*da0073e9SAndroid Build Coastguard Worker reference, 1868*da0073e9SAndroid Build Coastguard Worker list( 1869*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1870*da0073e9SAndroid Build Coastguard Worker ds_cls(counting_ds_n), 1871*da0073e9SAndroid Build Coastguard Worker multiprocessing_context=ctx, 1872*da0073e9SAndroid Build Coastguard Worker **dl_common_args, 1873*da0073e9SAndroid Build Coastguard Worker ) 1874*da0073e9SAndroid Build Coastguard Worker ), 1875*da0073e9SAndroid Build Coastguard Worker ) 1876*da0073e9SAndroid Build Coastguard Worker if ctx is not None: 1877*da0073e9SAndroid Build Coastguard Worker # test ctx object 1878*da0073e9SAndroid Build Coastguard Worker ctx = mp.get_context(ctx) 1879*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1880*da0073e9SAndroid Build Coastguard Worker reference, 1881*da0073e9SAndroid Build Coastguard Worker list( 1882*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 1883*da0073e9SAndroid Build Coastguard Worker ds_cls(counting_ds_n), 1884*da0073e9SAndroid Build Coastguard Worker multiprocessing_context=ctx, 1885*da0073e9SAndroid Build Coastguard Worker **dl_common_args, 1886*da0073e9SAndroid Build Coastguard Worker ) 1887*da0073e9SAndroid Build Coastguard Worker ), 1888*da0073e9SAndroid Build Coastguard Worker ) 1889*da0073e9SAndroid Build Coastguard Worker 1890*da0073e9SAndroid Build Coastguard Worker def _test_multiprocessing_iterdatapipe(self, with_dill): 1891*da0073e9SAndroid Build Coastguard Worker # Testing to make sure that function from global scope (e.g. imported from library) can be serialized 1892*da0073e9SAndroid Build Coastguard Worker # and used with multiprocess DataLoader 1893*da0073e9SAndroid Build Coastguard Worker 1894*da0073e9SAndroid Build Coastguard Worker reference = [ 1895*da0073e9SAndroid Build Coastguard Worker torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64), 1896*da0073e9SAndroid Build Coastguard Worker torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64), 1897*da0073e9SAndroid Build Coastguard Worker ] 1898*da0073e9SAndroid Build Coastguard Worker datapipe: IterDataPipe = IterableWrapper([[1, 2, 3, 4], [1, 2, 3, 4, 5, 6]]) 1899*da0073e9SAndroid Build Coastguard Worker datapipe = datapipe.map(row_processor) 1900*da0073e9SAndroid Build Coastguard Worker datapipe = ( 1901*da0073e9SAndroid Build Coastguard Worker datapipe.filter(lambda row: len(row) == 4) 1902*da0073e9SAndroid Build Coastguard Worker if with_dill 1903*da0073e9SAndroid Build Coastguard Worker else datapipe.filter(filter_len) 1904*da0073e9SAndroid Build Coastguard Worker ) 1905*da0073e9SAndroid Build Coastguard Worker 1906*da0073e9SAndroid Build Coastguard Worker dl_common_args = dict( 1907*da0073e9SAndroid Build Coastguard Worker num_workers=2, batch_size=2, shuffle=True, pin_memory=(not TEST_CUDA) 1908*da0073e9SAndroid Build Coastguard Worker ) 1909*da0073e9SAndroid Build Coastguard Worker for ctx in supported_multiprocessing_contexts: 1910*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1911*da0073e9SAndroid Build Coastguard Worker reference, 1912*da0073e9SAndroid Build Coastguard Worker [ 1913*da0073e9SAndroid Build Coastguard Worker t.type(torch.int64) 1914*da0073e9SAndroid Build Coastguard Worker for t in self._get_data_loader( 1915*da0073e9SAndroid Build Coastguard Worker datapipe, multiprocessing_context=ctx, **dl_common_args 1916*da0073e9SAndroid Build Coastguard Worker ) 1917*da0073e9SAndroid Build Coastguard Worker ], 1918*da0073e9SAndroid Build Coastguard Worker ) 1919*da0073e9SAndroid Build Coastguard Worker if ctx is not None: 1920*da0073e9SAndroid Build Coastguard Worker # test ctx object 1921*da0073e9SAndroid Build Coastguard Worker ctx = mp.get_context(ctx) 1922*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1923*da0073e9SAndroid Build Coastguard Worker reference, 1924*da0073e9SAndroid Build Coastguard Worker [ 1925*da0073e9SAndroid Build Coastguard Worker t.type(torch.int64) 1926*da0073e9SAndroid Build Coastguard Worker for t in self._get_data_loader( 1927*da0073e9SAndroid Build Coastguard Worker datapipe, multiprocessing_context=ctx, **dl_common_args 1928*da0073e9SAndroid Build Coastguard Worker ) 1929*da0073e9SAndroid Build Coastguard Worker ], 1930*da0073e9SAndroid Build Coastguard Worker ) 1931*da0073e9SAndroid Build Coastguard Worker 1932*da0073e9SAndroid Build Coastguard Worker @skipIfNoNumpy 1933*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 1934*da0073e9SAndroid Build Coastguard Worker def test_multiprocessing_iterdatapipe(self): 1935*da0073e9SAndroid Build Coastguard Worker self._test_multiprocessing_iterdatapipe(with_dill=False) 1936*da0073e9SAndroid Build Coastguard Worker 1937*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 1938*da0073e9SAndroid Build Coastguard Worker @skipIfNoNumpy 1939*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 1940*da0073e9SAndroid Build Coastguard Worker @skipIfNoDill 1941*da0073e9SAndroid Build Coastguard Worker def test_multiprocessing_iterdatapipe_with_dill(self): 1942*da0073e9SAndroid Build Coastguard Worker self._test_multiprocessing_iterdatapipe(with_dill=True) 1943*da0073e9SAndroid Build Coastguard Worker 1944*da0073e9SAndroid Build Coastguard Worker def test_worker_seed(self): 1945*da0073e9SAndroid Build Coastguard Worker num_workers = 6 1946*da0073e9SAndroid Build Coastguard Worker batch_size = 1 1947*da0073e9SAndroid Build Coastguard Worker dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers) 1948*da0073e9SAndroid Build Coastguard Worker dataloader = self._get_data_loader( 1949*da0073e9SAndroid Build Coastguard Worker dataset, batch_size=batch_size, num_workers=num_workers 1950*da0073e9SAndroid Build Coastguard Worker ) 1951*da0073e9SAndroid Build Coastguard Worker seeds = set() 1952*da0073e9SAndroid Build Coastguard Worker seeds.update(batch[0] for batch in dataloader) 1953*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(seeds), num_workers) 1954*da0073e9SAndroid Build Coastguard Worker 1955*da0073e9SAndroid Build Coastguard Worker def test_worker_seed_reproducibility(self): 1956*da0073e9SAndroid Build Coastguard Worker def get_dataloader(): 1957*da0073e9SAndroid Build Coastguard Worker return DataLoader( 1958*da0073e9SAndroid Build Coastguard Worker dataset, 1959*da0073e9SAndroid Build Coastguard Worker batch_size=batch_size, 1960*da0073e9SAndroid Build Coastguard Worker num_workers=num_workers, 1961*da0073e9SAndroid Build Coastguard Worker generator=torch.Generator().manual_seed(42), 1962*da0073e9SAndroid Build Coastguard Worker ) 1963*da0073e9SAndroid Build Coastguard Worker 1964*da0073e9SAndroid Build Coastguard Worker num_workers = 6 1965*da0073e9SAndroid Build Coastguard Worker batch_size = 1 1966*da0073e9SAndroid Build Coastguard Worker dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers) 1967*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1968*da0073e9SAndroid Build Coastguard Worker {int(batch) for batch in get_dataloader()}, 1969*da0073e9SAndroid Build Coastguard Worker {int(batch) for batch in get_dataloader()}, 1970*da0073e9SAndroid Build Coastguard Worker ) 1971*da0073e9SAndroid Build Coastguard Worker 1972*da0073e9SAndroid Build Coastguard Worker def test_multi_epochs_reproducibility(self): 1973*da0073e9SAndroid Build Coastguard Worker num_workers = 2 1974*da0073e9SAndroid Build Coastguard Worker batch_size = 10 1975*da0073e9SAndroid Build Coastguard Worker num_epochs = 3 1976*da0073e9SAndroid Build Coastguard Worker 1977*da0073e9SAndroid Build Coastguard Worker dataset = TestMultiEpochDataset(batch_size * num_workers) 1978*da0073e9SAndroid Build Coastguard Worker dataloader = self._get_data_loader( 1979*da0073e9SAndroid Build Coastguard Worker dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 1980*da0073e9SAndroid Build Coastguard Worker ) 1981*da0073e9SAndroid Build Coastguard Worker 1982*da0073e9SAndroid Build Coastguard Worker for ind in range(num_epochs): 1983*da0073e9SAndroid Build Coastguard Worker for batch_idx, sample in enumerate(dataloader): 1984*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1985*da0073e9SAndroid Build Coastguard Worker sample.tolist(), [batch_idx % num_workers] * batch_size 1986*da0073e9SAndroid Build Coastguard Worker ) 1987*da0073e9SAndroid Build Coastguard Worker 1988*da0073e9SAndroid Build Coastguard Worker def test_worker_init_fn(self): 1989*da0073e9SAndroid Build Coastguard Worker dataset = SeedDataset(4) 1990*da0073e9SAndroid Build Coastguard Worker dataloader = self._get_data_loader( 1991*da0073e9SAndroid Build Coastguard Worker dataset, batch_size=2, num_workers=2, worker_init_fn=init_fn 1992*da0073e9SAndroid Build Coastguard Worker ) 1993*da0073e9SAndroid Build Coastguard Worker for batch in dataloader: 1994*da0073e9SAndroid Build Coastguard Worker self.assertEqual(12345, batch[0]) 1995*da0073e9SAndroid Build Coastguard Worker self.assertEqual(12345, batch[1]) 1996*da0073e9SAndroid Build Coastguard Worker 1997*da0073e9SAndroid Build Coastguard Worker def test_get_worker_info(self): 1998*da0073e9SAndroid Build Coastguard Worker p = ErrorTrackingProcess(target=_test_get_worker_info) 1999*da0073e9SAndroid Build Coastguard Worker p.start() 2000*da0073e9SAndroid Build Coastguard Worker p.join(JOIN_TIMEOUT) 2001*da0073e9SAndroid Build Coastguard Worker try: 2002*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p.is_alive()) 2003*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p.exitcode, 0) 2004*da0073e9SAndroid Build Coastguard Worker finally: 2005*da0073e9SAndroid Build Coastguard Worker p.terminate() 2006*da0073e9SAndroid Build Coastguard Worker 2007*da0073e9SAndroid Build Coastguard Worker def test_shuffle(self): 2008*da0073e9SAndroid Build Coastguard Worker self._test_shuffle(self._get_data_loader(self.dataset, shuffle=True)) 2009*da0073e9SAndroid Build Coastguard Worker 2010*da0073e9SAndroid Build Coastguard Worker def test_shuffle_batch_none(self): 2011*da0073e9SAndroid Build Coastguard Worker self._test_shuffle(DataLoader(self.dataset, batch_size=None, shuffle=True)) 2012*da0073e9SAndroid Build Coastguard Worker 2013*da0073e9SAndroid Build Coastguard Worker def test_shuffle_batch(self): 2014*da0073e9SAndroid Build Coastguard Worker self._test_shuffle( 2015*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(self.dataset, batch_size=2, shuffle=True) 2016*da0073e9SAndroid Build Coastguard Worker ) 2017*da0073e9SAndroid Build Coastguard Worker 2018*da0073e9SAndroid Build Coastguard Worker def test_shuffle_reproducibility(self): 2019*da0073e9SAndroid Build Coastguard Worker for fn in ( 2020*da0073e9SAndroid Build Coastguard Worker lambda: DataLoader( 2021*da0073e9SAndroid Build Coastguard Worker self.dataset, 2022*da0073e9SAndroid Build Coastguard Worker shuffle=True, 2023*da0073e9SAndroid Build Coastguard Worker num_workers=0, 2024*da0073e9SAndroid Build Coastguard Worker generator=torch.Generator().manual_seed(42), 2025*da0073e9SAndroid Build Coastguard Worker ), 2026*da0073e9SAndroid Build Coastguard Worker lambda: DataLoader( 2027*da0073e9SAndroid Build Coastguard Worker self.dataset, 2028*da0073e9SAndroid Build Coastguard Worker shuffle=True, 2029*da0073e9SAndroid Build Coastguard Worker num_workers=2, 2030*da0073e9SAndroid Build Coastguard Worker generator=torch.Generator().manual_seed(42), 2031*da0073e9SAndroid Build Coastguard Worker ), 2032*da0073e9SAndroid Build Coastguard Worker ): 2033*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(fn()), list(fn())) 2034*da0073e9SAndroid Build Coastguard Worker 2035*da0073e9SAndroid Build Coastguard Worker def test_sequential_workers(self): 2036*da0073e9SAndroid Build Coastguard Worker self._test_sequential(self._get_data_loader(self.dataset, num_workers=4)) 2037*da0073e9SAndroid Build Coastguard Worker 2038*da0073e9SAndroid Build Coastguard Worker def test_seqential_batch_workers(self): 2039*da0073e9SAndroid Build Coastguard Worker self._test_sequential( 2040*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(self.dataset, batch_size=2, num_workers=4) 2041*da0073e9SAndroid Build Coastguard Worker ) 2042*da0073e9SAndroid Build Coastguard Worker 2043*da0073e9SAndroid Build Coastguard Worker def test_seqential_batch_workers_prefetch(self): 2044*da0073e9SAndroid Build Coastguard Worker self._test_sequential( 2045*da0073e9SAndroid Build Coastguard Worker DataLoader(self.dataset, batch_size=2, num_workers=4, prefetch_factor=3) 2046*da0073e9SAndroid Build Coastguard Worker ) 2047*da0073e9SAndroid Build Coastguard Worker 2048*da0073e9SAndroid Build Coastguard Worker def test_shuffle_workers(self): 2049*da0073e9SAndroid Build Coastguard Worker self._test_shuffle( 2050*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(self.dataset, shuffle=True, num_workers=4) 2051*da0073e9SAndroid Build Coastguard Worker ) 2052*da0073e9SAndroid Build Coastguard Worker 2053*da0073e9SAndroid Build Coastguard Worker def test_shuffle_batch_workers(self): 2054*da0073e9SAndroid Build Coastguard Worker self._test_shuffle( 2055*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 2056*da0073e9SAndroid Build Coastguard Worker self.dataset, batch_size=2, shuffle=True, num_workers=4 2057*da0073e9SAndroid Build Coastguard Worker ) 2058*da0073e9SAndroid Build Coastguard Worker ) 2059*da0073e9SAndroid Build Coastguard Worker 2060*da0073e9SAndroid Build Coastguard Worker def test_shuffle_batch_workers_prefetch(self): 2061*da0073e9SAndroid Build Coastguard Worker self._test_shuffle( 2062*da0073e9SAndroid Build Coastguard Worker DataLoader( 2063*da0073e9SAndroid Build Coastguard Worker self.dataset, 2064*da0073e9SAndroid Build Coastguard Worker batch_size=2, 2065*da0073e9SAndroid Build Coastguard Worker shuffle=True, 2066*da0073e9SAndroid Build Coastguard Worker num_workers=4, 2067*da0073e9SAndroid Build Coastguard Worker prefetch_factor=3, 2068*da0073e9SAndroid Build Coastguard Worker ) 2069*da0073e9SAndroid Build Coastguard Worker ) 2070*da0073e9SAndroid Build Coastguard Worker 2071*da0073e9SAndroid Build Coastguard Worker def test_random_sampler(self): 2072*da0073e9SAndroid Build Coastguard Worker from collections import Counter 2073*da0073e9SAndroid Build Coastguard Worker 2074*da0073e9SAndroid Build Coastguard Worker from torch.utils.data import RandomSampler 2075*da0073e9SAndroid Build Coastguard Worker 2076*da0073e9SAndroid Build Coastguard Worker def sample_stat(sampler, num_samples): 2077*da0073e9SAndroid Build Coastguard Worker counts = Counter(sampler) 2078*da0073e9SAndroid Build Coastguard Worker count_repeated = sum(val > 1 for val in counts.values()) 2079*da0073e9SAndroid Build Coastguard Worker return ( 2080*da0073e9SAndroid Build Coastguard Worker count_repeated, 2081*da0073e9SAndroid Build Coastguard Worker min(counts.keys()), 2082*da0073e9SAndroid Build Coastguard Worker max(counts.keys()), 2083*da0073e9SAndroid Build Coastguard Worker sum(counts.values()), 2084*da0073e9SAndroid Build Coastguard Worker ) 2085*da0073e9SAndroid Build Coastguard Worker 2086*da0073e9SAndroid Build Coastguard Worker # test sample with replacement 2087*da0073e9SAndroid Build Coastguard Worker n = len(self.dataset) + 1 # ensure at least one sample is drawn more than once 2088*da0073e9SAndroid Build Coastguard Worker sampler_with_replacement = RandomSampler( 2089*da0073e9SAndroid Build Coastguard Worker self.dataset, replacement=True, num_samples=n 2090*da0073e9SAndroid Build Coastguard Worker ) 2091*da0073e9SAndroid Build Coastguard Worker count_repeated, minval, maxval, count_total = sample_stat( 2092*da0073e9SAndroid Build Coastguard Worker sampler_with_replacement, n 2093*da0073e9SAndroid Build Coastguard Worker ) 2094*da0073e9SAndroid Build Coastguard Worker self.assertTrue(count_repeated > 0) 2095*da0073e9SAndroid Build Coastguard Worker self.assertTrue(minval >= 0) 2096*da0073e9SAndroid Build Coastguard Worker self.assertTrue(maxval < len(self.dataset)) 2097*da0073e9SAndroid Build Coastguard Worker self.assertTrue(count_total == n) 2098*da0073e9SAndroid Build Coastguard Worker 2099*da0073e9SAndroid Build Coastguard Worker # test sample without replacement and without specified num_samples 2100*da0073e9SAndroid Build Coastguard Worker sampler_without_replacement = RandomSampler(self.dataset) 2101*da0073e9SAndroid Build Coastguard Worker count_repeated, minval, maxval, count_total = sample_stat( 2102*da0073e9SAndroid Build Coastguard Worker sampler_without_replacement, len(self.dataset) 2103*da0073e9SAndroid Build Coastguard Worker ) 2104*da0073e9SAndroid Build Coastguard Worker self.assertTrue(count_repeated == 0) 2105*da0073e9SAndroid Build Coastguard Worker self.assertTrue(minval == 0) 2106*da0073e9SAndroid Build Coastguard Worker self.assertTrue(maxval == len(self.dataset) - 1) 2107*da0073e9SAndroid Build Coastguard Worker self.assertTrue(count_total == len(self.dataset)) 2108*da0073e9SAndroid Build Coastguard Worker 2109*da0073e9SAndroid Build Coastguard Worker # test sample without replacement and with specified num_samples 2110*da0073e9SAndroid Build Coastguard Worker n = len(self.dataset) * 2 2111*da0073e9SAndroid Build Coastguard Worker sampler_without_replacement = RandomSampler(self.dataset, num_samples=n) 2112*da0073e9SAndroid Build Coastguard Worker count_repeated, minval, maxval, count_total = sample_stat( 2113*da0073e9SAndroid Build Coastguard Worker sampler_without_replacement, len(self.dataset) 2114*da0073e9SAndroid Build Coastguard Worker ) 2115*da0073e9SAndroid Build Coastguard Worker self.assertTrue(count_repeated == len(self.dataset)) 2116*da0073e9SAndroid Build Coastguard Worker self.assertTrue(minval == 0) 2117*da0073e9SAndroid Build Coastguard Worker self.assertTrue(maxval == len(self.dataset) - 1) 2118*da0073e9SAndroid Build Coastguard Worker self.assertTrue(count_total == n) 2119*da0073e9SAndroid Build Coastguard Worker 2120*da0073e9SAndroid Build Coastguard Worker n = len(self.dataset) - 1 2121*da0073e9SAndroid Build Coastguard Worker sampler_without_replacement = RandomSampler(self.dataset, num_samples=n) 2122*da0073e9SAndroid Build Coastguard Worker count_repeated, minval, maxval, count_total = sample_stat( 2123*da0073e9SAndroid Build Coastguard Worker sampler_without_replacement, len(self.dataset) 2124*da0073e9SAndroid Build Coastguard Worker ) 2125*da0073e9SAndroid Build Coastguard Worker self.assertTrue(count_repeated == 0) 2126*da0073e9SAndroid Build Coastguard Worker self.assertTrue(minval >= 0) 2127*da0073e9SAndroid Build Coastguard Worker self.assertTrue(maxval < len(self.dataset)) 2128*da0073e9SAndroid Build Coastguard Worker self.assertTrue(count_total == n) 2129*da0073e9SAndroid Build Coastguard Worker 2130*da0073e9SAndroid Build Coastguard Worker n = len(self.dataset) + 1 2131*da0073e9SAndroid Build Coastguard Worker sampler_without_replacement = RandomSampler(self.dataset, num_samples=n) 2132*da0073e9SAndroid Build Coastguard Worker count_repeated, minval, maxval, count_total = sample_stat( 2133*da0073e9SAndroid Build Coastguard Worker sampler_without_replacement, len(self.dataset) 2134*da0073e9SAndroid Build Coastguard Worker ) 2135*da0073e9SAndroid Build Coastguard Worker self.assertTrue(count_repeated == 1) 2136*da0073e9SAndroid Build Coastguard Worker self.assertTrue(minval == 0) 2137*da0073e9SAndroid Build Coastguard Worker self.assertTrue(maxval == len(self.dataset) - 1) 2138*da0073e9SAndroid Build Coastguard Worker self.assertTrue(count_total == n) 2139*da0073e9SAndroid Build Coastguard Worker 2140*da0073e9SAndroid Build Coastguard Worker # raise error when replacement is non-boolean 2141*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2142*da0073e9SAndroid Build Coastguard Worker TypeError, "replacement should be a boolean value, but got replacement=0" 2143*da0073e9SAndroid Build Coastguard Worker ): 2144*da0073e9SAndroid Build Coastguard Worker RandomSampler(self.dataset, replacement=0) 2145*da0073e9SAndroid Build Coastguard Worker 2146*da0073e9SAndroid Build Coastguard Worker def test_random_sampler_len_with_replacement(self): 2147*da0073e9SAndroid Build Coastguard Worker from torch.utils.data import RandomSampler 2148*da0073e9SAndroid Build Coastguard Worker 2149*da0073e9SAndroid Build Coastguard Worker # add 5 extra samples 2150*da0073e9SAndroid Build Coastguard Worker num_samples = len(self.dataset) + 5 2151*da0073e9SAndroid Build Coastguard Worker sampler = RandomSampler(self.dataset, replacement=True, num_samples=num_samples) 2152*da0073e9SAndroid Build Coastguard Worker # test len method 2153*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_samples, len(sampler)) 2154*da0073e9SAndroid Build Coastguard Worker 2155*da0073e9SAndroid Build Coastguard Worker # test with iteration 2156*da0073e9SAndroid Build Coastguard Worker count_num_samples = sum(1 for _ in sampler) 2157*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_samples, count_num_samples) 2158*da0073e9SAndroid Build Coastguard Worker 2159*da0073e9SAndroid Build Coastguard Worker # test with dataloader, batch_size = 1 2160*da0073e9SAndroid Build Coastguard Worker batch_size = 1 2161*da0073e9SAndroid Build Coastguard Worker count_num_samples_in_data_loader = len( 2162*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler) 2163*da0073e9SAndroid Build Coastguard Worker ) 2164*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_samples, count_num_samples_in_data_loader) 2165*da0073e9SAndroid Build Coastguard Worker 2166*da0073e9SAndroid Build Coastguard Worker # test with dataloader, batch_size = 6 2167*da0073e9SAndroid Build Coastguard Worker batch_size = 6 2168*da0073e9SAndroid Build Coastguard Worker count_num_samples_in_data_loader = len( 2169*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler) 2170*da0073e9SAndroid Build Coastguard Worker ) 2171*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2172*da0073e9SAndroid Build Coastguard Worker int(math.ceil(float(num_samples) / batch_size)), 2173*da0073e9SAndroid Build Coastguard Worker count_num_samples_in_data_loader, 2174*da0073e9SAndroid Build Coastguard Worker ) 2175*da0073e9SAndroid Build Coastguard Worker 2176*da0073e9SAndroid Build Coastguard Worker def test_random_sampler_len_without_replacement(self): 2177*da0073e9SAndroid Build Coastguard Worker from torch.utils.data import RandomSampler 2178*da0073e9SAndroid Build Coastguard Worker 2179*da0073e9SAndroid Build Coastguard Worker # add 5 extra samples 2180*da0073e9SAndroid Build Coastguard Worker num_samples = len(self.dataset) + 5 2181*da0073e9SAndroid Build Coastguard Worker sampler = RandomSampler( 2182*da0073e9SAndroid Build Coastguard Worker self.dataset, replacement=False, num_samples=num_samples 2183*da0073e9SAndroid Build Coastguard Worker ) 2184*da0073e9SAndroid Build Coastguard Worker # test len method 2185*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_samples, len(sampler)) 2186*da0073e9SAndroid Build Coastguard Worker 2187*da0073e9SAndroid Build Coastguard Worker # test with iteration 2188*da0073e9SAndroid Build Coastguard Worker count_num_samples = sum(1 for _ in sampler) 2189*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_samples, count_num_samples) 2190*da0073e9SAndroid Build Coastguard Worker 2191*da0073e9SAndroid Build Coastguard Worker # test with dataloader, batch_size = 1 2192*da0073e9SAndroid Build Coastguard Worker batch_size = 1 2193*da0073e9SAndroid Build Coastguard Worker count_num_samples_in_data_loader = len( 2194*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler) 2195*da0073e9SAndroid Build Coastguard Worker ) 2196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_samples, count_num_samples_in_data_loader) 2197*da0073e9SAndroid Build Coastguard Worker 2198*da0073e9SAndroid Build Coastguard Worker # test with dataloader, batch_size = 6 2199*da0073e9SAndroid Build Coastguard Worker batch_size = 6 2200*da0073e9SAndroid Build Coastguard Worker count_num_samples_in_data_loader = len( 2201*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler) 2202*da0073e9SAndroid Build Coastguard Worker ) 2203*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2204*da0073e9SAndroid Build Coastguard Worker num_samples // batch_size + (num_samples % batch_size > 0), 2205*da0073e9SAndroid Build Coastguard Worker count_num_samples_in_data_loader, 2206*da0073e9SAndroid Build Coastguard Worker ) 2207*da0073e9SAndroid Build Coastguard Worker 2208*da0073e9SAndroid Build Coastguard Worker def test_distributed_sampler_invalid_rank(self): 2209*da0073e9SAndroid Build Coastguard Worker from torch.utils.data.distributed import DistributedSampler 2210*da0073e9SAndroid Build Coastguard Worker 2211*da0073e9SAndroid Build Coastguard Worker dataset = torch.IntTensor(range(10)) 2212*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Invalid rank"): 2213*da0073e9SAndroid Build Coastguard Worker sampler = DistributedSampler(dataset, 3, 3) 2214*da0073e9SAndroid Build Coastguard Worker 2215*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Invalid rank"): 2216*da0073e9SAndroid Build Coastguard Worker sampler = DistributedSampler(dataset, 3, -1) 2217*da0073e9SAndroid Build Coastguard Worker 2218*da0073e9SAndroid Build Coastguard Worker def test_duplicating_data_with_drop_last(self): 2219*da0073e9SAndroid Build Coastguard Worker from torch.utils.data.distributed import DistributedSampler 2220*da0073e9SAndroid Build Coastguard Worker 2221*da0073e9SAndroid Build Coastguard Worker num_processes = 4 2222*da0073e9SAndroid Build Coastguard Worker num_batches = 9 2223*da0073e9SAndroid Build Coastguard Worker data_set = torch.IntTensor(range(num_batches)) 2224*da0073e9SAndroid Build Coastguard Worker scanned_data = torch.IntTensor([]) 2225*da0073e9SAndroid Build Coastguard Worker for i in range(num_processes): 2226*da0073e9SAndroid Build Coastguard Worker s = DistributedSampler(data_set, num_processes, i) 2227*da0073e9SAndroid Build Coastguard Worker d_loader = self._get_data_loader( 2228*da0073e9SAndroid Build Coastguard Worker data_set, 2229*da0073e9SAndroid Build Coastguard Worker batch_size=int(num_batches / num_processes), 2230*da0073e9SAndroid Build Coastguard Worker drop_last=True, 2231*da0073e9SAndroid Build Coastguard Worker sampler=s, 2232*da0073e9SAndroid Build Coastguard Worker ) 2233*da0073e9SAndroid Build Coastguard Worker for data in d_loader: 2234*da0073e9SAndroid Build Coastguard Worker scanned_data = torch.cat((scanned_data, data), 0) 2235*da0073e9SAndroid Build Coastguard Worker 2236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scanned_data.size(), scanned_data.unique().size()) 2237*da0073e9SAndroid Build Coastguard Worker 2238*da0073e9SAndroid Build Coastguard Worker def test_sampler_reproducibility(self): 2239*da0073e9SAndroid Build Coastguard Worker from torch.utils.data import ( 2240*da0073e9SAndroid Build Coastguard Worker RandomSampler, 2241*da0073e9SAndroid Build Coastguard Worker SubsetRandomSampler, 2242*da0073e9SAndroid Build Coastguard Worker WeightedRandomSampler, 2243*da0073e9SAndroid Build Coastguard Worker ) 2244*da0073e9SAndroid Build Coastguard Worker 2245*da0073e9SAndroid Build Coastguard Worker weights = [0.1, 0.9, 0.4, 0.7, 3.0, 0.6] 2246*da0073e9SAndroid Build Coastguard Worker for fn in ( 2247*da0073e9SAndroid Build Coastguard Worker lambda: RandomSampler( 2248*da0073e9SAndroid Build Coastguard Worker self.dataset, 2249*da0073e9SAndroid Build Coastguard Worker num_samples=5, 2250*da0073e9SAndroid Build Coastguard Worker replacement=True, 2251*da0073e9SAndroid Build Coastguard Worker generator=torch.Generator().manual_seed(42), 2252*da0073e9SAndroid Build Coastguard Worker ), 2253*da0073e9SAndroid Build Coastguard Worker lambda: RandomSampler( 2254*da0073e9SAndroid Build Coastguard Worker self.dataset, 2255*da0073e9SAndroid Build Coastguard Worker replacement=False, 2256*da0073e9SAndroid Build Coastguard Worker generator=torch.Generator().manual_seed(42), 2257*da0073e9SAndroid Build Coastguard Worker ), 2258*da0073e9SAndroid Build Coastguard Worker lambda: WeightedRandomSampler( 2259*da0073e9SAndroid Build Coastguard Worker weights, 2260*da0073e9SAndroid Build Coastguard Worker num_samples=5, 2261*da0073e9SAndroid Build Coastguard Worker replacement=True, 2262*da0073e9SAndroid Build Coastguard Worker generator=torch.Generator().manual_seed(42), 2263*da0073e9SAndroid Build Coastguard Worker ), 2264*da0073e9SAndroid Build Coastguard Worker lambda: WeightedRandomSampler( 2265*da0073e9SAndroid Build Coastguard Worker weights, 2266*da0073e9SAndroid Build Coastguard Worker num_samples=5, 2267*da0073e9SAndroid Build Coastguard Worker replacement=False, 2268*da0073e9SAndroid Build Coastguard Worker generator=torch.Generator().manual_seed(42), 2269*da0073e9SAndroid Build Coastguard Worker ), 2270*da0073e9SAndroid Build Coastguard Worker lambda: SubsetRandomSampler( 2271*da0073e9SAndroid Build Coastguard Worker range(10), generator=torch.Generator().manual_seed(42) 2272*da0073e9SAndroid Build Coastguard Worker ), 2273*da0073e9SAndroid Build Coastguard Worker ): 2274*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(fn()), list(fn())) 2275*da0073e9SAndroid Build Coastguard Worker 2276*da0073e9SAndroid Build Coastguard Worker for sampler in ( 2277*da0073e9SAndroid Build Coastguard Worker RandomSampler(self.dataset, num_samples=5, replacement=True), 2278*da0073e9SAndroid Build Coastguard Worker RandomSampler(self.dataset, replacement=False), 2279*da0073e9SAndroid Build Coastguard Worker WeightedRandomSampler(weights, num_samples=5, replacement=True), 2280*da0073e9SAndroid Build Coastguard Worker WeightedRandomSampler(weights, num_samples=5, replacement=False), 2281*da0073e9SAndroid Build Coastguard Worker SubsetRandomSampler(range(10)), 2282*da0073e9SAndroid Build Coastguard Worker ): 2283*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(0) 2284*da0073e9SAndroid Build Coastguard Worker l1 = list(sampler) + list(sampler) 2285*da0073e9SAndroid Build Coastguard Worker 2286*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(0) 2287*da0073e9SAndroid Build Coastguard Worker l2 = list(sampler) + list(sampler) 2288*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l1, l2) 2289*da0073e9SAndroid Build Coastguard Worker 2290*da0073e9SAndroid Build Coastguard Worker its = (iter(sampler), iter(sampler)) 2291*da0073e9SAndroid Build Coastguard Worker ls = ([], []) 2292*da0073e9SAndroid Build Coastguard Worker for idx in range(len(sampler)): 2293*da0073e9SAndroid Build Coastguard Worker for i in range(2): 2294*da0073e9SAndroid Build Coastguard Worker if idx == 0: 2295*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(0) 2296*da0073e9SAndroid Build Coastguard Worker ls[i].append(next(its[i])) 2297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ls[0], ls[1]) 2298*da0073e9SAndroid Build Coastguard Worker 2299*da0073e9SAndroid Build Coastguard Worker def _test_sampler(self, **kwargs): 2300*da0073e9SAndroid Build Coastguard Worker indices = range(2, 12) # using a regular iterable 2301*da0073e9SAndroid Build Coastguard Worker dl = self._get_data_loader( 2302*da0073e9SAndroid Build Coastguard Worker self.dataset, sampler=indices, batch_size=2, **kwargs 2303*da0073e9SAndroid Build Coastguard Worker ) 2304*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(dl), 5) 2305*da0073e9SAndroid Build Coastguard Worker for i, (input, _target) in enumerate(dl): 2306*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(input), 2) 2307*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, self.data[i * 2 + 2 : i * 2 + 4]) 2308*da0073e9SAndroid Build Coastguard Worker 2309*da0073e9SAndroid Build Coastguard Worker def test_sampler(self): 2310*da0073e9SAndroid Build Coastguard Worker self._test_sampler() 2311*da0073e9SAndroid Build Coastguard Worker self._test_sampler(num_workers=4) 2312*da0073e9SAndroid Build Coastguard Worker if not NO_MULTIPROCESSING_SPAWN: 2313*da0073e9SAndroid Build Coastguard Worker self._test_batch_sampler(num_workers=4, multiprocessing_context="spawn") 2314*da0073e9SAndroid Build Coastguard Worker 2315*da0073e9SAndroid Build Coastguard Worker def _test_batch_sampler(self, **kwargs): 2316*da0073e9SAndroid Build Coastguard Worker # [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...] 2317*da0073e9SAndroid Build Coastguard Worker batches = [] # using a regular iterable 2318*da0073e9SAndroid Build Coastguard Worker for i in range(0, 20, 5): 2319*da0073e9SAndroid Build Coastguard Worker batches.append(tuple(range(i, i + 2))) 2320*da0073e9SAndroid Build Coastguard Worker batches.append(tuple(range(i + 2, i + 5))) 2321*da0073e9SAndroid Build Coastguard Worker 2322*da0073e9SAndroid Build Coastguard Worker dl = self._get_data_loader(self.dataset, batch_sampler=batches, **kwargs) 2323*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(dl), 8) 2324*da0073e9SAndroid Build Coastguard Worker for i, (input, _target) in enumerate(dl): 2325*da0073e9SAndroid Build Coastguard Worker if i % 2 == 0: 2326*da0073e9SAndroid Build Coastguard Worker offset = i * 5 // 2 2327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(input), 2) 2328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, self.data[offset : offset + 2]) 2329*da0073e9SAndroid Build Coastguard Worker else: 2330*da0073e9SAndroid Build Coastguard Worker offset = i * 5 // 2 2331*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(input), 3) 2332*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, self.data[offset : offset + 3]) 2333*da0073e9SAndroid Build Coastguard Worker 2334*da0073e9SAndroid Build Coastguard Worker def test_batch_sampler(self): 2335*da0073e9SAndroid Build Coastguard Worker self._test_batch_sampler() 2336*da0073e9SAndroid Build Coastguard Worker self._test_batch_sampler(num_workers=4) 2337*da0073e9SAndroid Build Coastguard Worker if not NO_MULTIPROCESSING_SPAWN: 2338*da0073e9SAndroid Build Coastguard Worker self._test_batch_sampler(num_workers=4, multiprocessing_context="spawn") 2339*da0073e9SAndroid Build Coastguard Worker 2340*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 2341*da0073e9SAndroid Build Coastguard Worker def test_shuffle_pin_memory(self): 2342*da0073e9SAndroid Build Coastguard Worker loader = self._get_data_loader( 2343*da0073e9SAndroid Build Coastguard Worker self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True 2344*da0073e9SAndroid Build Coastguard Worker ) 2345*da0073e9SAndroid Build Coastguard Worker for input, target in loader: 2346*da0073e9SAndroid Build Coastguard Worker self.assertTrue(input.is_pinned()) 2347*da0073e9SAndroid Build Coastguard Worker self.assertTrue(target.is_pinned()) 2348*da0073e9SAndroid Build Coastguard Worker 2349*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") 2350*da0073e9SAndroid Build Coastguard Worker def test_numpy(self): 2351*da0073e9SAndroid Build Coastguard Worker import numpy as np 2352*da0073e9SAndroid Build Coastguard Worker 2353*da0073e9SAndroid Build Coastguard Worker class TestDataset(torch.utils.data.Dataset): 2354*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, i): 2355*da0073e9SAndroid Build Coastguard Worker return np.ones((2, 3, 4)) * i 2356*da0073e9SAndroid Build Coastguard Worker 2357*da0073e9SAndroid Build Coastguard Worker def __len__(self): 2358*da0073e9SAndroid Build Coastguard Worker return 1000 2359*da0073e9SAndroid Build Coastguard Worker 2360*da0073e9SAndroid Build Coastguard Worker loader = self._get_data_loader(TestDataset(), batch_size=12) 2361*da0073e9SAndroid Build Coastguard Worker batch = next(iter(loader)) 2362*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(batch, torch.DoubleTensor) 2363*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4])) 2364*da0073e9SAndroid Build Coastguard Worker 2365*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") 2366*da0073e9SAndroid Build Coastguard Worker def test_numpy_gen_state(self): 2367*da0073e9SAndroid Build Coastguard Worker from torch.utils.data._utils.worker import _generate_state 2368*da0073e9SAndroid Build Coastguard Worker 2369*da0073e9SAndroid Build Coastguard Worker # Using NumPy generated states as the reference to test `_generate_state` 2370*da0073e9SAndroid Build Coastguard Worker # having the same result. 2371*da0073e9SAndroid Build Coastguard Worker # Test case: ((worker_id, base_seed), expected_state) 2372*da0073e9SAndroid Build Coastguard Worker test_cases = [ 2373*da0073e9SAndroid Build Coastguard Worker ( 2374*da0073e9SAndroid Build Coastguard Worker (4, 13434589827475259383), 2375*da0073e9SAndroid Build Coastguard Worker (2884386318, 1088094898, 3523808998, 3860348662), 2376*da0073e9SAndroid Build Coastguard Worker ), 2377*da0073e9SAndroid Build Coastguard Worker ((1, 15014285634777110771), (1934848465, 763213760, 2959016433, 179751970)), 2378*da0073e9SAndroid Build Coastguard Worker ( 2379*da0073e9SAndroid Build Coastguard Worker (10, 978296274032934101), 2380*da0073e9SAndroid Build Coastguard Worker (1759791917, 3550927336, 1225977135, 1036538043), 2381*da0073e9SAndroid Build Coastguard Worker ), 2382*da0073e9SAndroid Build Coastguard Worker ( 2383*da0073e9SAndroid Build Coastguard Worker (12, 11868770762134256968), 2384*da0073e9SAndroid Build Coastguard Worker (3974661794, 3331131333, 3630387033, 2885815368), 2385*da0073e9SAndroid Build Coastguard Worker ), 2386*da0073e9SAndroid Build Coastguard Worker ( 2387*da0073e9SAndroid Build Coastguard Worker (9, 15378787925219019706), 2388*da0073e9SAndroid Build Coastguard Worker (3815056996, 3162224466, 2735102421, 3190253477), 2389*da0073e9SAndroid Build Coastguard Worker ), 2390*da0073e9SAndroid Build Coastguard Worker ((5, 9055612723125076328), (3522565701, 3368424109, 959377806, 621878693)), 2391*da0073e9SAndroid Build Coastguard Worker ( 2392*da0073e9SAndroid Build Coastguard Worker (15, 14617792358407278405), 2393*da0073e9SAndroid Build Coastguard Worker (3402479508, 1588702753, 1169536393, 3675067356), 2394*da0073e9SAndroid Build Coastguard Worker ), 2395*da0073e9SAndroid Build Coastguard Worker ( 2396*da0073e9SAndroid Build Coastguard Worker (9, 17363320784006640087), 2397*da0073e9SAndroid Build Coastguard Worker (957989458, 2518334477, 1421725660, 3086155459), 2398*da0073e9SAndroid Build Coastguard Worker ), 2399*da0073e9SAndroid Build Coastguard Worker ( 2400*da0073e9SAndroid Build Coastguard Worker (12, 480002904169484764), 2401*da0073e9SAndroid Build Coastguard Worker (2732851467, 1762620729, 4055801988, 1277640511), 2402*da0073e9SAndroid Build Coastguard Worker ), 2403*da0073e9SAndroid Build Coastguard Worker ( 2404*da0073e9SAndroid Build Coastguard Worker (15, 16803975943592702950), 2405*da0073e9SAndroid Build Coastguard Worker (3479415043, 4022359553, 295994005, 3358606349), 2406*da0073e9SAndroid Build Coastguard Worker ), 2407*da0073e9SAndroid Build Coastguard Worker ( 2408*da0073e9SAndroid Build Coastguard Worker (9, 11704776406047813044), 2409*da0073e9SAndroid Build Coastguard Worker (1968928009, 710113752, 2442656196, 1587420279), 2410*da0073e9SAndroid Build Coastguard Worker ), 2411*da0073e9SAndroid Build Coastguard Worker ( 2412*da0073e9SAndroid Build Coastguard Worker (10, 16357891985431864516), 2413*da0073e9SAndroid Build Coastguard Worker (1271733898, 4197047399, 3727213786, 2338547348), 2414*da0073e9SAndroid Build Coastguard Worker ), 2415*da0073e9SAndroid Build Coastguard Worker ( 2416*da0073e9SAndroid Build Coastguard Worker (2, 17423369006318065007), 2417*da0073e9SAndroid Build Coastguard Worker (544294336, 1911284083, 3299147734, 3231058347), 2418*da0073e9SAndroid Build Coastguard Worker ), 2419*da0073e9SAndroid Build Coastguard Worker ((2, 2889492011444113593), (3721591783, 2595811276, 2212881745, 977682627)), 2420*da0073e9SAndroid Build Coastguard Worker ((0, 8979703111668486195), (4276723937, 2556068849, 2962827292, 233130238)), 2421*da0073e9SAndroid Build Coastguard Worker ( 2422*da0073e9SAndroid Build Coastguard Worker (6, 6269787272229682235), 2423*da0073e9SAndroid Build Coastguard Worker (2548857855, 1216457374, 1012973562, 2999759647), 2424*da0073e9SAndroid Build Coastguard Worker ), 2425*da0073e9SAndroid Build Coastguard Worker ] 2426*da0073e9SAndroid Build Coastguard Worker 2427*da0073e9SAndroid Build Coastguard Worker for (worker_id, base_seed), exp in test_cases: 2428*da0073e9SAndroid Build Coastguard Worker self.assertEqual(exp, _generate_state(base_seed, worker_id)) 2429*da0073e9SAndroid Build Coastguard Worker 2430*da0073e9SAndroid Build Coastguard Worker def test_error(self): 2431*da0073e9SAndroid Build Coastguard Worker self._test_error( 2432*da0073e9SAndroid Build Coastguard Worker self._get_data_loader(ErrorDataset(100), batch_size=2, shuffle=True) 2433*da0073e9SAndroid Build Coastguard Worker ) 2434*da0073e9SAndroid Build Coastguard Worker 2435*da0073e9SAndroid Build Coastguard Worker def test_error_workers(self): 2436*da0073e9SAndroid Build Coastguard Worker self._test_error( 2437*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 2438*da0073e9SAndroid Build Coastguard Worker ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4 2439*da0073e9SAndroid Build Coastguard Worker ) 2440*da0073e9SAndroid Build Coastguard Worker ) 2441*da0073e9SAndroid Build Coastguard Worker 2442*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "FIXME: stuck test") 2443*da0073e9SAndroid Build Coastguard Worker def test_partial_workers(self): 2444*da0073e9SAndroid Build Coastguard Worker r"""Check that workers exit even if the iterator is not exhausted.""" 2445*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 2446*da0073e9SAndroid Build Coastguard Worker pin_memory_configs = (True, False) 2447*da0073e9SAndroid Build Coastguard Worker else: 2448*da0073e9SAndroid Build Coastguard Worker pin_memory_configs = (False,) 2449*da0073e9SAndroid Build Coastguard Worker 2450*da0073e9SAndroid Build Coastguard Worker for pin_memory in pin_memory_configs: 2451*da0073e9SAndroid Build Coastguard Worker loader = iter( 2452*da0073e9SAndroid Build Coastguard Worker self._get_data_loader( 2453*da0073e9SAndroid Build Coastguard Worker self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory 2454*da0073e9SAndroid Build Coastguard Worker ) 2455*da0073e9SAndroid Build Coastguard Worker ) 2456*da0073e9SAndroid Build Coastguard Worker workers = loader._workers 2457*da0073e9SAndroid Build Coastguard Worker if pin_memory: 2458*da0073e9SAndroid Build Coastguard Worker pin_memory_thread = loader._pin_memory_thread 2459*da0073e9SAndroid Build Coastguard Worker for i, _ in enumerate(loader): 2460*da0073e9SAndroid Build Coastguard Worker if i == 10: 2461*da0073e9SAndroid Build Coastguard Worker break 2462*da0073e9SAndroid Build Coastguard Worker assert i == 10 2463*da0073e9SAndroid Build Coastguard Worker del loader 2464*da0073e9SAndroid Build Coastguard Worker for w in workers: 2465*da0073e9SAndroid Build Coastguard Worker w.join(JOIN_TIMEOUT) 2466*da0073e9SAndroid Build Coastguard Worker self.assertFalse(w.is_alive(), "subprocess not terminated") 2467*da0073e9SAndroid Build Coastguard Worker if pin_memory: 2468*da0073e9SAndroid Build Coastguard Worker pin_memory_thread.join(JOIN_TIMEOUT) 2469*da0073e9SAndroid Build Coastguard Worker self.assertFalse(pin_memory_thread.is_alive()) 2470*da0073e9SAndroid Build Coastguard Worker 2471*da0073e9SAndroid Build Coastguard Worker # Takes 2.5min to finish, see https://github.com/pytorch/pytorch/issues/46065 2472*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 2473*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not HAS_PSUTIL, "psutil not found") 2474*da0073e9SAndroid Build Coastguard Worker @slowTest 2475*da0073e9SAndroid Build Coastguard Worker def test_proper_exit(self): 2476*da0073e9SAndroid Build Coastguard Worker ( 2477*da0073e9SAndroid Build Coastguard Worker r"""There might be ConnectionResetError or leaked semaphore warning """ 2478*da0073e9SAndroid Build Coastguard Worker r"""(due to dirty process exit), but they are all safe to ignore""" 2479*da0073e9SAndroid Build Coastguard Worker ) 2480*da0073e9SAndroid Build Coastguard Worker 2481*da0073e9SAndroid Build Coastguard Worker # TODO: test the case where the pin_memory_thread triggers an 2482*da0073e9SAndroid Build Coastguard Worker # error/fatal signal. I haven't found out how to properly do that. 2483*da0073e9SAndroid Build Coastguard Worker 2484*da0073e9SAndroid Build Coastguard Worker for ( 2485*da0073e9SAndroid Build Coastguard Worker is_iterable_dataset, 2486*da0073e9SAndroid Build Coastguard Worker use_workers, 2487*da0073e9SAndroid Build Coastguard Worker pin_memory, 2488*da0073e9SAndroid Build Coastguard Worker hold_iter_reference, 2489*da0073e9SAndroid Build Coastguard Worker ) in itertools.product([True, False], repeat=4): 2490*da0073e9SAndroid Build Coastguard Worker # `hold_iter_reference` specifies whether we hold a reference to the 2491*da0073e9SAndroid Build Coastguard Worker # iterator. This is interesting because Python3 error traces holds a 2492*da0073e9SAndroid Build Coastguard Worker # reference to the frames, which hold references to all the local 2493*da0073e9SAndroid Build Coastguard Worker # variables including the iterator, and then the iterator dtor may 2494*da0073e9SAndroid Build Coastguard Worker # not be called before process end. It is important to see that the 2495*da0073e9SAndroid Build Coastguard Worker # processes still exit in both cases. 2496*da0073e9SAndroid Build Coastguard Worker 2497*da0073e9SAndroid Build Coastguard Worker if pin_memory and (not TEST_CUDA or NO_MULTIPROCESSING_SPAWN or IS_WINDOWS): 2498*da0073e9SAndroid Build Coastguard Worker # This test runs in a subprocess, which can only initialize CUDA with spawn. 2499*da0073e9SAndroid Build Coastguard Worker # DataLoader with pin_memory=True initializes CUDA when its iterator is constructed. 2500*da0073e9SAndroid Build Coastguard Worker # For windows, pin_memory sometimes causes CUDA oom. 2501*da0073e9SAndroid Build Coastguard Worker continue 2502*da0073e9SAndroid Build Coastguard Worker 2503*da0073e9SAndroid Build Coastguard Worker # `exit_method` controls the way the loader process ends. 2504*da0073e9SAndroid Build Coastguard Worker # - `*_kill` means that `*` is killed by OS. 2505*da0073e9SAndroid Build Coastguard Worker # - `*_error` means that `*` raises an error. 2506*da0073e9SAndroid Build Coastguard Worker # - `None` means that no error happens. 2507*da0073e9SAndroid Build Coastguard Worker # In all cases, all processes should end properly. 2508*da0073e9SAndroid Build Coastguard Worker if use_workers: 2509*da0073e9SAndroid Build Coastguard Worker # TODO: Fix test for 'loader_kill' that would cause running out of shared memory. 2510*da0073e9SAndroid Build Coastguard Worker # Killing loader process would prevent DataLoader iterator clean up all queues 2511*da0073e9SAndroid Build Coastguard Worker # and worker processes 2512*da0073e9SAndroid Build Coastguard Worker exit_methods = [None, "loader_error", "worker_error", "worker_kill"] 2513*da0073e9SAndroid Build Coastguard Worker persistent_workers = self.persistent_workers 2514*da0073e9SAndroid Build Coastguard Worker else: 2515*da0073e9SAndroid Build Coastguard Worker exit_methods = [None, "loader_error", "loader_kill"] 2516*da0073e9SAndroid Build Coastguard Worker persistent_workers = False 2517*da0073e9SAndroid Build Coastguard Worker 2518*da0073e9SAndroid Build Coastguard Worker for exit_method in exit_methods: 2519*da0073e9SAndroid Build Coastguard Worker if exit_method == "worker_kill": 2520*da0073e9SAndroid Build Coastguard Worker # FIXME: This sometimes hangs. See #16608. 2521*da0073e9SAndroid Build Coastguard Worker continue 2522*da0073e9SAndroid Build Coastguard Worker 2523*da0073e9SAndroid Build Coastguard Worker desc = [] 2524*da0073e9SAndroid Build Coastguard Worker desc.append(f"is_iterable_dataset={is_iterable_dataset}") 2525*da0073e9SAndroid Build Coastguard Worker desc.append(f"use_workers={use_workers}") 2526*da0073e9SAndroid Build Coastguard Worker desc.append(f"pin_memory={pin_memory}") 2527*da0073e9SAndroid Build Coastguard Worker desc.append(f"hold_iter_reference={hold_iter_reference}") 2528*da0073e9SAndroid Build Coastguard Worker desc.append(f"exit_method={exit_method}") 2529*da0073e9SAndroid Build Coastguard Worker desc = "test_proper_exit with " + ", ".join(desc) 2530*da0073e9SAndroid Build Coastguard Worker 2531*da0073e9SAndroid Build Coastguard Worker # Event that the loader process uses to signal testing process 2532*da0073e9SAndroid Build Coastguard Worker # that various things are setup, including that the worker pids 2533*da0073e9SAndroid Build Coastguard Worker # are specified in `worker_pids` array. 2534*da0073e9SAndroid Build Coastguard Worker loader_setup_event = mp.Event() 2535*da0073e9SAndroid Build Coastguard Worker 2536*da0073e9SAndroid Build Coastguard Worker # Event that this process has finished setting up, and the 2537*da0073e9SAndroid Build Coastguard Worker # loader process can now proceed to trigger error events or 2538*da0073e9SAndroid Build Coastguard Worker # finish normally. 2539*da0073e9SAndroid Build Coastguard Worker tester_setup_event = mp.Event() 2540*da0073e9SAndroid Build Coastguard Worker 2541*da0073e9SAndroid Build Coastguard Worker loader_p = ErrorTrackingProcess( 2542*da0073e9SAndroid Build Coastguard Worker target=_test_proper_exit, 2543*da0073e9SAndroid Build Coastguard Worker args=( 2544*da0073e9SAndroid Build Coastguard Worker is_iterable_dataset, 2545*da0073e9SAndroid Build Coastguard Worker use_workers, 2546*da0073e9SAndroid Build Coastguard Worker pin_memory, 2547*da0073e9SAndroid Build Coastguard Worker exit_method, 2548*da0073e9SAndroid Build Coastguard Worker hold_iter_reference, 2549*da0073e9SAndroid Build Coastguard Worker loader_setup_event, 2550*da0073e9SAndroid Build Coastguard Worker tester_setup_event, 2551*da0073e9SAndroid Build Coastguard Worker persistent_workers, 2552*da0073e9SAndroid Build Coastguard Worker ), 2553*da0073e9SAndroid Build Coastguard Worker disable_stderr=False, 2554*da0073e9SAndroid Build Coastguard Worker ) 2555*da0073e9SAndroid Build Coastguard Worker loader_p.start() 2556*da0073e9SAndroid Build Coastguard Worker loader_psutil_p = psutil.Process(loader_p.pid) 2557*da0073e9SAndroid Build Coastguard Worker 2558*da0073e9SAndroid Build Coastguard Worker # Wait for loader process to set everything up, e.g., starting 2559*da0073e9SAndroid Build Coastguard Worker # workers. 2560*da0073e9SAndroid Build Coastguard Worker loader_setup_event.wait(timeout=JOIN_TIMEOUT) 2561*da0073e9SAndroid Build Coastguard Worker if not loader_setup_event.is_set(): 2562*da0073e9SAndroid Build Coastguard Worker fail_msg = ( 2563*da0073e9SAndroid Build Coastguard Worker desc + ": loader process failed to setup within given time" 2564*da0073e9SAndroid Build Coastguard Worker ) 2565*da0073e9SAndroid Build Coastguard Worker if loader_p.exception is not None: 2566*da0073e9SAndroid Build Coastguard Worker fail_msg += f", and had exception {loader_p.exception}" 2567*da0073e9SAndroid Build Coastguard Worker elif not loader_p.is_alive(): 2568*da0073e9SAndroid Build Coastguard Worker fail_msg += f", and exited with code {loader_p.exitcode} but had no exception" 2569*da0073e9SAndroid Build Coastguard Worker else: 2570*da0073e9SAndroid Build Coastguard Worker fail_msg += ", and is still alive." 2571*da0073e9SAndroid Build Coastguard Worker if loader_p.is_alive(): 2572*da0073e9SAndroid Build Coastguard Worker # this may kill the process, needs to run after the above lines 2573*da0073e9SAndroid Build Coastguard Worker loader_p.print_traces_of_all_threads() 2574*da0073e9SAndroid Build Coastguard Worker self.fail(fail_msg) 2575*da0073e9SAndroid Build Coastguard Worker 2576*da0073e9SAndroid Build Coastguard Worker # We are certain that the workers have started now. 2577*da0073e9SAndroid Build Coastguard Worker worker_psutil_ps = loader_psutil_p.children() 2578*da0073e9SAndroid Build Coastguard Worker 2579*da0073e9SAndroid Build Coastguard Worker def fail(reason): 2580*da0073e9SAndroid Build Coastguard Worker report_psutil_attrs = [ 2581*da0073e9SAndroid Build Coastguard Worker "pid", 2582*da0073e9SAndroid Build Coastguard Worker "name", 2583*da0073e9SAndroid Build Coastguard Worker "cpu_times", 2584*da0073e9SAndroid Build Coastguard Worker "io_counters", 2585*da0073e9SAndroid Build Coastguard Worker "memory_full_info", 2586*da0073e9SAndroid Build Coastguard Worker "num_ctx_switches", 2587*da0073e9SAndroid Build Coastguard Worker "open_files", 2588*da0073e9SAndroid Build Coastguard Worker "threads", 2589*da0073e9SAndroid Build Coastguard Worker "status", 2590*da0073e9SAndroid Build Coastguard Worker "nice", 2591*da0073e9SAndroid Build Coastguard Worker "ionice", 2592*da0073e9SAndroid Build Coastguard Worker ] 2593*da0073e9SAndroid Build Coastguard Worker if reason is None: 2594*da0073e9SAndroid Build Coastguard Worker err_msg = desc 2595*da0073e9SAndroid Build Coastguard Worker else: 2596*da0073e9SAndroid Build Coastguard Worker err_msg = f"{desc}: {reason}" 2597*da0073e9SAndroid Build Coastguard Worker err_msg += "\nLoader info:\n\t" 2598*da0073e9SAndroid Build Coastguard Worker if loader_psutil_p.is_running(): 2599*da0073e9SAndroid Build Coastguard Worker err_msg += str( 2600*da0073e9SAndroid Build Coastguard Worker loader_psutil_p.as_dict(attrs=report_psutil_attrs) 2601*da0073e9SAndroid Build Coastguard Worker ) 2602*da0073e9SAndroid Build Coastguard Worker # this may kill the process, needs to run after the above line 2603*da0073e9SAndroid Build Coastguard Worker loader_p.print_traces_of_all_threads() 2604*da0073e9SAndroid Build Coastguard Worker else: 2605*da0073e9SAndroid Build Coastguard Worker err_msg += f"exited with code {loader_p.exitcode}" 2606*da0073e9SAndroid Build Coastguard Worker if use_workers: 2607*da0073e9SAndroid Build Coastguard Worker err_msg += "\nWorker(s) info:" 2608*da0073e9SAndroid Build Coastguard Worker for idx, worker_psutil_p in enumerate(worker_psutil_ps): 2609*da0073e9SAndroid Build Coastguard Worker err_msg += f"\n\tWorker {idx}:\n\t\t" 2610*da0073e9SAndroid Build Coastguard Worker if worker_psutil_p.is_running(): 2611*da0073e9SAndroid Build Coastguard Worker err_msg += str( 2612*da0073e9SAndroid Build Coastguard Worker worker_psutil_p.as_dict(attrs=report_psutil_attrs) 2613*da0073e9SAndroid Build Coastguard Worker ) 2614*da0073e9SAndroid Build Coastguard Worker # this may kill the process, needs to run after the above line 2615*da0073e9SAndroid Build Coastguard Worker print_traces_of_all_threads(worker_psutil_p.pid) 2616*da0073e9SAndroid Build Coastguard Worker else: 2617*da0073e9SAndroid Build Coastguard Worker err_msg += "exited with unknown code" 2618*da0073e9SAndroid Build Coastguard Worker self.fail(err_msg) 2619*da0073e9SAndroid Build Coastguard Worker 2620*da0073e9SAndroid Build Coastguard Worker tester_setup_event.set() 2621*da0073e9SAndroid Build Coastguard Worker 2622*da0073e9SAndroid Build Coastguard Worker try: 2623*da0073e9SAndroid Build Coastguard Worker loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL) 2624*da0073e9SAndroid Build Coastguard Worker if loader_p.is_alive(): 2625*da0073e9SAndroid Build Coastguard Worker fail_reason = "loader process did not terminate" 2626*da0073e9SAndroid Build Coastguard Worker if loader_p.exception is not None: 2627*da0073e9SAndroid Build Coastguard Worker fail( 2628*da0073e9SAndroid Build Coastguard Worker fail_reason 2629*da0073e9SAndroid Build Coastguard Worker + f", and had exception {loader_p.exception}" 2630*da0073e9SAndroid Build Coastguard Worker ) 2631*da0073e9SAndroid Build Coastguard Worker else: 2632*da0073e9SAndroid Build Coastguard Worker fail(fail_reason + ", and had no exception") 2633*da0073e9SAndroid Build Coastguard Worker _, alive = psutil.wait_procs( 2634*da0073e9SAndroid Build Coastguard Worker worker_psutil_ps, 2635*da0073e9SAndroid Build Coastguard Worker timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT), 2636*da0073e9SAndroid Build Coastguard Worker ) 2637*da0073e9SAndroid Build Coastguard Worker if len(alive) > 0: 2638*da0073e9SAndroid Build Coastguard Worker fail( 2639*da0073e9SAndroid Build Coastguard Worker "worker process (pid(s) {}) did not terminate".format( 2640*da0073e9SAndroid Build Coastguard Worker ", ".join(str(p.pid) for p in alive) 2641*da0073e9SAndroid Build Coastguard Worker ) 2642*da0073e9SAndroid Build Coastguard Worker ) 2643*da0073e9SAndroid Build Coastguard Worker if exit_method is None: 2644*da0073e9SAndroid Build Coastguard Worker if loader_p.exitcode != 0: 2645*da0073e9SAndroid Build Coastguard Worker fail( 2646*da0073e9SAndroid Build Coastguard Worker f"loader process had nonzero exitcode {loader_p.exitcode}" 2647*da0073e9SAndroid Build Coastguard Worker ) 2648*da0073e9SAndroid Build Coastguard Worker else: 2649*da0073e9SAndroid Build Coastguard Worker if loader_p.exitcode == 0: 2650*da0073e9SAndroid Build Coastguard Worker fail("loader process had zero exitcode") 2651*da0073e9SAndroid Build Coastguard Worker if exit_method == "loader_error": 2652*da0073e9SAndroid Build Coastguard Worker if not isinstance( 2653*da0073e9SAndroid Build Coastguard Worker loader_p.exception, RuntimeError 2654*da0073e9SAndroid Build Coastguard Worker ) or "Loader error" not in str(loader_p.exception): 2655*da0073e9SAndroid Build Coastguard Worker fail( 2656*da0073e9SAndroid Build Coastguard Worker f"loader process did not raise expected exception, but had {loader_p.exception}" 2657*da0073e9SAndroid Build Coastguard Worker ) 2658*da0073e9SAndroid Build Coastguard Worker elif exit_method == "worker_kill": 2659*da0073e9SAndroid Build Coastguard Worker if isinstance(loader_p.exception, RuntimeError): 2660*da0073e9SAndroid Build Coastguard Worker if "DataLoader worker (pid" not in str( 2661*da0073e9SAndroid Build Coastguard Worker loader_p.exception 2662*da0073e9SAndroid Build Coastguard Worker ): 2663*da0073e9SAndroid Build Coastguard Worker fail( 2664*da0073e9SAndroid Build Coastguard Worker f"loader process did not raise expected exception, but had {loader_p.exception}" 2665*da0073e9SAndroid Build Coastguard Worker ) 2666*da0073e9SAndroid Build Coastguard Worker elif isinstance(loader_p.exception, ConnectionRefusedError): 2667*da0073e9SAndroid Build Coastguard Worker # Sometimes, when the worker is being killed and is freeing its 2668*da0073e9SAndroid Build Coastguard Worker # resources, the unpickling in loader process will be met an 2669*da0073e9SAndroid Build Coastguard Worker # a `ConnectionRefusedError` as it can not open a socket to receive 2670*da0073e9SAndroid Build Coastguard Worker # resource. In such cases, the worker may not have fully exited, 2671*da0073e9SAndroid Build Coastguard Worker # and the loader can't know this via `is_alive` check or `SIGCHLD` 2672*da0073e9SAndroid Build Coastguard Worker # handler. So we permit this as an allowed error as well. 2673*da0073e9SAndroid Build Coastguard Worker # After all, we are happy as long as it terminates. 2674*da0073e9SAndroid Build Coastguard Worker pass 2675*da0073e9SAndroid Build Coastguard Worker else: 2676*da0073e9SAndroid Build Coastguard Worker fail( 2677*da0073e9SAndroid Build Coastguard Worker f"loader process did not raise expected exception, but had {loader_p.exception}" 2678*da0073e9SAndroid Build Coastguard Worker ) 2679*da0073e9SAndroid Build Coastguard Worker elif exit_method == "worker_error": 2680*da0073e9SAndroid Build Coastguard Worker if not isinstance( 2681*da0073e9SAndroid Build Coastguard Worker loader_p.exception, RuntimeError 2682*da0073e9SAndroid Build Coastguard Worker ) or "Worker error" not in str(loader_p.exception): 2683*da0073e9SAndroid Build Coastguard Worker fail( 2684*da0073e9SAndroid Build Coastguard Worker f"loader process did not raise expected exception, but had {loader_p.exception}" 2685*da0073e9SAndroid Build Coastguard Worker ) 2686*da0073e9SAndroid Build Coastguard Worker finally: 2687*da0073e9SAndroid Build Coastguard Worker loader_p.terminate() 2688*da0073e9SAndroid Build Coastguard Worker 2689*da0073e9SAndroid Build Coastguard Worker def test_len(self): 2690*da0073e9SAndroid Build Coastguard Worker def check_len(dl, expected): 2691*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(dl), expected) 2692*da0073e9SAndroid Build Coastguard Worker n = 0 2693*da0073e9SAndroid Build Coastguard Worker for _ in dl: 2694*da0073e9SAndroid Build Coastguard Worker n += 1 2695*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n, expected) 2696*da0073e9SAndroid Build Coastguard Worker 2697*da0073e9SAndroid Build Coastguard Worker check_len(self.dataset, 100) 2698*da0073e9SAndroid Build Coastguard Worker check_len(self._get_data_loader(self.dataset, batch_size=2), 50) 2699*da0073e9SAndroid Build Coastguard Worker check_len(self._get_data_loader(self.dataset, batch_size=3), 34) 2700*da0073e9SAndroid Build Coastguard Worker 2701*da0073e9SAndroid Build Coastguard Worker def test_iterabledataset_len(self): 2702*da0073e9SAndroid Build Coastguard Worker class IterableDataset(torch.utils.data.IterableDataset): 2703*da0073e9SAndroid Build Coastguard Worker def __len__(self): 2704*da0073e9SAndroid Build Coastguard Worker return 10 2705*da0073e9SAndroid Build Coastguard Worker 2706*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 2707*da0073e9SAndroid Build Coastguard Worker return iter(range(10)) 2708*da0073e9SAndroid Build Coastguard Worker 2709*da0073e9SAndroid Build Coastguard Worker iterable_loader = DataLoader(IterableDataset(), batch_size=1) 2710*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(iterable_loader), 10) 2711*da0073e9SAndroid Build Coastguard Worker iterable_loader = DataLoader(IterableDataset(), batch_size=1, drop_last=True) 2712*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(iterable_loader), 10) 2713*da0073e9SAndroid Build Coastguard Worker 2714*da0073e9SAndroid Build Coastguard Worker iterable_loader = DataLoader(IterableDataset(), batch_size=2) 2715*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(iterable_loader), 5) 2716*da0073e9SAndroid Build Coastguard Worker iterable_loader = DataLoader(IterableDataset(), batch_size=2, drop_last=True) 2717*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(iterable_loader), 5) 2718*da0073e9SAndroid Build Coastguard Worker 2719*da0073e9SAndroid Build Coastguard Worker iterable_loader = DataLoader(IterableDataset(), batch_size=3) 2720*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(iterable_loader), 4) 2721*da0073e9SAndroid Build Coastguard Worker iterable_loader = DataLoader(IterableDataset(), batch_size=3, drop_last=True) 2722*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(iterable_loader), 3) 2723*da0073e9SAndroid Build Coastguard Worker 2724*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") 2725*da0073e9SAndroid Build Coastguard Worker def test_numpy_scalars(self): 2726*da0073e9SAndroid Build Coastguard Worker import numpy as np 2727*da0073e9SAndroid Build Coastguard Worker 2728*da0073e9SAndroid Build Coastguard Worker class ScalarDataset(torch.utils.data.Dataset): 2729*da0073e9SAndroid Build Coastguard Worker def __init__(self, dtype): 2730*da0073e9SAndroid Build Coastguard Worker self.dtype = dtype 2731*da0073e9SAndroid Build Coastguard Worker 2732*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, i): 2733*da0073e9SAndroid Build Coastguard Worker return self.dtype() 2734*da0073e9SAndroid Build Coastguard Worker 2735*da0073e9SAndroid Build Coastguard Worker def __len__(self): 2736*da0073e9SAndroid Build Coastguard Worker return 4 2737*da0073e9SAndroid Build Coastguard Worker 2738*da0073e9SAndroid Build Coastguard Worker dtypes = { 2739*da0073e9SAndroid Build Coastguard Worker np.float64: torch.DoubleTensor, 2740*da0073e9SAndroid Build Coastguard Worker np.float32: torch.FloatTensor, 2741*da0073e9SAndroid Build Coastguard Worker np.float16: torch.HalfTensor, 2742*da0073e9SAndroid Build Coastguard Worker np.int64: torch.LongTensor, 2743*da0073e9SAndroid Build Coastguard Worker np.int32: torch.IntTensor, 2744*da0073e9SAndroid Build Coastguard Worker np.int16: torch.ShortTensor, 2745*da0073e9SAndroid Build Coastguard Worker np.int8: torch.CharTensor, 2746*da0073e9SAndroid Build Coastguard Worker np.uint8: torch.ByteTensor, 2747*da0073e9SAndroid Build Coastguard Worker } 2748*da0073e9SAndroid Build Coastguard Worker for dt, tt in dtypes.items(): 2749*da0073e9SAndroid Build Coastguard Worker dset = ScalarDataset(dt) 2750*da0073e9SAndroid Build Coastguard Worker loader = self._get_data_loader(dset, batch_size=2) 2751*da0073e9SAndroid Build Coastguard Worker batch = next(iter(loader)) 2752*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(batch, tt) 2753*da0073e9SAndroid Build Coastguard Worker 2754*da0073e9SAndroid Build Coastguard Worker def test_default_convert_mapping_keep_type(self): 2755*da0073e9SAndroid Build Coastguard Worker data = CustomDict({"a": 1, "b": 2}) 2756*da0073e9SAndroid Build Coastguard Worker converted = _utils.collate.default_convert(data) 2757*da0073e9SAndroid Build Coastguard Worker 2758*da0073e9SAndroid Build Coastguard Worker self.assertEqual(converted, data) 2759*da0073e9SAndroid Build Coastguard Worker 2760*da0073e9SAndroid Build Coastguard Worker def test_default_convert_sequence_keep_type(self): 2761*da0073e9SAndroid Build Coastguard Worker data = CustomList([1, 2, 3]) 2762*da0073e9SAndroid Build Coastguard Worker converted = _utils.collate.default_convert(data) 2763*da0073e9SAndroid Build Coastguard Worker 2764*da0073e9SAndroid Build Coastguard Worker self.assertEqual(converted, data) 2765*da0073e9SAndroid Build Coastguard Worker 2766*da0073e9SAndroid Build Coastguard Worker def test_default_convert_sequence_dont_keep_type(self): 2767*da0073e9SAndroid Build Coastguard Worker data = range(2) 2768*da0073e9SAndroid Build Coastguard Worker converted = _utils.collate.default_convert(data) 2769*da0073e9SAndroid Build Coastguard Worker 2770*da0073e9SAndroid Build Coastguard Worker self.assertEqual(converted, [0, 1]) 2771*da0073e9SAndroid Build Coastguard Worker 2772*da0073e9SAndroid Build Coastguard Worker def test_default_collate_dtype(self): 2773*da0073e9SAndroid Build Coastguard Worker arr = [1, 2, -1] 2774*da0073e9SAndroid Build Coastguard Worker collated = _utils.collate.default_collate(arr) 2775*da0073e9SAndroid Build Coastguard Worker self.assertEqual(collated, torch.tensor(arr)) 2776*da0073e9SAndroid Build Coastguard Worker self.assertEqual(collated.dtype, torch.int64) 2777*da0073e9SAndroid Build Coastguard Worker 2778*da0073e9SAndroid Build Coastguard Worker arr = [1.1, 2.3, -0.9] 2779*da0073e9SAndroid Build Coastguard Worker collated = _utils.collate.default_collate(arr) 2780*da0073e9SAndroid Build Coastguard Worker self.assertEqual(collated, torch.tensor(arr, dtype=torch.float64)) 2781*da0073e9SAndroid Build Coastguard Worker 2782*da0073e9SAndroid Build Coastguard Worker arr = [True, False] 2783*da0073e9SAndroid Build Coastguard Worker collated = _utils.collate.default_collate(arr) 2784*da0073e9SAndroid Build Coastguard Worker self.assertEqual(collated, torch.tensor(arr)) 2785*da0073e9SAndroid Build Coastguard Worker self.assertEqual(collated.dtype, torch.bool) 2786*da0073e9SAndroid Build Coastguard Worker 2787*da0073e9SAndroid Build Coastguard Worker # Should be a no-op 2788*da0073e9SAndroid Build Coastguard Worker arr = ["a", "b", "c"] 2789*da0073e9SAndroid Build Coastguard Worker self.assertEqual(arr, _utils.collate.default_collate(arr)) 2790*da0073e9SAndroid Build Coastguard Worker 2791*da0073e9SAndroid Build Coastguard Worker def test_default_collate_mapping_keep_type(self): 2792*da0073e9SAndroid Build Coastguard Worker batch = [CustomDict({"a": 1, "b": 2}), CustomDict({"a": 3, "b": 4})] 2793*da0073e9SAndroid Build Coastguard Worker collated = _utils.collate.default_collate(batch) 2794*da0073e9SAndroid Build Coastguard Worker 2795*da0073e9SAndroid Build Coastguard Worker expected = CustomDict({"a": torch.tensor([1, 3]), "b": torch.tensor([2, 4])}) 2796*da0073e9SAndroid Build Coastguard Worker self.assertEqual(collated, expected) 2797*da0073e9SAndroid Build Coastguard Worker 2798*da0073e9SAndroid Build Coastguard Worker def test_default_collate_sequence_keep_type(self): 2799*da0073e9SAndroid Build Coastguard Worker batch = [CustomList([1, 2, 3]), CustomList([4, 5, 6])] 2800*da0073e9SAndroid Build Coastguard Worker collated = _utils.collate.default_collate(batch) 2801*da0073e9SAndroid Build Coastguard Worker 2802*da0073e9SAndroid Build Coastguard Worker expected = CustomList( 2803*da0073e9SAndroid Build Coastguard Worker [ 2804*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 4]), 2805*da0073e9SAndroid Build Coastguard Worker torch.tensor([2, 5]), 2806*da0073e9SAndroid Build Coastguard Worker torch.tensor([3, 6]), 2807*da0073e9SAndroid Build Coastguard Worker ] 2808*da0073e9SAndroid Build Coastguard Worker ) 2809*da0073e9SAndroid Build Coastguard Worker self.assertEqual(collated, expected) 2810*da0073e9SAndroid Build Coastguard Worker 2811*da0073e9SAndroid Build Coastguard Worker def test_default_collate_sequence_dont_keep_type(self): 2812*da0073e9SAndroid Build Coastguard Worker batch = [range(2), range(2)] 2813*da0073e9SAndroid Build Coastguard Worker collated = _utils.collate.default_collate(batch) 2814*da0073e9SAndroid Build Coastguard Worker 2815*da0073e9SAndroid Build Coastguard Worker self.assertEqual(collated, [torch.tensor([0, 0]), torch.tensor([1, 1])]) 2816*da0073e9SAndroid Build Coastguard Worker 2817*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") 2818*da0073e9SAndroid Build Coastguard Worker def test_default_collate_bad_numpy_types(self): 2819*da0073e9SAndroid Build Coastguard Worker import numpy as np 2820*da0073e9SAndroid Build Coastguard Worker 2821*da0073e9SAndroid Build Coastguard Worker # Should be a no-op 2822*da0073e9SAndroid Build Coastguard Worker arr = np.array(["a", "b", "c"]) 2823*da0073e9SAndroid Build Coastguard Worker self.assertEqual(arr, _utils.collate.default_collate(arr)) 2824*da0073e9SAndroid Build Coastguard Worker 2825*da0073e9SAndroid Build Coastguard Worker arr = np.array([[["a", "b", "c"]]]) 2826*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) 2827*da0073e9SAndroid Build Coastguard Worker 2828*da0073e9SAndroid Build Coastguard Worker arr = np.array([object(), object(), object()]) 2829*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) 2830*da0073e9SAndroid Build Coastguard Worker 2831*da0073e9SAndroid Build Coastguard Worker arr = np.array([[[object(), object(), object()]]]) 2832*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) 2833*da0073e9SAndroid Build Coastguard Worker 2834*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") 2835*da0073e9SAndroid Build Coastguard Worker def test_default_collate_numpy_memmap(self): 2836*da0073e9SAndroid Build Coastguard Worker import numpy as np 2837*da0073e9SAndroid Build Coastguard Worker 2838*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryFile() as f: 2839*da0073e9SAndroid Build Coastguard Worker arr = np.array([[0, 1], [2, 3], [4, 5], [6, 7]]) 2840*da0073e9SAndroid Build Coastguard Worker arr_memmap = np.memmap(f, dtype=arr.dtype, mode="w+", shape=arr.shape) 2841*da0073e9SAndroid Build Coastguard Worker arr_memmap[:] = arr[:] 2842*da0073e9SAndroid Build Coastguard Worker arr_new = np.memmap(f, dtype=arr.dtype, mode="r", shape=arr.shape) 2843*da0073e9SAndroid Build Coastguard Worker tensor = _utils.collate.default_collate(list(arr_new)) 2844*da0073e9SAndroid Build Coastguard Worker 2845*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2846*da0073e9SAndroid Build Coastguard Worker (tensor == tensor.new_tensor([[0, 1], [2, 3], [4, 5], [6, 7]])).all().item() 2847*da0073e9SAndroid Build Coastguard Worker ) 2848*da0073e9SAndroid Build Coastguard Worker 2849*da0073e9SAndroid Build Coastguard Worker def test_default_collate_bad_sequence_type(self): 2850*da0073e9SAndroid Build Coastguard Worker batch = [["X"], ["X", "X"]] 2851*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: _utils.collate.default_collate(batch)) 2852*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 2853*da0073e9SAndroid Build Coastguard Worker RuntimeError, lambda: _utils.collate.default_collate(batch[::-1]) 2854*da0073e9SAndroid Build Coastguard Worker ) 2855*da0073e9SAndroid Build Coastguard Worker 2856*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") 2857*da0073e9SAndroid Build Coastguard Worker def test_default_collate_shared_tensor(self): 2858*da0073e9SAndroid Build Coastguard Worker import numpy as np 2859*da0073e9SAndroid Build Coastguard Worker 2860*da0073e9SAndroid Build Coastguard Worker t_in = torch.zeros(1) 2861*da0073e9SAndroid Build Coastguard Worker n_in = np.zeros(1) 2862*da0073e9SAndroid Build Coastguard Worker 2863*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t_in.is_shared(), False) 2864*da0073e9SAndroid Build Coastguard Worker 2865*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), False) 2866*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), False) 2867*da0073e9SAndroid Build Coastguard Worker 2868*da0073e9SAndroid Build Coastguard Worker # FIXME: fix the following hack that makes `default_collate` believe 2869*da0073e9SAndroid Build Coastguard Worker # that it is in a worker process (since it tests 2870*da0073e9SAndroid Build Coastguard Worker # `get_worker_info() != None`), even though it is not. 2871*da0073e9SAndroid Build Coastguard Worker old = _utils.worker._worker_info 2872*da0073e9SAndroid Build Coastguard Worker try: 2873*da0073e9SAndroid Build Coastguard Worker _utils.worker._worker_info = "x" 2874*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), True) 2875*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), True) 2876*da0073e9SAndroid Build Coastguard Worker finally: 2877*da0073e9SAndroid Build Coastguard Worker _utils.worker._worker_info = old 2878*da0073e9SAndroid Build Coastguard Worker 2879*da0073e9SAndroid Build Coastguard Worker def test_excessive_thread_creation_warning(self): 2880*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 2881*da0073e9SAndroid Build Coastguard Worker UserWarning, 2882*da0073e9SAndroid Build Coastguard Worker r"excessive worker creation might get DataLoader running slow or even freeze", 2883*da0073e9SAndroid Build Coastguard Worker ): 2884*da0073e9SAndroid Build Coastguard Worker dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000) 2885*da0073e9SAndroid Build Coastguard Worker 2886*da0073e9SAndroid Build Coastguard Worker 2887*da0073e9SAndroid Build Coastguard Workerclass TestDataLoaderDeviceType(TestCase): 2888*da0073e9SAndroid Build Coastguard Worker @parametrize( 2889*da0073e9SAndroid Build Coastguard Worker "context", 2890*da0073e9SAndroid Build Coastguard Worker [ctx for ctx in supported_multiprocessing_contexts if ctx is not None], 2891*da0073e9SAndroid Build Coastguard Worker ) 2892*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 2893*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_multiprocessing(self, device, context): 2894*da0073e9SAndroid Build Coastguard Worker # The 'fork' multiprocessing context doesn't work for CUDA so skip it 2895*da0073e9SAndroid Build Coastguard Worker if "cuda" in device and context == "fork": 2896*da0073e9SAndroid Build Coastguard Worker # TODO: Skip this better in a better way when the test framework allows 2897*da0073e9SAndroid Build Coastguard Worker return 2898*da0073e9SAndroid Build Coastguard Worker 2899*da0073e9SAndroid Build Coastguard Worker dataset = [ 2900*da0073e9SAndroid Build Coastguard Worker torch.nested.nested_tensor([torch.randn(5)], device=device) 2901*da0073e9SAndroid Build Coastguard Worker for _ in range(10) 2902*da0073e9SAndroid Build Coastguard Worker ] 2903*da0073e9SAndroid Build Coastguard Worker 2904*da0073e9SAndroid Build Coastguard Worker pin_memory_settings = [False] 2905*da0073e9SAndroid Build Coastguard Worker if device == "cpu" and torch.cuda.is_available(): 2906*da0073e9SAndroid Build Coastguard Worker pin_memory_settings.append(True) 2907*da0073e9SAndroid Build Coastguard Worker 2908*da0073e9SAndroid Build Coastguard Worker for pin_memory in pin_memory_settings: 2909*da0073e9SAndroid Build Coastguard Worker loader = torch.utils.data.DataLoader( 2910*da0073e9SAndroid Build Coastguard Worker dataset, 2911*da0073e9SAndroid Build Coastguard Worker batch_size=1, 2912*da0073e9SAndroid Build Coastguard Worker num_workers=4, 2913*da0073e9SAndroid Build Coastguard Worker collate_fn=_clone_collate, 2914*da0073e9SAndroid Build Coastguard Worker pin_memory=pin_memory, 2915*da0073e9SAndroid Build Coastguard Worker multiprocessing_context=context, 2916*da0073e9SAndroid Build Coastguard Worker ) 2917*da0073e9SAndroid Build Coastguard Worker 2918*da0073e9SAndroid Build Coastguard Worker for i, batch in enumerate(loader): 2919*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batch[0], dataset[i]) 2920*da0073e9SAndroid Build Coastguard Worker 2921*da0073e9SAndroid Build Coastguard Worker # Error case: default collate_fn doesn't currently support batches of nested tensors. 2922*da0073e9SAndroid Build Coastguard Worker # Following the current semantics, we'd need to stack them, which isn't possible atm. 2923*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2924*da0073e9SAndroid Build Coastguard Worker RuntimeError, "not currently supported by the default collate_fn" 2925*da0073e9SAndroid Build Coastguard Worker ): 2926*da0073e9SAndroid Build Coastguard Worker loader = torch.utils.data.DataLoader( 2927*da0073e9SAndroid Build Coastguard Worker dataset, 2928*da0073e9SAndroid Build Coastguard Worker batch_size=1, 2929*da0073e9SAndroid Build Coastguard Worker num_workers=4, 2930*da0073e9SAndroid Build Coastguard Worker multiprocessing_context=context, 2931*da0073e9SAndroid Build Coastguard Worker ) 2932*da0073e9SAndroid Build Coastguard Worker 2933*da0073e9SAndroid Build Coastguard Worker next(iter(loader)) 2934*da0073e9SAndroid Build Coastguard Worker 2935*da0073e9SAndroid Build Coastguard Worker 2936*da0073e9SAndroid Build Coastguard Workerclass IntegrationTestDataLoaderDataPipe(TestCase): 2937*da0073e9SAndroid Build Coastguard Worker r""" 2938*da0073e9SAndroid Build Coastguard Worker Verify the behavior of a certain ``DataPipes`` with ``DataLoader`` 2939*da0073e9SAndroid Build Coastguard Worker """ 2940*da0073e9SAndroid Build Coastguard Worker 2941*da0073e9SAndroid Build Coastguard Worker def test_shuffler_iterdatapipe(self): 2942*da0073e9SAndroid Build Coastguard Worker r""" 2943*da0073e9SAndroid Build Coastguard Worker Verify ``IterDataPipe.shuffle`` is controlled by ``DataLoader`` 2944*da0073e9SAndroid Build Coastguard Worker to generate different seeds deterministically per epoch. 2945*da0073e9SAndroid Build Coastguard Worker """ 2946*da0073e9SAndroid Build Coastguard Worker exp = list(range(100)) 2947*da0073e9SAndroid Build Coastguard Worker 2948*da0073e9SAndroid Build Coastguard Worker def _create_dp(buffer_size): 2949*da0073e9SAndroid Build Coastguard Worker input_ds = dp.iter.IterableWrapper(exp) 2950*da0073e9SAndroid Build Coastguard Worker return input_ds.shuffle(buffer_size=buffer_size).sharding_filter() 2951*da0073e9SAndroid Build Coastguard Worker 2952*da0073e9SAndroid Build Coastguard Worker for bs in (5, 20, 33): 2953*da0073e9SAndroid Build Coastguard Worker # Test Deterministic 2954*da0073e9SAndroid Build Coastguard Worker for num_workers, pw in itertools.product((0, 1, 2), (True, False)): 2955*da0073e9SAndroid Build Coastguard Worker if num_workers == 0 and pw: 2956*da0073e9SAndroid Build Coastguard Worker continue 2957*da0073e9SAndroid Build Coastguard Worker 2958*da0073e9SAndroid Build Coastguard Worker shuffle_dp = _create_dp(bs) 2959*da0073e9SAndroid Build Coastguard Worker 2960*da0073e9SAndroid Build Coastguard Worker mp_ctx = "spawn" if num_workers > 0 else None 2961*da0073e9SAndroid Build Coastguard Worker dl = DataLoader( 2962*da0073e9SAndroid Build Coastguard Worker shuffle_dp, 2963*da0073e9SAndroid Build Coastguard Worker num_workers=num_workers, 2964*da0073e9SAndroid Build Coastguard Worker shuffle=True, 2965*da0073e9SAndroid Build Coastguard Worker multiprocessing_context=mp_ctx, 2966*da0073e9SAndroid Build Coastguard Worker persistent_workers=pw, 2967*da0073e9SAndroid Build Coastguard Worker ) 2968*da0073e9SAndroid Build Coastguard Worker 2969*da0073e9SAndroid Build Coastguard Worker # No seed 2970*da0073e9SAndroid Build Coastguard Worker dl_res_ns = list(dl) 2971*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sorted(dl_res_ns), exp) 2972*da0073e9SAndroid Build Coastguard Worker 2973*da0073e9SAndroid Build Coastguard Worker # Same seeds 2974*da0073e9SAndroid Build Coastguard Worker dl_res = [] 2975*da0073e9SAndroid Build Coastguard Worker for epoch in range(2): 2976*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 2977*da0073e9SAndroid Build Coastguard Worker dl_res.append(list(dl)) 2978*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dl_res[0], dl_res[1]) 2979*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sorted(dl_res[0]), exp) 2980*da0073e9SAndroid Build Coastguard Worker 2981*da0073e9SAndroid Build Coastguard Worker # Different seeds 2982*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(321) 2983*da0073e9SAndroid Build Coastguard Worker dl_res.append(list(dl)) 2984*da0073e9SAndroid Build Coastguard Worker 2985*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(dl_res[0]), len(dl_res[2])) 2986*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(dl_res[0], dl_res[2]) 2987*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sorted(dl_res[0]), sorted(dl_res[2])) 2988*da0073e9SAndroid Build Coastguard Worker 2989*da0073e9SAndroid Build Coastguard Worker if dl._iterator is not None: 2990*da0073e9SAndroid Build Coastguard Worker dl._iterator._shutdown_workers() 2991*da0073e9SAndroid Build Coastguard Worker dl._iterator = None 2992*da0073e9SAndroid Build Coastguard Worker del dl 2993*da0073e9SAndroid Build Coastguard Worker 2994*da0073e9SAndroid Build Coastguard Worker 2995*da0073e9SAndroid Build Coastguard Workerclass StringDataset(Dataset): 2996*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2997*da0073e9SAndroid Build Coastguard Worker self.s = "12345" 2998*da0073e9SAndroid Build Coastguard Worker 2999*da0073e9SAndroid Build Coastguard Worker def __len__(self): 3000*da0073e9SAndroid Build Coastguard Worker return len(self.s) 3001*da0073e9SAndroid Build Coastguard Worker 3002*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, ndx): 3003*da0073e9SAndroid Build Coastguard Worker return (self.s[ndx], ndx) 3004*da0073e9SAndroid Build Coastguard Worker 3005*da0073e9SAndroid Build Coastguard Worker 3006*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 3007*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 3008*da0073e9SAndroid Build Coastguard Worker "Fails with TSAN with the following error: starting new threads after multi-threaded " 3009*da0073e9SAndroid Build Coastguard Worker "fork is not supported. Dying (set die_after_fork=0 to override)", 3010*da0073e9SAndroid Build Coastguard Worker) 3011*da0073e9SAndroid Build Coastguard Workerclass TestStringDataLoader(TestCase): 3012*da0073e9SAndroid Build Coastguard Worker def setUp(self): 3013*da0073e9SAndroid Build Coastguard Worker super().setUp() 3014*da0073e9SAndroid Build Coastguard Worker self.dataset = StringDataset() 3015*da0073e9SAndroid Build Coastguard Worker 3016*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 3017*da0073e9SAndroid Build Coastguard Worker def test_shuffle_pin_memory(self): 3018*da0073e9SAndroid Build Coastguard Worker loader = DataLoader( 3019*da0073e9SAndroid Build Coastguard Worker self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True 3020*da0073e9SAndroid Build Coastguard Worker ) 3021*da0073e9SAndroid Build Coastguard Worker for s, n in loader: 3022*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(s[0], str) 3023*da0073e9SAndroid Build Coastguard Worker self.assertTrue(n.is_pinned()) 3024*da0073e9SAndroid Build Coastguard Worker 3025*da0073e9SAndroid Build Coastguard Worker 3026*da0073e9SAndroid Build Coastguard Workerclass DictDataset(Dataset): 3027*da0073e9SAndroid Build Coastguard Worker def __len__(self): 3028*da0073e9SAndroid Build Coastguard Worker return 4 3029*da0073e9SAndroid Build Coastguard Worker 3030*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, ndx): 3031*da0073e9SAndroid Build Coastguard Worker return { 3032*da0073e9SAndroid Build Coastguard Worker "a_tensor": torch.empty(4, 2).fill_(ndx), 3033*da0073e9SAndroid Build Coastguard Worker "another_dict": {"a_number": ndx}, 3034*da0073e9SAndroid Build Coastguard Worker } 3035*da0073e9SAndroid Build Coastguard Worker 3036*da0073e9SAndroid Build Coastguard Worker 3037*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 3038*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 3039*da0073e9SAndroid Build Coastguard Worker "Fails with TSAN with the following error: starting new threads after multi-threaded " 3040*da0073e9SAndroid Build Coastguard Worker "fork is not supported. Dying (set die_after_fork=0 to override)", 3041*da0073e9SAndroid Build Coastguard Worker) 3042*da0073e9SAndroid Build Coastguard Workerclass TestDictDataLoader(TestCase): 3043*da0073e9SAndroid Build Coastguard Worker def setUp(self): 3044*da0073e9SAndroid Build Coastguard Worker super().setUp() 3045*da0073e9SAndroid Build Coastguard Worker self.dataset = DictDataset() 3046*da0073e9SAndroid Build Coastguard Worker 3047*da0073e9SAndroid Build Coastguard Worker def test_sequential_batch(self): 3048*da0073e9SAndroid Build Coastguard Worker for persistent_workers in (False, True): 3049*da0073e9SAndroid Build Coastguard Worker if persistent_workers: 3050*da0073e9SAndroid Build Coastguard Worker loader = DataLoader( 3051*da0073e9SAndroid Build Coastguard Worker self.dataset, 3052*da0073e9SAndroid Build Coastguard Worker batch_size=2, 3053*da0073e9SAndroid Build Coastguard Worker shuffle=False, 3054*da0073e9SAndroid Build Coastguard Worker persistent_workers=persistent_workers, 3055*da0073e9SAndroid Build Coastguard Worker num_workers=1, 3056*da0073e9SAndroid Build Coastguard Worker ) 3057*da0073e9SAndroid Build Coastguard Worker else: 3058*da0073e9SAndroid Build Coastguard Worker loader = DataLoader( 3059*da0073e9SAndroid Build Coastguard Worker self.dataset, 3060*da0073e9SAndroid Build Coastguard Worker batch_size=2, 3061*da0073e9SAndroid Build Coastguard Worker shuffle=False, 3062*da0073e9SAndroid Build Coastguard Worker persistent_workers=persistent_workers, 3063*da0073e9SAndroid Build Coastguard Worker ) 3064*da0073e9SAndroid Build Coastguard Worker batch_size = loader.batch_size 3065*da0073e9SAndroid Build Coastguard Worker for i, sample in enumerate(loader): 3066*da0073e9SAndroid Build Coastguard Worker idx = i * batch_size 3067*da0073e9SAndroid Build Coastguard Worker self.assertEqual(set(sample.keys()), {"a_tensor", "another_dict"}) 3068*da0073e9SAndroid Build Coastguard Worker self.assertEqual(set(sample["another_dict"].keys()), {"a_number"}) 3069*da0073e9SAndroid Build Coastguard Worker 3070*da0073e9SAndroid Build Coastguard Worker t = sample["a_tensor"] 3071*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.size(), torch.Size([batch_size, 4, 2])) 3072*da0073e9SAndroid Build Coastguard Worker self.assertTrue((t[0] == idx).all()) 3073*da0073e9SAndroid Build Coastguard Worker self.assertTrue((t[1] == idx + 1).all()) 3074*da0073e9SAndroid Build Coastguard Worker 3075*da0073e9SAndroid Build Coastguard Worker n = sample["another_dict"]["a_number"] 3076*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n.size(), torch.Size([batch_size])) 3077*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n[0], idx) 3078*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n[1], idx + 1) 3079*da0073e9SAndroid Build Coastguard Worker 3080*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 3081*da0073e9SAndroid Build Coastguard Worker def test_pin_memory(self): 3082*da0073e9SAndroid Build Coastguard Worker loader = DataLoader(self.dataset, batch_size=2, pin_memory=True) 3083*da0073e9SAndroid Build Coastguard Worker for sample in loader: 3084*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sample["a_tensor"].is_pinned()) 3085*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sample["another_dict"]["a_number"].is_pinned()) 3086*da0073e9SAndroid Build Coastguard Worker 3087*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 3088*da0073e9SAndroid Build Coastguard Worker def test_pin_memory_device(self): 3089*da0073e9SAndroid Build Coastguard Worker loader = DataLoader( 3090*da0073e9SAndroid Build Coastguard Worker self.dataset, batch_size=2, pin_memory=True, pin_memory_device="cuda" 3091*da0073e9SAndroid Build Coastguard Worker ) 3092*da0073e9SAndroid Build Coastguard Worker for sample in loader: 3093*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sample["a_tensor"].is_pinned(device="cuda")) 3094*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sample["another_dict"]["a_number"].is_pinned(device="cuda")) 3095*da0073e9SAndroid Build Coastguard Worker 3096*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 3097*da0073e9SAndroid Build Coastguard Worker def test_pin_memory_with_only_device(self): 3098*da0073e9SAndroid Build Coastguard Worker loader = DataLoader(self.dataset, batch_size=2, pin_memory_device="cuda") 3099*da0073e9SAndroid Build Coastguard Worker for sample in loader: 3100*da0073e9SAndroid Build Coastguard Worker self.assertFalse(sample["a_tensor"].is_pinned(device="cuda")) 3101*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 3102*da0073e9SAndroid Build Coastguard Worker sample["another_dict"]["a_number"].is_pinned(device="cuda") 3103*da0073e9SAndroid Build Coastguard Worker ) 3104*da0073e9SAndroid Build Coastguard Worker 3105*da0073e9SAndroid Build Coastguard Worker 3106*da0073e9SAndroid Build Coastguard Workerclass DummyDataset(torch.utils.data.Dataset): 3107*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3108*da0073e9SAndroid Build Coastguard Worker self.data = list(range(10)) 3109*da0073e9SAndroid Build Coastguard Worker 3110*da0073e9SAndroid Build Coastguard Worker def __len__(self): 3111*da0073e9SAndroid Build Coastguard Worker return len(self.data) 3112*da0073e9SAndroid Build Coastguard Worker 3113*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, idx): 3114*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(idx): 3115*da0073e9SAndroid Build Coastguard Worker idx = idx.tolist() 3116*da0073e9SAndroid Build Coastguard Worker # The persistent workers always maintain the original 3117*da0073e9SAndroid Build Coastguard Worker # dataset through the dataloader lifetime 3118*da0073e9SAndroid Build Coastguard Worker # so the attributes will remain the same as the 3119*da0073e9SAndroid Build Coastguard Worker # first time the workers where spawned (dataloader iteration) 3120*da0073e9SAndroid Build Coastguard Worker assert self.start == 0 3121*da0073e9SAndroid Build Coastguard Worker return self.data[idx] 3122*da0073e9SAndroid Build Coastguard Worker 3123*da0073e9SAndroid Build Coastguard Worker 3124*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 3125*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 3126*da0073e9SAndroid Build Coastguard Worker "Fails with TSAN with the following error: starting new threads after multi-threaded " 3127*da0073e9SAndroid Build Coastguard Worker "fork is not supported. Dying (set die_after_fork=0 to override)", 3128*da0073e9SAndroid Build Coastguard Worker) 3129*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 3130*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ASAN, 3131*da0073e9SAndroid Build Coastguard Worker "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223", 3132*da0073e9SAndroid Build Coastguard Worker) 3133*da0073e9SAndroid Build Coastguard Workerclass TestDataLoaderPersistentWorkers(TestDataLoader): 3134*da0073e9SAndroid Build Coastguard Worker def setUp(self): 3135*da0073e9SAndroid Build Coastguard Worker super().setUp() 3136*da0073e9SAndroid Build Coastguard Worker self.persistent_workers = True 3137*da0073e9SAndroid Build Coastguard Worker 3138*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI") 3139*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows") 3140*da0073e9SAndroid Build Coastguard Worker def test_fd_limit_exceeded(self): 3141*da0073e9SAndroid Build Coastguard Worker # See NOTE [ DataLoader on Linux and open files limit ] 3142*da0073e9SAndroid Build Coastguard Worker import subprocess 3143*da0073e9SAndroid Build Coastguard Worker 3144*da0073e9SAndroid Build Coastguard Worker subprocess.check_output( 3145*da0073e9SAndroid Build Coastguard Worker [ 3146*da0073e9SAndroid Build Coastguard Worker sys.executable, 3147*da0073e9SAndroid Build Coastguard Worker "-c", 3148*da0073e9SAndroid Build Coastguard Worker """\ 3149*da0073e9SAndroid Build Coastguard Workerimport torch 3150*da0073e9SAndroid Build Coastguard Workerimport resource 3151*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data import DataLoader, IterableDataset 3152*da0073e9SAndroid Build Coastguard Worker 3153*da0073e9SAndroid Build Coastguard Workerclass RandomDataset(IterableDataset): 3154*da0073e9SAndroid Build Coastguard Worker def __init__(self, len, size): 3155*da0073e9SAndroid Build Coastguard Worker super(RandomDataset).__init__() 3156*da0073e9SAndroid Build Coastguard Worker self.len = len 3157*da0073e9SAndroid Build Coastguard Worker self.size = size 3158*da0073e9SAndroid Build Coastguard Worker 3159*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 3160*da0073e9SAndroid Build Coastguard Worker return self 3161*da0073e9SAndroid Build Coastguard Worker 3162*da0073e9SAndroid Build Coastguard Worker def __next__(self): 3163*da0073e9SAndroid Build Coastguard Worker if self.len <= 0: 3164*da0073e9SAndroid Build Coastguard Worker raise StopIteration 3165*da0073e9SAndroid Build Coastguard Worker self.len -= 1 3166*da0073e9SAndroid Build Coastguard Worker return torch.randn(self.size) 3167*da0073e9SAndroid Build Coastguard Worker 3168*da0073e9SAndroid Build Coastguard Workertry: 3169*da0073e9SAndroid Build Coastguard Worker keep_fds_alive = [] 3170*da0073e9SAndroid Build Coastguard Worker resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100)) 3171*da0073e9SAndroid Build Coastguard Worker for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork", 3172*da0073e9SAndroid Build Coastguard Worker num_workers=1, persistent_workers=True): 3173*da0073e9SAndroid Build Coastguard Worker random_t.max(dim=0) 3174*da0073e9SAndroid Build Coastguard Worker keep_fds_alive.append(random_t) 3175*da0073e9SAndroid Build Coastguard Workerexcept RuntimeError as e: 3176*da0073e9SAndroid Build Coastguard Worker assert "ulimit -n" in str(e) 3177*da0073e9SAndroid Build Coastguard Worker assert "set_sharing_strategy" in str(e) 3178*da0073e9SAndroid Build Coastguard Worker""", 3179*da0073e9SAndroid Build Coastguard Worker ] 3180*da0073e9SAndroid Build Coastguard Worker ) 3181*da0073e9SAndroid Build Coastguard Worker 3182*da0073e9SAndroid Build Coastguard Worker def test_dataset_not_reset(self): 3183*da0073e9SAndroid Build Coastguard Worker dataset = DummyDataset() 3184*da0073e9SAndroid Build Coastguard Worker pin_memory_configs = [False] 3185*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 3186*da0073e9SAndroid Build Coastguard Worker pin_memory_configs.append(True) 3187*da0073e9SAndroid Build Coastguard Worker for pin_memory in pin_memory_configs: 3188*da0073e9SAndroid Build Coastguard Worker dataloader = self._get_data_loader( 3189*da0073e9SAndroid Build Coastguard Worker dataset, num_workers=2, pin_memory=pin_memory 3190*da0073e9SAndroid Build Coastguard Worker ) 3191*da0073e9SAndroid Build Coastguard Worker dataset.start = 0 3192*da0073e9SAndroid Build Coastguard Worker for i in range(10): 3193*da0073e9SAndroid Build Coastguard Worker for x in dataloader: 3194*da0073e9SAndroid Build Coastguard Worker pass 3195*da0073e9SAndroid Build Coastguard Worker # Changing the start value here doesn't have any effect in the dataset 3196*da0073e9SAndroid Build Coastguard Worker # cached by the workers. since they are not recreated between epochs 3197*da0073e9SAndroid Build Coastguard Worker # and can cache values safely 3198*da0073e9SAndroid Build Coastguard Worker dataset.start = i 3199*da0073e9SAndroid Build Coastguard Worker 3200*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI") 3201*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Needs fork") 3202*da0073e9SAndroid Build Coastguard Worker def test_early_exit(self): 3203*da0073e9SAndroid Build Coastguard Worker import subprocess 3204*da0073e9SAndroid Build Coastguard Worker 3205*da0073e9SAndroid Build Coastguard Worker proc = subprocess.check_output( 3206*da0073e9SAndroid Build Coastguard Worker [ 3207*da0073e9SAndroid Build Coastguard Worker sys.executable, 3208*da0073e9SAndroid Build Coastguard Worker "-c", 3209*da0073e9SAndroid Build Coastguard Worker """\ 3210*da0073e9SAndroid Build Coastguard Workerimport torch 3211*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data import DataLoader, IterableDataset 3212*da0073e9SAndroid Build Coastguard Worker 3213*da0073e9SAndroid Build Coastguard Workerclass RandomDataset(IterableDataset): 3214*da0073e9SAndroid Build Coastguard Worker def __init__(self, len, size): 3215*da0073e9SAndroid Build Coastguard Worker super(RandomDataset).__init__() 3216*da0073e9SAndroid Build Coastguard Worker self.len = len 3217*da0073e9SAndroid Build Coastguard Worker self.size = size 3218*da0073e9SAndroid Build Coastguard Worker 3219*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 3220*da0073e9SAndroid Build Coastguard Worker return self 3221*da0073e9SAndroid Build Coastguard Worker 3222*da0073e9SAndroid Build Coastguard Worker def __next__(self): 3223*da0073e9SAndroid Build Coastguard Worker if self.len <= 0: 3224*da0073e9SAndroid Build Coastguard Worker raise StopIteration 3225*da0073e9SAndroid Build Coastguard Worker self.len -= 1 3226*da0073e9SAndroid Build Coastguard Worker return torch.randn(self.size) 3227*da0073e9SAndroid Build Coastguard Worker 3228*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 3229*da0073e9SAndroid Build Coastguard Worker dl = DataLoader( 3230*da0073e9SAndroid Build Coastguard Worker RandomDataset(64, (28, 28)), 3231*da0073e9SAndroid Build Coastguard Worker batch_size=16, 3232*da0073e9SAndroid Build Coastguard Worker num_workers=2, 3233*da0073e9SAndroid Build Coastguard Worker pin_memory=True, 3234*da0073e9SAndroid Build Coastguard Worker persistent_workers=True, 3235*da0073e9SAndroid Build Coastguard Worker multiprocessing_context="fork", 3236*da0073e9SAndroid Build Coastguard Worker ) 3237*da0073e9SAndroid Build Coastguard Worker 3238*da0073e9SAndroid Build Coastguard Worker for _ in dl: 3239*da0073e9SAndroid Build Coastguard Worker break 3240*da0073e9SAndroid Build Coastguard Worker""", 3241*da0073e9SAndroid Build Coastguard Worker ] 3242*da0073e9SAndroid Build Coastguard Worker ) 3243*da0073e9SAndroid Build Coastguard Worker 3244*da0073e9SAndroid Build Coastguard Worker 3245*da0073e9SAndroid Build Coastguard Workerclass NamedTupleDataset(Dataset): 3246*da0073e9SAndroid Build Coastguard Worker from collections import namedtuple 3247*da0073e9SAndroid Build Coastguard Worker 3248*da0073e9SAndroid Build Coastguard Worker Batch = namedtuple("Batch", ["data", "label", "random_tensor"]) 3249*da0073e9SAndroid Build Coastguard Worker Data = namedtuple("Data", ["positive", "negative"]) 3250*da0073e9SAndroid Build Coastguard Worker 3251*da0073e9SAndroid Build Coastguard Worker def __len__(self): 3252*da0073e9SAndroid Build Coastguard Worker return 4 3253*da0073e9SAndroid Build Coastguard Worker 3254*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, ndx): 3255*da0073e9SAndroid Build Coastguard Worker return self.Batch( 3256*da0073e9SAndroid Build Coastguard Worker data=self.Data(positive=ndx, negative=-ndx), 3257*da0073e9SAndroid Build Coastguard Worker label=str(ndx), 3258*da0073e9SAndroid Build Coastguard Worker random_tensor=torch.randn(3), 3259*da0073e9SAndroid Build Coastguard Worker ) 3260*da0073e9SAndroid Build Coastguard Worker 3261*da0073e9SAndroid Build Coastguard Worker 3262*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 3263*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 3264*da0073e9SAndroid Build Coastguard Worker "Fails with TSAN with the following error: starting new threads after multi-threaded " 3265*da0073e9SAndroid Build Coastguard Worker "fork is not supported. Dying (set die_after_fork=0 to override)", 3266*da0073e9SAndroid Build Coastguard Worker) 3267*da0073e9SAndroid Build Coastguard Workerclass TestNamedTupleDataLoader(TestCase): 3268*da0073e9SAndroid Build Coastguard Worker def setUp(self): 3269*da0073e9SAndroid Build Coastguard Worker super().setUp() 3270*da0073e9SAndroid Build Coastguard Worker self.dataset = NamedTupleDataset() 3271*da0073e9SAndroid Build Coastguard Worker 3272*da0073e9SAndroid Build Coastguard Worker def test_dataloader_with_namedtuple(self): 3273*da0073e9SAndroid Build Coastguard Worker # auto-collation 3274*da0073e9SAndroid Build Coastguard Worker loader = DataLoader(self.dataset, batch_size=2, pin_memory=TEST_CUDA) 3275*da0073e9SAndroid Build Coastguard Worker for batch in loader: 3276*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(batch, NamedTupleDataset.Batch) 3277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA) 3278*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(batch.data, NamedTupleDataset.Data) 3279*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(batch.data.positive, torch.Tensor) 3280*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batch.data.positive.is_pinned(), TEST_CUDA) 3281*da0073e9SAndroid Build Coastguard Worker # no auto-collation 3282*da0073e9SAndroid Build Coastguard Worker loader = DataLoader(self.dataset, batch_size=None, pin_memory=TEST_CUDA) 3283*da0073e9SAndroid Build Coastguard Worker for batch in loader: 3284*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(batch, NamedTupleDataset.Batch) 3285*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA) 3286*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(batch.data, NamedTupleDataset.Data) 3287*da0073e9SAndroid Build Coastguard Worker self.assertNotIsInstance(batch.data.positive, torch.Tensor) 3288*da0073e9SAndroid Build Coastguard Worker 3289*da0073e9SAndroid Build Coastguard Worker 3290*da0073e9SAndroid Build Coastguard Workerclass SimpleCustomBatch: 3291*da0073e9SAndroid Build Coastguard Worker def __init__(self, data): 3292*da0073e9SAndroid Build Coastguard Worker transposed_data = list(zip(*data)) 3293*da0073e9SAndroid Build Coastguard Worker self.inp = torch.stack(transposed_data[0], 0) 3294*da0073e9SAndroid Build Coastguard Worker self.tgt = torch.stack(transposed_data[1], 0) 3295*da0073e9SAndroid Build Coastguard Worker 3296*da0073e9SAndroid Build Coastguard Worker def pin_memory(self): 3297*da0073e9SAndroid Build Coastguard Worker self.inp = self.inp.pin_memory() 3298*da0073e9SAndroid Build Coastguard Worker self.tgt = self.tgt.pin_memory() 3299*da0073e9SAndroid Build Coastguard Worker return self 3300*da0073e9SAndroid Build Coastguard Worker 3301*da0073e9SAndroid Build Coastguard Worker def is_pinned(self): 3302*da0073e9SAndroid Build Coastguard Worker return self.inp.is_pinned() and self.tgt.is_pinned() 3303*da0073e9SAndroid Build Coastguard Worker 3304*da0073e9SAndroid Build Coastguard Worker 3305*da0073e9SAndroid Build Coastguard Worker# Workaround for https://github.com/pytorch/pytorch/issues/50661 3306*da0073e9SAndroid Build Coastguard Worker# Classes from `__main__` can not be correctly unpickled from spawned module 3307*da0073e9SAndroid Build Coastguard Worker# See https://docs.python.org/3/library/multiprocessing.html#multiprocessing-programming 3308*da0073e9SAndroid Build Coastguard Workerself_module = __import__(os.path.splitext(os.path.basename(__file__))[0]) 3309*da0073e9SAndroid Build Coastguard Worker 3310*da0073e9SAndroid Build Coastguard Worker 3311*da0073e9SAndroid Build Coastguard Workerdef collate_wrapper(batch): 3312*da0073e9SAndroid Build Coastguard Worker return self_module.SimpleCustomBatch(batch) 3313*da0073e9SAndroid Build Coastguard Worker 3314*da0073e9SAndroid Build Coastguard Worker 3315*da0073e9SAndroid Build Coastguard Workerdef collate_into_packed_sequence(batch): 3316*da0073e9SAndroid Build Coastguard Worker data = torch.stack([sample[0] for sample in batch], 1) 3317*da0073e9SAndroid Build Coastguard Worker t, b = data.size() 3318*da0073e9SAndroid Build Coastguard Worker lengths = torch.randint(1, t, size=(b,), dtype=torch.int64) 3319*da0073e9SAndroid Build Coastguard Worker return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, enforce_sorted=False) 3320*da0073e9SAndroid Build Coastguard Worker 3321*da0073e9SAndroid Build Coastguard Worker 3322*da0073e9SAndroid Build Coastguard Workerdef collate_into_packed_sequence_batch_first(batch): 3323*da0073e9SAndroid Build Coastguard Worker data = torch.stack([sample[0] for sample in batch], 0) 3324*da0073e9SAndroid Build Coastguard Worker b, t = data.size() 3325*da0073e9SAndroid Build Coastguard Worker lengths = torch.randint(1, t, size=(b,), dtype=torch.int64) 3326*da0073e9SAndroid Build Coastguard Worker return torch.nn.utils.rnn.pack_padded_sequence( 3327*da0073e9SAndroid Build Coastguard Worker data, lengths, batch_first=True, enforce_sorted=False 3328*da0073e9SAndroid Build Coastguard Worker ) 3329*da0073e9SAndroid Build Coastguard Worker 3330*da0073e9SAndroid Build Coastguard Worker 3331*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 3332*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 3333*da0073e9SAndroid Build Coastguard Worker "Fails with TSAN with the following error: starting new threads after multi-threaded " 3334*da0073e9SAndroid Build Coastguard Worker "fork is not supported. Dying (set die_after_fork=0 to override)", 3335*da0073e9SAndroid Build Coastguard Worker) 3336*da0073e9SAndroid Build Coastguard Workerclass TestCustomPinFn(TestCase): 3337*da0073e9SAndroid Build Coastguard Worker def setUp(self): 3338*da0073e9SAndroid Build Coastguard Worker super().setUp() 3339*da0073e9SAndroid Build Coastguard Worker inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) 3340*da0073e9SAndroid Build Coastguard Worker tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) 3341*da0073e9SAndroid Build Coastguard Worker self.dataset = TensorDataset(inps, tgts) 3342*da0073e9SAndroid Build Coastguard Worker 3343*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 3344*da0073e9SAndroid Build Coastguard Worker def test_custom_batch_pin(self): 3345*da0073e9SAndroid Build Coastguard Worker test_cases = [ 3346*da0073e9SAndroid Build Coastguard Worker (collate_wrapper, self_module.SimpleCustomBatch), 3347*da0073e9SAndroid Build Coastguard Worker (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence), 3348*da0073e9SAndroid Build Coastguard Worker ( 3349*da0073e9SAndroid Build Coastguard Worker collate_into_packed_sequence_batch_first, 3350*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.rnn.PackedSequence, 3351*da0073e9SAndroid Build Coastguard Worker ), 3352*da0073e9SAndroid Build Coastguard Worker ] 3353*da0073e9SAndroid Build Coastguard Worker for collate_fn, elem_cls in test_cases: 3354*da0073e9SAndroid Build Coastguard Worker loader = DataLoader( 3355*da0073e9SAndroid Build Coastguard Worker self.dataset, batch_size=2, collate_fn=collate_fn, pin_memory=True 3356*da0073e9SAndroid Build Coastguard Worker ) 3357*da0073e9SAndroid Build Coastguard Worker for sample in loader: 3358*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(sample, elem_cls) 3359*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sample.is_pinned()) 3360*da0073e9SAndroid Build Coastguard Worker 3361*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 3362*da0073e9SAndroid Build Coastguard Worker def test_custom_batch_pin_worker(self): 3363*da0073e9SAndroid Build Coastguard Worker test_cases = [ 3364*da0073e9SAndroid Build Coastguard Worker (collate_wrapper, self_module.SimpleCustomBatch), 3365*da0073e9SAndroid Build Coastguard Worker (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence), 3366*da0073e9SAndroid Build Coastguard Worker ( 3367*da0073e9SAndroid Build Coastguard Worker collate_into_packed_sequence_batch_first, 3368*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.rnn.PackedSequence, 3369*da0073e9SAndroid Build Coastguard Worker ), 3370*da0073e9SAndroid Build Coastguard Worker ] 3371*da0073e9SAndroid Build Coastguard Worker for collate_fn, elem_cls in test_cases: 3372*da0073e9SAndroid Build Coastguard Worker loader = DataLoader( 3373*da0073e9SAndroid Build Coastguard Worker self.dataset, 3374*da0073e9SAndroid Build Coastguard Worker batch_size=2, 3375*da0073e9SAndroid Build Coastguard Worker collate_fn=collate_fn, 3376*da0073e9SAndroid Build Coastguard Worker pin_memory=True, 3377*da0073e9SAndroid Build Coastguard Worker num_workers=1, 3378*da0073e9SAndroid Build Coastguard Worker ) 3379*da0073e9SAndroid Build Coastguard Worker for sample in loader: 3380*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(sample, elem_cls) 3381*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sample.is_pinned()) 3382*da0073e9SAndroid Build Coastguard Worker 3383*da0073e9SAndroid Build Coastguard Worker 3384*da0073e9SAndroid Build Coastguard Workerclass TestWorkerQueueDataset(Dataset): 3385*da0073e9SAndroid Build Coastguard Worker def __init__(self, data): 3386*da0073e9SAndroid Build Coastguard Worker self.data = data 3387*da0073e9SAndroid Build Coastguard Worker self.worker_id = None 3388*da0073e9SAndroid Build Coastguard Worker 3389*da0073e9SAndroid Build Coastguard Worker def worker_init_fn(self, worker_id): 3390*da0073e9SAndroid Build Coastguard Worker self.worker_id = worker_id 3391*da0073e9SAndroid Build Coastguard Worker 3392*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, item): 3393*da0073e9SAndroid Build Coastguard Worker return self.worker_id, self.data[item] 3394*da0073e9SAndroid Build Coastguard Worker 3395*da0073e9SAndroid Build Coastguard Worker def __len__(self): 3396*da0073e9SAndroid Build Coastguard Worker return len(self.data) 3397*da0073e9SAndroid Build Coastguard Worker 3398*da0073e9SAndroid Build Coastguard Worker 3399*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 3400*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 3401*da0073e9SAndroid Build Coastguard Worker "Fails with TSAN with the following error: starting new threads after multi-threaded " 3402*da0073e9SAndroid Build Coastguard Worker "fork is not supported. Dying (set die_after_fork=0 to override)", 3403*da0073e9SAndroid Build Coastguard Worker) 3404*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 3405*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ASAN, 3406*da0073e9SAndroid Build Coastguard Worker "Flaky with ASAN, see https://github.com/pytorch/pytorch/issues/65727", 3407*da0073e9SAndroid Build Coastguard Worker) 3408*da0073e9SAndroid Build Coastguard Workerclass TestIndividualWorkerQueue(TestCase): 3409*da0073e9SAndroid Build Coastguard Worker def setUp(self): 3410*da0073e9SAndroid Build Coastguard Worker super().setUp() 3411*da0073e9SAndroid Build Coastguard Worker self.dataset = TestWorkerQueueDataset(list(range(128))) 3412*da0073e9SAndroid Build Coastguard Worker 3413*da0073e9SAndroid Build Coastguard Worker def _run_ind_worker_queue_test(self, batch_size, num_workers): 3414*da0073e9SAndroid Build Coastguard Worker loader = DataLoader( 3415*da0073e9SAndroid Build Coastguard Worker self.dataset, 3416*da0073e9SAndroid Build Coastguard Worker batch_size=batch_size, 3417*da0073e9SAndroid Build Coastguard Worker shuffle=False, 3418*da0073e9SAndroid Build Coastguard Worker num_workers=num_workers, 3419*da0073e9SAndroid Build Coastguard Worker timeout=5, 3420*da0073e9SAndroid Build Coastguard Worker worker_init_fn=self.dataset.worker_init_fn, 3421*da0073e9SAndroid Build Coastguard Worker ) 3422*da0073e9SAndroid Build Coastguard Worker current_worker_idx = 0 3423*da0073e9SAndroid Build Coastguard Worker for i, (worker_ids, sample) in enumerate(loader): 3424*da0073e9SAndroid Build Coastguard Worker self.assertEqual(worker_ids.tolist(), [current_worker_idx] * batch_size) 3425*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3426*da0073e9SAndroid Build Coastguard Worker sample.tolist(), list(range(i * batch_size, (i + 1) * batch_size)) 3427*da0073e9SAndroid Build Coastguard Worker ) 3428*da0073e9SAndroid Build Coastguard Worker current_worker_idx += 1 3429*da0073e9SAndroid Build Coastguard Worker if current_worker_idx == num_workers: 3430*da0073e9SAndroid Build Coastguard Worker current_worker_idx = 0 3431*da0073e9SAndroid Build Coastguard Worker 3432*da0073e9SAndroid Build Coastguard Worker def test_ind_worker_queue(self): 3433*da0073e9SAndroid Build Coastguard Worker max_num_workers = None 3434*da0073e9SAndroid Build Coastguard Worker if hasattr(os, "sched_getaffinity"): 3435*da0073e9SAndroid Build Coastguard Worker try: 3436*da0073e9SAndroid Build Coastguard Worker max_num_workers = len(os.sched_getaffinity(0)) 3437*da0073e9SAndroid Build Coastguard Worker except Exception: 3438*da0073e9SAndroid Build Coastguard Worker pass 3439*da0073e9SAndroid Build Coastguard Worker if max_num_workers is None: 3440*da0073e9SAndroid Build Coastguard Worker cpu_count = os.cpu_count() 3441*da0073e9SAndroid Build Coastguard Worker if cpu_count is not None: 3442*da0073e9SAndroid Build Coastguard Worker # Use half number of CPUs 3443*da0073e9SAndroid Build Coastguard Worker max_num_workers = cpu_count // 2 3444*da0073e9SAndroid Build Coastguard Worker 3445*da0073e9SAndroid Build Coastguard Worker if max_num_workers is None: 3446*da0073e9SAndroid Build Coastguard Worker max_num_workers = 1 3447*da0073e9SAndroid Build Coastguard Worker 3448*da0073e9SAndroid Build Coastguard Worker for batch_size in (8, 16, 32, 64): 3449*da0073e9SAndroid Build Coastguard Worker for num_workers in range(0, min(6, max_num_workers)): 3450*da0073e9SAndroid Build Coastguard Worker self._run_ind_worker_queue_test( 3451*da0073e9SAndroid Build Coastguard Worker batch_size=batch_size, num_workers=num_workers + 1 3452*da0073e9SAndroid Build Coastguard Worker ) 3453*da0073e9SAndroid Build Coastguard Worker 3454*da0073e9SAndroid Build Coastguard Worker 3455*da0073e9SAndroid Build Coastguard Workerclass SetAffinityDataset(IterableDataset): 3456*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 3457*da0073e9SAndroid Build Coastguard Worker torch.randperm(1) 3458*da0073e9SAndroid Build Coastguard Worker after = os.sched_getaffinity(0) 3459*da0073e9SAndroid Build Coastguard Worker return iter(after) 3460*da0073e9SAndroid Build Coastguard Worker 3461*da0073e9SAndroid Build Coastguard Worker 3462*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 3463*da0073e9SAndroid Build Coastguard Worker not hasattr(os, "sched_setaffinity"), "os.sched_setaffinity is not available" 3464*da0073e9SAndroid Build Coastguard Worker) 3465*da0073e9SAndroid Build Coastguard Workerclass TestSetAffinity(TestCase): 3466*da0073e9SAndroid Build Coastguard Worker def test_set_affinity_in_worker_init(self): 3467*da0073e9SAndroid Build Coastguard Worker # Query the current affinity mask to avoid setting a disallowed one 3468*da0073e9SAndroid Build Coastguard Worker old_affinity = os.sched_getaffinity(0) 3469*da0073e9SAndroid Build Coastguard Worker if not old_affinity: 3470*da0073e9SAndroid Build Coastguard Worker self.skipTest("No affinity information") 3471*da0073e9SAndroid Build Coastguard Worker # Choose any 3472*da0073e9SAndroid Build Coastguard Worker expected_affinity = list(old_affinity)[-1] 3473*da0073e9SAndroid Build Coastguard Worker 3474*da0073e9SAndroid Build Coastguard Worker def worker_set_affinity(_): 3475*da0073e9SAndroid Build Coastguard Worker os.sched_setaffinity(0, [expected_affinity]) 3476*da0073e9SAndroid Build Coastguard Worker 3477*da0073e9SAndroid Build Coastguard Worker dataset = SetAffinityDataset() 3478*da0073e9SAndroid Build Coastguard Worker 3479*da0073e9SAndroid Build Coastguard Worker dataloader = torch.utils.data.DataLoader( 3480*da0073e9SAndroid Build Coastguard Worker dataset, num_workers=2, worker_init_fn=worker_set_affinity 3481*da0073e9SAndroid Build Coastguard Worker ) 3482*da0073e9SAndroid Build Coastguard Worker for sample in dataloader: 3483*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sample, [expected_affinity]) 3484*da0073e9SAndroid Build Coastguard Worker 3485*da0073e9SAndroid Build Coastguard Worker 3486*da0073e9SAndroid Build Coastguard Workerclass ConvDataset(Dataset): 3487*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3488*da0073e9SAndroid Build Coastguard Worker self.x = torch.ones(1, 1, 24000) 3489*da0073e9SAndroid Build Coastguard Worker # Call convolution on parent process 3490*da0073e9SAndroid Build Coastguard Worker self[0] 3491*da0073e9SAndroid Build Coastguard Worker 3492*da0073e9SAndroid Build Coastguard Worker def __len__(self): 3493*da0073e9SAndroid Build Coastguard Worker return 1 3494*da0073e9SAndroid Build Coastguard Worker 3495*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, index): 3496*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.conv1d(self.x, torch.ones(1, 1, 2)) 3497*da0073e9SAndroid Build Coastguard Worker 3498*da0073e9SAndroid Build Coastguard Worker 3499*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_WINDOWS, "Needs fork") 3500*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 3501*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ASAN, 3502*da0073e9SAndroid Build Coastguard Worker "This test hangs when running with ASAN, see https://github.com/pytorch/pytorch/issues/75492", 3503*da0073e9SAndroid Build Coastguard Worker) 3504*da0073e9SAndroid Build Coastguard Workerclass TestConvAfterFork(TestCase): 3505*da0073e9SAndroid Build Coastguard Worker # Tests crash reported in https://github.com/pytorch/pytorch/issues/53565 3506*da0073e9SAndroid Build Coastguard Worker def test_conv_after_fork(self): 3507*da0073e9SAndroid Build Coastguard Worker loader = DataLoader(ConvDataset(), num_workers=1) 3508*da0073e9SAndroid Build Coastguard Worker for x in loader: 3509*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.shape, (1, 1, 1, 23999)) 3510*da0073e9SAndroid Build Coastguard Worker 3511*da0073e9SAndroid Build Coastguard Worker 3512*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestDataLoaderDeviceType, globals()) 3513*da0073e9SAndroid Build Coastguard Worker 3514*da0073e9SAndroid Build Coastguard Worker 3515*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 3516*da0073e9SAndroid Build Coastguard Worker run_tests() 3517