xref: /aosp_15_r20/external/pytorch/test/test_dataloader.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dataloader"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport ctypes
4*da0073e9SAndroid Build Coastguard Workerimport errno
5*da0073e9SAndroid Build Coastguard Workerimport faulthandler
6*da0073e9SAndroid Build Coastguard Workerimport functools
7*da0073e9SAndroid Build Coastguard Workerimport gc
8*da0073e9SAndroid Build Coastguard Workerimport itertools
9*da0073e9SAndroid Build Coastguard Workerimport math
10*da0073e9SAndroid Build Coastguard Workerimport operator
11*da0073e9SAndroid Build Coastguard Workerimport os
12*da0073e9SAndroid Build Coastguard Workerimport signal
13*da0073e9SAndroid Build Coastguard Workerimport sys
14*da0073e9SAndroid Build Coastguard Workerimport tempfile
15*da0073e9SAndroid Build Coastguard Workerimport time
16*da0073e9SAndroid Build Coastguard Workerimport unittest
17*da0073e9SAndroid Build Coastguard Workerimport warnings
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerimport torch
20*da0073e9SAndroid Build Coastguard Workerimport torch.utils.data.datapipes as dp
21*da0073e9SAndroid Build Coastguard Workerfrom torch import multiprocessing as mp
22*da0073e9SAndroid Build Coastguard Workerfrom torch._utils import ExceptionWrapper
23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import instantiate_device_type_tests
24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
25*da0073e9SAndroid Build Coastguard Worker    IS_CI,
26*da0073e9SAndroid Build Coastguard Worker    IS_JETSON,
27*da0073e9SAndroid Build Coastguard Worker    IS_MACOS,
28*da0073e9SAndroid Build Coastguard Worker    IS_SANDCASTLE,
29*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
30*da0073e9SAndroid Build Coastguard Worker    load_tests,
31*da0073e9SAndroid Build Coastguard Worker    NO_MULTIPROCESSING_SPAWN,
32*da0073e9SAndroid Build Coastguard Worker    parametrize,
33*da0073e9SAndroid Build Coastguard Worker    run_tests,
34*da0073e9SAndroid Build Coastguard Worker    skipIfNoDill,
35*da0073e9SAndroid Build Coastguard Worker    skipIfRocm,
36*da0073e9SAndroid Build Coastguard Worker    slowTest,
37*da0073e9SAndroid Build Coastguard Worker    TEST_CUDA,
38*da0073e9SAndroid Build Coastguard Worker    TEST_NUMPY,
39*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ASAN,
40*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
41*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
42*da0073e9SAndroid Build Coastguard Worker    TestCase,
43*da0073e9SAndroid Build Coastguard Worker)
44*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data import (
45*da0073e9SAndroid Build Coastguard Worker    _utils,
46*da0073e9SAndroid Build Coastguard Worker    ChainDataset,
47*da0073e9SAndroid Build Coastguard Worker    ConcatDataset,
48*da0073e9SAndroid Build Coastguard Worker    DataLoader,
49*da0073e9SAndroid Build Coastguard Worker    Dataset,
50*da0073e9SAndroid Build Coastguard Worker    IterableDataset,
51*da0073e9SAndroid Build Coastguard Worker    IterDataPipe,
52*da0073e9SAndroid Build Coastguard Worker    StackDataset,
53*da0073e9SAndroid Build Coastguard Worker    Subset,
54*da0073e9SAndroid Build Coastguard Worker    TensorDataset,
55*da0073e9SAndroid Build Coastguard Worker)
56*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
57*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.datapipes.iter import IterableWrapper
58*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.dataset import random_split
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Workertry:
62*da0073e9SAndroid Build Coastguard Worker    import psutil
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    HAS_PSUTIL = True
65*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError:
66*da0073e9SAndroid Build Coastguard Worker    HAS_PSUTIL = False
67*da0073e9SAndroid Build Coastguard Worker    psutil = None
68*da0073e9SAndroid Build Coastguard Worker    err_msg = (
69*da0073e9SAndroid Build Coastguard Worker        "psutil not found. Some critical data loader tests relying on it "
70*da0073e9SAndroid Build Coastguard Worker        "(e.g., TestDataLoader.test_proper_exit) will not run."
71*da0073e9SAndroid Build Coastguard Worker    )
72*da0073e9SAndroid Build Coastguard Worker    if IS_CI:
73*da0073e9SAndroid Build Coastguard Worker        raise ModuleNotFoundError(err_msg) from None
74*da0073e9SAndroid Build Coastguard Worker    else:
75*da0073e9SAndroid Build Coastguard Worker        warnings.warn(err_msg)
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Workertry:
79*da0073e9SAndroid Build Coastguard Worker    import numpy as np
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker    HAS_NUMPY = True
82*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError:
83*da0073e9SAndroid Build Coastguard Worker    HAS_NUMPY = False
84*da0073e9SAndroid Build Coastguard Worker    np = None
85*da0073e9SAndroid Build Coastguard WorkerskipIfNoNumpy = unittest.skipIf(not HAS_NUMPY, "no NumPy")
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
88*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings
89*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard WorkerTEST_CUDA_IPC = (
92*da0073e9SAndroid Build Coastguard Worker    torch.cuda.is_available()
93*da0073e9SAndroid Build Coastguard Worker    and sys.platform != "darwin"
94*da0073e9SAndroid Build Coastguard Worker    and sys.platform != "win32"
95*da0073e9SAndroid Build Coastguard Worker    and not IS_JETSON
96*da0073e9SAndroid Build Coastguard Worker    and not TEST_WITH_ROCM
97*da0073e9SAndroid Build Coastguard Worker)  # https://github.com/pytorch/pytorch/issues/90940
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard WorkerTEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Workerif not NO_MULTIPROCESSING_SPAWN:
102*da0073e9SAndroid Build Coastguard Worker    # We want to use `spawn` if able because some of our tests check that the
103*da0073e9SAndroid Build Coastguard Worker    # data loader terminiates gracefully. To prevent hanging in the testing
104*da0073e9SAndroid Build Coastguard Worker    # process, such data loaders are run in a separate subprocess.
105*da0073e9SAndroid Build Coastguard Worker    #
106*da0073e9SAndroid Build Coastguard Worker    # We also want to test the `pin_memory=True` configuration, thus `spawn` is
107*da0073e9SAndroid Build Coastguard Worker    # required to launch such processes and they initialize the CUDA context.
108*da0073e9SAndroid Build Coastguard Worker    #
109*da0073e9SAndroid Build Coastguard Worker    # Mixing different start method is a recipe for disaster (e.g., using a fork
110*da0073e9SAndroid Build Coastguard Worker    # `mp.Event` with a spawn `mp.Process` segfaults). So we set this globally
111*da0073e9SAndroid Build Coastguard Worker    # to avoid bugs.
112*da0073e9SAndroid Build Coastguard Worker    #
113*da0073e9SAndroid Build Coastguard Worker    # Get a multiprocessing context because some test / third party library will
114*da0073e9SAndroid Build Coastguard Worker    # set start_method when imported, and setting again triggers `RuntimeError`.
115*da0073e9SAndroid Build Coastguard Worker    mp = mp.get_context(method="spawn")
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker# 60s of timeout?
119*da0073e9SAndroid Build Coastguard Worker# Yes, in environments where physical CPU resources are shared, e.g., CI, the
120*da0073e9SAndroid Build Coastguard Worker# time for a inter-process communication can be highly varying.  With 15~17s of
121*da0073e9SAndroid Build Coastguard Worker# timeout, we have observed flakiness in some CI builds (see
122*da0073e9SAndroid Build Coastguard Worker# pytorch/pytorch#14501, pytorch/pytorch#16608).  We follow the CPython
123*da0073e9SAndroid Build Coastguard Worker# multiprocessing setup and set the timeout to 60s here:
124*da0073e9SAndroid Build Coastguard Worker#
125*da0073e9SAndroid Build Coastguard Worker# https://github.com/python/cpython/blob/e8113f51a8bdf33188ee30a1c038a298329e7bfa/Lib/test/_test_multiprocessing.py#L73
126*da0073e9SAndroid Build Coastguard WorkerJOIN_TIMEOUT = 60.0  # seconds
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Workersupported_multiprocessing_contexts = [None] + list(
130*da0073e9SAndroid Build Coastguard Worker    torch.multiprocessing.get_all_start_methods()
131*da0073e9SAndroid Build Coastguard Worker)
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker# collate_fn that returns the batch cloned; defined globally here for pickle purposes.
135*da0073e9SAndroid Build Coastguard Workerdef _clone_collate(b):
136*da0073e9SAndroid Build Coastguard Worker    return [x.clone() for x in b]
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
140*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
141*da0073e9SAndroid Build Coastguard Worker    "Fails with TSAN with the following error: starting new threads after multi-threaded "
142*da0073e9SAndroid Build Coastguard Worker    "fork is not supported. Dying (set die_after_fork=0 to override)",
143*da0073e9SAndroid Build Coastguard Worker)
144*da0073e9SAndroid Build Coastguard Workerclass TestDatasetRandomSplit(TestCase):
145*da0073e9SAndroid Build Coastguard Worker    def test_lengths_must_equal_dataset_size(self):
146*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
147*da0073e9SAndroid Build Coastguard Worker            random_split([1, 2, 3, 4], [1, 2])
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker    def test_splits_have_correct_size(self):
150*da0073e9SAndroid Build Coastguard Worker        splits = random_split([1, 2, 3, 4, 5, 6], [2, 4])
151*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(splits), 2)
152*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(splits[0]), 2)
153*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(splits[1]), 4)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        splits = random_split([1, 2, 3, 4, 5, 6], [0.5, 0.5])
156*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(splits), 2)
157*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(splits[0]), 3)
158*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(splits[1]), 3)
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker        # Odd size splits
161*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
162*da0073e9SAndroid Build Coastguard Worker            len(
163*da0073e9SAndroid Build Coastguard Worker                random_split(
164*da0073e9SAndroid Build Coastguard Worker                    range(3), [0.5, 0.5], generator=torch.Generator().manual_seed(1)
165*da0073e9SAndroid Build Coastguard Worker                )
166*da0073e9SAndroid Build Coastguard Worker            ),
167*da0073e9SAndroid Build Coastguard Worker            2,
168*da0073e9SAndroid Build Coastguard Worker        )
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker        # Odd sized round-robin splits
171*da0073e9SAndroid Build Coastguard Worker        splits = random_split(
172*da0073e9SAndroid Build Coastguard Worker            range(106), [0.1, 0.2, 0.3, 0.4], generator=torch.Generator().manual_seed(1)
173*da0073e9SAndroid Build Coastguard Worker        )
174*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(splits[0]), 11)
175*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(splits[1]), 22)
176*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(splits[2]), 31)
177*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(splits[3]), 42)
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker    def test_splits_are_mutually_exclusive(self):
180*da0073e9SAndroid Build Coastguard Worker        data = [5, 2, 3, 4, 1, 6]
181*da0073e9SAndroid Build Coastguard Worker        splits = random_split(data, [2, 4])
182*da0073e9SAndroid Build Coastguard Worker        all_values = []
183*da0073e9SAndroid Build Coastguard Worker        all_values.extend(list(splits[0]))
184*da0073e9SAndroid Build Coastguard Worker        all_values.extend(list(splits[1]))
185*da0073e9SAndroid Build Coastguard Worker        data.sort()
186*da0073e9SAndroid Build Coastguard Worker        all_values.sort()
187*da0073e9SAndroid Build Coastguard Worker        self.assertListEqual(data, all_values)
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker        splits = random_split(data, [0.33, 0.67])
190*da0073e9SAndroid Build Coastguard Worker        all_values = []
191*da0073e9SAndroid Build Coastguard Worker        all_values.extend(list(splits[0]))
192*da0073e9SAndroid Build Coastguard Worker        all_values.extend(list(splits[1]))
193*da0073e9SAndroid Build Coastguard Worker        data.sort()
194*da0073e9SAndroid Build Coastguard Worker        all_values.sort()
195*da0073e9SAndroid Build Coastguard Worker        self.assertListEqual(data, all_values)
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        data = [1, 2, 3, 4]
198*da0073e9SAndroid Build Coastguard Worker        splits = random_split(data, [0.25, 0.75])
199*da0073e9SAndroid Build Coastguard Worker        all_values = []
200*da0073e9SAndroid Build Coastguard Worker        all_values.extend(list(splits[0]))
201*da0073e9SAndroid Build Coastguard Worker        all_values.extend(list(splits[1]))
202*da0073e9SAndroid Build Coastguard Worker        data.sort()
203*da0073e9SAndroid Build Coastguard Worker        all_values.sort()
204*da0073e9SAndroid Build Coastguard Worker        self.assertListEqual(data, all_values)
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker    def test_splits_indexing_type(self):
207*da0073e9SAndroid Build Coastguard Worker        r"""Indices generated by random_split
208*da0073e9SAndroid Build Coastguard Worker        should be of integer type
209*da0073e9SAndroid Build Coastguard Worker        """
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker        class CustomDataset:
212*da0073e9SAndroid Build Coastguard Worker            def __init__(self, test_object, custom_list):
213*da0073e9SAndroid Build Coastguard Worker                self.data = custom_list
214*da0073e9SAndroid Build Coastguard Worker                self.test_object = test_object
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker            def __getitem__(self, key):
217*da0073e9SAndroid Build Coastguard Worker                self.test_object.assertEqual(type(key), int)
218*da0073e9SAndroid Build Coastguard Worker                return self.data[key]
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
221*da0073e9SAndroid Build Coastguard Worker                return len(self.data)
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker        x = [1, 2, 3, 4, 5]
224*da0073e9SAndroid Build Coastguard Worker        dataset = CustomDataset(self, x)
225*da0073e9SAndroid Build Coastguard Worker        dataset = random_split(dataset, [5])[0]
226*da0073e9SAndroid Build Coastguard Worker        data_loader = DataLoader(dataset)
227*da0073e9SAndroid Build Coastguard Worker        for batch in data_loader:
228*da0073e9SAndroid Build Coastguard Worker            pass
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        # fractional splitting
231*da0073e9SAndroid Build Coastguard Worker        dataset = CustomDataset(self, x)
232*da0073e9SAndroid Build Coastguard Worker        dataset = random_split(dataset, [1.0])[0]
233*da0073e9SAndroid Build Coastguard Worker        data_loader = DataLoader(dataset)
234*da0073e9SAndroid Build Coastguard Worker        for batch in data_loader:
235*da0073e9SAndroid Build Coastguard Worker            pass
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker    def test_splits_reproducibility(self):
238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
239*da0073e9SAndroid Build Coastguard Worker            [
240*da0073e9SAndroid Build Coastguard Worker                list(x)
241*da0073e9SAndroid Build Coastguard Worker                for x in random_split(
242*da0073e9SAndroid Build Coastguard Worker                    range(10), [3, 7], generator=torch.Generator().manual_seed(1)
243*da0073e9SAndroid Build Coastguard Worker                )
244*da0073e9SAndroid Build Coastguard Worker            ],
245*da0073e9SAndroid Build Coastguard Worker            [[5, 6, 1], [2, 0, 8, 9, 3, 7, 4]],
246*da0073e9SAndroid Build Coastguard Worker        )
247*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
248*da0073e9SAndroid Build Coastguard Worker            random_split(
249*da0073e9SAndroid Build Coastguard Worker                range(100), [60, 40], generator=torch.Generator().manual_seed(42)
250*da0073e9SAndroid Build Coastguard Worker            ),
251*da0073e9SAndroid Build Coastguard Worker            random_split(
252*da0073e9SAndroid Build Coastguard Worker                range(100), [60, 40], generator=torch.Generator().manual_seed(42)
253*da0073e9SAndroid Build Coastguard Worker            ),
254*da0073e9SAndroid Build Coastguard Worker        )
255*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
256*da0073e9SAndroid Build Coastguard Worker            random_split(
257*da0073e9SAndroid Build Coastguard Worker                range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42)
258*da0073e9SAndroid Build Coastguard Worker            ),
259*da0073e9SAndroid Build Coastguard Worker            random_split(
260*da0073e9SAndroid Build Coastguard Worker                range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42)
261*da0073e9SAndroid Build Coastguard Worker            ),
262*da0073e9SAndroid Build Coastguard Worker        )
263*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
264*da0073e9SAndroid Build Coastguard Worker            random_split(
265*da0073e9SAndroid Build Coastguard Worker                range(100),
266*da0073e9SAndroid Build Coastguard Worker                [0.33, 0.33, 0.34],
267*da0073e9SAndroid Build Coastguard Worker                generator=torch.Generator().manual_seed(42),
268*da0073e9SAndroid Build Coastguard Worker            ),
269*da0073e9SAndroid Build Coastguard Worker            random_split(
270*da0073e9SAndroid Build Coastguard Worker                range(100),
271*da0073e9SAndroid Build Coastguard Worker                [0.33, 0.33, 0.34],
272*da0073e9SAndroid Build Coastguard Worker                generator=torch.Generator().manual_seed(42),
273*da0073e9SAndroid Build Coastguard Worker            ),
274*da0073e9SAndroid Build Coastguard Worker        )
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker    def test_incomplete_fractional_splits(self):
277*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
278*da0073e9SAndroid Build Coastguard Worker            # should raise since the sum of fractions is not 1
279*da0073e9SAndroid Build Coastguard Worker            random_split([1, 2, 3, 4], [0.1])
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
282*da0073e9SAndroid Build Coastguard Worker            # should raise since fraction > 1
283*da0073e9SAndroid Build Coastguard Worker            random_split([1, 2, 3, 4], [1.1])
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker    def test_splits_generator(self):
286*da0073e9SAndroid Build Coastguard Worker        # A random_split without a specific generator should affect the default one
287*da0073e9SAndroid Build Coastguard Worker        state = torch.get_rng_state()
288*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(10)
289*da0073e9SAndroid Build Coastguard Worker        torch.set_rng_state(state)
290*da0073e9SAndroid Build Coastguard Worker        random_split(range(10), [5, 5])
291*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(10)
292*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(a, b)
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker        # A random_split with a specific generator should not affect the default one
295*da0073e9SAndroid Build Coastguard Worker        state = torch.get_rng_state()
296*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(10)
297*da0073e9SAndroid Build Coastguard Worker        torch.set_rng_state(state)
298*da0073e9SAndroid Build Coastguard Worker        random_split(range(10), [5, 5], generator=torch.Generator().manual_seed(42))
299*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(10)
300*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, b)
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker    def test_slicing_of_subset_of_dataset(self):
303*da0073e9SAndroid Build Coastguard Worker        # Testing slicing a subset initialized with a dataset
304*da0073e9SAndroid Build Coastguard Worker        dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5]))
305*da0073e9SAndroid Build Coastguard Worker        subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4])
306*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(subset_of_dataset[:], dataset[:])
307*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(subset_of_dataset[1:2], dataset[1:2])
308*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(subset_of_dataset[0:-1:2], dataset[0:-1:2])
309*da0073e9SAndroid Build Coastguard Worker        # Testing slicing of subset from random split
310*da0073e9SAndroid Build Coastguard Worker        subset1, subset2 = random_split(dataset, [3, 2])
311*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(subset1[:], dataset[subset1.indices[:]])
312*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(subset1[0:2], dataset[subset1.indices[0:2]])
313*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(subset1[0:-1:2], dataset[subset1.indices[0:-1:2]])
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker    def test_slicing_of_subset_of_subset(self):
316*da0073e9SAndroid Build Coastguard Worker        # Testing slicing a subset initialized with a subset
317*da0073e9SAndroid Build Coastguard Worker        dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5]))
318*da0073e9SAndroid Build Coastguard Worker        subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4])
319*da0073e9SAndroid Build Coastguard Worker        subset_of_subset = Subset(subset_of_dataset, [0, 1, 2, 3, 4])
320*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(subset_of_subset[:], dataset[:])
321*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(subset_of_subset[0:2], dataset[0:2])
322*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(subset_of_subset[0:-1:2], dataset[0:-1:2])
323*da0073e9SAndroid Build Coastguard Worker        # Testing slicing of subset of subset from random split
324*da0073e9SAndroid Build Coastguard Worker        subset1, subset2 = random_split(dataset, [4, 1])
325*da0073e9SAndroid Build Coastguard Worker        subset_of_subset1, subset_of_subset2 = random_split(subset1, [3, 1])
326*da0073e9SAndroid Build Coastguard Worker        idx = [subset1.indices[i] for i in subset_of_subset1.indices]
327*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(subset_of_subset1[:], dataset[idx.copy()])
328*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(subset_of_subset1[0:2], dataset[idx[0:2]])
329*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(subset_of_subset1[0:-1:2], dataset[idx[0:-1:2]])
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Workerclass CUDACountingDataset(Dataset):
333*da0073e9SAndroid Build Coastguard Worker    def __init__(self, n):
334*da0073e9SAndroid Build Coastguard Worker        super().__init__()
335*da0073e9SAndroid Build Coastguard Worker        self.n = n
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, i):
338*da0073e9SAndroid Build Coastguard Worker        return torch.as_tensor(i, device="cuda")
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
341*da0073e9SAndroid Build Coastguard Worker        return self.n
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Workerclass CountingDataset(Dataset):
345*da0073e9SAndroid Build Coastguard Worker    def __init__(self, n):
346*da0073e9SAndroid Build Coastguard Worker        super().__init__()
347*da0073e9SAndroid Build Coastguard Worker        self.n = n
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, i):
350*da0073e9SAndroid Build Coastguard Worker        return i
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
353*da0073e9SAndroid Build Coastguard Worker        return self.n
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Workerclass CountingIterableDataset(IterableDataset):
357*da0073e9SAndroid Build Coastguard Worker    def __init__(self, n):
358*da0073e9SAndroid Build Coastguard Worker        super().__init__()
359*da0073e9SAndroid Build Coastguard Worker        self.n = n
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
362*da0073e9SAndroid Build Coastguard Worker        return iter(range(self.n))
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
365*da0073e9SAndroid Build Coastguard Worker        return self.n
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
369*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
370*da0073e9SAndroid Build Coastguard Worker    "Fails with TSAN with the following error: starting new threads after multi-threaded "
371*da0073e9SAndroid Build Coastguard Worker    "fork is not supported. Dying (set die_after_fork=0 to override)",
372*da0073e9SAndroid Build Coastguard Worker)
373*da0073e9SAndroid Build Coastguard Workerclass TestTensorDataset(TestCase):
374*da0073e9SAndroid Build Coastguard Worker    def test_len(self):
375*da0073e9SAndroid Build Coastguard Worker        source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15))
376*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(source), 15)
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker    def test_getitem(self):
379*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(15, 10, 2, 3, 4, 5)
380*da0073e9SAndroid Build Coastguard Worker        l = torch.randn(15, 10)
381*da0073e9SAndroid Build Coastguard Worker        source = TensorDataset(t, l)
382*da0073e9SAndroid Build Coastguard Worker        for i in range(15):
383*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[i], source[i][0])
384*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(l[i], source[i][1])
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker    def test_getitem_1d(self):
387*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(15)
388*da0073e9SAndroid Build Coastguard Worker        l = torch.randn(15)
389*da0073e9SAndroid Build Coastguard Worker        source = TensorDataset(t, l)
390*da0073e9SAndroid Build Coastguard Worker        for i in range(15):
391*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[i], source[i][0])
392*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(l[i], source[i][1])
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker    def test_single_tensor(self):
395*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(5, 10)
396*da0073e9SAndroid Build Coastguard Worker        source = TensorDataset(t)
397*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(source), 5)
398*da0073e9SAndroid Build Coastguard Worker        for i in range(5):
399*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[i], source[i][0])
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker    def test_many_tensors(self):
402*da0073e9SAndroid Build Coastguard Worker        t0 = torch.randn(5, 10, 2, 3, 4, 5)
403*da0073e9SAndroid Build Coastguard Worker        t1 = torch.randn(5, 10)
404*da0073e9SAndroid Build Coastguard Worker        t2 = torch.randn(5, 10, 2, 5)
405*da0073e9SAndroid Build Coastguard Worker        t3 = torch.randn(5, 10, 3, 7)
406*da0073e9SAndroid Build Coastguard Worker        source = TensorDataset(t0, t1, t2, t3)
407*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(source), 5)
408*da0073e9SAndroid Build Coastguard Worker        for i in range(5):
409*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t0[i], source[i][0])
410*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t1[i], source[i][1])
411*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t2[i], source[i][2])
412*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t3[i], source[i][3])
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
416*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
417*da0073e9SAndroid Build Coastguard Worker    "Fails with TSAN with the following error: starting new threads after multi-threaded "
418*da0073e9SAndroid Build Coastguard Worker    "fork is not supported. Dying (set die_after_fork=0 to override)",
419*da0073e9SAndroid Build Coastguard Worker)
420*da0073e9SAndroid Build Coastguard Workerclass TestStackDataset(TestCase):
421*da0073e9SAndroid Build Coastguard Worker    def test_empty(self):
422*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
423*da0073e9SAndroid Build Coastguard Worker            ValueError, "At least one dataset should be passed"
424*da0073e9SAndroid Build Coastguard Worker        ):
425*da0073e9SAndroid Build Coastguard Worker            StackDataset()
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker    def test_mixed(self):
428*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Supported either"):
429*da0073e9SAndroid Build Coastguard Worker            StackDataset(
430*da0073e9SAndroid Build Coastguard Worker                TensorDataset(torch.randn(15, 10)), a=TensorDataset(torch.randn(10, 15))
431*da0073e9SAndroid Build Coastguard Worker            )
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker    def test_size_mismatch(self):
434*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"):
435*da0073e9SAndroid Build Coastguard Worker            StackDataset(
436*da0073e9SAndroid Build Coastguard Worker                TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(10, 15))
437*da0073e9SAndroid Build Coastguard Worker            )
438*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"):
439*da0073e9SAndroid Build Coastguard Worker            StackDataset(
440*da0073e9SAndroid Build Coastguard Worker                a=TensorDataset(torch.randn(15, 10)),
441*da0073e9SAndroid Build Coastguard Worker                b=TensorDataset(torch.randn(10, 15)),
442*da0073e9SAndroid Build Coastguard Worker            )
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker    def test_len(self):
445*da0073e9SAndroid Build Coastguard Worker        source = StackDataset(
446*da0073e9SAndroid Build Coastguard Worker            TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(15))
447*da0073e9SAndroid Build Coastguard Worker        )
448*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(source), 15)
449*da0073e9SAndroid Build Coastguard Worker        source = StackDataset(TensorDataset(torch.randn(15, 10)))
450*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(source), 15)
451*da0073e9SAndroid Build Coastguard Worker        source = StackDataset(
452*da0073e9SAndroid Build Coastguard Worker            a=TensorDataset(torch.randn(15, 10)), b=TensorDataset(torch.randn(15))
453*da0073e9SAndroid Build Coastguard Worker        )
454*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(source), 15)
455*da0073e9SAndroid Build Coastguard Worker        source = StackDataset(a=TensorDataset(torch.randn(15, 10)))
456*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(source), 15)
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker    def test_single(self):
459*da0073e9SAndroid Build Coastguard Worker        t = TensorDataset(torch.randn(15, 10))
460*da0073e9SAndroid Build Coastguard Worker        source = StackDataset(t)
461*da0073e9SAndroid Build Coastguard Worker        for i in range(15):
462*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[i], source[i][0])
463*da0073e9SAndroid Build Coastguard Worker        source = StackDataset(a=t)
464*da0073e9SAndroid Build Coastguard Worker        for i in range(15):
465*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[i], source[i]["a"])
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker    def test_getitem(self):
468*da0073e9SAndroid Build Coastguard Worker        t = TensorDataset(torch.randn(15, 10))
469*da0073e9SAndroid Build Coastguard Worker        l = TensorDataset(torch.randn(15, 5, 4))
470*da0073e9SAndroid Build Coastguard Worker        source = StackDataset(t, l)
471*da0073e9SAndroid Build Coastguard Worker        for i in range(15):
472*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[i], source[i][0])
473*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(l[i], source[i][1])
474*da0073e9SAndroid Build Coastguard Worker        source = StackDataset(a=t, b=l)
475*da0073e9SAndroid Build Coastguard Worker        for i in range(15):
476*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[i], source[i]["a"])
477*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(l[i], source[i]["b"])
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker    def test_getitems(self):
480*da0073e9SAndroid Build Coastguard Worker        class GetItemsDataset(Dataset):
481*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
482*da0073e9SAndroid Build Coastguard Worker                self.data = torch.randn(4)
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker            def __getitem__(self, item):
485*da0073e9SAndroid Build Coastguard Worker                return self.data[item]
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker            def __getitems__(self, items):
488*da0073e9SAndroid Build Coastguard Worker                return self.data[items]
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
491*da0073e9SAndroid Build Coastguard Worker                return 4
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker        t = GetItemsDataset()
494*da0073e9SAndroid Build Coastguard Worker        l = [1, 2, 3, 4]
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker        source = StackDataset(t, l)
497*da0073e9SAndroid Build Coastguard Worker        batch = source.__getitems__([0, 1, 2, 3])
498*da0073e9SAndroid Build Coastguard Worker        for i in range(4):
499*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[i], batch[i][0])
500*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(l[i], batch[i][1])
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker        source = StackDataset(t=t, l=l)
503*da0073e9SAndroid Build Coastguard Worker        batch = source.__getitems__([0, 1, 2, 3])
504*da0073e9SAndroid Build Coastguard Worker        for i in range(4):
505*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[i], batch[i]["t"])
506*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(l[i], batch[i]["l"])
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker    def test_getitems_raises_index_error(self):
509*da0073e9SAndroid Build Coastguard Worker        class GetItemsDataset(Dataset):
510*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
511*da0073e9SAndroid Build Coastguard Worker                self.data = torch.randn(4)
512*da0073e9SAndroid Build Coastguard Worker
513*da0073e9SAndroid Build Coastguard Worker            def __getitem__(self, item):
514*da0073e9SAndroid Build Coastguard Worker                return self.data[item]
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker            def __getitems__(self, items):
517*da0073e9SAndroid Build Coastguard Worker                return self.data[items]
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
520*da0073e9SAndroid Build Coastguard Worker                return 4
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker        t = GetItemsDataset()
523*da0073e9SAndroid Build Coastguard Worker        l = [1, 2, 3, 4]
524*da0073e9SAndroid Build Coastguard Worker
525*da0073e9SAndroid Build Coastguard Worker        source = StackDataset(t, l)
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
528*da0073e9SAndroid Build Coastguard Worker            source.__getitems__([0, 4])
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker    def test_getitems_value_error(self):
531*da0073e9SAndroid Build Coastguard Worker        class GetItemsDataset(Dataset):
532*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
533*da0073e9SAndroid Build Coastguard Worker                self.data = torch.randn(4)
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker            def __getitem__(self, item):
536*da0073e9SAndroid Build Coastguard Worker                return self.data[item]
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker            def __getitems__(self, items):
539*da0073e9SAndroid Build Coastguard Worker                return self.data[items][:-1]  # return less
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
542*da0073e9SAndroid Build Coastguard Worker                return 4
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker        t = GetItemsDataset()
545*da0073e9SAndroid Build Coastguard Worker        l = [1, 2, 3, 4]
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker        source = StackDataset(t, l)
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
550*da0073e9SAndroid Build Coastguard Worker            ValueError, "Nested dataset's output size mismatch. Expected 4, got 3"
551*da0073e9SAndroid Build Coastguard Worker        ):
552*da0073e9SAndroid Build Coastguard Worker            source.__getitems__([0, 1, 2, 3])
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
556*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
557*da0073e9SAndroid Build Coastguard Worker    "Fails with TSAN with the following error: starting new threads after multi-threaded "
558*da0073e9SAndroid Build Coastguard Worker    "fork is not supported. Dying (set die_after_fork=0 to override)",
559*da0073e9SAndroid Build Coastguard Worker)
560*da0073e9SAndroid Build Coastguard Workerclass TestConcatDataset(TestCase):
561*da0073e9SAndroid Build Coastguard Worker    def test_concat_two_singletons(self):
562*da0073e9SAndroid Build Coastguard Worker        result = ConcatDataset([[0], [1]])
563*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, len(result))
564*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, result[0])
565*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, result[1])
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker    def test_concat_two_non_singletons(self):
568*da0073e9SAndroid Build Coastguard Worker        result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
569*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(10, len(result))
570*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, result[0])
571*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(5, result[5])
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker    def test_concat_two_non_singletons_with_empty(self):
574*da0073e9SAndroid Build Coastguard Worker        # Adding an empty dataset somewhere is correctly handled
575*da0073e9SAndroid Build Coastguard Worker        result = ConcatDataset([[0, 1, 2, 3, 4], [], [5, 6, 7, 8, 9]])
576*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(10, len(result))
577*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, result[0])
578*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(5, result[5])
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker    def test_concat_raises_index_error(self):
581*da0073e9SAndroid Build Coastguard Worker        result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
582*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
583*da0073e9SAndroid Build Coastguard Worker            # this one goes to 11
584*da0073e9SAndroid Build Coastguard Worker            result[11]
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker    def test_add_dataset(self):
587*da0073e9SAndroid Build Coastguard Worker        d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
588*da0073e9SAndroid Build Coastguard Worker        d2 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
589*da0073e9SAndroid Build Coastguard Worker        d3 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
590*da0073e9SAndroid Build Coastguard Worker        result = d1 + d2 + d3
591*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(21, len(result))
592*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, (d1[0][0] - result[0][0]).abs().sum())
593*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, (d2[0][0] - result[7][0]).abs().sum())
594*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker    def test_iterable_dataset_err(self):
597*da0073e9SAndroid Build Coastguard Worker        d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
598*da0073e9SAndroid Build Coastguard Worker        it1 = CountingIterableDataset(5)
599*da0073e9SAndroid Build Coastguard Worker        it2 = CountingIterableDataset(10)
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
602*da0073e9SAndroid Build Coastguard Worker            ConcatDataset([d1, it2, it1])
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
605*da0073e9SAndroid Build Coastguard Worker            ConcatDataset([it2])
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
608*da0073e9SAndroid Build Coastguard Worker            ConcatDataset([it1, d1])
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker
611*da0073e9SAndroid Build Coastguard Worker# takes in dummy var so this can also be used as a `worker_init_fn`
612*da0073e9SAndroid Build Coastguard Workerdef set_faulthander_if_available(_=None):
613*da0073e9SAndroid Build Coastguard Worker    faulthandler.enable(sys.__stderr__)
614*da0073e9SAndroid Build Coastguard Worker    if not IS_WINDOWS:
615*da0073e9SAndroid Build Coastguard Worker        # windows does not have faulthandler.register
616*da0073e9SAndroid Build Coastguard Worker        # chain=False prevents the default behavior of killing the process
617*da0073e9SAndroid Build Coastguard Worker        faulthandler.register(signal.SIGUSR1, file=sys.__stderr__, chain=False)
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Workerset_faulthander_if_available()
621*da0073e9SAndroid Build Coastguard Worker
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker# Process `pid` must have called `set_faulthander_if_available`
624*da0073e9SAndroid Build Coastguard Workerdef print_traces_of_all_threads(pid):
625*da0073e9SAndroid Build Coastguard Worker    if not IS_WINDOWS:
626*da0073e9SAndroid Build Coastguard Worker        # use the custom signal if available
627*da0073e9SAndroid Build Coastguard Worker        os.kill(pid, signal.SIGUSR1)
628*da0073e9SAndroid Build Coastguard Worker    else:
629*da0073e9SAndroid Build Coastguard Worker        # otherwise we can still use the handler given by faulthandler.enable()
630*da0073e9SAndroid Build Coastguard Worker        # at the cost of killing the process.
631*da0073e9SAndroid Build Coastguard Worker        os.kill(pid, signal.SIGSEGV)
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker    # wait in parent process to give subprocess some time to print
634*da0073e9SAndroid Build Coastguard Worker    time.sleep(5)
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker# The following `ErrorTrackingProcess` stores the first encountered exception in
638*da0073e9SAndroid Build Coastguard Worker# its `.exception` attribute.
639*da0073e9SAndroid Build Coastguard Worker# Inspired by https://stackoverflow.com/a/33599967
640*da0073e9SAndroid Build Coastguard Workerclass ErrorTrackingProcess(mp.Process):
641*da0073e9SAndroid Build Coastguard Worker    # Why no *args?
642*da0073e9SAndroid Build Coastguard Worker    #   py2 doesn't support def fn(x, *args, key=val, **kwargs)
643*da0073e9SAndroid Build Coastguard Worker    # Setting disable_stderr=True may generate a lot of unrelated error outputs
644*da0073e9SAndroid Build Coastguard Worker    # but could be helpful for debugging.
645*da0073e9SAndroid Build Coastguard Worker    def __init__(self, disable_stderr=True, **kwargs):
646*da0073e9SAndroid Build Coastguard Worker        super().__init__(**kwargs)
647*da0073e9SAndroid Build Coastguard Worker        self._pconn, self._cconn = mp.Pipe()
648*da0073e9SAndroid Build Coastguard Worker        self._exception = None
649*da0073e9SAndroid Build Coastguard Worker        self.disable_stderr = disable_stderr
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker    def run(self):
652*da0073e9SAndroid Build Coastguard Worker        set_faulthander_if_available()
653*da0073e9SAndroid Build Coastguard Worker        if self.disable_stderr:
654*da0073e9SAndroid Build Coastguard Worker            # Disable polluting stderr with errors that are supposed to happen.
655*da0073e9SAndroid Build Coastguard Worker            with open(os.devnull, "w") as devnull:
656*da0073e9SAndroid Build Coastguard Worker                os.dup2(devnull.fileno(), sys.stderr.fileno())
657*da0073e9SAndroid Build Coastguard Worker        try:
658*da0073e9SAndroid Build Coastguard Worker            super().run()
659*da0073e9SAndroid Build Coastguard Worker            self._cconn.send(None)
660*da0073e9SAndroid Build Coastguard Worker        except Exception:
661*da0073e9SAndroid Build Coastguard Worker            self._cconn.send(ExceptionWrapper(sys.exc_info()))
662*da0073e9SAndroid Build Coastguard Worker            raise
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker    def print_traces_of_all_threads(self):
665*da0073e9SAndroid Build Coastguard Worker        assert (
666*da0073e9SAndroid Build Coastguard Worker            self.is_alive()
667*da0073e9SAndroid Build Coastguard Worker        ), "can only use print_traces_of_all_threads if the process is alive"
668*da0073e9SAndroid Build Coastguard Worker        assert (
669*da0073e9SAndroid Build Coastguard Worker            not self.disable_stderr
670*da0073e9SAndroid Build Coastguard Worker        ), "do not disable stderr if you use print_traces_of_all_threads"
671*da0073e9SAndroid Build Coastguard Worker        # On platforms without `SIGUSR1`, `set_faulthander_if_available` sets
672*da0073e9SAndroid Build Coastguard Worker        # `faulthandler.enable()`, and `print_traces_of_all_threads` may kill
673*da0073e9SAndroid Build Coastguard Worker        # the process. So let's poll the exception first
674*da0073e9SAndroid Build Coastguard Worker        _ = self.exception
675*da0073e9SAndroid Build Coastguard Worker        print_traces_of_all_threads(self.pid)
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker    @property
678*da0073e9SAndroid Build Coastguard Worker    def exception(self):
679*da0073e9SAndroid Build Coastguard Worker        if self._pconn.poll():
680*da0073e9SAndroid Build Coastguard Worker            self._exception = self._pconn.recv()
681*da0073e9SAndroid Build Coastguard Worker        if self._exception is None:
682*da0073e9SAndroid Build Coastguard Worker            return None
683*da0073e9SAndroid Build Coastguard Worker        else:
684*da0073e9SAndroid Build Coastguard Worker            return self._exception.exc_type(self._exception.exc_msg)
685*da0073e9SAndroid Build Coastguard Worker
686*da0073e9SAndroid Build Coastguard Worker    # ESRCH means that os.kill can't finds alive proc
687*da0073e9SAndroid Build Coastguard Worker    def send_signal(self, signum, ignore_ESRCH=False):
688*da0073e9SAndroid Build Coastguard Worker        try:
689*da0073e9SAndroid Build Coastguard Worker            os.kill(self.pid, signum)
690*da0073e9SAndroid Build Coastguard Worker        except OSError as e:
691*da0073e9SAndroid Build Coastguard Worker            if not ignore_ESRCH or e.errno != errno.ESRCH:
692*da0073e9SAndroid Build Coastguard Worker                raise
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Workerclass ErrorDataset(Dataset):
696*da0073e9SAndroid Build Coastguard Worker    def __init__(self, size):
697*da0073e9SAndroid Build Coastguard Worker        self.size = size
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
700*da0073e9SAndroid Build Coastguard Worker        return self.size
701*da0073e9SAndroid Build Coastguard Worker
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Workerclass SegfaultDataset(Dataset):
704*da0073e9SAndroid Build Coastguard Worker    def __init__(self, size):
705*da0073e9SAndroid Build Coastguard Worker        self.size = size
706*da0073e9SAndroid Build Coastguard Worker
707*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, idx):
708*da0073e9SAndroid Build Coastguard Worker        return ctypes.string_at(0)
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
711*da0073e9SAndroid Build Coastguard Worker        return self.size
712*da0073e9SAndroid Build Coastguard Worker
713*da0073e9SAndroid Build Coastguard Worker
714*da0073e9SAndroid Build Coastguard Workerclass SleepDataset(Dataset):
715*da0073e9SAndroid Build Coastguard Worker    def __init__(self, size, sleep_sec):
716*da0073e9SAndroid Build Coastguard Worker        self.size = size
717*da0073e9SAndroid Build Coastguard Worker        self.sleep_sec = sleep_sec
718*da0073e9SAndroid Build Coastguard Worker        self.sleeped = False
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, idx):
721*da0073e9SAndroid Build Coastguard Worker        if not self.sleeped:
722*da0073e9SAndroid Build Coastguard Worker            time.sleep(self.sleep_sec)
723*da0073e9SAndroid Build Coastguard Worker            self.sleeped = True
724*da0073e9SAndroid Build Coastguard Worker        return idx
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
727*da0073e9SAndroid Build Coastguard Worker        return self.size
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Workerclass SeedDataset(Dataset):
731*da0073e9SAndroid Build Coastguard Worker    def __init__(self, size):
732*da0073e9SAndroid Build Coastguard Worker        self.size = size
733*da0073e9SAndroid Build Coastguard Worker
734*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, idx):
735*da0073e9SAndroid Build Coastguard Worker        return torch.initial_seed()
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
738*da0073e9SAndroid Build Coastguard Worker        return self.size
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Workerclass WorkerSpecificIterableDataset(IterableDataset):
742*da0073e9SAndroid Build Coastguard Worker    def __init__(self, sizes_for_all_workers):
743*da0073e9SAndroid Build Coastguard Worker        self.sizes_for_all_workers = sizes_for_all_workers
744*da0073e9SAndroid Build Coastguard Worker
745*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
746*da0073e9SAndroid Build Coastguard Worker        worker_info = torch.utils.data.get_worker_info()
747*da0073e9SAndroid Build Coastguard Worker        assert worker_info is not None
748*da0073e9SAndroid Build Coastguard Worker        return iter(range(self.sizes_for_all_workers[worker_info.id]))
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
751*da0073e9SAndroid Build Coastguard Worker        return sum(self.sizes_for_all_workers)
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker
754*da0073e9SAndroid Build Coastguard Worker# Inspired by https://stackoverflow.com/a/26703365
755*da0073e9SAndroid Build Coastguard Worker# If all workers will call `sync_once`, they will be blocked until all workers
756*da0073e9SAndroid Build Coastguard Worker# reach the call (i.e., acting like a barrier).
757*da0073e9SAndroid Build Coastguard Worker# This can be used to ensure that each worker at least processes one data.
758*da0073e9SAndroid Build Coastguard Workerclass SynchronizedDataset(Dataset):
759*da0073e9SAndroid Build Coastguard Worker    def __init__(self, size, batch_size, num_workers):
760*da0073e9SAndroid Build Coastguard Worker        assert size >= num_workers * batch_size
761*da0073e9SAndroid Build Coastguard Worker        self.count = mp.Value("i", 0, lock=True)
762*da0073e9SAndroid Build Coastguard Worker        self.barrier = mp.Semaphore(0)
763*da0073e9SAndroid Build Coastguard Worker        self.num_workers = num_workers
764*da0073e9SAndroid Build Coastguard Worker        self.size = size
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker    def sync_once(self):
767*da0073e9SAndroid Build Coastguard Worker        with self.count.get_lock():
768*da0073e9SAndroid Build Coastguard Worker            self.count.value += 1
769*da0073e9SAndroid Build Coastguard Worker            if self.count.value == self.num_workers:
770*da0073e9SAndroid Build Coastguard Worker                self.barrier.release()
771*da0073e9SAndroid Build Coastguard Worker        self.barrier.acquire()
772*da0073e9SAndroid Build Coastguard Worker        self.barrier.release()
773*da0073e9SAndroid Build Coastguard Worker
774*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, idx):
775*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
778*da0073e9SAndroid Build Coastguard Worker        return self.size
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Workerclass EmptyTensorDataset(torch.utils.data.Dataset):
782*da0073e9SAndroid Build Coastguard Worker    def __init__(self, len):
783*da0073e9SAndroid Build Coastguard Worker        self.len = len
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
786*da0073e9SAndroid Build Coastguard Worker        return self.len
787*da0073e9SAndroid Build Coastguard Worker
788*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, any):
789*da0073e9SAndroid Build Coastguard Worker        return torch.empty(0)
790*da0073e9SAndroid Build Coastguard Worker
791*da0073e9SAndroid Build Coastguard Worker
792*da0073e9SAndroid Build Coastguard Workerclass SynchronizedSeedDataset(SynchronizedDataset):
793*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, idx):
794*da0073e9SAndroid Build Coastguard Worker        self.sync_once()
795*da0073e9SAndroid Build Coastguard Worker        return torch.initial_seed()
796*da0073e9SAndroid Build Coastguard Worker
797*da0073e9SAndroid Build Coastguard Worker
798*da0073e9SAndroid Build Coastguard Workerdef _test_timeout(persistent_workers):
799*da0073e9SAndroid Build Coastguard Worker    dataset = SleepDataset(10, 3)
800*da0073e9SAndroid Build Coastguard Worker    dataloader = DataLoader(
801*da0073e9SAndroid Build Coastguard Worker        dataset,
802*da0073e9SAndroid Build Coastguard Worker        batch_size=2,
803*da0073e9SAndroid Build Coastguard Worker        num_workers=2,
804*da0073e9SAndroid Build Coastguard Worker        timeout=1,
805*da0073e9SAndroid Build Coastguard Worker        persistent_workers=persistent_workers,
806*da0073e9SAndroid Build Coastguard Worker    )
807*da0073e9SAndroid Build Coastguard Worker    _ = next(iter(dataloader))
808*da0073e9SAndroid Build Coastguard Worker
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Workerdef _test_timeout_pin_memory(persistent_workers):
811*da0073e9SAndroid Build Coastguard Worker    dataset = SleepDataset(10, 3)
812*da0073e9SAndroid Build Coastguard Worker    dataloader = DataLoader(
813*da0073e9SAndroid Build Coastguard Worker        dataset,
814*da0073e9SAndroid Build Coastguard Worker        batch_size=2,
815*da0073e9SAndroid Build Coastguard Worker        num_workers=2,
816*da0073e9SAndroid Build Coastguard Worker        timeout=1,
817*da0073e9SAndroid Build Coastguard Worker        pin_memory=True,
818*da0073e9SAndroid Build Coastguard Worker        persistent_workers=persistent_workers,
819*da0073e9SAndroid Build Coastguard Worker    )
820*da0073e9SAndroid Build Coastguard Worker    _ = next(iter(dataloader))
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker
823*da0073e9SAndroid Build Coastguard Workerdef _test_large_sampler_indices(persistent_workers):
824*da0073e9SAndroid Build Coastguard Worker    # See
825*da0073e9SAndroid Build Coastguard Worker    #   test_large_sampler_indices
826*da0073e9SAndroid Build Coastguard Worker    #   https://github.com/pytorch/pytorch/issues/48666
827*da0073e9SAndroid Build Coastguard Worker
828*da0073e9SAndroid Build Coastguard Worker    dataloader = torch.utils.data.DataLoader(
829*da0073e9SAndroid Build Coastguard Worker        EmptyTensorDataset(10000000),
830*da0073e9SAndroid Build Coastguard Worker        batch_size=40960,
831*da0073e9SAndroid Build Coastguard Worker        persistent_workers=persistent_workers,
832*da0073e9SAndroid Build Coastguard Worker        num_workers=1,
833*da0073e9SAndroid Build Coastguard Worker    )
834*da0073e9SAndroid Build Coastguard Worker
835*da0073e9SAndroid Build Coastguard Worker    it = iter(dataloader)
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker    for x in it:
838*da0073e9SAndroid Build Coastguard Worker        assert x.numel() == 0
839*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("My Error")
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker
842*da0073e9SAndroid Build Coastguard Workerdef disable_stderr(worker_id):
843*da0073e9SAndroid Build Coastguard Worker    r"""
844*da0073e9SAndroid Build Coastguard Worker    Avoids printing "ERROR: Unexpected segmentation fault encountered in worker."
845*da0073e9SAndroid Build Coastguard Worker    from workers. Since worker signal handler prints with low-level write(),
846*da0073e9SAndroid Build Coastguard Worker    this has to be done on OS level via dup.
847*da0073e9SAndroid Build Coastguard Worker
848*da0073e9SAndroid Build Coastguard Worker    This is used as worker_init_fn for test_segfault.
849*da0073e9SAndroid Build Coastguard Worker    """
850*da0073e9SAndroid Build Coastguard Worker    sys.stderr.flush()  # flush library buffers that dup2 knows nothing about
851*da0073e9SAndroid Build Coastguard Worker    # Can't use a with-block because otherwise the fd will be closed when this
852*da0073e9SAndroid Build Coastguard Worker    # function ends.
853*da0073e9SAndroid Build Coastguard Worker    with open(os.devnull, "w") as devnull:
854*da0073e9SAndroid Build Coastguard Worker        os.dup2(devnull.fileno(), sys.stderr.fileno())
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker
857*da0073e9SAndroid Build Coastguard Workerdef _test_segfault():
858*da0073e9SAndroid Build Coastguard Worker    dataset = SegfaultDataset(10)
859*da0073e9SAndroid Build Coastguard Worker    dataloader = DataLoader(
860*da0073e9SAndroid Build Coastguard Worker        dataset, batch_size=2, num_workers=2, worker_init_fn=disable_stderr
861*da0073e9SAndroid Build Coastguard Worker    )
862*da0073e9SAndroid Build Coastguard Worker    _ = next(iter(dataloader))
863*da0073e9SAndroid Build Coastguard Worker
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Workerdef _test_no_segfault():
866*da0073e9SAndroid Build Coastguard Worker    dataset = [1, 2, 3]
867*da0073e9SAndroid Build Coastguard Worker    num_threads = torch.get_num_threads()
868*da0073e9SAndroid Build Coastguard Worker    if num_threads < 4:
869*da0073e9SAndroid Build Coastguard Worker        torch.set_num_threads(4)
870*da0073e9SAndroid Build Coastguard Worker    else:
871*da0073e9SAndroid Build Coastguard Worker        torch.set_num_threads(num_threads)
872*da0073e9SAndroid Build Coastguard Worker    mp_ctx = torch.multiprocessing.get_context(method="fork")
873*da0073e9SAndroid Build Coastguard Worker    dataloader = DataLoader(
874*da0073e9SAndroid Build Coastguard Worker        dataset,
875*da0073e9SAndroid Build Coastguard Worker        num_workers=1,
876*da0073e9SAndroid Build Coastguard Worker        worker_init_fn=disable_stderr,
877*da0073e9SAndroid Build Coastguard Worker        multiprocessing_context=mp_ctx,
878*da0073e9SAndroid Build Coastguard Worker    )
879*da0073e9SAndroid Build Coastguard Worker    _ = next(iter(dataloader))
880*da0073e9SAndroid Build Coastguard Worker
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Workerclass TestProperExitDataset(Dataset):
883*da0073e9SAndroid Build Coastguard Worker    def __init__(self, size, error_event):
884*da0073e9SAndroid Build Coastguard Worker        self.size = size
885*da0073e9SAndroid Build Coastguard Worker        self.error_event = error_event
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
888*da0073e9SAndroid Build Coastguard Worker        return self.size
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, idx):
891*da0073e9SAndroid Build Coastguard Worker        worker_info = torch.utils.data.get_worker_info()
892*da0073e9SAndroid Build Coastguard Worker        if (
893*da0073e9SAndroid Build Coastguard Worker            self.error_event is not None
894*da0073e9SAndroid Build Coastguard Worker            and self.error_event.is_set()
895*da0073e9SAndroid Build Coastguard Worker            and worker_info.id == worker_info.num_workers - 1
896*da0073e9SAndroid Build Coastguard Worker        ):
897*da0073e9SAndroid Build Coastguard Worker            # only error in the last worker
898*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Worker error")
899*da0073e9SAndroid Build Coastguard Worker        return torch.tensor([idx])
900*da0073e9SAndroid Build Coastguard Worker
901*da0073e9SAndroid Build Coastguard Worker
902*da0073e9SAndroid Build Coastguard Workerclass TestProperExitIterableDataset(IterableDataset):
903*da0073e9SAndroid Build Coastguard Worker    def __init__(self, size, error_event):
904*da0073e9SAndroid Build Coastguard Worker        self.error_event = error_event
905*da0073e9SAndroid Build Coastguard Worker        self.size = size
906*da0073e9SAndroid Build Coastguard Worker        self.remaining = size
907*da0073e9SAndroid Build Coastguard Worker
908*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
909*da0073e9SAndroid Build Coastguard Worker        return self.size
910*da0073e9SAndroid Build Coastguard Worker
911*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
912*da0073e9SAndroid Build Coastguard Worker        return self
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Worker    def __next__(self):
915*da0073e9SAndroid Build Coastguard Worker        worker_info = torch.utils.data.get_worker_info()
916*da0073e9SAndroid Build Coastguard Worker        if (
917*da0073e9SAndroid Build Coastguard Worker            self.error_event is not None
918*da0073e9SAndroid Build Coastguard Worker            and self.error_event.is_set()
919*da0073e9SAndroid Build Coastguard Worker            and worker_info.id == worker_info.num_workers - 1
920*da0073e9SAndroid Build Coastguard Worker        ):
921*da0073e9SAndroid Build Coastguard Worker            # only error in the last worker
922*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Worker error")
923*da0073e9SAndroid Build Coastguard Worker        self.remaining -= 1
924*da0073e9SAndroid Build Coastguard Worker        if self.remaining < 0:
925*da0073e9SAndroid Build Coastguard Worker            raise StopIteration
926*da0073e9SAndroid Build Coastguard Worker        return torch.tensor(-1000)
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker
929*da0073e9SAndroid Build Coastguard Worker# See TestDataLoader.test_proper_exit for usage
930*da0073e9SAndroid Build Coastguard Workerdef _test_proper_exit(
931*da0073e9SAndroid Build Coastguard Worker    is_iterable_dataset,
932*da0073e9SAndroid Build Coastguard Worker    use_workers,
933*da0073e9SAndroid Build Coastguard Worker    pin_memory,
934*da0073e9SAndroid Build Coastguard Worker    exit_method,
935*da0073e9SAndroid Build Coastguard Worker    hold_iter_reference,
936*da0073e9SAndroid Build Coastguard Worker    loader_setup_event,
937*da0073e9SAndroid Build Coastguard Worker    tester_setup_event,
938*da0073e9SAndroid Build Coastguard Worker    persistent_workers,
939*da0073e9SAndroid Build Coastguard Worker):
940*da0073e9SAndroid Build Coastguard Worker    num_workers = 2 if use_workers else 0
941*da0073e9SAndroid Build Coastguard Worker
942*da0073e9SAndroid Build Coastguard Worker    if exit_method == "worker_error" or exit_method == "worker_kill":
943*da0073e9SAndroid Build Coastguard Worker        assert use_workers is True
944*da0073e9SAndroid Build Coastguard Worker
945*da0073e9SAndroid Build Coastguard Worker    if exit_method == "worker_error":
946*da0073e9SAndroid Build Coastguard Worker        worker_error_event = mp.Event()
947*da0073e9SAndroid Build Coastguard Worker    else:
948*da0073e9SAndroid Build Coastguard Worker        worker_error_event = None
949*da0073e9SAndroid Build Coastguard Worker
950*da0073e9SAndroid Build Coastguard Worker    if is_iterable_dataset:
951*da0073e9SAndroid Build Coastguard Worker        ds = TestProperExitIterableDataset(7, worker_error_event)
952*da0073e9SAndroid Build Coastguard Worker    else:
953*da0073e9SAndroid Build Coastguard Worker        ds = TestProperExitDataset(12, worker_error_event)
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker    loader = DataLoader(
956*da0073e9SAndroid Build Coastguard Worker        ds,
957*da0073e9SAndroid Build Coastguard Worker        batch_size=1,
958*da0073e9SAndroid Build Coastguard Worker        shuffle=False,
959*da0073e9SAndroid Build Coastguard Worker        num_workers=num_workers,
960*da0073e9SAndroid Build Coastguard Worker        pin_memory=pin_memory,
961*da0073e9SAndroid Build Coastguard Worker        worker_init_fn=set_faulthander_if_available,
962*da0073e9SAndroid Build Coastguard Worker        persistent_workers=persistent_workers,
963*da0073e9SAndroid Build Coastguard Worker    )
964*da0073e9SAndroid Build Coastguard Worker
965*da0073e9SAndroid Build Coastguard Worker    error_it = 2
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker    if use_workers:
968*da0073e9SAndroid Build Coastguard Worker        # 2 is the magical per-worker prefetch number...
969*da0073e9SAndroid Build Coastguard Worker        # FIXME: change this after the number becomes configurable.
970*da0073e9SAndroid Build Coastguard Worker        if is_iterable_dataset:
971*da0073e9SAndroid Build Coastguard Worker            assert len(ds) * num_workers > (error_it + 2 + 1)
972*da0073e9SAndroid Build Coastguard Worker        else:
973*da0073e9SAndroid Build Coastguard Worker            assert len(loader) > (error_it + 2 + 1) * num_workers
974*da0073e9SAndroid Build Coastguard Worker    else:
975*da0073e9SAndroid Build Coastguard Worker        if is_iterable_dataset:
976*da0073e9SAndroid Build Coastguard Worker            assert len(ds) > error_it + 1
977*da0073e9SAndroid Build Coastguard Worker        else:
978*da0073e9SAndroid Build Coastguard Worker            assert len(loader) > error_it + 1
979*da0073e9SAndroid Build Coastguard Worker
980*da0073e9SAndroid Build Coastguard Worker    it = iter(loader)
981*da0073e9SAndroid Build Coastguard Worker    if use_workers:
982*da0073e9SAndroid Build Coastguard Worker        workers = it._workers
983*da0073e9SAndroid Build Coastguard Worker
984*da0073e9SAndroid Build Coastguard Worker    def kill_pid(pid):
985*da0073e9SAndroid Build Coastguard Worker        psutil_p = psutil.Process(pid)
986*da0073e9SAndroid Build Coastguard Worker        psutil_p.kill()
987*da0073e9SAndroid Build Coastguard Worker        psutil_p.wait(JOIN_TIMEOUT)
988*da0073e9SAndroid Build Coastguard Worker        assert not psutil_p.is_running()
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker    for i, _ in enumerate(it):
991*da0073e9SAndroid Build Coastguard Worker        if i == 0:
992*da0073e9SAndroid Build Coastguard Worker            if not hold_iter_reference:
993*da0073e9SAndroid Build Coastguard Worker                del it
994*da0073e9SAndroid Build Coastguard Worker                del loader
995*da0073e9SAndroid Build Coastguard Worker            loader_setup_event.set()
996*da0073e9SAndroid Build Coastguard Worker            tester_setup_event.wait()
997*da0073e9SAndroid Build Coastguard Worker            # ensure that the workers are still alive
998*da0073e9SAndroid Build Coastguard Worker            if use_workers:
999*da0073e9SAndroid Build Coastguard Worker                for w in workers:
1000*da0073e9SAndroid Build Coastguard Worker                    assert w.is_alive()
1001*da0073e9SAndroid Build Coastguard Worker            if worker_error_event is not None:
1002*da0073e9SAndroid Build Coastguard Worker                worker_error_event.set()
1003*da0073e9SAndroid Build Coastguard Worker
1004*da0073e9SAndroid Build Coastguard Worker        if i == error_it:
1005*da0073e9SAndroid Build Coastguard Worker            if exit_method == "loader_error":
1006*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Loader error")
1007*da0073e9SAndroid Build Coastguard Worker            elif exit_method == "loader_kill":
1008*da0073e9SAndroid Build Coastguard Worker                kill_pid(os.getpid())
1009*da0073e9SAndroid Build Coastguard Worker            elif exit_method == "worker_kill":
1010*da0073e9SAndroid Build Coastguard Worker                kill_pid(workers[-1].pid)  # kill last worker
1011*da0073e9SAndroid Build Coastguard Worker
1012*da0073e9SAndroid Build Coastguard Worker    if not hold_iter_reference:
1013*da0073e9SAndroid Build Coastguard Worker        # Tries to trigger the __del__ clean-up rather than the automatic
1014*da0073e9SAndroid Build Coastguard Worker        # exiting of daemonic children. Technically it should be automatically
1015*da0073e9SAndroid Build Coastguard Worker        # triggered, but I don't want to rely on the implementation detail of
1016*da0073e9SAndroid Build Coastguard Worker        # Python gc.
1017*da0073e9SAndroid Build Coastguard Worker        gc.collect()
1018*da0073e9SAndroid Build Coastguard Worker
1019*da0073e9SAndroid Build Coastguard Worker
1020*da0073e9SAndroid Build Coastguard Workerclass TestWorkerInfoDataset(SynchronizedDataset):
1021*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, idx):
1022*da0073e9SAndroid Build Coastguard Worker        self.sync_once()
1023*da0073e9SAndroid Build Coastguard Worker        return torch.tensor(self.value)
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker# Should be used as worker_init_fn with TestWorkerInfoDataset.
1027*da0073e9SAndroid Build Coastguard Worker# See _test_get_worker_info below for usage.
1028*da0073e9SAndroid Build Coastguard Workerdef _test_worker_info_init_fn(worker_id):
1029*da0073e9SAndroid Build Coastguard Worker    worker_info = torch.utils.data.get_worker_info()
1030*da0073e9SAndroid Build Coastguard Worker    assert (
1031*da0073e9SAndroid Build Coastguard Worker        worker_id == worker_info.id
1032*da0073e9SAndroid Build Coastguard Worker    ), "worker_init_fn and worker_info should have consistent id"
1033*da0073e9SAndroid Build Coastguard Worker    assert (
1034*da0073e9SAndroid Build Coastguard Worker        worker_id < worker_info.num_workers
1035*da0073e9SAndroid Build Coastguard Worker    ), "worker_init_fn and worker_info should have valid id"
1036*da0073e9SAndroid Build Coastguard Worker    assert (
1037*da0073e9SAndroid Build Coastguard Worker        worker_info.seed == torch.initial_seed()
1038*da0073e9SAndroid Build Coastguard Worker    ), "worker_init_fn and worker_info should have consistent seed"
1039*da0073e9SAndroid Build Coastguard Worker    dataset = worker_info.dataset
1040*da0073e9SAndroid Build Coastguard Worker    assert isinstance(
1041*da0073e9SAndroid Build Coastguard Worker        dataset, TestWorkerInfoDataset
1042*da0073e9SAndroid Build Coastguard Worker    ), "worker_info should have correct dataset copy"
1043*da0073e9SAndroid Build Coastguard Worker    assert not hasattr(dataset, "value"), "worker_info should have correct dataset copy"
1044*da0073e9SAndroid Build Coastguard Worker    # test that WorkerInfo attributes are read-only
1045*da0073e9SAndroid Build Coastguard Worker    try:
1046*da0073e9SAndroid Build Coastguard Worker        worker_info.id = 3999
1047*da0073e9SAndroid Build Coastguard Worker    except RuntimeError as e:
1048*da0073e9SAndroid Build Coastguard Worker        assert str(e) == "Cannot assign attributes to WorkerInfo objects"
1049*da0073e9SAndroid Build Coastguard Worker    try:
1050*da0073e9SAndroid Build Coastguard Worker        worker_info.a = 3
1051*da0073e9SAndroid Build Coastguard Worker    except RuntimeError as e:
1052*da0073e9SAndroid Build Coastguard Worker        assert str(e) == "Cannot assign attributes to WorkerInfo objects"
1053*da0073e9SAndroid Build Coastguard Worker    for k in ["id", "num_workers", "seed", "dataset"]:
1054*da0073e9SAndroid Build Coastguard Worker        assert f"{k}=" in repr(worker_info)
1055*da0073e9SAndroid Build Coastguard Worker    dataset.value = [worker_id, os.getpid()]
1056*da0073e9SAndroid Build Coastguard Worker
1057*da0073e9SAndroid Build Coastguard Worker
1058*da0073e9SAndroid Build Coastguard Workerdef _test_get_worker_info():
1059*da0073e9SAndroid Build Coastguard Worker    # get_worker_info returns None in main proc
1060*da0073e9SAndroid Build Coastguard Worker    assert torch.utils.data.get_worker_info() is None
1061*da0073e9SAndroid Build Coastguard Worker    num_workers = 2
1062*da0073e9SAndroid Build Coastguard Worker    batch_size = 2
1063*da0073e9SAndroid Build Coastguard Worker    dataset = TestWorkerInfoDataset(6, batch_size, num_workers)
1064*da0073e9SAndroid Build Coastguard Worker    dataloader = DataLoader(
1065*da0073e9SAndroid Build Coastguard Worker        dataset,
1066*da0073e9SAndroid Build Coastguard Worker        batch_size=batch_size,
1067*da0073e9SAndroid Build Coastguard Worker        num_workers=num_workers,
1068*da0073e9SAndroid Build Coastguard Worker        worker_init_fn=_test_worker_info_init_fn,
1069*da0073e9SAndroid Build Coastguard Worker    )
1070*da0073e9SAndroid Build Coastguard Worker    it = iter(dataloader)
1071*da0073e9SAndroid Build Coastguard Worker    data = []
1072*da0073e9SAndroid Build Coastguard Worker    for d in it:
1073*da0073e9SAndroid Build Coastguard Worker        data.append(d)  # noqa: PERF402
1074*da0073e9SAndroid Build Coastguard Worker    worker_pids = [w.pid for w in it._workers]
1075*da0073e9SAndroid Build Coastguard Worker    data = torch.cat(data, 0)
1076*da0073e9SAndroid Build Coastguard Worker    for d in data:
1077*da0073e9SAndroid Build Coastguard Worker        # each `d` is a [worker_id, worker_pid] pair, which is set in
1078*da0073e9SAndroid Build Coastguard Worker        # _test_worker_info_init_fn
1079*da0073e9SAndroid Build Coastguard Worker        assert d[1] == worker_pids[d[0]]
1080*da0073e9SAndroid Build Coastguard Worker    # get_worker_info returns None in main proc after data loading
1081*da0073e9SAndroid Build Coastguard Worker    assert torch.utils.data.get_worker_info() is None
1082*da0073e9SAndroid Build Coastguard Worker    # main proc dataset was never assigned this attribute
1083*da0073e9SAndroid Build Coastguard Worker    assert not hasattr(dataset, "value")
1084*da0073e9SAndroid Build Coastguard Worker    try:
1085*da0073e9SAndroid Build Coastguard Worker        _ = dataset[0]
1086*da0073e9SAndroid Build Coastguard Worker    except AttributeError:
1087*da0073e9SAndroid Build Coastguard Worker        return
1088*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError("Expected AttributeError")
1089*da0073e9SAndroid Build Coastguard Worker
1090*da0073e9SAndroid Build Coastguard Worker
1091*da0073e9SAndroid Build Coastguard Worker# test custom init function
1092*da0073e9SAndroid Build Coastguard Workerdef init_fn(worker_id):
1093*da0073e9SAndroid Build Coastguard Worker    torch.manual_seed(12345)
1094*da0073e9SAndroid Build Coastguard Worker
1095*da0073e9SAndroid Build Coastguard Worker
1096*da0073e9SAndroid Build Coastguard Worker# used with test_error_in_init
1097*da0073e9SAndroid Build Coastguard Workerclass ErrorIterableDataset(IterableDataset):
1098*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
1099*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("Error in __iter__")
1100*da0073e9SAndroid Build Coastguard Worker
1101*da0073e9SAndroid Build Coastguard Worker
1102*da0073e9SAndroid Build Coastguard Worker# used with test_error_in_init
1103*da0073e9SAndroid Build Coastguard Workerdef error_worker_init_fn(_):
1104*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError("Error in worker_init_fn")
1105*da0073e9SAndroid Build Coastguard Worker
1106*da0073e9SAndroid Build Coastguard Worker
1107*da0073e9SAndroid Build Coastguard Workerclass BulkLoadingDataset(Dataset):
1108*da0073e9SAndroid Build Coastguard Worker    def __init__(self, length):
1109*da0073e9SAndroid Build Coastguard Worker        self.length = length
1110*da0073e9SAndroid Build Coastguard Worker
1111*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, indices):
1112*da0073e9SAndroid Build Coastguard Worker        assert isinstance(indices, (list, tuple))
1113*da0073e9SAndroid Build Coastguard Worker        return torch.as_tensor(indices)
1114*da0073e9SAndroid Build Coastguard Worker
1115*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
1116*da0073e9SAndroid Build Coastguard Worker        return self.length
1117*da0073e9SAndroid Build Coastguard Worker
1118*da0073e9SAndroid Build Coastguard Worker
1119*da0073e9SAndroid Build Coastguard Workerclass BulkLoadingSampler(torch.utils.data.Sampler):
1120*da0073e9SAndroid Build Coastguard Worker    def __init__(self, dataset, batch_size):
1121*da0073e9SAndroid Build Coastguard Worker        self.dataset = dataset
1122*da0073e9SAndroid Build Coastguard Worker        self.batch_size = batch_size
1123*da0073e9SAndroid Build Coastguard Worker
1124*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
1125*da0073e9SAndroid Build Coastguard Worker        for x in torch.randperm(len(self.dataset)).split(self.batch_size):
1126*da0073e9SAndroid Build Coastguard Worker            yield x.tolist()
1127*da0073e9SAndroid Build Coastguard Worker
1128*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
1129*da0073e9SAndroid Build Coastguard Worker        return int(math.ceil(len(self.dataset) / float(self.batch_size)))
1130*da0073e9SAndroid Build Coastguard Worker
1131*da0073e9SAndroid Build Coastguard Worker
1132*da0073e9SAndroid Build Coastguard Workerclass TestMultiEpochDataset(IterableDataset):
1133*da0073e9SAndroid Build Coastguard Worker    def __init__(self, length):
1134*da0073e9SAndroid Build Coastguard Worker        self.length = length
1135*da0073e9SAndroid Build Coastguard Worker
1136*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
1137*da0073e9SAndroid Build Coastguard Worker        worker_info = torch.utils.data.get_worker_info()
1138*da0073e9SAndroid Build Coastguard Worker        assert worker_info is not None
1139*da0073e9SAndroid Build Coastguard Worker        worker_id = worker_info.id
1140*da0073e9SAndroid Build Coastguard Worker        for idx in range(self.length // worker_info.num_workers):
1141*da0073e9SAndroid Build Coastguard Worker            yield worker_id
1142*da0073e9SAndroid Build Coastguard Worker
1143*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
1144*da0073e9SAndroid Build Coastguard Worker        return self.length
1145*da0073e9SAndroid Build Coastguard Worker
1146*da0073e9SAndroid Build Coastguard Worker
1147*da0073e9SAndroid Build Coastguard Workerclass CustomList(list):
1148*da0073e9SAndroid Build Coastguard Worker    pass
1149*da0073e9SAndroid Build Coastguard Worker
1150*da0073e9SAndroid Build Coastguard Worker
1151*da0073e9SAndroid Build Coastguard Workerclass CustomDict(dict):
1152*da0073e9SAndroid Build Coastguard Worker    pass
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Worker
1155*da0073e9SAndroid Build Coastguard Workerdef row_processor(row):
1156*da0073e9SAndroid Build Coastguard Worker    return np.add(row, 1)
1157*da0073e9SAndroid Build Coastguard Worker
1158*da0073e9SAndroid Build Coastguard Worker
1159*da0073e9SAndroid Build Coastguard Workerdef filter_len(row):
1160*da0073e9SAndroid Build Coastguard Worker    return len(row) == 4
1161*da0073e9SAndroid Build Coastguard Worker
1162*da0073e9SAndroid Build Coastguard Worker
1163*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
1164*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
1165*da0073e9SAndroid Build Coastguard Worker    "Fails with TSAN with the following error: starting new threads after multi-threaded "
1166*da0073e9SAndroid Build Coastguard Worker    "fork is not supported. Dying (set die_after_fork=0 to override)",
1167*da0073e9SAndroid Build Coastguard Worker)
1168*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
1169*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ASAN,
1170*da0073e9SAndroid Build Coastguard Worker    "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223",
1171*da0073e9SAndroid Build Coastguard Worker)
1172*da0073e9SAndroid Build Coastguard Workerclass TestDataLoader(TestCase):
1173*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
1174*da0073e9SAndroid Build Coastguard Worker        super().setUp()
1175*da0073e9SAndroid Build Coastguard Worker        self.data = torch.randn(100, 2, 3, 5)
1176*da0073e9SAndroid Build Coastguard Worker        self.labels = torch.randperm(50).repeat(2)
1177*da0073e9SAndroid Build Coastguard Worker        self.dataset = TensorDataset(self.data, self.labels)
1178*da0073e9SAndroid Build Coastguard Worker        self.persistent_workers = False
1179*da0073e9SAndroid Build Coastguard Worker
1180*da0073e9SAndroid Build Coastguard Worker    def _get_data_loader(self, dataset, **kwargs):
1181*da0073e9SAndroid Build Coastguard Worker        persistent_workers = kwargs.get("persistent_workers", self.persistent_workers)
1182*da0073e9SAndroid Build Coastguard Worker        if persistent_workers and kwargs.get("num_workers", 0) == 0:
1183*da0073e9SAndroid Build Coastguard Worker            persistent_workers = False
1184*da0073e9SAndroid Build Coastguard Worker        kwargs["persistent_workers"] = persistent_workers
1185*da0073e9SAndroid Build Coastguard Worker        return DataLoader(dataset, **kwargs)
1186*da0073e9SAndroid Build Coastguard Worker
1187*da0073e9SAndroid Build Coastguard Worker    def _test_sequential(self, loader):
1188*da0073e9SAndroid Build Coastguard Worker        batch_size = loader.batch_size
1189*da0073e9SAndroid Build Coastguard Worker        if batch_size is None:
1190*da0073e9SAndroid Build Coastguard Worker            for idx, (sample, target) in enumerate(loader):
1191*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sample, self.data[idx])
1192*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(target, self.labels[idx])
1193*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx, len(self.dataset) - 1)
1194*da0073e9SAndroid Build Coastguard Worker        else:
1195*da0073e9SAndroid Build Coastguard Worker            for i, (sample, target) in enumerate(loader):
1196*da0073e9SAndroid Build Coastguard Worker                idx = i * batch_size
1197*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sample, self.data[idx : idx + batch_size])
1198*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(target, self.labels[idx : idx + batch_size])
1199*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
1200*da0073e9SAndroid Build Coastguard Worker
1201*da0073e9SAndroid Build Coastguard Worker    def _test_shuffle(self, loader):
1202*da0073e9SAndroid Build Coastguard Worker        found_data = dict.fromkeys(range(self.data.size(0)), 0)
1203*da0073e9SAndroid Build Coastguard Worker        found_labels = dict.fromkeys(range(self.labels.size(0)), 0)
1204*da0073e9SAndroid Build Coastguard Worker        batch_size = loader.batch_size
1205*da0073e9SAndroid Build Coastguard Worker        if batch_size is None:
1206*da0073e9SAndroid Build Coastguard Worker            for i, (batch_samples, batch_targets) in enumerate(loader):
1207*da0073e9SAndroid Build Coastguard Worker                sample, target = (batch_samples, batch_targets)
1208*da0073e9SAndroid Build Coastguard Worker                for data_point_idx, data_point in enumerate(self.data):
1209*da0073e9SAndroid Build Coastguard Worker                    if data_point.eq(sample).all():
1210*da0073e9SAndroid Build Coastguard Worker                        self.assertFalse(found_data[data_point_idx])
1211*da0073e9SAndroid Build Coastguard Worker                        found_data[data_point_idx] += 1
1212*da0073e9SAndroid Build Coastguard Worker                        break
1213*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(target, self.labels[data_point_idx])
1214*da0073e9SAndroid Build Coastguard Worker                found_labels[data_point_idx] += 1
1215*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sum(found_data.values()), (i + 1))
1216*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sum(found_labels.values()), (i + 1))
1217*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(i, (len(self.dataset) - 1))
1218*da0073e9SAndroid Build Coastguard Worker        else:
1219*da0073e9SAndroid Build Coastguard Worker            for i, (batch_samples, batch_targets) in enumerate(loader):
1220*da0073e9SAndroid Build Coastguard Worker                for sample, target in zip(batch_samples, batch_targets):
1221*da0073e9SAndroid Build Coastguard Worker                    for data_point_idx, data_point in enumerate(self.data):
1222*da0073e9SAndroid Build Coastguard Worker                        if data_point.eq(sample).all():
1223*da0073e9SAndroid Build Coastguard Worker                            self.assertFalse(found_data[data_point_idx])
1224*da0073e9SAndroid Build Coastguard Worker                            found_data[data_point_idx] += 1
1225*da0073e9SAndroid Build Coastguard Worker                            break
1226*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(target, self.labels[data_point_idx])
1227*da0073e9SAndroid Build Coastguard Worker                    found_labels[data_point_idx] += 1
1228*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sum(found_data.values()), (i + 1) * batch_size)
1229*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size)
1230*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
1231*da0073e9SAndroid Build Coastguard Worker
1232*da0073e9SAndroid Build Coastguard Worker    def _test_error(self, loader):
1233*da0073e9SAndroid Build Coastguard Worker        it = iter(loader)
1234*da0073e9SAndroid Build Coastguard Worker        errors = 0
1235*da0073e9SAndroid Build Coastguard Worker        while True:
1236*da0073e9SAndroid Build Coastguard Worker            try:
1237*da0073e9SAndroid Build Coastguard Worker                next(it)
1238*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError:
1239*da0073e9SAndroid Build Coastguard Worker                errors += 1
1240*da0073e9SAndroid Build Coastguard Worker            except StopIteration:
1241*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1242*da0073e9SAndroid Build Coastguard Worker                    errors, math.ceil(float(len(loader.dataset)) / loader.batch_size)
1243*da0073e9SAndroid Build Coastguard Worker                )
1244*da0073e9SAndroid Build Coastguard Worker                return
1245*da0073e9SAndroid Build Coastguard Worker
1246*da0073e9SAndroid Build Coastguard Worker    def test_error_in_init(self):
1247*da0073e9SAndroid Build Coastguard Worker        for num_workers in [0, 2]:
1248*da0073e9SAndroid Build Coastguard Worker            loader = self._get_data_loader(
1249*da0073e9SAndroid Build Coastguard Worker                ErrorIterableDataset(), num_workers=num_workers
1250*da0073e9SAndroid Build Coastguard Worker            )
1251*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Error in __iter__"):
1252*da0073e9SAndroid Build Coastguard Worker                list(iter(loader))
1253*da0073e9SAndroid Build Coastguard Worker
1254*da0073e9SAndroid Build Coastguard Worker        loader = self._get_data_loader(
1255*da0073e9SAndroid Build Coastguard Worker            self.dataset, num_workers=2, worker_init_fn=error_worker_init_fn
1256*da0073e9SAndroid Build Coastguard Worker        )
1257*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Error in worker_init_fn"):
1258*da0073e9SAndroid Build Coastguard Worker            list(iter(loader))
1259*da0073e9SAndroid Build Coastguard Worker
1260*da0073e9SAndroid Build Coastguard Worker    def test_typing(self):
1261*da0073e9SAndroid Build Coastguard Worker        from typing import List
1262*da0073e9SAndroid Build Coastguard Worker
1263*da0073e9SAndroid Build Coastguard Worker        # Make sure there is no TypeError
1264*da0073e9SAndroid Build Coastguard Worker
1265*da0073e9SAndroid Build Coastguard Worker        class SomeDatasetClass(Dataset[List[torch.Tensor]]):
1266*da0073e9SAndroid Build Coastguard Worker            pass
1267*da0073e9SAndroid Build Coastguard Worker
1268*da0073e9SAndroid Build Coastguard Worker        def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]:
1269*da0073e9SAndroid Build Coastguard Worker            pass
1270*da0073e9SAndroid Build Coastguard Worker
1271*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
1272*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows")
1273*da0073e9SAndroid Build Coastguard Worker    def test_fd_limit_exceeded(self):
1274*da0073e9SAndroid Build Coastguard Worker        # See NOTE [ DataLoader on Linux and open files limit ]
1275*da0073e9SAndroid Build Coastguard Worker        import subprocess
1276*da0073e9SAndroid Build Coastguard Worker
1277*da0073e9SAndroid Build Coastguard Worker        subprocess.check_output(
1278*da0073e9SAndroid Build Coastguard Worker            [
1279*da0073e9SAndroid Build Coastguard Worker                sys.executable,
1280*da0073e9SAndroid Build Coastguard Worker                "-c",
1281*da0073e9SAndroid Build Coastguard Worker                """\
1282*da0073e9SAndroid Build Coastguard Workerimport torch
1283*da0073e9SAndroid Build Coastguard Workerimport resource
1284*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data import DataLoader, IterableDataset
1285*da0073e9SAndroid Build Coastguard Worker
1286*da0073e9SAndroid Build Coastguard Workerclass RandomDataset(IterableDataset):
1287*da0073e9SAndroid Build Coastguard Worker    def __init__(self, len, size):
1288*da0073e9SAndroid Build Coastguard Worker        super(RandomDataset).__init__()
1289*da0073e9SAndroid Build Coastguard Worker        self.len = len
1290*da0073e9SAndroid Build Coastguard Worker        self.size = size
1291*da0073e9SAndroid Build Coastguard Worker
1292*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
1293*da0073e9SAndroid Build Coastguard Worker        return self
1294*da0073e9SAndroid Build Coastguard Worker
1295*da0073e9SAndroid Build Coastguard Worker    def __next__(self):
1296*da0073e9SAndroid Build Coastguard Worker        if self.len <= 0:
1297*da0073e9SAndroid Build Coastguard Worker            raise StopIteration
1298*da0073e9SAndroid Build Coastguard Worker        self.len -= 1
1299*da0073e9SAndroid Build Coastguard Worker        return torch.randn(self.size)
1300*da0073e9SAndroid Build Coastguard Worker
1301*da0073e9SAndroid Build Coastguard Workertry:
1302*da0073e9SAndroid Build Coastguard Worker    keep_fds_alive = []
1303*da0073e9SAndroid Build Coastguard Worker    resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100))
1304*da0073e9SAndroid Build Coastguard Worker    for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork",
1305*da0073e9SAndroid Build Coastguard Worker                               num_workers=1):
1306*da0073e9SAndroid Build Coastguard Worker      random_t.max(dim=0)
1307*da0073e9SAndroid Build Coastguard Worker      keep_fds_alive.append(random_t)
1308*da0073e9SAndroid Build Coastguard Workerexcept RuntimeError as e:
1309*da0073e9SAndroid Build Coastguard Worker    assert "ulimit -n" in str(e)
1310*da0073e9SAndroid Build Coastguard Worker    assert "set_sharing_strategy" in str(e)
1311*da0073e9SAndroid Build Coastguard Worker""",
1312*da0073e9SAndroid Build Coastguard Worker            ]
1313*da0073e9SAndroid Build Coastguard Worker        )
1314*da0073e9SAndroid Build Coastguard Worker
1315*da0073e9SAndroid Build Coastguard Worker    def test_invalid_assign_after_init(self):
1316*da0073e9SAndroid Build Coastguard Worker        dl = self._get_data_loader(self.dataset)
1317*da0073e9SAndroid Build Coastguard Worker        for attr in ("batch_size", "sampler", "batch_sampler", "drop_last", "dataset"):
1318*da0073e9SAndroid Build Coastguard Worker
1319*da0073e9SAndroid Build Coastguard Worker            def fn():
1320*da0073e9SAndroid Build Coastguard Worker                setattr(dl, attr, {})
1321*da0073e9SAndroid Build Coastguard Worker
1322*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(ValueError, fn)
1323*da0073e9SAndroid Build Coastguard Worker
1324*da0073e9SAndroid Build Coastguard Worker    def test_sequential_nonbatch(self):
1325*da0073e9SAndroid Build Coastguard Worker        self._test_sequential(self._get_data_loader(self.dataset, batch_size=None))
1326*da0073e9SAndroid Build Coastguard Worker
1327*da0073e9SAndroid Build Coastguard Worker    def test_sequential_batch(self):
1328*da0073e9SAndroid Build Coastguard Worker        self._test_sequential(self._get_data_loader(self.dataset))
1329*da0073e9SAndroid Build Coastguard Worker        self._test_sequential(self._get_data_loader(self.dataset, batch_size=2))
1330*da0073e9SAndroid Build Coastguard Worker
1331*da0073e9SAndroid Build Coastguard Worker    def test_bulk_loading_nobatch(self):
1332*da0073e9SAndroid Build Coastguard Worker        n = 35
1333*da0073e9SAndroid Build Coastguard Worker        bs = 4
1334*da0073e9SAndroid Build Coastguard Worker        ds = BulkLoadingDataset(n)
1335*da0073e9SAndroid Build Coastguard Worker        sampler = BulkLoadingSampler(ds, batch_size=4)
1336*da0073e9SAndroid Build Coastguard Worker
1337*da0073e9SAndroid Build Coastguard Worker        for num_workers in [0, 4]:
1338*da0073e9SAndroid Build Coastguard Worker            dl = self._get_data_loader(
1339*da0073e9SAndroid Build Coastguard Worker                ds,
1340*da0073e9SAndroid Build Coastguard Worker                num_workers=num_workers,
1341*da0073e9SAndroid Build Coastguard Worker                batch_size=None,
1342*da0073e9SAndroid Build Coastguard Worker                sampler=sampler,
1343*da0073e9SAndroid Build Coastguard Worker                pin_memory=TEST_CUDA,
1344*da0073e9SAndroid Build Coastguard Worker            )
1345*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(dl._auto_collation)
1346*da0073e9SAndroid Build Coastguard Worker            samples = list(dl)
1347*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(samples[0].is_pinned(), TEST_CUDA)
1348*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(set(torch.cat(samples, 0).tolist()), set(range(n)))
1349*da0073e9SAndroid Build Coastguard Worker
1350*da0073e9SAndroid Build Coastguard Worker    def test_growing_dataset(self):
1351*da0073e9SAndroid Build Coastguard Worker        dataset = [torch.ones(4) for _ in range(4)]
1352*da0073e9SAndroid Build Coastguard Worker        dataloader_seq = self._get_data_loader(dataset, shuffle=False)
1353*da0073e9SAndroid Build Coastguard Worker        dataloader_shuffle = self._get_data_loader(dataset, shuffle=True)
1354*da0073e9SAndroid Build Coastguard Worker        dataset.append(torch.ones(4))
1355*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(dataloader_seq), 5)
1356*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(dataloader_shuffle), 5)
1357*da0073e9SAndroid Build Coastguard Worker
1358*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1359*da0073e9SAndroid Build Coastguard Worker    def test_sequential_pin_memory(self):
1360*da0073e9SAndroid Build Coastguard Worker        loader = self._get_data_loader(self.dataset, batch_size=2, pin_memory=True)
1361*da0073e9SAndroid Build Coastguard Worker        for input, target in loader:
1362*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(input.is_pinned())
1363*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(target.is_pinned())
1364*da0073e9SAndroid Build Coastguard Worker
1365*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1366*da0073e9SAndroid Build Coastguard Worker    def test_multiple_dataloaders(self):
1367*da0073e9SAndroid Build Coastguard Worker        for multiprocessing_context in supported_multiprocessing_contexts:
1368*da0073e9SAndroid Build Coastguard Worker            loader1_it = iter(self._get_data_loader(self.dataset, num_workers=1))
1369*da0073e9SAndroid Build Coastguard Worker            loader2_it = iter(
1370*da0073e9SAndroid Build Coastguard Worker                self._get_data_loader(
1371*da0073e9SAndroid Build Coastguard Worker                    self.dataset,
1372*da0073e9SAndroid Build Coastguard Worker                    num_workers=2,
1373*da0073e9SAndroid Build Coastguard Worker                    multiprocessing_context=multiprocessing_context,
1374*da0073e9SAndroid Build Coastguard Worker                )
1375*da0073e9SAndroid Build Coastguard Worker            )
1376*da0073e9SAndroid Build Coastguard Worker            next(loader1_it)
1377*da0073e9SAndroid Build Coastguard Worker            next(loader1_it)
1378*da0073e9SAndroid Build Coastguard Worker            next(loader2_it)
1379*da0073e9SAndroid Build Coastguard Worker            next(loader2_it)
1380*da0073e9SAndroid Build Coastguard Worker            next(loader1_it)
1381*da0073e9SAndroid Build Coastguard Worker            next(loader2_it)
1382*da0073e9SAndroid Build Coastguard Worker            del loader1_it
1383*da0073e9SAndroid Build Coastguard Worker            del loader2_it
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(True, "This test is disabled in pytorch/pytorch")
1386*da0073e9SAndroid Build Coastguard Worker    def test_segfault(self):
1387*da0073e9SAndroid Build Coastguard Worker        p = ErrorTrackingProcess(target=_test_segfault)
1388*da0073e9SAndroid Build Coastguard Worker        p.start()
1389*da0073e9SAndroid Build Coastguard Worker        p.join(JOIN_TIMEOUT)
1390*da0073e9SAndroid Build Coastguard Worker        try:
1391*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(p.is_alive())
1392*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(p.exitcode, 0)
1393*da0073e9SAndroid Build Coastguard Worker            if IS_WINDOWS:
1394*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(p.exception, OSError)
1395*da0073e9SAndroid Build Coastguard Worker                self.assertRegex(str(p.exception), r"access violation reading ")
1396*da0073e9SAndroid Build Coastguard Worker            else:
1397*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(p.exception, RuntimeError)
1398*da0073e9SAndroid Build Coastguard Worker                self.assertRegex(
1399*da0073e9SAndroid Build Coastguard Worker                    str(p.exception),
1400*da0073e9SAndroid Build Coastguard Worker                    r"DataLoader worker \(pid \d+\) is killed by signal: ",
1401*da0073e9SAndroid Build Coastguard Worker                )
1402*da0073e9SAndroid Build Coastguard Worker        finally:
1403*da0073e9SAndroid Build Coastguard Worker            p.terminate()
1404*da0073e9SAndroid Build Coastguard Worker
1405*da0073e9SAndroid Build Coastguard Worker    # Tests if the child process forked by the DataLoader segfaults due to having more than 3 threads
1406*da0073e9SAndroid Build Coastguard Worker    # in the parent process after at least one set_num_threads invocation in the parent process.
1407*da0073e9SAndroid Build Coastguard Worker    # After forking, set_num_threads(1) in the child process entails handling some inherited data-structures
1408*da0073e9SAndroid Build Coastguard Worker    # of the Caffe2 thread-pool of the parent process, culminating in a segfault.
1409*da0073e9SAndroid Build Coastguard Worker    # Reference: https://github.com/pytorch/pytorch/issues/54752
1410*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Needs fork")
1411*da0073e9SAndroid Build Coastguard Worker    def test_no_segfault(self):
1412*da0073e9SAndroid Build Coastguard Worker        p = ErrorTrackingProcess(target=_test_no_segfault)
1413*da0073e9SAndroid Build Coastguard Worker        p.start()
1414*da0073e9SAndroid Build Coastguard Worker        p.join(JOIN_TIMEOUT)
1415*da0073e9SAndroid Build Coastguard Worker        try:
1416*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(p.is_alive())
1417*da0073e9SAndroid Build Coastguard Worker            if p.exception:
1418*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(p.exception, RuntimeError)
1419*da0073e9SAndroid Build Coastguard Worker                self.assertRegex(
1420*da0073e9SAndroid Build Coastguard Worker                    str(p.exception),
1421*da0073e9SAndroid Build Coastguard Worker                    r"DataLoader worker \(pid \d+\) is killed by signal: ",
1422*da0073e9SAndroid Build Coastguard Worker                )
1423*da0073e9SAndroid Build Coastguard Worker                self.fail("Segfault occurred in worker process after fork")
1424*da0073e9SAndroid Build Coastguard Worker        finally:
1425*da0073e9SAndroid Build Coastguard Worker            p.terminate()
1426*da0073e9SAndroid Build Coastguard Worker
1427*da0073e9SAndroid Build Coastguard Worker    def test_timeout(self):
1428*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA and not NO_MULTIPROCESSING_SPAWN:
1429*da0073e9SAndroid Build Coastguard Worker            # This test runs in a subprocess, which can only initialize CUDA with spawn.
1430*da0073e9SAndroid Build Coastguard Worker            # _test_timeout_pin_memory with pin_memory=True initializes CUDA when the iterator is
1431*da0073e9SAndroid Build Coastguard Worker            # constructed.
1432*da0073e9SAndroid Build Coastguard Worker            targets = (_test_timeout, _test_timeout_pin_memory)
1433*da0073e9SAndroid Build Coastguard Worker        else:
1434*da0073e9SAndroid Build Coastguard Worker            targets = (_test_timeout,)
1435*da0073e9SAndroid Build Coastguard Worker        for target in targets:
1436*da0073e9SAndroid Build Coastguard Worker            p = ErrorTrackingProcess(target=target, args=(self.persistent_workers,))
1437*da0073e9SAndroid Build Coastguard Worker            p.start()
1438*da0073e9SAndroid Build Coastguard Worker            p.join(JOIN_TIMEOUT)
1439*da0073e9SAndroid Build Coastguard Worker            try:
1440*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(p.is_alive())
1441*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(p.exitcode, 0)
1442*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(p.exception, RuntimeError)
1443*da0073e9SAndroid Build Coastguard Worker                self.assertRegex(
1444*da0073e9SAndroid Build Coastguard Worker                    str(p.exception), r"DataLoader timed out after \d+ seconds"
1445*da0073e9SAndroid Build Coastguard Worker                )
1446*da0073e9SAndroid Build Coastguard Worker            finally:
1447*da0073e9SAndroid Build Coastguard Worker                p.terminate()
1448*da0073e9SAndroid Build Coastguard Worker
1449*da0073e9SAndroid Build Coastguard Worker    def test_large_sampler_indices(self):
1450*da0073e9SAndroid Build Coastguard Worker        # Test that the data loader cleanly exit when the process errors
1451*da0073e9SAndroid Build Coastguard Worker        #   1. having an reference to the iterator
1452*da0073e9SAndroid Build Coastguard Worker        #   2. using a sampler that yields big elements s.t. _index_queues putters block
1453*da0073e9SAndroid Build Coastguard Worker        #
1454*da0073e9SAndroid Build Coastguard Worker        # More context: https://github.com/pytorch/pytorch/issues/48666
1455*da0073e9SAndroid Build Coastguard Worker
1456*da0073e9SAndroid Build Coastguard Worker        p = ErrorTrackingProcess(
1457*da0073e9SAndroid Build Coastguard Worker            target=_test_large_sampler_indices, args=(self.persistent_workers,)
1458*da0073e9SAndroid Build Coastguard Worker        )
1459*da0073e9SAndroid Build Coastguard Worker        p.start()
1460*da0073e9SAndroid Build Coastguard Worker        p.join(JOIN_TIMEOUT)
1461*da0073e9SAndroid Build Coastguard Worker        try:
1462*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(p.is_alive())
1463*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(p.exitcode, 0)
1464*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(p.exception, RuntimeError)
1465*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(str(p.exception), r"My Error")
1466*da0073e9SAndroid Build Coastguard Worker        finally:
1467*da0073e9SAndroid Build Coastguard Worker            p.terminate()
1468*da0073e9SAndroid Build Coastguard Worker
1469*da0073e9SAndroid Build Coastguard Worker    def test_invalid_ctor_args_combinations(self):
1470*da0073e9SAndroid Build Coastguard Worker        # general
1471*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1472*da0073e9SAndroid Build Coastguard Worker            ValueError, "num_workers option should be non-negative"
1473*da0073e9SAndroid Build Coastguard Worker        ):
1474*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(self.dataset, num_workers=-1)
1475*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1476*da0073e9SAndroid Build Coastguard Worker            ValueError, "timeout option should be non-negative"
1477*da0073e9SAndroid Build Coastguard Worker        ):
1478*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(self.dataset, timeout=-1)
1479*da0073e9SAndroid Build Coastguard Worker
1480*da0073e9SAndroid Build Coastguard Worker        # disable auto-batching
1481*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1482*da0073e9SAndroid Build Coastguard Worker            ValueError,
1483*da0073e9SAndroid Build Coastguard Worker            "batch_size=None option disables auto-batching and is mutually exclusive",
1484*da0073e9SAndroid Build Coastguard Worker        ):
1485*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(self.dataset, batch_size=None, drop_last=True)
1486*da0073e9SAndroid Build Coastguard Worker
1487*da0073e9SAndroid Build Coastguard Worker        valid_ctx = list(torch.multiprocessing.get_all_start_methods())[-1]
1488*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1489*da0073e9SAndroid Build Coastguard Worker            ValueError, r"multi-process loading \(num_workers > 0\), but got"
1490*da0073e9SAndroid Build Coastguard Worker        ):
1491*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(
1492*da0073e9SAndroid Build Coastguard Worker                self.dataset, num_workers=0, multiprocessing_context=valid_ctx
1493*da0073e9SAndroid Build Coastguard Worker            )
1494*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1495*da0073e9SAndroid Build Coastguard Worker            ValueError, "should specify a valid start method in"
1496*da0073e9SAndroid Build Coastguard Worker        ):
1497*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(
1498*da0073e9SAndroid Build Coastguard Worker                self.dataset, num_workers=1, multiprocessing_context="bad"
1499*da0073e9SAndroid Build Coastguard Worker            )
1500*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1501*da0073e9SAndroid Build Coastguard Worker            TypeError, "multiprocessing_context option should be a valid context "
1502*da0073e9SAndroid Build Coastguard Worker        ):
1503*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(
1504*da0073e9SAndroid Build Coastguard Worker                self.dataset, num_workers=1, multiprocessing_context=object()
1505*da0073e9SAndroid Build Coastguard Worker            )
1506*da0073e9SAndroid Build Coastguard Worker
1507*da0073e9SAndroid Build Coastguard Worker        # map-style
1508*da0073e9SAndroid Build Coastguard Worker        sampler = torch.utils.data.SequentialSampler(self.dataset)
1509*da0073e9SAndroid Build Coastguard Worker        batch_sampler = torch.utils.data.BatchSampler(sampler, 3, False)
1510*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1511*da0073e9SAndroid Build Coastguard Worker            ValueError, "sampler option is mutually exclusive with shuffle"
1512*da0073e9SAndroid Build Coastguard Worker        ):
1513*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(
1514*da0073e9SAndroid Build Coastguard Worker                self.dataset, batch_size=11, sampler=sampler, shuffle=True
1515*da0073e9SAndroid Build Coastguard Worker            )
1516*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1517*da0073e9SAndroid Build Coastguard Worker            ValueError, "sampler option is mutually exclusive with shuffle"
1518*da0073e9SAndroid Build Coastguard Worker        ):
1519*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(
1520*da0073e9SAndroid Build Coastguard Worker                self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=True
1521*da0073e9SAndroid Build Coastguard Worker            )
1522*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1523*da0073e9SAndroid Build Coastguard Worker            ValueError, "sampler option is mutually exclusive with shuffle"
1524*da0073e9SAndroid Build Coastguard Worker        ):
1525*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(
1526*da0073e9SAndroid Build Coastguard Worker                self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=3
1527*da0073e9SAndroid Build Coastguard Worker            )
1528*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1529*da0073e9SAndroid Build Coastguard Worker            ValueError, "batch_sampler option is mutually exclusive with"
1530*da0073e9SAndroid Build Coastguard Worker        ):
1531*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(
1532*da0073e9SAndroid Build Coastguard Worker                self.dataset, batch_size=11, batch_sampler=batch_sampler
1533*da0073e9SAndroid Build Coastguard Worker            )
1534*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1535*da0073e9SAndroid Build Coastguard Worker            ValueError, "batch_sampler option is mutually exclusive with"
1536*da0073e9SAndroid Build Coastguard Worker        ):
1537*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(
1538*da0073e9SAndroid Build Coastguard Worker                self.dataset, shuffle=True, batch_sampler=batch_sampler
1539*da0073e9SAndroid Build Coastguard Worker            )
1540*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1541*da0073e9SAndroid Build Coastguard Worker            ValueError, "batch_sampler option is mutually exclusive with"
1542*da0073e9SAndroid Build Coastguard Worker        ):
1543*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(
1544*da0073e9SAndroid Build Coastguard Worker                self.dataset, drop_last=True, batch_sampler=batch_sampler
1545*da0073e9SAndroid Build Coastguard Worker            )
1546*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1547*da0073e9SAndroid Build Coastguard Worker            ValueError, "batch_sampler option is mutually exclusive with"
1548*da0073e9SAndroid Build Coastguard Worker        ):
1549*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(
1550*da0073e9SAndroid Build Coastguard Worker                self.dataset, drop_last=3, batch_sampler=batch_sampler
1551*da0073e9SAndroid Build Coastguard Worker            )
1552*da0073e9SAndroid Build Coastguard Worker
1553*da0073e9SAndroid Build Coastguard Worker        # iterable-style
1554*da0073e9SAndroid Build Coastguard Worker        dataset = CountingIterableDataset(20)
1555*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1556*da0073e9SAndroid Build Coastguard Worker            ValueError, "DataLoader with IterableDataset: expected unspecified shuffle"
1557*da0073e9SAndroid Build Coastguard Worker        ):
1558*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(dataset, shuffle=True)
1559*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1560*da0073e9SAndroid Build Coastguard Worker            ValueError, "DataLoader with IterableDataset: expected unspecified shuffle"
1561*da0073e9SAndroid Build Coastguard Worker        ):
1562*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(dataset, shuffle=3)
1563*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1564*da0073e9SAndroid Build Coastguard Worker            ValueError, "DataLoader with IterableDataset: expected unspecified sampler"
1565*da0073e9SAndroid Build Coastguard Worker        ):
1566*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(
1567*da0073e9SAndroid Build Coastguard Worker                dataset, sampler=torch.utils.data.SequentialSampler(dataset)
1568*da0073e9SAndroid Build Coastguard Worker            )
1569*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1570*da0073e9SAndroid Build Coastguard Worker            ValueError, "DataLoader with IterableDataset: expected unspecified sampler"
1571*da0073e9SAndroid Build Coastguard Worker        ):
1572*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(dataset, sampler=3)
1573*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1574*da0073e9SAndroid Build Coastguard Worker            ValueError,
1575*da0073e9SAndroid Build Coastguard Worker            "DataLoader with IterableDataset: expected unspecified batch_sampler",
1576*da0073e9SAndroid Build Coastguard Worker        ):
1577*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(
1578*da0073e9SAndroid Build Coastguard Worker                dataset,
1579*da0073e9SAndroid Build Coastguard Worker                batch_sampler=torch.utils.data.BatchSampler(
1580*da0073e9SAndroid Build Coastguard Worker                    torch.utils.data.SequentialSampler(dataset), 3, False
1581*da0073e9SAndroid Build Coastguard Worker                ),
1582*da0073e9SAndroid Build Coastguard Worker            )
1583*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1584*da0073e9SAndroid Build Coastguard Worker            ValueError,
1585*da0073e9SAndroid Build Coastguard Worker            "DataLoader with IterableDataset: expected unspecified batch_sampler",
1586*da0073e9SAndroid Build Coastguard Worker        ):
1587*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(dataset, batch_sampler=3)
1588*da0073e9SAndroid Build Coastguard Worker
1589*da0073e9SAndroid Build Coastguard Worker    def test_builtin_collection_conversion(self):
1590*da0073e9SAndroid Build Coastguard Worker        for coll_ty in (list, tuple):
1591*da0073e9SAndroid Build Coastguard Worker            for num_workers in (0, 1):
1592*da0073e9SAndroid Build Coastguard Worker                # map-style dataset
1593*da0073e9SAndroid Build Coastguard Worker                dataset = CountingDataset(20)
1594*da0073e9SAndroid Build Coastguard Worker                # no auto-batching
1595*da0073e9SAndroid Build Coastguard Worker                fetched = coll_ty(
1596*da0073e9SAndroid Build Coastguard Worker                    self._get_data_loader(
1597*da0073e9SAndroid Build Coastguard Worker                        dataset, batch_size=None, num_workers=num_workers
1598*da0073e9SAndroid Build Coastguard Worker                    )
1599*da0073e9SAndroid Build Coastguard Worker                )
1600*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(fetched, coll_ty(range(20)))
1601*da0073e9SAndroid Build Coastguard Worker                # auto-batching
1602*da0073e9SAndroid Build Coastguard Worker                fetched = coll_ty(
1603*da0073e9SAndroid Build Coastguard Worker                    self._get_data_loader(
1604*da0073e9SAndroid Build Coastguard Worker                        dataset, batch_size=2, num_workers=num_workers
1605*da0073e9SAndroid Build Coastguard Worker                    )
1606*da0073e9SAndroid Build Coastguard Worker                )
1607*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1608*da0073e9SAndroid Build Coastguard Worker                    fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2))
1609*da0073e9SAndroid Build Coastguard Worker                )
1610*da0073e9SAndroid Build Coastguard Worker
1611*da0073e9SAndroid Build Coastguard Worker                # iterable-style dataset
1612*da0073e9SAndroid Build Coastguard Worker                dataset = CountingIterableDataset(20)
1613*da0073e9SAndroid Build Coastguard Worker                # no auto-batching
1614*da0073e9SAndroid Build Coastguard Worker                fetched = coll_ty(
1615*da0073e9SAndroid Build Coastguard Worker                    self._get_data_loader(
1616*da0073e9SAndroid Build Coastguard Worker                        dataset, batch_size=None, num_workers=num_workers
1617*da0073e9SAndroid Build Coastguard Worker                    )
1618*da0073e9SAndroid Build Coastguard Worker                )
1619*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(fetched, coll_ty(range(20)))
1620*da0073e9SAndroid Build Coastguard Worker                # auto-batching
1621*da0073e9SAndroid Build Coastguard Worker                # this IterableDataset isn't configured for each worker, so for
1622*da0073e9SAndroid Build Coastguard Worker                # the equality test below to be valid, we cannot have more than 1 workers.
1623*da0073e9SAndroid Build Coastguard Worker                assert num_workers in [0, 1], "invalid test"
1624*da0073e9SAndroid Build Coastguard Worker                fetched = coll_ty(
1625*da0073e9SAndroid Build Coastguard Worker                    self._get_data_loader(
1626*da0073e9SAndroid Build Coastguard Worker                        dataset, batch_size=2, num_workers=num_workers
1627*da0073e9SAndroid Build Coastguard Worker                    )
1628*da0073e9SAndroid Build Coastguard Worker                )
1629*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1630*da0073e9SAndroid Build Coastguard Worker                    fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2))
1631*da0073e9SAndroid Build Coastguard Worker                )
1632*da0073e9SAndroid Build Coastguard Worker
1633*da0073e9SAndroid Build Coastguard Worker    def test_iterable_style_dataset(self):
1634*da0073e9SAndroid Build Coastguard Worker        # [no auto-batching] single process loading
1635*da0073e9SAndroid Build Coastguard Worker        dataset = CountingIterableDataset(20)
1636*da0073e9SAndroid Build Coastguard Worker        dataloader = self._get_data_loader(dataset, batch_size=None)
1637*da0073e9SAndroid Build Coastguard Worker        fetched = list(dataloader)
1638*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(fetched), 20)
1639*da0073e9SAndroid Build Coastguard Worker        for i, d in enumerate(fetched):
1640*da0073e9SAndroid Build Coastguard Worker            # non-batched should not convert ints into tensors
1641*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(d, int)
1642*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(d, i)
1643*da0073e9SAndroid Build Coastguard Worker        # DataLoader should match len of the iterable-style dataset (if implemented)
1644*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(dataloader), len(dataset))
1645*da0073e9SAndroid Build Coastguard Worker
1646*da0073e9SAndroid Build Coastguard Worker        # [no auto-batching] multiprocessing loading
1647*da0073e9SAndroid Build Coastguard Worker        num_workers = 3
1648*da0073e9SAndroid Build Coastguard Worker        sizes_for_all_workers = [0, 4, 20]
1649*da0073e9SAndroid Build Coastguard Worker        expected = sorted(
1650*da0073e9SAndroid Build Coastguard Worker            functools.reduce(
1651*da0073e9SAndroid Build Coastguard Worker                operator.iadd, (list(range(s)) for s in sizes_for_all_workers), []
1652*da0073e9SAndroid Build Coastguard Worker            )
1653*da0073e9SAndroid Build Coastguard Worker        )
1654*da0073e9SAndroid Build Coastguard Worker        assert len(sizes_for_all_workers) == num_workers, "invalid test case"
1655*da0073e9SAndroid Build Coastguard Worker        for prefetch_factor in [2, 3, 4]:
1656*da0073e9SAndroid Build Coastguard Worker            dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
1657*da0073e9SAndroid Build Coastguard Worker            dataloader = self._get_data_loader(
1658*da0073e9SAndroid Build Coastguard Worker                dataset,
1659*da0073e9SAndroid Build Coastguard Worker                num_workers=num_workers,
1660*da0073e9SAndroid Build Coastguard Worker                batch_size=None,
1661*da0073e9SAndroid Build Coastguard Worker                worker_init_fn=set_faulthander_if_available,
1662*da0073e9SAndroid Build Coastguard Worker                prefetch_factor=prefetch_factor,
1663*da0073e9SAndroid Build Coastguard Worker            )
1664*da0073e9SAndroid Build Coastguard Worker            dataloader_iter = iter(dataloader)
1665*da0073e9SAndroid Build Coastguard Worker            fetched = sorted(dataloader_iter)
1666*da0073e9SAndroid Build Coastguard Worker            for a, b in zip(fetched, expected):
1667*da0073e9SAndroid Build Coastguard Worker                # non-batched should not convert ints into tensors
1668*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(a, int)
1669*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a, b)
1670*da0073e9SAndroid Build Coastguard Worker            # DataLoader should match len of the iterable-style dataset (if implemented)
1671*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(dataloader), len(dataset))
1672*da0073e9SAndroid Build Coastguard Worker            # When loading more than len(dataset) data, after accessing len(dataloader),
1673*da0073e9SAndroid Build Coastguard Worker            # we should get a warning. See NOTE [ IterableDataset and __len__ ].
1674*da0073e9SAndroid Build Coastguard Worker            dataset = CountingIterableDataset(20)
1675*da0073e9SAndroid Build Coastguard Worker            dataloader = self._get_data_loader(
1676*da0073e9SAndroid Build Coastguard Worker                dataset,
1677*da0073e9SAndroid Build Coastguard Worker                num_workers=num_workers,
1678*da0073e9SAndroid Build Coastguard Worker                worker_init_fn=set_faulthander_if_available,
1679*da0073e9SAndroid Build Coastguard Worker                prefetch_factor=prefetch_factor,
1680*da0073e9SAndroid Build Coastguard Worker            )
1681*da0073e9SAndroid Build Coastguard Worker            it = iter(dataloader)
1682*da0073e9SAndroid Build Coastguard Worker            for _ in range(40):
1683*da0073e9SAndroid Build Coastguard Worker                self.assertNotWarn(
1684*da0073e9SAndroid Build Coastguard Worker                    lambda: next(it), "Should not warn before accessing len(dataloader)"
1685*da0073e9SAndroid Build Coastguard Worker                )
1686*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(dataloader), len(dataset))
1687*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(dataloader), 20)
1688*da0073e9SAndroid Build Coastguard Worker            it = iter(dataloader)
1689*da0073e9SAndroid Build Coastguard Worker            for _ in range(20):
1690*da0073e9SAndroid Build Coastguard Worker                self.assertNotWarn(
1691*da0073e9SAndroid Build Coastguard Worker                    lambda: next(it), "Should not warn before exceeding length"
1692*da0073e9SAndroid Build Coastguard Worker                )
1693*da0073e9SAndroid Build Coastguard Worker            for _ in range(3):
1694*da0073e9SAndroid Build Coastguard Worker                with self.assertWarnsRegex(
1695*da0073e9SAndroid Build Coastguard Worker                    UserWarning,
1696*da0073e9SAndroid Build Coastguard Worker                    r"but [0-9]+ samples have been fetched\. For multiprocessing data-loading, this",
1697*da0073e9SAndroid Build Coastguard Worker                    msg="Should always warn after exceeding length",
1698*da0073e9SAndroid Build Coastguard Worker                ):
1699*da0073e9SAndroid Build Coastguard Worker                    next(it)
1700*da0073e9SAndroid Build Coastguard Worker        # [no auto-batching] test that workers exit gracefully
1701*da0073e9SAndroid Build Coastguard Worker        workers = dataloader_iter._workers
1702*da0073e9SAndroid Build Coastguard Worker        del dataloader_iter
1703*da0073e9SAndroid Build Coastguard Worker        del dataloader
1704*da0073e9SAndroid Build Coastguard Worker        try:
1705*da0073e9SAndroid Build Coastguard Worker            for w in workers:
1706*da0073e9SAndroid Build Coastguard Worker                w.join(JOIN_TIMEOUT)
1707*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(w.is_alive())
1708*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(w.exitcode, 0)
1709*da0073e9SAndroid Build Coastguard Worker        finally:
1710*da0073e9SAndroid Build Coastguard Worker            for w in workers:
1711*da0073e9SAndroid Build Coastguard Worker                w.terminate()
1712*da0073e9SAndroid Build Coastguard Worker
1713*da0073e9SAndroid Build Coastguard Worker        # [auto-batching] single process loading
1714*da0073e9SAndroid Build Coastguard Worker        dataset = CountingIterableDataset(20)
1715*da0073e9SAndroid Build Coastguard Worker        fetched = list(self._get_data_loader(dataset, batch_size=7))
1716*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(fetched), 3)
1717*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fetched[0].tolist(), list(range(7)))
1718*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fetched[1].tolist(), list(range(7, 14)))
1719*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fetched[2].tolist(), list(range(14, 20)))
1720*da0073e9SAndroid Build Coastguard Worker
1721*da0073e9SAndroid Build Coastguard Worker        # [auto-batching] multiprocessing loading
1722*da0073e9SAndroid Build Coastguard Worker        num_workers = 3
1723*da0073e9SAndroid Build Coastguard Worker        sizes_for_all_workers = [0, 4, 20]
1724*da0073e9SAndroid Build Coastguard Worker        expected = sorted(
1725*da0073e9SAndroid Build Coastguard Worker            functools.reduce(
1726*da0073e9SAndroid Build Coastguard Worker                operator.iadd, (list(range(s)) for s in sizes_for_all_workers), []
1727*da0073e9SAndroid Build Coastguard Worker            )
1728*da0073e9SAndroid Build Coastguard Worker        )
1729*da0073e9SAndroid Build Coastguard Worker        assert len(sizes_for_all_workers) == num_workers, "invalid test case"
1730*da0073e9SAndroid Build Coastguard Worker        for prefetch_factor in [2, 3, 4]:
1731*da0073e9SAndroid Build Coastguard Worker            dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
1732*da0073e9SAndroid Build Coastguard Worker            # worker 0 should return 0 batches
1733*da0073e9SAndroid Build Coastguard Worker            # worker 1 should return 1 batches
1734*da0073e9SAndroid Build Coastguard Worker            # worker 2 should return 3 batches
1735*da0073e9SAndroid Build Coastguard Worker            dataloader = self._get_data_loader(
1736*da0073e9SAndroid Build Coastguard Worker                dataset,
1737*da0073e9SAndroid Build Coastguard Worker                num_workers=num_workers,
1738*da0073e9SAndroid Build Coastguard Worker                batch_size=7,
1739*da0073e9SAndroid Build Coastguard Worker                prefetch_factor=prefetch_factor,
1740*da0073e9SAndroid Build Coastguard Worker            )
1741*da0073e9SAndroid Build Coastguard Worker            dataloader_iter = iter(dataloader)
1742*da0073e9SAndroid Build Coastguard Worker            fetched = list(dataloader_iter)
1743*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(fetched), 4)
1744*da0073e9SAndroid Build Coastguard Worker            fetched = {tuple(t.tolist()) for t in fetched}
1745*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1746*da0073e9SAndroid Build Coastguard Worker                fetched,
1747*da0073e9SAndroid Build Coastguard Worker                {
1748*da0073e9SAndroid Build Coastguard Worker                    tuple(range(4)),
1749*da0073e9SAndroid Build Coastguard Worker                    tuple(range(7)),
1750*da0073e9SAndroid Build Coastguard Worker                    tuple(range(7, 14)),
1751*da0073e9SAndroid Build Coastguard Worker                    tuple(range(14, 20)),
1752*da0073e9SAndroid Build Coastguard Worker                },
1753*da0073e9SAndroid Build Coastguard Worker            )
1754*da0073e9SAndroid Build Coastguard Worker
1755*da0073e9SAndroid Build Coastguard Worker            # [auto-batching] test that workers exit gracefully
1756*da0073e9SAndroid Build Coastguard Worker            workers = dataloader_iter._workers
1757*da0073e9SAndroid Build Coastguard Worker            del dataloader_iter
1758*da0073e9SAndroid Build Coastguard Worker            del dataloader
1759*da0073e9SAndroid Build Coastguard Worker            try:
1760*da0073e9SAndroid Build Coastguard Worker                for w in workers:
1761*da0073e9SAndroid Build Coastguard Worker                    w.join(JOIN_TIMEOUT)
1762*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(w.is_alive())
1763*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(w.exitcode, 0)
1764*da0073e9SAndroid Build Coastguard Worker            finally:
1765*da0073e9SAndroid Build Coastguard Worker                for w in workers:
1766*da0073e9SAndroid Build Coastguard Worker                    w.terminate()
1767*da0073e9SAndroid Build Coastguard Worker        # [auto-batching & drop_last] single process loading
1768*da0073e9SAndroid Build Coastguard Worker        dataset = CountingIterableDataset(20)
1769*da0073e9SAndroid Build Coastguard Worker        fetched = list(self._get_data_loader(dataset, batch_size=7, drop_last=True))
1770*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(fetched), 2)
1771*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fetched[0].tolist(), list(range(7)))
1772*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fetched[1].tolist(), list(range(7, 14)))
1773*da0073e9SAndroid Build Coastguard Worker
1774*da0073e9SAndroid Build Coastguard Worker        # [auto-batching & drop_last] multiprocessing loading
1775*da0073e9SAndroid Build Coastguard Worker        num_workers = 3
1776*da0073e9SAndroid Build Coastguard Worker        sizes_for_all_workers = [0, 4, 20]
1777*da0073e9SAndroid Build Coastguard Worker        expected = sorted(
1778*da0073e9SAndroid Build Coastguard Worker            functools.reduce(
1779*da0073e9SAndroid Build Coastguard Worker                operator.iadd, (list(range(s)) for s in sizes_for_all_workers), []
1780*da0073e9SAndroid Build Coastguard Worker            )
1781*da0073e9SAndroid Build Coastguard Worker        )
1782*da0073e9SAndroid Build Coastguard Worker        assert len(sizes_for_all_workers) == num_workers, "invalid test case"
1783*da0073e9SAndroid Build Coastguard Worker        for prefetch_factor in [2, 3, 4]:
1784*da0073e9SAndroid Build Coastguard Worker            dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
1785*da0073e9SAndroid Build Coastguard Worker            # worker 0 should return 0 batches
1786*da0073e9SAndroid Build Coastguard Worker            # worker 1 should return 1 batches
1787*da0073e9SAndroid Build Coastguard Worker            # worker 2 should return 3 batches
1788*da0073e9SAndroid Build Coastguard Worker            dataloader = self._get_data_loader(
1789*da0073e9SAndroid Build Coastguard Worker                dataset,
1790*da0073e9SAndroid Build Coastguard Worker                num_workers=num_workers,
1791*da0073e9SAndroid Build Coastguard Worker                batch_size=7,
1792*da0073e9SAndroid Build Coastguard Worker                drop_last=True,
1793*da0073e9SAndroid Build Coastguard Worker                worker_init_fn=set_faulthander_if_available,
1794*da0073e9SAndroid Build Coastguard Worker                prefetch_factor=prefetch_factor,
1795*da0073e9SAndroid Build Coastguard Worker            )
1796*da0073e9SAndroid Build Coastguard Worker            dataloader_iter = iter(dataloader)
1797*da0073e9SAndroid Build Coastguard Worker            fetched = list(dataloader_iter)
1798*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(fetched), 2)
1799*da0073e9SAndroid Build Coastguard Worker            fetched = {tuple(t.tolist()) for t in fetched}
1800*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fetched, {tuple(range(7)), tuple(range(7, 14))})
1801*da0073e9SAndroid Build Coastguard Worker
1802*da0073e9SAndroid Build Coastguard Worker            # [auto-batching & drop_last] test that workers exit gracefully
1803*da0073e9SAndroid Build Coastguard Worker            workers = dataloader_iter._workers
1804*da0073e9SAndroid Build Coastguard Worker            del dataloader_iter
1805*da0073e9SAndroid Build Coastguard Worker            del dataloader
1806*da0073e9SAndroid Build Coastguard Worker            try:
1807*da0073e9SAndroid Build Coastguard Worker                for w in workers:
1808*da0073e9SAndroid Build Coastguard Worker                    w.join(JOIN_TIMEOUT)
1809*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(w.is_alive())
1810*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(w.exitcode, 0)
1811*da0073e9SAndroid Build Coastguard Worker            finally:
1812*da0073e9SAndroid Build Coastguard Worker                for w in workers:
1813*da0073e9SAndroid Build Coastguard Worker                    w.terminate()
1814*da0073e9SAndroid Build Coastguard Worker
1815*da0073e9SAndroid Build Coastguard Worker    def test_chain_iterable_style_dataset(self):
1816*da0073e9SAndroid Build Coastguard Worker        # chaining (concatenation)
1817*da0073e9SAndroid Build Coastguard Worker        dataset1 = CountingIterableDataset(20)
1818*da0073e9SAndroid Build Coastguard Worker        dataset2 = CountingIterableDataset(15)
1819*da0073e9SAndroid Build Coastguard Worker        expected = list(range(20)) + list(range(15))
1820*da0073e9SAndroid Build Coastguard Worker        for num_workers in [0, 1]:
1821*da0073e9SAndroid Build Coastguard Worker            for chained_dataset in [
1822*da0073e9SAndroid Build Coastguard Worker                dataset1 + dataset2,
1823*da0073e9SAndroid Build Coastguard Worker                ChainDataset([dataset1, dataset2]),
1824*da0073e9SAndroid Build Coastguard Worker            ]:
1825*da0073e9SAndroid Build Coastguard Worker                fetched = list(
1826*da0073e9SAndroid Build Coastguard Worker                    self._get_data_loader(chained_dataset, num_workers=num_workers)
1827*da0073e9SAndroid Build Coastguard Worker                )
1828*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(fetched), len(expected))
1829*da0073e9SAndroid Build Coastguard Worker                for e, d in zip(expected, fetched):
1830*da0073e9SAndroid Build Coastguard Worker                    self.assertIsInstance(d, torch.Tensor)
1831*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(e, d)
1832*da0073e9SAndroid Build Coastguard Worker
1833*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1834*da0073e9SAndroid Build Coastguard Worker            AssertionError, "ChainDataset only supports IterableDataset"
1835*da0073e9SAndroid Build Coastguard Worker        ):
1836*da0073e9SAndroid Build Coastguard Worker            list(iter(dataset1 + self.dataset))
1837*da0073e9SAndroid Build Coastguard Worker
1838*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1839*da0073e9SAndroid Build Coastguard Worker            AssertionError, "ChainDataset only supports IterableDataset"
1840*da0073e9SAndroid Build Coastguard Worker        ):
1841*da0073e9SAndroid Build Coastguard Worker            list(iter(ChainDataset([dataset1, self.dataset])))
1842*da0073e9SAndroid Build Coastguard Worker
1843*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_MACOS, "Not working on macos")
1844*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1845*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # https://github.com/pytorch/pytorch/issues/90940
1846*da0073e9SAndroid Build Coastguard Worker    def test_multiprocessing_contexts(self):
1847*da0073e9SAndroid Build Coastguard Worker        reference = [
1848*da0073e9SAndroid Build Coastguard Worker            torch.arange(3),
1849*da0073e9SAndroid Build Coastguard Worker            torch.arange(3, 6),
1850*da0073e9SAndroid Build Coastguard Worker            torch.arange(6, 9),
1851*da0073e9SAndroid Build Coastguard Worker            torch.arange(9, 11),
1852*da0073e9SAndroid Build Coastguard Worker        ]
1853*da0073e9SAndroid Build Coastguard Worker        counting_ds_n = 11
1854*da0073e9SAndroid Build Coastguard Worker        dl_common_args = dict(num_workers=3, batch_size=3, pin_memory=(not TEST_CUDA))
1855*da0073e9SAndroid Build Coastguard Worker        for ctx in supported_multiprocessing_contexts:
1856*da0073e9SAndroid Build Coastguard Worker            # windows and jetson devices don't support sharing cuda tensor; ROCm does not yet fully support IPC
1857*da0073e9SAndroid Build Coastguard Worker            if (
1858*da0073e9SAndroid Build Coastguard Worker                ctx in ["spawn", "forkserver"]
1859*da0073e9SAndroid Build Coastguard Worker                and TEST_CUDA
1860*da0073e9SAndroid Build Coastguard Worker                and not IS_WINDOWS
1861*da0073e9SAndroid Build Coastguard Worker                and not IS_JETSON
1862*da0073e9SAndroid Build Coastguard Worker            ):
1863*da0073e9SAndroid Build Coastguard Worker                ds_cls = CUDACountingDataset
1864*da0073e9SAndroid Build Coastguard Worker            else:
1865*da0073e9SAndroid Build Coastguard Worker                ds_cls = CountingDataset
1866*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1867*da0073e9SAndroid Build Coastguard Worker                reference,
1868*da0073e9SAndroid Build Coastguard Worker                list(
1869*da0073e9SAndroid Build Coastguard Worker                    self._get_data_loader(
1870*da0073e9SAndroid Build Coastguard Worker                        ds_cls(counting_ds_n),
1871*da0073e9SAndroid Build Coastguard Worker                        multiprocessing_context=ctx,
1872*da0073e9SAndroid Build Coastguard Worker                        **dl_common_args,
1873*da0073e9SAndroid Build Coastguard Worker                    )
1874*da0073e9SAndroid Build Coastguard Worker                ),
1875*da0073e9SAndroid Build Coastguard Worker            )
1876*da0073e9SAndroid Build Coastguard Worker            if ctx is not None:
1877*da0073e9SAndroid Build Coastguard Worker                # test ctx object
1878*da0073e9SAndroid Build Coastguard Worker                ctx = mp.get_context(ctx)
1879*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1880*da0073e9SAndroid Build Coastguard Worker                    reference,
1881*da0073e9SAndroid Build Coastguard Worker                    list(
1882*da0073e9SAndroid Build Coastguard Worker                        self._get_data_loader(
1883*da0073e9SAndroid Build Coastguard Worker                            ds_cls(counting_ds_n),
1884*da0073e9SAndroid Build Coastguard Worker                            multiprocessing_context=ctx,
1885*da0073e9SAndroid Build Coastguard Worker                            **dl_common_args,
1886*da0073e9SAndroid Build Coastguard Worker                        )
1887*da0073e9SAndroid Build Coastguard Worker                    ),
1888*da0073e9SAndroid Build Coastguard Worker                )
1889*da0073e9SAndroid Build Coastguard Worker
1890*da0073e9SAndroid Build Coastguard Worker    def _test_multiprocessing_iterdatapipe(self, with_dill):
1891*da0073e9SAndroid Build Coastguard Worker        # Testing to make sure that function from global scope (e.g. imported from library) can be serialized
1892*da0073e9SAndroid Build Coastguard Worker        # and used with multiprocess DataLoader
1893*da0073e9SAndroid Build Coastguard Worker
1894*da0073e9SAndroid Build Coastguard Worker        reference = [
1895*da0073e9SAndroid Build Coastguard Worker            torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64),
1896*da0073e9SAndroid Build Coastguard Worker            torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64),
1897*da0073e9SAndroid Build Coastguard Worker        ]
1898*da0073e9SAndroid Build Coastguard Worker        datapipe: IterDataPipe = IterableWrapper([[1, 2, 3, 4], [1, 2, 3, 4, 5, 6]])
1899*da0073e9SAndroid Build Coastguard Worker        datapipe = datapipe.map(row_processor)
1900*da0073e9SAndroid Build Coastguard Worker        datapipe = (
1901*da0073e9SAndroid Build Coastguard Worker            datapipe.filter(lambda row: len(row) == 4)
1902*da0073e9SAndroid Build Coastguard Worker            if with_dill
1903*da0073e9SAndroid Build Coastguard Worker            else datapipe.filter(filter_len)
1904*da0073e9SAndroid Build Coastguard Worker        )
1905*da0073e9SAndroid Build Coastguard Worker
1906*da0073e9SAndroid Build Coastguard Worker        dl_common_args = dict(
1907*da0073e9SAndroid Build Coastguard Worker            num_workers=2, batch_size=2, shuffle=True, pin_memory=(not TEST_CUDA)
1908*da0073e9SAndroid Build Coastguard Worker        )
1909*da0073e9SAndroid Build Coastguard Worker        for ctx in supported_multiprocessing_contexts:
1910*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1911*da0073e9SAndroid Build Coastguard Worker                reference,
1912*da0073e9SAndroid Build Coastguard Worker                [
1913*da0073e9SAndroid Build Coastguard Worker                    t.type(torch.int64)
1914*da0073e9SAndroid Build Coastguard Worker                    for t in self._get_data_loader(
1915*da0073e9SAndroid Build Coastguard Worker                        datapipe, multiprocessing_context=ctx, **dl_common_args
1916*da0073e9SAndroid Build Coastguard Worker                    )
1917*da0073e9SAndroid Build Coastguard Worker                ],
1918*da0073e9SAndroid Build Coastguard Worker            )
1919*da0073e9SAndroid Build Coastguard Worker            if ctx is not None:
1920*da0073e9SAndroid Build Coastguard Worker                # test ctx object
1921*da0073e9SAndroid Build Coastguard Worker                ctx = mp.get_context(ctx)
1922*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1923*da0073e9SAndroid Build Coastguard Worker                    reference,
1924*da0073e9SAndroid Build Coastguard Worker                    [
1925*da0073e9SAndroid Build Coastguard Worker                        t.type(torch.int64)
1926*da0073e9SAndroid Build Coastguard Worker                        for t in self._get_data_loader(
1927*da0073e9SAndroid Build Coastguard Worker                            datapipe, multiprocessing_context=ctx, **dl_common_args
1928*da0073e9SAndroid Build Coastguard Worker                        )
1929*da0073e9SAndroid Build Coastguard Worker                    ],
1930*da0073e9SAndroid Build Coastguard Worker                )
1931*da0073e9SAndroid Build Coastguard Worker
1932*da0073e9SAndroid Build Coastguard Worker    @skipIfNoNumpy
1933*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1934*da0073e9SAndroid Build Coastguard Worker    def test_multiprocessing_iterdatapipe(self):
1935*da0073e9SAndroid Build Coastguard Worker        self._test_multiprocessing_iterdatapipe(with_dill=False)
1936*da0073e9SAndroid Build Coastguard Worker
1937*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
1938*da0073e9SAndroid Build Coastguard Worker    @skipIfNoNumpy
1939*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1940*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDill
1941*da0073e9SAndroid Build Coastguard Worker    def test_multiprocessing_iterdatapipe_with_dill(self):
1942*da0073e9SAndroid Build Coastguard Worker        self._test_multiprocessing_iterdatapipe(with_dill=True)
1943*da0073e9SAndroid Build Coastguard Worker
1944*da0073e9SAndroid Build Coastguard Worker    def test_worker_seed(self):
1945*da0073e9SAndroid Build Coastguard Worker        num_workers = 6
1946*da0073e9SAndroid Build Coastguard Worker        batch_size = 1
1947*da0073e9SAndroid Build Coastguard Worker        dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
1948*da0073e9SAndroid Build Coastguard Worker        dataloader = self._get_data_loader(
1949*da0073e9SAndroid Build Coastguard Worker            dataset, batch_size=batch_size, num_workers=num_workers
1950*da0073e9SAndroid Build Coastguard Worker        )
1951*da0073e9SAndroid Build Coastguard Worker        seeds = set()
1952*da0073e9SAndroid Build Coastguard Worker        seeds.update(batch[0] for batch in dataloader)
1953*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(seeds), num_workers)
1954*da0073e9SAndroid Build Coastguard Worker
1955*da0073e9SAndroid Build Coastguard Worker    def test_worker_seed_reproducibility(self):
1956*da0073e9SAndroid Build Coastguard Worker        def get_dataloader():
1957*da0073e9SAndroid Build Coastguard Worker            return DataLoader(
1958*da0073e9SAndroid Build Coastguard Worker                dataset,
1959*da0073e9SAndroid Build Coastguard Worker                batch_size=batch_size,
1960*da0073e9SAndroid Build Coastguard Worker                num_workers=num_workers,
1961*da0073e9SAndroid Build Coastguard Worker                generator=torch.Generator().manual_seed(42),
1962*da0073e9SAndroid Build Coastguard Worker            )
1963*da0073e9SAndroid Build Coastguard Worker
1964*da0073e9SAndroid Build Coastguard Worker        num_workers = 6
1965*da0073e9SAndroid Build Coastguard Worker        batch_size = 1
1966*da0073e9SAndroid Build Coastguard Worker        dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
1967*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1968*da0073e9SAndroid Build Coastguard Worker            {int(batch) for batch in get_dataloader()},
1969*da0073e9SAndroid Build Coastguard Worker            {int(batch) for batch in get_dataloader()},
1970*da0073e9SAndroid Build Coastguard Worker        )
1971*da0073e9SAndroid Build Coastguard Worker
1972*da0073e9SAndroid Build Coastguard Worker    def test_multi_epochs_reproducibility(self):
1973*da0073e9SAndroid Build Coastguard Worker        num_workers = 2
1974*da0073e9SAndroid Build Coastguard Worker        batch_size = 10
1975*da0073e9SAndroid Build Coastguard Worker        num_epochs = 3
1976*da0073e9SAndroid Build Coastguard Worker
1977*da0073e9SAndroid Build Coastguard Worker        dataset = TestMultiEpochDataset(batch_size * num_workers)
1978*da0073e9SAndroid Build Coastguard Worker        dataloader = self._get_data_loader(
1979*da0073e9SAndroid Build Coastguard Worker            dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
1980*da0073e9SAndroid Build Coastguard Worker        )
1981*da0073e9SAndroid Build Coastguard Worker
1982*da0073e9SAndroid Build Coastguard Worker        for ind in range(num_epochs):
1983*da0073e9SAndroid Build Coastguard Worker            for batch_idx, sample in enumerate(dataloader):
1984*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1985*da0073e9SAndroid Build Coastguard Worker                    sample.tolist(), [batch_idx % num_workers] * batch_size
1986*da0073e9SAndroid Build Coastguard Worker                )
1987*da0073e9SAndroid Build Coastguard Worker
1988*da0073e9SAndroid Build Coastguard Worker    def test_worker_init_fn(self):
1989*da0073e9SAndroid Build Coastguard Worker        dataset = SeedDataset(4)
1990*da0073e9SAndroid Build Coastguard Worker        dataloader = self._get_data_loader(
1991*da0073e9SAndroid Build Coastguard Worker            dataset, batch_size=2, num_workers=2, worker_init_fn=init_fn
1992*da0073e9SAndroid Build Coastguard Worker        )
1993*da0073e9SAndroid Build Coastguard Worker        for batch in dataloader:
1994*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(12345, batch[0])
1995*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(12345, batch[1])
1996*da0073e9SAndroid Build Coastguard Worker
1997*da0073e9SAndroid Build Coastguard Worker    def test_get_worker_info(self):
1998*da0073e9SAndroid Build Coastguard Worker        p = ErrorTrackingProcess(target=_test_get_worker_info)
1999*da0073e9SAndroid Build Coastguard Worker        p.start()
2000*da0073e9SAndroid Build Coastguard Worker        p.join(JOIN_TIMEOUT)
2001*da0073e9SAndroid Build Coastguard Worker        try:
2002*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(p.is_alive())
2003*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(p.exitcode, 0)
2004*da0073e9SAndroid Build Coastguard Worker        finally:
2005*da0073e9SAndroid Build Coastguard Worker            p.terminate()
2006*da0073e9SAndroid Build Coastguard Worker
2007*da0073e9SAndroid Build Coastguard Worker    def test_shuffle(self):
2008*da0073e9SAndroid Build Coastguard Worker        self._test_shuffle(self._get_data_loader(self.dataset, shuffle=True))
2009*da0073e9SAndroid Build Coastguard Worker
2010*da0073e9SAndroid Build Coastguard Worker    def test_shuffle_batch_none(self):
2011*da0073e9SAndroid Build Coastguard Worker        self._test_shuffle(DataLoader(self.dataset, batch_size=None, shuffle=True))
2012*da0073e9SAndroid Build Coastguard Worker
2013*da0073e9SAndroid Build Coastguard Worker    def test_shuffle_batch(self):
2014*da0073e9SAndroid Build Coastguard Worker        self._test_shuffle(
2015*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(self.dataset, batch_size=2, shuffle=True)
2016*da0073e9SAndroid Build Coastguard Worker        )
2017*da0073e9SAndroid Build Coastguard Worker
2018*da0073e9SAndroid Build Coastguard Worker    def test_shuffle_reproducibility(self):
2019*da0073e9SAndroid Build Coastguard Worker        for fn in (
2020*da0073e9SAndroid Build Coastguard Worker            lambda: DataLoader(
2021*da0073e9SAndroid Build Coastguard Worker                self.dataset,
2022*da0073e9SAndroid Build Coastguard Worker                shuffle=True,
2023*da0073e9SAndroid Build Coastguard Worker                num_workers=0,
2024*da0073e9SAndroid Build Coastguard Worker                generator=torch.Generator().manual_seed(42),
2025*da0073e9SAndroid Build Coastguard Worker            ),
2026*da0073e9SAndroid Build Coastguard Worker            lambda: DataLoader(
2027*da0073e9SAndroid Build Coastguard Worker                self.dataset,
2028*da0073e9SAndroid Build Coastguard Worker                shuffle=True,
2029*da0073e9SAndroid Build Coastguard Worker                num_workers=2,
2030*da0073e9SAndroid Build Coastguard Worker                generator=torch.Generator().manual_seed(42),
2031*da0073e9SAndroid Build Coastguard Worker            ),
2032*da0073e9SAndroid Build Coastguard Worker        ):
2033*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(fn()), list(fn()))
2034*da0073e9SAndroid Build Coastguard Worker
2035*da0073e9SAndroid Build Coastguard Worker    def test_sequential_workers(self):
2036*da0073e9SAndroid Build Coastguard Worker        self._test_sequential(self._get_data_loader(self.dataset, num_workers=4))
2037*da0073e9SAndroid Build Coastguard Worker
2038*da0073e9SAndroid Build Coastguard Worker    def test_seqential_batch_workers(self):
2039*da0073e9SAndroid Build Coastguard Worker        self._test_sequential(
2040*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(self.dataset, batch_size=2, num_workers=4)
2041*da0073e9SAndroid Build Coastguard Worker        )
2042*da0073e9SAndroid Build Coastguard Worker
2043*da0073e9SAndroid Build Coastguard Worker    def test_seqential_batch_workers_prefetch(self):
2044*da0073e9SAndroid Build Coastguard Worker        self._test_sequential(
2045*da0073e9SAndroid Build Coastguard Worker            DataLoader(self.dataset, batch_size=2, num_workers=4, prefetch_factor=3)
2046*da0073e9SAndroid Build Coastguard Worker        )
2047*da0073e9SAndroid Build Coastguard Worker
2048*da0073e9SAndroid Build Coastguard Worker    def test_shuffle_workers(self):
2049*da0073e9SAndroid Build Coastguard Worker        self._test_shuffle(
2050*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(self.dataset, shuffle=True, num_workers=4)
2051*da0073e9SAndroid Build Coastguard Worker        )
2052*da0073e9SAndroid Build Coastguard Worker
2053*da0073e9SAndroid Build Coastguard Worker    def test_shuffle_batch_workers(self):
2054*da0073e9SAndroid Build Coastguard Worker        self._test_shuffle(
2055*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(
2056*da0073e9SAndroid Build Coastguard Worker                self.dataset, batch_size=2, shuffle=True, num_workers=4
2057*da0073e9SAndroid Build Coastguard Worker            )
2058*da0073e9SAndroid Build Coastguard Worker        )
2059*da0073e9SAndroid Build Coastguard Worker
2060*da0073e9SAndroid Build Coastguard Worker    def test_shuffle_batch_workers_prefetch(self):
2061*da0073e9SAndroid Build Coastguard Worker        self._test_shuffle(
2062*da0073e9SAndroid Build Coastguard Worker            DataLoader(
2063*da0073e9SAndroid Build Coastguard Worker                self.dataset,
2064*da0073e9SAndroid Build Coastguard Worker                batch_size=2,
2065*da0073e9SAndroid Build Coastguard Worker                shuffle=True,
2066*da0073e9SAndroid Build Coastguard Worker                num_workers=4,
2067*da0073e9SAndroid Build Coastguard Worker                prefetch_factor=3,
2068*da0073e9SAndroid Build Coastguard Worker            )
2069*da0073e9SAndroid Build Coastguard Worker        )
2070*da0073e9SAndroid Build Coastguard Worker
2071*da0073e9SAndroid Build Coastguard Worker    def test_random_sampler(self):
2072*da0073e9SAndroid Build Coastguard Worker        from collections import Counter
2073*da0073e9SAndroid Build Coastguard Worker
2074*da0073e9SAndroid Build Coastguard Worker        from torch.utils.data import RandomSampler
2075*da0073e9SAndroid Build Coastguard Worker
2076*da0073e9SAndroid Build Coastguard Worker        def sample_stat(sampler, num_samples):
2077*da0073e9SAndroid Build Coastguard Worker            counts = Counter(sampler)
2078*da0073e9SAndroid Build Coastguard Worker            count_repeated = sum(val > 1 for val in counts.values())
2079*da0073e9SAndroid Build Coastguard Worker            return (
2080*da0073e9SAndroid Build Coastguard Worker                count_repeated,
2081*da0073e9SAndroid Build Coastguard Worker                min(counts.keys()),
2082*da0073e9SAndroid Build Coastguard Worker                max(counts.keys()),
2083*da0073e9SAndroid Build Coastguard Worker                sum(counts.values()),
2084*da0073e9SAndroid Build Coastguard Worker            )
2085*da0073e9SAndroid Build Coastguard Worker
2086*da0073e9SAndroid Build Coastguard Worker        # test sample with replacement
2087*da0073e9SAndroid Build Coastguard Worker        n = len(self.dataset) + 1  # ensure at least one sample is drawn more than once
2088*da0073e9SAndroid Build Coastguard Worker        sampler_with_replacement = RandomSampler(
2089*da0073e9SAndroid Build Coastguard Worker            self.dataset, replacement=True, num_samples=n
2090*da0073e9SAndroid Build Coastguard Worker        )
2091*da0073e9SAndroid Build Coastguard Worker        count_repeated, minval, maxval, count_total = sample_stat(
2092*da0073e9SAndroid Build Coastguard Worker            sampler_with_replacement, n
2093*da0073e9SAndroid Build Coastguard Worker        )
2094*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(count_repeated > 0)
2095*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(minval >= 0)
2096*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(maxval < len(self.dataset))
2097*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(count_total == n)
2098*da0073e9SAndroid Build Coastguard Worker
2099*da0073e9SAndroid Build Coastguard Worker        # test sample without replacement and without specified num_samples
2100*da0073e9SAndroid Build Coastguard Worker        sampler_without_replacement = RandomSampler(self.dataset)
2101*da0073e9SAndroid Build Coastguard Worker        count_repeated, minval, maxval, count_total = sample_stat(
2102*da0073e9SAndroid Build Coastguard Worker            sampler_without_replacement, len(self.dataset)
2103*da0073e9SAndroid Build Coastguard Worker        )
2104*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(count_repeated == 0)
2105*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(minval == 0)
2106*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(maxval == len(self.dataset) - 1)
2107*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(count_total == len(self.dataset))
2108*da0073e9SAndroid Build Coastguard Worker
2109*da0073e9SAndroid Build Coastguard Worker        # test sample without replacement and with specified num_samples
2110*da0073e9SAndroid Build Coastguard Worker        n = len(self.dataset) * 2
2111*da0073e9SAndroid Build Coastguard Worker        sampler_without_replacement = RandomSampler(self.dataset, num_samples=n)
2112*da0073e9SAndroid Build Coastguard Worker        count_repeated, minval, maxval, count_total = sample_stat(
2113*da0073e9SAndroid Build Coastguard Worker            sampler_without_replacement, len(self.dataset)
2114*da0073e9SAndroid Build Coastguard Worker        )
2115*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(count_repeated == len(self.dataset))
2116*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(minval == 0)
2117*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(maxval == len(self.dataset) - 1)
2118*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(count_total == n)
2119*da0073e9SAndroid Build Coastguard Worker
2120*da0073e9SAndroid Build Coastguard Worker        n = len(self.dataset) - 1
2121*da0073e9SAndroid Build Coastguard Worker        sampler_without_replacement = RandomSampler(self.dataset, num_samples=n)
2122*da0073e9SAndroid Build Coastguard Worker        count_repeated, minval, maxval, count_total = sample_stat(
2123*da0073e9SAndroid Build Coastguard Worker            sampler_without_replacement, len(self.dataset)
2124*da0073e9SAndroid Build Coastguard Worker        )
2125*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(count_repeated == 0)
2126*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(minval >= 0)
2127*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(maxval < len(self.dataset))
2128*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(count_total == n)
2129*da0073e9SAndroid Build Coastguard Worker
2130*da0073e9SAndroid Build Coastguard Worker        n = len(self.dataset) + 1
2131*da0073e9SAndroid Build Coastguard Worker        sampler_without_replacement = RandomSampler(self.dataset, num_samples=n)
2132*da0073e9SAndroid Build Coastguard Worker        count_repeated, minval, maxval, count_total = sample_stat(
2133*da0073e9SAndroid Build Coastguard Worker            sampler_without_replacement, len(self.dataset)
2134*da0073e9SAndroid Build Coastguard Worker        )
2135*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(count_repeated == 1)
2136*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(minval == 0)
2137*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(maxval == len(self.dataset) - 1)
2138*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(count_total == n)
2139*da0073e9SAndroid Build Coastguard Worker
2140*da0073e9SAndroid Build Coastguard Worker        # raise error when replacement is non-boolean
2141*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2142*da0073e9SAndroid Build Coastguard Worker            TypeError, "replacement should be a boolean value, but got replacement=0"
2143*da0073e9SAndroid Build Coastguard Worker        ):
2144*da0073e9SAndroid Build Coastguard Worker            RandomSampler(self.dataset, replacement=0)
2145*da0073e9SAndroid Build Coastguard Worker
2146*da0073e9SAndroid Build Coastguard Worker    def test_random_sampler_len_with_replacement(self):
2147*da0073e9SAndroid Build Coastguard Worker        from torch.utils.data import RandomSampler
2148*da0073e9SAndroid Build Coastguard Worker
2149*da0073e9SAndroid Build Coastguard Worker        # add 5 extra samples
2150*da0073e9SAndroid Build Coastguard Worker        num_samples = len(self.dataset) + 5
2151*da0073e9SAndroid Build Coastguard Worker        sampler = RandomSampler(self.dataset, replacement=True, num_samples=num_samples)
2152*da0073e9SAndroid Build Coastguard Worker        # test len method
2153*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_samples, len(sampler))
2154*da0073e9SAndroid Build Coastguard Worker
2155*da0073e9SAndroid Build Coastguard Worker        # test with iteration
2156*da0073e9SAndroid Build Coastguard Worker        count_num_samples = sum(1 for _ in sampler)
2157*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_samples, count_num_samples)
2158*da0073e9SAndroid Build Coastguard Worker
2159*da0073e9SAndroid Build Coastguard Worker        # test with dataloader, batch_size = 1
2160*da0073e9SAndroid Build Coastguard Worker        batch_size = 1
2161*da0073e9SAndroid Build Coastguard Worker        count_num_samples_in_data_loader = len(
2162*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2163*da0073e9SAndroid Build Coastguard Worker        )
2164*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_samples, count_num_samples_in_data_loader)
2165*da0073e9SAndroid Build Coastguard Worker
2166*da0073e9SAndroid Build Coastguard Worker        # test with dataloader, batch_size = 6
2167*da0073e9SAndroid Build Coastguard Worker        batch_size = 6
2168*da0073e9SAndroid Build Coastguard Worker        count_num_samples_in_data_loader = len(
2169*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2170*da0073e9SAndroid Build Coastguard Worker        )
2171*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2172*da0073e9SAndroid Build Coastguard Worker            int(math.ceil(float(num_samples) / batch_size)),
2173*da0073e9SAndroid Build Coastguard Worker            count_num_samples_in_data_loader,
2174*da0073e9SAndroid Build Coastguard Worker        )
2175*da0073e9SAndroid Build Coastguard Worker
2176*da0073e9SAndroid Build Coastguard Worker    def test_random_sampler_len_without_replacement(self):
2177*da0073e9SAndroid Build Coastguard Worker        from torch.utils.data import RandomSampler
2178*da0073e9SAndroid Build Coastguard Worker
2179*da0073e9SAndroid Build Coastguard Worker        # add 5 extra samples
2180*da0073e9SAndroid Build Coastguard Worker        num_samples = len(self.dataset) + 5
2181*da0073e9SAndroid Build Coastguard Worker        sampler = RandomSampler(
2182*da0073e9SAndroid Build Coastguard Worker            self.dataset, replacement=False, num_samples=num_samples
2183*da0073e9SAndroid Build Coastguard Worker        )
2184*da0073e9SAndroid Build Coastguard Worker        # test len method
2185*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_samples, len(sampler))
2186*da0073e9SAndroid Build Coastguard Worker
2187*da0073e9SAndroid Build Coastguard Worker        # test with iteration
2188*da0073e9SAndroid Build Coastguard Worker        count_num_samples = sum(1 for _ in sampler)
2189*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_samples, count_num_samples)
2190*da0073e9SAndroid Build Coastguard Worker
2191*da0073e9SAndroid Build Coastguard Worker        # test with dataloader, batch_size = 1
2192*da0073e9SAndroid Build Coastguard Worker        batch_size = 1
2193*da0073e9SAndroid Build Coastguard Worker        count_num_samples_in_data_loader = len(
2194*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2195*da0073e9SAndroid Build Coastguard Worker        )
2196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_samples, count_num_samples_in_data_loader)
2197*da0073e9SAndroid Build Coastguard Worker
2198*da0073e9SAndroid Build Coastguard Worker        # test with dataloader, batch_size = 6
2199*da0073e9SAndroid Build Coastguard Worker        batch_size = 6
2200*da0073e9SAndroid Build Coastguard Worker        count_num_samples_in_data_loader = len(
2201*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2202*da0073e9SAndroid Build Coastguard Worker        )
2203*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2204*da0073e9SAndroid Build Coastguard Worker            num_samples // batch_size + (num_samples % batch_size > 0),
2205*da0073e9SAndroid Build Coastguard Worker            count_num_samples_in_data_loader,
2206*da0073e9SAndroid Build Coastguard Worker        )
2207*da0073e9SAndroid Build Coastguard Worker
2208*da0073e9SAndroid Build Coastguard Worker    def test_distributed_sampler_invalid_rank(self):
2209*da0073e9SAndroid Build Coastguard Worker        from torch.utils.data.distributed import DistributedSampler
2210*da0073e9SAndroid Build Coastguard Worker
2211*da0073e9SAndroid Build Coastguard Worker        dataset = torch.IntTensor(range(10))
2212*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Invalid rank"):
2213*da0073e9SAndroid Build Coastguard Worker            sampler = DistributedSampler(dataset, 3, 3)
2214*da0073e9SAndroid Build Coastguard Worker
2215*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Invalid rank"):
2216*da0073e9SAndroid Build Coastguard Worker            sampler = DistributedSampler(dataset, 3, -1)
2217*da0073e9SAndroid Build Coastguard Worker
2218*da0073e9SAndroid Build Coastguard Worker    def test_duplicating_data_with_drop_last(self):
2219*da0073e9SAndroid Build Coastguard Worker        from torch.utils.data.distributed import DistributedSampler
2220*da0073e9SAndroid Build Coastguard Worker
2221*da0073e9SAndroid Build Coastguard Worker        num_processes = 4
2222*da0073e9SAndroid Build Coastguard Worker        num_batches = 9
2223*da0073e9SAndroid Build Coastguard Worker        data_set = torch.IntTensor(range(num_batches))
2224*da0073e9SAndroid Build Coastguard Worker        scanned_data = torch.IntTensor([])
2225*da0073e9SAndroid Build Coastguard Worker        for i in range(num_processes):
2226*da0073e9SAndroid Build Coastguard Worker            s = DistributedSampler(data_set, num_processes, i)
2227*da0073e9SAndroid Build Coastguard Worker            d_loader = self._get_data_loader(
2228*da0073e9SAndroid Build Coastguard Worker                data_set,
2229*da0073e9SAndroid Build Coastguard Worker                batch_size=int(num_batches / num_processes),
2230*da0073e9SAndroid Build Coastguard Worker                drop_last=True,
2231*da0073e9SAndroid Build Coastguard Worker                sampler=s,
2232*da0073e9SAndroid Build Coastguard Worker            )
2233*da0073e9SAndroid Build Coastguard Worker            for data in d_loader:
2234*da0073e9SAndroid Build Coastguard Worker                scanned_data = torch.cat((scanned_data, data), 0)
2235*da0073e9SAndroid Build Coastguard Worker
2236*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scanned_data.size(), scanned_data.unique().size())
2237*da0073e9SAndroid Build Coastguard Worker
2238*da0073e9SAndroid Build Coastguard Worker    def test_sampler_reproducibility(self):
2239*da0073e9SAndroid Build Coastguard Worker        from torch.utils.data import (
2240*da0073e9SAndroid Build Coastguard Worker            RandomSampler,
2241*da0073e9SAndroid Build Coastguard Worker            SubsetRandomSampler,
2242*da0073e9SAndroid Build Coastguard Worker            WeightedRandomSampler,
2243*da0073e9SAndroid Build Coastguard Worker        )
2244*da0073e9SAndroid Build Coastguard Worker
2245*da0073e9SAndroid Build Coastguard Worker        weights = [0.1, 0.9, 0.4, 0.7, 3.0, 0.6]
2246*da0073e9SAndroid Build Coastguard Worker        for fn in (
2247*da0073e9SAndroid Build Coastguard Worker            lambda: RandomSampler(
2248*da0073e9SAndroid Build Coastguard Worker                self.dataset,
2249*da0073e9SAndroid Build Coastguard Worker                num_samples=5,
2250*da0073e9SAndroid Build Coastguard Worker                replacement=True,
2251*da0073e9SAndroid Build Coastguard Worker                generator=torch.Generator().manual_seed(42),
2252*da0073e9SAndroid Build Coastguard Worker            ),
2253*da0073e9SAndroid Build Coastguard Worker            lambda: RandomSampler(
2254*da0073e9SAndroid Build Coastguard Worker                self.dataset,
2255*da0073e9SAndroid Build Coastguard Worker                replacement=False,
2256*da0073e9SAndroid Build Coastguard Worker                generator=torch.Generator().manual_seed(42),
2257*da0073e9SAndroid Build Coastguard Worker            ),
2258*da0073e9SAndroid Build Coastguard Worker            lambda: WeightedRandomSampler(
2259*da0073e9SAndroid Build Coastguard Worker                weights,
2260*da0073e9SAndroid Build Coastguard Worker                num_samples=5,
2261*da0073e9SAndroid Build Coastguard Worker                replacement=True,
2262*da0073e9SAndroid Build Coastguard Worker                generator=torch.Generator().manual_seed(42),
2263*da0073e9SAndroid Build Coastguard Worker            ),
2264*da0073e9SAndroid Build Coastguard Worker            lambda: WeightedRandomSampler(
2265*da0073e9SAndroid Build Coastguard Worker                weights,
2266*da0073e9SAndroid Build Coastguard Worker                num_samples=5,
2267*da0073e9SAndroid Build Coastguard Worker                replacement=False,
2268*da0073e9SAndroid Build Coastguard Worker                generator=torch.Generator().manual_seed(42),
2269*da0073e9SAndroid Build Coastguard Worker            ),
2270*da0073e9SAndroid Build Coastguard Worker            lambda: SubsetRandomSampler(
2271*da0073e9SAndroid Build Coastguard Worker                range(10), generator=torch.Generator().manual_seed(42)
2272*da0073e9SAndroid Build Coastguard Worker            ),
2273*da0073e9SAndroid Build Coastguard Worker        ):
2274*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(fn()), list(fn()))
2275*da0073e9SAndroid Build Coastguard Worker
2276*da0073e9SAndroid Build Coastguard Worker        for sampler in (
2277*da0073e9SAndroid Build Coastguard Worker            RandomSampler(self.dataset, num_samples=5, replacement=True),
2278*da0073e9SAndroid Build Coastguard Worker            RandomSampler(self.dataset, replacement=False),
2279*da0073e9SAndroid Build Coastguard Worker            WeightedRandomSampler(weights, num_samples=5, replacement=True),
2280*da0073e9SAndroid Build Coastguard Worker            WeightedRandomSampler(weights, num_samples=5, replacement=False),
2281*da0073e9SAndroid Build Coastguard Worker            SubsetRandomSampler(range(10)),
2282*da0073e9SAndroid Build Coastguard Worker        ):
2283*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(0)
2284*da0073e9SAndroid Build Coastguard Worker            l1 = list(sampler) + list(sampler)
2285*da0073e9SAndroid Build Coastguard Worker
2286*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(0)
2287*da0073e9SAndroid Build Coastguard Worker            l2 = list(sampler) + list(sampler)
2288*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(l1, l2)
2289*da0073e9SAndroid Build Coastguard Worker
2290*da0073e9SAndroid Build Coastguard Worker            its = (iter(sampler), iter(sampler))
2291*da0073e9SAndroid Build Coastguard Worker            ls = ([], [])
2292*da0073e9SAndroid Build Coastguard Worker            for idx in range(len(sampler)):
2293*da0073e9SAndroid Build Coastguard Worker                for i in range(2):
2294*da0073e9SAndroid Build Coastguard Worker                    if idx == 0:
2295*da0073e9SAndroid Build Coastguard Worker                        torch.manual_seed(0)
2296*da0073e9SAndroid Build Coastguard Worker                    ls[i].append(next(its[i]))
2297*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ls[0], ls[1])
2298*da0073e9SAndroid Build Coastguard Worker
2299*da0073e9SAndroid Build Coastguard Worker    def _test_sampler(self, **kwargs):
2300*da0073e9SAndroid Build Coastguard Worker        indices = range(2, 12)  # using a regular iterable
2301*da0073e9SAndroid Build Coastguard Worker        dl = self._get_data_loader(
2302*da0073e9SAndroid Build Coastguard Worker            self.dataset, sampler=indices, batch_size=2, **kwargs
2303*da0073e9SAndroid Build Coastguard Worker        )
2304*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(dl), 5)
2305*da0073e9SAndroid Build Coastguard Worker        for i, (input, _target) in enumerate(dl):
2306*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(input), 2)
2307*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input, self.data[i * 2 + 2 : i * 2 + 4])
2308*da0073e9SAndroid Build Coastguard Worker
2309*da0073e9SAndroid Build Coastguard Worker    def test_sampler(self):
2310*da0073e9SAndroid Build Coastguard Worker        self._test_sampler()
2311*da0073e9SAndroid Build Coastguard Worker        self._test_sampler(num_workers=4)
2312*da0073e9SAndroid Build Coastguard Worker        if not NO_MULTIPROCESSING_SPAWN:
2313*da0073e9SAndroid Build Coastguard Worker            self._test_batch_sampler(num_workers=4, multiprocessing_context="spawn")
2314*da0073e9SAndroid Build Coastguard Worker
2315*da0073e9SAndroid Build Coastguard Worker    def _test_batch_sampler(self, **kwargs):
2316*da0073e9SAndroid Build Coastguard Worker        # [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...]
2317*da0073e9SAndroid Build Coastguard Worker        batches = []  # using a regular iterable
2318*da0073e9SAndroid Build Coastguard Worker        for i in range(0, 20, 5):
2319*da0073e9SAndroid Build Coastguard Worker            batches.append(tuple(range(i, i + 2)))
2320*da0073e9SAndroid Build Coastguard Worker            batches.append(tuple(range(i + 2, i + 5)))
2321*da0073e9SAndroid Build Coastguard Worker
2322*da0073e9SAndroid Build Coastguard Worker        dl = self._get_data_loader(self.dataset, batch_sampler=batches, **kwargs)
2323*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(dl), 8)
2324*da0073e9SAndroid Build Coastguard Worker        for i, (input, _target) in enumerate(dl):
2325*da0073e9SAndroid Build Coastguard Worker            if i % 2 == 0:
2326*da0073e9SAndroid Build Coastguard Worker                offset = i * 5 // 2
2327*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(input), 2)
2328*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(input, self.data[offset : offset + 2])
2329*da0073e9SAndroid Build Coastguard Worker            else:
2330*da0073e9SAndroid Build Coastguard Worker                offset = i * 5 // 2
2331*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(input), 3)
2332*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(input, self.data[offset : offset + 3])
2333*da0073e9SAndroid Build Coastguard Worker
2334*da0073e9SAndroid Build Coastguard Worker    def test_batch_sampler(self):
2335*da0073e9SAndroid Build Coastguard Worker        self._test_batch_sampler()
2336*da0073e9SAndroid Build Coastguard Worker        self._test_batch_sampler(num_workers=4)
2337*da0073e9SAndroid Build Coastguard Worker        if not NO_MULTIPROCESSING_SPAWN:
2338*da0073e9SAndroid Build Coastguard Worker            self._test_batch_sampler(num_workers=4, multiprocessing_context="spawn")
2339*da0073e9SAndroid Build Coastguard Worker
2340*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
2341*da0073e9SAndroid Build Coastguard Worker    def test_shuffle_pin_memory(self):
2342*da0073e9SAndroid Build Coastguard Worker        loader = self._get_data_loader(
2343*da0073e9SAndroid Build Coastguard Worker            self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True
2344*da0073e9SAndroid Build Coastguard Worker        )
2345*da0073e9SAndroid Build Coastguard Worker        for input, target in loader:
2346*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(input.is_pinned())
2347*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(target.is_pinned())
2348*da0073e9SAndroid Build Coastguard Worker
2349*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2350*da0073e9SAndroid Build Coastguard Worker    def test_numpy(self):
2351*da0073e9SAndroid Build Coastguard Worker        import numpy as np
2352*da0073e9SAndroid Build Coastguard Worker
2353*da0073e9SAndroid Build Coastguard Worker        class TestDataset(torch.utils.data.Dataset):
2354*da0073e9SAndroid Build Coastguard Worker            def __getitem__(self, i):
2355*da0073e9SAndroid Build Coastguard Worker                return np.ones((2, 3, 4)) * i
2356*da0073e9SAndroid Build Coastguard Worker
2357*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
2358*da0073e9SAndroid Build Coastguard Worker                return 1000
2359*da0073e9SAndroid Build Coastguard Worker
2360*da0073e9SAndroid Build Coastguard Worker        loader = self._get_data_loader(TestDataset(), batch_size=12)
2361*da0073e9SAndroid Build Coastguard Worker        batch = next(iter(loader))
2362*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(batch, torch.DoubleTensor)
2363*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
2364*da0073e9SAndroid Build Coastguard Worker
2365*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2366*da0073e9SAndroid Build Coastguard Worker    def test_numpy_gen_state(self):
2367*da0073e9SAndroid Build Coastguard Worker        from torch.utils.data._utils.worker import _generate_state
2368*da0073e9SAndroid Build Coastguard Worker
2369*da0073e9SAndroid Build Coastguard Worker        # Using NumPy generated states as the reference to test `_generate_state`
2370*da0073e9SAndroid Build Coastguard Worker        # having the same result.
2371*da0073e9SAndroid Build Coastguard Worker        # Test case: ((worker_id, base_seed), expected_state)
2372*da0073e9SAndroid Build Coastguard Worker        test_cases = [
2373*da0073e9SAndroid Build Coastguard Worker            (
2374*da0073e9SAndroid Build Coastguard Worker                (4, 13434589827475259383),
2375*da0073e9SAndroid Build Coastguard Worker                (2884386318, 1088094898, 3523808998, 3860348662),
2376*da0073e9SAndroid Build Coastguard Worker            ),
2377*da0073e9SAndroid Build Coastguard Worker            ((1, 15014285634777110771), (1934848465, 763213760, 2959016433, 179751970)),
2378*da0073e9SAndroid Build Coastguard Worker            (
2379*da0073e9SAndroid Build Coastguard Worker                (10, 978296274032934101),
2380*da0073e9SAndroid Build Coastguard Worker                (1759791917, 3550927336, 1225977135, 1036538043),
2381*da0073e9SAndroid Build Coastguard Worker            ),
2382*da0073e9SAndroid Build Coastguard Worker            (
2383*da0073e9SAndroid Build Coastguard Worker                (12, 11868770762134256968),
2384*da0073e9SAndroid Build Coastguard Worker                (3974661794, 3331131333, 3630387033, 2885815368),
2385*da0073e9SAndroid Build Coastguard Worker            ),
2386*da0073e9SAndroid Build Coastguard Worker            (
2387*da0073e9SAndroid Build Coastguard Worker                (9, 15378787925219019706),
2388*da0073e9SAndroid Build Coastguard Worker                (3815056996, 3162224466, 2735102421, 3190253477),
2389*da0073e9SAndroid Build Coastguard Worker            ),
2390*da0073e9SAndroid Build Coastguard Worker            ((5, 9055612723125076328), (3522565701, 3368424109, 959377806, 621878693)),
2391*da0073e9SAndroid Build Coastguard Worker            (
2392*da0073e9SAndroid Build Coastguard Worker                (15, 14617792358407278405),
2393*da0073e9SAndroid Build Coastguard Worker                (3402479508, 1588702753, 1169536393, 3675067356),
2394*da0073e9SAndroid Build Coastguard Worker            ),
2395*da0073e9SAndroid Build Coastguard Worker            (
2396*da0073e9SAndroid Build Coastguard Worker                (9, 17363320784006640087),
2397*da0073e9SAndroid Build Coastguard Worker                (957989458, 2518334477, 1421725660, 3086155459),
2398*da0073e9SAndroid Build Coastguard Worker            ),
2399*da0073e9SAndroid Build Coastguard Worker            (
2400*da0073e9SAndroid Build Coastguard Worker                (12, 480002904169484764),
2401*da0073e9SAndroid Build Coastguard Worker                (2732851467, 1762620729, 4055801988, 1277640511),
2402*da0073e9SAndroid Build Coastguard Worker            ),
2403*da0073e9SAndroid Build Coastguard Worker            (
2404*da0073e9SAndroid Build Coastguard Worker                (15, 16803975943592702950),
2405*da0073e9SAndroid Build Coastguard Worker                (3479415043, 4022359553, 295994005, 3358606349),
2406*da0073e9SAndroid Build Coastguard Worker            ),
2407*da0073e9SAndroid Build Coastguard Worker            (
2408*da0073e9SAndroid Build Coastguard Worker                (9, 11704776406047813044),
2409*da0073e9SAndroid Build Coastguard Worker                (1968928009, 710113752, 2442656196, 1587420279),
2410*da0073e9SAndroid Build Coastguard Worker            ),
2411*da0073e9SAndroid Build Coastguard Worker            (
2412*da0073e9SAndroid Build Coastguard Worker                (10, 16357891985431864516),
2413*da0073e9SAndroid Build Coastguard Worker                (1271733898, 4197047399, 3727213786, 2338547348),
2414*da0073e9SAndroid Build Coastguard Worker            ),
2415*da0073e9SAndroid Build Coastguard Worker            (
2416*da0073e9SAndroid Build Coastguard Worker                (2, 17423369006318065007),
2417*da0073e9SAndroid Build Coastguard Worker                (544294336, 1911284083, 3299147734, 3231058347),
2418*da0073e9SAndroid Build Coastguard Worker            ),
2419*da0073e9SAndroid Build Coastguard Worker            ((2, 2889492011444113593), (3721591783, 2595811276, 2212881745, 977682627)),
2420*da0073e9SAndroid Build Coastguard Worker            ((0, 8979703111668486195), (4276723937, 2556068849, 2962827292, 233130238)),
2421*da0073e9SAndroid Build Coastguard Worker            (
2422*da0073e9SAndroid Build Coastguard Worker                (6, 6269787272229682235),
2423*da0073e9SAndroid Build Coastguard Worker                (2548857855, 1216457374, 1012973562, 2999759647),
2424*da0073e9SAndroid Build Coastguard Worker            ),
2425*da0073e9SAndroid Build Coastguard Worker        ]
2426*da0073e9SAndroid Build Coastguard Worker
2427*da0073e9SAndroid Build Coastguard Worker        for (worker_id, base_seed), exp in test_cases:
2428*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(exp, _generate_state(base_seed, worker_id))
2429*da0073e9SAndroid Build Coastguard Worker
2430*da0073e9SAndroid Build Coastguard Worker    def test_error(self):
2431*da0073e9SAndroid Build Coastguard Worker        self._test_error(
2432*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(ErrorDataset(100), batch_size=2, shuffle=True)
2433*da0073e9SAndroid Build Coastguard Worker        )
2434*da0073e9SAndroid Build Coastguard Worker
2435*da0073e9SAndroid Build Coastguard Worker    def test_error_workers(self):
2436*da0073e9SAndroid Build Coastguard Worker        self._test_error(
2437*da0073e9SAndroid Build Coastguard Worker            self._get_data_loader(
2438*da0073e9SAndroid Build Coastguard Worker                ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4
2439*da0073e9SAndroid Build Coastguard Worker            )
2440*da0073e9SAndroid Build Coastguard Worker        )
2441*da0073e9SAndroid Build Coastguard Worker
2442*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "FIXME: stuck test")
2443*da0073e9SAndroid Build Coastguard Worker    def test_partial_workers(self):
2444*da0073e9SAndroid Build Coastguard Worker        r"""Check that workers exit even if the iterator is not exhausted."""
2445*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
2446*da0073e9SAndroid Build Coastguard Worker            pin_memory_configs = (True, False)
2447*da0073e9SAndroid Build Coastguard Worker        else:
2448*da0073e9SAndroid Build Coastguard Worker            pin_memory_configs = (False,)
2449*da0073e9SAndroid Build Coastguard Worker
2450*da0073e9SAndroid Build Coastguard Worker        for pin_memory in pin_memory_configs:
2451*da0073e9SAndroid Build Coastguard Worker            loader = iter(
2452*da0073e9SAndroid Build Coastguard Worker                self._get_data_loader(
2453*da0073e9SAndroid Build Coastguard Worker                    self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory
2454*da0073e9SAndroid Build Coastguard Worker                )
2455*da0073e9SAndroid Build Coastguard Worker            )
2456*da0073e9SAndroid Build Coastguard Worker            workers = loader._workers
2457*da0073e9SAndroid Build Coastguard Worker            if pin_memory:
2458*da0073e9SAndroid Build Coastguard Worker                pin_memory_thread = loader._pin_memory_thread
2459*da0073e9SAndroid Build Coastguard Worker            for i, _ in enumerate(loader):
2460*da0073e9SAndroid Build Coastguard Worker                if i == 10:
2461*da0073e9SAndroid Build Coastguard Worker                    break
2462*da0073e9SAndroid Build Coastguard Worker            assert i == 10
2463*da0073e9SAndroid Build Coastguard Worker            del loader
2464*da0073e9SAndroid Build Coastguard Worker            for w in workers:
2465*da0073e9SAndroid Build Coastguard Worker                w.join(JOIN_TIMEOUT)
2466*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(w.is_alive(), "subprocess not terminated")
2467*da0073e9SAndroid Build Coastguard Worker            if pin_memory:
2468*da0073e9SAndroid Build Coastguard Worker                pin_memory_thread.join(JOIN_TIMEOUT)
2469*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(pin_memory_thread.is_alive())
2470*da0073e9SAndroid Build Coastguard Worker
2471*da0073e9SAndroid Build Coastguard Worker    # Takes 2.5min to finish, see https://github.com/pytorch/pytorch/issues/46065
2472*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
2473*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not HAS_PSUTIL, "psutil not found")
2474*da0073e9SAndroid Build Coastguard Worker    @slowTest
2475*da0073e9SAndroid Build Coastguard Worker    def test_proper_exit(self):
2476*da0073e9SAndroid Build Coastguard Worker        (
2477*da0073e9SAndroid Build Coastguard Worker            r"""There might be ConnectionResetError or leaked semaphore warning """
2478*da0073e9SAndroid Build Coastguard Worker            r"""(due to dirty process exit), but they are all safe to ignore"""
2479*da0073e9SAndroid Build Coastguard Worker        )
2480*da0073e9SAndroid Build Coastguard Worker
2481*da0073e9SAndroid Build Coastguard Worker        # TODO: test the case where the pin_memory_thread triggers an
2482*da0073e9SAndroid Build Coastguard Worker        #       error/fatal signal. I haven't found out how to properly do that.
2483*da0073e9SAndroid Build Coastguard Worker
2484*da0073e9SAndroid Build Coastguard Worker        for (
2485*da0073e9SAndroid Build Coastguard Worker            is_iterable_dataset,
2486*da0073e9SAndroid Build Coastguard Worker            use_workers,
2487*da0073e9SAndroid Build Coastguard Worker            pin_memory,
2488*da0073e9SAndroid Build Coastguard Worker            hold_iter_reference,
2489*da0073e9SAndroid Build Coastguard Worker        ) in itertools.product([True, False], repeat=4):
2490*da0073e9SAndroid Build Coastguard Worker            # `hold_iter_reference` specifies whether we hold a reference to the
2491*da0073e9SAndroid Build Coastguard Worker            # iterator. This is interesting because Python3 error traces holds a
2492*da0073e9SAndroid Build Coastguard Worker            # reference to the frames, which hold references to all the local
2493*da0073e9SAndroid Build Coastguard Worker            # variables including the iterator, and then the iterator dtor may
2494*da0073e9SAndroid Build Coastguard Worker            # not be called before process end. It is important to see that the
2495*da0073e9SAndroid Build Coastguard Worker            # processes still exit in both cases.
2496*da0073e9SAndroid Build Coastguard Worker
2497*da0073e9SAndroid Build Coastguard Worker            if pin_memory and (not TEST_CUDA or NO_MULTIPROCESSING_SPAWN or IS_WINDOWS):
2498*da0073e9SAndroid Build Coastguard Worker                # This test runs in a subprocess, which can only initialize CUDA with spawn.
2499*da0073e9SAndroid Build Coastguard Worker                # DataLoader with pin_memory=True initializes CUDA when its iterator is constructed.
2500*da0073e9SAndroid Build Coastguard Worker                # For windows, pin_memory sometimes causes CUDA oom.
2501*da0073e9SAndroid Build Coastguard Worker                continue
2502*da0073e9SAndroid Build Coastguard Worker
2503*da0073e9SAndroid Build Coastguard Worker            # `exit_method` controls the way the loader process ends.
2504*da0073e9SAndroid Build Coastguard Worker            #   - `*_kill` means that `*` is killed by OS.
2505*da0073e9SAndroid Build Coastguard Worker            #   - `*_error` means that `*` raises an error.
2506*da0073e9SAndroid Build Coastguard Worker            #   - `None` means that no error happens.
2507*da0073e9SAndroid Build Coastguard Worker            # In all cases, all processes should end properly.
2508*da0073e9SAndroid Build Coastguard Worker            if use_workers:
2509*da0073e9SAndroid Build Coastguard Worker                # TODO: Fix test for 'loader_kill' that would cause running out of shared memory.
2510*da0073e9SAndroid Build Coastguard Worker                # Killing loader process would prevent DataLoader iterator clean up all queues
2511*da0073e9SAndroid Build Coastguard Worker                # and worker processes
2512*da0073e9SAndroid Build Coastguard Worker                exit_methods = [None, "loader_error", "worker_error", "worker_kill"]
2513*da0073e9SAndroid Build Coastguard Worker                persistent_workers = self.persistent_workers
2514*da0073e9SAndroid Build Coastguard Worker            else:
2515*da0073e9SAndroid Build Coastguard Worker                exit_methods = [None, "loader_error", "loader_kill"]
2516*da0073e9SAndroid Build Coastguard Worker                persistent_workers = False
2517*da0073e9SAndroid Build Coastguard Worker
2518*da0073e9SAndroid Build Coastguard Worker            for exit_method in exit_methods:
2519*da0073e9SAndroid Build Coastguard Worker                if exit_method == "worker_kill":
2520*da0073e9SAndroid Build Coastguard Worker                    # FIXME: This sometimes hangs. See #16608.
2521*da0073e9SAndroid Build Coastguard Worker                    continue
2522*da0073e9SAndroid Build Coastguard Worker
2523*da0073e9SAndroid Build Coastguard Worker                desc = []
2524*da0073e9SAndroid Build Coastguard Worker                desc.append(f"is_iterable_dataset={is_iterable_dataset}")
2525*da0073e9SAndroid Build Coastguard Worker                desc.append(f"use_workers={use_workers}")
2526*da0073e9SAndroid Build Coastguard Worker                desc.append(f"pin_memory={pin_memory}")
2527*da0073e9SAndroid Build Coastguard Worker                desc.append(f"hold_iter_reference={hold_iter_reference}")
2528*da0073e9SAndroid Build Coastguard Worker                desc.append(f"exit_method={exit_method}")
2529*da0073e9SAndroid Build Coastguard Worker                desc = "test_proper_exit with " + ", ".join(desc)
2530*da0073e9SAndroid Build Coastguard Worker
2531*da0073e9SAndroid Build Coastguard Worker                # Event that the loader process uses to signal testing process
2532*da0073e9SAndroid Build Coastguard Worker                # that various things are setup, including that the worker pids
2533*da0073e9SAndroid Build Coastguard Worker                # are specified in `worker_pids` array.
2534*da0073e9SAndroid Build Coastguard Worker                loader_setup_event = mp.Event()
2535*da0073e9SAndroid Build Coastguard Worker
2536*da0073e9SAndroid Build Coastguard Worker                # Event that this process has finished setting up, and the
2537*da0073e9SAndroid Build Coastguard Worker                # loader process can now proceed to trigger error events or
2538*da0073e9SAndroid Build Coastguard Worker                # finish normally.
2539*da0073e9SAndroid Build Coastguard Worker                tester_setup_event = mp.Event()
2540*da0073e9SAndroid Build Coastguard Worker
2541*da0073e9SAndroid Build Coastguard Worker                loader_p = ErrorTrackingProcess(
2542*da0073e9SAndroid Build Coastguard Worker                    target=_test_proper_exit,
2543*da0073e9SAndroid Build Coastguard Worker                    args=(
2544*da0073e9SAndroid Build Coastguard Worker                        is_iterable_dataset,
2545*da0073e9SAndroid Build Coastguard Worker                        use_workers,
2546*da0073e9SAndroid Build Coastguard Worker                        pin_memory,
2547*da0073e9SAndroid Build Coastguard Worker                        exit_method,
2548*da0073e9SAndroid Build Coastguard Worker                        hold_iter_reference,
2549*da0073e9SAndroid Build Coastguard Worker                        loader_setup_event,
2550*da0073e9SAndroid Build Coastguard Worker                        tester_setup_event,
2551*da0073e9SAndroid Build Coastguard Worker                        persistent_workers,
2552*da0073e9SAndroid Build Coastguard Worker                    ),
2553*da0073e9SAndroid Build Coastguard Worker                    disable_stderr=False,
2554*da0073e9SAndroid Build Coastguard Worker                )
2555*da0073e9SAndroid Build Coastguard Worker                loader_p.start()
2556*da0073e9SAndroid Build Coastguard Worker                loader_psutil_p = psutil.Process(loader_p.pid)
2557*da0073e9SAndroid Build Coastguard Worker
2558*da0073e9SAndroid Build Coastguard Worker                # Wait for loader process to set everything up, e.g., starting
2559*da0073e9SAndroid Build Coastguard Worker                # workers.
2560*da0073e9SAndroid Build Coastguard Worker                loader_setup_event.wait(timeout=JOIN_TIMEOUT)
2561*da0073e9SAndroid Build Coastguard Worker                if not loader_setup_event.is_set():
2562*da0073e9SAndroid Build Coastguard Worker                    fail_msg = (
2563*da0073e9SAndroid Build Coastguard Worker                        desc + ": loader process failed to setup within given time"
2564*da0073e9SAndroid Build Coastguard Worker                    )
2565*da0073e9SAndroid Build Coastguard Worker                    if loader_p.exception is not None:
2566*da0073e9SAndroid Build Coastguard Worker                        fail_msg += f", and had exception {loader_p.exception}"
2567*da0073e9SAndroid Build Coastguard Worker                    elif not loader_p.is_alive():
2568*da0073e9SAndroid Build Coastguard Worker                        fail_msg += f", and exited with code {loader_p.exitcode} but had no exception"
2569*da0073e9SAndroid Build Coastguard Worker                    else:
2570*da0073e9SAndroid Build Coastguard Worker                        fail_msg += ", and is still alive."
2571*da0073e9SAndroid Build Coastguard Worker                    if loader_p.is_alive():
2572*da0073e9SAndroid Build Coastguard Worker                        # this may kill the process, needs to run after the above lines
2573*da0073e9SAndroid Build Coastguard Worker                        loader_p.print_traces_of_all_threads()
2574*da0073e9SAndroid Build Coastguard Worker                    self.fail(fail_msg)
2575*da0073e9SAndroid Build Coastguard Worker
2576*da0073e9SAndroid Build Coastguard Worker                # We are certain that the workers have started now.
2577*da0073e9SAndroid Build Coastguard Worker                worker_psutil_ps = loader_psutil_p.children()
2578*da0073e9SAndroid Build Coastguard Worker
2579*da0073e9SAndroid Build Coastguard Worker                def fail(reason):
2580*da0073e9SAndroid Build Coastguard Worker                    report_psutil_attrs = [
2581*da0073e9SAndroid Build Coastguard Worker                        "pid",
2582*da0073e9SAndroid Build Coastguard Worker                        "name",
2583*da0073e9SAndroid Build Coastguard Worker                        "cpu_times",
2584*da0073e9SAndroid Build Coastguard Worker                        "io_counters",
2585*da0073e9SAndroid Build Coastguard Worker                        "memory_full_info",
2586*da0073e9SAndroid Build Coastguard Worker                        "num_ctx_switches",
2587*da0073e9SAndroid Build Coastguard Worker                        "open_files",
2588*da0073e9SAndroid Build Coastguard Worker                        "threads",
2589*da0073e9SAndroid Build Coastguard Worker                        "status",
2590*da0073e9SAndroid Build Coastguard Worker                        "nice",
2591*da0073e9SAndroid Build Coastguard Worker                        "ionice",
2592*da0073e9SAndroid Build Coastguard Worker                    ]
2593*da0073e9SAndroid Build Coastguard Worker                    if reason is None:
2594*da0073e9SAndroid Build Coastguard Worker                        err_msg = desc
2595*da0073e9SAndroid Build Coastguard Worker                    else:
2596*da0073e9SAndroid Build Coastguard Worker                        err_msg = f"{desc}: {reason}"
2597*da0073e9SAndroid Build Coastguard Worker                    err_msg += "\nLoader info:\n\t"
2598*da0073e9SAndroid Build Coastguard Worker                    if loader_psutil_p.is_running():
2599*da0073e9SAndroid Build Coastguard Worker                        err_msg += str(
2600*da0073e9SAndroid Build Coastguard Worker                            loader_psutil_p.as_dict(attrs=report_psutil_attrs)
2601*da0073e9SAndroid Build Coastguard Worker                        )
2602*da0073e9SAndroid Build Coastguard Worker                        # this may kill the process, needs to run after the above line
2603*da0073e9SAndroid Build Coastguard Worker                        loader_p.print_traces_of_all_threads()
2604*da0073e9SAndroid Build Coastguard Worker                    else:
2605*da0073e9SAndroid Build Coastguard Worker                        err_msg += f"exited with code {loader_p.exitcode}"
2606*da0073e9SAndroid Build Coastguard Worker                    if use_workers:
2607*da0073e9SAndroid Build Coastguard Worker                        err_msg += "\nWorker(s) info:"
2608*da0073e9SAndroid Build Coastguard Worker                        for idx, worker_psutil_p in enumerate(worker_psutil_ps):
2609*da0073e9SAndroid Build Coastguard Worker                            err_msg += f"\n\tWorker {idx}:\n\t\t"
2610*da0073e9SAndroid Build Coastguard Worker                            if worker_psutil_p.is_running():
2611*da0073e9SAndroid Build Coastguard Worker                                err_msg += str(
2612*da0073e9SAndroid Build Coastguard Worker                                    worker_psutil_p.as_dict(attrs=report_psutil_attrs)
2613*da0073e9SAndroid Build Coastguard Worker                                )
2614*da0073e9SAndroid Build Coastguard Worker                                # this may kill the process, needs to run after the above line
2615*da0073e9SAndroid Build Coastguard Worker                                print_traces_of_all_threads(worker_psutil_p.pid)
2616*da0073e9SAndroid Build Coastguard Worker                            else:
2617*da0073e9SAndroid Build Coastguard Worker                                err_msg += "exited with unknown code"
2618*da0073e9SAndroid Build Coastguard Worker                    self.fail(err_msg)
2619*da0073e9SAndroid Build Coastguard Worker
2620*da0073e9SAndroid Build Coastguard Worker                tester_setup_event.set()
2621*da0073e9SAndroid Build Coastguard Worker
2622*da0073e9SAndroid Build Coastguard Worker                try:
2623*da0073e9SAndroid Build Coastguard Worker                    loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
2624*da0073e9SAndroid Build Coastguard Worker                    if loader_p.is_alive():
2625*da0073e9SAndroid Build Coastguard Worker                        fail_reason = "loader process did not terminate"
2626*da0073e9SAndroid Build Coastguard Worker                        if loader_p.exception is not None:
2627*da0073e9SAndroid Build Coastguard Worker                            fail(
2628*da0073e9SAndroid Build Coastguard Worker                                fail_reason
2629*da0073e9SAndroid Build Coastguard Worker                                + f", and had exception {loader_p.exception}"
2630*da0073e9SAndroid Build Coastguard Worker                            )
2631*da0073e9SAndroid Build Coastguard Worker                        else:
2632*da0073e9SAndroid Build Coastguard Worker                            fail(fail_reason + ", and had no exception")
2633*da0073e9SAndroid Build Coastguard Worker                    _, alive = psutil.wait_procs(
2634*da0073e9SAndroid Build Coastguard Worker                        worker_psutil_ps,
2635*da0073e9SAndroid Build Coastguard Worker                        timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT),
2636*da0073e9SAndroid Build Coastguard Worker                    )
2637*da0073e9SAndroid Build Coastguard Worker                    if len(alive) > 0:
2638*da0073e9SAndroid Build Coastguard Worker                        fail(
2639*da0073e9SAndroid Build Coastguard Worker                            "worker process (pid(s) {}) did not terminate".format(
2640*da0073e9SAndroid Build Coastguard Worker                                ", ".join(str(p.pid) for p in alive)
2641*da0073e9SAndroid Build Coastguard Worker                            )
2642*da0073e9SAndroid Build Coastguard Worker                        )
2643*da0073e9SAndroid Build Coastguard Worker                    if exit_method is None:
2644*da0073e9SAndroid Build Coastguard Worker                        if loader_p.exitcode != 0:
2645*da0073e9SAndroid Build Coastguard Worker                            fail(
2646*da0073e9SAndroid Build Coastguard Worker                                f"loader process had nonzero exitcode {loader_p.exitcode}"
2647*da0073e9SAndroid Build Coastguard Worker                            )
2648*da0073e9SAndroid Build Coastguard Worker                    else:
2649*da0073e9SAndroid Build Coastguard Worker                        if loader_p.exitcode == 0:
2650*da0073e9SAndroid Build Coastguard Worker                            fail("loader process had zero exitcode")
2651*da0073e9SAndroid Build Coastguard Worker                        if exit_method == "loader_error":
2652*da0073e9SAndroid Build Coastguard Worker                            if not isinstance(
2653*da0073e9SAndroid Build Coastguard Worker                                loader_p.exception, RuntimeError
2654*da0073e9SAndroid Build Coastguard Worker                            ) or "Loader error" not in str(loader_p.exception):
2655*da0073e9SAndroid Build Coastguard Worker                                fail(
2656*da0073e9SAndroid Build Coastguard Worker                                    f"loader process did not raise expected exception, but had {loader_p.exception}"
2657*da0073e9SAndroid Build Coastguard Worker                                )
2658*da0073e9SAndroid Build Coastguard Worker                        elif exit_method == "worker_kill":
2659*da0073e9SAndroid Build Coastguard Worker                            if isinstance(loader_p.exception, RuntimeError):
2660*da0073e9SAndroid Build Coastguard Worker                                if "DataLoader worker (pid" not in str(
2661*da0073e9SAndroid Build Coastguard Worker                                    loader_p.exception
2662*da0073e9SAndroid Build Coastguard Worker                                ):
2663*da0073e9SAndroid Build Coastguard Worker                                    fail(
2664*da0073e9SAndroid Build Coastguard Worker                                        f"loader process did not raise expected exception, but had {loader_p.exception}"
2665*da0073e9SAndroid Build Coastguard Worker                                    )
2666*da0073e9SAndroid Build Coastguard Worker                            elif isinstance(loader_p.exception, ConnectionRefusedError):
2667*da0073e9SAndroid Build Coastguard Worker                                # Sometimes, when the worker is being killed and is freeing its
2668*da0073e9SAndroid Build Coastguard Worker                                # resources, the unpickling in loader process will be met an
2669*da0073e9SAndroid Build Coastguard Worker                                # a `ConnectionRefusedError` as it can not open a socket to receive
2670*da0073e9SAndroid Build Coastguard Worker                                # resource. In such cases, the worker may not have fully exited,
2671*da0073e9SAndroid Build Coastguard Worker                                # and the loader can't know this via `is_alive` check or `SIGCHLD`
2672*da0073e9SAndroid Build Coastguard Worker                                # handler. So we permit this as an allowed error as well.
2673*da0073e9SAndroid Build Coastguard Worker                                # After all, we are happy as long as it terminates.
2674*da0073e9SAndroid Build Coastguard Worker                                pass
2675*da0073e9SAndroid Build Coastguard Worker                            else:
2676*da0073e9SAndroid Build Coastguard Worker                                fail(
2677*da0073e9SAndroid Build Coastguard Worker                                    f"loader process did not raise expected exception, but had {loader_p.exception}"
2678*da0073e9SAndroid Build Coastguard Worker                                )
2679*da0073e9SAndroid Build Coastguard Worker                        elif exit_method == "worker_error":
2680*da0073e9SAndroid Build Coastguard Worker                            if not isinstance(
2681*da0073e9SAndroid Build Coastguard Worker                                loader_p.exception, RuntimeError
2682*da0073e9SAndroid Build Coastguard Worker                            ) or "Worker error" not in str(loader_p.exception):
2683*da0073e9SAndroid Build Coastguard Worker                                fail(
2684*da0073e9SAndroid Build Coastguard Worker                                    f"loader process did not raise expected exception, but had {loader_p.exception}"
2685*da0073e9SAndroid Build Coastguard Worker                                )
2686*da0073e9SAndroid Build Coastguard Worker                finally:
2687*da0073e9SAndroid Build Coastguard Worker                    loader_p.terminate()
2688*da0073e9SAndroid Build Coastguard Worker
2689*da0073e9SAndroid Build Coastguard Worker    def test_len(self):
2690*da0073e9SAndroid Build Coastguard Worker        def check_len(dl, expected):
2691*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(dl), expected)
2692*da0073e9SAndroid Build Coastguard Worker            n = 0
2693*da0073e9SAndroid Build Coastguard Worker            for _ in dl:
2694*da0073e9SAndroid Build Coastguard Worker                n += 1
2695*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(n, expected)
2696*da0073e9SAndroid Build Coastguard Worker
2697*da0073e9SAndroid Build Coastguard Worker        check_len(self.dataset, 100)
2698*da0073e9SAndroid Build Coastguard Worker        check_len(self._get_data_loader(self.dataset, batch_size=2), 50)
2699*da0073e9SAndroid Build Coastguard Worker        check_len(self._get_data_loader(self.dataset, batch_size=3), 34)
2700*da0073e9SAndroid Build Coastguard Worker
2701*da0073e9SAndroid Build Coastguard Worker    def test_iterabledataset_len(self):
2702*da0073e9SAndroid Build Coastguard Worker        class IterableDataset(torch.utils.data.IterableDataset):
2703*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
2704*da0073e9SAndroid Build Coastguard Worker                return 10
2705*da0073e9SAndroid Build Coastguard Worker
2706*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
2707*da0073e9SAndroid Build Coastguard Worker                return iter(range(10))
2708*da0073e9SAndroid Build Coastguard Worker
2709*da0073e9SAndroid Build Coastguard Worker        iterable_loader = DataLoader(IterableDataset(), batch_size=1)
2710*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(iterable_loader), 10)
2711*da0073e9SAndroid Build Coastguard Worker        iterable_loader = DataLoader(IterableDataset(), batch_size=1, drop_last=True)
2712*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(iterable_loader), 10)
2713*da0073e9SAndroid Build Coastguard Worker
2714*da0073e9SAndroid Build Coastguard Worker        iterable_loader = DataLoader(IterableDataset(), batch_size=2)
2715*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(iterable_loader), 5)
2716*da0073e9SAndroid Build Coastguard Worker        iterable_loader = DataLoader(IterableDataset(), batch_size=2, drop_last=True)
2717*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(iterable_loader), 5)
2718*da0073e9SAndroid Build Coastguard Worker
2719*da0073e9SAndroid Build Coastguard Worker        iterable_loader = DataLoader(IterableDataset(), batch_size=3)
2720*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(iterable_loader), 4)
2721*da0073e9SAndroid Build Coastguard Worker        iterable_loader = DataLoader(IterableDataset(), batch_size=3, drop_last=True)
2722*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(iterable_loader), 3)
2723*da0073e9SAndroid Build Coastguard Worker
2724*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2725*da0073e9SAndroid Build Coastguard Worker    def test_numpy_scalars(self):
2726*da0073e9SAndroid Build Coastguard Worker        import numpy as np
2727*da0073e9SAndroid Build Coastguard Worker
2728*da0073e9SAndroid Build Coastguard Worker        class ScalarDataset(torch.utils.data.Dataset):
2729*da0073e9SAndroid Build Coastguard Worker            def __init__(self, dtype):
2730*da0073e9SAndroid Build Coastguard Worker                self.dtype = dtype
2731*da0073e9SAndroid Build Coastguard Worker
2732*da0073e9SAndroid Build Coastguard Worker            def __getitem__(self, i):
2733*da0073e9SAndroid Build Coastguard Worker                return self.dtype()
2734*da0073e9SAndroid Build Coastguard Worker
2735*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
2736*da0073e9SAndroid Build Coastguard Worker                return 4
2737*da0073e9SAndroid Build Coastguard Worker
2738*da0073e9SAndroid Build Coastguard Worker        dtypes = {
2739*da0073e9SAndroid Build Coastguard Worker            np.float64: torch.DoubleTensor,
2740*da0073e9SAndroid Build Coastguard Worker            np.float32: torch.FloatTensor,
2741*da0073e9SAndroid Build Coastguard Worker            np.float16: torch.HalfTensor,
2742*da0073e9SAndroid Build Coastguard Worker            np.int64: torch.LongTensor,
2743*da0073e9SAndroid Build Coastguard Worker            np.int32: torch.IntTensor,
2744*da0073e9SAndroid Build Coastguard Worker            np.int16: torch.ShortTensor,
2745*da0073e9SAndroid Build Coastguard Worker            np.int8: torch.CharTensor,
2746*da0073e9SAndroid Build Coastguard Worker            np.uint8: torch.ByteTensor,
2747*da0073e9SAndroid Build Coastguard Worker        }
2748*da0073e9SAndroid Build Coastguard Worker        for dt, tt in dtypes.items():
2749*da0073e9SAndroid Build Coastguard Worker            dset = ScalarDataset(dt)
2750*da0073e9SAndroid Build Coastguard Worker            loader = self._get_data_loader(dset, batch_size=2)
2751*da0073e9SAndroid Build Coastguard Worker            batch = next(iter(loader))
2752*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(batch, tt)
2753*da0073e9SAndroid Build Coastguard Worker
2754*da0073e9SAndroid Build Coastguard Worker    def test_default_convert_mapping_keep_type(self):
2755*da0073e9SAndroid Build Coastguard Worker        data = CustomDict({"a": 1, "b": 2})
2756*da0073e9SAndroid Build Coastguard Worker        converted = _utils.collate.default_convert(data)
2757*da0073e9SAndroid Build Coastguard Worker
2758*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(converted, data)
2759*da0073e9SAndroid Build Coastguard Worker
2760*da0073e9SAndroid Build Coastguard Worker    def test_default_convert_sequence_keep_type(self):
2761*da0073e9SAndroid Build Coastguard Worker        data = CustomList([1, 2, 3])
2762*da0073e9SAndroid Build Coastguard Worker        converted = _utils.collate.default_convert(data)
2763*da0073e9SAndroid Build Coastguard Worker
2764*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(converted, data)
2765*da0073e9SAndroid Build Coastguard Worker
2766*da0073e9SAndroid Build Coastguard Worker    def test_default_convert_sequence_dont_keep_type(self):
2767*da0073e9SAndroid Build Coastguard Worker        data = range(2)
2768*da0073e9SAndroid Build Coastguard Worker        converted = _utils.collate.default_convert(data)
2769*da0073e9SAndroid Build Coastguard Worker
2770*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(converted, [0, 1])
2771*da0073e9SAndroid Build Coastguard Worker
2772*da0073e9SAndroid Build Coastguard Worker    def test_default_collate_dtype(self):
2773*da0073e9SAndroid Build Coastguard Worker        arr = [1, 2, -1]
2774*da0073e9SAndroid Build Coastguard Worker        collated = _utils.collate.default_collate(arr)
2775*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(collated, torch.tensor(arr))
2776*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(collated.dtype, torch.int64)
2777*da0073e9SAndroid Build Coastguard Worker
2778*da0073e9SAndroid Build Coastguard Worker        arr = [1.1, 2.3, -0.9]
2779*da0073e9SAndroid Build Coastguard Worker        collated = _utils.collate.default_collate(arr)
2780*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(collated, torch.tensor(arr, dtype=torch.float64))
2781*da0073e9SAndroid Build Coastguard Worker
2782*da0073e9SAndroid Build Coastguard Worker        arr = [True, False]
2783*da0073e9SAndroid Build Coastguard Worker        collated = _utils.collate.default_collate(arr)
2784*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(collated, torch.tensor(arr))
2785*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(collated.dtype, torch.bool)
2786*da0073e9SAndroid Build Coastguard Worker
2787*da0073e9SAndroid Build Coastguard Worker        # Should be a no-op
2788*da0073e9SAndroid Build Coastguard Worker        arr = ["a", "b", "c"]
2789*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(arr, _utils.collate.default_collate(arr))
2790*da0073e9SAndroid Build Coastguard Worker
2791*da0073e9SAndroid Build Coastguard Worker    def test_default_collate_mapping_keep_type(self):
2792*da0073e9SAndroid Build Coastguard Worker        batch = [CustomDict({"a": 1, "b": 2}), CustomDict({"a": 3, "b": 4})]
2793*da0073e9SAndroid Build Coastguard Worker        collated = _utils.collate.default_collate(batch)
2794*da0073e9SAndroid Build Coastguard Worker
2795*da0073e9SAndroid Build Coastguard Worker        expected = CustomDict({"a": torch.tensor([1, 3]), "b": torch.tensor([2, 4])})
2796*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(collated, expected)
2797*da0073e9SAndroid Build Coastguard Worker
2798*da0073e9SAndroid Build Coastguard Worker    def test_default_collate_sequence_keep_type(self):
2799*da0073e9SAndroid Build Coastguard Worker        batch = [CustomList([1, 2, 3]), CustomList([4, 5, 6])]
2800*da0073e9SAndroid Build Coastguard Worker        collated = _utils.collate.default_collate(batch)
2801*da0073e9SAndroid Build Coastguard Worker
2802*da0073e9SAndroid Build Coastguard Worker        expected = CustomList(
2803*da0073e9SAndroid Build Coastguard Worker            [
2804*da0073e9SAndroid Build Coastguard Worker                torch.tensor([1, 4]),
2805*da0073e9SAndroid Build Coastguard Worker                torch.tensor([2, 5]),
2806*da0073e9SAndroid Build Coastguard Worker                torch.tensor([3, 6]),
2807*da0073e9SAndroid Build Coastguard Worker            ]
2808*da0073e9SAndroid Build Coastguard Worker        )
2809*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(collated, expected)
2810*da0073e9SAndroid Build Coastguard Worker
2811*da0073e9SAndroid Build Coastguard Worker    def test_default_collate_sequence_dont_keep_type(self):
2812*da0073e9SAndroid Build Coastguard Worker        batch = [range(2), range(2)]
2813*da0073e9SAndroid Build Coastguard Worker        collated = _utils.collate.default_collate(batch)
2814*da0073e9SAndroid Build Coastguard Worker
2815*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(collated, [torch.tensor([0, 0]), torch.tensor([1, 1])])
2816*da0073e9SAndroid Build Coastguard Worker
2817*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2818*da0073e9SAndroid Build Coastguard Worker    def test_default_collate_bad_numpy_types(self):
2819*da0073e9SAndroid Build Coastguard Worker        import numpy as np
2820*da0073e9SAndroid Build Coastguard Worker
2821*da0073e9SAndroid Build Coastguard Worker        # Should be a no-op
2822*da0073e9SAndroid Build Coastguard Worker        arr = np.array(["a", "b", "c"])
2823*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(arr, _utils.collate.default_collate(arr))
2824*da0073e9SAndroid Build Coastguard Worker
2825*da0073e9SAndroid Build Coastguard Worker        arr = np.array([[["a", "b", "c"]]])
2826*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
2827*da0073e9SAndroid Build Coastguard Worker
2828*da0073e9SAndroid Build Coastguard Worker        arr = np.array([object(), object(), object()])
2829*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
2830*da0073e9SAndroid Build Coastguard Worker
2831*da0073e9SAndroid Build Coastguard Worker        arr = np.array([[[object(), object(), object()]]])
2832*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
2833*da0073e9SAndroid Build Coastguard Worker
2834*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2835*da0073e9SAndroid Build Coastguard Worker    def test_default_collate_numpy_memmap(self):
2836*da0073e9SAndroid Build Coastguard Worker        import numpy as np
2837*da0073e9SAndroid Build Coastguard Worker
2838*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryFile() as f:
2839*da0073e9SAndroid Build Coastguard Worker            arr = np.array([[0, 1], [2, 3], [4, 5], [6, 7]])
2840*da0073e9SAndroid Build Coastguard Worker            arr_memmap = np.memmap(f, dtype=arr.dtype, mode="w+", shape=arr.shape)
2841*da0073e9SAndroid Build Coastguard Worker            arr_memmap[:] = arr[:]
2842*da0073e9SAndroid Build Coastguard Worker            arr_new = np.memmap(f, dtype=arr.dtype, mode="r", shape=arr.shape)
2843*da0073e9SAndroid Build Coastguard Worker            tensor = _utils.collate.default_collate(list(arr_new))
2844*da0073e9SAndroid Build Coastguard Worker
2845*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
2846*da0073e9SAndroid Build Coastguard Worker            (tensor == tensor.new_tensor([[0, 1], [2, 3], [4, 5], [6, 7]])).all().item()
2847*da0073e9SAndroid Build Coastguard Worker        )
2848*da0073e9SAndroid Build Coastguard Worker
2849*da0073e9SAndroid Build Coastguard Worker    def test_default_collate_bad_sequence_type(self):
2850*da0073e9SAndroid Build Coastguard Worker        batch = [["X"], ["X", "X"]]
2851*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: _utils.collate.default_collate(batch))
2852*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
2853*da0073e9SAndroid Build Coastguard Worker            RuntimeError, lambda: _utils.collate.default_collate(batch[::-1])
2854*da0073e9SAndroid Build Coastguard Worker        )
2855*da0073e9SAndroid Build Coastguard Worker
2856*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2857*da0073e9SAndroid Build Coastguard Worker    def test_default_collate_shared_tensor(self):
2858*da0073e9SAndroid Build Coastguard Worker        import numpy as np
2859*da0073e9SAndroid Build Coastguard Worker
2860*da0073e9SAndroid Build Coastguard Worker        t_in = torch.zeros(1)
2861*da0073e9SAndroid Build Coastguard Worker        n_in = np.zeros(1)
2862*da0073e9SAndroid Build Coastguard Worker
2863*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t_in.is_shared(), False)
2864*da0073e9SAndroid Build Coastguard Worker
2865*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), False)
2866*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), False)
2867*da0073e9SAndroid Build Coastguard Worker
2868*da0073e9SAndroid Build Coastguard Worker        # FIXME: fix the following hack that makes `default_collate` believe
2869*da0073e9SAndroid Build Coastguard Worker        #        that it is in a worker process (since it tests
2870*da0073e9SAndroid Build Coastguard Worker        #        `get_worker_info() != None`), even though it is not.
2871*da0073e9SAndroid Build Coastguard Worker        old = _utils.worker._worker_info
2872*da0073e9SAndroid Build Coastguard Worker        try:
2873*da0073e9SAndroid Build Coastguard Worker            _utils.worker._worker_info = "x"
2874*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), True)
2875*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), True)
2876*da0073e9SAndroid Build Coastguard Worker        finally:
2877*da0073e9SAndroid Build Coastguard Worker            _utils.worker._worker_info = old
2878*da0073e9SAndroid Build Coastguard Worker
2879*da0073e9SAndroid Build Coastguard Worker    def test_excessive_thread_creation_warning(self):
2880*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
2881*da0073e9SAndroid Build Coastguard Worker            UserWarning,
2882*da0073e9SAndroid Build Coastguard Worker            r"excessive worker creation might get DataLoader running slow or even freeze",
2883*da0073e9SAndroid Build Coastguard Worker        ):
2884*da0073e9SAndroid Build Coastguard Worker            dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000)
2885*da0073e9SAndroid Build Coastguard Worker
2886*da0073e9SAndroid Build Coastguard Worker
2887*da0073e9SAndroid Build Coastguard Workerclass TestDataLoaderDeviceType(TestCase):
2888*da0073e9SAndroid Build Coastguard Worker    @parametrize(
2889*da0073e9SAndroid Build Coastguard Worker        "context",
2890*da0073e9SAndroid Build Coastguard Worker        [ctx for ctx in supported_multiprocessing_contexts if ctx is not None],
2891*da0073e9SAndroid Build Coastguard Worker    )
2892*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
2893*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_multiprocessing(self, device, context):
2894*da0073e9SAndroid Build Coastguard Worker        # The 'fork' multiprocessing context doesn't work for CUDA so skip it
2895*da0073e9SAndroid Build Coastguard Worker        if "cuda" in device and context == "fork":
2896*da0073e9SAndroid Build Coastguard Worker            # TODO: Skip this better in a better way when the test framework allows
2897*da0073e9SAndroid Build Coastguard Worker            return
2898*da0073e9SAndroid Build Coastguard Worker
2899*da0073e9SAndroid Build Coastguard Worker        dataset = [
2900*da0073e9SAndroid Build Coastguard Worker            torch.nested.nested_tensor([torch.randn(5)], device=device)
2901*da0073e9SAndroid Build Coastguard Worker            for _ in range(10)
2902*da0073e9SAndroid Build Coastguard Worker        ]
2903*da0073e9SAndroid Build Coastguard Worker
2904*da0073e9SAndroid Build Coastguard Worker        pin_memory_settings = [False]
2905*da0073e9SAndroid Build Coastguard Worker        if device == "cpu" and torch.cuda.is_available():
2906*da0073e9SAndroid Build Coastguard Worker            pin_memory_settings.append(True)
2907*da0073e9SAndroid Build Coastguard Worker
2908*da0073e9SAndroid Build Coastguard Worker        for pin_memory in pin_memory_settings:
2909*da0073e9SAndroid Build Coastguard Worker            loader = torch.utils.data.DataLoader(
2910*da0073e9SAndroid Build Coastguard Worker                dataset,
2911*da0073e9SAndroid Build Coastguard Worker                batch_size=1,
2912*da0073e9SAndroid Build Coastguard Worker                num_workers=4,
2913*da0073e9SAndroid Build Coastguard Worker                collate_fn=_clone_collate,
2914*da0073e9SAndroid Build Coastguard Worker                pin_memory=pin_memory,
2915*da0073e9SAndroid Build Coastguard Worker                multiprocessing_context=context,
2916*da0073e9SAndroid Build Coastguard Worker            )
2917*da0073e9SAndroid Build Coastguard Worker
2918*da0073e9SAndroid Build Coastguard Worker            for i, batch in enumerate(loader):
2919*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(batch[0], dataset[i])
2920*da0073e9SAndroid Build Coastguard Worker
2921*da0073e9SAndroid Build Coastguard Worker        # Error case: default collate_fn doesn't currently support batches of nested tensors.
2922*da0073e9SAndroid Build Coastguard Worker        # Following the current semantics, we'd need to stack them, which isn't possible atm.
2923*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2924*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "not currently supported by the default collate_fn"
2925*da0073e9SAndroid Build Coastguard Worker        ):
2926*da0073e9SAndroid Build Coastguard Worker            loader = torch.utils.data.DataLoader(
2927*da0073e9SAndroid Build Coastguard Worker                dataset,
2928*da0073e9SAndroid Build Coastguard Worker                batch_size=1,
2929*da0073e9SAndroid Build Coastguard Worker                num_workers=4,
2930*da0073e9SAndroid Build Coastguard Worker                multiprocessing_context=context,
2931*da0073e9SAndroid Build Coastguard Worker            )
2932*da0073e9SAndroid Build Coastguard Worker
2933*da0073e9SAndroid Build Coastguard Worker            next(iter(loader))
2934*da0073e9SAndroid Build Coastguard Worker
2935*da0073e9SAndroid Build Coastguard Worker
2936*da0073e9SAndroid Build Coastguard Workerclass IntegrationTestDataLoaderDataPipe(TestCase):
2937*da0073e9SAndroid Build Coastguard Worker    r"""
2938*da0073e9SAndroid Build Coastguard Worker    Verify the behavior of a certain ``DataPipes`` with ``DataLoader``
2939*da0073e9SAndroid Build Coastguard Worker    """
2940*da0073e9SAndroid Build Coastguard Worker
2941*da0073e9SAndroid Build Coastguard Worker    def test_shuffler_iterdatapipe(self):
2942*da0073e9SAndroid Build Coastguard Worker        r"""
2943*da0073e9SAndroid Build Coastguard Worker        Verify ``IterDataPipe.shuffle`` is controlled by ``DataLoader``
2944*da0073e9SAndroid Build Coastguard Worker        to generate different seeds deterministically per epoch.
2945*da0073e9SAndroid Build Coastguard Worker        """
2946*da0073e9SAndroid Build Coastguard Worker        exp = list(range(100))
2947*da0073e9SAndroid Build Coastguard Worker
2948*da0073e9SAndroid Build Coastguard Worker        def _create_dp(buffer_size):
2949*da0073e9SAndroid Build Coastguard Worker            input_ds = dp.iter.IterableWrapper(exp)
2950*da0073e9SAndroid Build Coastguard Worker            return input_ds.shuffle(buffer_size=buffer_size).sharding_filter()
2951*da0073e9SAndroid Build Coastguard Worker
2952*da0073e9SAndroid Build Coastguard Worker        for bs in (5, 20, 33):
2953*da0073e9SAndroid Build Coastguard Worker            # Test Deterministic
2954*da0073e9SAndroid Build Coastguard Worker            for num_workers, pw in itertools.product((0, 1, 2), (True, False)):
2955*da0073e9SAndroid Build Coastguard Worker                if num_workers == 0 and pw:
2956*da0073e9SAndroid Build Coastguard Worker                    continue
2957*da0073e9SAndroid Build Coastguard Worker
2958*da0073e9SAndroid Build Coastguard Worker                shuffle_dp = _create_dp(bs)
2959*da0073e9SAndroid Build Coastguard Worker
2960*da0073e9SAndroid Build Coastguard Worker                mp_ctx = "spawn" if num_workers > 0 else None
2961*da0073e9SAndroid Build Coastguard Worker                dl = DataLoader(
2962*da0073e9SAndroid Build Coastguard Worker                    shuffle_dp,
2963*da0073e9SAndroid Build Coastguard Worker                    num_workers=num_workers,
2964*da0073e9SAndroid Build Coastguard Worker                    shuffle=True,
2965*da0073e9SAndroid Build Coastguard Worker                    multiprocessing_context=mp_ctx,
2966*da0073e9SAndroid Build Coastguard Worker                    persistent_workers=pw,
2967*da0073e9SAndroid Build Coastguard Worker                )
2968*da0073e9SAndroid Build Coastguard Worker
2969*da0073e9SAndroid Build Coastguard Worker                # No seed
2970*da0073e9SAndroid Build Coastguard Worker                dl_res_ns = list(dl)
2971*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sorted(dl_res_ns), exp)
2972*da0073e9SAndroid Build Coastguard Worker
2973*da0073e9SAndroid Build Coastguard Worker                # Same seeds
2974*da0073e9SAndroid Build Coastguard Worker                dl_res = []
2975*da0073e9SAndroid Build Coastguard Worker                for epoch in range(2):
2976*da0073e9SAndroid Build Coastguard Worker                    torch.manual_seed(123)
2977*da0073e9SAndroid Build Coastguard Worker                    dl_res.append(list(dl))
2978*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(dl_res[0], dl_res[1])
2979*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sorted(dl_res[0]), exp)
2980*da0073e9SAndroid Build Coastguard Worker
2981*da0073e9SAndroid Build Coastguard Worker                # Different seeds
2982*da0073e9SAndroid Build Coastguard Worker                torch.manual_seed(321)
2983*da0073e9SAndroid Build Coastguard Worker                dl_res.append(list(dl))
2984*da0073e9SAndroid Build Coastguard Worker
2985*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(dl_res[0]), len(dl_res[2]))
2986*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(dl_res[0], dl_res[2])
2987*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sorted(dl_res[0]), sorted(dl_res[2]))
2988*da0073e9SAndroid Build Coastguard Worker
2989*da0073e9SAndroid Build Coastguard Worker                if dl._iterator is not None:
2990*da0073e9SAndroid Build Coastguard Worker                    dl._iterator._shutdown_workers()
2991*da0073e9SAndroid Build Coastguard Worker                    dl._iterator = None
2992*da0073e9SAndroid Build Coastguard Worker                del dl
2993*da0073e9SAndroid Build Coastguard Worker
2994*da0073e9SAndroid Build Coastguard Worker
2995*da0073e9SAndroid Build Coastguard Workerclass StringDataset(Dataset):
2996*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
2997*da0073e9SAndroid Build Coastguard Worker        self.s = "12345"
2998*da0073e9SAndroid Build Coastguard Worker
2999*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
3000*da0073e9SAndroid Build Coastguard Worker        return len(self.s)
3001*da0073e9SAndroid Build Coastguard Worker
3002*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, ndx):
3003*da0073e9SAndroid Build Coastguard Worker        return (self.s[ndx], ndx)
3004*da0073e9SAndroid Build Coastguard Worker
3005*da0073e9SAndroid Build Coastguard Worker
3006*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
3007*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
3008*da0073e9SAndroid Build Coastguard Worker    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3009*da0073e9SAndroid Build Coastguard Worker    "fork is not supported. Dying (set die_after_fork=0 to override)",
3010*da0073e9SAndroid Build Coastguard Worker)
3011*da0073e9SAndroid Build Coastguard Workerclass TestStringDataLoader(TestCase):
3012*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
3013*da0073e9SAndroid Build Coastguard Worker        super().setUp()
3014*da0073e9SAndroid Build Coastguard Worker        self.dataset = StringDataset()
3015*da0073e9SAndroid Build Coastguard Worker
3016*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3017*da0073e9SAndroid Build Coastguard Worker    def test_shuffle_pin_memory(self):
3018*da0073e9SAndroid Build Coastguard Worker        loader = DataLoader(
3019*da0073e9SAndroid Build Coastguard Worker            self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True
3020*da0073e9SAndroid Build Coastguard Worker        )
3021*da0073e9SAndroid Build Coastguard Worker        for s, n in loader:
3022*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(s[0], str)
3023*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(n.is_pinned())
3024*da0073e9SAndroid Build Coastguard Worker
3025*da0073e9SAndroid Build Coastguard Worker
3026*da0073e9SAndroid Build Coastguard Workerclass DictDataset(Dataset):
3027*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
3028*da0073e9SAndroid Build Coastguard Worker        return 4
3029*da0073e9SAndroid Build Coastguard Worker
3030*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, ndx):
3031*da0073e9SAndroid Build Coastguard Worker        return {
3032*da0073e9SAndroid Build Coastguard Worker            "a_tensor": torch.empty(4, 2).fill_(ndx),
3033*da0073e9SAndroid Build Coastguard Worker            "another_dict": {"a_number": ndx},
3034*da0073e9SAndroid Build Coastguard Worker        }
3035*da0073e9SAndroid Build Coastguard Worker
3036*da0073e9SAndroid Build Coastguard Worker
3037*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
3038*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
3039*da0073e9SAndroid Build Coastguard Worker    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3040*da0073e9SAndroid Build Coastguard Worker    "fork is not supported. Dying (set die_after_fork=0 to override)",
3041*da0073e9SAndroid Build Coastguard Worker)
3042*da0073e9SAndroid Build Coastguard Workerclass TestDictDataLoader(TestCase):
3043*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
3044*da0073e9SAndroid Build Coastguard Worker        super().setUp()
3045*da0073e9SAndroid Build Coastguard Worker        self.dataset = DictDataset()
3046*da0073e9SAndroid Build Coastguard Worker
3047*da0073e9SAndroid Build Coastguard Worker    def test_sequential_batch(self):
3048*da0073e9SAndroid Build Coastguard Worker        for persistent_workers in (False, True):
3049*da0073e9SAndroid Build Coastguard Worker            if persistent_workers:
3050*da0073e9SAndroid Build Coastguard Worker                loader = DataLoader(
3051*da0073e9SAndroid Build Coastguard Worker                    self.dataset,
3052*da0073e9SAndroid Build Coastguard Worker                    batch_size=2,
3053*da0073e9SAndroid Build Coastguard Worker                    shuffle=False,
3054*da0073e9SAndroid Build Coastguard Worker                    persistent_workers=persistent_workers,
3055*da0073e9SAndroid Build Coastguard Worker                    num_workers=1,
3056*da0073e9SAndroid Build Coastguard Worker                )
3057*da0073e9SAndroid Build Coastguard Worker            else:
3058*da0073e9SAndroid Build Coastguard Worker                loader = DataLoader(
3059*da0073e9SAndroid Build Coastguard Worker                    self.dataset,
3060*da0073e9SAndroid Build Coastguard Worker                    batch_size=2,
3061*da0073e9SAndroid Build Coastguard Worker                    shuffle=False,
3062*da0073e9SAndroid Build Coastguard Worker                    persistent_workers=persistent_workers,
3063*da0073e9SAndroid Build Coastguard Worker                )
3064*da0073e9SAndroid Build Coastguard Worker            batch_size = loader.batch_size
3065*da0073e9SAndroid Build Coastguard Worker            for i, sample in enumerate(loader):
3066*da0073e9SAndroid Build Coastguard Worker                idx = i * batch_size
3067*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(set(sample.keys()), {"a_tensor", "another_dict"})
3068*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(set(sample["another_dict"].keys()), {"a_number"})
3069*da0073e9SAndroid Build Coastguard Worker
3070*da0073e9SAndroid Build Coastguard Worker                t = sample["a_tensor"]
3071*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.size(), torch.Size([batch_size, 4, 2]))
3072*da0073e9SAndroid Build Coastguard Worker                self.assertTrue((t[0] == idx).all())
3073*da0073e9SAndroid Build Coastguard Worker                self.assertTrue((t[1] == idx + 1).all())
3074*da0073e9SAndroid Build Coastguard Worker
3075*da0073e9SAndroid Build Coastguard Worker                n = sample["another_dict"]["a_number"]
3076*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(n.size(), torch.Size([batch_size]))
3077*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(n[0], idx)
3078*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(n[1], idx + 1)
3079*da0073e9SAndroid Build Coastguard Worker
3080*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3081*da0073e9SAndroid Build Coastguard Worker    def test_pin_memory(self):
3082*da0073e9SAndroid Build Coastguard Worker        loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
3083*da0073e9SAndroid Build Coastguard Worker        for sample in loader:
3084*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(sample["a_tensor"].is_pinned())
3085*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(sample["another_dict"]["a_number"].is_pinned())
3086*da0073e9SAndroid Build Coastguard Worker
3087*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3088*da0073e9SAndroid Build Coastguard Worker    def test_pin_memory_device(self):
3089*da0073e9SAndroid Build Coastguard Worker        loader = DataLoader(
3090*da0073e9SAndroid Build Coastguard Worker            self.dataset, batch_size=2, pin_memory=True, pin_memory_device="cuda"
3091*da0073e9SAndroid Build Coastguard Worker        )
3092*da0073e9SAndroid Build Coastguard Worker        for sample in loader:
3093*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(sample["a_tensor"].is_pinned(device="cuda"))
3094*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(sample["another_dict"]["a_number"].is_pinned(device="cuda"))
3095*da0073e9SAndroid Build Coastguard Worker
3096*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3097*da0073e9SAndroid Build Coastguard Worker    def test_pin_memory_with_only_device(self):
3098*da0073e9SAndroid Build Coastguard Worker        loader = DataLoader(self.dataset, batch_size=2, pin_memory_device="cuda")
3099*da0073e9SAndroid Build Coastguard Worker        for sample in loader:
3100*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(sample["a_tensor"].is_pinned(device="cuda"))
3101*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
3102*da0073e9SAndroid Build Coastguard Worker                sample["another_dict"]["a_number"].is_pinned(device="cuda")
3103*da0073e9SAndroid Build Coastguard Worker            )
3104*da0073e9SAndroid Build Coastguard Worker
3105*da0073e9SAndroid Build Coastguard Worker
3106*da0073e9SAndroid Build Coastguard Workerclass DummyDataset(torch.utils.data.Dataset):
3107*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
3108*da0073e9SAndroid Build Coastguard Worker        self.data = list(range(10))
3109*da0073e9SAndroid Build Coastguard Worker
3110*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
3111*da0073e9SAndroid Build Coastguard Worker        return len(self.data)
3112*da0073e9SAndroid Build Coastguard Worker
3113*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, idx):
3114*da0073e9SAndroid Build Coastguard Worker        if torch.is_tensor(idx):
3115*da0073e9SAndroid Build Coastguard Worker            idx = idx.tolist()
3116*da0073e9SAndroid Build Coastguard Worker        # The persistent workers always maintain the original
3117*da0073e9SAndroid Build Coastguard Worker        # dataset through the dataloader lifetime
3118*da0073e9SAndroid Build Coastguard Worker        # so the attributes will remain the same as the
3119*da0073e9SAndroid Build Coastguard Worker        # first time the workers where spawned (dataloader iteration)
3120*da0073e9SAndroid Build Coastguard Worker        assert self.start == 0
3121*da0073e9SAndroid Build Coastguard Worker        return self.data[idx]
3122*da0073e9SAndroid Build Coastguard Worker
3123*da0073e9SAndroid Build Coastguard Worker
3124*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
3125*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
3126*da0073e9SAndroid Build Coastguard Worker    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3127*da0073e9SAndroid Build Coastguard Worker    "fork is not supported. Dying (set die_after_fork=0 to override)",
3128*da0073e9SAndroid Build Coastguard Worker)
3129*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
3130*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ASAN,
3131*da0073e9SAndroid Build Coastguard Worker    "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223",
3132*da0073e9SAndroid Build Coastguard Worker)
3133*da0073e9SAndroid Build Coastguard Workerclass TestDataLoaderPersistentWorkers(TestDataLoader):
3134*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
3135*da0073e9SAndroid Build Coastguard Worker        super().setUp()
3136*da0073e9SAndroid Build Coastguard Worker        self.persistent_workers = True
3137*da0073e9SAndroid Build Coastguard Worker
3138*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
3139*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows")
3140*da0073e9SAndroid Build Coastguard Worker    def test_fd_limit_exceeded(self):
3141*da0073e9SAndroid Build Coastguard Worker        # See NOTE [ DataLoader on Linux and open files limit ]
3142*da0073e9SAndroid Build Coastguard Worker        import subprocess
3143*da0073e9SAndroid Build Coastguard Worker
3144*da0073e9SAndroid Build Coastguard Worker        subprocess.check_output(
3145*da0073e9SAndroid Build Coastguard Worker            [
3146*da0073e9SAndroid Build Coastguard Worker                sys.executable,
3147*da0073e9SAndroid Build Coastguard Worker                "-c",
3148*da0073e9SAndroid Build Coastguard Worker                """\
3149*da0073e9SAndroid Build Coastguard Workerimport torch
3150*da0073e9SAndroid Build Coastguard Workerimport resource
3151*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data import DataLoader, IterableDataset
3152*da0073e9SAndroid Build Coastguard Worker
3153*da0073e9SAndroid Build Coastguard Workerclass RandomDataset(IterableDataset):
3154*da0073e9SAndroid Build Coastguard Worker    def __init__(self, len, size):
3155*da0073e9SAndroid Build Coastguard Worker        super(RandomDataset).__init__()
3156*da0073e9SAndroid Build Coastguard Worker        self.len = len
3157*da0073e9SAndroid Build Coastguard Worker        self.size = size
3158*da0073e9SAndroid Build Coastguard Worker
3159*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
3160*da0073e9SAndroid Build Coastguard Worker        return self
3161*da0073e9SAndroid Build Coastguard Worker
3162*da0073e9SAndroid Build Coastguard Worker    def __next__(self):
3163*da0073e9SAndroid Build Coastguard Worker        if self.len <= 0:
3164*da0073e9SAndroid Build Coastguard Worker            raise StopIteration
3165*da0073e9SAndroid Build Coastguard Worker        self.len -= 1
3166*da0073e9SAndroid Build Coastguard Worker        return torch.randn(self.size)
3167*da0073e9SAndroid Build Coastguard Worker
3168*da0073e9SAndroid Build Coastguard Workertry:
3169*da0073e9SAndroid Build Coastguard Worker    keep_fds_alive = []
3170*da0073e9SAndroid Build Coastguard Worker    resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100))
3171*da0073e9SAndroid Build Coastguard Worker    for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork",
3172*da0073e9SAndroid Build Coastguard Worker                               num_workers=1, persistent_workers=True):
3173*da0073e9SAndroid Build Coastguard Worker      random_t.max(dim=0)
3174*da0073e9SAndroid Build Coastguard Worker      keep_fds_alive.append(random_t)
3175*da0073e9SAndroid Build Coastguard Workerexcept RuntimeError as e:
3176*da0073e9SAndroid Build Coastguard Worker    assert "ulimit -n" in str(e)
3177*da0073e9SAndroid Build Coastguard Worker    assert "set_sharing_strategy" in str(e)
3178*da0073e9SAndroid Build Coastguard Worker""",
3179*da0073e9SAndroid Build Coastguard Worker            ]
3180*da0073e9SAndroid Build Coastguard Worker        )
3181*da0073e9SAndroid Build Coastguard Worker
3182*da0073e9SAndroid Build Coastguard Worker    def test_dataset_not_reset(self):
3183*da0073e9SAndroid Build Coastguard Worker        dataset = DummyDataset()
3184*da0073e9SAndroid Build Coastguard Worker        pin_memory_configs = [False]
3185*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
3186*da0073e9SAndroid Build Coastguard Worker            pin_memory_configs.append(True)
3187*da0073e9SAndroid Build Coastguard Worker        for pin_memory in pin_memory_configs:
3188*da0073e9SAndroid Build Coastguard Worker            dataloader = self._get_data_loader(
3189*da0073e9SAndroid Build Coastguard Worker                dataset, num_workers=2, pin_memory=pin_memory
3190*da0073e9SAndroid Build Coastguard Worker            )
3191*da0073e9SAndroid Build Coastguard Worker            dataset.start = 0
3192*da0073e9SAndroid Build Coastguard Worker            for i in range(10):
3193*da0073e9SAndroid Build Coastguard Worker                for x in dataloader:
3194*da0073e9SAndroid Build Coastguard Worker                    pass
3195*da0073e9SAndroid Build Coastguard Worker                # Changing the start value here doesn't have any effect in the dataset
3196*da0073e9SAndroid Build Coastguard Worker                # cached by the workers. since they are not recreated between epochs
3197*da0073e9SAndroid Build Coastguard Worker                # and can cache values safely
3198*da0073e9SAndroid Build Coastguard Worker                dataset.start = i
3199*da0073e9SAndroid Build Coastguard Worker
3200*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
3201*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Needs fork")
3202*da0073e9SAndroid Build Coastguard Worker    def test_early_exit(self):
3203*da0073e9SAndroid Build Coastguard Worker        import subprocess
3204*da0073e9SAndroid Build Coastguard Worker
3205*da0073e9SAndroid Build Coastguard Worker        proc = subprocess.check_output(
3206*da0073e9SAndroid Build Coastguard Worker            [
3207*da0073e9SAndroid Build Coastguard Worker                sys.executable,
3208*da0073e9SAndroid Build Coastguard Worker                "-c",
3209*da0073e9SAndroid Build Coastguard Worker                """\
3210*da0073e9SAndroid Build Coastguard Workerimport torch
3211*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data import DataLoader, IterableDataset
3212*da0073e9SAndroid Build Coastguard Worker
3213*da0073e9SAndroid Build Coastguard Workerclass RandomDataset(IterableDataset):
3214*da0073e9SAndroid Build Coastguard Worker    def __init__(self, len, size):
3215*da0073e9SAndroid Build Coastguard Worker        super(RandomDataset).__init__()
3216*da0073e9SAndroid Build Coastguard Worker        self.len = len
3217*da0073e9SAndroid Build Coastguard Worker        self.size = size
3218*da0073e9SAndroid Build Coastguard Worker
3219*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
3220*da0073e9SAndroid Build Coastguard Worker        return self
3221*da0073e9SAndroid Build Coastguard Worker
3222*da0073e9SAndroid Build Coastguard Worker    def __next__(self):
3223*da0073e9SAndroid Build Coastguard Worker        if self.len <= 0:
3224*da0073e9SAndroid Build Coastguard Worker            raise StopIteration
3225*da0073e9SAndroid Build Coastguard Worker        self.len -= 1
3226*da0073e9SAndroid Build Coastguard Worker        return torch.randn(self.size)
3227*da0073e9SAndroid Build Coastguard Worker
3228*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
3229*da0073e9SAndroid Build Coastguard Worker    dl = DataLoader(
3230*da0073e9SAndroid Build Coastguard Worker        RandomDataset(64, (28, 28)),
3231*da0073e9SAndroid Build Coastguard Worker        batch_size=16,
3232*da0073e9SAndroid Build Coastguard Worker        num_workers=2,
3233*da0073e9SAndroid Build Coastguard Worker        pin_memory=True,
3234*da0073e9SAndroid Build Coastguard Worker        persistent_workers=True,
3235*da0073e9SAndroid Build Coastguard Worker        multiprocessing_context="fork",
3236*da0073e9SAndroid Build Coastguard Worker    )
3237*da0073e9SAndroid Build Coastguard Worker
3238*da0073e9SAndroid Build Coastguard Worker    for _ in dl:
3239*da0073e9SAndroid Build Coastguard Worker        break
3240*da0073e9SAndroid Build Coastguard Worker""",
3241*da0073e9SAndroid Build Coastguard Worker            ]
3242*da0073e9SAndroid Build Coastguard Worker        )
3243*da0073e9SAndroid Build Coastguard Worker
3244*da0073e9SAndroid Build Coastguard Worker
3245*da0073e9SAndroid Build Coastguard Workerclass NamedTupleDataset(Dataset):
3246*da0073e9SAndroid Build Coastguard Worker    from collections import namedtuple
3247*da0073e9SAndroid Build Coastguard Worker
3248*da0073e9SAndroid Build Coastguard Worker    Batch = namedtuple("Batch", ["data", "label", "random_tensor"])
3249*da0073e9SAndroid Build Coastguard Worker    Data = namedtuple("Data", ["positive", "negative"])
3250*da0073e9SAndroid Build Coastguard Worker
3251*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
3252*da0073e9SAndroid Build Coastguard Worker        return 4
3253*da0073e9SAndroid Build Coastguard Worker
3254*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, ndx):
3255*da0073e9SAndroid Build Coastguard Worker        return self.Batch(
3256*da0073e9SAndroid Build Coastguard Worker            data=self.Data(positive=ndx, negative=-ndx),
3257*da0073e9SAndroid Build Coastguard Worker            label=str(ndx),
3258*da0073e9SAndroid Build Coastguard Worker            random_tensor=torch.randn(3),
3259*da0073e9SAndroid Build Coastguard Worker        )
3260*da0073e9SAndroid Build Coastguard Worker
3261*da0073e9SAndroid Build Coastguard Worker
3262*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
3263*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
3264*da0073e9SAndroid Build Coastguard Worker    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3265*da0073e9SAndroid Build Coastguard Worker    "fork is not supported. Dying (set die_after_fork=0 to override)",
3266*da0073e9SAndroid Build Coastguard Worker)
3267*da0073e9SAndroid Build Coastguard Workerclass TestNamedTupleDataLoader(TestCase):
3268*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
3269*da0073e9SAndroid Build Coastguard Worker        super().setUp()
3270*da0073e9SAndroid Build Coastguard Worker        self.dataset = NamedTupleDataset()
3271*da0073e9SAndroid Build Coastguard Worker
3272*da0073e9SAndroid Build Coastguard Worker    def test_dataloader_with_namedtuple(self):
3273*da0073e9SAndroid Build Coastguard Worker        # auto-collation
3274*da0073e9SAndroid Build Coastguard Worker        loader = DataLoader(self.dataset, batch_size=2, pin_memory=TEST_CUDA)
3275*da0073e9SAndroid Build Coastguard Worker        for batch in loader:
3276*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(batch, NamedTupleDataset.Batch)
3277*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA)
3278*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(batch.data, NamedTupleDataset.Data)
3279*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(batch.data.positive, torch.Tensor)
3280*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(batch.data.positive.is_pinned(), TEST_CUDA)
3281*da0073e9SAndroid Build Coastguard Worker        # no auto-collation
3282*da0073e9SAndroid Build Coastguard Worker        loader = DataLoader(self.dataset, batch_size=None, pin_memory=TEST_CUDA)
3283*da0073e9SAndroid Build Coastguard Worker        for batch in loader:
3284*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(batch, NamedTupleDataset.Batch)
3285*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA)
3286*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(batch.data, NamedTupleDataset.Data)
3287*da0073e9SAndroid Build Coastguard Worker            self.assertNotIsInstance(batch.data.positive, torch.Tensor)
3288*da0073e9SAndroid Build Coastguard Worker
3289*da0073e9SAndroid Build Coastguard Worker
3290*da0073e9SAndroid Build Coastguard Workerclass SimpleCustomBatch:
3291*da0073e9SAndroid Build Coastguard Worker    def __init__(self, data):
3292*da0073e9SAndroid Build Coastguard Worker        transposed_data = list(zip(*data))
3293*da0073e9SAndroid Build Coastguard Worker        self.inp = torch.stack(transposed_data[0], 0)
3294*da0073e9SAndroid Build Coastguard Worker        self.tgt = torch.stack(transposed_data[1], 0)
3295*da0073e9SAndroid Build Coastguard Worker
3296*da0073e9SAndroid Build Coastguard Worker    def pin_memory(self):
3297*da0073e9SAndroid Build Coastguard Worker        self.inp = self.inp.pin_memory()
3298*da0073e9SAndroid Build Coastguard Worker        self.tgt = self.tgt.pin_memory()
3299*da0073e9SAndroid Build Coastguard Worker        return self
3300*da0073e9SAndroid Build Coastguard Worker
3301*da0073e9SAndroid Build Coastguard Worker    def is_pinned(self):
3302*da0073e9SAndroid Build Coastguard Worker        return self.inp.is_pinned() and self.tgt.is_pinned()
3303*da0073e9SAndroid Build Coastguard Worker
3304*da0073e9SAndroid Build Coastguard Worker
3305*da0073e9SAndroid Build Coastguard Worker# Workaround for https://github.com/pytorch/pytorch/issues/50661
3306*da0073e9SAndroid Build Coastguard Worker# Classes from  `__main__` can not be correctly unpickled from spawned module
3307*da0073e9SAndroid Build Coastguard Worker# See https://docs.python.org/3/library/multiprocessing.html#multiprocessing-programming
3308*da0073e9SAndroid Build Coastguard Workerself_module = __import__(os.path.splitext(os.path.basename(__file__))[0])
3309*da0073e9SAndroid Build Coastguard Worker
3310*da0073e9SAndroid Build Coastguard Worker
3311*da0073e9SAndroid Build Coastguard Workerdef collate_wrapper(batch):
3312*da0073e9SAndroid Build Coastguard Worker    return self_module.SimpleCustomBatch(batch)
3313*da0073e9SAndroid Build Coastguard Worker
3314*da0073e9SAndroid Build Coastguard Worker
3315*da0073e9SAndroid Build Coastguard Workerdef collate_into_packed_sequence(batch):
3316*da0073e9SAndroid Build Coastguard Worker    data = torch.stack([sample[0] for sample in batch], 1)
3317*da0073e9SAndroid Build Coastguard Worker    t, b = data.size()
3318*da0073e9SAndroid Build Coastguard Worker    lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
3319*da0073e9SAndroid Build Coastguard Worker    return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, enforce_sorted=False)
3320*da0073e9SAndroid Build Coastguard Worker
3321*da0073e9SAndroid Build Coastguard Worker
3322*da0073e9SAndroid Build Coastguard Workerdef collate_into_packed_sequence_batch_first(batch):
3323*da0073e9SAndroid Build Coastguard Worker    data = torch.stack([sample[0] for sample in batch], 0)
3324*da0073e9SAndroid Build Coastguard Worker    b, t = data.size()
3325*da0073e9SAndroid Build Coastguard Worker    lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
3326*da0073e9SAndroid Build Coastguard Worker    return torch.nn.utils.rnn.pack_padded_sequence(
3327*da0073e9SAndroid Build Coastguard Worker        data, lengths, batch_first=True, enforce_sorted=False
3328*da0073e9SAndroid Build Coastguard Worker    )
3329*da0073e9SAndroid Build Coastguard Worker
3330*da0073e9SAndroid Build Coastguard Worker
3331*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
3332*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
3333*da0073e9SAndroid Build Coastguard Worker    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3334*da0073e9SAndroid Build Coastguard Worker    "fork is not supported. Dying (set die_after_fork=0 to override)",
3335*da0073e9SAndroid Build Coastguard Worker)
3336*da0073e9SAndroid Build Coastguard Workerclass TestCustomPinFn(TestCase):
3337*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
3338*da0073e9SAndroid Build Coastguard Worker        super().setUp()
3339*da0073e9SAndroid Build Coastguard Worker        inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
3340*da0073e9SAndroid Build Coastguard Worker        tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
3341*da0073e9SAndroid Build Coastguard Worker        self.dataset = TensorDataset(inps, tgts)
3342*da0073e9SAndroid Build Coastguard Worker
3343*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3344*da0073e9SAndroid Build Coastguard Worker    def test_custom_batch_pin(self):
3345*da0073e9SAndroid Build Coastguard Worker        test_cases = [
3346*da0073e9SAndroid Build Coastguard Worker            (collate_wrapper, self_module.SimpleCustomBatch),
3347*da0073e9SAndroid Build Coastguard Worker            (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
3348*da0073e9SAndroid Build Coastguard Worker            (
3349*da0073e9SAndroid Build Coastguard Worker                collate_into_packed_sequence_batch_first,
3350*da0073e9SAndroid Build Coastguard Worker                torch.nn.utils.rnn.PackedSequence,
3351*da0073e9SAndroid Build Coastguard Worker            ),
3352*da0073e9SAndroid Build Coastguard Worker        ]
3353*da0073e9SAndroid Build Coastguard Worker        for collate_fn, elem_cls in test_cases:
3354*da0073e9SAndroid Build Coastguard Worker            loader = DataLoader(
3355*da0073e9SAndroid Build Coastguard Worker                self.dataset, batch_size=2, collate_fn=collate_fn, pin_memory=True
3356*da0073e9SAndroid Build Coastguard Worker            )
3357*da0073e9SAndroid Build Coastguard Worker            for sample in loader:
3358*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(sample, elem_cls)
3359*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(sample.is_pinned())
3360*da0073e9SAndroid Build Coastguard Worker
3361*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3362*da0073e9SAndroid Build Coastguard Worker    def test_custom_batch_pin_worker(self):
3363*da0073e9SAndroid Build Coastguard Worker        test_cases = [
3364*da0073e9SAndroid Build Coastguard Worker            (collate_wrapper, self_module.SimpleCustomBatch),
3365*da0073e9SAndroid Build Coastguard Worker            (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
3366*da0073e9SAndroid Build Coastguard Worker            (
3367*da0073e9SAndroid Build Coastguard Worker                collate_into_packed_sequence_batch_first,
3368*da0073e9SAndroid Build Coastguard Worker                torch.nn.utils.rnn.PackedSequence,
3369*da0073e9SAndroid Build Coastguard Worker            ),
3370*da0073e9SAndroid Build Coastguard Worker        ]
3371*da0073e9SAndroid Build Coastguard Worker        for collate_fn, elem_cls in test_cases:
3372*da0073e9SAndroid Build Coastguard Worker            loader = DataLoader(
3373*da0073e9SAndroid Build Coastguard Worker                self.dataset,
3374*da0073e9SAndroid Build Coastguard Worker                batch_size=2,
3375*da0073e9SAndroid Build Coastguard Worker                collate_fn=collate_fn,
3376*da0073e9SAndroid Build Coastguard Worker                pin_memory=True,
3377*da0073e9SAndroid Build Coastguard Worker                num_workers=1,
3378*da0073e9SAndroid Build Coastguard Worker            )
3379*da0073e9SAndroid Build Coastguard Worker            for sample in loader:
3380*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(sample, elem_cls)
3381*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(sample.is_pinned())
3382*da0073e9SAndroid Build Coastguard Worker
3383*da0073e9SAndroid Build Coastguard Worker
3384*da0073e9SAndroid Build Coastguard Workerclass TestWorkerQueueDataset(Dataset):
3385*da0073e9SAndroid Build Coastguard Worker    def __init__(self, data):
3386*da0073e9SAndroid Build Coastguard Worker        self.data = data
3387*da0073e9SAndroid Build Coastguard Worker        self.worker_id = None
3388*da0073e9SAndroid Build Coastguard Worker
3389*da0073e9SAndroid Build Coastguard Worker    def worker_init_fn(self, worker_id):
3390*da0073e9SAndroid Build Coastguard Worker        self.worker_id = worker_id
3391*da0073e9SAndroid Build Coastguard Worker
3392*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, item):
3393*da0073e9SAndroid Build Coastguard Worker        return self.worker_id, self.data[item]
3394*da0073e9SAndroid Build Coastguard Worker
3395*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
3396*da0073e9SAndroid Build Coastguard Worker        return len(self.data)
3397*da0073e9SAndroid Build Coastguard Worker
3398*da0073e9SAndroid Build Coastguard Worker
3399*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
3400*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
3401*da0073e9SAndroid Build Coastguard Worker    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3402*da0073e9SAndroid Build Coastguard Worker    "fork is not supported. Dying (set die_after_fork=0 to override)",
3403*da0073e9SAndroid Build Coastguard Worker)
3404*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
3405*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ASAN,
3406*da0073e9SAndroid Build Coastguard Worker    "Flaky with ASAN, see https://github.com/pytorch/pytorch/issues/65727",
3407*da0073e9SAndroid Build Coastguard Worker)
3408*da0073e9SAndroid Build Coastguard Workerclass TestIndividualWorkerQueue(TestCase):
3409*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
3410*da0073e9SAndroid Build Coastguard Worker        super().setUp()
3411*da0073e9SAndroid Build Coastguard Worker        self.dataset = TestWorkerQueueDataset(list(range(128)))
3412*da0073e9SAndroid Build Coastguard Worker
3413*da0073e9SAndroid Build Coastguard Worker    def _run_ind_worker_queue_test(self, batch_size, num_workers):
3414*da0073e9SAndroid Build Coastguard Worker        loader = DataLoader(
3415*da0073e9SAndroid Build Coastguard Worker            self.dataset,
3416*da0073e9SAndroid Build Coastguard Worker            batch_size=batch_size,
3417*da0073e9SAndroid Build Coastguard Worker            shuffle=False,
3418*da0073e9SAndroid Build Coastguard Worker            num_workers=num_workers,
3419*da0073e9SAndroid Build Coastguard Worker            timeout=5,
3420*da0073e9SAndroid Build Coastguard Worker            worker_init_fn=self.dataset.worker_init_fn,
3421*da0073e9SAndroid Build Coastguard Worker        )
3422*da0073e9SAndroid Build Coastguard Worker        current_worker_idx = 0
3423*da0073e9SAndroid Build Coastguard Worker        for i, (worker_ids, sample) in enumerate(loader):
3424*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(worker_ids.tolist(), [current_worker_idx] * batch_size)
3425*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
3426*da0073e9SAndroid Build Coastguard Worker                sample.tolist(), list(range(i * batch_size, (i + 1) * batch_size))
3427*da0073e9SAndroid Build Coastguard Worker            )
3428*da0073e9SAndroid Build Coastguard Worker            current_worker_idx += 1
3429*da0073e9SAndroid Build Coastguard Worker            if current_worker_idx == num_workers:
3430*da0073e9SAndroid Build Coastguard Worker                current_worker_idx = 0
3431*da0073e9SAndroid Build Coastguard Worker
3432*da0073e9SAndroid Build Coastguard Worker    def test_ind_worker_queue(self):
3433*da0073e9SAndroid Build Coastguard Worker        max_num_workers = None
3434*da0073e9SAndroid Build Coastguard Worker        if hasattr(os, "sched_getaffinity"):
3435*da0073e9SAndroid Build Coastguard Worker            try:
3436*da0073e9SAndroid Build Coastguard Worker                max_num_workers = len(os.sched_getaffinity(0))
3437*da0073e9SAndroid Build Coastguard Worker            except Exception:
3438*da0073e9SAndroid Build Coastguard Worker                pass
3439*da0073e9SAndroid Build Coastguard Worker        if max_num_workers is None:
3440*da0073e9SAndroid Build Coastguard Worker            cpu_count = os.cpu_count()
3441*da0073e9SAndroid Build Coastguard Worker            if cpu_count is not None:
3442*da0073e9SAndroid Build Coastguard Worker                # Use half number of CPUs
3443*da0073e9SAndroid Build Coastguard Worker                max_num_workers = cpu_count // 2
3444*da0073e9SAndroid Build Coastguard Worker
3445*da0073e9SAndroid Build Coastguard Worker        if max_num_workers is None:
3446*da0073e9SAndroid Build Coastguard Worker            max_num_workers = 1
3447*da0073e9SAndroid Build Coastguard Worker
3448*da0073e9SAndroid Build Coastguard Worker        for batch_size in (8, 16, 32, 64):
3449*da0073e9SAndroid Build Coastguard Worker            for num_workers in range(0, min(6, max_num_workers)):
3450*da0073e9SAndroid Build Coastguard Worker                self._run_ind_worker_queue_test(
3451*da0073e9SAndroid Build Coastguard Worker                    batch_size=batch_size, num_workers=num_workers + 1
3452*da0073e9SAndroid Build Coastguard Worker                )
3453*da0073e9SAndroid Build Coastguard Worker
3454*da0073e9SAndroid Build Coastguard Worker
3455*da0073e9SAndroid Build Coastguard Workerclass SetAffinityDataset(IterableDataset):
3456*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
3457*da0073e9SAndroid Build Coastguard Worker        torch.randperm(1)
3458*da0073e9SAndroid Build Coastguard Worker        after = os.sched_getaffinity(0)
3459*da0073e9SAndroid Build Coastguard Worker        return iter(after)
3460*da0073e9SAndroid Build Coastguard Worker
3461*da0073e9SAndroid Build Coastguard Worker
3462*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
3463*da0073e9SAndroid Build Coastguard Worker    not hasattr(os, "sched_setaffinity"), "os.sched_setaffinity is not available"
3464*da0073e9SAndroid Build Coastguard Worker)
3465*da0073e9SAndroid Build Coastguard Workerclass TestSetAffinity(TestCase):
3466*da0073e9SAndroid Build Coastguard Worker    def test_set_affinity_in_worker_init(self):
3467*da0073e9SAndroid Build Coastguard Worker        # Query the current affinity mask to avoid setting a disallowed one
3468*da0073e9SAndroid Build Coastguard Worker        old_affinity = os.sched_getaffinity(0)
3469*da0073e9SAndroid Build Coastguard Worker        if not old_affinity:
3470*da0073e9SAndroid Build Coastguard Worker            self.skipTest("No affinity information")
3471*da0073e9SAndroid Build Coastguard Worker        # Choose any
3472*da0073e9SAndroid Build Coastguard Worker        expected_affinity = list(old_affinity)[-1]
3473*da0073e9SAndroid Build Coastguard Worker
3474*da0073e9SAndroid Build Coastguard Worker        def worker_set_affinity(_):
3475*da0073e9SAndroid Build Coastguard Worker            os.sched_setaffinity(0, [expected_affinity])
3476*da0073e9SAndroid Build Coastguard Worker
3477*da0073e9SAndroid Build Coastguard Worker        dataset = SetAffinityDataset()
3478*da0073e9SAndroid Build Coastguard Worker
3479*da0073e9SAndroid Build Coastguard Worker        dataloader = torch.utils.data.DataLoader(
3480*da0073e9SAndroid Build Coastguard Worker            dataset, num_workers=2, worker_init_fn=worker_set_affinity
3481*da0073e9SAndroid Build Coastguard Worker        )
3482*da0073e9SAndroid Build Coastguard Worker        for sample in dataloader:
3483*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sample, [expected_affinity])
3484*da0073e9SAndroid Build Coastguard Worker
3485*da0073e9SAndroid Build Coastguard Worker
3486*da0073e9SAndroid Build Coastguard Workerclass ConvDataset(Dataset):
3487*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
3488*da0073e9SAndroid Build Coastguard Worker        self.x = torch.ones(1, 1, 24000)
3489*da0073e9SAndroid Build Coastguard Worker        # Call convolution on parent process
3490*da0073e9SAndroid Build Coastguard Worker        self[0]
3491*da0073e9SAndroid Build Coastguard Worker
3492*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
3493*da0073e9SAndroid Build Coastguard Worker        return 1
3494*da0073e9SAndroid Build Coastguard Worker
3495*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, index):
3496*da0073e9SAndroid Build Coastguard Worker        return torch.nn.functional.conv1d(self.x, torch.ones(1, 1, 2))
3497*da0073e9SAndroid Build Coastguard Worker
3498*da0073e9SAndroid Build Coastguard Worker
3499*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_WINDOWS, "Needs fork")
3500*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
3501*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ASAN,
3502*da0073e9SAndroid Build Coastguard Worker    "This test hangs when running with ASAN, see https://github.com/pytorch/pytorch/issues/75492",
3503*da0073e9SAndroid Build Coastguard Worker)
3504*da0073e9SAndroid Build Coastguard Workerclass TestConvAfterFork(TestCase):
3505*da0073e9SAndroid Build Coastguard Worker    # Tests crash reported in https://github.com/pytorch/pytorch/issues/53565
3506*da0073e9SAndroid Build Coastguard Worker    def test_conv_after_fork(self):
3507*da0073e9SAndroid Build Coastguard Worker        loader = DataLoader(ConvDataset(), num_workers=1)
3508*da0073e9SAndroid Build Coastguard Worker        for x in loader:
3509*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.shape, (1, 1, 1, 23999))
3510*da0073e9SAndroid Build Coastguard Worker
3511*da0073e9SAndroid Build Coastguard Worker
3512*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestDataLoaderDeviceType, globals())
3513*da0073e9SAndroid Build Coastguard Worker
3514*da0073e9SAndroid Build Coastguard Worker
3515*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
3516*da0073e9SAndroid Build Coastguard Worker    run_tests()
3517