xref: /aosp_15_r20/external/pytorch/test/test_datapipe.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: ignore-errors
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dataloader"]
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport copy
6*da0073e9SAndroid Build Coastguard Workerimport itertools
7*da0073e9SAndroid Build Coastguard Workerimport os
8*da0073e9SAndroid Build Coastguard Workerimport os.path
9*da0073e9SAndroid Build Coastguard Workerimport pickle
10*da0073e9SAndroid Build Coastguard Workerimport pydoc
11*da0073e9SAndroid Build Coastguard Workerimport random
12*da0073e9SAndroid Build Coastguard Workerimport sys
13*da0073e9SAndroid Build Coastguard Workerimport tempfile
14*da0073e9SAndroid Build Coastguard Workerimport warnings
15*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
16*da0073e9SAndroid Build Coastguard Workerfrom typing import (
17*da0073e9SAndroid Build Coastguard Worker    Any,
18*da0073e9SAndroid Build Coastguard Worker    Awaitable,
19*da0073e9SAndroid Build Coastguard Worker    Dict,
20*da0073e9SAndroid Build Coastguard Worker    Generic,
21*da0073e9SAndroid Build Coastguard Worker    Iterator,
22*da0073e9SAndroid Build Coastguard Worker    List,
23*da0073e9SAndroid Build Coastguard Worker    Optional,
24*da0073e9SAndroid Build Coastguard Worker    Set,
25*da0073e9SAndroid Build Coastguard Worker    Tuple,
26*da0073e9SAndroid Build Coastguard Worker    Type,
27*da0073e9SAndroid Build Coastguard Worker    TYPE_CHECKING,
28*da0073e9SAndroid Build Coastguard Worker    TypeVar,
29*da0073e9SAndroid Build Coastguard Worker    Union,
30*da0073e9SAndroid Build Coastguard Worker)
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workerif not TYPE_CHECKING:
33*da0073e9SAndroid Build Coastguard Worker    # pyre isn't treating this the same as a typing.NamedTuple
34*da0073e9SAndroid Build Coastguard Worker    from typing_extensions import NamedTuple
35*da0073e9SAndroid Build Coastguard Workerelse:
36*da0073e9SAndroid Build Coastguard Worker    from typing import NamedTuple
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Workerimport operator
39*da0073e9SAndroid Build Coastguard Workerfrom unittest import skipIf
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Workerimport numpy as np
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Workerimport torch
44*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
45*da0073e9SAndroid Build Coastguard Workerimport torch.utils.data.datapipes as dp
46*da0073e9SAndroid Build Coastguard Workerimport torch.utils.data.graph
47*da0073e9SAndroid Build Coastguard Workerimport torch.utils.data.graph_settings
48*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
49*da0073e9SAndroid Build Coastguard Worker    run_tests,
50*da0073e9SAndroid Build Coastguard Worker    skipIfNoDill,
51*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
52*da0073e9SAndroid Build Coastguard Worker    suppress_warnings,
53*da0073e9SAndroid Build Coastguard Worker    TEST_DILL,
54*da0073e9SAndroid Build Coastguard Worker    TestCase,
55*da0073e9SAndroid Build Coastguard Worker)
56*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._import_utils import import_dill
57*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data import (
58*da0073e9SAndroid Build Coastguard Worker    argument_validation,
59*da0073e9SAndroid Build Coastguard Worker    DataChunk,
60*da0073e9SAndroid Build Coastguard Worker    DataLoader,
61*da0073e9SAndroid Build Coastguard Worker    IterDataPipe,
62*da0073e9SAndroid Build Coastguard Worker    MapDataPipe,
63*da0073e9SAndroid Build Coastguard Worker    RandomSampler,
64*da0073e9SAndroid Build Coastguard Worker    runtime_validation,
65*da0073e9SAndroid Build Coastguard Worker    runtime_validation_disabled,
66*da0073e9SAndroid Build Coastguard Worker)
67*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.datapipes.dataframe import (
68*da0073e9SAndroid Build Coastguard Worker    CaptureDataFrame,
69*da0073e9SAndroid Build Coastguard Worker    dataframe_wrapper as df_wrapper,
70*da0073e9SAndroid Build Coastguard Worker)
71*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
72*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.datapipes.utils.common import StreamWrapper
73*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.datapipes.utils.decoder import (
74*da0073e9SAndroid Build Coastguard Worker    basichandlers as decoder_basichandlers,
75*da0073e9SAndroid Build Coastguard Worker)
76*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration
77*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.graph import traverse_dps
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Workerdill = import_dill()
80*da0073e9SAndroid Build Coastguard WorkerHAS_DILL = TEST_DILL
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Workertry:
83*da0073e9SAndroid Build Coastguard Worker    import pandas  # type: ignore[import]  # noqa: F401 F403
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    HAS_PANDAS = True
86*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
87*da0073e9SAndroid Build Coastguard Worker    HAS_PANDAS = False
88*da0073e9SAndroid Build Coastguard WorkerskipIfNoDataFrames = skipIf(not HAS_PANDAS, "no dataframes (pandas)")
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard WorkerskipTyping = skipIf(True, "TODO: Fix typing bug")
91*da0073e9SAndroid Build Coastguard WorkerT_co = TypeVar("T_co", covariant=True)
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Workerdef create_temp_dir_and_files():
95*da0073e9SAndroid Build Coastguard Worker    # The temp dir and files within it will be released and deleted in tearDown().
96*da0073e9SAndroid Build Coastguard Worker    # Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function.
97*da0073e9SAndroid Build Coastguard Worker    temp_dir = tempfile.TemporaryDirectory()  # noqa: P201
98*da0073e9SAndroid Build Coastguard Worker    temp_dir_path = temp_dir.name
99*da0073e9SAndroid Build Coastguard Worker    with tempfile.NamedTemporaryFile(
100*da0073e9SAndroid Build Coastguard Worker        dir=temp_dir_path, delete=False, suffix=".txt"
101*da0073e9SAndroid Build Coastguard Worker    ) as f:
102*da0073e9SAndroid Build Coastguard Worker        temp_file1_name = f.name
103*da0073e9SAndroid Build Coastguard Worker    with tempfile.NamedTemporaryFile(
104*da0073e9SAndroid Build Coastguard Worker        dir=temp_dir_path, delete=False, suffix=".byte"
105*da0073e9SAndroid Build Coastguard Worker    ) as f:
106*da0073e9SAndroid Build Coastguard Worker        temp_file2_name = f.name
107*da0073e9SAndroid Build Coastguard Worker    with tempfile.NamedTemporaryFile(
108*da0073e9SAndroid Build Coastguard Worker        dir=temp_dir_path, delete=False, suffix=".empty"
109*da0073e9SAndroid Build Coastguard Worker    ) as f:
110*da0073e9SAndroid Build Coastguard Worker        temp_file3_name = f.name
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker    with open(temp_file1_name, "w") as f1:
113*da0073e9SAndroid Build Coastguard Worker        f1.write("0123456789abcdef")
114*da0073e9SAndroid Build Coastguard Worker    with open(temp_file2_name, "wb") as f2:
115*da0073e9SAndroid Build Coastguard Worker        f2.write(b"0123456789abcdef")
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    temp_sub_dir = tempfile.TemporaryDirectory(dir=temp_dir_path)  # noqa: P201
118*da0073e9SAndroid Build Coastguard Worker    temp_sub_dir_path = temp_sub_dir.name
119*da0073e9SAndroid Build Coastguard Worker    with tempfile.NamedTemporaryFile(
120*da0073e9SAndroid Build Coastguard Worker        dir=temp_sub_dir_path, delete=False, suffix=".txt"
121*da0073e9SAndroid Build Coastguard Worker    ) as f:
122*da0073e9SAndroid Build Coastguard Worker        temp_sub_file1_name = f.name
123*da0073e9SAndroid Build Coastguard Worker    with tempfile.NamedTemporaryFile(
124*da0073e9SAndroid Build Coastguard Worker        dir=temp_sub_dir_path, delete=False, suffix=".byte"
125*da0073e9SAndroid Build Coastguard Worker    ) as f:
126*da0073e9SAndroid Build Coastguard Worker        temp_sub_file2_name = f.name
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    with open(temp_sub_file1_name, "w") as f1:
129*da0073e9SAndroid Build Coastguard Worker        f1.write("0123456789abcdef")
130*da0073e9SAndroid Build Coastguard Worker    with open(temp_sub_file2_name, "wb") as f2:
131*da0073e9SAndroid Build Coastguard Worker        f2.write(b"0123456789abcdef")
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    return [
134*da0073e9SAndroid Build Coastguard Worker        (temp_dir, temp_file1_name, temp_file2_name, temp_file3_name),
135*da0073e9SAndroid Build Coastguard Worker        (temp_sub_dir, temp_sub_file1_name, temp_sub_file2_name),
136*da0073e9SAndroid Build Coastguard Worker    ]
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Workerdef reset_after_n_next_calls(
140*da0073e9SAndroid Build Coastguard Worker    datapipe: Union[IterDataPipe[T_co], MapDataPipe[T_co]], n: int
141*da0073e9SAndroid Build Coastguard Worker) -> Tuple[List[T_co], List[T_co]]:
142*da0073e9SAndroid Build Coastguard Worker    """
143*da0073e9SAndroid Build Coastguard Worker    Given a DataPipe and integer n, iterate the DataPipe for n elements and store the elements into a list
144*da0073e9SAndroid Build Coastguard Worker    Then, reset the DataPipe and return a tuple of two lists
145*da0073e9SAndroid Build Coastguard Worker        1. A list of elements yielded before the reset
146*da0073e9SAndroid Build Coastguard Worker        2. A list of all elements of the DataPipe after the reset
147*da0073e9SAndroid Build Coastguard Worker    """
148*da0073e9SAndroid Build Coastguard Worker    it = iter(datapipe)
149*da0073e9SAndroid Build Coastguard Worker    res_before_reset = []
150*da0073e9SAndroid Build Coastguard Worker    for _ in range(n):
151*da0073e9SAndroid Build Coastguard Worker        res_before_reset.append(next(it))
152*da0073e9SAndroid Build Coastguard Worker    return res_before_reset, list(datapipe)
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Workerdef odd_or_even(x: int) -> int:
156*da0073e9SAndroid Build Coastguard Worker    return x % 2
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Workerclass TestDataChunk(TestCase):
160*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
161*da0073e9SAndroid Build Coastguard Worker        self.elements = list(range(10))
162*da0073e9SAndroid Build Coastguard Worker        random.shuffle(self.elements)
163*da0073e9SAndroid Build Coastguard Worker        self.chunk: DataChunk[int] = DataChunk(self.elements)
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker    def test_getitem(self):
166*da0073e9SAndroid Build Coastguard Worker        for i in range(10):
167*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(self.elements[i], self.chunk[i])
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker    def test_iter(self):
170*da0073e9SAndroid Build Coastguard Worker        for ele, dc in zip(self.elements, iter(self.chunk)):
171*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ele, dc)
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker    def test_len(self):
174*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(self.elements), len(self.chunk))
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker    def test_as_string(self):
177*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(self.chunk), str(self.elements))
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker        batch = [self.elements] * 3
180*da0073e9SAndroid Build Coastguard Worker        chunks: List[DataChunk[int]] = [DataChunk(self.elements)] * 3
181*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(batch), str(chunks))
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker    def test_sort(self):
184*da0073e9SAndroid Build Coastguard Worker        chunk: DataChunk[int] = DataChunk(self.elements)
185*da0073e9SAndroid Build Coastguard Worker        chunk.sort()
186*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(chunk, DataChunk))
187*da0073e9SAndroid Build Coastguard Worker        for i, d in enumerate(chunk):
188*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(i, d)
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker    def test_reverse(self):
191*da0073e9SAndroid Build Coastguard Worker        chunk: DataChunk[int] = DataChunk(self.elements)
192*da0073e9SAndroid Build Coastguard Worker        chunk.reverse()
193*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(chunk, DataChunk))
194*da0073e9SAndroid Build Coastguard Worker        for i in range(10):
195*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(chunk[i], self.elements[9 - i])
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker    def test_random_shuffle(self):
198*da0073e9SAndroid Build Coastguard Worker        elements = list(range(10))
199*da0073e9SAndroid Build Coastguard Worker        chunk: DataChunk[int] = DataChunk(elements)
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker        rng = random.Random(0)
202*da0073e9SAndroid Build Coastguard Worker        rng.shuffle(chunk)
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker        rng = random.Random(0)
205*da0073e9SAndroid Build Coastguard Worker        rng.shuffle(elements)
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chunk, elements)
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Workerclass TestStreamWrapper(TestCase):
211*da0073e9SAndroid Build Coastguard Worker    class _FakeFD:
212*da0073e9SAndroid Build Coastguard Worker        def __init__(self, filepath):
213*da0073e9SAndroid Build Coastguard Worker            self.filepath = filepath
214*da0073e9SAndroid Build Coastguard Worker            self.opened = False
215*da0073e9SAndroid Build Coastguard Worker            self.closed = False
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker        def open(self):
218*da0073e9SAndroid Build Coastguard Worker            self.opened = True
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker        def read(self):
221*da0073e9SAndroid Build Coastguard Worker            if self.opened:
222*da0073e9SAndroid Build Coastguard Worker                return "".join(self)
223*da0073e9SAndroid Build Coastguard Worker            else:
224*da0073e9SAndroid Build Coastguard Worker                raise OSError("Cannot read from un-opened file descriptor")
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker        def __iter__(self):
227*da0073e9SAndroid Build Coastguard Worker            for i in range(5):
228*da0073e9SAndroid Build Coastguard Worker                yield str(i)
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        def close(self):
231*da0073e9SAndroid Build Coastguard Worker            if self.opened:
232*da0073e9SAndroid Build Coastguard Worker                self.opened = False
233*da0073e9SAndroid Build Coastguard Worker                self.closed = True
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker        def __repr__(self):
236*da0073e9SAndroid Build Coastguard Worker            return "FakeFD"
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker    def test_dir(self):
239*da0073e9SAndroid Build Coastguard Worker        fd = TestStreamWrapper._FakeFD("")
240*da0073e9SAndroid Build Coastguard Worker        wrap_fd = StreamWrapper(fd)
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker        s = set(dir(wrap_fd))
243*da0073e9SAndroid Build Coastguard Worker        for api in ["open", "read", "close"]:
244*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(api in s)
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
247*da0073e9SAndroid Build Coastguard Worker    def test_api(self):
248*da0073e9SAndroid Build Coastguard Worker        fd = TestStreamWrapper._FakeFD("")
249*da0073e9SAndroid Build Coastguard Worker        wrap_fd = StreamWrapper(fd)
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(fd.opened)
252*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(fd.closed)
253*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(IOError, "Cannot read from"):
254*da0073e9SAndroid Build Coastguard Worker            wrap_fd.read()
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker        wrap_fd.open()
257*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fd.opened)
258*da0073e9SAndroid Build Coastguard Worker        self.assertEqual("01234", wrap_fd.read())
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        del wrap_fd
261*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(fd.opened)
262*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fd.closed)
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker    def test_pickle(self):
265*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryFile() as f:
266*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(TypeError) as ctx1:
267*da0073e9SAndroid Build Coastguard Worker                pickle.dumps(f)
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker            wrap_f = StreamWrapper(f)
270*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(TypeError) as ctx2:
271*da0073e9SAndroid Build Coastguard Worker                pickle.dumps(wrap_f)
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker            # Same exception when pickle
274*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(str(ctx1.exception), str(ctx2.exception))
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker        fd = TestStreamWrapper._FakeFD("")
277*da0073e9SAndroid Build Coastguard Worker        wrap_fd = StreamWrapper(fd)
278*da0073e9SAndroid Build Coastguard Worker        _ = pickle.loads(pickle.dumps(wrap_fd))
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker    def test_repr(self):
281*da0073e9SAndroid Build Coastguard Worker        fd = TestStreamWrapper._FakeFD("")
282*da0073e9SAndroid Build Coastguard Worker        wrap_fd = StreamWrapper(fd)
283*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(wrap_fd), "StreamWrapper<FakeFD>")
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryFile() as f:
286*da0073e9SAndroid Build Coastguard Worker            wrap_f = StreamWrapper(f)
287*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(str(wrap_f), "StreamWrapper<" + str(f) + ">")
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Workerclass TestIterableDataPipeBasic(TestCase):
291*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
292*da0073e9SAndroid Build Coastguard Worker        ret = create_temp_dir_and_files()
293*da0073e9SAndroid Build Coastguard Worker        self.temp_dir = ret[0][0]
294*da0073e9SAndroid Build Coastguard Worker        self.temp_files = ret[0][1:]
295*da0073e9SAndroid Build Coastguard Worker        self.temp_sub_dir = ret[1][0]
296*da0073e9SAndroid Build Coastguard Worker        self.temp_sub_files = ret[1][1:]
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
299*da0073e9SAndroid Build Coastguard Worker        try:
300*da0073e9SAndroid Build Coastguard Worker            self.temp_sub_dir.cleanup()
301*da0073e9SAndroid Build Coastguard Worker            self.temp_dir.cleanup()
302*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
303*da0073e9SAndroid Build Coastguard Worker            warnings.warn(
304*da0073e9SAndroid Build Coastguard Worker                f"TestIterableDatasetBasic was not able to cleanup temp dir due to {str(e)}"
305*da0073e9SAndroid Build Coastguard Worker            )
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker    def test_listdirfiles_iterable_datapipe(self):
308*da0073e9SAndroid Build Coastguard Worker        temp_dir = self.temp_dir.name
309*da0073e9SAndroid Build Coastguard Worker        datapipe: IterDataPipe = dp.iter.FileLister(temp_dir, "")
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker        count = 0
312*da0073e9SAndroid Build Coastguard Worker        for pathname in datapipe:
313*da0073e9SAndroid Build Coastguard Worker            count = count + 1
314*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(pathname in self.temp_files)
315*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count, len(self.temp_files))
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker        count = 0
318*da0073e9SAndroid Build Coastguard Worker        datapipe = dp.iter.FileLister(temp_dir, "", recursive=True)
319*da0073e9SAndroid Build Coastguard Worker        for pathname in datapipe:
320*da0073e9SAndroid Build Coastguard Worker            count = count + 1
321*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
322*da0073e9SAndroid Build Coastguard Worker                (pathname in self.temp_files) or (pathname in self.temp_sub_files)
323*da0073e9SAndroid Build Coastguard Worker            )
324*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count, len(self.temp_files) + len(self.temp_sub_files))
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker        temp_files = self.temp_files
327*da0073e9SAndroid Build Coastguard Worker        datapipe = dp.iter.FileLister([temp_dir, *temp_files])
328*da0073e9SAndroid Build Coastguard Worker        count = 0
329*da0073e9SAndroid Build Coastguard Worker        for pathname in datapipe:
330*da0073e9SAndroid Build Coastguard Worker            count += 1
331*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(pathname in self.temp_files)
332*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count, 2 * len(self.temp_files))
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker        # test functional API
335*da0073e9SAndroid Build Coastguard Worker        datapipe = datapipe.list_files()
336*da0073e9SAndroid Build Coastguard Worker        count = 0
337*da0073e9SAndroid Build Coastguard Worker        for pathname in datapipe:
338*da0073e9SAndroid Build Coastguard Worker            count += 1
339*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(pathname in self.temp_files)
340*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count, 2 * len(self.temp_files))
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker    def test_listdirfilesdeterministic_iterable_datapipe(self):
343*da0073e9SAndroid Build Coastguard Worker        temp_dir = self.temp_dir.name
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker        datapipe = dp.iter.FileLister(temp_dir, "")
346*da0073e9SAndroid Build Coastguard Worker        # The output order should be always the same.
347*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(datapipe), list(datapipe))
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker        datapipe = dp.iter.FileLister(temp_dir, "", recursive=True)
350*da0073e9SAndroid Build Coastguard Worker        # The output order should be always the same.
351*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(datapipe), list(datapipe))
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker    def test_openfilesfromdisk_iterable_datapipe(self):
354*da0073e9SAndroid Build Coastguard Worker        # test import datapipe class directly
355*da0073e9SAndroid Build Coastguard Worker        from torch.utils.data.datapipes.iter import FileLister, FileOpener
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker        temp_dir = self.temp_dir.name
358*da0073e9SAndroid Build Coastguard Worker        datapipe1 = FileLister(temp_dir, "")
359*da0073e9SAndroid Build Coastguard Worker        datapipe2 = FileOpener(datapipe1, mode="b")
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        count = 0
362*da0073e9SAndroid Build Coastguard Worker        for rec in datapipe2:
363*da0073e9SAndroid Build Coastguard Worker            count = count + 1
364*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(rec[0] in self.temp_files)
365*da0073e9SAndroid Build Coastguard Worker            with open(rec[0], "rb") as f:
366*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(rec[1].read(), f.read())
367*da0073e9SAndroid Build Coastguard Worker                rec[1].close()
368*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count, len(self.temp_files))
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker        # functional API
371*da0073e9SAndroid Build Coastguard Worker        datapipe3 = datapipe1.open_files(mode="b")
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker        count = 0
374*da0073e9SAndroid Build Coastguard Worker        for rec in datapipe3:
375*da0073e9SAndroid Build Coastguard Worker            count = count + 1
376*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(rec[0] in self.temp_files)
377*da0073e9SAndroid Build Coastguard Worker            with open(rec[0], "rb") as f:
378*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(rec[1].read(), f.read())
379*da0073e9SAndroid Build Coastguard Worker                rec[1].close()
380*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count, len(self.temp_files))
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker        # __len__ Test
383*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
384*da0073e9SAndroid Build Coastguard Worker            len(datapipe3)
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker    def test_routeddecoder_iterable_datapipe(self):
387*da0073e9SAndroid Build Coastguard Worker        temp_dir = self.temp_dir.name
388*da0073e9SAndroid Build Coastguard Worker        temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png")
389*da0073e9SAndroid Build Coastguard Worker        png_data = np.array(
390*da0073e9SAndroid Build Coastguard Worker            [[[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]],
391*da0073e9SAndroid Build Coastguard Worker            dtype=np.single,
392*da0073e9SAndroid Build Coastguard Worker        )
393*da0073e9SAndroid Build Coastguard Worker        np.save(temp_pngfile_pathname, png_data)
394*da0073e9SAndroid Build Coastguard Worker        datapipe1 = dp.iter.FileLister(temp_dir, ["*.png", "*.txt"])
395*da0073e9SAndroid Build Coastguard Worker        datapipe2 = dp.iter.FileOpener(datapipe1, mode="b")
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker        def _png_decoder(extension, data):
398*da0073e9SAndroid Build Coastguard Worker            if extension != "png":
399*da0073e9SAndroid Build Coastguard Worker                return None
400*da0073e9SAndroid Build Coastguard Worker            return np.load(data)
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker        def _helper(prior_dp, dp, channel_first=False):
403*da0073e9SAndroid Build Coastguard Worker            # Byte stream is not closed
404*da0073e9SAndroid Build Coastguard Worker            for inp in prior_dp:
405*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(inp[1].closed)
406*da0073e9SAndroid Build Coastguard Worker            for inp, rec in zip(prior_dp, dp):
407*da0073e9SAndroid Build Coastguard Worker                ext = os.path.splitext(rec[0])[1]
408*da0073e9SAndroid Build Coastguard Worker                if ext == ".png":
409*da0073e9SAndroid Build Coastguard Worker                    expected = np.array(
410*da0073e9SAndroid Build Coastguard Worker                        [
411*da0073e9SAndroid Build Coastguard Worker                            [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
412*da0073e9SAndroid Build Coastguard Worker                            [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
413*da0073e9SAndroid Build Coastguard Worker                        ],
414*da0073e9SAndroid Build Coastguard Worker                        dtype=np.single,
415*da0073e9SAndroid Build Coastguard Worker                    )
416*da0073e9SAndroid Build Coastguard Worker                    if channel_first:
417*da0073e9SAndroid Build Coastguard Worker                        expected = expected.transpose(2, 0, 1)
418*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(rec[1], expected)
419*da0073e9SAndroid Build Coastguard Worker                else:
420*da0073e9SAndroid Build Coastguard Worker                    with open(rec[0], "rb") as f:
421*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(rec[1], f.read().decode("utf-8"))
422*da0073e9SAndroid Build Coastguard Worker                # Corresponding byte stream is closed by Decoder
423*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(inp[1].closed)
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker        cached = list(datapipe2)
426*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
427*da0073e9SAndroid Build Coastguard Worker            datapipe3 = dp.iter.RoutedDecoder(cached, _png_decoder)
428*da0073e9SAndroid Build Coastguard Worker        datapipe3.add_handler(decoder_basichandlers)
429*da0073e9SAndroid Build Coastguard Worker        _helper(cached, datapipe3)
430*da0073e9SAndroid Build Coastguard Worker
431*da0073e9SAndroid Build Coastguard Worker        cached = list(datapipe2)
432*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
433*da0073e9SAndroid Build Coastguard Worker            datapipe4 = dp.iter.RoutedDecoder(cached, decoder_basichandlers)
434*da0073e9SAndroid Build Coastguard Worker        datapipe4.add_handler(_png_decoder)
435*da0073e9SAndroid Build Coastguard Worker        _helper(cached, datapipe4, channel_first=True)
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker    def test_groupby_iterable_datapipe(self):
438*da0073e9SAndroid Build Coastguard Worker        file_list = [
439*da0073e9SAndroid Build Coastguard Worker            "a.png",
440*da0073e9SAndroid Build Coastguard Worker            "b.png",
441*da0073e9SAndroid Build Coastguard Worker            "c.json",
442*da0073e9SAndroid Build Coastguard Worker            "a.json",
443*da0073e9SAndroid Build Coastguard Worker            "c.png",
444*da0073e9SAndroid Build Coastguard Worker            "b.json",
445*da0073e9SAndroid Build Coastguard Worker            "d.png",
446*da0073e9SAndroid Build Coastguard Worker            "d.json",
447*da0073e9SAndroid Build Coastguard Worker            "e.png",
448*da0073e9SAndroid Build Coastguard Worker            "f.json",
449*da0073e9SAndroid Build Coastguard Worker            "g.png",
450*da0073e9SAndroid Build Coastguard Worker            "f.png",
451*da0073e9SAndroid Build Coastguard Worker            "g.json",
452*da0073e9SAndroid Build Coastguard Worker            "e.json",
453*da0073e9SAndroid Build Coastguard Worker            "h.txt",
454*da0073e9SAndroid Build Coastguard Worker            "h.json",
455*da0073e9SAndroid Build Coastguard Worker        ]
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker        import io
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker        datapipe1 = dp.iter.IterableWrapper(
460*da0073e9SAndroid Build Coastguard Worker            [(filename, io.BytesIO(b"12345abcde")) for filename in file_list]
461*da0073e9SAndroid Build Coastguard Worker        )
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker        def group_fn(data):
464*da0073e9SAndroid Build Coastguard Worker            filepath, _ = data
465*da0073e9SAndroid Build Coastguard Worker            return os.path.basename(filepath).split(".")[0]
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker        datapipe2 = dp.iter.Grouper(datapipe1, group_key_fn=group_fn, group_size=2)
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker        def order_fn(data):
470*da0073e9SAndroid Build Coastguard Worker            data.sort(key=lambda f: f[0], reverse=True)
471*da0073e9SAndroid Build Coastguard Worker            return data
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker        datapipe3 = dp.iter.Mapper(datapipe2, fn=order_fn)  # type: ignore[var-annotated]
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker        expected_result = [
476*da0073e9SAndroid Build Coastguard Worker            ("a.png", "a.json"),
477*da0073e9SAndroid Build Coastguard Worker            ("c.png", "c.json"),
478*da0073e9SAndroid Build Coastguard Worker            ("b.png", "b.json"),
479*da0073e9SAndroid Build Coastguard Worker            ("d.png", "d.json"),
480*da0073e9SAndroid Build Coastguard Worker            ("f.png", "f.json"),
481*da0073e9SAndroid Build Coastguard Worker            ("g.png", "g.json"),
482*da0073e9SAndroid Build Coastguard Worker            ("e.png", "e.json"),
483*da0073e9SAndroid Build Coastguard Worker            ("h.txt", "h.json"),
484*da0073e9SAndroid Build Coastguard Worker        ]
485*da0073e9SAndroid Build Coastguard Worker
486*da0073e9SAndroid Build Coastguard Worker        count = 0
487*da0073e9SAndroid Build Coastguard Worker        for rec, expected in zip(datapipe3, expected_result):
488*da0073e9SAndroid Build Coastguard Worker            count = count + 1
489*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(os.path.basename(rec[0][0]), expected[0])
490*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(os.path.basename(rec[1][0]), expected[1])
491*da0073e9SAndroid Build Coastguard Worker            for i in [0, 1]:
492*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(rec[i][1].read(), b"12345abcde")
493*da0073e9SAndroid Build Coastguard Worker                rec[i][1].close()
494*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count, 8)
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker        # testing the keep_key option
497*da0073e9SAndroid Build Coastguard Worker        datapipe4 = dp.iter.Grouper(
498*da0073e9SAndroid Build Coastguard Worker            datapipe1, group_key_fn=group_fn, keep_key=True, group_size=2
499*da0073e9SAndroid Build Coastguard Worker        )
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker        def order_fn(data):
502*da0073e9SAndroid Build Coastguard Worker            data[1].sort(key=lambda f: f[0], reverse=True)
503*da0073e9SAndroid Build Coastguard Worker            return data
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker        datapipe5 = dp.iter.Mapper(datapipe4, fn=order_fn)  # type: ignore[var-annotated]
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker        expected_result = [
508*da0073e9SAndroid Build Coastguard Worker            ("a", ("a.png", "a.json")),
509*da0073e9SAndroid Build Coastguard Worker            ("c", ("c.png", "c.json")),
510*da0073e9SAndroid Build Coastguard Worker            ("b", ("b.png", "b.json")),
511*da0073e9SAndroid Build Coastguard Worker            ("d", ("d.png", "d.json")),
512*da0073e9SAndroid Build Coastguard Worker            ("f", ("f.png", "f.json")),
513*da0073e9SAndroid Build Coastguard Worker            ("g", ("g.png", "g.json")),
514*da0073e9SAndroid Build Coastguard Worker            ("e", ("e.png", "e.json")),
515*da0073e9SAndroid Build Coastguard Worker            ("h", ("h.txt", "h.json")),
516*da0073e9SAndroid Build Coastguard Worker        ]
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker        count = 0
519*da0073e9SAndroid Build Coastguard Worker        for rec, expected in zip(datapipe5, expected_result):
520*da0073e9SAndroid Build Coastguard Worker            count = count + 1
521*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rec[0], expected[0])
522*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rec[1][0][0], expected[1][0])
523*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rec[1][1][0], expected[1][1])
524*da0073e9SAndroid Build Coastguard Worker            for i in [0, 1]:
525*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(rec[1][i][1].read(), b"12345abcde")
526*da0073e9SAndroid Build Coastguard Worker                rec[1][i][1].close()
527*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count, 8)
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker    def test_demux_mux_datapipe(self):
530*da0073e9SAndroid Build Coastguard Worker        numbers = NumbersDataset(10)
531*da0073e9SAndroid Build Coastguard Worker        n1, n2 = numbers.demux(2, lambda x: x % 2)
532*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([0, 2, 4, 6, 8], list(n1))
533*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([1, 3, 5, 7, 9], list(n2))
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker        # Functional Test: demux and mux works sequentially as expected
536*da0073e9SAndroid Build Coastguard Worker        numbers = NumbersDataset(10)
537*da0073e9SAndroid Build Coastguard Worker        n1, n2, n3 = numbers.demux(3, lambda x: x % 3)
538*da0073e9SAndroid Build Coastguard Worker        n = n1.mux(n2, n3)
539*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(9)), list(n))
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Uneven DataPipes
542*da0073e9SAndroid Build Coastguard Worker        source_numbers = list(range(0, 10)) + [10, 12]
543*da0073e9SAndroid Build Coastguard Worker        numbers_dp = dp.iter.IterableWrapper(source_numbers)
544*da0073e9SAndroid Build Coastguard Worker        n1, n2 = numbers_dp.demux(2, lambda x: x % 2)
545*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([0, 2, 4, 6, 8, 10, 12], list(n1))
546*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([1, 3, 5, 7, 9], list(n2))
547*da0073e9SAndroid Build Coastguard Worker        n = n1.mux(n2)
548*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), list(n))
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings  # Suppress warning for lambda fn
551*da0073e9SAndroid Build Coastguard Worker    def test_map_with_col_file_handle_datapipe(self):
552*da0073e9SAndroid Build Coastguard Worker        temp_dir = self.temp_dir.name
553*da0073e9SAndroid Build Coastguard Worker        datapipe1 = dp.iter.FileLister(temp_dir, "")
554*da0073e9SAndroid Build Coastguard Worker        datapipe2 = dp.iter.FileOpener(datapipe1)
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker        def _helper(datapipe):
557*da0073e9SAndroid Build Coastguard Worker            dp1 = datapipe.map(lambda x: x.read(), input_col=1)
558*da0073e9SAndroid Build Coastguard Worker            dp2 = datapipe.map(lambda x: (x[0], x[1].read()))
559*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(dp1), list(dp2))
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker        # tuple
562*da0073e9SAndroid Build Coastguard Worker        _helper(datapipe2)
563*da0073e9SAndroid Build Coastguard Worker        # list
564*da0073e9SAndroid Build Coastguard Worker        datapipe3 = datapipe2.map(lambda x: list(x))
565*da0073e9SAndroid Build Coastguard Worker        _helper(datapipe3)
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker@skipIfNoDataFrames
569*da0073e9SAndroid Build Coastguard Workerclass TestCaptureDataFrame(TestCase):
570*da0073e9SAndroid Build Coastguard Worker    def get_new_df(self):
571*da0073e9SAndroid Build Coastguard Worker        return df_wrapper.create_dataframe([[1, 2]], columns=["a", "b"])
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker    def compare_capture_and_eager(self, operations):
574*da0073e9SAndroid Build Coastguard Worker        cdf = CaptureDataFrame()
575*da0073e9SAndroid Build Coastguard Worker        cdf = operations(cdf)
576*da0073e9SAndroid Build Coastguard Worker        df = self.get_new_df()
577*da0073e9SAndroid Build Coastguard Worker        cdf = cdf.apply_ops(df)
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker        df = self.get_new_df()
580*da0073e9SAndroid Build Coastguard Worker        df = operations(df)
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(df.equals(cdf))
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker    def test_basic_capture(self):
585*da0073e9SAndroid Build Coastguard Worker        def operations(df):
586*da0073e9SAndroid Build Coastguard Worker            df["c"] = df.b + df["a"] * 7
587*da0073e9SAndroid Build Coastguard Worker            # somehow swallows pandas UserWarning when `df.c = df.b + df['a'] * 7`
588*da0073e9SAndroid Build Coastguard Worker            return df
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker        self.compare_capture_and_eager(operations)
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker
593*da0073e9SAndroid Build Coastguard Workerclass TestDataFramesPipes(TestCase):
594*da0073e9SAndroid Build Coastguard Worker    """
595*da0073e9SAndroid Build Coastguard Worker    Most of test will fail if pandas instaled, but no dill available.
596*da0073e9SAndroid Build Coastguard Worker    Need to rework them to avoid multiple skips.
597*da0073e9SAndroid Build Coastguard Worker    """
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker    def _get_datapipe(self, range=10, dataframe_size=7):
600*da0073e9SAndroid Build Coastguard Worker        return NumbersDataset(range).map(lambda i: (i, i % 3))
601*da0073e9SAndroid Build Coastguard Worker
602*da0073e9SAndroid Build Coastguard Worker    def _get_dataframes_pipe(self, range=10, dataframe_size=7):
603*da0073e9SAndroid Build Coastguard Worker        return (
604*da0073e9SAndroid Build Coastguard Worker            NumbersDataset(range)
605*da0073e9SAndroid Build Coastguard Worker            .map(lambda i: (i, i % 3))
606*da0073e9SAndroid Build Coastguard Worker            ._to_dataframes_pipe(columns=["i", "j"], dataframe_size=dataframe_size)
607*da0073e9SAndroid Build Coastguard Worker        )
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDataFrames
610*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDill  # TODO(VitalyFedyunin): Decouple tests from dill by avoiding lambdas in map
611*da0073e9SAndroid Build Coastguard Worker    def test_capture(self):
612*da0073e9SAndroid Build Coastguard Worker        dp_numbers = self._get_datapipe().map(lambda x: (x[0], x[1], x[1] + 3 * x[0]))
613*da0073e9SAndroid Build Coastguard Worker        df_numbers = self._get_dataframes_pipe()
614*da0073e9SAndroid Build Coastguard Worker        df_numbers["k"] = df_numbers["j"] + df_numbers.i * 3
615*da0073e9SAndroid Build Coastguard Worker        expected = list(dp_numbers)
616*da0073e9SAndroid Build Coastguard Worker        actual = list(df_numbers)
617*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDataFrames
620*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDill
621*da0073e9SAndroid Build Coastguard Worker    def test_shuffle(self):
622*da0073e9SAndroid Build Coastguard Worker        #  With non-zero (but extremely low) probability (when shuffle do nothing),
623*da0073e9SAndroid Build Coastguard Worker        #  this test fails, so feel free to restart
624*da0073e9SAndroid Build Coastguard Worker        df_numbers = self._get_dataframes_pipe(range=1000).shuffle()
625*da0073e9SAndroid Build Coastguard Worker        dp_numbers = self._get_datapipe(range=1000)
626*da0073e9SAndroid Build Coastguard Worker        df_result = [tuple(item) for item in df_numbers]
627*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(list(dp_numbers), df_result)
628*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(dp_numbers), sorted(df_result))
629*da0073e9SAndroid Build Coastguard Worker
630*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDataFrames
631*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDill
632*da0073e9SAndroid Build Coastguard Worker    def test_batch(self):
633*da0073e9SAndroid Build Coastguard Worker        df_numbers = self._get_dataframes_pipe(range=100).batch(8)
634*da0073e9SAndroid Build Coastguard Worker        df_numbers_list = list(df_numbers)
635*da0073e9SAndroid Build Coastguard Worker        last_batch = df_numbers_list[-1]
636*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(4, len(last_batch))
637*da0073e9SAndroid Build Coastguard Worker        unpacked_batch = [tuple(row) for row in last_batch]
638*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(96, 0), (97, 1), (98, 2), (99, 0)], unpacked_batch)
639*da0073e9SAndroid Build Coastguard Worker
640*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDataFrames
641*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDill
642*da0073e9SAndroid Build Coastguard Worker    def test_unbatch(self):
643*da0073e9SAndroid Build Coastguard Worker        df_numbers = self._get_dataframes_pipe(range=100).batch(8).batch(3)
644*da0073e9SAndroid Build Coastguard Worker        dp_numbers = self._get_datapipe(range=100)
645*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(dp_numbers), list(df_numbers.unbatch(2)))
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDataFrames
648*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDill
649*da0073e9SAndroid Build Coastguard Worker    def test_filter(self):
650*da0073e9SAndroid Build Coastguard Worker        df_numbers = self._get_dataframes_pipe(range=10).filter(lambda x: x.i > 5)
651*da0073e9SAndroid Build Coastguard Worker        actual = list(df_numbers)
652*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(6, 0), (7, 1), (8, 2), (9, 0)], actual)
653*da0073e9SAndroid Build Coastguard Worker
654*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDataFrames
655*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDill
656*da0073e9SAndroid Build Coastguard Worker    def test_collate(self):
657*da0073e9SAndroid Build Coastguard Worker        def collate_i(column):
658*da0073e9SAndroid Build Coastguard Worker            return column.sum()
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Worker        def collate_j(column):
661*da0073e9SAndroid Build Coastguard Worker            return column.prod()
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker        df_numbers = self._get_dataframes_pipe(range=30).batch(3)
664*da0073e9SAndroid Build Coastguard Worker        df_numbers = df_numbers.collate({"j": collate_j, "i": collate_i})
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker        expected_i = [
667*da0073e9SAndroid Build Coastguard Worker            3,
668*da0073e9SAndroid Build Coastguard Worker            12,
669*da0073e9SAndroid Build Coastguard Worker            21,
670*da0073e9SAndroid Build Coastguard Worker            30,
671*da0073e9SAndroid Build Coastguard Worker            39,
672*da0073e9SAndroid Build Coastguard Worker            48,
673*da0073e9SAndroid Build Coastguard Worker            57,
674*da0073e9SAndroid Build Coastguard Worker            66,
675*da0073e9SAndroid Build Coastguard Worker            75,
676*da0073e9SAndroid Build Coastguard Worker            84,
677*da0073e9SAndroid Build Coastguard Worker        ]
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker        actual_i = []
680*da0073e9SAndroid Build Coastguard Worker        for i, j in df_numbers:
681*da0073e9SAndroid Build Coastguard Worker            actual_i.append(i)
682*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_i, actual_i)
683*da0073e9SAndroid Build Coastguard Worker
684*da0073e9SAndroid Build Coastguard Worker        actual_i = []
685*da0073e9SAndroid Build Coastguard Worker        for item in df_numbers:
686*da0073e9SAndroid Build Coastguard Worker            actual_i.append(item.i)
687*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_i, actual_i)
688*da0073e9SAndroid Build Coastguard Worker
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Workerclass IDP_NoLen(IterDataPipe):
691*da0073e9SAndroid Build Coastguard Worker    def __init__(self, input_dp):
692*da0073e9SAndroid Build Coastguard Worker        super().__init__()
693*da0073e9SAndroid Build Coastguard Worker        self.input_dp = input_dp
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker    # Prevent in-place modification
696*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
697*da0073e9SAndroid Build Coastguard Worker        input_dp = (
698*da0073e9SAndroid Build Coastguard Worker            self.input_dp
699*da0073e9SAndroid Build Coastguard Worker            if isinstance(self.input_dp, IterDataPipe)
700*da0073e9SAndroid Build Coastguard Worker            else copy.deepcopy(self.input_dp)
701*da0073e9SAndroid Build Coastguard Worker        )
702*da0073e9SAndroid Build Coastguard Worker        yield from input_dp
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Workerdef _fake_fn(data):
706*da0073e9SAndroid Build Coastguard Worker    return data
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker
709*da0073e9SAndroid Build Coastguard Workerdef _fake_add(constant, data):
710*da0073e9SAndroid Build Coastguard Worker    return constant + data
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker
713*da0073e9SAndroid Build Coastguard Workerdef _fake_filter_fn(data):
714*da0073e9SAndroid Build Coastguard Worker    return True
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Workerdef _simple_filter_fn(data):
718*da0073e9SAndroid Build Coastguard Worker    return data >= 5
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Workerdef _fake_filter_fn_constant(constant, data):
722*da0073e9SAndroid Build Coastguard Worker    return data >= constant
723*da0073e9SAndroid Build Coastguard Worker
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Workerdef _mul_10(x):
726*da0073e9SAndroid Build Coastguard Worker    return x * 10
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Workerdef _mod_3_test(x):
730*da0073e9SAndroid Build Coastguard Worker    return x % 3 == 1
731*da0073e9SAndroid Build Coastguard Worker
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Workerdef _to_list(x):
734*da0073e9SAndroid Build Coastguard Worker    return [x]
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Workerlambda_fn1 = lambda x: x  # noqa: E731
738*da0073e9SAndroid Build Coastguard Workerlambda_fn2 = lambda x: x % 2  # noqa: E731
739*da0073e9SAndroid Build Coastguard Workerlambda_fn3 = lambda x: x >= 5  # noqa: E731
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Workerclass Add1Module(nn.Module):
743*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
744*da0073e9SAndroid Build Coastguard Worker        return x + 1
745*da0073e9SAndroid Build Coastguard Worker
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Workerclass Add1Callable:
748*da0073e9SAndroid Build Coastguard Worker    def __call__(self, x):
749*da0073e9SAndroid Build Coastguard Worker        return x + 1
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Workerclass TestFunctionalIterDataPipe(TestCase):
753*da0073e9SAndroid Build Coastguard Worker    def _serialization_test_helper(self, datapipe, use_dill):
754*da0073e9SAndroid Build Coastguard Worker        if use_dill:
755*da0073e9SAndroid Build Coastguard Worker            serialized_dp = dill.dumps(datapipe)
756*da0073e9SAndroid Build Coastguard Worker            deserialized_dp = dill.loads(serialized_dp)
757*da0073e9SAndroid Build Coastguard Worker        else:
758*da0073e9SAndroid Build Coastguard Worker            serialized_dp = pickle.dumps(datapipe)
759*da0073e9SAndroid Build Coastguard Worker            deserialized_dp = pickle.loads(serialized_dp)
760*da0073e9SAndroid Build Coastguard Worker        try:
761*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(datapipe), list(deserialized_dp))
762*da0073e9SAndroid Build Coastguard Worker        except AssertionError as e:
763*da0073e9SAndroid Build Coastguard Worker            print(f"{datapipe} is failing.")
764*da0073e9SAndroid Build Coastguard Worker            raise e
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker    def _serialization_test_for_single_dp(self, dp, use_dill=False):
767*da0073e9SAndroid Build Coastguard Worker        # 1. Testing for serialization before any iteration starts
768*da0073e9SAndroid Build Coastguard Worker        self._serialization_test_helper(dp, use_dill)
769*da0073e9SAndroid Build Coastguard Worker        # 2. Testing for serialization after DataPipe is partially read
770*da0073e9SAndroid Build Coastguard Worker        it = iter(dp)
771*da0073e9SAndroid Build Coastguard Worker        _ = next(it)
772*da0073e9SAndroid Build Coastguard Worker        self._serialization_test_helper(dp, use_dill)
773*da0073e9SAndroid Build Coastguard Worker        # 3. Testing for serialization after DataPipe is fully read
774*da0073e9SAndroid Build Coastguard Worker        it = iter(dp)
775*da0073e9SAndroid Build Coastguard Worker        _ = list(it)
776*da0073e9SAndroid Build Coastguard Worker        self._serialization_test_helper(dp, use_dill)
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker    def _serialization_test_for_dp_with_children(self, dp1, dp2, use_dill=False):
779*da0073e9SAndroid Build Coastguard Worker        # 1. Testing for serialization before any iteration starts
780*da0073e9SAndroid Build Coastguard Worker        self._serialization_test_helper(dp1, use_dill)
781*da0073e9SAndroid Build Coastguard Worker        self._serialization_test_helper(dp2, use_dill)
782*da0073e9SAndroid Build Coastguard Worker
783*da0073e9SAndroid Build Coastguard Worker        # 2. Testing for serialization after DataPipe is partially read
784*da0073e9SAndroid Build Coastguard Worker        it1, it2 = iter(dp1), iter(dp2)
785*da0073e9SAndroid Build Coastguard Worker        _, _ = next(it1), next(it2)
786*da0073e9SAndroid Build Coastguard Worker        # Catch `fork`, `demux` "some child DataPipes are not exhausted" warning
787*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
788*da0073e9SAndroid Build Coastguard Worker            self._serialization_test_helper(dp1, use_dill)
789*da0073e9SAndroid Build Coastguard Worker            self._serialization_test_helper(dp2, use_dill)
790*da0073e9SAndroid Build Coastguard Worker
791*da0073e9SAndroid Build Coastguard Worker        # 2.5. Testing for serialization after one child DataPipe is fully read
792*da0073e9SAndroid Build Coastguard Worker        #      (Only for DataPipes with children DataPipes)
793*da0073e9SAndroid Build Coastguard Worker        it1 = iter(dp1)
794*da0073e9SAndroid Build Coastguard Worker        _ = list(it1)  # fully read one child
795*da0073e9SAndroid Build Coastguard Worker        # Catch `fork`, `demux` "some child DataPipes are not exhausted" warning
796*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
797*da0073e9SAndroid Build Coastguard Worker            self._serialization_test_helper(dp1, use_dill)
798*da0073e9SAndroid Build Coastguard Worker            self._serialization_test_helper(dp2, use_dill)
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker        # 3. Testing for serialization after DataPipe is fully read
801*da0073e9SAndroid Build Coastguard Worker        it2 = iter(dp2)
802*da0073e9SAndroid Build Coastguard Worker        _ = list(it2)  # fully read the other child
803*da0073e9SAndroid Build Coastguard Worker        self._serialization_test_helper(dp1, use_dill)
804*da0073e9SAndroid Build Coastguard Worker        self._serialization_test_helper(dp2, use_dill)
805*da0073e9SAndroid Build Coastguard Worker
806*da0073e9SAndroid Build Coastguard Worker    def test_serializable(self):
807*da0073e9SAndroid Build Coastguard Worker        picklable_datapipes: List = [
808*da0073e9SAndroid Build Coastguard Worker            (
809*da0073e9SAndroid Build Coastguard Worker                dp.iter.Batcher,
810*da0073e9SAndroid Build Coastguard Worker                None,
811*da0073e9SAndroid Build Coastguard Worker                (
812*da0073e9SAndroid Build Coastguard Worker                    3,
813*da0073e9SAndroid Build Coastguard Worker                    True,
814*da0073e9SAndroid Build Coastguard Worker                ),
815*da0073e9SAndroid Build Coastguard Worker                {},
816*da0073e9SAndroid Build Coastguard Worker            ),
817*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Collator, None, (_fake_fn,), {}),
818*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Concater, None, (dp.iter.IterableWrapper(range(5)),), {}),
819*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Demultiplexer, None, (2, _simple_filter_fn), {}),
820*da0073e9SAndroid Build Coastguard Worker            (dp.iter.FileLister, ".", (), {}),
821*da0073e9SAndroid Build Coastguard Worker            (dp.iter.FileOpener, None, (), {}),
822*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Filter, None, (_fake_filter_fn,), {}),
823*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Filter, None, (partial(_fake_filter_fn_constant, 5),), {}),
824*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Forker, None, (2,), {}),
825*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Forker, None, (2,), {"copy": "shallow"}),
826*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Grouper, None, (_fake_filter_fn,), {"group_size": 2}),
827*da0073e9SAndroid Build Coastguard Worker            (dp.iter.IterableWrapper, range(10), (), {}),
828*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Mapper, None, (_fake_fn,), {}),
829*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Mapper, None, (partial(_fake_add, 1),), {}),
830*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Multiplexer, None, (dp.iter.IterableWrapper(range(10)),), {}),
831*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Sampler, None, (), {}),
832*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Shuffler, dp.iter.IterableWrapper([0] * 10), (), {}),
833*da0073e9SAndroid Build Coastguard Worker            (dp.iter.StreamReader, None, (), {}),
834*da0073e9SAndroid Build Coastguard Worker            (dp.iter.UnBatcher, None, (0,), {}),
835*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Zipper, None, (dp.iter.IterableWrapper(range(10)),), {}),
836*da0073e9SAndroid Build Coastguard Worker        ]
837*da0073e9SAndroid Build Coastguard Worker        # Skipping comparison for these DataPipes
838*da0073e9SAndroid Build Coastguard Worker        dp_skip_comparison = {dp.iter.FileOpener, dp.iter.StreamReader}
839*da0073e9SAndroid Build Coastguard Worker        # These DataPipes produce multiple DataPipes as outputs and those should be compared
840*da0073e9SAndroid Build Coastguard Worker        dp_compare_children = {dp.iter.Demultiplexer, dp.iter.Forker}
841*da0073e9SAndroid Build Coastguard Worker
842*da0073e9SAndroid Build Coastguard Worker        for dpipe, custom_input, dp_args, dp_kwargs in picklable_datapipes:
843*da0073e9SAndroid Build Coastguard Worker            if custom_input is None:
844*da0073e9SAndroid Build Coastguard Worker                custom_input = dp.iter.IterableWrapper(range(10))
845*da0073e9SAndroid Build Coastguard Worker            if (
846*da0073e9SAndroid Build Coastguard Worker                dpipe in dp_skip_comparison
847*da0073e9SAndroid Build Coastguard Worker            ):  # Merely make sure they are picklable and loadable (no value comparison)
848*da0073e9SAndroid Build Coastguard Worker                datapipe = dpipe(custom_input, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
849*da0073e9SAndroid Build Coastguard Worker                serialized_dp = pickle.dumps(datapipe)
850*da0073e9SAndroid Build Coastguard Worker                _ = pickle.loads(serialized_dp)
851*da0073e9SAndroid Build Coastguard Worker            elif dpipe in dp_compare_children:  # DataPipes that have children
852*da0073e9SAndroid Build Coastguard Worker                dp1, dp2 = dpipe(custom_input, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
853*da0073e9SAndroid Build Coastguard Worker                self._serialization_test_for_dp_with_children(dp1, dp2)
854*da0073e9SAndroid Build Coastguard Worker            else:  # Single DataPipe that requires comparison
855*da0073e9SAndroid Build Coastguard Worker                datapipe = dpipe(custom_input, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
856*da0073e9SAndroid Build Coastguard Worker                self._serialization_test_for_single_dp(datapipe)
857*da0073e9SAndroid Build Coastguard Worker
858*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Dict with function as keys")
859*da0073e9SAndroid Build Coastguard Worker    def test_serializable_with_dill(self):
860*da0073e9SAndroid Build Coastguard Worker        """Only for DataPipes that take in a function as argument"""
861*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper(range(10))
862*da0073e9SAndroid Build Coastguard Worker
863*da0073e9SAndroid Build Coastguard Worker        datapipes_with_lambda_fn: List[
864*da0073e9SAndroid Build Coastguard Worker            Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]
865*da0073e9SAndroid Build Coastguard Worker        ] = [
866*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Collator, (lambda_fn1,), {}),
867*da0073e9SAndroid Build Coastguard Worker            (
868*da0073e9SAndroid Build Coastguard Worker                dp.iter.Demultiplexer,
869*da0073e9SAndroid Build Coastguard Worker                (
870*da0073e9SAndroid Build Coastguard Worker                    2,
871*da0073e9SAndroid Build Coastguard Worker                    lambda_fn2,
872*da0073e9SAndroid Build Coastguard Worker                ),
873*da0073e9SAndroid Build Coastguard Worker                {},
874*da0073e9SAndroid Build Coastguard Worker            ),
875*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Filter, (lambda_fn3,), {}),
876*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Grouper, (lambda_fn3,), {}),
877*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Mapper, (lambda_fn1,), {}),
878*da0073e9SAndroid Build Coastguard Worker        ]
879*da0073e9SAndroid Build Coastguard Worker
880*da0073e9SAndroid Build Coastguard Worker        def _local_fns():
881*da0073e9SAndroid Build Coastguard Worker            def _fn1(x):
882*da0073e9SAndroid Build Coastguard Worker                return x
883*da0073e9SAndroid Build Coastguard Worker
884*da0073e9SAndroid Build Coastguard Worker            def _fn2(x):
885*da0073e9SAndroid Build Coastguard Worker                return x % 2
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker            def _fn3(x):
888*da0073e9SAndroid Build Coastguard Worker                return x >= 5
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker            return _fn1, _fn2, _fn3
891*da0073e9SAndroid Build Coastguard Worker
892*da0073e9SAndroid Build Coastguard Worker        fn1, fn2, fn3 = _local_fns()
893*da0073e9SAndroid Build Coastguard Worker
894*da0073e9SAndroid Build Coastguard Worker        datapipes_with_local_fn: List[
895*da0073e9SAndroid Build Coastguard Worker            Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]
896*da0073e9SAndroid Build Coastguard Worker        ] = [
897*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Collator, (fn1,), {}),
898*da0073e9SAndroid Build Coastguard Worker            (
899*da0073e9SAndroid Build Coastguard Worker                dp.iter.Demultiplexer,
900*da0073e9SAndroid Build Coastguard Worker                (
901*da0073e9SAndroid Build Coastguard Worker                    2,
902*da0073e9SAndroid Build Coastguard Worker                    fn2,
903*da0073e9SAndroid Build Coastguard Worker                ),
904*da0073e9SAndroid Build Coastguard Worker                {},
905*da0073e9SAndroid Build Coastguard Worker            ),
906*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Filter, (fn3,), {}),
907*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Grouper, (fn3,), {}),
908*da0073e9SAndroid Build Coastguard Worker            (dp.iter.Mapper, (fn1,), {}),
909*da0073e9SAndroid Build Coastguard Worker        ]
910*da0073e9SAndroid Build Coastguard Worker
911*da0073e9SAndroid Build Coastguard Worker        dp_compare_children = {dp.iter.Demultiplexer}
912*da0073e9SAndroid Build Coastguard Worker
913*da0073e9SAndroid Build Coastguard Worker        if HAS_DILL:
914*da0073e9SAndroid Build Coastguard Worker            for dpipe, dp_args, dp_kwargs in (
915*da0073e9SAndroid Build Coastguard Worker                datapipes_with_lambda_fn + datapipes_with_local_fn
916*da0073e9SAndroid Build Coastguard Worker            ):
917*da0073e9SAndroid Build Coastguard Worker                if dpipe in dp_compare_children:
918*da0073e9SAndroid Build Coastguard Worker                    dp1, dp2 = dpipe(input_dp, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
919*da0073e9SAndroid Build Coastguard Worker                    self._serialization_test_for_dp_with_children(
920*da0073e9SAndroid Build Coastguard Worker                        dp1, dp2, use_dill=True
921*da0073e9SAndroid Build Coastguard Worker                    )
922*da0073e9SAndroid Build Coastguard Worker                else:
923*da0073e9SAndroid Build Coastguard Worker                    datapipe = dpipe(input_dp, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
924*da0073e9SAndroid Build Coastguard Worker                    self._serialization_test_for_single_dp(datapipe, use_dill=True)
925*da0073e9SAndroid Build Coastguard Worker        else:
926*da0073e9SAndroid Build Coastguard Worker            msgs = (
927*da0073e9SAndroid Build Coastguard Worker                r"^Lambda function is not supported by pickle",
928*da0073e9SAndroid Build Coastguard Worker                r"^Local function is not supported by pickle",
929*da0073e9SAndroid Build Coastguard Worker            )
930*da0073e9SAndroid Build Coastguard Worker            for dps, msg in zip(
931*da0073e9SAndroid Build Coastguard Worker                (datapipes_with_lambda_fn, datapipes_with_local_fn), msgs
932*da0073e9SAndroid Build Coastguard Worker            ):
933*da0073e9SAndroid Build Coastguard Worker                for dpipe, dp_args, dp_kwargs in dps:
934*da0073e9SAndroid Build Coastguard Worker                    with self.assertWarnsRegex(UserWarning, msg):
935*da0073e9SAndroid Build Coastguard Worker                        datapipe = dpipe(input_dp, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
936*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises((pickle.PicklingError, AttributeError)):
937*da0073e9SAndroid Build Coastguard Worker                        pickle.dumps(datapipe)
938*da0073e9SAndroid Build Coastguard Worker
939*da0073e9SAndroid Build Coastguard Worker    def test_docstring(self):
940*da0073e9SAndroid Build Coastguard Worker        """
941*da0073e9SAndroid Build Coastguard Worker        Ensure functional form of IterDataPipe has the correct docstring from
942*da0073e9SAndroid Build Coastguard Worker        the class form.
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker        Regression test for https://github.com/pytorch/data/issues/792.
945*da0073e9SAndroid Build Coastguard Worker        """
946*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper(range(10))
947*da0073e9SAndroid Build Coastguard Worker
948*da0073e9SAndroid Build Coastguard Worker        for dp_funcname in [
949*da0073e9SAndroid Build Coastguard Worker            "batch",
950*da0073e9SAndroid Build Coastguard Worker            "collate",
951*da0073e9SAndroid Build Coastguard Worker            "concat",
952*da0073e9SAndroid Build Coastguard Worker            "demux",
953*da0073e9SAndroid Build Coastguard Worker            "filter",
954*da0073e9SAndroid Build Coastguard Worker            "fork",
955*da0073e9SAndroid Build Coastguard Worker            "map",
956*da0073e9SAndroid Build Coastguard Worker            "mux",
957*da0073e9SAndroid Build Coastguard Worker            "read_from_stream",
958*da0073e9SAndroid Build Coastguard Worker            # "sampler",
959*da0073e9SAndroid Build Coastguard Worker            "shuffle",
960*da0073e9SAndroid Build Coastguard Worker            "unbatch",
961*da0073e9SAndroid Build Coastguard Worker            "zip",
962*da0073e9SAndroid Build Coastguard Worker        ]:
963*da0073e9SAndroid Build Coastguard Worker            if sys.version_info >= (3, 9):
964*da0073e9SAndroid Build Coastguard Worker                docstring = pydoc.render_doc(
965*da0073e9SAndroid Build Coastguard Worker                    thing=getattr(input_dp, dp_funcname), forceload=True
966*da0073e9SAndroid Build Coastguard Worker                )
967*da0073e9SAndroid Build Coastguard Worker            elif sys.version_info < (3, 9):
968*da0073e9SAndroid Build Coastguard Worker                # pydoc works differently on Python 3.8, see
969*da0073e9SAndroid Build Coastguard Worker                # https://docs.python.org/3/whatsnew/3.9.html#pydoc
970*da0073e9SAndroid Build Coastguard Worker                docstring = getattr(input_dp, dp_funcname).__doc__
971*da0073e9SAndroid Build Coastguard Worker
972*da0073e9SAndroid Build Coastguard Worker            assert f"(functional name: ``{dp_funcname}``)" in docstring
973*da0073e9SAndroid Build Coastguard Worker            assert "Args:" in docstring
974*da0073e9SAndroid Build Coastguard Worker            assert "Example:" in docstring or "Examples:" in docstring
975*da0073e9SAndroid Build Coastguard Worker
976*da0073e9SAndroid Build Coastguard Worker    def test_iterable_wrapper_datapipe(self):
977*da0073e9SAndroid Build Coastguard Worker        input_ls = list(range(10))
978*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper(input_ls)
979*da0073e9SAndroid Build Coastguard Worker
980*da0073e9SAndroid Build Coastguard Worker        # Functional Test: values are unchanged and in the same order
981*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input_ls, list(input_dp))
982*da0073e9SAndroid Build Coastguard Worker
983*da0073e9SAndroid Build Coastguard Worker        # Functional Test: deep copy by default when an iterator is initialized (first element is read)
984*da0073e9SAndroid Build Coastguard Worker        it = iter(input_dp)
985*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
986*da0073e9SAndroid Build Coastguard Worker            0, next(it)
987*da0073e9SAndroid Build Coastguard Worker        )  # The deep copy only happens when the first element is read
988*da0073e9SAndroid Build Coastguard Worker        input_ls.append(50)
989*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(1, 10)), list(it))
990*da0073e9SAndroid Build Coastguard Worker
991*da0073e9SAndroid Build Coastguard Worker        # Functional Test: shallow copy
992*da0073e9SAndroid Build Coastguard Worker        input_ls2 = [1, 2, 3]
993*da0073e9SAndroid Build Coastguard Worker        input_dp_shallow = dp.iter.IterableWrapper(input_ls2, deepcopy=False)
994*da0073e9SAndroid Build Coastguard Worker        input_ls2.append(10)
995*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([1, 2, 3, 10], list(input_dp_shallow))
996*da0073e9SAndroid Build Coastguard Worker
997*da0073e9SAndroid Build Coastguard Worker        # Reset Test: reset the DataPipe
998*da0073e9SAndroid Build Coastguard Worker        input_ls = list(range(10))
999*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper(input_ls)
1000*da0073e9SAndroid Build Coastguard Worker        n_elements_before_reset = 5
1001*da0073e9SAndroid Build Coastguard Worker        res_before_reset, res_after_reset = reset_after_n_next_calls(
1002*da0073e9SAndroid Build Coastguard Worker            input_dp, n_elements_before_reset
1003*da0073e9SAndroid Build Coastguard Worker        )
1004*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input_ls[:n_elements_before_reset], res_before_reset)
1005*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input_ls, res_after_reset)
1006*da0073e9SAndroid Build Coastguard Worker
1007*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: inherits length from sequence
1008*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(input_ls), len(input_dp))
1009*da0073e9SAndroid Build Coastguard Worker
1010*da0073e9SAndroid Build Coastguard Worker    def test_concat_iterdatapipe(self):
1011*da0073e9SAndroid Build Coastguard Worker        input_dp1 = dp.iter.IterableWrapper(range(10))
1012*da0073e9SAndroid Build Coastguard Worker        input_dp2 = dp.iter.IterableWrapper(range(5))
1013*da0073e9SAndroid Build Coastguard Worker
1014*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Raises exception for empty input
1015*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
1016*da0073e9SAndroid Build Coastguard Worker            dp.iter.Concater()
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Raises exception for non-IterDataPipe input
1019*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1020*da0073e9SAndroid Build Coastguard Worker            TypeError, r"Expected all inputs to be `IterDataPipe`"
1021*da0073e9SAndroid Build Coastguard Worker        ):
1022*da0073e9SAndroid Build Coastguard Worker            dp.iter.Concater(input_dp1, ())  # type: ignore[arg-type]
1023*da0073e9SAndroid Build Coastguard Worker
1024*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Concatenate DataPipes as expected
1025*da0073e9SAndroid Build Coastguard Worker        concat_dp = input_dp1.concat(input_dp2)
1026*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(concat_dp), 15)
1027*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(concat_dp), list(range(10)) + list(range(5)))
1028*da0073e9SAndroid Build Coastguard Worker
1029*da0073e9SAndroid Build Coastguard Worker        # Reset Test: reset the DataPipe
1030*da0073e9SAndroid Build Coastguard Worker        n_elements_before_reset = 5
1031*da0073e9SAndroid Build Coastguard Worker        res_before_reset, res_after_reset = reset_after_n_next_calls(
1032*da0073e9SAndroid Build Coastguard Worker            concat_dp, n_elements_before_reset
1033*da0073e9SAndroid Build Coastguard Worker        )
1034*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(5)), res_before_reset)
1035*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)) + list(range(5)), res_after_reset)
1036*da0073e9SAndroid Build Coastguard Worker
1037*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: inherits length from source DataPipe
1038*da0073e9SAndroid Build Coastguard Worker        input_dp_nl = IDP_NoLen(range(5))
1039*da0073e9SAndroid Build Coastguard Worker        concat_dp = input_dp1.concat(input_dp_nl)
1040*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
1041*da0073e9SAndroid Build Coastguard Worker            len(concat_dp)
1042*da0073e9SAndroid Build Coastguard Worker
1043*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(concat_dp), list(range(10)) + list(range(5)))
1044*da0073e9SAndroid Build Coastguard Worker
1045*da0073e9SAndroid Build Coastguard Worker    def test_fork_iterdatapipe(self):
1046*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper(range(10))
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1049*da0073e9SAndroid Build Coastguard Worker            input_dp.fork(num_instances=0)
1050*da0073e9SAndroid Build Coastguard Worker
1051*da0073e9SAndroid Build Coastguard Worker        dp0 = input_dp.fork(num_instances=1, buffer_size=0)
1052*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dp0, input_dp)
1053*da0073e9SAndroid Build Coastguard Worker
1054*da0073e9SAndroid Build Coastguard Worker        # Functional Test: making sure all child DataPipe shares the same reference
1055*da0073e9SAndroid Build Coastguard Worker        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
1056*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(n1 is n2 and n1 is n3 for n1, n2, n3 in zip(dp1, dp2, dp3)))
1057*da0073e9SAndroid Build Coastguard Worker
1058*da0073e9SAndroid Build Coastguard Worker        # Functional Test: one child DataPipe yields all value at a time
1059*da0073e9SAndroid Build Coastguard Worker        output1, output2, output3 = list(dp1), list(dp2), list(dp3)
1060*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), output1)
1061*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), output2)
1062*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), output3)
1063*da0073e9SAndroid Build Coastguard Worker
1064*da0073e9SAndroid Build Coastguard Worker        # Functional Test: two child DataPipes yield value together
1065*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.fork(num_instances=2)
1066*da0073e9SAndroid Build Coastguard Worker        output = []
1067*da0073e9SAndroid Build Coastguard Worker        for n1, n2 in zip(dp1, dp2):
1068*da0073e9SAndroid Build Coastguard Worker            output.append((n1, n2))
1069*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(i, i) for i in range(10)], output)
1070*da0073e9SAndroid Build Coastguard Worker
1071*da0073e9SAndroid Build Coastguard Worker        # Functional Test: one child DataPipe yields all value first, but buffer_size = 5 being too small
1072*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=4)
1073*da0073e9SAndroid Build Coastguard Worker        it1 = iter(dp1)
1074*da0073e9SAndroid Build Coastguard Worker        for _ in range(4):
1075*da0073e9SAndroid Build Coastguard Worker            next(it1)
1076*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(BufferError):
1077*da0073e9SAndroid Build Coastguard Worker            next(it1)
1078*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(BufferError):
1079*da0073e9SAndroid Build Coastguard Worker            list(dp2)
1080*da0073e9SAndroid Build Coastguard Worker
1081*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=5)
1082*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(BufferError):
1083*da0073e9SAndroid Build Coastguard Worker            list(dp2)
1084*da0073e9SAndroid Build Coastguard Worker
1085*da0073e9SAndroid Build Coastguard Worker        # Functional Test: one child DataPipe yields all value first with unlimited buffer
1086*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
1087*da0073e9SAndroid Build Coastguard Worker            dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=-1)
1088*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(wa), 1)
1089*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(str(wa[0].message), r"Unlimited buffer size is set")
1090*da0073e9SAndroid Build Coastguard Worker        l1, l2 = list(dp1), list(dp2)
1091*da0073e9SAndroid Build Coastguard Worker        for d1, d2 in zip(l1, l2):
1092*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(d1, d2)
1093*da0073e9SAndroid Build Coastguard Worker
1094*da0073e9SAndroid Build Coastguard Worker        # Functional Test: two child DataPipes yield value together with buffer size 1
1095*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=1)
1096*da0073e9SAndroid Build Coastguard Worker        output = []
1097*da0073e9SAndroid Build Coastguard Worker        for n1, n2 in zip(dp1, dp2):
1098*da0073e9SAndroid Build Coastguard Worker            output.append((n1, n2))
1099*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(i, i) for i in range(10)], output)
1100*da0073e9SAndroid Build Coastguard Worker
1101*da0073e9SAndroid Build Coastguard Worker        # Functional Test: two child DataPipes yield shallow copies with copy equals shallow
1102*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.map(_to_list).fork(num_instances=2, copy="shallow")
1103*da0073e9SAndroid Build Coastguard Worker        for n1, n2 in zip(dp1, dp2):
1104*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(n1, n2)
1105*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(n1, n2)
1106*da0073e9SAndroid Build Coastguard Worker
1107*da0073e9SAndroid Build Coastguard Worker        # Functional Test: two child DataPipes yield deep copies with copy equals deep
1108*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = (
1109*da0073e9SAndroid Build Coastguard Worker            input_dp.map(_to_list).map(_to_list).fork(num_instances=2, copy="deep")
1110*da0073e9SAndroid Build Coastguard Worker        )
1111*da0073e9SAndroid Build Coastguard Worker        for n1, n2 in zip(dp1, dp2):
1112*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(n1[0], n2[0])
1113*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(n1, n2)
1114*da0073e9SAndroid Build Coastguard Worker
1115*da0073e9SAndroid Build Coastguard Worker        # Functional Test: fork DataPipe raises error for unknown copy method
1116*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1117*da0073e9SAndroid Build Coastguard Worker            input_dp.fork(num_instances=2, copy="unknown")
1118*da0073e9SAndroid Build Coastguard Worker
1119*da0073e9SAndroid Build Coastguard Worker        # Functional Test: make sure logic related to slowest_ptr is working properly
1120*da0073e9SAndroid Build Coastguard Worker        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
1121*da0073e9SAndroid Build Coastguard Worker        output1, output2, output3 = [], [], []
1122*da0073e9SAndroid Build Coastguard Worker        for i, (n1, n2) in enumerate(zip(dp1, dp2)):
1123*da0073e9SAndroid Build Coastguard Worker            output1.append(n1)
1124*da0073e9SAndroid Build Coastguard Worker            output2.append(n2)
1125*da0073e9SAndroid Build Coastguard Worker            if i == 4:  # yield all of dp3 when halfway through dp1, dp2
1126*da0073e9SAndroid Build Coastguard Worker                output3 = list(dp3)
1127*da0073e9SAndroid Build Coastguard Worker                break
1128*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(5)), output1)
1129*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(5)), output2)
1130*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), output3)
1131*da0073e9SAndroid Build Coastguard Worker
1132*da0073e9SAndroid Build Coastguard Worker        # Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read
1133*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.fork(num_instances=2)
1134*da0073e9SAndroid Build Coastguard Worker        _ = iter(dp1)
1135*da0073e9SAndroid Build Coastguard Worker        output2 = []
1136*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"iterator has been invalidated"):
1137*da0073e9SAndroid Build Coastguard Worker            for i, n2 in enumerate(dp2):
1138*da0073e9SAndroid Build Coastguard Worker                output2.append(n2)
1139*da0073e9SAndroid Build Coastguard Worker                if i == 4:
1140*da0073e9SAndroid Build Coastguard Worker                    with warnings.catch_warnings(record=True) as wa:
1141*da0073e9SAndroid Build Coastguard Worker                        _ = iter(dp1)  # This will reset all child DataPipes
1142*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(len(wa), 1)
1143*da0073e9SAndroid Build Coastguard Worker                        self.assertRegex(
1144*da0073e9SAndroid Build Coastguard Worker                            str(wa[0].message), r"child DataPipes are not exhausted"
1145*da0073e9SAndroid Build Coastguard Worker                        )
1146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(5)), output2)
1147*da0073e9SAndroid Build Coastguard Worker
1148*da0073e9SAndroid Build Coastguard Worker        # Reset Test: DataPipe resets when some of it has been read
1149*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.fork(num_instances=2)
1150*da0073e9SAndroid Build Coastguard Worker        output1, output2 = [], []
1151*da0073e9SAndroid Build Coastguard Worker        for i, (n1, n2) in enumerate(zip(dp1, dp2)):
1152*da0073e9SAndroid Build Coastguard Worker            output1.append(n1)
1153*da0073e9SAndroid Build Coastguard Worker            output2.append(n2)
1154*da0073e9SAndroid Build Coastguard Worker            if i == 4:
1155*da0073e9SAndroid Build Coastguard Worker                with warnings.catch_warnings(record=True) as wa:
1156*da0073e9SAndroid Build Coastguard Worker                    _ = iter(dp1)  # Reset both all child DataPipe
1157*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(len(wa), 1)
1158*da0073e9SAndroid Build Coastguard Worker                    self.assertRegex(
1159*da0073e9SAndroid Build Coastguard Worker                        str(wa[0].message), r"Some child DataPipes are not exhausted"
1160*da0073e9SAndroid Build Coastguard Worker                    )
1161*da0073e9SAndroid Build Coastguard Worker                break
1162*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
1163*da0073e9SAndroid Build Coastguard Worker            for i, (n1, n2) in enumerate(zip(dp1, dp2)):
1164*da0073e9SAndroid Build Coastguard Worker                output1.append(n1)
1165*da0073e9SAndroid Build Coastguard Worker                output2.append(n2)
1166*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(wa), 1)
1167*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted")
1168*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(5)) + list(range(10)), output1)
1169*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(5)) + list(range(10)), output2)
1170*da0073e9SAndroid Build Coastguard Worker
1171*da0073e9SAndroid Build Coastguard Worker        # Reset Test: DataPipe reset, even when some other child DataPipes are not read
1172*da0073e9SAndroid Build Coastguard Worker        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
1173*da0073e9SAndroid Build Coastguard Worker        output1, output2 = list(dp1), list(dp2)
1174*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), output1)
1175*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), output2)
1176*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
1177*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1178*da0073e9SAndroid Build Coastguard Worker                list(range(10)), list(dp1)
1179*da0073e9SAndroid Build Coastguard Worker            )  # Resets even though dp3 has not been read
1180*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(wa), 1)
1181*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(
1182*da0073e9SAndroid Build Coastguard Worker                str(wa[0].message), r"Some child DataPipes are not exhausted"
1183*da0073e9SAndroid Build Coastguard Worker            )
1184*da0073e9SAndroid Build Coastguard Worker        output3 = []
1185*da0073e9SAndroid Build Coastguard Worker        for i, n3 in enumerate(dp3):
1186*da0073e9SAndroid Build Coastguard Worker            output3.append(n3)
1187*da0073e9SAndroid Build Coastguard Worker            if i == 4:
1188*da0073e9SAndroid Build Coastguard Worker                with warnings.catch_warnings(record=True) as wa:
1189*da0073e9SAndroid Build Coastguard Worker                    output1 = list(dp1)  # Resets even though dp3 is only partially read
1190*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(len(wa), 1)
1191*da0073e9SAndroid Build Coastguard Worker                    self.assertRegex(
1192*da0073e9SAndroid Build Coastguard Worker                        str(wa[0].message), r"Some child DataPipes are not exhausted"
1193*da0073e9SAndroid Build Coastguard Worker                    )
1194*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(list(range(5)), output3)
1195*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(list(range(10)), output1)
1196*da0073e9SAndroid Build Coastguard Worker                break
1197*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1198*da0073e9SAndroid Build Coastguard Worker            list(range(10)), list(dp3)
1199*da0073e9SAndroid Build Coastguard Worker        )  # dp3 has to read from the start again
1200*da0073e9SAndroid Build Coastguard Worker
1201*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: Each DataPipe inherits the source datapipe's length
1202*da0073e9SAndroid Build Coastguard Worker        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
1203*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(input_dp), len(dp1))
1204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(input_dp), len(dp2))
1205*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(input_dp), len(dp3))
1206*da0073e9SAndroid Build Coastguard Worker
1207*da0073e9SAndroid Build Coastguard Worker        # Pickle Test:
1208*da0073e9SAndroid Build Coastguard Worker        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
1209*da0073e9SAndroid Build Coastguard Worker        traverse_dps(dp1)  # This should not raise any error
1210*da0073e9SAndroid Build Coastguard Worker        for _ in zip(dp1, dp2, dp3):
1211*da0073e9SAndroid Build Coastguard Worker            pass
1212*da0073e9SAndroid Build Coastguard Worker        traverse_dps(dp2)  # This should not raise any error either
1213*da0073e9SAndroid Build Coastguard Worker
1214*da0073e9SAndroid Build Coastguard Worker    def test_mux_iterdatapipe(self):
1215*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Elements are yielded one at a time from each DataPipe, until they are all exhausted
1216*da0073e9SAndroid Build Coastguard Worker        input_dp1 = dp.iter.IterableWrapper(range(4))
1217*da0073e9SAndroid Build Coastguard Worker        input_dp2 = dp.iter.IterableWrapper(range(4, 8))
1218*da0073e9SAndroid Build Coastguard Worker        input_dp3 = dp.iter.IterableWrapper(range(8, 12))
1219*da0073e9SAndroid Build Coastguard Worker        output_dp = input_dp1.mux(input_dp2, input_dp3)
1220*da0073e9SAndroid Build Coastguard Worker        expected_output = [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11]
1221*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(expected_output), len(output_dp))
1222*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_output, list(output_dp))
1223*da0073e9SAndroid Build Coastguard Worker
1224*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Uneven input Data Pipes
1225*da0073e9SAndroid Build Coastguard Worker        input_dp1 = dp.iter.IterableWrapper([1, 2, 3, 4])
1226*da0073e9SAndroid Build Coastguard Worker        input_dp2 = dp.iter.IterableWrapper([10])
1227*da0073e9SAndroid Build Coastguard Worker        input_dp3 = dp.iter.IterableWrapper([100, 200, 300])
1228*da0073e9SAndroid Build Coastguard Worker        output_dp = input_dp1.mux(input_dp2, input_dp3)
1229*da0073e9SAndroid Build Coastguard Worker        expected_output = [1, 10, 100]
1230*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(expected_output), len(output_dp))
1231*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_output, list(output_dp))
1232*da0073e9SAndroid Build Coastguard Worker
1233*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Empty Data Pipe
1234*da0073e9SAndroid Build Coastguard Worker        input_dp1 = dp.iter.IterableWrapper([0, 1, 2, 3])
1235*da0073e9SAndroid Build Coastguard Worker        input_dp2 = dp.iter.IterableWrapper([])
1236*da0073e9SAndroid Build Coastguard Worker        output_dp = input_dp1.mux(input_dp2)
1237*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(input_dp2), len(output_dp))
1238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(input_dp2), list(output_dp))
1239*da0073e9SAndroid Build Coastguard Worker
1240*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: raises TypeError when __len__ is called and an input doesn't have __len__
1241*da0073e9SAndroid Build Coastguard Worker        input_dp1 = dp.iter.IterableWrapper(range(10))
1242*da0073e9SAndroid Build Coastguard Worker        input_dp_no_len = IDP_NoLen(range(10))
1243*da0073e9SAndroid Build Coastguard Worker        output_dp = input_dp1.mux(input_dp_no_len)
1244*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1245*da0073e9SAndroid Build Coastguard Worker            len(output_dp)
1246*da0073e9SAndroid Build Coastguard Worker
1247*da0073e9SAndroid Build Coastguard Worker    def test_demux_iterdatapipe(self):
1248*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper(range(10))
1249*da0073e9SAndroid Build Coastguard Worker
1250*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1251*da0073e9SAndroid Build Coastguard Worker            input_dp.demux(num_instances=0, classifier_fn=lambda x: 0)
1252*da0073e9SAndroid Build Coastguard Worker
1253*da0073e9SAndroid Build Coastguard Worker        # Functional Test: split into 2 DataPipes and output them one at a time
1254*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
1255*da0073e9SAndroid Build Coastguard Worker        output1, output2 = list(dp1), list(dp2)
1256*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(0, 10, 2)), output1)
1257*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(1, 10, 2)), output2)
1258*da0073e9SAndroid Build Coastguard Worker
1259*da0073e9SAndroid Build Coastguard Worker        # Functional Test: split into 2 DataPipes and output them together
1260*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
1261*da0073e9SAndroid Build Coastguard Worker        output = []
1262*da0073e9SAndroid Build Coastguard Worker        for n1, n2 in zip(dp1, dp2):
1263*da0073e9SAndroid Build Coastguard Worker            output.append((n1, n2))
1264*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(i, i + 1) for i in range(0, 10, 2)], output)
1265*da0073e9SAndroid Build Coastguard Worker
1266*da0073e9SAndroid Build Coastguard Worker        # Functional Test: values of the same classification are lumped together, and buffer_size = 3 being too small
1267*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.demux(
1268*da0073e9SAndroid Build Coastguard Worker            num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=4
1269*da0073e9SAndroid Build Coastguard Worker        )
1270*da0073e9SAndroid Build Coastguard Worker        it1 = iter(dp1)
1271*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(BufferError):
1272*da0073e9SAndroid Build Coastguard Worker            next(
1273*da0073e9SAndroid Build Coastguard Worker                it1
1274*da0073e9SAndroid Build Coastguard Worker            )  # Buffer raises because first 5 elements all belong to the a different child
1275*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(BufferError):
1276*da0073e9SAndroid Build Coastguard Worker            list(dp2)
1277*da0073e9SAndroid Build Coastguard Worker
1278*da0073e9SAndroid Build Coastguard Worker        # Functional Test: values of the same classification are lumped together, and buffer_size = 5 is just enough
1279*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.demux(
1280*da0073e9SAndroid Build Coastguard Worker            num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=5
1281*da0073e9SAndroid Build Coastguard Worker        )
1282*da0073e9SAndroid Build Coastguard Worker        output1, output2 = list(dp1), list(dp2)
1283*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(5, 10)), output1)
1284*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(0, 5)), output2)
1285*da0073e9SAndroid Build Coastguard Worker
1286*da0073e9SAndroid Build Coastguard Worker        # Functional Test: values of the same classification are lumped together, and unlimited buffer
1287*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
1288*da0073e9SAndroid Build Coastguard Worker            dp1, dp2 = input_dp.demux(
1289*da0073e9SAndroid Build Coastguard Worker                num_instances=2,
1290*da0073e9SAndroid Build Coastguard Worker                classifier_fn=lambda x: 0 if x >= 5 else 1,
1291*da0073e9SAndroid Build Coastguard Worker                buffer_size=-1,
1292*da0073e9SAndroid Build Coastguard Worker            )
1293*da0073e9SAndroid Build Coastguard Worker            exp_l = 1 if HAS_DILL else 2
1294*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(wa), exp_l)
1295*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(str(wa[-1].message), r"Unlimited buffer size is set")
1296*da0073e9SAndroid Build Coastguard Worker        output1, output2 = list(dp1), list(dp2)
1297*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(5, 10)), output1)
1298*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(0, 5)), output2)
1299*da0073e9SAndroid Build Coastguard Worker
1300*da0073e9SAndroid Build Coastguard Worker        # Functional Test: classifier returns a value outside of [0, num_instance - 1]
1301*da0073e9SAndroid Build Coastguard Worker        dp0 = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2)
1302*da0073e9SAndroid Build Coastguard Worker        it = iter(dp0[0])
1303*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1304*da0073e9SAndroid Build Coastguard Worker            next(it)
1305*da0073e9SAndroid Build Coastguard Worker            next(it)
1306*da0073e9SAndroid Build Coastguard Worker
1307*da0073e9SAndroid Build Coastguard Worker        # Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read
1308*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
1309*da0073e9SAndroid Build Coastguard Worker        _ = iter(dp1)
1310*da0073e9SAndroid Build Coastguard Worker        output2 = []
1311*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"iterator has been invalidated"):
1312*da0073e9SAndroid Build Coastguard Worker            for i, n2 in enumerate(dp2):
1313*da0073e9SAndroid Build Coastguard Worker                output2.append(n2)
1314*da0073e9SAndroid Build Coastguard Worker                if i == 4:
1315*da0073e9SAndroid Build Coastguard Worker                    with warnings.catch_warnings(record=True) as wa:
1316*da0073e9SAndroid Build Coastguard Worker                        _ = iter(dp1)  # This will reset all child DataPipes
1317*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(len(wa), 1)
1318*da0073e9SAndroid Build Coastguard Worker                        self.assertRegex(
1319*da0073e9SAndroid Build Coastguard Worker                            str(wa[0].message), r"child DataPipes are not exhausted"
1320*da0073e9SAndroid Build Coastguard Worker                        )
1321*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(1, 10, 2)), output2)
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker        # Reset Test: DataPipe resets when some of it has been read
1324*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
1325*da0073e9SAndroid Build Coastguard Worker        output1, output2 = [], []
1326*da0073e9SAndroid Build Coastguard Worker        for n1, n2 in zip(dp1, dp2):
1327*da0073e9SAndroid Build Coastguard Worker            output1.append(n1)
1328*da0073e9SAndroid Build Coastguard Worker            output2.append(n2)
1329*da0073e9SAndroid Build Coastguard Worker            if n1 == 4:
1330*da0073e9SAndroid Build Coastguard Worker                break
1331*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
1332*da0073e9SAndroid Build Coastguard Worker            i1 = iter(dp1)  # Reset all child DataPipes
1333*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(wa), 1)
1334*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(
1335*da0073e9SAndroid Build Coastguard Worker                str(wa[0].message), r"Some child DataPipes are not exhausted"
1336*da0073e9SAndroid Build Coastguard Worker            )
1337*da0073e9SAndroid Build Coastguard Worker            for n1, n2 in zip(dp1, dp2):
1338*da0073e9SAndroid Build Coastguard Worker                output1.append(n1)
1339*da0073e9SAndroid Build Coastguard Worker                output2.append(n2)
1340*da0073e9SAndroid Build Coastguard Worker            self.assertEqual([0, 2, 4] + list(range(0, 10, 2)), output1)
1341*da0073e9SAndroid Build Coastguard Worker            self.assertEqual([1, 3, 5] + list(range(1, 10, 2)), output2)
1342*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(wa), 1)
1343*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted")
1344*da0073e9SAndroid Build Coastguard Worker
1345*da0073e9SAndroid Build Coastguard Worker        # Reset Test: DataPipe reset, even when not all child DataPipes are exhausted
1346*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
1347*da0073e9SAndroid Build Coastguard Worker        output1 = list(dp1)
1348*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(0, 10, 2)), output1)
1349*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
1350*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1351*da0073e9SAndroid Build Coastguard Worker                list(range(0, 10, 2)), list(dp1)
1352*da0073e9SAndroid Build Coastguard Worker            )  # Reset even when dp2 is not read
1353*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(wa), 1)
1354*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(
1355*da0073e9SAndroid Build Coastguard Worker                str(wa[0].message), r"Some child DataPipes are not exhausted"
1356*da0073e9SAndroid Build Coastguard Worker            )
1357*da0073e9SAndroid Build Coastguard Worker        output2 = []
1358*da0073e9SAndroid Build Coastguard Worker        for i, n2 in enumerate(dp2):
1359*da0073e9SAndroid Build Coastguard Worker            output2.append(n2)
1360*da0073e9SAndroid Build Coastguard Worker            if i == 1:
1361*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(list(range(1, 5, 2)), output2)
1362*da0073e9SAndroid Build Coastguard Worker                with warnings.catch_warnings(record=True) as wa:
1363*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
1364*da0073e9SAndroid Build Coastguard Worker                        list(range(0, 10, 2)), list(dp1)
1365*da0073e9SAndroid Build Coastguard Worker                    )  # Can reset even when dp2 is partially read
1366*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(len(wa), 1)
1367*da0073e9SAndroid Build Coastguard Worker                    self.assertRegex(
1368*da0073e9SAndroid Build Coastguard Worker                        str(wa[0].message), r"Some child DataPipes are not exhausted"
1369*da0073e9SAndroid Build Coastguard Worker                    )
1370*da0073e9SAndroid Build Coastguard Worker                break
1371*da0073e9SAndroid Build Coastguard Worker        output2 = list(dp2)  # output2 has to read from beginning again
1372*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(1, 10, 2)), output2)
1373*da0073e9SAndroid Build Coastguard Worker
1374*da0073e9SAndroid Build Coastguard Worker        # Functional Test: drop_none = True
1375*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.demux(
1376*da0073e9SAndroid Build Coastguard Worker            num_instances=2,
1377*da0073e9SAndroid Build Coastguard Worker            classifier_fn=lambda x: x % 2 if x % 5 != 0 else None,
1378*da0073e9SAndroid Build Coastguard Worker            drop_none=True,
1379*da0073e9SAndroid Build Coastguard Worker        )
1380*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([2, 4, 6, 8], list(dp1))
1381*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([1, 3, 7, 9], list(dp2))
1382*da0073e9SAndroid Build Coastguard Worker
1383*da0073e9SAndroid Build Coastguard Worker        # Functional Test: drop_none = False
1384*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.demux(
1385*da0073e9SAndroid Build Coastguard Worker            num_instances=2,
1386*da0073e9SAndroid Build Coastguard Worker            classifier_fn=lambda x: x % 2 if x % 5 != 0 else None,
1387*da0073e9SAndroid Build Coastguard Worker            drop_none=False,
1388*da0073e9SAndroid Build Coastguard Worker        )
1389*da0073e9SAndroid Build Coastguard Worker        it1 = iter(dp1)
1390*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1391*da0073e9SAndroid Build Coastguard Worker            next(it1)
1392*da0073e9SAndroid Build Coastguard Worker
1393*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: __len__ not implemented
1394*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
1395*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1396*da0073e9SAndroid Build Coastguard Worker            len(
1397*da0073e9SAndroid Build Coastguard Worker                dp1
1398*da0073e9SAndroid Build Coastguard Worker            )  # It is not implemented as we do not know length for each child in advance
1399*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1400*da0073e9SAndroid Build Coastguard Worker            len(dp2)
1401*da0073e9SAndroid Build Coastguard Worker
1402*da0073e9SAndroid Build Coastguard Worker        # Pickle Test:
1403*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=odd_or_even)
1404*da0073e9SAndroid Build Coastguard Worker        traverse_dps(dp1)  # This should not raise any error
1405*da0073e9SAndroid Build Coastguard Worker        for _ in zip(dp1, dp2):
1406*da0073e9SAndroid Build Coastguard Worker            pass
1407*da0073e9SAndroid Build Coastguard Worker        traverse_dps(dp2)  # This should not raise any error either
1408*da0073e9SAndroid Build Coastguard Worker
1409*da0073e9SAndroid Build Coastguard Worker    def test_map_iterdatapipe(self):
1410*da0073e9SAndroid Build Coastguard Worker        target_length = 10
1411*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper(range(target_length))
1412*da0073e9SAndroid Build Coastguard Worker
1413*da0073e9SAndroid Build Coastguard Worker        def fn(item, dtype=torch.float, *, sum=False):
1414*da0073e9SAndroid Build Coastguard Worker            data = torch.tensor(item, dtype=dtype)
1415*da0073e9SAndroid Build Coastguard Worker            return data if not sum else data.sum()
1416*da0073e9SAndroid Build Coastguard Worker
1417*da0073e9SAndroid Build Coastguard Worker        # Functional Test: apply to each element correctly
1418*da0073e9SAndroid Build Coastguard Worker        map_dp = input_dp.map(fn)
1419*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(target_length, len(map_dp))
1420*da0073e9SAndroid Build Coastguard Worker        for x, y in zip(map_dp, range(target_length)):
1421*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, torch.tensor(y, dtype=torch.float))
1422*da0073e9SAndroid Build Coastguard Worker
1423*da0073e9SAndroid Build Coastguard Worker        # Functional Test: works with partial function
1424*da0073e9SAndroid Build Coastguard Worker        map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True))
1425*da0073e9SAndroid Build Coastguard Worker        for x, y in zip(map_dp, range(target_length)):
1426*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())
1427*da0073e9SAndroid Build Coastguard Worker
1428*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: inherits length from source DataPipe
1429*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(target_length, len(map_dp))
1430*da0073e9SAndroid Build Coastguard Worker
1431*da0073e9SAndroid Build Coastguard Worker        input_dp_nl = IDP_NoLen(range(target_length))
1432*da0073e9SAndroid Build Coastguard Worker        map_dp_nl = input_dp_nl.map(lambda x: x)
1433*da0073e9SAndroid Build Coastguard Worker        for x, y in zip(map_dp_nl, range(target_length)):
1434*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, torch.tensor(y, dtype=torch.float))
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: inherits length from source DataPipe - raises error when invalid
1437*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
1438*da0073e9SAndroid Build Coastguard Worker            len(map_dp_nl)
1439*da0073e9SAndroid Build Coastguard Worker
1440*da0073e9SAndroid Build Coastguard Worker        # Reset Test: DataPipe resets properly
1441*da0073e9SAndroid Build Coastguard Worker        n_elements_before_reset = 5
1442*da0073e9SAndroid Build Coastguard Worker        res_before_reset, res_after_reset = reset_after_n_next_calls(
1443*da0073e9SAndroid Build Coastguard Worker            map_dp, n_elements_before_reset
1444*da0073e9SAndroid Build Coastguard Worker        )
1445*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(n_elements_before_reset)), res_before_reset)
1446*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), res_after_reset)
1447*da0073e9SAndroid Build Coastguard Worker
1448*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings  # Suppress warning for lambda fn
1449*da0073e9SAndroid Build Coastguard Worker    def test_map_tuple_list_with_col_iterdatapipe(self):
1450*da0073e9SAndroid Build Coastguard Worker        def fn_11(d):
1451*da0073e9SAndroid Build Coastguard Worker            return -d
1452*da0073e9SAndroid Build Coastguard Worker
1453*da0073e9SAndroid Build Coastguard Worker        def fn_1n(d):
1454*da0073e9SAndroid Build Coastguard Worker            return -d, d
1455*da0073e9SAndroid Build Coastguard Worker
1456*da0073e9SAndroid Build Coastguard Worker        def fn_n1(d0, d1):
1457*da0073e9SAndroid Build Coastguard Worker            return d0 + d1
1458*da0073e9SAndroid Build Coastguard Worker
1459*da0073e9SAndroid Build Coastguard Worker        def fn_nn(d0, d1):
1460*da0073e9SAndroid Build Coastguard Worker            return -d0, -d1, d0 + d1
1461*da0073e9SAndroid Build Coastguard Worker
1462*da0073e9SAndroid Build Coastguard Worker        def fn_n1_def(d0, d1=1):
1463*da0073e9SAndroid Build Coastguard Worker            return d0 + d1
1464*da0073e9SAndroid Build Coastguard Worker
1465*da0073e9SAndroid Build Coastguard Worker        def fn_n1_kwargs(d0, d1, **kwargs):
1466*da0073e9SAndroid Build Coastguard Worker            return d0 + d1
1467*da0073e9SAndroid Build Coastguard Worker
1468*da0073e9SAndroid Build Coastguard Worker        def fn_n1_pos(d0, d1, *args):
1469*da0073e9SAndroid Build Coastguard Worker            return d0 + d1
1470*da0073e9SAndroid Build Coastguard Worker
1471*da0073e9SAndroid Build Coastguard Worker        def fn_n1_sep_pos(d0, *args, d1):
1472*da0073e9SAndroid Build Coastguard Worker            return d0 + d1
1473*da0073e9SAndroid Build Coastguard Worker
1474*da0073e9SAndroid Build Coastguard Worker        def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
1475*da0073e9SAndroid Build Coastguard Worker            return d0 + d1
1476*da0073e9SAndroid Build Coastguard Worker
1477*da0073e9SAndroid Build Coastguard Worker        p_fn_n1 = partial(fn_n1, d1=1)
1478*da0073e9SAndroid Build Coastguard Worker        p_fn_cmplx = partial(fn_cmplx, d2=2)
1479*da0073e9SAndroid Build Coastguard Worker        p_fn_cmplx_large_arg = partial(
1480*da0073e9SAndroid Build Coastguard Worker            fn_cmplx, d2={i: list(range(i)) for i in range(10_000)}
1481*da0073e9SAndroid Build Coastguard Worker        )
1482*da0073e9SAndroid Build Coastguard Worker
1483*da0073e9SAndroid Build Coastguard Worker        def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
1484*da0073e9SAndroid Build Coastguard Worker            for constr in (list, tuple):
1485*da0073e9SAndroid Build Coastguard Worker                datapipe = dp.iter.IterableWrapper(
1486*da0073e9SAndroid Build Coastguard Worker                    [constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))]
1487*da0073e9SAndroid Build Coastguard Worker                )
1488*da0073e9SAndroid Build Coastguard Worker                if ref_fn is None:
1489*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(error):
1490*da0073e9SAndroid Build Coastguard Worker                        res_dp = datapipe.map(fn, input_col, output_col)
1491*da0073e9SAndroid Build Coastguard Worker                        list(res_dp)
1492*da0073e9SAndroid Build Coastguard Worker                else:
1493*da0073e9SAndroid Build Coastguard Worker                    res_dp = datapipe.map(fn, input_col, output_col)
1494*da0073e9SAndroid Build Coastguard Worker                    ref_dp = datapipe.map(ref_fn)
1495*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(list(res_dp), list(ref_dp))
1496*da0073e9SAndroid Build Coastguard Worker                    # Reset
1497*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(list(res_dp), list(ref_dp))
1498*da0073e9SAndroid Build Coastguard Worker
1499*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: data, fn_n1_def, 0, 1)
1500*da0073e9SAndroid Build Coastguard Worker        _helper(
1501*da0073e9SAndroid Build Coastguard Worker            lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2
1502*da0073e9SAndroid Build Coastguard Worker        )
1503*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: data, p_fn_n1, 0, 1)
1504*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: data, p_fn_cmplx, 0, 1)
1505*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: data, p_fn_cmplx_large_arg, 0, 1)
1506*da0073e9SAndroid Build Coastguard Worker        _helper(
1507*da0073e9SAndroid Build Coastguard Worker            lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2
1508*da0073e9SAndroid Build Coastguard Worker        )
1509*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: (data[0] + data[1],), fn_n1_pos, [0, 1, 2])
1510*da0073e9SAndroid Build Coastguard Worker
1511*da0073e9SAndroid Build Coastguard Worker        # Replacing with one input column and default output column
1512*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
1513*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1)
1514*da0073e9SAndroid Build Coastguard Worker        # The index of input column is out of range
1515*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_1n, 3, error=IndexError)
1516*da0073e9SAndroid Build Coastguard Worker        # Unmatched input columns with fn arguments
1517*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_n1, 1, error=ValueError)
1518*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_n1, [0, 1, 2], error=ValueError)
1519*da0073e9SAndroid Build Coastguard Worker        _helper(None, operator.add, 0, error=ValueError)
1520*da0073e9SAndroid Build Coastguard Worker        _helper(None, operator.add, [0, 1, 2], error=ValueError)
1521*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_cmplx, 0, 1, ValueError)
1522*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_n1_pos, 1, error=ValueError)
1523*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError)
1524*da0073e9SAndroid Build Coastguard Worker        _helper(None, p_fn_n1, [0, 1], error=ValueError)
1525*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_1n, [1, 2], error=ValueError)
1526*da0073e9SAndroid Build Coastguard Worker        # _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError)
1527*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError)
1528*da0073e9SAndroid Build Coastguard Worker        # Fn has keyword-only arguments
1529*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_n1_kwargs, 1, error=ValueError)
1530*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_cmplx, [0, 1], 2, ValueError)
1531*da0073e9SAndroid Build Coastguard Worker
1532*da0073e9SAndroid Build Coastguard Worker        # Replacing with multiple input columns and default output column (the left-most input column)
1533*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
1534*da0073e9SAndroid Build Coastguard Worker        _helper(
1535*da0073e9SAndroid Build Coastguard Worker            lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])),
1536*da0073e9SAndroid Build Coastguard Worker            fn_nn,
1537*da0073e9SAndroid Build Coastguard Worker            [2, 1],
1538*da0073e9SAndroid Build Coastguard Worker        )
1539*da0073e9SAndroid Build Coastguard Worker
1540*da0073e9SAndroid Build Coastguard Worker        # output_col can only be specified when input_col is not None
1541*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_n1, None, 1, error=ValueError)
1542*da0073e9SAndroid Build Coastguard Worker        # output_col can only be single-element list or tuple
1543*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_n1, None, [0, 1], error=ValueError)
1544*da0073e9SAndroid Build Coastguard Worker        # Single-element list as output_col
1545*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
1546*da0073e9SAndroid Build Coastguard Worker        # Replacing with one input column and single specified output column
1547*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
1548*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
1549*da0073e9SAndroid Build Coastguard Worker        # The index of output column is out of range
1550*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_1n, 1, 3, error=IndexError)
1551*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
1552*da0073e9SAndroid Build Coastguard Worker        _helper(
1553*da0073e9SAndroid Build Coastguard Worker            lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]),
1554*da0073e9SAndroid Build Coastguard Worker            fn_nn,
1555*da0073e9SAndroid Build Coastguard Worker            [1, 2],
1556*da0073e9SAndroid Build Coastguard Worker            0,
1557*da0073e9SAndroid Build Coastguard Worker        )
1558*da0073e9SAndroid Build Coastguard Worker
1559*da0073e9SAndroid Build Coastguard Worker        # Appending the output at the end
1560*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: (*data, -data[1]), fn_11, 1, -1)
1561*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1)
1562*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1)
1563*da0073e9SAndroid Build Coastguard Worker        _helper(
1564*da0073e9SAndroid Build Coastguard Worker            lambda data: (*data, (-data[1], -data[2], data[1] + data[2])),
1565*da0073e9SAndroid Build Coastguard Worker            fn_nn,
1566*da0073e9SAndroid Build Coastguard Worker            [1, 2],
1567*da0073e9SAndroid Build Coastguard Worker            -1,
1568*da0073e9SAndroid Build Coastguard Worker        )
1569*da0073e9SAndroid Build Coastguard Worker
1570*da0073e9SAndroid Build Coastguard Worker        # Handling built-in functions (e.g. `dict`, `iter`, `int`, `str`) whose signatures cannot be inspected
1571*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: (str(data[0]), data[1], data[2]), str, 0)
1572*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: (data[0], data[1], int(data[2])), int, 2)
1573*da0073e9SAndroid Build Coastguard Worker
1574*da0073e9SAndroid Build Coastguard Worker        # Handle nn.Module and Callable (without __name__ implemented)
1575*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Module(), 0)
1576*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Callable(), 0)
1577*da0073e9SAndroid Build Coastguard Worker
1578*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings  # Suppress warning for lambda fn
1579*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
1580*da0073e9SAndroid Build Coastguard Worker    def test_map_dict_with_col_iterdatapipe(self):
1581*da0073e9SAndroid Build Coastguard Worker        def fn_11(d):
1582*da0073e9SAndroid Build Coastguard Worker            return -d
1583*da0073e9SAndroid Build Coastguard Worker
1584*da0073e9SAndroid Build Coastguard Worker        def fn_1n(d):
1585*da0073e9SAndroid Build Coastguard Worker            return -d, d
1586*da0073e9SAndroid Build Coastguard Worker
1587*da0073e9SAndroid Build Coastguard Worker        def fn_n1(d0, d1):
1588*da0073e9SAndroid Build Coastguard Worker            return d0 + d1
1589*da0073e9SAndroid Build Coastguard Worker
1590*da0073e9SAndroid Build Coastguard Worker        def fn_nn(d0, d1):
1591*da0073e9SAndroid Build Coastguard Worker            return -d0, -d1, d0 + d1
1592*da0073e9SAndroid Build Coastguard Worker
1593*da0073e9SAndroid Build Coastguard Worker        def fn_n1_def(d0, d1=1):
1594*da0073e9SAndroid Build Coastguard Worker            return d0 + d1
1595*da0073e9SAndroid Build Coastguard Worker
1596*da0073e9SAndroid Build Coastguard Worker        p_fn_n1 = partial(fn_n1, d1=1)
1597*da0073e9SAndroid Build Coastguard Worker
1598*da0073e9SAndroid Build Coastguard Worker        def fn_n1_pos(d0, d1, *args):
1599*da0073e9SAndroid Build Coastguard Worker            return d0 + d1
1600*da0073e9SAndroid Build Coastguard Worker
1601*da0073e9SAndroid Build Coastguard Worker        def fn_n1_kwargs(d0, d1, **kwargs):
1602*da0073e9SAndroid Build Coastguard Worker            return d0 + d1
1603*da0073e9SAndroid Build Coastguard Worker
1604*da0073e9SAndroid Build Coastguard Worker        def fn_kwonly(*, d0, d1):
1605*da0073e9SAndroid Build Coastguard Worker            return d0 + d1
1606*da0073e9SAndroid Build Coastguard Worker
1607*da0073e9SAndroid Build Coastguard Worker        def fn_has_nondefault_kwonly(d0, *, d1):
1608*da0073e9SAndroid Build Coastguard Worker            return d0 + d1
1609*da0073e9SAndroid Build Coastguard Worker
1610*da0073e9SAndroid Build Coastguard Worker        def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
1611*da0073e9SAndroid Build Coastguard Worker            return d0 + d1
1612*da0073e9SAndroid Build Coastguard Worker
1613*da0073e9SAndroid Build Coastguard Worker        p_fn_cmplx = partial(fn_cmplx, d2=2)
1614*da0073e9SAndroid Build Coastguard Worker        p_fn_cmplx_large_arg = partial(
1615*da0073e9SAndroid Build Coastguard Worker            fn_cmplx, d2={i: list(range(i)) for i in range(10_000)}
1616*da0073e9SAndroid Build Coastguard Worker        )
1617*da0073e9SAndroid Build Coastguard Worker
1618*da0073e9SAndroid Build Coastguard Worker        # Prevent modification in-place to support resetting
1619*da0073e9SAndroid Build Coastguard Worker        def _dict_update(data, newdata, remove_idx=None):
1620*da0073e9SAndroid Build Coastguard Worker            _data = dict(data)
1621*da0073e9SAndroid Build Coastguard Worker            _data.update(newdata)
1622*da0073e9SAndroid Build Coastguard Worker            if remove_idx:
1623*da0073e9SAndroid Build Coastguard Worker                for idx in remove_idx:
1624*da0073e9SAndroid Build Coastguard Worker                    del _data[idx]
1625*da0073e9SAndroid Build Coastguard Worker            return _data
1626*da0073e9SAndroid Build Coastguard Worker
1627*da0073e9SAndroid Build Coastguard Worker        def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
1628*da0073e9SAndroid Build Coastguard Worker            datapipe = dp.iter.IterableWrapper(
1629*da0073e9SAndroid Build Coastguard Worker                [
1630*da0073e9SAndroid Build Coastguard Worker                    {"x": 0, "y": 1, "z": 2},
1631*da0073e9SAndroid Build Coastguard Worker                    {"x": 3, "y": 4, "z": 5},
1632*da0073e9SAndroid Build Coastguard Worker                    {"x": 6, "y": 7, "z": 8},
1633*da0073e9SAndroid Build Coastguard Worker                ]
1634*da0073e9SAndroid Build Coastguard Worker            )
1635*da0073e9SAndroid Build Coastguard Worker            if ref_fn is None:
1636*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(error):
1637*da0073e9SAndroid Build Coastguard Worker                    res_dp = datapipe.map(fn, input_col, output_col)
1638*da0073e9SAndroid Build Coastguard Worker                    list(res_dp)
1639*da0073e9SAndroid Build Coastguard Worker            else:
1640*da0073e9SAndroid Build Coastguard Worker                res_dp = datapipe.map(fn, input_col, output_col)
1641*da0073e9SAndroid Build Coastguard Worker                ref_dp = datapipe.map(ref_fn)
1642*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(list(res_dp), list(ref_dp))
1643*da0073e9SAndroid Build Coastguard Worker                # Reset
1644*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(list(res_dp), list(ref_dp))
1645*da0073e9SAndroid Build Coastguard Worker
1646*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: data, fn_n1_def, "x", "y")
1647*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: data, p_fn_n1, "x", "y")
1648*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: data, p_fn_cmplx, "x", "y")
1649*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: data, p_fn_cmplx_large_arg, "x", "y")
1650*da0073e9SAndroid Build Coastguard Worker        _helper(
1651*da0073e9SAndroid Build Coastguard Worker            lambda data: _dict_update(data, {"z": data["x"] + data["y"]}),
1652*da0073e9SAndroid Build Coastguard Worker            p_fn_cmplx,
1653*da0073e9SAndroid Build Coastguard Worker            ["x", "y", "z"],
1654*da0073e9SAndroid Build Coastguard Worker            "z",
1655*da0073e9SAndroid Build Coastguard Worker        )
1656*da0073e9SAndroid Build Coastguard Worker
1657*da0073e9SAndroid Build Coastguard Worker        _helper(
1658*da0073e9SAndroid Build Coastguard Worker            lambda data: _dict_update(data, {"z": data["x"] + data["y"]}),
1659*da0073e9SAndroid Build Coastguard Worker            fn_n1_def,
1660*da0073e9SAndroid Build Coastguard Worker            ["x", "y"],
1661*da0073e9SAndroid Build Coastguard Worker            "z",
1662*da0073e9SAndroid Build Coastguard Worker        )
1663*da0073e9SAndroid Build Coastguard Worker
1664*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_n1_pos, "x", error=ValueError)
1665*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_n1_kwargs, "x", error=ValueError)
1666*da0073e9SAndroid Build Coastguard Worker        # non-default kw-only args
1667*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_kwonly, ["x", "y"], error=ValueError)
1668*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_has_nondefault_kwonly, ["x", "y"], error=ValueError)
1669*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_cmplx, ["x", "y"], error=ValueError)
1670*da0073e9SAndroid Build Coastguard Worker
1671*da0073e9SAndroid Build Coastguard Worker        # Replacing with one input column and default output column
1672*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
1673*da0073e9SAndroid Build Coastguard Worker        _helper(
1674*da0073e9SAndroid Build Coastguard Worker            lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y"
1675*da0073e9SAndroid Build Coastguard Worker        )
1676*da0073e9SAndroid Build Coastguard Worker        # The key of input column is not in dict
1677*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_1n, "a", error=KeyError)
1678*da0073e9SAndroid Build Coastguard Worker        # Unmatched input columns with fn arguments
1679*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_n1, "y", error=ValueError)
1680*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_1n, ["x", "y"], error=ValueError)
1681*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError)
1682*da0073e9SAndroid Build Coastguard Worker        _helper(None, p_fn_n1, ["x", "y"], error=ValueError)
1683*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError)
1684*da0073e9SAndroid Build Coastguard Worker        # Replacing with multiple input columns and default output column (the left-most input column)
1685*da0073e9SAndroid Build Coastguard Worker        _helper(
1686*da0073e9SAndroid Build Coastguard Worker            lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]),
1687*da0073e9SAndroid Build Coastguard Worker            fn_n1,
1688*da0073e9SAndroid Build Coastguard Worker            ["z", "x"],
1689*da0073e9SAndroid Build Coastguard Worker        )
1690*da0073e9SAndroid Build Coastguard Worker        _helper(
1691*da0073e9SAndroid Build Coastguard Worker            lambda data: _dict_update(
1692*da0073e9SAndroid Build Coastguard Worker                data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]
1693*da0073e9SAndroid Build Coastguard Worker            ),
1694*da0073e9SAndroid Build Coastguard Worker            fn_nn,
1695*da0073e9SAndroid Build Coastguard Worker            ["z", "y"],
1696*da0073e9SAndroid Build Coastguard Worker        )
1697*da0073e9SAndroid Build Coastguard Worker
1698*da0073e9SAndroid Build Coastguard Worker        # output_col can only be specified when input_col is not None
1699*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_n1, None, "x", error=ValueError)
1700*da0073e9SAndroid Build Coastguard Worker        # output_col can only be single-element list or tuple
1701*da0073e9SAndroid Build Coastguard Worker        _helper(None, fn_n1, None, ["x", "y"], error=ValueError)
1702*da0073e9SAndroid Build Coastguard Worker        # Single-element list as output_col
1703*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
1704*da0073e9SAndroid Build Coastguard Worker        # Replacing with one input column and single specified output column
1705*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x")
1706*da0073e9SAndroid Build Coastguard Worker        _helper(
1707*da0073e9SAndroid Build Coastguard Worker            lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}),
1708*da0073e9SAndroid Build Coastguard Worker            fn_1n,
1709*da0073e9SAndroid Build Coastguard Worker            "y",
1710*da0073e9SAndroid Build Coastguard Worker            "z",
1711*da0073e9SAndroid Build Coastguard Worker        )
1712*da0073e9SAndroid Build Coastguard Worker        _helper(
1713*da0073e9SAndroid Build Coastguard Worker            lambda data: _dict_update(data, {"y": data["x"] + data["z"]}),
1714*da0073e9SAndroid Build Coastguard Worker            fn_n1,
1715*da0073e9SAndroid Build Coastguard Worker            ["x", "z"],
1716*da0073e9SAndroid Build Coastguard Worker            "y",
1717*da0073e9SAndroid Build Coastguard Worker        )
1718*da0073e9SAndroid Build Coastguard Worker        _helper(
1719*da0073e9SAndroid Build Coastguard Worker            lambda data: _dict_update(
1720*da0073e9SAndroid Build Coastguard Worker                data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}
1721*da0073e9SAndroid Build Coastguard Worker            ),
1722*da0073e9SAndroid Build Coastguard Worker            fn_nn,
1723*da0073e9SAndroid Build Coastguard Worker            ["y", "z"],
1724*da0073e9SAndroid Build Coastguard Worker            "x",
1725*da0073e9SAndroid Build Coastguard Worker        )
1726*da0073e9SAndroid Build Coastguard Worker
1727*da0073e9SAndroid Build Coastguard Worker        # Adding new key to dict for the output
1728*da0073e9SAndroid Build Coastguard Worker        _helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a")
1729*da0073e9SAndroid Build Coastguard Worker        _helper(
1730*da0073e9SAndroid Build Coastguard Worker            lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}),
1731*da0073e9SAndroid Build Coastguard Worker            fn_1n,
1732*da0073e9SAndroid Build Coastguard Worker            "y",
1733*da0073e9SAndroid Build Coastguard Worker            "a",
1734*da0073e9SAndroid Build Coastguard Worker        )
1735*da0073e9SAndroid Build Coastguard Worker        _helper(
1736*da0073e9SAndroid Build Coastguard Worker            lambda data: _dict_update(data, {"a": data["x"] + data["z"]}),
1737*da0073e9SAndroid Build Coastguard Worker            fn_n1,
1738*da0073e9SAndroid Build Coastguard Worker            ["x", "z"],
1739*da0073e9SAndroid Build Coastguard Worker            "a",
1740*da0073e9SAndroid Build Coastguard Worker        )
1741*da0073e9SAndroid Build Coastguard Worker        _helper(
1742*da0073e9SAndroid Build Coastguard Worker            lambda data: _dict_update(
1743*da0073e9SAndroid Build Coastguard Worker                data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}
1744*da0073e9SAndroid Build Coastguard Worker            ),
1745*da0073e9SAndroid Build Coastguard Worker            fn_nn,
1746*da0073e9SAndroid Build Coastguard Worker            ["y", "z"],
1747*da0073e9SAndroid Build Coastguard Worker            "a",
1748*da0073e9SAndroid Build Coastguard Worker        )
1749*da0073e9SAndroid Build Coastguard Worker
1750*da0073e9SAndroid Build Coastguard Worker    def test_collate_iterdatapipe(self):
1751*da0073e9SAndroid Build Coastguard Worker        arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
1752*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper(arrs)
1753*da0073e9SAndroid Build Coastguard Worker
1754*da0073e9SAndroid Build Coastguard Worker        def _collate_fn(batch, default_type=torch.float):
1755*da0073e9SAndroid Build Coastguard Worker            return torch.tensor(sum(batch), dtype=default_type)
1756*da0073e9SAndroid Build Coastguard Worker
1757*da0073e9SAndroid Build Coastguard Worker        # Functional Test: defaults to the default collate function when a custom one is not specified
1758*da0073e9SAndroid Build Coastguard Worker        collate_dp = input_dp.collate()
1759*da0073e9SAndroid Build Coastguard Worker        for x, y in zip(arrs, collate_dp):
1760*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.tensor(x), y)
1761*da0073e9SAndroid Build Coastguard Worker
1762*da0073e9SAndroid Build Coastguard Worker        # Functional Test: custom collate function
1763*da0073e9SAndroid Build Coastguard Worker        collate_dp = input_dp.collate(collate_fn=_collate_fn)
1764*da0073e9SAndroid Build Coastguard Worker        for x, y in zip(arrs, collate_dp):
1765*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.tensor(sum(x), dtype=torch.float), y)
1766*da0073e9SAndroid Build Coastguard Worker
1767*da0073e9SAndroid Build Coastguard Worker        # Functional Test: custom, partial collate function
1768*da0073e9SAndroid Build Coastguard Worker        collate_dp = input_dp.collate(partial(_collate_fn, default_type=torch.int))
1769*da0073e9SAndroid Build Coastguard Worker        for x, y in zip(arrs, collate_dp):
1770*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y)
1771*da0073e9SAndroid Build Coastguard Worker
1772*da0073e9SAndroid Build Coastguard Worker        # Reset Test: reset the DataPipe and results are still correct
1773*da0073e9SAndroid Build Coastguard Worker        n_elements_before_reset = 1
1774*da0073e9SAndroid Build Coastguard Worker        res_before_reset, res_after_reset = reset_after_n_next_calls(
1775*da0073e9SAndroid Build Coastguard Worker            collate_dp, n_elements_before_reset
1776*da0073e9SAndroid Build Coastguard Worker        )
1777*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([torch.tensor(6, dtype=torch.int)], res_before_reset)
1778*da0073e9SAndroid Build Coastguard Worker        for x, y in zip(arrs, res_after_reset):
1779*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y)
1780*da0073e9SAndroid Build Coastguard Worker
1781*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: __len__ is inherited
1782*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(input_dp), len(collate_dp))
1783*da0073e9SAndroid Build Coastguard Worker
1784*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: verify that it has no valid __len__ when the source doesn't have it
1785*da0073e9SAndroid Build Coastguard Worker        input_dp_nl = IDP_NoLen(arrs)
1786*da0073e9SAndroid Build Coastguard Worker        collate_dp_nl = input_dp_nl.collate()
1787*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
1788*da0073e9SAndroid Build Coastguard Worker            len(collate_dp_nl)
1789*da0073e9SAndroid Build Coastguard Worker        for x, y in zip(arrs, collate_dp_nl):
1790*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.tensor(x), y)
1791*da0073e9SAndroid Build Coastguard Worker
1792*da0073e9SAndroid Build Coastguard Worker    def test_batch_iterdatapipe(self):
1793*da0073e9SAndroid Build Coastguard Worker        arrs = list(range(10))
1794*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper(arrs)
1795*da0073e9SAndroid Build Coastguard Worker
1796*da0073e9SAndroid Build Coastguard Worker        # Functional Test: raise error when input argument `batch_size = 0`
1797*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
1798*da0073e9SAndroid Build Coastguard Worker            input_dp.batch(batch_size=0)
1799*da0073e9SAndroid Build Coastguard Worker
1800*da0073e9SAndroid Build Coastguard Worker        # Functional Test: by default, do not drop the last batch
1801*da0073e9SAndroid Build Coastguard Worker        bs = 3
1802*da0073e9SAndroid Build Coastguard Worker        batch_dp = input_dp.batch(batch_size=bs)
1803*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(batch_dp), 4)
1804*da0073e9SAndroid Build Coastguard Worker        for i, batch in enumerate(batch_dp):
1805*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(batch), 1 if i == 3 else bs)
1806*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(batch, arrs[i * bs : i * bs + len(batch)])
1807*da0073e9SAndroid Build Coastguard Worker
1808*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Drop the last batch when specified
1809*da0073e9SAndroid Build Coastguard Worker        bs = 4
1810*da0073e9SAndroid Build Coastguard Worker        batch_dp = input_dp.batch(batch_size=bs, drop_last=True)
1811*da0073e9SAndroid Build Coastguard Worker        for i, batch in enumerate(batch_dp):
1812*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(batch, arrs[i * bs : i * bs + len(batch)])
1813*da0073e9SAndroid Build Coastguard Worker
1814*da0073e9SAndroid Build Coastguard Worker        # __len__ test: verifying that the overall length and of each batch is correct
1815*da0073e9SAndroid Build Coastguard Worker        for i, batch in enumerate(batch_dp):
1816*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(batch), bs)
1817*da0073e9SAndroid Build Coastguard Worker
1818*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: the length is missing if the source DataPipe doesn't have length
1819*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(batch_dp), 2)
1820*da0073e9SAndroid Build Coastguard Worker        input_dp_nl = IDP_NoLen(range(10))
1821*da0073e9SAndroid Build Coastguard Worker        batch_dp_nl = input_dp_nl.batch(batch_size=2)
1822*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
1823*da0073e9SAndroid Build Coastguard Worker            len(batch_dp_nl)
1824*da0073e9SAndroid Build Coastguard Worker
1825*da0073e9SAndroid Build Coastguard Worker        # Reset Test: Ensures that the DataPipe can properly reset
1826*da0073e9SAndroid Build Coastguard Worker        n_elements_before_reset = 1
1827*da0073e9SAndroid Build Coastguard Worker        res_before_reset, res_after_reset = reset_after_n_next_calls(
1828*da0073e9SAndroid Build Coastguard Worker            batch_dp, n_elements_before_reset
1829*da0073e9SAndroid Build Coastguard Worker        )
1830*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([[0, 1, 2, 3]], res_before_reset)
1831*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([[0, 1, 2, 3], [4, 5, 6, 7]], res_after_reset)
1832*da0073e9SAndroid Build Coastguard Worker
1833*da0073e9SAndroid Build Coastguard Worker    def test_unbatch_iterdatapipe(self):
1834*da0073e9SAndroid Build Coastguard Worker        target_length = 6
1835*da0073e9SAndroid Build Coastguard Worker        prebatch_dp = dp.iter.IterableWrapper(range(target_length))
1836*da0073e9SAndroid Build Coastguard Worker
1837*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Unbatch DataPipe should be the same as pre-batch DataPipe
1838*da0073e9SAndroid Build Coastguard Worker        input_dp = prebatch_dp.batch(3)
1839*da0073e9SAndroid Build Coastguard Worker        unbatch_dp = input_dp.unbatch()
1840*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(unbatch_dp)), target_length)  # __len__ is as expected
1841*da0073e9SAndroid Build Coastguard Worker        for i, res in zip(range(target_length), unbatch_dp):
1842*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(i, res)
1843*da0073e9SAndroid Build Coastguard Worker
1844*da0073e9SAndroid Build Coastguard Worker        # Functional Test: unbatch works for an input with nested levels
1845*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
1846*da0073e9SAndroid Build Coastguard Worker        unbatch_dp = input_dp.unbatch()
1847*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(unbatch_dp)), target_length)
1848*da0073e9SAndroid Build Coastguard Worker        for i, res in zip(range(target_length), unbatch_dp):
1849*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(i, res)
1850*da0073e9SAndroid Build Coastguard Worker
1851*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
1852*da0073e9SAndroid Build Coastguard Worker
1853*da0073e9SAndroid Build Coastguard Worker        # Functional Test: unbatch works for an input with nested levels
1854*da0073e9SAndroid Build Coastguard Worker        unbatch_dp = input_dp.unbatch()
1855*da0073e9SAndroid Build Coastguard Worker        expected_dp = [[0, 1], [2, 3], [4, 5], [6, 7]]
1856*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(unbatch_dp)), 4)
1857*da0073e9SAndroid Build Coastguard Worker        for j, res in zip(expected_dp, unbatch_dp):
1858*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(j, res)
1859*da0073e9SAndroid Build Coastguard Worker
1860*da0073e9SAndroid Build Coastguard Worker        # Functional Test: unbatching multiple levels at the same time
1861*da0073e9SAndroid Build Coastguard Worker        unbatch_dp = input_dp.unbatch(unbatch_level=2)
1862*da0073e9SAndroid Build Coastguard Worker        expected_dp2 = [0, 1, 2, 3, 4, 5, 6, 7]
1863*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(unbatch_dp)), 8)
1864*da0073e9SAndroid Build Coastguard Worker        for i, res in zip(expected_dp2, unbatch_dp):
1865*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(i, res)
1866*da0073e9SAndroid Build Coastguard Worker
1867*da0073e9SAndroid Build Coastguard Worker        # Functional Test: unbatching all levels at the same time
1868*da0073e9SAndroid Build Coastguard Worker        unbatch_dp = input_dp.unbatch(unbatch_level=-1)
1869*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(unbatch_dp)), 8)
1870*da0073e9SAndroid Build Coastguard Worker        for i, res in zip(expected_dp2, unbatch_dp):
1871*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(i, res)
1872*da0073e9SAndroid Build Coastguard Worker
1873*da0073e9SAndroid Build Coastguard Worker        # Functional Test: raises error when input unbatch_level is less than -1
1874*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
1875*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1876*da0073e9SAndroid Build Coastguard Worker            unbatch_dp = input_dp.unbatch(unbatch_level=-2)
1877*da0073e9SAndroid Build Coastguard Worker            for i in unbatch_dp:
1878*da0073e9SAndroid Build Coastguard Worker                print(i)
1879*da0073e9SAndroid Build Coastguard Worker
1880*da0073e9SAndroid Build Coastguard Worker        # Functional Test: raises error when input unbatch_level is too high
1881*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
1882*da0073e9SAndroid Build Coastguard Worker            unbatch_dp = input_dp.unbatch(unbatch_level=5)
1883*da0073e9SAndroid Build Coastguard Worker            for i in unbatch_dp:
1884*da0073e9SAndroid Build Coastguard Worker                print(i)
1885*da0073e9SAndroid Build Coastguard Worker
1886*da0073e9SAndroid Build Coastguard Worker        # Reset Test: unbatch_dp resets properly
1887*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
1888*da0073e9SAndroid Build Coastguard Worker        unbatch_dp = input_dp.unbatch(unbatch_level=-1)
1889*da0073e9SAndroid Build Coastguard Worker        n_elements_before_reset = 3
1890*da0073e9SAndroid Build Coastguard Worker        res_before_reset, res_after_reset = reset_after_n_next_calls(
1891*da0073e9SAndroid Build Coastguard Worker            unbatch_dp, n_elements_before_reset
1892*da0073e9SAndroid Build Coastguard Worker        )
1893*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([0, 1, 2], res_before_reset)
1894*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([0, 1, 2, 3, 4, 5], res_after_reset)
1895*da0073e9SAndroid Build Coastguard Worker
1896*da0073e9SAndroid Build Coastguard Worker    def test_filter_datapipe(self):
1897*da0073e9SAndroid Build Coastguard Worker        input_ds = dp.iter.IterableWrapper(range(10))
1898*da0073e9SAndroid Build Coastguard Worker
1899*da0073e9SAndroid Build Coastguard Worker        def _filter_fn(data, val):
1900*da0073e9SAndroid Build Coastguard Worker            return data >= val
1901*da0073e9SAndroid Build Coastguard Worker
1902*da0073e9SAndroid Build Coastguard Worker        # Functional Test: filter works with partial function
1903*da0073e9SAndroid Build Coastguard Worker        filter_dp = input_ds.filter(partial(_filter_fn, val=5))
1904*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(filter_dp), list(range(5, 10)))
1905*da0073e9SAndroid Build Coastguard Worker
1906*da0073e9SAndroid Build Coastguard Worker        def _non_bool_fn(data):
1907*da0073e9SAndroid Build Coastguard Worker            return 1
1908*da0073e9SAndroid Build Coastguard Worker
1909*da0073e9SAndroid Build Coastguard Worker        # Functional Test: filter function must return bool
1910*da0073e9SAndroid Build Coastguard Worker        filter_dp = input_ds.filter(filter_fn=_non_bool_fn)
1911*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1912*da0073e9SAndroid Build Coastguard Worker            temp = list(filter_dp)
1913*da0073e9SAndroid Build Coastguard Worker
1914*da0073e9SAndroid Build Coastguard Worker        # Funtional Test: Specify input_col
1915*da0073e9SAndroid Build Coastguard Worker        tuple_input_ds = dp.iter.IterableWrapper([(d - 1, d, d + 1) for d in range(10)])
1916*da0073e9SAndroid Build Coastguard Worker
1917*da0073e9SAndroid Build Coastguard Worker        # Single input_col
1918*da0073e9SAndroid Build Coastguard Worker        input_col_1_dp = tuple_input_ds.filter(partial(_filter_fn, val=5), input_col=1)
1919*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1920*da0073e9SAndroid Build Coastguard Worker            list(input_col_1_dp), [(d - 1, d, d + 1) for d in range(5, 10)]
1921*da0073e9SAndroid Build Coastguard Worker        )
1922*da0073e9SAndroid Build Coastguard Worker
1923*da0073e9SAndroid Build Coastguard Worker        # Multiple input_col
1924*da0073e9SAndroid Build Coastguard Worker        def _mul_filter_fn(a, b):
1925*da0073e9SAndroid Build Coastguard Worker            return a + b < 10
1926*da0073e9SAndroid Build Coastguard Worker
1927*da0073e9SAndroid Build Coastguard Worker        input_col_2_dp = tuple_input_ds.filter(_mul_filter_fn, input_col=[0, 2])
1928*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(input_col_2_dp), [(d - 1, d, d + 1) for d in range(5)])
1929*da0073e9SAndroid Build Coastguard Worker
1930*da0073e9SAndroid Build Coastguard Worker        # invalid input col
1931*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1932*da0073e9SAndroid Build Coastguard Worker            tuple_input_ds.filter(_mul_filter_fn, input_col=0)
1933*da0073e9SAndroid Build Coastguard Worker
1934*da0073e9SAndroid Build Coastguard Worker        p_mul_filter_fn = partial(_mul_filter_fn, b=1)
1935*da0073e9SAndroid Build Coastguard Worker        out = tuple_input_ds.filter(p_mul_filter_fn, input_col=0)
1936*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)])
1937*da0073e9SAndroid Build Coastguard Worker
1938*da0073e9SAndroid Build Coastguard Worker        def _mul_filter_fn_with_defaults(a, b=1):
1939*da0073e9SAndroid Build Coastguard Worker            return a + b < 10
1940*da0073e9SAndroid Build Coastguard Worker
1941*da0073e9SAndroid Build Coastguard Worker        out = tuple_input_ds.filter(_mul_filter_fn_with_defaults, input_col=0)
1942*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)])
1943*da0073e9SAndroid Build Coastguard Worker
1944*da0073e9SAndroid Build Coastguard Worker        def _mul_filter_fn_with_kw_only(*, a, b):
1945*da0073e9SAndroid Build Coastguard Worker            return a + b < 10
1946*da0073e9SAndroid Build Coastguard Worker
1947*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1948*da0073e9SAndroid Build Coastguard Worker            tuple_input_ds.filter(_mul_filter_fn_with_kw_only, input_col=0)
1949*da0073e9SAndroid Build Coastguard Worker
1950*da0073e9SAndroid Build Coastguard Worker        def _mul_filter_fn_with_kw_only_1_default(*, a, b=1):
1951*da0073e9SAndroid Build Coastguard Worker            return a + b < 10
1952*da0073e9SAndroid Build Coastguard Worker
1953*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1954*da0073e9SAndroid Build Coastguard Worker            tuple_input_ds.filter(_mul_filter_fn_with_kw_only_1_default, input_col=0)
1955*da0073e9SAndroid Build Coastguard Worker
1956*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: DataPipe has no valid len
1957*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, r"has no len"):
1958*da0073e9SAndroid Build Coastguard Worker            len(filter_dp)
1959*da0073e9SAndroid Build Coastguard Worker
1960*da0073e9SAndroid Build Coastguard Worker        # Reset Test: DataPipe resets correctly
1961*da0073e9SAndroid Build Coastguard Worker        filter_dp = input_ds.filter(partial(_filter_fn, val=5))
1962*da0073e9SAndroid Build Coastguard Worker        n_elements_before_reset = 3
1963*da0073e9SAndroid Build Coastguard Worker        res_before_reset, res_after_reset = reset_after_n_next_calls(
1964*da0073e9SAndroid Build Coastguard Worker            filter_dp, n_elements_before_reset
1965*da0073e9SAndroid Build Coastguard Worker        )
1966*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(5, 10))[:n_elements_before_reset], res_before_reset)
1967*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(5, 10)), res_after_reset)
1968*da0073e9SAndroid Build Coastguard Worker
1969*da0073e9SAndroid Build Coastguard Worker    def test_sampler_iterdatapipe(self):
1970*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper(range(10))
1971*da0073e9SAndroid Build Coastguard Worker        # Default SequentialSampler
1972*da0073e9SAndroid Build Coastguard Worker        sampled_dp = dp.iter.Sampler(input_dp)  # type: ignore[var-annotated]
1973*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(sampled_dp), 10)
1974*da0073e9SAndroid Build Coastguard Worker        for i, x in enumerate(sampled_dp):
1975*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, i)
1976*da0073e9SAndroid Build Coastguard Worker
1977*da0073e9SAndroid Build Coastguard Worker        # RandomSampler
1978*da0073e9SAndroid Build Coastguard Worker        random_sampled_dp = dp.iter.Sampler(
1979*da0073e9SAndroid Build Coastguard Worker            input_dp, sampler=RandomSampler, sampler_kwargs={"replacement": True}
1980*da0073e9SAndroid Build Coastguard Worker        )  # type: ignore[var-annotated] # noqa: B950
1981*da0073e9SAndroid Build Coastguard Worker
1982*da0073e9SAndroid Build Coastguard Worker        # Requires `__len__` to build SamplerDataPipe
1983*da0073e9SAndroid Build Coastguard Worker        input_dp_nolen = IDP_NoLen(range(10))
1984*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
1985*da0073e9SAndroid Build Coastguard Worker            sampled_dp = dp.iter.Sampler(input_dp_nolen)
1986*da0073e9SAndroid Build Coastguard Worker
1987*da0073e9SAndroid Build Coastguard Worker    def test_stream_reader_iterdatapipe(self):
1988*da0073e9SAndroid Build Coastguard Worker        from io import StringIO
1989*da0073e9SAndroid Build Coastguard Worker
1990*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper(
1991*da0073e9SAndroid Build Coastguard Worker            [("f1", StringIO("abcde")), ("f2", StringIO("bcdef"))]
1992*da0073e9SAndroid Build Coastguard Worker        )
1993*da0073e9SAndroid Build Coastguard Worker        expected_res = ["abcde", "bcdef"]
1994*da0073e9SAndroid Build Coastguard Worker
1995*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Read full chunk
1996*da0073e9SAndroid Build Coastguard Worker        dp1 = input_dp.read_from_stream()
1997*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([d[1] for d in dp1], expected_res)
1998*da0073e9SAndroid Build Coastguard Worker
1999*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Read full chunk
2000*da0073e9SAndroid Build Coastguard Worker        dp2 = input_dp.read_from_stream(chunk=1)
2001*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([d[1] for d in dp2], [c for s in expected_res for c in s])
2002*da0073e9SAndroid Build Coastguard Worker
2003*da0073e9SAndroid Build Coastguard Worker        # `__len__` Test
2004*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
2005*da0073e9SAndroid Build Coastguard Worker            len(dp1)
2006*da0073e9SAndroid Build Coastguard Worker
2007*da0073e9SAndroid Build Coastguard Worker    def test_shuffler_iterdatapipe(self):
2008*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.iter.IterableWrapper(list(range(10)))
2009*da0073e9SAndroid Build Coastguard Worker
2010*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
2011*da0073e9SAndroid Build Coastguard Worker            shuffle_dp = input_dp.shuffle(buffer_size=0)
2012*da0073e9SAndroid Build Coastguard Worker
2013*da0073e9SAndroid Build Coastguard Worker        # Functional Test: No seed
2014*da0073e9SAndroid Build Coastguard Worker        shuffler_dp = input_dp.shuffle()
2015*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(set(range(10)), set(shuffler_dp))
2016*da0073e9SAndroid Build Coastguard Worker
2017*da0073e9SAndroid Build Coastguard Worker        # Functional Test: With global seed
2018*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(123)
2019*da0073e9SAndroid Build Coastguard Worker        shuffler_dp = input_dp.shuffle()
2020*da0073e9SAndroid Build Coastguard Worker        res = list(shuffler_dp)
2021*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(123)
2022*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(shuffler_dp), res)
2023*da0073e9SAndroid Build Coastguard Worker
2024*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Set seed
2025*da0073e9SAndroid Build Coastguard Worker        shuffler_dp = input_dp.shuffle().set_seed(123)
2026*da0073e9SAndroid Build Coastguard Worker        res = list(shuffler_dp)
2027*da0073e9SAndroid Build Coastguard Worker        shuffler_dp.set_seed(123)
2028*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(shuffler_dp), res)
2029*da0073e9SAndroid Build Coastguard Worker
2030*da0073e9SAndroid Build Coastguard Worker        # Functional Test: deactivate shuffling via set_shuffle
2031*da0073e9SAndroid Build Coastguard Worker        unshuffled_dp = input_dp.shuffle().set_shuffle(False)
2032*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(unshuffled_dp), list(input_dp))
2033*da0073e9SAndroid Build Coastguard Worker
2034*da0073e9SAndroid Build Coastguard Worker        # Reset Test:
2035*da0073e9SAndroid Build Coastguard Worker        shuffler_dp = input_dp.shuffle()
2036*da0073e9SAndroid Build Coastguard Worker        n_elements_before_reset = 5
2037*da0073e9SAndroid Build Coastguard Worker        res_before_reset, res_after_reset = reset_after_n_next_calls(
2038*da0073e9SAndroid Build Coastguard Worker            shuffler_dp, n_elements_before_reset
2039*da0073e9SAndroid Build Coastguard Worker        )
2040*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(5, len(res_before_reset))
2041*da0073e9SAndroid Build Coastguard Worker        for x in res_before_reset:
2042*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x in set(range(10)))
2043*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(set(range(10)), set(res_after_reset))
2044*da0073e9SAndroid Build Coastguard Worker
2045*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: returns the length of the input DataPipe
2046*da0073e9SAndroid Build Coastguard Worker        shuffler_dp = input_dp.shuffle()
2047*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(10, len(shuffler_dp))
2048*da0073e9SAndroid Build Coastguard Worker        exp = list(range(100))
2049*da0073e9SAndroid Build Coastguard Worker
2050*da0073e9SAndroid Build Coastguard Worker        # Serialization Test
2051*da0073e9SAndroid Build Coastguard Worker        from torch.utils.data.datapipes._hook_iterator import _SnapshotState
2052*da0073e9SAndroid Build Coastguard Worker
2053*da0073e9SAndroid Build Coastguard Worker        def _serialization_helper(bs):
2054*da0073e9SAndroid Build Coastguard Worker            shuffler_dp = input_dp.shuffle(buffer_size=bs)
2055*da0073e9SAndroid Build Coastguard Worker            it = iter(shuffler_dp)
2056*da0073e9SAndroid Build Coastguard Worker            for _ in range(2):
2057*da0073e9SAndroid Build Coastguard Worker                next(it)
2058*da0073e9SAndroid Build Coastguard Worker            shuffler_dp_copy = pickle.loads(pickle.dumps(shuffler_dp))
2059*da0073e9SAndroid Build Coastguard Worker            _simple_graph_snapshot_restoration(
2060*da0073e9SAndroid Build Coastguard Worker                shuffler_dp_copy.datapipe,
2061*da0073e9SAndroid Build Coastguard Worker                shuffler_dp.datapipe._number_of_samples_yielded,
2062*da0073e9SAndroid Build Coastguard Worker            )
2063*da0073e9SAndroid Build Coastguard Worker
2064*da0073e9SAndroid Build Coastguard Worker            exp = list(it)
2065*da0073e9SAndroid Build Coastguard Worker            shuffler_dp_copy._snapshot_state = _SnapshotState.Restored
2066*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(exp, list(shuffler_dp_copy))
2067*da0073e9SAndroid Build Coastguard Worker
2068*da0073e9SAndroid Build Coastguard Worker        buffer_sizes = [2, 5, 15]
2069*da0073e9SAndroid Build Coastguard Worker        for bs in buffer_sizes:
2070*da0073e9SAndroid Build Coastguard Worker            _serialization_helper(bs)
2071*da0073e9SAndroid Build Coastguard Worker
2072*da0073e9SAndroid Build Coastguard Worker    def test_zip_iterdatapipe(self):
2073*da0073e9SAndroid Build Coastguard Worker        # Functional Test: raises TypeError when an input is not of type `IterDataPipe`
2074*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
2075*da0073e9SAndroid Build Coastguard Worker            dp.iter.Zipper(dp.iter.IterableWrapper(range(10)), list(range(10)))  # type: ignore[arg-type]
2076*da0073e9SAndroid Build Coastguard Worker
2077*da0073e9SAndroid Build Coastguard Worker        # Functional Test: raises TypeError when an input does not have valid length
2078*da0073e9SAndroid Build Coastguard Worker        zipped_dp = dp.iter.Zipper(
2079*da0073e9SAndroid Build Coastguard Worker            dp.iter.IterableWrapper(range(10)), IDP_NoLen(range(5))
2080*da0073e9SAndroid Build Coastguard Worker        )  # type: ignore[var-annotated]
2081*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
2082*da0073e9SAndroid Build Coastguard Worker            len(zipped_dp)
2083*da0073e9SAndroid Build Coastguard Worker
2084*da0073e9SAndroid Build Coastguard Worker        # Functional Test: zips the results properly
2085*da0073e9SAndroid Build Coastguard Worker        exp = [(i, i) for i in range(5)]
2086*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(zipped_dp), exp)
2087*da0073e9SAndroid Build Coastguard Worker
2088*da0073e9SAndroid Build Coastguard Worker        # Functional Test: zips the inputs properly even when lengths are different (zips to the shortest)
2089*da0073e9SAndroid Build Coastguard Worker        zipped_dp = dp.iter.Zipper(
2090*da0073e9SAndroid Build Coastguard Worker            dp.iter.IterableWrapper(range(10)), dp.iter.IterableWrapper(range(5))
2091*da0073e9SAndroid Build Coastguard Worker        )
2092*da0073e9SAndroid Build Coastguard Worker
2093*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: length matches the length of the shortest input
2094*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(zipped_dp), 5)
2095*da0073e9SAndroid Build Coastguard Worker
2096*da0073e9SAndroid Build Coastguard Worker        # Reset Test:
2097*da0073e9SAndroid Build Coastguard Worker        n_elements_before_reset = 3
2098*da0073e9SAndroid Build Coastguard Worker        res_before_reset, res_after_reset = reset_after_n_next_calls(
2099*da0073e9SAndroid Build Coastguard Worker            zipped_dp, n_elements_before_reset
2100*da0073e9SAndroid Build Coastguard Worker        )
2101*da0073e9SAndroid Build Coastguard Worker        expected_res = [(i, i) for i in range(5)]
2102*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset)
2103*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_res, res_after_reset)
2104*da0073e9SAndroid Build Coastguard Worker
2105*da0073e9SAndroid Build Coastguard Worker
2106*da0073e9SAndroid Build Coastguard Workerclass TestFunctionalMapDataPipe(TestCase):
2107*da0073e9SAndroid Build Coastguard Worker    def _serialization_test_helper(self, datapipe, use_dill):
2108*da0073e9SAndroid Build Coastguard Worker        if use_dill:
2109*da0073e9SAndroid Build Coastguard Worker            serialized_dp = dill.dumps(datapipe)
2110*da0073e9SAndroid Build Coastguard Worker            deserialized_dp = dill.loads(serialized_dp)
2111*da0073e9SAndroid Build Coastguard Worker        else:
2112*da0073e9SAndroid Build Coastguard Worker            serialized_dp = pickle.dumps(datapipe)
2113*da0073e9SAndroid Build Coastguard Worker            deserialized_dp = pickle.loads(serialized_dp)
2114*da0073e9SAndroid Build Coastguard Worker        try:
2115*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(datapipe), list(deserialized_dp))
2116*da0073e9SAndroid Build Coastguard Worker        except AssertionError as e:
2117*da0073e9SAndroid Build Coastguard Worker            print(f"{datapipe} is failing.")
2118*da0073e9SAndroid Build Coastguard Worker            raise e
2119*da0073e9SAndroid Build Coastguard Worker
2120*da0073e9SAndroid Build Coastguard Worker    def _serialization_test_for_single_dp(self, dp, use_dill=False):
2121*da0073e9SAndroid Build Coastguard Worker        # 1. Testing for serialization before any iteration starts
2122*da0073e9SAndroid Build Coastguard Worker        self._serialization_test_helper(dp, use_dill)
2123*da0073e9SAndroid Build Coastguard Worker        # 2. Testing for serialization after DataPipe is partially read
2124*da0073e9SAndroid Build Coastguard Worker        it = iter(dp)
2125*da0073e9SAndroid Build Coastguard Worker        _ = next(it)
2126*da0073e9SAndroid Build Coastguard Worker        self._serialization_test_helper(dp, use_dill)
2127*da0073e9SAndroid Build Coastguard Worker        # 3. Testing for serialization after DataPipe is fully read
2128*da0073e9SAndroid Build Coastguard Worker        _ = list(dp)
2129*da0073e9SAndroid Build Coastguard Worker        self._serialization_test_helper(dp, use_dill)
2130*da0073e9SAndroid Build Coastguard Worker
2131*da0073e9SAndroid Build Coastguard Worker    def test_serializable(self):
2132*da0073e9SAndroid Build Coastguard Worker        picklable_datapipes: List = [
2133*da0073e9SAndroid Build Coastguard Worker            (dp.map.Batcher, None, (2,), {}),
2134*da0073e9SAndroid Build Coastguard Worker            (dp.map.Concater, None, (dp.map.SequenceWrapper(range(10)),), {}),
2135*da0073e9SAndroid Build Coastguard Worker            (dp.map.Mapper, None, (), {}),
2136*da0073e9SAndroid Build Coastguard Worker            (dp.map.Mapper, None, (_fake_fn,), {}),
2137*da0073e9SAndroid Build Coastguard Worker            (dp.map.Mapper, None, (partial(_fake_add, 1),), {}),
2138*da0073e9SAndroid Build Coastguard Worker            (dp.map.SequenceWrapper, range(10), (), {}),
2139*da0073e9SAndroid Build Coastguard Worker            (dp.map.Shuffler, dp.map.SequenceWrapper([0] * 5), (), {}),
2140*da0073e9SAndroid Build Coastguard Worker            (dp.map.Zipper, None, (dp.map.SequenceWrapper(range(10)),), {}),
2141*da0073e9SAndroid Build Coastguard Worker        ]
2142*da0073e9SAndroid Build Coastguard Worker        for dpipe, custom_input, dp_args, dp_kwargs in picklable_datapipes:
2143*da0073e9SAndroid Build Coastguard Worker            if custom_input is None:
2144*da0073e9SAndroid Build Coastguard Worker                custom_input = dp.map.SequenceWrapper(range(10))
2145*da0073e9SAndroid Build Coastguard Worker            datapipe = dpipe(custom_input, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
2146*da0073e9SAndroid Build Coastguard Worker            self._serialization_test_for_single_dp(datapipe)
2147*da0073e9SAndroid Build Coastguard Worker
2148*da0073e9SAndroid Build Coastguard Worker    def test_serializable_with_dill(self):
2149*da0073e9SAndroid Build Coastguard Worker        """Only for DataPipes that take in a function as argument"""
2150*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.map.SequenceWrapper(range(10))
2151*da0073e9SAndroid Build Coastguard Worker
2152*da0073e9SAndroid Build Coastguard Worker        datapipes_with_lambda_fn: List[
2153*da0073e9SAndroid Build Coastguard Worker            Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
2154*da0073e9SAndroid Build Coastguard Worker        ] = [
2155*da0073e9SAndroid Build Coastguard Worker            (dp.map.Mapper, (lambda_fn1,), {}),
2156*da0073e9SAndroid Build Coastguard Worker        ]
2157*da0073e9SAndroid Build Coastguard Worker
2158*da0073e9SAndroid Build Coastguard Worker        def _local_fns():
2159*da0073e9SAndroid Build Coastguard Worker            def _fn1(x):
2160*da0073e9SAndroid Build Coastguard Worker                return x
2161*da0073e9SAndroid Build Coastguard Worker
2162*da0073e9SAndroid Build Coastguard Worker            return _fn1
2163*da0073e9SAndroid Build Coastguard Worker
2164*da0073e9SAndroid Build Coastguard Worker        fn1 = _local_fns()
2165*da0073e9SAndroid Build Coastguard Worker
2166*da0073e9SAndroid Build Coastguard Worker        datapipes_with_local_fn: List[
2167*da0073e9SAndroid Build Coastguard Worker            Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
2168*da0073e9SAndroid Build Coastguard Worker        ] = [
2169*da0073e9SAndroid Build Coastguard Worker            (dp.map.Mapper, (fn1,), {}),
2170*da0073e9SAndroid Build Coastguard Worker        ]
2171*da0073e9SAndroid Build Coastguard Worker
2172*da0073e9SAndroid Build Coastguard Worker        if HAS_DILL:
2173*da0073e9SAndroid Build Coastguard Worker            for dpipe, dp_args, dp_kwargs in (
2174*da0073e9SAndroid Build Coastguard Worker                datapipes_with_lambda_fn + datapipes_with_local_fn
2175*da0073e9SAndroid Build Coastguard Worker            ):
2176*da0073e9SAndroid Build Coastguard Worker                _ = dill.dumps(dpipe(input_dp, *dp_args, **dp_kwargs))  # type: ignore[call-arg]
2177*da0073e9SAndroid Build Coastguard Worker        else:
2178*da0073e9SAndroid Build Coastguard Worker            msgs = (
2179*da0073e9SAndroid Build Coastguard Worker                r"^Lambda function is not supported by pickle",
2180*da0073e9SAndroid Build Coastguard Worker                r"^Local function is not supported by pickle",
2181*da0073e9SAndroid Build Coastguard Worker            )
2182*da0073e9SAndroid Build Coastguard Worker            for dps, msg in zip(
2183*da0073e9SAndroid Build Coastguard Worker                (datapipes_with_lambda_fn, datapipes_with_local_fn), msgs
2184*da0073e9SAndroid Build Coastguard Worker            ):
2185*da0073e9SAndroid Build Coastguard Worker                for dpipe, dp_args, dp_kwargs in dps:
2186*da0073e9SAndroid Build Coastguard Worker                    with self.assertWarnsRegex(UserWarning, msg):
2187*da0073e9SAndroid Build Coastguard Worker                        datapipe = dpipe(input_dp, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
2188*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises((pickle.PicklingError, AttributeError)):
2189*da0073e9SAndroid Build Coastguard Worker                        pickle.dumps(datapipe)
2190*da0073e9SAndroid Build Coastguard Worker
2191*da0073e9SAndroid Build Coastguard Worker    def test_docstring(self):
2192*da0073e9SAndroid Build Coastguard Worker        """
2193*da0073e9SAndroid Build Coastguard Worker        Ensure functional form of MapDataPipe has the correct docstring from
2194*da0073e9SAndroid Build Coastguard Worker        the class form.
2195*da0073e9SAndroid Build Coastguard Worker
2196*da0073e9SAndroid Build Coastguard Worker        Regression test for https://github.com/pytorch/data/issues/792.
2197*da0073e9SAndroid Build Coastguard Worker        """
2198*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.map.SequenceWrapper(range(10))
2199*da0073e9SAndroid Build Coastguard Worker
2200*da0073e9SAndroid Build Coastguard Worker        for dp_funcname in [
2201*da0073e9SAndroid Build Coastguard Worker            "batch",
2202*da0073e9SAndroid Build Coastguard Worker            "concat",
2203*da0073e9SAndroid Build Coastguard Worker            "map",
2204*da0073e9SAndroid Build Coastguard Worker            "shuffle",
2205*da0073e9SAndroid Build Coastguard Worker            "zip",
2206*da0073e9SAndroid Build Coastguard Worker        ]:
2207*da0073e9SAndroid Build Coastguard Worker            if sys.version_info >= (3, 9):
2208*da0073e9SAndroid Build Coastguard Worker                docstring = pydoc.render_doc(
2209*da0073e9SAndroid Build Coastguard Worker                    thing=getattr(input_dp, dp_funcname), forceload=True
2210*da0073e9SAndroid Build Coastguard Worker                )
2211*da0073e9SAndroid Build Coastguard Worker            elif sys.version_info < (3, 9):
2212*da0073e9SAndroid Build Coastguard Worker                # pydoc works differently on Python 3.8, see
2213*da0073e9SAndroid Build Coastguard Worker                # https://docs.python.org/3/whatsnew/3.9.html#pydoc
2214*da0073e9SAndroid Build Coastguard Worker                docstring = getattr(input_dp, dp_funcname).__doc__
2215*da0073e9SAndroid Build Coastguard Worker            assert f"(functional name: ``{dp_funcname}``)" in docstring
2216*da0073e9SAndroid Build Coastguard Worker            assert "Args:" in docstring
2217*da0073e9SAndroid Build Coastguard Worker            assert "Example:" in docstring or "Examples:" in docstring
2218*da0073e9SAndroid Build Coastguard Worker
2219*da0073e9SAndroid Build Coastguard Worker    def test_sequence_wrapper_datapipe(self):
2220*da0073e9SAndroid Build Coastguard Worker        seq = list(range(10))
2221*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.map.SequenceWrapper(seq)
2222*da0073e9SAndroid Build Coastguard Worker
2223*da0073e9SAndroid Build Coastguard Worker        # Functional Test: all elements are equal in the same order
2224*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(seq, list(input_dp))
2225*da0073e9SAndroid Build Coastguard Worker
2226*da0073e9SAndroid Build Coastguard Worker        # Functional Test: confirm deepcopy works by default
2227*da0073e9SAndroid Build Coastguard Worker        seq.append(11)
2228*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), list(input_dp))  # input_dp shouldn't have 11
2229*da0073e9SAndroid Build Coastguard Worker
2230*da0073e9SAndroid Build Coastguard Worker        # Functional Test: non-deepcopy version is working
2231*da0073e9SAndroid Build Coastguard Worker        seq2 = [1, 2, 3]
2232*da0073e9SAndroid Build Coastguard Worker        input_dp_non_deep = dp.map.SequenceWrapper(seq2, deepcopy=False)
2233*da0073e9SAndroid Build Coastguard Worker        seq2.append(4)
2234*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(seq2), list(input_dp_non_deep))  # should have 4
2235*da0073e9SAndroid Build Coastguard Worker
2236*da0073e9SAndroid Build Coastguard Worker        # Reset Test: reset the DataPipe
2237*da0073e9SAndroid Build Coastguard Worker        seq = list(range(10))
2238*da0073e9SAndroid Build Coastguard Worker        n_elements_before_reset = 5
2239*da0073e9SAndroid Build Coastguard Worker        res_before_reset, res_after_reset = reset_after_n_next_calls(
2240*da0073e9SAndroid Build Coastguard Worker            input_dp, n_elements_before_reset
2241*da0073e9SAndroid Build Coastguard Worker        )
2242*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(5)), res_before_reset)
2243*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(seq, res_after_reset)
2244*da0073e9SAndroid Build Coastguard Worker
2245*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: inherits length from sequence
2246*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(seq), len(input_dp))
2247*da0073e9SAndroid Build Coastguard Worker
2248*da0073e9SAndroid Build Coastguard Worker    def test_concat_mapdatapipe(self):
2249*da0073e9SAndroid Build Coastguard Worker        input_dp1 = dp.map.SequenceWrapper(range(10))
2250*da0073e9SAndroid Build Coastguard Worker        input_dp2 = dp.map.SequenceWrapper(range(5))
2251*da0073e9SAndroid Build Coastguard Worker
2252*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
2253*da0073e9SAndroid Build Coastguard Worker            dp.map.Concater()
2254*da0073e9SAndroid Build Coastguard Worker
2255*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2256*da0073e9SAndroid Build Coastguard Worker            TypeError, r"Expected all inputs to be `MapDataPipe`"
2257*da0073e9SAndroid Build Coastguard Worker        ):
2258*da0073e9SAndroid Build Coastguard Worker            dp.map.Concater(input_dp1, ())  # type: ignore[arg-type]
2259*da0073e9SAndroid Build Coastguard Worker
2260*da0073e9SAndroid Build Coastguard Worker        concat_dp = input_dp1.concat(input_dp2)
2261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(concat_dp), 15)
2262*da0073e9SAndroid Build Coastguard Worker        for index in range(15):
2263*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
2264*da0073e9SAndroid Build Coastguard Worker                concat_dp[index], (list(range(10)) + list(range(5)))[index]
2265*da0073e9SAndroid Build Coastguard Worker            )
2266*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(concat_dp), list(range(10)) + list(range(5)))
2267*da0073e9SAndroid Build Coastguard Worker
2268*da0073e9SAndroid Build Coastguard Worker    def test_zip_mapdatapipe(self):
2269*da0073e9SAndroid Build Coastguard Worker        input_dp1 = dp.map.SequenceWrapper(range(10))
2270*da0073e9SAndroid Build Coastguard Worker        input_dp2 = dp.map.SequenceWrapper(range(5))
2271*da0073e9SAndroid Build Coastguard Worker        input_dp3 = dp.map.SequenceWrapper(range(15))
2272*da0073e9SAndroid Build Coastguard Worker
2273*da0073e9SAndroid Build Coastguard Worker        # Functional Test: requires at least one input DataPipe
2274*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
2275*da0073e9SAndroid Build Coastguard Worker            dp.map.Zipper()
2276*da0073e9SAndroid Build Coastguard Worker
2277*da0073e9SAndroid Build Coastguard Worker        # Functional Test: all inputs must be MapDataPipes
2278*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2279*da0073e9SAndroid Build Coastguard Worker            TypeError, r"Expected all inputs to be `MapDataPipe`"
2280*da0073e9SAndroid Build Coastguard Worker        ):
2281*da0073e9SAndroid Build Coastguard Worker            dp.map.Zipper(input_dp1, ())  # type: ignore[arg-type]
2282*da0073e9SAndroid Build Coastguard Worker
2283*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Zip the elements up as a tuples
2284*da0073e9SAndroid Build Coastguard Worker        zip_dp = input_dp1.zip(input_dp2, input_dp3)
2285*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([(i, i, i) for i in range(5)], [zip_dp[i] for i in range(5)])
2286*da0073e9SAndroid Build Coastguard Worker
2287*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Raise IndexError when index equal or exceed the length of the shortest DataPipe
2288*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(IndexError, r"out of range"):
2289*da0073e9SAndroid Build Coastguard Worker            input_dp1.zip(input_dp2, input_dp3)[5]
2290*da0073e9SAndroid Build Coastguard Worker
2291*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Ensure `zip` can combine `Batcher` with others
2292*da0073e9SAndroid Build Coastguard Worker        dp1 = dp.map.SequenceWrapper(range(10))
2293*da0073e9SAndroid Build Coastguard Worker        shuffle_dp1 = dp1.batch(2)
2294*da0073e9SAndroid Build Coastguard Worker        dp2 = dp.map.SequenceWrapper(range(10))
2295*da0073e9SAndroid Build Coastguard Worker        shuffle_dp2 = dp2.batch(3)
2296*da0073e9SAndroid Build Coastguard Worker        zip_dp1 = shuffle_dp1.zip(shuffle_dp2)
2297*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(4, len(list(zip_dp1)))
2298*da0073e9SAndroid Build Coastguard Worker        zip_dp2 = shuffle_dp1.zip(dp2)
2299*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(5, len(list(zip_dp2)))
2300*da0073e9SAndroid Build Coastguard Worker
2301*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: returns the length of the shortest DataPipe
2302*da0073e9SAndroid Build Coastguard Worker        zip_dp = input_dp1.zip(input_dp2, input_dp3)
2303*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(5, len(zip_dp))
2304*da0073e9SAndroid Build Coastguard Worker
2305*da0073e9SAndroid Build Coastguard Worker    def test_shuffler_mapdatapipe(self):
2306*da0073e9SAndroid Build Coastguard Worker        input_dp1 = dp.map.SequenceWrapper(range(10))
2307*da0073e9SAndroid Build Coastguard Worker        input_dp2 = dp.map.SequenceWrapper({"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
2308*da0073e9SAndroid Build Coastguard Worker
2309*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Assumes 0-index when indices is not given
2310*da0073e9SAndroid Build Coastguard Worker        shuffler_dp = input_dp1.shuffle()
2311*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(set(range(10)), set(shuffler_dp))
2312*da0073e9SAndroid Build Coastguard Worker
2313*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Custom indices are working
2314*da0073e9SAndroid Build Coastguard Worker        shuffler_dp = input_dp2.shuffle(indices=["a", "b", "c", "d", "e"])
2315*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(set(range(1, 6)), set(shuffler_dp))
2316*da0073e9SAndroid Build Coastguard Worker
2317*da0073e9SAndroid Build Coastguard Worker        # Functional Test: With global seed
2318*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(123)
2319*da0073e9SAndroid Build Coastguard Worker        shuffler_dp = input_dp1.shuffle()
2320*da0073e9SAndroid Build Coastguard Worker        res = list(shuffler_dp)
2321*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(123)
2322*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(shuffler_dp), res)
2323*da0073e9SAndroid Build Coastguard Worker
2324*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Set seed
2325*da0073e9SAndroid Build Coastguard Worker        shuffler_dp = input_dp1.shuffle().set_seed(123)
2326*da0073e9SAndroid Build Coastguard Worker        res = list(shuffler_dp)
2327*da0073e9SAndroid Build Coastguard Worker        shuffler_dp.set_seed(123)
2328*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(shuffler_dp), res)
2329*da0073e9SAndroid Build Coastguard Worker
2330*da0073e9SAndroid Build Coastguard Worker        # Functional Test: deactivate shuffling via set_shuffle
2331*da0073e9SAndroid Build Coastguard Worker        unshuffled_dp = input_dp1.shuffle().set_shuffle(False)
2332*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(unshuffled_dp), list(input_dp1))
2333*da0073e9SAndroid Build Coastguard Worker
2334*da0073e9SAndroid Build Coastguard Worker        # Reset Test:
2335*da0073e9SAndroid Build Coastguard Worker        shuffler_dp = input_dp1.shuffle()
2336*da0073e9SAndroid Build Coastguard Worker        n_elements_before_reset = 5
2337*da0073e9SAndroid Build Coastguard Worker        res_before_reset, res_after_reset = reset_after_n_next_calls(
2338*da0073e9SAndroid Build Coastguard Worker            shuffler_dp, n_elements_before_reset
2339*da0073e9SAndroid Build Coastguard Worker        )
2340*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(5, len(res_before_reset))
2341*da0073e9SAndroid Build Coastguard Worker        for x in res_before_reset:
2342*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x in set(range(10)))
2343*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(set(range(10)), set(res_after_reset))
2344*da0073e9SAndroid Build Coastguard Worker
2345*da0073e9SAndroid Build Coastguard Worker        # __len__ Test: returns the length of the input DataPipe
2346*da0073e9SAndroid Build Coastguard Worker        shuffler_dp = input_dp1.shuffle()
2347*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(10, len(shuffler_dp))
2348*da0073e9SAndroid Build Coastguard Worker
2349*da0073e9SAndroid Build Coastguard Worker        # Serialization Test
2350*da0073e9SAndroid Build Coastguard Worker        from torch.utils.data.datapipes._hook_iterator import _SnapshotState
2351*da0073e9SAndroid Build Coastguard Worker
2352*da0073e9SAndroid Build Coastguard Worker        shuffler_dp = input_dp1.shuffle()
2353*da0073e9SAndroid Build Coastguard Worker        it = iter(shuffler_dp)
2354*da0073e9SAndroid Build Coastguard Worker        for _ in range(2):
2355*da0073e9SAndroid Build Coastguard Worker            next(it)
2356*da0073e9SAndroid Build Coastguard Worker        shuffler_dp_copy = pickle.loads(pickle.dumps(shuffler_dp))
2357*da0073e9SAndroid Build Coastguard Worker
2358*da0073e9SAndroid Build Coastguard Worker        exp = list(it)
2359*da0073e9SAndroid Build Coastguard Worker        shuffler_dp_copy._snapshot_state = _SnapshotState.Restored
2360*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(exp, list(shuffler_dp_copy))
2361*da0073e9SAndroid Build Coastguard Worker
2362*da0073e9SAndroid Build Coastguard Worker    def test_map_mapdatapipe(self):
2363*da0073e9SAndroid Build Coastguard Worker        arr = range(10)
2364*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.map.SequenceWrapper(arr)
2365*da0073e9SAndroid Build Coastguard Worker
2366*da0073e9SAndroid Build Coastguard Worker        def fn(item, dtype=torch.float, *, sum=False):
2367*da0073e9SAndroid Build Coastguard Worker            data = torch.tensor(item, dtype=dtype)
2368*da0073e9SAndroid Build Coastguard Worker            return data if not sum else data.sum()
2369*da0073e9SAndroid Build Coastguard Worker
2370*da0073e9SAndroid Build Coastguard Worker        map_dp = input_dp.map(fn)
2371*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(input_dp), len(map_dp))
2372*da0073e9SAndroid Build Coastguard Worker        for index in arr:
2373*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
2374*da0073e9SAndroid Build Coastguard Worker                map_dp[index], torch.tensor(input_dp[index], dtype=torch.float)
2375*da0073e9SAndroid Build Coastguard Worker            )
2376*da0073e9SAndroid Build Coastguard Worker
2377*da0073e9SAndroid Build Coastguard Worker        map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True))
2378*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(input_dp), len(map_dp))
2379*da0073e9SAndroid Build Coastguard Worker        for index in arr:
2380*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
2381*da0073e9SAndroid Build Coastguard Worker                map_dp[index], torch.tensor(input_dp[index], dtype=torch.int).sum()
2382*da0073e9SAndroid Build Coastguard Worker            )
2383*da0073e9SAndroid Build Coastguard Worker
2384*da0073e9SAndroid Build Coastguard Worker    def test_batch_mapdatapipe(self):
2385*da0073e9SAndroid Build Coastguard Worker        arr = list(range(13))
2386*da0073e9SAndroid Build Coastguard Worker        input_dp = dp.map.SequenceWrapper(arr)
2387*da0073e9SAndroid Build Coastguard Worker
2388*da0073e9SAndroid Build Coastguard Worker        # Functional Test: batches top level by default
2389*da0073e9SAndroid Build Coastguard Worker        batch_dp = dp.map.Batcher(input_dp, batch_size=2)
2390*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2391*da0073e9SAndroid Build Coastguard Worker            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12]], list(batch_dp)
2392*da0073e9SAndroid Build Coastguard Worker        )
2393*da0073e9SAndroid Build Coastguard Worker
2394*da0073e9SAndroid Build Coastguard Worker        # Functional Test: drop_last on command
2395*da0073e9SAndroid Build Coastguard Worker        batch_dp = dp.map.Batcher(input_dp, batch_size=2, drop_last=True)
2396*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2397*da0073e9SAndroid Build Coastguard Worker            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]], list(batch_dp)
2398*da0073e9SAndroid Build Coastguard Worker        )
2399*da0073e9SAndroid Build Coastguard Worker
2400*da0073e9SAndroid Build Coastguard Worker        # Functional Test: nested batching
2401*da0073e9SAndroid Build Coastguard Worker        batch_dp_2 = batch_dp.batch(batch_size=3)
2402*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2403*da0073e9SAndroid Build Coastguard Worker            [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], list(batch_dp_2)
2404*da0073e9SAndroid Build Coastguard Worker        )
2405*da0073e9SAndroid Build Coastguard Worker
2406*da0073e9SAndroid Build Coastguard Worker        # Reset Test:
2407*da0073e9SAndroid Build Coastguard Worker        n_elements_before_reset = 3
2408*da0073e9SAndroid Build Coastguard Worker        res_before_reset, res_after_reset = reset_after_n_next_calls(
2409*da0073e9SAndroid Build Coastguard Worker            batch_dp, n_elements_before_reset
2410*da0073e9SAndroid Build Coastguard Worker        )
2411*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([[0, 1], [2, 3], [4, 5]], res_before_reset)
2412*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2413*da0073e9SAndroid Build Coastguard Worker            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]], res_after_reset
2414*da0073e9SAndroid Build Coastguard Worker        )
2415*da0073e9SAndroid Build Coastguard Worker
2416*da0073e9SAndroid Build Coastguard Worker        # __len__ Test:
2417*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(6, len(batch_dp))
2418*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, len(batch_dp_2))
2419*da0073e9SAndroid Build Coastguard Worker
2420*da0073e9SAndroid Build Coastguard Worker
2421*da0073e9SAndroid Build Coastguard Worker# Metaclass conflict for Python 3.6
2422*da0073e9SAndroid Build Coastguard Worker# Multiple inheritance with NamedTuple is not supported for Python 3.9
2423*da0073e9SAndroid Build Coastguard Worker_generic_namedtuple_allowed = sys.version_info >= (3, 7) and sys.version_info < (3, 9)
2424*da0073e9SAndroid Build Coastguard Workerif _generic_namedtuple_allowed:
2425*da0073e9SAndroid Build Coastguard Worker
2426*da0073e9SAndroid Build Coastguard Worker    class InvalidData(NamedTuple, Generic[T_co]):
2427*da0073e9SAndroid Build Coastguard Worker        name: str
2428*da0073e9SAndroid Build Coastguard Worker        data: T_co
2429*da0073e9SAndroid Build Coastguard Worker
2430*da0073e9SAndroid Build Coastguard Worker
2431*da0073e9SAndroid Build Coastguard Workerclass TestTyping(TestCase):
2432*da0073e9SAndroid Build Coastguard Worker    def test_isinstance(self):
2433*da0073e9SAndroid Build Coastguard Worker        class A(IterDataPipe):
2434*da0073e9SAndroid Build Coastguard Worker            pass
2435*da0073e9SAndroid Build Coastguard Worker
2436*da0073e9SAndroid Build Coastguard Worker        class B(IterDataPipe):
2437*da0073e9SAndroid Build Coastguard Worker            pass
2438*da0073e9SAndroid Build Coastguard Worker
2439*da0073e9SAndroid Build Coastguard Worker        a = A()
2440*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(a, A))
2441*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(isinstance(a, B))
2442*da0073e9SAndroid Build Coastguard Worker
2443*da0073e9SAndroid Build Coastguard Worker    def test_protocol(self):
2444*da0073e9SAndroid Build Coastguard Worker        try:
2445*da0073e9SAndroid Build Coastguard Worker            from typing import Protocol  # type: ignore[attr-defined]
2446*da0073e9SAndroid Build Coastguard Worker        except ImportError:
2447*da0073e9SAndroid Build Coastguard Worker            from typing import _Protocol  # type: ignore[attr-defined]
2448*da0073e9SAndroid Build Coastguard Worker
2449*da0073e9SAndroid Build Coastguard Worker            Protocol = _Protocol
2450*da0073e9SAndroid Build Coastguard Worker
2451*da0073e9SAndroid Build Coastguard Worker        class P(Protocol):
2452*da0073e9SAndroid Build Coastguard Worker            pass
2453*da0073e9SAndroid Build Coastguard Worker
2454*da0073e9SAndroid Build Coastguard Worker        class A(IterDataPipe[P]):
2455*da0073e9SAndroid Build Coastguard Worker            pass
2456*da0073e9SAndroid Build Coastguard Worker
2457*da0073e9SAndroid Build Coastguard Worker    @skipTyping
2458*da0073e9SAndroid Build Coastguard Worker    def test_subtype(self):
2459*da0073e9SAndroid Build Coastguard Worker        from torch.utils.data.datapipes._typing import issubtype
2460*da0073e9SAndroid Build Coastguard Worker
2461*da0073e9SAndroid Build Coastguard Worker        basic_type = (int, str, bool, float, complex, list, tuple, dict, set, T_co)
2462*da0073e9SAndroid Build Coastguard Worker        for t in basic_type:
2463*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(issubtype(t, t))
2464*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(issubtype(t, Any))
2465*da0073e9SAndroid Build Coastguard Worker            if t == T_co:
2466*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(issubtype(Any, t))
2467*da0073e9SAndroid Build Coastguard Worker            else:
2468*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(issubtype(Any, t))
2469*da0073e9SAndroid Build Coastguard Worker        for t1, t2 in itertools.product(basic_type, basic_type):
2470*da0073e9SAndroid Build Coastguard Worker            if t1 == t2 or t2 == T_co:
2471*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(issubtype(t1, t2))
2472*da0073e9SAndroid Build Coastguard Worker            else:
2473*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(issubtype(t1, t2))
2474*da0073e9SAndroid Build Coastguard Worker
2475*da0073e9SAndroid Build Coastguard Worker        T = TypeVar("T", int, str)
2476*da0073e9SAndroid Build Coastguard Worker        S = TypeVar("S", bool, Union[str, int], Tuple[int, T])  # type: ignore[valid-type]
2477*da0073e9SAndroid Build Coastguard Worker        types = (
2478*da0073e9SAndroid Build Coastguard Worker            (int, Optional[int]),
2479*da0073e9SAndroid Build Coastguard Worker            (List, Union[int, list]),
2480*da0073e9SAndroid Build Coastguard Worker            (Tuple[int, str], S),
2481*da0073e9SAndroid Build Coastguard Worker            (Tuple[int, str], tuple),
2482*da0073e9SAndroid Build Coastguard Worker            (T, S),
2483*da0073e9SAndroid Build Coastguard Worker            (S, T_co),
2484*da0073e9SAndroid Build Coastguard Worker            (T, Union[S, Set]),
2485*da0073e9SAndroid Build Coastguard Worker        )
2486*da0073e9SAndroid Build Coastguard Worker        for sub, par in types:
2487*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(issubtype(sub, par))
2488*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(issubtype(par, sub))
2489*da0073e9SAndroid Build Coastguard Worker
2490*da0073e9SAndroid Build Coastguard Worker        subscriptable_types = {
2491*da0073e9SAndroid Build Coastguard Worker            List: 1,
2492*da0073e9SAndroid Build Coastguard Worker            Tuple: 2,  # use 2 parameters
2493*da0073e9SAndroid Build Coastguard Worker            Set: 1,
2494*da0073e9SAndroid Build Coastguard Worker            Dict: 2,
2495*da0073e9SAndroid Build Coastguard Worker        }
2496*da0073e9SAndroid Build Coastguard Worker        for subscript_type, n in subscriptable_types.items():
2497*da0073e9SAndroid Build Coastguard Worker            for ts in itertools.combinations(types, n):
2498*da0073e9SAndroid Build Coastguard Worker                subs, pars = zip(*ts)
2499*da0073e9SAndroid Build Coastguard Worker                sub = subscript_type[subs]  # type: ignore[index]
2500*da0073e9SAndroid Build Coastguard Worker                par = subscript_type[pars]  # type: ignore[index]
2501*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(issubtype(sub, par))
2502*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(issubtype(par, sub))
2503*da0073e9SAndroid Build Coastguard Worker                # Non-recursive check
2504*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(issubtype(par, sub, recursive=False))
2505*da0073e9SAndroid Build Coastguard Worker
2506*da0073e9SAndroid Build Coastguard Worker    @skipTyping
2507*da0073e9SAndroid Build Coastguard Worker    def test_issubinstance(self):
2508*da0073e9SAndroid Build Coastguard Worker        from torch.utils.data.datapipes._typing import issubinstance
2509*da0073e9SAndroid Build Coastguard Worker
2510*da0073e9SAndroid Build Coastguard Worker        basic_data = (1, "1", True, 1.0, complex(1.0, 0.0))
2511*da0073e9SAndroid Build Coastguard Worker        basic_type = (int, str, bool, float, complex)
2512*da0073e9SAndroid Build Coastguard Worker        S = TypeVar("S", bool, Union[str, int])
2513*da0073e9SAndroid Build Coastguard Worker        for d in basic_data:
2514*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(issubinstance(d, Any))
2515*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(issubinstance(d, T_co))
2516*da0073e9SAndroid Build Coastguard Worker            if type(d) in (bool, int, str):
2517*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(issubinstance(d, S))
2518*da0073e9SAndroid Build Coastguard Worker            else:
2519*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(issubinstance(d, S))
2520*da0073e9SAndroid Build Coastguard Worker            for t in basic_type:
2521*da0073e9SAndroid Build Coastguard Worker                if type(d) == t:
2522*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(issubinstance(d, t))
2523*da0073e9SAndroid Build Coastguard Worker                else:
2524*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(issubinstance(d, t))
2525*da0073e9SAndroid Build Coastguard Worker        # list/set
2526*da0073e9SAndroid Build Coastguard Worker        dt = (([1, "1", 2], List), (set({1, "1", 2}), Set))
2527*da0073e9SAndroid Build Coastguard Worker        for d, t in dt:
2528*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(issubinstance(d, t))
2529*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(issubinstance(d, t[T_co]))  # type: ignore[index]
2530*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(issubinstance(d, t[int]))  # type: ignore[index]
2531*da0073e9SAndroid Build Coastguard Worker
2532*da0073e9SAndroid Build Coastguard Worker        # dict
2533*da0073e9SAndroid Build Coastguard Worker        d = {"1": 1, "2": 2.0}
2534*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubinstance(d, Dict))
2535*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubinstance(d, Dict[str, T_co]))
2536*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(issubinstance(d, Dict[str, int]))
2537*da0073e9SAndroid Build Coastguard Worker
2538*da0073e9SAndroid Build Coastguard Worker        # tuple
2539*da0073e9SAndroid Build Coastguard Worker        d = (1, "1", 2)
2540*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubinstance(d, Tuple))
2541*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubinstance(d, Tuple[int, str, T_co]))
2542*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(issubinstance(d, Tuple[int, Any]))
2543*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(issubinstance(d, Tuple[int, int, int]))
2544*da0073e9SAndroid Build Coastguard Worker
2545*da0073e9SAndroid Build Coastguard Worker    # Static checking annotation
2546*da0073e9SAndroid Build Coastguard Worker    @skipTyping
2547*da0073e9SAndroid Build Coastguard Worker    def test_compile_time(self):
2548*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, r"Expected 'Iterator' as the return"):
2549*da0073e9SAndroid Build Coastguard Worker
2550*da0073e9SAndroid Build Coastguard Worker            class InvalidDP1(IterDataPipe[int]):
2551*da0073e9SAndroid Build Coastguard Worker                def __iter__(self) -> str:  # type: ignore[misc, override]
2552*da0073e9SAndroid Build Coastguard Worker                    yield 0
2553*da0073e9SAndroid Build Coastguard Worker
2554*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"):
2555*da0073e9SAndroid Build Coastguard Worker
2556*da0073e9SAndroid Build Coastguard Worker            class InvalidDP2(IterDataPipe[Tuple]):
2557*da0073e9SAndroid Build Coastguard Worker                def __iter__(self) -> Iterator[int]:  # type: ignore[override]
2558*da0073e9SAndroid Build Coastguard Worker                    yield 0
2559*da0073e9SAndroid Build Coastguard Worker
2560*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"):
2561*da0073e9SAndroid Build Coastguard Worker
2562*da0073e9SAndroid Build Coastguard Worker            class InvalidDP3(IterDataPipe[Tuple[int, str]]):
2563*da0073e9SAndroid Build Coastguard Worker                def __iter__(self) -> Iterator[tuple]:  # type: ignore[override]
2564*da0073e9SAndroid Build Coastguard Worker                    yield (0,)
2565*da0073e9SAndroid Build Coastguard Worker
2566*da0073e9SAndroid Build Coastguard Worker        if _generic_namedtuple_allowed:
2567*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
2568*da0073e9SAndroid Build Coastguard Worker                TypeError, r"is not supported by Python typing"
2569*da0073e9SAndroid Build Coastguard Worker            ):
2570*da0073e9SAndroid Build Coastguard Worker
2571*da0073e9SAndroid Build Coastguard Worker                class InvalidDP4(IterDataPipe["InvalidData[int]"]):  # type: ignore[type-arg, misc]
2572*da0073e9SAndroid Build Coastguard Worker                    pass
2573*da0073e9SAndroid Build Coastguard Worker
2574*da0073e9SAndroid Build Coastguard Worker        class DP1(IterDataPipe[Tuple[int, str]]):
2575*da0073e9SAndroid Build Coastguard Worker            def __init__(self, length):
2576*da0073e9SAndroid Build Coastguard Worker                self.length = length
2577*da0073e9SAndroid Build Coastguard Worker
2578*da0073e9SAndroid Build Coastguard Worker            def __iter__(self) -> Iterator[Tuple[int, str]]:
2579*da0073e9SAndroid Build Coastguard Worker                for d in range(self.length):
2580*da0073e9SAndroid Build Coastguard Worker                    yield d, str(d)
2581*da0073e9SAndroid Build Coastguard Worker
2582*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubclass(DP1, IterDataPipe))
2583*da0073e9SAndroid Build Coastguard Worker        dp1 = DP1(10)
2584*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(DP1.type.issubtype(dp1.type) and dp1.type.issubtype(DP1.type))  # type: ignore[attr-defined]
2585*da0073e9SAndroid Build Coastguard Worker        dp1_ = DP1(5)
2586*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dp1.type, dp1_.type)
2587*da0073e9SAndroid Build Coastguard Worker
2588*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, r"is not a generic class"):
2589*da0073e9SAndroid Build Coastguard Worker
2590*da0073e9SAndroid Build Coastguard Worker            class InvalidDP5(DP1[tuple]):  # type: ignore[type-arg]
2591*da0073e9SAndroid Build Coastguard Worker                def __iter__(self) -> Iterator[tuple]:  # type: ignore[override]
2592*da0073e9SAndroid Build Coastguard Worker                    yield (0,)
2593*da0073e9SAndroid Build Coastguard Worker
2594*da0073e9SAndroid Build Coastguard Worker        class DP2(IterDataPipe[T_co]):
2595*da0073e9SAndroid Build Coastguard Worker            def __iter__(self) -> Iterator[T_co]:
2596*da0073e9SAndroid Build Coastguard Worker                yield from range(10)  # type: ignore[misc]
2597*da0073e9SAndroid Build Coastguard Worker
2598*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubclass(DP2, IterDataPipe))
2599*da0073e9SAndroid Build Coastguard Worker        dp2 = DP2()  # type: ignore[var-annotated]
2600*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(DP2.type.issubtype(dp2.type) and dp2.type.issubtype(DP2.type))  # type: ignore[attr-defined]
2601*da0073e9SAndroid Build Coastguard Worker        dp2_ = DP2()  # type: ignore[var-annotated]
2602*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dp2.type, dp2_.type)
2603*da0073e9SAndroid Build Coastguard Worker
2604*da0073e9SAndroid Build Coastguard Worker        class DP3(IterDataPipe[Tuple[T_co, str]]):
2605*da0073e9SAndroid Build Coastguard Worker            r"""DataPipe without fixed type with __init__ function"""
2606*da0073e9SAndroid Build Coastguard Worker
2607*da0073e9SAndroid Build Coastguard Worker            def __init__(self, datasource):
2608*da0073e9SAndroid Build Coastguard Worker                self.datasource = datasource
2609*da0073e9SAndroid Build Coastguard Worker
2610*da0073e9SAndroid Build Coastguard Worker            def __iter__(self) -> Iterator[Tuple[T_co, str]]:
2611*da0073e9SAndroid Build Coastguard Worker                for d in self.datasource:
2612*da0073e9SAndroid Build Coastguard Worker                    yield d, str(d)
2613*da0073e9SAndroid Build Coastguard Worker
2614*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubclass(DP3, IterDataPipe))
2615*da0073e9SAndroid Build Coastguard Worker        dp3 = DP3(range(10))  # type: ignore[var-annotated]
2616*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(DP3.type.issubtype(dp3.type) and dp3.type.issubtype(DP3.type))  # type: ignore[attr-defined]
2617*da0073e9SAndroid Build Coastguard Worker        dp3_ = DP3(5)  # type: ignore[var-annotated]
2618*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dp3.type, dp3_.type)
2619*da0073e9SAndroid Build Coastguard Worker
2620*da0073e9SAndroid Build Coastguard Worker        class DP4(IterDataPipe[tuple]):
2621*da0073e9SAndroid Build Coastguard Worker            r"""DataPipe without __iter__ annotation"""
2622*da0073e9SAndroid Build Coastguard Worker
2623*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
2624*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
2625*da0073e9SAndroid Build Coastguard Worker
2626*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubclass(DP4, IterDataPipe))
2627*da0073e9SAndroid Build Coastguard Worker        dp4 = DP4()
2628*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(dp4.type.param == tuple)
2629*da0073e9SAndroid Build Coastguard Worker
2630*da0073e9SAndroid Build Coastguard Worker        class DP5(IterDataPipe):
2631*da0073e9SAndroid Build Coastguard Worker            r"""DataPipe without type annotation"""
2632*da0073e9SAndroid Build Coastguard Worker
2633*da0073e9SAndroid Build Coastguard Worker            def __iter__(self) -> Iterator[str]:
2634*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
2635*da0073e9SAndroid Build Coastguard Worker
2636*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubclass(DP5, IterDataPipe))
2637*da0073e9SAndroid Build Coastguard Worker        dp5 = DP5()
2638*da0073e9SAndroid Build Coastguard Worker        from torch.utils.data.datapipes._typing import issubtype
2639*da0073e9SAndroid Build Coastguard Worker
2640*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
2641*da0073e9SAndroid Build Coastguard Worker            issubtype(dp5.type.param, Any) and issubtype(Any, dp5.type.param)
2642*da0073e9SAndroid Build Coastguard Worker        )
2643*da0073e9SAndroid Build Coastguard Worker
2644*da0073e9SAndroid Build Coastguard Worker        class DP6(IterDataPipe[int]):
2645*da0073e9SAndroid Build Coastguard Worker            r"""DataPipe with plain Iterator"""
2646*da0073e9SAndroid Build Coastguard Worker
2647*da0073e9SAndroid Build Coastguard Worker            def __iter__(self) -> Iterator:
2648*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError
2649*da0073e9SAndroid Build Coastguard Worker
2650*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubclass(DP6, IterDataPipe))
2651*da0073e9SAndroid Build Coastguard Worker        dp6 = DP6()
2652*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(dp6.type.param == int)
2653*da0073e9SAndroid Build Coastguard Worker
2654*da0073e9SAndroid Build Coastguard Worker        class DP7(IterDataPipe[Awaitable[T_co]]):
2655*da0073e9SAndroid Build Coastguard Worker            r"""DataPipe with abstract base class"""
2656*da0073e9SAndroid Build Coastguard Worker
2657*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubclass(DP7, IterDataPipe))
2658*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(DP7.type.param == Awaitable[T_co])  # type: ignore[attr-defined]
2659*da0073e9SAndroid Build Coastguard Worker
2660*da0073e9SAndroid Build Coastguard Worker        class DP8(DP7[str]):
2661*da0073e9SAndroid Build Coastguard Worker            r"""DataPipe subclass from a DataPipe with abc type"""
2662*da0073e9SAndroid Build Coastguard Worker
2663*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubclass(DP8, IterDataPipe))
2664*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(DP8.type.param == Awaitable[str])  # type: ignore[attr-defined]
2665*da0073e9SAndroid Build Coastguard Worker
2666*da0073e9SAndroid Build Coastguard Worker    @skipTyping
2667*da0073e9SAndroid Build Coastguard Worker    def test_construct_time(self):
2668*da0073e9SAndroid Build Coastguard Worker        class DP0(IterDataPipe[Tuple]):
2669*da0073e9SAndroid Build Coastguard Worker            @argument_validation
2670*da0073e9SAndroid Build Coastguard Worker            def __init__(self, dp: IterDataPipe):
2671*da0073e9SAndroid Build Coastguard Worker                self.dp = dp
2672*da0073e9SAndroid Build Coastguard Worker
2673*da0073e9SAndroid Build Coastguard Worker            def __iter__(self) -> Iterator[Tuple]:
2674*da0073e9SAndroid Build Coastguard Worker                for d in self.dp:
2675*da0073e9SAndroid Build Coastguard Worker                    yield d, str(d)
2676*da0073e9SAndroid Build Coastguard Worker
2677*da0073e9SAndroid Build Coastguard Worker        class DP1(IterDataPipe[int]):
2678*da0073e9SAndroid Build Coastguard Worker            @argument_validation
2679*da0073e9SAndroid Build Coastguard Worker            def __init__(self, dp: IterDataPipe[Tuple[int, str]]):
2680*da0073e9SAndroid Build Coastguard Worker                self.dp = dp
2681*da0073e9SAndroid Build Coastguard Worker
2682*da0073e9SAndroid Build Coastguard Worker            def __iter__(self) -> Iterator[int]:
2683*da0073e9SAndroid Build Coastguard Worker                for a, b in self.dp:
2684*da0073e9SAndroid Build Coastguard Worker                    yield a
2685*da0073e9SAndroid Build Coastguard Worker
2686*da0073e9SAndroid Build Coastguard Worker        # Non-DataPipe input with DataPipe hint
2687*da0073e9SAndroid Build Coastguard Worker        datasource = [(1, "1"), (2, "2"), (3, "3")]
2688*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2689*da0073e9SAndroid Build Coastguard Worker            TypeError, r"Expected argument 'dp' as a IterDataPipe"
2690*da0073e9SAndroid Build Coastguard Worker        ):
2691*da0073e9SAndroid Build Coastguard Worker            dp0 = DP0(datasource)
2692*da0073e9SAndroid Build Coastguard Worker
2693*da0073e9SAndroid Build Coastguard Worker        dp0 = DP0(dp.iter.IterableWrapper(range(10)))
2694*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2695*da0073e9SAndroid Build Coastguard Worker            TypeError, r"Expected type of argument 'dp' as a subtype"
2696*da0073e9SAndroid Build Coastguard Worker        ):
2697*da0073e9SAndroid Build Coastguard Worker            dp1 = DP1(dp0)
2698*da0073e9SAndroid Build Coastguard Worker
2699*da0073e9SAndroid Build Coastguard Worker    @skipTyping
2700*da0073e9SAndroid Build Coastguard Worker    def test_runtime(self):
2701*da0073e9SAndroid Build Coastguard Worker        class DP(IterDataPipe[Tuple[int, T_co]]):
2702*da0073e9SAndroid Build Coastguard Worker            def __init__(self, datasource):
2703*da0073e9SAndroid Build Coastguard Worker                self.ds = datasource
2704*da0073e9SAndroid Build Coastguard Worker
2705*da0073e9SAndroid Build Coastguard Worker            @runtime_validation
2706*da0073e9SAndroid Build Coastguard Worker            def __iter__(self) -> Iterator[Tuple[int, T_co]]:
2707*da0073e9SAndroid Build Coastguard Worker                yield from self.ds
2708*da0073e9SAndroid Build Coastguard Worker
2709*da0073e9SAndroid Build Coastguard Worker        dss = ([(1, "1"), (2, "2")], [(1, 1), (2, "2")])
2710*da0073e9SAndroid Build Coastguard Worker        for ds in dss:
2711*da0073e9SAndroid Build Coastguard Worker            dp0 = DP(ds)  # type: ignore[var-annotated]
2712*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(dp0), ds)
2713*da0073e9SAndroid Build Coastguard Worker            # Reset __iter__
2714*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(dp0), ds)
2715*da0073e9SAndroid Build Coastguard Worker
2716*da0073e9SAndroid Build Coastguard Worker        dss = (
2717*da0073e9SAndroid Build Coastguard Worker            [(1, 1), ("2", 2)],  # type: ignore[assignment, list-item]
2718*da0073e9SAndroid Build Coastguard Worker            [[1, "1"], [2, "2"]],  # type: ignore[list-item]
2719*da0073e9SAndroid Build Coastguard Worker            [1, "1", 2, "2"],
2720*da0073e9SAndroid Build Coastguard Worker        )
2721*da0073e9SAndroid Build Coastguard Worker        for ds in dss:
2722*da0073e9SAndroid Build Coastguard Worker            dp0 = DP(ds)
2723*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
2724*da0073e9SAndroid Build Coastguard Worker                RuntimeError, r"Expected an instance as subtype"
2725*da0073e9SAndroid Build Coastguard Worker            ):
2726*da0073e9SAndroid Build Coastguard Worker                list(dp0)
2727*da0073e9SAndroid Build Coastguard Worker
2728*da0073e9SAndroid Build Coastguard Worker            with runtime_validation_disabled():
2729*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(list(dp0), ds)
2730*da0073e9SAndroid Build Coastguard Worker                with runtime_validation_disabled():
2731*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(list(dp0), ds)
2732*da0073e9SAndroid Build Coastguard Worker
2733*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
2734*da0073e9SAndroid Build Coastguard Worker                RuntimeError, r"Expected an instance as subtype"
2735*da0073e9SAndroid Build Coastguard Worker            ):
2736*da0073e9SAndroid Build Coastguard Worker                list(dp0)
2737*da0073e9SAndroid Build Coastguard Worker
2738*da0073e9SAndroid Build Coastguard Worker    @skipTyping
2739*da0073e9SAndroid Build Coastguard Worker    def test_reinforce(self):
2740*da0073e9SAndroid Build Coastguard Worker        T = TypeVar("T", int, str)
2741*da0073e9SAndroid Build Coastguard Worker
2742*da0073e9SAndroid Build Coastguard Worker        class DP(IterDataPipe[T]):
2743*da0073e9SAndroid Build Coastguard Worker            def __init__(self, ds):
2744*da0073e9SAndroid Build Coastguard Worker                self.ds = ds
2745*da0073e9SAndroid Build Coastguard Worker
2746*da0073e9SAndroid Build Coastguard Worker            @runtime_validation
2747*da0073e9SAndroid Build Coastguard Worker            def __iter__(self) -> Iterator[T]:
2748*da0073e9SAndroid Build Coastguard Worker                yield from self.ds
2749*da0073e9SAndroid Build Coastguard Worker
2750*da0073e9SAndroid Build Coastguard Worker        ds = list(range(10))
2751*da0073e9SAndroid Build Coastguard Worker        # Valid type reinforcement
2752*da0073e9SAndroid Build Coastguard Worker        dp0 = DP(ds).reinforce_type(int)
2753*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(dp0.type, int)
2754*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(dp0), ds)
2755*da0073e9SAndroid Build Coastguard Worker
2756*da0073e9SAndroid Build Coastguard Worker        # Invalid type
2757*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, r"'expected_type' must be a type"):
2758*da0073e9SAndroid Build Coastguard Worker            dp1 = DP(ds).reinforce_type(1)
2759*da0073e9SAndroid Build Coastguard Worker
2760*da0073e9SAndroid Build Coastguard Worker        # Type is not subtype
2761*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2762*da0073e9SAndroid Build Coastguard Worker            TypeError, r"Expected 'expected_type' as subtype of"
2763*da0073e9SAndroid Build Coastguard Worker        ):
2764*da0073e9SAndroid Build Coastguard Worker            dp2 = DP(ds).reinforce_type(float)
2765*da0073e9SAndroid Build Coastguard Worker
2766*da0073e9SAndroid Build Coastguard Worker        # Invalid data at runtime
2767*da0073e9SAndroid Build Coastguard Worker        dp3 = DP(ds).reinforce_type(str)
2768*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"):
2769*da0073e9SAndroid Build Coastguard Worker            list(dp3)
2770*da0073e9SAndroid Build Coastguard Worker
2771*da0073e9SAndroid Build Coastguard Worker        # Context Manager to disable the runtime validation
2772*da0073e9SAndroid Build Coastguard Worker        with runtime_validation_disabled():
2773*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(dp3), ds)
2774*da0073e9SAndroid Build Coastguard Worker
2775*da0073e9SAndroid Build Coastguard Worker
2776*da0073e9SAndroid Build Coastguard Workerclass NumbersDataset(IterDataPipe):
2777*da0073e9SAndroid Build Coastguard Worker    def __init__(self, size=10):
2778*da0073e9SAndroid Build Coastguard Worker        self.size = size
2779*da0073e9SAndroid Build Coastguard Worker
2780*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
2781*da0073e9SAndroid Build Coastguard Worker        yield from range(self.size)
2782*da0073e9SAndroid Build Coastguard Worker
2783*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
2784*da0073e9SAndroid Build Coastguard Worker        return self.size
2785*da0073e9SAndroid Build Coastguard Worker
2786*da0073e9SAndroid Build Coastguard Worker
2787*da0073e9SAndroid Build Coastguard Workerclass TestGraph(TestCase):
2788*da0073e9SAndroid Build Coastguard Worker    class CustomIterDataPipe(IterDataPipe):
2789*da0073e9SAndroid Build Coastguard Worker        def add_v(self, x):
2790*da0073e9SAndroid Build Coastguard Worker            return x + self.v
2791*da0073e9SAndroid Build Coastguard Worker
2792*da0073e9SAndroid Build Coastguard Worker        def __init__(self, source_dp, v=1):
2793*da0073e9SAndroid Build Coastguard Worker            self._dp = source_dp.map(self.add_v)
2794*da0073e9SAndroid Build Coastguard Worker            self.v = 1
2795*da0073e9SAndroid Build Coastguard Worker
2796*da0073e9SAndroid Build Coastguard Worker        def __iter__(self):
2797*da0073e9SAndroid Build Coastguard Worker            yield from self._dp
2798*da0073e9SAndroid Build Coastguard Worker
2799*da0073e9SAndroid Build Coastguard Worker        def __hash__(self):
2800*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
2801*da0073e9SAndroid Build Coastguard Worker
2802*da0073e9SAndroid Build Coastguard Worker    def test_simple_traverse(self):
2803*da0073e9SAndroid Build Coastguard Worker        numbers_dp = NumbersDataset(size=50)
2804*da0073e9SAndroid Build Coastguard Worker        shuffled_dp = numbers_dp.shuffle()
2805*da0073e9SAndroid Build Coastguard Worker        sharded_dp = shuffled_dp.sharding_filter()
2806*da0073e9SAndroid Build Coastguard Worker        mapped_dp = sharded_dp.map(lambda x: x * 10)
2807*da0073e9SAndroid Build Coastguard Worker        graph = traverse_dps(mapped_dp)
2808*da0073e9SAndroid Build Coastguard Worker        expected: Dict[Any, Any] = {
2809*da0073e9SAndroid Build Coastguard Worker            id(mapped_dp): (
2810*da0073e9SAndroid Build Coastguard Worker                mapped_dp,
2811*da0073e9SAndroid Build Coastguard Worker                {
2812*da0073e9SAndroid Build Coastguard Worker                    id(sharded_dp): (
2813*da0073e9SAndroid Build Coastguard Worker                        sharded_dp,
2814*da0073e9SAndroid Build Coastguard Worker                        {
2815*da0073e9SAndroid Build Coastguard Worker                            id(shuffled_dp): (
2816*da0073e9SAndroid Build Coastguard Worker                                shuffled_dp,
2817*da0073e9SAndroid Build Coastguard Worker                                {id(numbers_dp): (numbers_dp, {})},
2818*da0073e9SAndroid Build Coastguard Worker                            )
2819*da0073e9SAndroid Build Coastguard Worker                        },
2820*da0073e9SAndroid Build Coastguard Worker                    )
2821*da0073e9SAndroid Build Coastguard Worker                },
2822*da0073e9SAndroid Build Coastguard Worker            )
2823*da0073e9SAndroid Build Coastguard Worker        }
2824*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, graph)
2825*da0073e9SAndroid Build Coastguard Worker
2826*da0073e9SAndroid Build Coastguard Worker        dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph)
2827*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(dps), 4)
2828*da0073e9SAndroid Build Coastguard Worker        for datapipe in (numbers_dp, shuffled_dp, sharded_dp, mapped_dp):
2829*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(datapipe in dps)
2830*da0073e9SAndroid Build Coastguard Worker
2831*da0073e9SAndroid Build Coastguard Worker    def test_traverse_forked(self):
2832*da0073e9SAndroid Build Coastguard Worker        numbers_dp = NumbersDataset(size=50)
2833*da0073e9SAndroid Build Coastguard Worker        dp0, dp1, dp2 = numbers_dp.fork(num_instances=3)
2834*da0073e9SAndroid Build Coastguard Worker        dp0_upd = dp0.map(lambda x: x * 10)
2835*da0073e9SAndroid Build Coastguard Worker        dp1_upd = dp1.filter(lambda x: x % 3 == 1)
2836*da0073e9SAndroid Build Coastguard Worker        combined_dp = dp0_upd.mux(dp1_upd, dp2)
2837*da0073e9SAndroid Build Coastguard Worker        graph = traverse_dps(combined_dp)
2838*da0073e9SAndroid Build Coastguard Worker        expected = {
2839*da0073e9SAndroid Build Coastguard Worker            id(combined_dp): (
2840*da0073e9SAndroid Build Coastguard Worker                combined_dp,
2841*da0073e9SAndroid Build Coastguard Worker                {
2842*da0073e9SAndroid Build Coastguard Worker                    id(dp0_upd): (
2843*da0073e9SAndroid Build Coastguard Worker                        dp0_upd,
2844*da0073e9SAndroid Build Coastguard Worker                        {
2845*da0073e9SAndroid Build Coastguard Worker                            id(dp0): (
2846*da0073e9SAndroid Build Coastguard Worker                                dp0,
2847*da0073e9SAndroid Build Coastguard Worker                                {
2848*da0073e9SAndroid Build Coastguard Worker                                    id(dp0.main_datapipe): (
2849*da0073e9SAndroid Build Coastguard Worker                                        dp0.main_datapipe,
2850*da0073e9SAndroid Build Coastguard Worker                                        {
2851*da0073e9SAndroid Build Coastguard Worker                                            id(dp0.main_datapipe.main_datapipe): (
2852*da0073e9SAndroid Build Coastguard Worker                                                dp0.main_datapipe.main_datapipe,
2853*da0073e9SAndroid Build Coastguard Worker                                                {},
2854*da0073e9SAndroid Build Coastguard Worker                                            )
2855*da0073e9SAndroid Build Coastguard Worker                                        },
2856*da0073e9SAndroid Build Coastguard Worker                                    )
2857*da0073e9SAndroid Build Coastguard Worker                                },
2858*da0073e9SAndroid Build Coastguard Worker                            )
2859*da0073e9SAndroid Build Coastguard Worker                        },
2860*da0073e9SAndroid Build Coastguard Worker                    ),
2861*da0073e9SAndroid Build Coastguard Worker                    id(dp1_upd): (
2862*da0073e9SAndroid Build Coastguard Worker                        dp1_upd,
2863*da0073e9SAndroid Build Coastguard Worker                        {
2864*da0073e9SAndroid Build Coastguard Worker                            id(dp1): (
2865*da0073e9SAndroid Build Coastguard Worker                                dp1,
2866*da0073e9SAndroid Build Coastguard Worker                                {
2867*da0073e9SAndroid Build Coastguard Worker                                    id(dp1.main_datapipe): (
2868*da0073e9SAndroid Build Coastguard Worker                                        dp1.main_datapipe,
2869*da0073e9SAndroid Build Coastguard Worker                                        {
2870*da0073e9SAndroid Build Coastguard Worker                                            id(dp1.main_datapipe.main_datapipe): (
2871*da0073e9SAndroid Build Coastguard Worker                                                dp1.main_datapipe.main_datapipe,
2872*da0073e9SAndroid Build Coastguard Worker                                                {},
2873*da0073e9SAndroid Build Coastguard Worker                                            )
2874*da0073e9SAndroid Build Coastguard Worker                                        },
2875*da0073e9SAndroid Build Coastguard Worker                                    )
2876*da0073e9SAndroid Build Coastguard Worker                                },
2877*da0073e9SAndroid Build Coastguard Worker                            )
2878*da0073e9SAndroid Build Coastguard Worker                        },
2879*da0073e9SAndroid Build Coastguard Worker                    ),
2880*da0073e9SAndroid Build Coastguard Worker                    id(dp2): (
2881*da0073e9SAndroid Build Coastguard Worker                        dp2,
2882*da0073e9SAndroid Build Coastguard Worker                        {
2883*da0073e9SAndroid Build Coastguard Worker                            id(dp2.main_datapipe): (
2884*da0073e9SAndroid Build Coastguard Worker                                dp2.main_datapipe,
2885*da0073e9SAndroid Build Coastguard Worker                                {
2886*da0073e9SAndroid Build Coastguard Worker                                    id(dp2.main_datapipe.main_datapipe): (
2887*da0073e9SAndroid Build Coastguard Worker                                        dp2.main_datapipe.main_datapipe,
2888*da0073e9SAndroid Build Coastguard Worker                                        {},
2889*da0073e9SAndroid Build Coastguard Worker                                    )
2890*da0073e9SAndroid Build Coastguard Worker                                },
2891*da0073e9SAndroid Build Coastguard Worker                            )
2892*da0073e9SAndroid Build Coastguard Worker                        },
2893*da0073e9SAndroid Build Coastguard Worker                    ),
2894*da0073e9SAndroid Build Coastguard Worker                },
2895*da0073e9SAndroid Build Coastguard Worker            )
2896*da0073e9SAndroid Build Coastguard Worker        }
2897*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, graph)
2898*da0073e9SAndroid Build Coastguard Worker
2899*da0073e9SAndroid Build Coastguard Worker        dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph)
2900*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(dps), 8)
2901*da0073e9SAndroid Build Coastguard Worker        for _dp in [
2902*da0073e9SAndroid Build Coastguard Worker            numbers_dp,
2903*da0073e9SAndroid Build Coastguard Worker            dp0.main_datapipe,
2904*da0073e9SAndroid Build Coastguard Worker            dp0,
2905*da0073e9SAndroid Build Coastguard Worker            dp1,
2906*da0073e9SAndroid Build Coastguard Worker            dp2,
2907*da0073e9SAndroid Build Coastguard Worker            dp0_upd,
2908*da0073e9SAndroid Build Coastguard Worker            dp1_upd,
2909*da0073e9SAndroid Build Coastguard Worker            combined_dp,
2910*da0073e9SAndroid Build Coastguard Worker        ]:
2911*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(_dp in dps)
2912*da0073e9SAndroid Build Coastguard Worker
2913*da0073e9SAndroid Build Coastguard Worker    def test_traverse_mapdatapipe(self):
2914*da0073e9SAndroid Build Coastguard Worker        source_dp = dp.map.SequenceWrapper(range(10))
2915*da0073e9SAndroid Build Coastguard Worker        map_dp = source_dp.map(partial(_fake_add, 1))
2916*da0073e9SAndroid Build Coastguard Worker        graph = traverse_dps(map_dp)
2917*da0073e9SAndroid Build Coastguard Worker        expected: Dict[Any, Any] = {
2918*da0073e9SAndroid Build Coastguard Worker            id(map_dp): (map_dp, {id(source_dp): (source_dp, {})})
2919*da0073e9SAndroid Build Coastguard Worker        }
2920*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, graph)
2921*da0073e9SAndroid Build Coastguard Worker
2922*da0073e9SAndroid Build Coastguard Worker    def test_traverse_mixdatapipe(self):
2923*da0073e9SAndroid Build Coastguard Worker        source_map_dp = dp.map.SequenceWrapper(range(10))
2924*da0073e9SAndroid Build Coastguard Worker        iter_dp = dp.iter.IterableWrapper(source_map_dp)
2925*da0073e9SAndroid Build Coastguard Worker        graph = traverse_dps(iter_dp)
2926*da0073e9SAndroid Build Coastguard Worker        expected: Dict[Any, Any] = {
2927*da0073e9SAndroid Build Coastguard Worker            id(iter_dp): (iter_dp, {id(source_map_dp): (source_map_dp, {})})
2928*da0073e9SAndroid Build Coastguard Worker        }
2929*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, graph)
2930*da0073e9SAndroid Build Coastguard Worker
2931*da0073e9SAndroid Build Coastguard Worker    def test_traverse_circular_datapipe(self):
2932*da0073e9SAndroid Build Coastguard Worker        source_iter_dp = dp.iter.IterableWrapper(list(range(10)))
2933*da0073e9SAndroid Build Coastguard Worker        circular_dp = TestGraph.CustomIterDataPipe(source_iter_dp)
2934*da0073e9SAndroid Build Coastguard Worker        graph = traverse_dps(circular_dp)
2935*da0073e9SAndroid Build Coastguard Worker        # See issue: https://github.com/pytorch/data/issues/535
2936*da0073e9SAndroid Build Coastguard Worker        expected: Dict[Any, Any] = {
2937*da0073e9SAndroid Build Coastguard Worker            id(circular_dp): (
2938*da0073e9SAndroid Build Coastguard Worker                circular_dp,
2939*da0073e9SAndroid Build Coastguard Worker                {
2940*da0073e9SAndroid Build Coastguard Worker                    id(circular_dp._dp): (
2941*da0073e9SAndroid Build Coastguard Worker                        circular_dp._dp,
2942*da0073e9SAndroid Build Coastguard Worker                        {id(source_iter_dp): (source_iter_dp, {})},
2943*da0073e9SAndroid Build Coastguard Worker                    )
2944*da0073e9SAndroid Build Coastguard Worker                },
2945*da0073e9SAndroid Build Coastguard Worker            )
2946*da0073e9SAndroid Build Coastguard Worker        }
2947*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, graph)
2948*da0073e9SAndroid Build Coastguard Worker
2949*da0073e9SAndroid Build Coastguard Worker        dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph)
2950*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(dps), 3)
2951*da0073e9SAndroid Build Coastguard Worker        for _dp in [circular_dp, circular_dp._dp, source_iter_dp]:
2952*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(_dp in dps)
2953*da0073e9SAndroid Build Coastguard Worker
2954*da0073e9SAndroid Build Coastguard Worker    def test_traverse_unhashable_datapipe(self):
2955*da0073e9SAndroid Build Coastguard Worker        source_iter_dp = dp.iter.IterableWrapper(list(range(10)))
2956*da0073e9SAndroid Build Coastguard Worker        unhashable_dp = TestGraph.CustomIterDataPipe(source_iter_dp)
2957*da0073e9SAndroid Build Coastguard Worker        graph = traverse_dps(unhashable_dp)
2958*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(NotImplementedError):
2959*da0073e9SAndroid Build Coastguard Worker            hash(unhashable_dp)
2960*da0073e9SAndroid Build Coastguard Worker        expected: Dict[Any, Any] = {
2961*da0073e9SAndroid Build Coastguard Worker            id(unhashable_dp): (
2962*da0073e9SAndroid Build Coastguard Worker                unhashable_dp,
2963*da0073e9SAndroid Build Coastguard Worker                {
2964*da0073e9SAndroid Build Coastguard Worker                    id(unhashable_dp._dp): (
2965*da0073e9SAndroid Build Coastguard Worker                        unhashable_dp._dp,
2966*da0073e9SAndroid Build Coastguard Worker                        {id(source_iter_dp): (source_iter_dp, {})},
2967*da0073e9SAndroid Build Coastguard Worker                    )
2968*da0073e9SAndroid Build Coastguard Worker                },
2969*da0073e9SAndroid Build Coastguard Worker            )
2970*da0073e9SAndroid Build Coastguard Worker        }
2971*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, graph)
2972*da0073e9SAndroid Build Coastguard Worker
2973*da0073e9SAndroid Build Coastguard Worker
2974*da0073e9SAndroid Build Coastguard Workerdef unbatch(x):
2975*da0073e9SAndroid Build Coastguard Worker    return x[0]
2976*da0073e9SAndroid Build Coastguard Worker
2977*da0073e9SAndroid Build Coastguard Worker
2978*da0073e9SAndroid Build Coastguard Workerclass TestSerialization(TestCase):
2979*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDill
2980*da0073e9SAndroid Build Coastguard Worker    def test_spawn_lambdas_iter(self):
2981*da0073e9SAndroid Build Coastguard Worker        idp = dp.iter.IterableWrapper(range(3)).map(lambda x: x + 1).shuffle()
2982*da0073e9SAndroid Build Coastguard Worker        dl = DataLoader(
2983*da0073e9SAndroid Build Coastguard Worker            idp,
2984*da0073e9SAndroid Build Coastguard Worker            num_workers=2,
2985*da0073e9SAndroid Build Coastguard Worker            shuffle=True,
2986*da0073e9SAndroid Build Coastguard Worker            multiprocessing_context="spawn",
2987*da0073e9SAndroid Build Coastguard Worker            collate_fn=unbatch,
2988*da0073e9SAndroid Build Coastguard Worker            batch_size=1,
2989*da0073e9SAndroid Build Coastguard Worker        )
2990*da0073e9SAndroid Build Coastguard Worker        result = list(dl)
2991*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([1, 1, 2, 2, 3, 3], sorted(result))
2992*da0073e9SAndroid Build Coastguard Worker
2993*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDill
2994*da0073e9SAndroid Build Coastguard Worker    def test_spawn_lambdas_map(self):
2995*da0073e9SAndroid Build Coastguard Worker        mdp = dp.map.SequenceWrapper(range(3)).map(lambda x: x + 1).shuffle()
2996*da0073e9SAndroid Build Coastguard Worker        dl = DataLoader(
2997*da0073e9SAndroid Build Coastguard Worker            mdp,
2998*da0073e9SAndroid Build Coastguard Worker            num_workers=2,
2999*da0073e9SAndroid Build Coastguard Worker            shuffle=True,
3000*da0073e9SAndroid Build Coastguard Worker            multiprocessing_context="spawn",
3001*da0073e9SAndroid Build Coastguard Worker            collate_fn=unbatch,
3002*da0073e9SAndroid Build Coastguard Worker            batch_size=1,
3003*da0073e9SAndroid Build Coastguard Worker        )
3004*da0073e9SAndroid Build Coastguard Worker        result = list(dl)
3005*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([1, 1, 2, 2, 3, 3], sorted(result))
3006*da0073e9SAndroid Build Coastguard Worker
3007*da0073e9SAndroid Build Coastguard Worker
3008*da0073e9SAndroid Build Coastguard Workerclass TestCircularSerialization(TestCase):
3009*da0073e9SAndroid Build Coastguard Worker    class CustomIterDataPipe(IterDataPipe):
3010*da0073e9SAndroid Build Coastguard Worker        @staticmethod
3011*da0073e9SAndroid Build Coastguard Worker        def add_one(x):
3012*da0073e9SAndroid Build Coastguard Worker            return x + 1
3013*da0073e9SAndroid Build Coastguard Worker
3014*da0073e9SAndroid Build Coastguard Worker        @classmethod
3015*da0073e9SAndroid Build Coastguard Worker        def classify(cls, x):
3016*da0073e9SAndroid Build Coastguard Worker            return 0
3017*da0073e9SAndroid Build Coastguard Worker
3018*da0073e9SAndroid Build Coastguard Worker        def add_v(self, x):
3019*da0073e9SAndroid Build Coastguard Worker            return x + self.v
3020*da0073e9SAndroid Build Coastguard Worker
3021*da0073e9SAndroid Build Coastguard Worker        def __init__(self, fn, source_dp=None):
3022*da0073e9SAndroid Build Coastguard Worker            self.fn = fn
3023*da0073e9SAndroid Build Coastguard Worker            self.source_dp = (
3024*da0073e9SAndroid Build Coastguard Worker                source_dp if source_dp else dp.iter.IterableWrapper([1, 2, 4])
3025*da0073e9SAndroid Build Coastguard Worker            )
3026*da0073e9SAndroid Build Coastguard Worker            self._dp = (
3027*da0073e9SAndroid Build Coastguard Worker                self.source_dp.map(self.add_one)
3028*da0073e9SAndroid Build Coastguard Worker                .map(self.add_v)
3029*da0073e9SAndroid Build Coastguard Worker                .demux(2, self.classify)[0]
3030*da0073e9SAndroid Build Coastguard Worker            )
3031*da0073e9SAndroid Build Coastguard Worker            self.v = 1
3032*da0073e9SAndroid Build Coastguard Worker
3033*da0073e9SAndroid Build Coastguard Worker        def __iter__(self):
3034*da0073e9SAndroid Build Coastguard Worker            yield from self._dp
3035*da0073e9SAndroid Build Coastguard Worker
3036*da0073e9SAndroid Build Coastguard Worker    def test_circular_serialization_with_pickle(self):
3037*da0073e9SAndroid Build Coastguard Worker        # Test for circular reference issue with pickle
3038*da0073e9SAndroid Build Coastguard Worker        dp1 = TestCircularSerialization.CustomIterDataPipe(fn=_fake_fn)
3039*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(list(dp1) == list(pickle.loads(pickle.dumps(dp1))))
3040*da0073e9SAndroid Build Coastguard Worker
3041*da0073e9SAndroid Build Coastguard Worker        child_1 = dp1._dp
3042*da0073e9SAndroid Build Coastguard Worker        dm_1 = child_1.main_datapipe
3043*da0073e9SAndroid Build Coastguard Worker        m2_1 = dm_1.main_datapipe
3044*da0073e9SAndroid Build Coastguard Worker        m1_1 = m2_1.datapipe
3045*da0073e9SAndroid Build Coastguard Worker        src_1 = m1_1.datapipe
3046*da0073e9SAndroid Build Coastguard Worker
3047*da0073e9SAndroid Build Coastguard Worker        res1 = traverse_dps(dp1)
3048*da0073e9SAndroid Build Coastguard Worker        exp_res_1 = {
3049*da0073e9SAndroid Build Coastguard Worker            id(dp1): (
3050*da0073e9SAndroid Build Coastguard Worker                dp1,
3051*da0073e9SAndroid Build Coastguard Worker                {
3052*da0073e9SAndroid Build Coastguard Worker                    id(src_1): (src_1, {}),
3053*da0073e9SAndroid Build Coastguard Worker                    id(child_1): (
3054*da0073e9SAndroid Build Coastguard Worker                        child_1,
3055*da0073e9SAndroid Build Coastguard Worker                        {
3056*da0073e9SAndroid Build Coastguard Worker                            id(dm_1): (
3057*da0073e9SAndroid Build Coastguard Worker                                dm_1,
3058*da0073e9SAndroid Build Coastguard Worker                                {
3059*da0073e9SAndroid Build Coastguard Worker                                    id(m2_1): (
3060*da0073e9SAndroid Build Coastguard Worker                                        m2_1,
3061*da0073e9SAndroid Build Coastguard Worker                                        {id(m1_1): (m1_1, {id(src_1): (src_1, {})})},
3062*da0073e9SAndroid Build Coastguard Worker                                    )
3063*da0073e9SAndroid Build Coastguard Worker                                },
3064*da0073e9SAndroid Build Coastguard Worker                            )
3065*da0073e9SAndroid Build Coastguard Worker                        },
3066*da0073e9SAndroid Build Coastguard Worker                    ),
3067*da0073e9SAndroid Build Coastguard Worker                },
3068*da0073e9SAndroid Build Coastguard Worker            )
3069*da0073e9SAndroid Build Coastguard Worker        }
3070*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, exp_res_1)
3071*da0073e9SAndroid Build Coastguard Worker        dp2 = TestCircularSerialization.CustomIterDataPipe(fn=_fake_fn, source_dp=dp1)
3072*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(list(dp2) == list(pickle.loads(pickle.dumps(dp2))))
3073*da0073e9SAndroid Build Coastguard Worker
3074*da0073e9SAndroid Build Coastguard Worker        child_2 = dp2._dp
3075*da0073e9SAndroid Build Coastguard Worker        dm_2 = child_2.main_datapipe
3076*da0073e9SAndroid Build Coastguard Worker        m2_2 = dm_2.main_datapipe
3077*da0073e9SAndroid Build Coastguard Worker        m1_2 = m2_2.datapipe
3078*da0073e9SAndroid Build Coastguard Worker
3079*da0073e9SAndroid Build Coastguard Worker        res2 = traverse_dps(dp2)
3080*da0073e9SAndroid Build Coastguard Worker        exp_res_2 = {
3081*da0073e9SAndroid Build Coastguard Worker            id(dp2): (
3082*da0073e9SAndroid Build Coastguard Worker                dp2,
3083*da0073e9SAndroid Build Coastguard Worker                {
3084*da0073e9SAndroid Build Coastguard Worker                    id(dp1): (
3085*da0073e9SAndroid Build Coastguard Worker                        dp1,
3086*da0073e9SAndroid Build Coastguard Worker                        {
3087*da0073e9SAndroid Build Coastguard Worker                            id(src_1): (src_1, {}),
3088*da0073e9SAndroid Build Coastguard Worker                            id(child_1): (
3089*da0073e9SAndroid Build Coastguard Worker                                child_1,
3090*da0073e9SAndroid Build Coastguard Worker                                {
3091*da0073e9SAndroid Build Coastguard Worker                                    id(dm_1): (
3092*da0073e9SAndroid Build Coastguard Worker                                        dm_1,
3093*da0073e9SAndroid Build Coastguard Worker                                        {
3094*da0073e9SAndroid Build Coastguard Worker                                            id(m2_1): (
3095*da0073e9SAndroid Build Coastguard Worker                                                m2_1,
3096*da0073e9SAndroid Build Coastguard Worker                                                {
3097*da0073e9SAndroid Build Coastguard Worker                                                    id(m1_1): (
3098*da0073e9SAndroid Build Coastguard Worker                                                        m1_1,
3099*da0073e9SAndroid Build Coastguard Worker                                                        {id(src_1): (src_1, {})},
3100*da0073e9SAndroid Build Coastguard Worker                                                    )
3101*da0073e9SAndroid Build Coastguard Worker                                                },
3102*da0073e9SAndroid Build Coastguard Worker                                            )
3103*da0073e9SAndroid Build Coastguard Worker                                        },
3104*da0073e9SAndroid Build Coastguard Worker                                    )
3105*da0073e9SAndroid Build Coastguard Worker                                },
3106*da0073e9SAndroid Build Coastguard Worker                            ),
3107*da0073e9SAndroid Build Coastguard Worker                        },
3108*da0073e9SAndroid Build Coastguard Worker                    ),
3109*da0073e9SAndroid Build Coastguard Worker                    id(child_2): (
3110*da0073e9SAndroid Build Coastguard Worker                        child_2,
3111*da0073e9SAndroid Build Coastguard Worker                        {
3112*da0073e9SAndroid Build Coastguard Worker                            id(dm_2): (
3113*da0073e9SAndroid Build Coastguard Worker                                dm_2,
3114*da0073e9SAndroid Build Coastguard Worker                                {
3115*da0073e9SAndroid Build Coastguard Worker                                    id(m2_2): (
3116*da0073e9SAndroid Build Coastguard Worker                                        m2_2,
3117*da0073e9SAndroid Build Coastguard Worker                                        {
3118*da0073e9SAndroid Build Coastguard Worker                                            id(m1_2): (
3119*da0073e9SAndroid Build Coastguard Worker                                                m1_2,
3120*da0073e9SAndroid Build Coastguard Worker                                                {
3121*da0073e9SAndroid Build Coastguard Worker                                                    id(dp1): (
3122*da0073e9SAndroid Build Coastguard Worker                                                        dp1,
3123*da0073e9SAndroid Build Coastguard Worker                                                        {
3124*da0073e9SAndroid Build Coastguard Worker                                                            id(src_1): (src_1, {}),
3125*da0073e9SAndroid Build Coastguard Worker                                                            id(child_1): (
3126*da0073e9SAndroid Build Coastguard Worker                                                                child_1,
3127*da0073e9SAndroid Build Coastguard Worker                                                                {
3128*da0073e9SAndroid Build Coastguard Worker                                                                    id(dm_1): (
3129*da0073e9SAndroid Build Coastguard Worker                                                                        dm_1,
3130*da0073e9SAndroid Build Coastguard Worker                                                                        {
3131*da0073e9SAndroid Build Coastguard Worker                                                                            id(m2_1): (
3132*da0073e9SAndroid Build Coastguard Worker                                                                                m2_1,
3133*da0073e9SAndroid Build Coastguard Worker                                                                                {
3134*da0073e9SAndroid Build Coastguard Worker                                                                                    id(
3135*da0073e9SAndroid Build Coastguard Worker                                                                                        m1_1
3136*da0073e9SAndroid Build Coastguard Worker                                                                                    ): (
3137*da0073e9SAndroid Build Coastguard Worker                                                                                        m1_1,
3138*da0073e9SAndroid Build Coastguard Worker                                                                                        {
3139*da0073e9SAndroid Build Coastguard Worker                                                                                            id(
3140*da0073e9SAndroid Build Coastguard Worker                                                                                                src_1
3141*da0073e9SAndroid Build Coastguard Worker                                                                                            ): (
3142*da0073e9SAndroid Build Coastguard Worker                                                                                                src_1,
3143*da0073e9SAndroid Build Coastguard Worker                                                                                                {},
3144*da0073e9SAndroid Build Coastguard Worker                                                                                            )
3145*da0073e9SAndroid Build Coastguard Worker                                                                                        },
3146*da0073e9SAndroid Build Coastguard Worker                                                                                    )
3147*da0073e9SAndroid Build Coastguard Worker                                                                                },
3148*da0073e9SAndroid Build Coastguard Worker                                                                            )
3149*da0073e9SAndroid Build Coastguard Worker                                                                        },
3150*da0073e9SAndroid Build Coastguard Worker                                                                    )
3151*da0073e9SAndroid Build Coastguard Worker                                                                },
3152*da0073e9SAndroid Build Coastguard Worker                                                            ),
3153*da0073e9SAndroid Build Coastguard Worker                                                        },
3154*da0073e9SAndroid Build Coastguard Worker                                                    ),
3155*da0073e9SAndroid Build Coastguard Worker                                                },
3156*da0073e9SAndroid Build Coastguard Worker                                            )
3157*da0073e9SAndroid Build Coastguard Worker                                        },
3158*da0073e9SAndroid Build Coastguard Worker                                    )
3159*da0073e9SAndroid Build Coastguard Worker                                },
3160*da0073e9SAndroid Build Coastguard Worker                            )
3161*da0073e9SAndroid Build Coastguard Worker                        },
3162*da0073e9SAndroid Build Coastguard Worker                    ),
3163*da0073e9SAndroid Build Coastguard Worker                },
3164*da0073e9SAndroid Build Coastguard Worker            )
3165*da0073e9SAndroid Build Coastguard Worker        }
3166*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res2, exp_res_2)
3167*da0073e9SAndroid Build Coastguard Worker
3168*da0073e9SAndroid Build Coastguard Worker    class LambdaIterDataPipe(CustomIterDataPipe):
3169*da0073e9SAndroid Build Coastguard Worker        def __init__(self, fn, source_dp=None):
3170*da0073e9SAndroid Build Coastguard Worker            super().__init__(fn, source_dp)
3171*da0073e9SAndroid Build Coastguard Worker            self.container = [
3172*da0073e9SAndroid Build Coastguard Worker                lambda x: x + 1,
3173*da0073e9SAndroid Build Coastguard Worker            ]
3174*da0073e9SAndroid Build Coastguard Worker            self.lambda_fn = lambda x: x + 1
3175*da0073e9SAndroid Build Coastguard Worker            self._dp = (
3176*da0073e9SAndroid Build Coastguard Worker                self.source_dp.map(self.add_one)
3177*da0073e9SAndroid Build Coastguard Worker                .map(self.lambda_fn)
3178*da0073e9SAndroid Build Coastguard Worker                .map(self.add_v)
3179*da0073e9SAndroid Build Coastguard Worker                .demux(2, self.classify)[0]
3180*da0073e9SAndroid Build Coastguard Worker            )
3181*da0073e9SAndroid Build Coastguard Worker
3182*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDill
3183*da0073e9SAndroid Build Coastguard Worker    @skipIf(True, "Dill Tests")
3184*da0073e9SAndroid Build Coastguard Worker    def test_circular_serialization_with_dill(self):
3185*da0073e9SAndroid Build Coastguard Worker        # Test for circular reference issue with dill
3186*da0073e9SAndroid Build Coastguard Worker        dp1 = TestCircularSerialization.LambdaIterDataPipe(lambda x: x + 1)
3187*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(list(dp1) == list(dill.loads(dill.dumps(dp1))))
3188*da0073e9SAndroid Build Coastguard Worker
3189*da0073e9SAndroid Build Coastguard Worker        child_1 = dp1._dp
3190*da0073e9SAndroid Build Coastguard Worker        dm_1 = child_1.main_datapipe
3191*da0073e9SAndroid Build Coastguard Worker        m2_1 = dm_1.main_datapipe
3192*da0073e9SAndroid Build Coastguard Worker        m1_1 = m2_1.datapipe
3193*da0073e9SAndroid Build Coastguard Worker        src_1 = m1_1.datapipe
3194*da0073e9SAndroid Build Coastguard Worker
3195*da0073e9SAndroid Build Coastguard Worker        res1 = traverse_dps(dp1)
3196*da0073e9SAndroid Build Coastguard Worker
3197*da0073e9SAndroid Build Coastguard Worker        exp_res_1 = {
3198*da0073e9SAndroid Build Coastguard Worker            id(dp1): (
3199*da0073e9SAndroid Build Coastguard Worker                dp1,
3200*da0073e9SAndroid Build Coastguard Worker                {
3201*da0073e9SAndroid Build Coastguard Worker                    id(src_1): (src_1, {}),
3202*da0073e9SAndroid Build Coastguard Worker                    id(child_1): (
3203*da0073e9SAndroid Build Coastguard Worker                        child_1,
3204*da0073e9SAndroid Build Coastguard Worker                        {
3205*da0073e9SAndroid Build Coastguard Worker                            id(dm_1): (
3206*da0073e9SAndroid Build Coastguard Worker                                dm_1,
3207*da0073e9SAndroid Build Coastguard Worker                                {
3208*da0073e9SAndroid Build Coastguard Worker                                    id(m2_1): (
3209*da0073e9SAndroid Build Coastguard Worker                                        m2_1,
3210*da0073e9SAndroid Build Coastguard Worker                                        {id(m1_1): (m1_1, {id(src_1): (src_1, {})})},
3211*da0073e9SAndroid Build Coastguard Worker                                    )
3212*da0073e9SAndroid Build Coastguard Worker                                },
3213*da0073e9SAndroid Build Coastguard Worker                            )
3214*da0073e9SAndroid Build Coastguard Worker                        },
3215*da0073e9SAndroid Build Coastguard Worker                    ),
3216*da0073e9SAndroid Build Coastguard Worker                },
3217*da0073e9SAndroid Build Coastguard Worker            )
3218*da0073e9SAndroid Build Coastguard Worker        }
3219*da0073e9SAndroid Build Coastguard Worker
3220*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, exp_res_1)
3221*da0073e9SAndroid Build Coastguard Worker
3222*da0073e9SAndroid Build Coastguard Worker        dp2 = TestCircularSerialization.LambdaIterDataPipe(fn=_fake_fn, source_dp=dp1)
3223*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(list(dp2) == list(dill.loads(dill.dumps(dp2))))
3224*da0073e9SAndroid Build Coastguard Worker
3225*da0073e9SAndroid Build Coastguard Worker        child_2 = dp2._dp
3226*da0073e9SAndroid Build Coastguard Worker        dm_2 = child_2.main_datapipe
3227*da0073e9SAndroid Build Coastguard Worker        m2_2 = dm_2.main_datapipe
3228*da0073e9SAndroid Build Coastguard Worker        m1_2 = m2_2.datapipe
3229*da0073e9SAndroid Build Coastguard Worker
3230*da0073e9SAndroid Build Coastguard Worker        res2 = traverse_dps(dp2)
3231*da0073e9SAndroid Build Coastguard Worker        exp_res_2 = {
3232*da0073e9SAndroid Build Coastguard Worker            id(dp2): (
3233*da0073e9SAndroid Build Coastguard Worker                dp2,
3234*da0073e9SAndroid Build Coastguard Worker                {
3235*da0073e9SAndroid Build Coastguard Worker                    id(dp1): (
3236*da0073e9SAndroid Build Coastguard Worker                        dp1,
3237*da0073e9SAndroid Build Coastguard Worker                        {
3238*da0073e9SAndroid Build Coastguard Worker                            id(src_1): (src_1, {}),
3239*da0073e9SAndroid Build Coastguard Worker                            id(child_1): (
3240*da0073e9SAndroid Build Coastguard Worker                                child_1,
3241*da0073e9SAndroid Build Coastguard Worker                                {
3242*da0073e9SAndroid Build Coastguard Worker                                    id(dm_1): (
3243*da0073e9SAndroid Build Coastguard Worker                                        dm_1,
3244*da0073e9SAndroid Build Coastguard Worker                                        {
3245*da0073e9SAndroid Build Coastguard Worker                                            id(m2_1): (
3246*da0073e9SAndroid Build Coastguard Worker                                                m2_1,
3247*da0073e9SAndroid Build Coastguard Worker                                                {
3248*da0073e9SAndroid Build Coastguard Worker                                                    id(m1_1): (
3249*da0073e9SAndroid Build Coastguard Worker                                                        m1_1,
3250*da0073e9SAndroid Build Coastguard Worker                                                        {id(src_1): (src_1, {})},
3251*da0073e9SAndroid Build Coastguard Worker                                                    )
3252*da0073e9SAndroid Build Coastguard Worker                                                },
3253*da0073e9SAndroid Build Coastguard Worker                                            )
3254*da0073e9SAndroid Build Coastguard Worker                                        },
3255*da0073e9SAndroid Build Coastguard Worker                                    )
3256*da0073e9SAndroid Build Coastguard Worker                                },
3257*da0073e9SAndroid Build Coastguard Worker                            ),
3258*da0073e9SAndroid Build Coastguard Worker                        },
3259*da0073e9SAndroid Build Coastguard Worker                    ),
3260*da0073e9SAndroid Build Coastguard Worker                    id(child_2): (
3261*da0073e9SAndroid Build Coastguard Worker                        child_2,
3262*da0073e9SAndroid Build Coastguard Worker                        {
3263*da0073e9SAndroid Build Coastguard Worker                            id(dm_2): (
3264*da0073e9SAndroid Build Coastguard Worker                                dm_2,
3265*da0073e9SAndroid Build Coastguard Worker                                {
3266*da0073e9SAndroid Build Coastguard Worker                                    id(m2_2): (
3267*da0073e9SAndroid Build Coastguard Worker                                        m2_2,
3268*da0073e9SAndroid Build Coastguard Worker                                        {
3269*da0073e9SAndroid Build Coastguard Worker                                            id(m1_2): (
3270*da0073e9SAndroid Build Coastguard Worker                                                m1_2,
3271*da0073e9SAndroid Build Coastguard Worker                                                {
3272*da0073e9SAndroid Build Coastguard Worker                                                    id(dp1): (
3273*da0073e9SAndroid Build Coastguard Worker                                                        dp1,
3274*da0073e9SAndroid Build Coastguard Worker                                                        {
3275*da0073e9SAndroid Build Coastguard Worker                                                            id(src_1): (src_1, {}),
3276*da0073e9SAndroid Build Coastguard Worker                                                            id(child_1): (
3277*da0073e9SAndroid Build Coastguard Worker                                                                child_1,
3278*da0073e9SAndroid Build Coastguard Worker                                                                {
3279*da0073e9SAndroid Build Coastguard Worker                                                                    id(dm_1): (
3280*da0073e9SAndroid Build Coastguard Worker                                                                        dm_1,
3281*da0073e9SAndroid Build Coastguard Worker                                                                        {
3282*da0073e9SAndroid Build Coastguard Worker                                                                            id(m2_1): (
3283*da0073e9SAndroid Build Coastguard Worker                                                                                m2_1,
3284*da0073e9SAndroid Build Coastguard Worker                                                                                {
3285*da0073e9SAndroid Build Coastguard Worker                                                                                    id(
3286*da0073e9SAndroid Build Coastguard Worker                                                                                        m1_1
3287*da0073e9SAndroid Build Coastguard Worker                                                                                    ): (
3288*da0073e9SAndroid Build Coastguard Worker                                                                                        m1_1,
3289*da0073e9SAndroid Build Coastguard Worker                                                                                        {
3290*da0073e9SAndroid Build Coastguard Worker                                                                                            id(
3291*da0073e9SAndroid Build Coastguard Worker                                                                                                src_1
3292*da0073e9SAndroid Build Coastguard Worker                                                                                            ): (
3293*da0073e9SAndroid Build Coastguard Worker                                                                                                src_1,
3294*da0073e9SAndroid Build Coastguard Worker                                                                                                {},
3295*da0073e9SAndroid Build Coastguard Worker                                                                                            )
3296*da0073e9SAndroid Build Coastguard Worker                                                                                        },
3297*da0073e9SAndroid Build Coastguard Worker                                                                                    )
3298*da0073e9SAndroid Build Coastguard Worker                                                                                },
3299*da0073e9SAndroid Build Coastguard Worker                                                                            )
3300*da0073e9SAndroid Build Coastguard Worker                                                                        },
3301*da0073e9SAndroid Build Coastguard Worker                                                                    )
3302*da0073e9SAndroid Build Coastguard Worker                                                                },
3303*da0073e9SAndroid Build Coastguard Worker                                                            ),
3304*da0073e9SAndroid Build Coastguard Worker                                                        },
3305*da0073e9SAndroid Build Coastguard Worker                                                    ),
3306*da0073e9SAndroid Build Coastguard Worker                                                },
3307*da0073e9SAndroid Build Coastguard Worker                                            )
3308*da0073e9SAndroid Build Coastguard Worker                                        },
3309*da0073e9SAndroid Build Coastguard Worker                                    )
3310*da0073e9SAndroid Build Coastguard Worker                                },
3311*da0073e9SAndroid Build Coastguard Worker                            )
3312*da0073e9SAndroid Build Coastguard Worker                        },
3313*da0073e9SAndroid Build Coastguard Worker                    ),
3314*da0073e9SAndroid Build Coastguard Worker                },
3315*da0073e9SAndroid Build Coastguard Worker            )
3316*da0073e9SAndroid Build Coastguard Worker        }
3317*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res2, exp_res_2)
3318*da0073e9SAndroid Build Coastguard Worker
3319*da0073e9SAndroid Build Coastguard Worker
3320*da0073e9SAndroid Build Coastguard Workerclass CustomShardingIterDataPipe(IterDataPipe):
3321*da0073e9SAndroid Build Coastguard Worker    def __init__(self, dp):
3322*da0073e9SAndroid Build Coastguard Worker        self.dp = dp
3323*da0073e9SAndroid Build Coastguard Worker        self.num_of_instances = 1
3324*da0073e9SAndroid Build Coastguard Worker        self.instance_id = 0
3325*da0073e9SAndroid Build Coastguard Worker
3326*da0073e9SAndroid Build Coastguard Worker    def apply_sharding(self, num_of_instances, instance_id):
3327*da0073e9SAndroid Build Coastguard Worker        self.num_of_instances = num_of_instances
3328*da0073e9SAndroid Build Coastguard Worker        self.instance_id = instance_id
3329*da0073e9SAndroid Build Coastguard Worker
3330*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
3331*da0073e9SAndroid Build Coastguard Worker        for i, d in enumerate(self.dp):
3332*da0073e9SAndroid Build Coastguard Worker            if i % self.num_of_instances == self.instance_id:
3333*da0073e9SAndroid Build Coastguard Worker                yield d
3334*da0073e9SAndroid Build Coastguard Worker
3335*da0073e9SAndroid Build Coastguard Worker
3336*da0073e9SAndroid Build Coastguard Workerclass TestSharding(TestCase):
3337*da0073e9SAndroid Build Coastguard Worker    def _get_pipeline(self):
3338*da0073e9SAndroid Build Coastguard Worker        numbers_dp = NumbersDataset(size=10)
3339*da0073e9SAndroid Build Coastguard Worker        dp0, dp1 = numbers_dp.fork(num_instances=2)
3340*da0073e9SAndroid Build Coastguard Worker        dp0_upd = dp0.map(_mul_10)
3341*da0073e9SAndroid Build Coastguard Worker        dp1_upd = dp1.filter(_mod_3_test)
3342*da0073e9SAndroid Build Coastguard Worker        combined_dp = dp0_upd.mux(dp1_upd)
3343*da0073e9SAndroid Build Coastguard Worker        return combined_dp
3344*da0073e9SAndroid Build Coastguard Worker
3345*da0073e9SAndroid Build Coastguard Worker    def _get_dill_pipeline(self):
3346*da0073e9SAndroid Build Coastguard Worker        numbers_dp = NumbersDataset(size=10)
3347*da0073e9SAndroid Build Coastguard Worker        dp0, dp1 = numbers_dp.fork(num_instances=2)
3348*da0073e9SAndroid Build Coastguard Worker        dp0_upd = dp0.map(lambda x: x * 10)
3349*da0073e9SAndroid Build Coastguard Worker        dp1_upd = dp1.filter(lambda x: x % 3 == 1)
3350*da0073e9SAndroid Build Coastguard Worker        combined_dp = dp0_upd.mux(dp1_upd)
3351*da0073e9SAndroid Build Coastguard Worker        return combined_dp
3352*da0073e9SAndroid Build Coastguard Worker
3353*da0073e9SAndroid Build Coastguard Worker    def test_simple_sharding(self):
3354*da0073e9SAndroid Build Coastguard Worker        sharded_dp = self._get_pipeline().sharding_filter()
3355*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 1)
3356*da0073e9SAndroid Build Coastguard Worker        items = list(sharded_dp)
3357*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([1, 20], items)
3358*da0073e9SAndroid Build Coastguard Worker
3359*da0073e9SAndroid Build Coastguard Worker        all_items = [0, 1, 10, 4, 20, 7]
3360*da0073e9SAndroid Build Coastguard Worker        items = []
3361*da0073e9SAndroid Build Coastguard Worker        for i in range(3):
3362*da0073e9SAndroid Build Coastguard Worker            sharded_dp = self._get_pipeline().sharding_filter()
3363*da0073e9SAndroid Build Coastguard Worker            torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, i)
3364*da0073e9SAndroid Build Coastguard Worker            items += list(sharded_dp)
3365*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sorted(all_items), sorted(items))
3366*da0073e9SAndroid Build Coastguard Worker
3367*da0073e9SAndroid Build Coastguard Worker    def test_sharding_groups(self):
3368*da0073e9SAndroid Build Coastguard Worker        def construct_sharded_pipe():
3369*da0073e9SAndroid Build Coastguard Worker            sharding_pipes = []
3370*da0073e9SAndroid Build Coastguard Worker            dp = NumbersDataset(size=90)
3371*da0073e9SAndroid Build Coastguard Worker            dp = dp.sharding_filter(
3372*da0073e9SAndroid Build Coastguard Worker                sharding_group_filter=SHARDING_PRIORITIES.DISTRIBUTED
3373*da0073e9SAndroid Build Coastguard Worker            )
3374*da0073e9SAndroid Build Coastguard Worker            sharding_pipes.append(dp)
3375*da0073e9SAndroid Build Coastguard Worker            dp = dp.sharding_filter(
3376*da0073e9SAndroid Build Coastguard Worker                sharding_group_filter=SHARDING_PRIORITIES.MULTIPROCESSING
3377*da0073e9SAndroid Build Coastguard Worker            )
3378*da0073e9SAndroid Build Coastguard Worker            sharding_pipes.append(dp)
3379*da0073e9SAndroid Build Coastguard Worker            dp = dp.sharding_filter(sharding_group_filter=300)
3380*da0073e9SAndroid Build Coastguard Worker            sharding_pipes.append(dp)
3381*da0073e9SAndroid Build Coastguard Worker            return dp, sharding_pipes
3382*da0073e9SAndroid Build Coastguard Worker
3383*da0073e9SAndroid Build Coastguard Worker        dp, sharding_pipes = construct_sharded_pipe()
3384*da0073e9SAndroid Build Coastguard Worker
3385*da0073e9SAndroid Build Coastguard Worker        for pipe in sharding_pipes:
3386*da0073e9SAndroid Build Coastguard Worker            pipe.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DISTRIBUTED)
3387*da0073e9SAndroid Build Coastguard Worker            pipe.apply_sharding(
3388*da0073e9SAndroid Build Coastguard Worker                5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING
3389*da0073e9SAndroid Build Coastguard Worker            )
3390*da0073e9SAndroid Build Coastguard Worker            pipe.apply_sharding(3, 1, sharding_group=300)
3391*da0073e9SAndroid Build Coastguard Worker
3392*da0073e9SAndroid Build Coastguard Worker        actual = list(dp)
3393*da0073e9SAndroid Build Coastguard Worker        expected = [17, 47, 77]
3394*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
3395*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(3, len(dp))
3396*da0073e9SAndroid Build Coastguard Worker
3397*da0073e9SAndroid Build Coastguard Worker        dp, _ = construct_sharded_pipe()
3398*da0073e9SAndroid Build Coastguard Worker        dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT)
3399*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
3400*da0073e9SAndroid Build Coastguard Worker            dp.apply_sharding(5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING)
3401*da0073e9SAndroid Build Coastguard Worker
3402*da0073e9SAndroid Build Coastguard Worker        dp, _ = construct_sharded_pipe()
3403*da0073e9SAndroid Build Coastguard Worker        dp.apply_sharding(5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING)
3404*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
3405*da0073e9SAndroid Build Coastguard Worker            dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT)
3406*da0073e9SAndroid Build Coastguard Worker
3407*da0073e9SAndroid Build Coastguard Worker    # Test tud.datapipes.iter.grouping.SHARDING_PRIORITIES for backward compatbility
3408*da0073e9SAndroid Build Coastguard Worker    # TODO: Remove this test once tud.datapipes.iter.grouping.SHARDING_PRIORITIES is deprecated
3409*da0073e9SAndroid Build Coastguard Worker    def test_sharding_groups_in_legacy_grouping_package(self):
3410*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
3411*da0073e9SAndroid Build Coastguard Worker            FutureWarning,
3412*da0073e9SAndroid Build Coastguard Worker            r"Please use `SHARDING_PRIORITIES` "
3413*da0073e9SAndroid Build Coastguard Worker            "from the `torch.utils.data.datapipes.iter.sharding`",
3414*da0073e9SAndroid Build Coastguard Worker        ):
3415*da0073e9SAndroid Build Coastguard Worker            from torch.utils.data.datapipes.iter.grouping import (
3416*da0073e9SAndroid Build Coastguard Worker                SHARDING_PRIORITIES as LEGACY_SHARDING_PRIORITIES,
3417*da0073e9SAndroid Build Coastguard Worker            )
3418*da0073e9SAndroid Build Coastguard Worker
3419*da0073e9SAndroid Build Coastguard Worker        def construct_sharded_pipe():
3420*da0073e9SAndroid Build Coastguard Worker            sharding_pipes = []
3421*da0073e9SAndroid Build Coastguard Worker            dp = NumbersDataset(size=90)
3422*da0073e9SAndroid Build Coastguard Worker            dp = dp.sharding_filter(
3423*da0073e9SAndroid Build Coastguard Worker                sharding_group_filter=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED
3424*da0073e9SAndroid Build Coastguard Worker            )
3425*da0073e9SAndroid Build Coastguard Worker            sharding_pipes.append(dp)
3426*da0073e9SAndroid Build Coastguard Worker            dp = dp.sharding_filter(
3427*da0073e9SAndroid Build Coastguard Worker                sharding_group_filter=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
3428*da0073e9SAndroid Build Coastguard Worker            )
3429*da0073e9SAndroid Build Coastguard Worker            sharding_pipes.append(dp)
3430*da0073e9SAndroid Build Coastguard Worker            dp = dp.sharding_filter(sharding_group_filter=300)
3431*da0073e9SAndroid Build Coastguard Worker            sharding_pipes.append(dp)
3432*da0073e9SAndroid Build Coastguard Worker            return dp, sharding_pipes
3433*da0073e9SAndroid Build Coastguard Worker
3434*da0073e9SAndroid Build Coastguard Worker        dp, sharding_pipes = construct_sharded_pipe()
3435*da0073e9SAndroid Build Coastguard Worker
3436*da0073e9SAndroid Build Coastguard Worker        for pipe in sharding_pipes:
3437*da0073e9SAndroid Build Coastguard Worker            pipe.apply_sharding(
3438*da0073e9SAndroid Build Coastguard Worker                2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED
3439*da0073e9SAndroid Build Coastguard Worker            )
3440*da0073e9SAndroid Build Coastguard Worker            pipe.apply_sharding(
3441*da0073e9SAndroid Build Coastguard Worker                5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
3442*da0073e9SAndroid Build Coastguard Worker            )
3443*da0073e9SAndroid Build Coastguard Worker            pipe.apply_sharding(3, 1, sharding_group=300)
3444*da0073e9SAndroid Build Coastguard Worker
3445*da0073e9SAndroid Build Coastguard Worker        actual = list(dp)
3446*da0073e9SAndroid Build Coastguard Worker        expected = [17, 47, 77]
3447*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
3448*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(3, len(dp))
3449*da0073e9SAndroid Build Coastguard Worker
3450*da0073e9SAndroid Build Coastguard Worker        dp, _ = construct_sharded_pipe()
3451*da0073e9SAndroid Build Coastguard Worker        dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT)
3452*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
3453*da0073e9SAndroid Build Coastguard Worker            dp.apply_sharding(
3454*da0073e9SAndroid Build Coastguard Worker                5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
3455*da0073e9SAndroid Build Coastguard Worker            )
3456*da0073e9SAndroid Build Coastguard Worker
3457*da0073e9SAndroid Build Coastguard Worker        dp, _ = construct_sharded_pipe()
3458*da0073e9SAndroid Build Coastguard Worker        dp.apply_sharding(
3459*da0073e9SAndroid Build Coastguard Worker            5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
3460*da0073e9SAndroid Build Coastguard Worker        )
3461*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
3462*da0073e9SAndroid Build Coastguard Worker            dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT)
3463*da0073e9SAndroid Build Coastguard Worker
3464*da0073e9SAndroid Build Coastguard Worker    def test_legacy_custom_sharding(self):
3465*da0073e9SAndroid Build Coastguard Worker        dp = self._get_pipeline()
3466*da0073e9SAndroid Build Coastguard Worker        sharded_dp = CustomShardingIterDataPipe(dp)
3467*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 1)
3468*da0073e9SAndroid Build Coastguard Worker        items = list(sharded_dp)
3469*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([1, 20], items)
3470*da0073e9SAndroid Build Coastguard Worker
3471*da0073e9SAndroid Build Coastguard Worker    def test_sharding_length(self):
3472*da0073e9SAndroid Build Coastguard Worker        numbers_dp = dp.iter.IterableWrapper(range(13))
3473*da0073e9SAndroid Build Coastguard Worker        sharded_dp0 = numbers_dp.sharding_filter()
3474*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_sharding(sharded_dp0, 3, 0)
3475*da0073e9SAndroid Build Coastguard Worker        sharded_dp1 = numbers_dp.sharding_filter()
3476*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_sharding(sharded_dp1, 3, 1)
3477*da0073e9SAndroid Build Coastguard Worker        sharded_dp2 = numbers_dp.sharding_filter()
3478*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_sharding(sharded_dp2, 3, 2)
3479*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(13, len(numbers_dp))
3480*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(5, len(sharded_dp0))
3481*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(4, len(sharded_dp1))
3482*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(4, len(sharded_dp2))
3483*da0073e9SAndroid Build Coastguard Worker
3484*da0073e9SAndroid Build Coastguard Worker        numbers_dp = dp.iter.IterableWrapper(range(1))
3485*da0073e9SAndroid Build Coastguard Worker        sharded_dp0 = numbers_dp.sharding_filter()
3486*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_sharding(sharded_dp0, 2, 0)
3487*da0073e9SAndroid Build Coastguard Worker        sharded_dp1 = numbers_dp.sharding_filter()
3488*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_sharding(sharded_dp1, 2, 1)
3489*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, len(sharded_dp0))
3490*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, len(sharded_dp1))
3491*da0073e9SAndroid Build Coastguard Worker
3492*da0073e9SAndroid Build Coastguard Worker    def test_old_dataloader(self):
3493*da0073e9SAndroid Build Coastguard Worker        dp0 = self._get_pipeline()
3494*da0073e9SAndroid Build Coastguard Worker        expected = list(dp0)
3495*da0073e9SAndroid Build Coastguard Worker
3496*da0073e9SAndroid Build Coastguard Worker        dp0 = self._get_pipeline().sharding_filter()
3497*da0073e9SAndroid Build Coastguard Worker        dl = DataLoader(dp0, batch_size=1, shuffle=False, num_workers=2)
3498*da0073e9SAndroid Build Coastguard Worker        items = list(dl)
3499*da0073e9SAndroid Build Coastguard Worker
3500*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sorted(expected), sorted(items))
3501*da0073e9SAndroid Build Coastguard Worker
3502*da0073e9SAndroid Build Coastguard Worker    def test_legacy_custom_sharding_with_old_dataloader(self):
3503*da0073e9SAndroid Build Coastguard Worker        dp0 = self._get_pipeline()
3504*da0073e9SAndroid Build Coastguard Worker        expected = list(dp0)
3505*da0073e9SAndroid Build Coastguard Worker
3506*da0073e9SAndroid Build Coastguard Worker        dp0 = self._get_pipeline()
3507*da0073e9SAndroid Build Coastguard Worker        dp0 = CustomShardingIterDataPipe(dp0)
3508*da0073e9SAndroid Build Coastguard Worker        dl = DataLoader(dp0, batch_size=1, shuffle=False, num_workers=2)
3509*da0073e9SAndroid Build Coastguard Worker        items = list(dl)
3510*da0073e9SAndroid Build Coastguard Worker
3511*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sorted(expected), sorted(items))
3512*da0073e9SAndroid Build Coastguard Worker
3513*da0073e9SAndroid Build Coastguard Worker    def test_multi_sharding(self):
3514*da0073e9SAndroid Build Coastguard Worker        # Raises Error when multiple sharding on the single branch
3515*da0073e9SAndroid Build Coastguard Worker        numbers_dp = dp.iter.IterableWrapper(range(13))
3516*da0073e9SAndroid Build Coastguard Worker        sharded_dp = numbers_dp.sharding_filter()
3517*da0073e9SAndroid Build Coastguard Worker        sharded_dp = sharded_dp.sharding_filter()
3518*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3519*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Sharding twice on a single pipeline"
3520*da0073e9SAndroid Build Coastguard Worker        ):
3521*da0073e9SAndroid Build Coastguard Worker            torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 0)
3522*da0073e9SAndroid Build Coastguard Worker
3523*da0073e9SAndroid Build Coastguard Worker        # Raises Error when sharding on both data source and branch
3524*da0073e9SAndroid Build Coastguard Worker        numbers_dp = dp.iter.IterableWrapper(range(13)).sharding_filter()
3525*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = numbers_dp.fork(2)
3526*da0073e9SAndroid Build Coastguard Worker        sharded_dp = dp1.sharding_filter()
3527*da0073e9SAndroid Build Coastguard Worker        zip_dp = dp2.zip(sharded_dp)
3528*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3529*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Sharding twice on a single pipeline"
3530*da0073e9SAndroid Build Coastguard Worker        ):
3531*da0073e9SAndroid Build Coastguard Worker            torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0)
3532*da0073e9SAndroid Build Coastguard Worker
3533*da0073e9SAndroid Build Coastguard Worker        # Raises Error when multiple sharding on the branch and end
3534*da0073e9SAndroid Build Coastguard Worker        numbers_dp = dp.iter.IterableWrapper(range(13))
3535*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = numbers_dp.fork(2)
3536*da0073e9SAndroid Build Coastguard Worker        sharded_dp = dp1.sharding_filter()
3537*da0073e9SAndroid Build Coastguard Worker        zip_dp = dp2.zip(sharded_dp).sharding_filter()
3538*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3539*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Sharding twice on a single pipeline"
3540*da0073e9SAndroid Build Coastguard Worker        ):
3541*da0073e9SAndroid Build Coastguard Worker            torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0)
3542*da0073e9SAndroid Build Coastguard Worker
3543*da0073e9SAndroid Build Coastguard Worker        # Single sharding_filter on data source
3544*da0073e9SAndroid Build Coastguard Worker        numbers_dp = dp.iter.IterableWrapper(range(13)).sharding_filter()
3545*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = numbers_dp.fork(2)
3546*da0073e9SAndroid Build Coastguard Worker        zip_dp = dp1.zip(dp2)
3547*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0)
3548*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(zip_dp), [(i * 3, i * 3) for i in range(13 // 3 + 1)])
3549*da0073e9SAndroid Build Coastguard Worker
3550*da0073e9SAndroid Build Coastguard Worker        # Single sharding_filter per branch
3551*da0073e9SAndroid Build Coastguard Worker        numbers_dp = dp.iter.IterableWrapper(range(13))
3552*da0073e9SAndroid Build Coastguard Worker        dp1, dp2 = numbers_dp.fork(2)
3553*da0073e9SAndroid Build Coastguard Worker        sharded_dp1 = dp1.sharding_filter()
3554*da0073e9SAndroid Build Coastguard Worker        sharded_dp2 = dp2.sharding_filter()
3555*da0073e9SAndroid Build Coastguard Worker        zip_dp = sharded_dp1.zip(sharded_dp2)
3556*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0)
3557*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(zip_dp), [(i * 3, i * 3) for i in range(13 // 3 + 1)])
3558*da0073e9SAndroid Build Coastguard Worker
3559*da0073e9SAndroid Build Coastguard Worker
3560*da0073e9SAndroid Build Coastguard Workerclass TestIterDataPipeSingletonConstraint(TestCase):
3561*da0073e9SAndroid Build Coastguard Worker    r"""
3562*da0073e9SAndroid Build Coastguard Worker    Each `IterDataPipe` can only have one active iterator. Whenever a new iterator is created, older
3563*da0073e9SAndroid Build Coastguard Worker    iterators are invalidated. These tests aim to ensure `IterDataPipe` follows this behavior.
3564*da0073e9SAndroid Build Coastguard Worker    """
3565*da0073e9SAndroid Build Coastguard Worker
3566*da0073e9SAndroid Build Coastguard Worker    def _check_single_iterator_invalidation_logic(self, source_dp: IterDataPipe):
3567*da0073e9SAndroid Build Coastguard Worker        r"""
3568*da0073e9SAndroid Build Coastguard Worker        Given a IterDataPipe, verifies that the iterator can be read, reset, and the creation of
3569*da0073e9SAndroid Build Coastguard Worker        a second iterator invalidates the first one.
3570*da0073e9SAndroid Build Coastguard Worker        """
3571*da0073e9SAndroid Build Coastguard Worker        it1 = iter(source_dp)
3572*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), list(it1))
3573*da0073e9SAndroid Build Coastguard Worker        it1 = iter(source_dp)
3574*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3575*da0073e9SAndroid Build Coastguard Worker            list(range(10)), list(it1)
3576*da0073e9SAndroid Build Coastguard Worker        )  # A fresh iterator can be read in full again
3577*da0073e9SAndroid Build Coastguard Worker        it1 = iter(source_dp)
3578*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, next(it1))
3579*da0073e9SAndroid Build Coastguard Worker        it2 = iter(source_dp)  # This should invalidate `it1`
3580*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, next(it2))  # Should read from the beginning again
3581*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3582*da0073e9SAndroid Build Coastguard Worker            next(it1)
3583*da0073e9SAndroid Build Coastguard Worker
3584*da0073e9SAndroid Build Coastguard Worker    def test_iterdatapipe_singleton_generator(self):
3585*da0073e9SAndroid Build Coastguard Worker        r"""
3586*da0073e9SAndroid Build Coastguard Worker        Testing for the case where IterDataPipe's `__iter__` is a generator function.
3587*da0073e9SAndroid Build Coastguard Worker        """
3588*da0073e9SAndroid Build Coastguard Worker
3589*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Check if invalidation logic is correct
3590*da0073e9SAndroid Build Coastguard Worker        source_dp: IterDataPipe = dp.iter.IterableWrapper(range(10))
3591*da0073e9SAndroid Build Coastguard Worker        self._check_single_iterator_invalidation_logic(source_dp)
3592*da0073e9SAndroid Build Coastguard Worker
3593*da0073e9SAndroid Build Coastguard Worker        # Functional Test: extend the test to a pipeline
3594*da0073e9SAndroid Build Coastguard Worker        dps = source_dp.map(_fake_fn).filter(_fake_filter_fn)
3595*da0073e9SAndroid Build Coastguard Worker        self._check_single_iterator_invalidation_logic(dps)
3596*da0073e9SAndroid Build Coastguard Worker
3597*da0073e9SAndroid Build Coastguard Worker        # Functional Test: multiple simultaneous references to the same DataPipe fails
3598*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3599*da0073e9SAndroid Build Coastguard Worker            for _ in zip(source_dp, source_dp):
3600*da0073e9SAndroid Build Coastguard Worker                pass
3601*da0073e9SAndroid Build Coastguard Worker
3602*da0073e9SAndroid Build Coastguard Worker        # Function Test: sequential references work
3603*da0073e9SAndroid Build Coastguard Worker        for _ in zip(list(source_dp), list(source_dp)):
3604*da0073e9SAndroid Build Coastguard Worker            pass
3605*da0073e9SAndroid Build Coastguard Worker
3606*da0073e9SAndroid Build Coastguard Worker    def test_iterdatapipe_singleton_self_next(self):
3607*da0073e9SAndroid Build Coastguard Worker        r"""
3608*da0073e9SAndroid Build Coastguard Worker        Testing for the case where IterDataPipe's `__iter__` returns `self` and there is a `__next__` method
3609*da0073e9SAndroid Build Coastguard Worker        Note that the following DataPipe by is singleton by default (because `__iter__` returns `self`).
3610*da0073e9SAndroid Build Coastguard Worker        """
3611*da0073e9SAndroid Build Coastguard Worker
3612*da0073e9SAndroid Build Coastguard Worker        class _CustomIterDP_Self(IterDataPipe):
3613*da0073e9SAndroid Build Coastguard Worker            def __init__(self, iterable):
3614*da0073e9SAndroid Build Coastguard Worker                self.source = iterable
3615*da0073e9SAndroid Build Coastguard Worker                self.iterable = iter(iterable)
3616*da0073e9SAndroid Build Coastguard Worker
3617*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
3618*da0073e9SAndroid Build Coastguard Worker                self.reset()
3619*da0073e9SAndroid Build Coastguard Worker                return self
3620*da0073e9SAndroid Build Coastguard Worker
3621*da0073e9SAndroid Build Coastguard Worker            def __next__(self):
3622*da0073e9SAndroid Build Coastguard Worker                return next(self.iterable)
3623*da0073e9SAndroid Build Coastguard Worker
3624*da0073e9SAndroid Build Coastguard Worker            def reset(self):
3625*da0073e9SAndroid Build Coastguard Worker                self.iterable = iter(self.source)
3626*da0073e9SAndroid Build Coastguard Worker
3627*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Check that every `__iter__` call returns the same object
3628*da0073e9SAndroid Build Coastguard Worker        source_dp = _CustomIterDP_Self(range(10))
3629*da0073e9SAndroid Build Coastguard Worker        res = list(source_dp)
3630*da0073e9SAndroid Build Coastguard Worker        it = iter(source_dp)
3631*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, list(it))
3632*da0073e9SAndroid Build Coastguard Worker
3633*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Check if invalidation logic is correct
3634*da0073e9SAndroid Build Coastguard Worker        source_dp = _CustomIterDP_Self(range(10))
3635*da0073e9SAndroid Build Coastguard Worker        self._check_single_iterator_invalidation_logic(source_dp)
3636*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3637*da0073e9SAndroid Build Coastguard Worker            1, next(source_dp)
3638*da0073e9SAndroid Build Coastguard Worker        )  # `source_dp` is still valid and can be read
3639*da0073e9SAndroid Build Coastguard Worker
3640*da0073e9SAndroid Build Coastguard Worker        # Functional Test: extend the test to a pipeline
3641*da0073e9SAndroid Build Coastguard Worker        source_dp = _CustomIterDP_Self(
3642*da0073e9SAndroid Build Coastguard Worker            dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn)
3643*da0073e9SAndroid Build Coastguard Worker        )
3644*da0073e9SAndroid Build Coastguard Worker        self._check_single_iterator_invalidation_logic(source_dp)
3645*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3646*da0073e9SAndroid Build Coastguard Worker            1, next(source_dp)
3647*da0073e9SAndroid Build Coastguard Worker        )  # `source_dp` is still valid and can be read
3648*da0073e9SAndroid Build Coastguard Worker
3649*da0073e9SAndroid Build Coastguard Worker        # Functional Test: multiple simultaneous references to the same DataPipe fails
3650*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3651*da0073e9SAndroid Build Coastguard Worker            for _ in zip(source_dp, source_dp):
3652*da0073e9SAndroid Build Coastguard Worker                pass
3653*da0073e9SAndroid Build Coastguard Worker
3654*da0073e9SAndroid Build Coastguard Worker    def test_iterdatapipe_singleton_new_object(self):
3655*da0073e9SAndroid Build Coastguard Worker        r"""
3656*da0073e9SAndroid Build Coastguard Worker        Testing for the case where IterDataPipe's `__iter__` isn't a generator nor returns `self`,
3657*da0073e9SAndroid Build Coastguard Worker        and there isn't a `__next__` method.
3658*da0073e9SAndroid Build Coastguard Worker        """
3659*da0073e9SAndroid Build Coastguard Worker
3660*da0073e9SAndroid Build Coastguard Worker        class _CustomIterDP(IterDataPipe):
3661*da0073e9SAndroid Build Coastguard Worker            def __init__(self, iterable):
3662*da0073e9SAndroid Build Coastguard Worker                self.iterable = iter(iterable)
3663*da0073e9SAndroid Build Coastguard Worker
3664*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):  # Note that this doesn't reset
3665*da0073e9SAndroid Build Coastguard Worker                return self.iterable  # Intentionally not returning `self`
3666*da0073e9SAndroid Build Coastguard Worker
3667*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Check if invalidation logic is correct
3668*da0073e9SAndroid Build Coastguard Worker        source_dp = _CustomIterDP(range(10))
3669*da0073e9SAndroid Build Coastguard Worker        it1 = iter(source_dp)
3670*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, next(it1))
3671*da0073e9SAndroid Build Coastguard Worker        it2 = iter(source_dp)
3672*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, next(it2))
3673*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3674*da0073e9SAndroid Build Coastguard Worker            next(it1)
3675*da0073e9SAndroid Build Coastguard Worker
3676*da0073e9SAndroid Build Coastguard Worker        # Functional Test: extend the test to a pipeline
3677*da0073e9SAndroid Build Coastguard Worker        source_dp = _CustomIterDP(
3678*da0073e9SAndroid Build Coastguard Worker            dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn)
3679*da0073e9SAndroid Build Coastguard Worker        )
3680*da0073e9SAndroid Build Coastguard Worker        it1 = iter(source_dp)
3681*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, next(it1))
3682*da0073e9SAndroid Build Coastguard Worker        it2 = iter(source_dp)
3683*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, next(it2))
3684*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3685*da0073e9SAndroid Build Coastguard Worker            next(it1)
3686*da0073e9SAndroid Build Coastguard Worker
3687*da0073e9SAndroid Build Coastguard Worker        # Functional Test: multiple simultaneous references to the same DataPipe fails
3688*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3689*da0073e9SAndroid Build Coastguard Worker            for _ in zip(source_dp, source_dp):
3690*da0073e9SAndroid Build Coastguard Worker                pass
3691*da0073e9SAndroid Build Coastguard Worker
3692*da0073e9SAndroid Build Coastguard Worker    def test_iterdatapipe_singleton_buggy(self):
3693*da0073e9SAndroid Build Coastguard Worker        r"""
3694*da0073e9SAndroid Build Coastguard Worker        Buggy test case case where IterDataPipe's `__iter__` returns a new object, but also has
3695*da0073e9SAndroid Build Coastguard Worker        a `__next__` method.
3696*da0073e9SAndroid Build Coastguard Worker        """
3697*da0073e9SAndroid Build Coastguard Worker
3698*da0073e9SAndroid Build Coastguard Worker        class _CustomIterDP(IterDataPipe):
3699*da0073e9SAndroid Build Coastguard Worker            def __init__(self, iterable):
3700*da0073e9SAndroid Build Coastguard Worker                self.source = iterable
3701*da0073e9SAndroid Build Coastguard Worker                self.iterable = iter(iterable)
3702*da0073e9SAndroid Build Coastguard Worker
3703*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
3704*da0073e9SAndroid Build Coastguard Worker                return iter(self.source)  # Intentionally not returning `self`
3705*da0073e9SAndroid Build Coastguard Worker
3706*da0073e9SAndroid Build Coastguard Worker            def __next__(self):
3707*da0073e9SAndroid Build Coastguard Worker                return next(self.iterable)
3708*da0073e9SAndroid Build Coastguard Worker
3709*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Check if invalidation logic is correct
3710*da0073e9SAndroid Build Coastguard Worker        source_dp = _CustomIterDP(range(10))
3711*da0073e9SAndroid Build Coastguard Worker        self._check_single_iterator_invalidation_logic(source_dp)
3712*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, next(source_dp))  # `__next__` is unrelated with `__iter__`
3713*da0073e9SAndroid Build Coastguard Worker
3714*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Special case to show `__next__` is unrelated with `__iter__`
3715*da0073e9SAndroid Build Coastguard Worker        source_dp = _CustomIterDP(range(10))
3716*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, next(source_dp))
3717*da0073e9SAndroid Build Coastguard Worker        it1 = iter(source_dp)
3718*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, next(it1))
3719*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, next(source_dp))
3720*da0073e9SAndroid Build Coastguard Worker        it2 = iter(source_dp)  # invalidates both `it1`
3721*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3722*da0073e9SAndroid Build Coastguard Worker            next(it1)
3723*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, next(source_dp))  # not impacted by the creation of `it2`
3724*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3725*da0073e9SAndroid Build Coastguard Worker            list(range(10)), list(it2)
3726*da0073e9SAndroid Build Coastguard Worker        )  # `it2` still works because it is a new object
3727*da0073e9SAndroid Build Coastguard Worker
3728*da0073e9SAndroid Build Coastguard Worker    def test_iterdatapipe_singleton_constraint_multiple_outputs(self):
3729*da0073e9SAndroid Build Coastguard Worker        r"""
3730*da0073e9SAndroid Build Coastguard Worker        Testing for the case where IterDataPipe has multiple child DataPipes as outputs.
3731*da0073e9SAndroid Build Coastguard Worker        """
3732*da0073e9SAndroid Build Coastguard Worker        # Functional Test: all previous related iterators should be invalidated when a new iterator
3733*da0073e9SAndroid Build Coastguard Worker        #                  is created from a ChildDataPipe
3734*da0073e9SAndroid Build Coastguard Worker        source_dp: IterDataPipe = dp.iter.IterableWrapper(range(10))
3735*da0073e9SAndroid Build Coastguard Worker        cdp1, cdp2 = source_dp.fork(num_instances=2)
3736*da0073e9SAndroid Build Coastguard Worker        it1, it2 = iter(cdp1), iter(cdp2)
3737*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), list(it1))
3738*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), list(it2))
3739*da0073e9SAndroid Build Coastguard Worker        it1, it2 = iter(cdp1), iter(cdp2)
3740*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
3741*da0073e9SAndroid Build Coastguard Worker            it3 = iter(cdp1)  # This should invalidate `it1` and `it2`
3742*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(wa), 1)
3743*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted")
3744*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3745*da0073e9SAndroid Build Coastguard Worker            next(it1)
3746*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3747*da0073e9SAndroid Build Coastguard Worker            next(it2)
3748*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, next(it3))
3749*da0073e9SAndroid Build Coastguard Worker        # The next line should not invalidate anything, as there was no new iterator created
3750*da0073e9SAndroid Build Coastguard Worker        # for `cdp2` after `it2` was invalidated
3751*da0073e9SAndroid Build Coastguard Worker        it4 = iter(cdp2)
3752*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, next(it3))  # An error shouldn't be raised here
3753*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), list(it4))
3754*da0073e9SAndroid Build Coastguard Worker
3755*da0073e9SAndroid Build Coastguard Worker        # Functional Test: invalidation when a new iterator is created from `source_dp`
3756*da0073e9SAndroid Build Coastguard Worker        source_dp = dp.iter.IterableWrapper(range(10))
3757*da0073e9SAndroid Build Coastguard Worker        cdp1, cdp2 = source_dp.fork(num_instances=2)
3758*da0073e9SAndroid Build Coastguard Worker        it1, it2 = iter(cdp1), iter(cdp2)
3759*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), list(it1))
3760*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), list(it2))
3761*da0073e9SAndroid Build Coastguard Worker        it1, it2 = iter(cdp1), iter(cdp2)
3762*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, next(it1))
3763*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, next(it2))
3764*da0073e9SAndroid Build Coastguard Worker        it3 = iter(source_dp)  # note that a new iterator is created from `source_dp`
3765*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3766*da0073e9SAndroid Build Coastguard Worker            0, next(it3)
3767*da0073e9SAndroid Build Coastguard Worker        )  # `it3` should invalidate `it1` and `it2` since they both use `source_dp`
3768*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3769*da0073e9SAndroid Build Coastguard Worker            next(it1)
3770*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, next(it3))
3771*da0073e9SAndroid Build Coastguard Worker
3772*da0073e9SAndroid Build Coastguard Worker        # Function Test: Extending test to pipeline
3773*da0073e9SAndroid Build Coastguard Worker        source_dp = (
3774*da0073e9SAndroid Build Coastguard Worker            dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn)
3775*da0073e9SAndroid Build Coastguard Worker        )
3776*da0073e9SAndroid Build Coastguard Worker        cdp1, cdp2 = source_dp.fork(num_instances=2)
3777*da0073e9SAndroid Build Coastguard Worker        it1, it2 = iter(cdp1), iter(cdp2)
3778*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), list(it1))
3779*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(range(10)), list(it2))
3780*da0073e9SAndroid Build Coastguard Worker        it1, it2 = iter(cdp1), iter(cdp2)
3781*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
3782*da0073e9SAndroid Build Coastguard Worker            it3 = iter(cdp1)  # This should invalidate `it1` and `it2`
3783*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(wa), 1)
3784*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted")
3785*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3786*da0073e9SAndroid Build Coastguard Worker            next(it1)
3787*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3788*da0073e9SAndroid Build Coastguard Worker            next(it2)
3789*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
3790*da0073e9SAndroid Build Coastguard Worker            it1, it2 = iter(cdp1), iter(cdp2)
3791*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(wa), 1)
3792*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted")
3793*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, next(it1))
3794*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, next(it2))
3795*da0073e9SAndroid Build Coastguard Worker        it3 = iter(source_dp)  # note that a new iterator is created from `source_dp`
3796*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3797*da0073e9SAndroid Build Coastguard Worker            0, next(it3)
3798*da0073e9SAndroid Build Coastguard Worker        )  # `it3` should invalidate `it1` and `it2` since they both use `source_dp`
3799*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3800*da0073e9SAndroid Build Coastguard Worker            next(it1)
3801*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, next(it3))
3802*da0073e9SAndroid Build Coastguard Worker
3803*da0073e9SAndroid Build Coastguard Worker
3804*da0073e9SAndroid Build Coastguard Workerclass TestIterDataPipeCountSampleYielded(TestCase):
3805*da0073e9SAndroid Build Coastguard Worker    def _yield_count_test_helper(self, datapipe, n_expected_samples):
3806*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Check if number of samples yielded is as expected
3807*da0073e9SAndroid Build Coastguard Worker        res = list(datapipe)
3808*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(res), datapipe._number_of_samples_yielded)
3809*da0073e9SAndroid Build Coastguard Worker
3810*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Check if the count is correct when DataPipe is partially read
3811*da0073e9SAndroid Build Coastguard Worker        it = iter(datapipe)
3812*da0073e9SAndroid Build Coastguard Worker        res = []
3813*da0073e9SAndroid Build Coastguard Worker        for i, value in enumerate(it):
3814*da0073e9SAndroid Build Coastguard Worker            res.append(value)
3815*da0073e9SAndroid Build Coastguard Worker            if i == n_expected_samples - 1:
3816*da0073e9SAndroid Build Coastguard Worker                break
3817*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n_expected_samples, datapipe._number_of_samples_yielded)
3818*da0073e9SAndroid Build Coastguard Worker
3819*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Check for reset behavior and if iterator also works
3820*da0073e9SAndroid Build Coastguard Worker        it = iter(datapipe)  # reset the DataPipe
3821*da0073e9SAndroid Build Coastguard Worker        res = list(it)
3822*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(res), datapipe._number_of_samples_yielded)
3823*da0073e9SAndroid Build Coastguard Worker
3824*da0073e9SAndroid Build Coastguard Worker    def test_iterdatapipe_sample_yielded_generator_function(self):
3825*da0073e9SAndroid Build Coastguard Worker        # Functional Test: `__iter__` is a generator function
3826*da0073e9SAndroid Build Coastguard Worker        datapipe: IterDataPipe = dp.iter.IterableWrapper(range(10))
3827*da0073e9SAndroid Build Coastguard Worker        self._yield_count_test_helper(datapipe, n_expected_samples=5)
3828*da0073e9SAndroid Build Coastguard Worker
3829*da0073e9SAndroid Build Coastguard Worker    def test_iterdatapipe_sample_yielded_generator_function_exception(self):
3830*da0073e9SAndroid Build Coastguard Worker        # Functional Test: `__iter__` is a custom generator function with exception
3831*da0073e9SAndroid Build Coastguard Worker        class _CustomGeneratorFnDataPipe(IterDataPipe):
3832*da0073e9SAndroid Build Coastguard Worker            # This class's `__iter__` has a Runtime Error
3833*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
3834*da0073e9SAndroid Build Coastguard Worker                yield 0
3835*da0073e9SAndroid Build Coastguard Worker                yield 1
3836*da0073e9SAndroid Build Coastguard Worker                yield 2
3837*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Custom test error after yielding 3 elements")
3838*da0073e9SAndroid Build Coastguard Worker                yield 3
3839*da0073e9SAndroid Build Coastguard Worker
3840*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Ensure the count is correct even when exception is raised
3841*da0073e9SAndroid Build Coastguard Worker        datapipe: IterDataPipe = _CustomGeneratorFnDataPipe()
3842*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3843*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Custom test error after yielding 3 elements"
3844*da0073e9SAndroid Build Coastguard Worker        ):
3845*da0073e9SAndroid Build Coastguard Worker            list(datapipe)
3846*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(3, datapipe._number_of_samples_yielded)
3847*da0073e9SAndroid Build Coastguard Worker
3848*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Check for reset behavior and if iterator also works
3849*da0073e9SAndroid Build Coastguard Worker        it = iter(datapipe)  # reset the DataPipe
3850*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3851*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Custom test error after yielding 3 elements"
3852*da0073e9SAndroid Build Coastguard Worker        ):
3853*da0073e9SAndroid Build Coastguard Worker            list(it)
3854*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(3, datapipe._number_of_samples_yielded)
3855*da0073e9SAndroid Build Coastguard Worker
3856*da0073e9SAndroid Build Coastguard Worker    def test_iterdatapipe_sample_yielded_return_self(self):
3857*da0073e9SAndroid Build Coastguard Worker        class _CustomGeneratorDataPipe(IterDataPipe):
3858*da0073e9SAndroid Build Coastguard Worker            # This class's `__iter__` is not a generator function
3859*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3860*da0073e9SAndroid Build Coastguard Worker                self.source = iter(range(10))
3861*da0073e9SAndroid Build Coastguard Worker
3862*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
3863*da0073e9SAndroid Build Coastguard Worker                return self.source
3864*da0073e9SAndroid Build Coastguard Worker
3865*da0073e9SAndroid Build Coastguard Worker            def reset(self):
3866*da0073e9SAndroid Build Coastguard Worker                self.source = iter(range(10))
3867*da0073e9SAndroid Build Coastguard Worker
3868*da0073e9SAndroid Build Coastguard Worker        datapipe: IterDataPipe = _CustomGeneratorDataPipe()
3869*da0073e9SAndroid Build Coastguard Worker        self._yield_count_test_helper(datapipe, n_expected_samples=5)
3870*da0073e9SAndroid Build Coastguard Worker
3871*da0073e9SAndroid Build Coastguard Worker    def test_iterdatapipe_sample_yielded_next(self):
3872*da0073e9SAndroid Build Coastguard Worker        class _CustomNextDataPipe(IterDataPipe):
3873*da0073e9SAndroid Build Coastguard Worker            # This class's `__iter__` returns `self` and has a `__next__`
3874*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3875*da0073e9SAndroid Build Coastguard Worker                self.source = iter(range(10))
3876*da0073e9SAndroid Build Coastguard Worker
3877*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
3878*da0073e9SAndroid Build Coastguard Worker                return self
3879*da0073e9SAndroid Build Coastguard Worker
3880*da0073e9SAndroid Build Coastguard Worker            def __next__(self):
3881*da0073e9SAndroid Build Coastguard Worker                return next(self.source)
3882*da0073e9SAndroid Build Coastguard Worker
3883*da0073e9SAndroid Build Coastguard Worker            def reset(self):
3884*da0073e9SAndroid Build Coastguard Worker                self.source = iter(range(10))
3885*da0073e9SAndroid Build Coastguard Worker
3886*da0073e9SAndroid Build Coastguard Worker        datapipe: IterDataPipe = _CustomNextDataPipe()
3887*da0073e9SAndroid Build Coastguard Worker        self._yield_count_test_helper(datapipe, n_expected_samples=5)
3888*da0073e9SAndroid Build Coastguard Worker
3889*da0073e9SAndroid Build Coastguard Worker    def test_iterdatapipe_sample_yielded_next_exception(self):
3890*da0073e9SAndroid Build Coastguard Worker        class _CustomNextDataPipe(IterDataPipe):
3891*da0073e9SAndroid Build Coastguard Worker            # This class's `__iter__` returns `self` and has a `__next__`
3892*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3893*da0073e9SAndroid Build Coastguard Worker                self.source = iter(range(10))
3894*da0073e9SAndroid Build Coastguard Worker                self.count = 0
3895*da0073e9SAndroid Build Coastguard Worker
3896*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
3897*da0073e9SAndroid Build Coastguard Worker                return self
3898*da0073e9SAndroid Build Coastguard Worker
3899*da0073e9SAndroid Build Coastguard Worker            def __next__(self):
3900*da0073e9SAndroid Build Coastguard Worker                if self.count == 3:
3901*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("Custom test error after yielding 3 elements")
3902*da0073e9SAndroid Build Coastguard Worker                self.count += 1
3903*da0073e9SAndroid Build Coastguard Worker                return next(self.source)
3904*da0073e9SAndroid Build Coastguard Worker
3905*da0073e9SAndroid Build Coastguard Worker            def reset(self):
3906*da0073e9SAndroid Build Coastguard Worker                self.count = 0
3907*da0073e9SAndroid Build Coastguard Worker                self.source = iter(range(10))
3908*da0073e9SAndroid Build Coastguard Worker
3909*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Ensure the count is correct even when exception is raised
3910*da0073e9SAndroid Build Coastguard Worker        datapipe: IterDataPipe = _CustomNextDataPipe()
3911*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3912*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Custom test error after yielding 3 elements"
3913*da0073e9SAndroid Build Coastguard Worker        ):
3914*da0073e9SAndroid Build Coastguard Worker            list(datapipe)
3915*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(3, datapipe._number_of_samples_yielded)
3916*da0073e9SAndroid Build Coastguard Worker
3917*da0073e9SAndroid Build Coastguard Worker        # Functional Test: Check for reset behavior and if iterator also works
3918*da0073e9SAndroid Build Coastguard Worker        it = iter(datapipe)  # reset the DataPipe
3919*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3920*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Custom test error after yielding 3 elements"
3921*da0073e9SAndroid Build Coastguard Worker        ):
3922*da0073e9SAndroid Build Coastguard Worker            list(it)
3923*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(3, datapipe._number_of_samples_yielded)
3924*da0073e9SAndroid Build Coastguard Worker
3925*da0073e9SAndroid Build Coastguard Worker
3926*da0073e9SAndroid Build Coastguard Workerclass _CustomNonGeneratorTestDataPipe(IterDataPipe):
3927*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
3928*da0073e9SAndroid Build Coastguard Worker        self.n = 10
3929*da0073e9SAndroid Build Coastguard Worker        self.source = list(range(self.n))
3930*da0073e9SAndroid Build Coastguard Worker
3931*da0073e9SAndroid Build Coastguard Worker    # This class's `__iter__` is not a generator function
3932*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
3933*da0073e9SAndroid Build Coastguard Worker        return iter(self.source)
3934*da0073e9SAndroid Build Coastguard Worker
3935*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
3936*da0073e9SAndroid Build Coastguard Worker        return self.n
3937*da0073e9SAndroid Build Coastguard Worker
3938*da0073e9SAndroid Build Coastguard Worker
3939*da0073e9SAndroid Build Coastguard Workerclass _CustomSelfNextTestDataPipe(IterDataPipe):
3940*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
3941*da0073e9SAndroid Build Coastguard Worker        self.n = 10
3942*da0073e9SAndroid Build Coastguard Worker        self.iter = iter(range(self.n))
3943*da0073e9SAndroid Build Coastguard Worker
3944*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
3945*da0073e9SAndroid Build Coastguard Worker        return self
3946*da0073e9SAndroid Build Coastguard Worker
3947*da0073e9SAndroid Build Coastguard Worker    def __next__(self):
3948*da0073e9SAndroid Build Coastguard Worker        return next(self.iter)
3949*da0073e9SAndroid Build Coastguard Worker
3950*da0073e9SAndroid Build Coastguard Worker    def reset(self):
3951*da0073e9SAndroid Build Coastguard Worker        self.iter = iter(range(self.n))
3952*da0073e9SAndroid Build Coastguard Worker
3953*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
3954*da0073e9SAndroid Build Coastguard Worker        return self.n
3955*da0073e9SAndroid Build Coastguard Worker
3956*da0073e9SAndroid Build Coastguard Worker
3957*da0073e9SAndroid Build Coastguard Workerclass TestIterDataPipeGraphFastForward(TestCase):
3958*da0073e9SAndroid Build Coastguard Worker    def _fast_forward_graph_test_helper(
3959*da0073e9SAndroid Build Coastguard Worker        self, datapipe, fast_forward_fn, expected_res, n_iterations=3, rng=None
3960*da0073e9SAndroid Build Coastguard Worker    ):
3961*da0073e9SAndroid Build Coastguard Worker        if rng is None:
3962*da0073e9SAndroid Build Coastguard Worker            rng = torch.Generator()
3963*da0073e9SAndroid Build Coastguard Worker        rng = rng.manual_seed(0)
3964*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_random_seed(datapipe, rng)
3965*da0073e9SAndroid Build Coastguard Worker
3966*da0073e9SAndroid Build Coastguard Worker        # Test Case: fast forward works with list
3967*da0073e9SAndroid Build Coastguard Worker        rng.manual_seed(0)
3968*da0073e9SAndroid Build Coastguard Worker        fast_forward_fn(datapipe, n_iterations, rng)
3969*da0073e9SAndroid Build Coastguard Worker        actual_res = list(datapipe)
3970*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(datapipe) - n_iterations, len(actual_res))
3971*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_res[n_iterations:], actual_res)
3972*da0073e9SAndroid Build Coastguard Worker
3973*da0073e9SAndroid Build Coastguard Worker        # Test Case: fast forward works with iterator
3974*da0073e9SAndroid Build Coastguard Worker        rng.manual_seed(0)
3975*da0073e9SAndroid Build Coastguard Worker        fast_forward_fn(datapipe, n_iterations, rng)
3976*da0073e9SAndroid Build Coastguard Worker        it = iter(datapipe)
3977*da0073e9SAndroid Build Coastguard Worker        actual_res = list(it)
3978*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(datapipe) - n_iterations, len(actual_res))
3979*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_res[n_iterations:], actual_res)
3980*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(StopIteration):
3981*da0073e9SAndroid Build Coastguard Worker            next(it)
3982*da0073e9SAndroid Build Coastguard Worker
3983*da0073e9SAndroid Build Coastguard Worker    def test_simple_snapshot_graph(self):
3984*da0073e9SAndroid Build Coastguard Worker        graph1 = dp.iter.IterableWrapper(range(10))
3985*da0073e9SAndroid Build Coastguard Worker        res1 = list(range(10))
3986*da0073e9SAndroid Build Coastguard Worker        self._fast_forward_graph_test_helper(
3987*da0073e9SAndroid Build Coastguard Worker            graph1, _simple_graph_snapshot_restoration, expected_res=res1
3988*da0073e9SAndroid Build Coastguard Worker        )
3989*da0073e9SAndroid Build Coastguard Worker
3990*da0073e9SAndroid Build Coastguard Worker        graph2 = graph1.map(_mul_10)
3991*da0073e9SAndroid Build Coastguard Worker        res2 = [10 * x for x in res1]
3992*da0073e9SAndroid Build Coastguard Worker        self._fast_forward_graph_test_helper(
3993*da0073e9SAndroid Build Coastguard Worker            graph2, _simple_graph_snapshot_restoration, expected_res=res2
3994*da0073e9SAndroid Build Coastguard Worker        )
3995*da0073e9SAndroid Build Coastguard Worker
3996*da0073e9SAndroid Build Coastguard Worker        rng = torch.Generator()
3997*da0073e9SAndroid Build Coastguard Worker        graph3 = graph2.shuffle()
3998*da0073e9SAndroid Build Coastguard Worker        rng.manual_seed(0)
3999*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_random_seed(graph3, rng)
4000*da0073e9SAndroid Build Coastguard Worker        res3 = list(graph3)
4001*da0073e9SAndroid Build Coastguard Worker        self._fast_forward_graph_test_helper(
4002*da0073e9SAndroid Build Coastguard Worker            graph3, _simple_graph_snapshot_restoration, expected_res=res3
4003*da0073e9SAndroid Build Coastguard Worker        )
4004*da0073e9SAndroid Build Coastguard Worker
4005*da0073e9SAndroid Build Coastguard Worker        graph4 = graph3.map(_mul_10)
4006*da0073e9SAndroid Build Coastguard Worker        res4 = [10 * x for x in res3]
4007*da0073e9SAndroid Build Coastguard Worker        self._fast_forward_graph_test_helper(
4008*da0073e9SAndroid Build Coastguard Worker            graph4, _simple_graph_snapshot_restoration, expected_res=res4
4009*da0073e9SAndroid Build Coastguard Worker        )
4010*da0073e9SAndroid Build Coastguard Worker
4011*da0073e9SAndroid Build Coastguard Worker        batch_size = 2
4012*da0073e9SAndroid Build Coastguard Worker        graph5 = graph4.batch(batch_size)
4013*da0073e9SAndroid Build Coastguard Worker        res5 = [
4014*da0073e9SAndroid Build Coastguard Worker            res4[i : i + batch_size] for i in range(0, len(res4), batch_size)
4015*da0073e9SAndroid Build Coastguard Worker        ]  # .batch(2)
4016*da0073e9SAndroid Build Coastguard Worker        self._fast_forward_graph_test_helper(
4017*da0073e9SAndroid Build Coastguard Worker            graph5, _simple_graph_snapshot_restoration, expected_res=res5
4018*da0073e9SAndroid Build Coastguard Worker        )
4019*da0073e9SAndroid Build Coastguard Worker
4020*da0073e9SAndroid Build Coastguard Worker        # With `fork` and `zip`
4021*da0073e9SAndroid Build Coastguard Worker        cdp1, cdp2 = graph5.fork(2)
4022*da0073e9SAndroid Build Coastguard Worker        graph6 = cdp1.zip(cdp2)
4023*da0073e9SAndroid Build Coastguard Worker        rng = rng.manual_seed(100)
4024*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_random_seed(graph6, rng)
4025*da0073e9SAndroid Build Coastguard Worker        res6 = [(x, x) for x in res5]
4026*da0073e9SAndroid Build Coastguard Worker        self._fast_forward_graph_test_helper(
4027*da0073e9SAndroid Build Coastguard Worker            graph6, _simple_graph_snapshot_restoration, expected_res=res6
4028*da0073e9SAndroid Build Coastguard Worker        )
4029*da0073e9SAndroid Build Coastguard Worker
4030*da0073e9SAndroid Build Coastguard Worker        # With `fork` and `concat`
4031*da0073e9SAndroid Build Coastguard Worker        graph7 = cdp1.concat(cdp2)
4032*da0073e9SAndroid Build Coastguard Worker        res7 = res5 * 2
4033*da0073e9SAndroid Build Coastguard Worker        self._fast_forward_graph_test_helper(
4034*da0073e9SAndroid Build Coastguard Worker            graph7, _simple_graph_snapshot_restoration, expected_res=res7
4035*da0073e9SAndroid Build Coastguard Worker        )
4036*da0073e9SAndroid Build Coastguard Worker
4037*da0073e9SAndroid Build Coastguard Worker        # Raises an exception if the graph has already been restored
4038*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
4039*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Snapshot restoration cannot be applied."
4040*da0073e9SAndroid Build Coastguard Worker        ):
4041*da0073e9SAndroid Build Coastguard Worker            _simple_graph_snapshot_restoration(graph7, 1)
4042*da0073e9SAndroid Build Coastguard Worker            _simple_graph_snapshot_restoration(graph7, 1)
4043*da0073e9SAndroid Build Coastguard Worker
4044*da0073e9SAndroid Build Coastguard Worker    def test_simple_snapshot_custom_non_generator(self):
4045*da0073e9SAndroid Build Coastguard Worker        graph = _CustomNonGeneratorTestDataPipe()
4046*da0073e9SAndroid Build Coastguard Worker        self._fast_forward_graph_test_helper(
4047*da0073e9SAndroid Build Coastguard Worker            graph, _simple_graph_snapshot_restoration, expected_res=range(10)
4048*da0073e9SAndroid Build Coastguard Worker        )
4049*da0073e9SAndroid Build Coastguard Worker
4050*da0073e9SAndroid Build Coastguard Worker    def test_simple_snapshot_custom_self_next(self):
4051*da0073e9SAndroid Build Coastguard Worker        graph = _CustomSelfNextTestDataPipe()
4052*da0073e9SAndroid Build Coastguard Worker        self._fast_forward_graph_test_helper(
4053*da0073e9SAndroid Build Coastguard Worker            graph, _simple_graph_snapshot_restoration, expected_res=range(10)
4054*da0073e9SAndroid Build Coastguard Worker        )
4055*da0073e9SAndroid Build Coastguard Worker
4056*da0073e9SAndroid Build Coastguard Worker    def _snapshot_test_helper(self, datapipe, expected_res, n_iter=3, rng=None):
4057*da0073e9SAndroid Build Coastguard Worker        """
4058*da0073e9SAndroid Build Coastguard Worker        Extend the previous test with serialization and deserialization test.
4059*da0073e9SAndroid Build Coastguard Worker        """
4060*da0073e9SAndroid Build Coastguard Worker        if rng is None:
4061*da0073e9SAndroid Build Coastguard Worker            rng = torch.Generator()
4062*da0073e9SAndroid Build Coastguard Worker        rng.manual_seed(0)
4063*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_random_seed(datapipe, rng)
4064*da0073e9SAndroid Build Coastguard Worker        it = iter(datapipe)
4065*da0073e9SAndroid Build Coastguard Worker        for _ in range(n_iter):
4066*da0073e9SAndroid Build Coastguard Worker            next(it)
4067*da0073e9SAndroid Build Coastguard Worker        serialized_graph = pickle.dumps(datapipe)
4068*da0073e9SAndroid Build Coastguard Worker        deserialized_graph = pickle.loads(serialized_graph)
4069*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n_iter, datapipe._number_of_samples_yielded)
4070*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n_iter, deserialized_graph._number_of_samples_yielded)
4071*da0073e9SAndroid Build Coastguard Worker
4072*da0073e9SAndroid Build Coastguard Worker        rng_for_deserialized = torch.Generator()
4073*da0073e9SAndroid Build Coastguard Worker        rng_for_deserialized.manual_seed(0)
4074*da0073e9SAndroid Build Coastguard Worker        _simple_graph_snapshot_restoration(
4075*da0073e9SAndroid Build Coastguard Worker            deserialized_graph, n_iter, rng=rng_for_deserialized
4076*da0073e9SAndroid Build Coastguard Worker        )
4077*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_res[n_iter:], list(it))
4078*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_res[n_iter:], list(deserialized_graph))
4079*da0073e9SAndroid Build Coastguard Worker
4080*da0073e9SAndroid Build Coastguard Worker    def test_simple_snapshot_graph_with_serialization(self):
4081*da0073e9SAndroid Build Coastguard Worker        graph1 = dp.iter.IterableWrapper(range(10))
4082*da0073e9SAndroid Build Coastguard Worker        res1 = list(range(10))
4083*da0073e9SAndroid Build Coastguard Worker        self._snapshot_test_helper(graph1, expected_res=res1)
4084*da0073e9SAndroid Build Coastguard Worker
4085*da0073e9SAndroid Build Coastguard Worker        graph2 = graph1.map(_mul_10)
4086*da0073e9SAndroid Build Coastguard Worker        res2 = [10 * x for x in res1]
4087*da0073e9SAndroid Build Coastguard Worker        self._snapshot_test_helper(graph2, expected_res=res2)
4088*da0073e9SAndroid Build Coastguard Worker
4089*da0073e9SAndroid Build Coastguard Worker        rng = torch.Generator()
4090*da0073e9SAndroid Build Coastguard Worker        graph3 = graph2.shuffle()
4091*da0073e9SAndroid Build Coastguard Worker        rng.manual_seed(0)
4092*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_random_seed(graph3, rng)
4093*da0073e9SAndroid Build Coastguard Worker        res3 = list(graph3)
4094*da0073e9SAndroid Build Coastguard Worker        self._snapshot_test_helper(graph3, expected_res=res3)
4095*da0073e9SAndroid Build Coastguard Worker
4096*da0073e9SAndroid Build Coastguard Worker        graph4 = graph3.map(_mul_10)
4097*da0073e9SAndroid Build Coastguard Worker        res4 = [10 * x for x in res3]
4098*da0073e9SAndroid Build Coastguard Worker        self._snapshot_test_helper(graph4, expected_res=res4)
4099*da0073e9SAndroid Build Coastguard Worker
4100*da0073e9SAndroid Build Coastguard Worker        batch_size = 2
4101*da0073e9SAndroid Build Coastguard Worker        graph5 = graph4.batch(batch_size)
4102*da0073e9SAndroid Build Coastguard Worker        res5 = [
4103*da0073e9SAndroid Build Coastguard Worker            res4[i : i + batch_size] for i in range(0, len(res4), batch_size)
4104*da0073e9SAndroid Build Coastguard Worker        ]  # .batch(2)
4105*da0073e9SAndroid Build Coastguard Worker        self._snapshot_test_helper(graph5, expected_res=res5)
4106*da0073e9SAndroid Build Coastguard Worker
4107*da0073e9SAndroid Build Coastguard Worker        # With `fork` and `zip`
4108*da0073e9SAndroid Build Coastguard Worker        cdp1, cdp2 = graph5.fork(2)
4109*da0073e9SAndroid Build Coastguard Worker        graph6 = cdp1.zip(cdp2)
4110*da0073e9SAndroid Build Coastguard Worker        res6 = [(x, x) for x in res5]
4111*da0073e9SAndroid Build Coastguard Worker        self._snapshot_test_helper(graph6, expected_res=res6)
4112*da0073e9SAndroid Build Coastguard Worker
4113*da0073e9SAndroid Build Coastguard Worker        # With `fork` and `concat`
4114*da0073e9SAndroid Build Coastguard Worker        graph7 = cdp1.concat(cdp2)
4115*da0073e9SAndroid Build Coastguard Worker        res7 = res5 * 2
4116*da0073e9SAndroid Build Coastguard Worker        self._snapshot_test_helper(graph7, expected_res=res7)
4117*da0073e9SAndroid Build Coastguard Worker
4118*da0073e9SAndroid Build Coastguard Worker    def test_simple_snapshot_graph_repeated(self):
4119*da0073e9SAndroid Build Coastguard Worker        cdp1, cdp2 = (
4120*da0073e9SAndroid Build Coastguard Worker            dp.iter.IterableWrapper(range(10))
4121*da0073e9SAndroid Build Coastguard Worker            .map(_mul_10)
4122*da0073e9SAndroid Build Coastguard Worker            .shuffle()
4123*da0073e9SAndroid Build Coastguard Worker            .map(_mul_10)
4124*da0073e9SAndroid Build Coastguard Worker            .map(_mul_10)
4125*da0073e9SAndroid Build Coastguard Worker            .fork(2)
4126*da0073e9SAndroid Build Coastguard Worker        )
4127*da0073e9SAndroid Build Coastguard Worker        graph = cdp1.zip(cdp2)
4128*da0073e9SAndroid Build Coastguard Worker
4129*da0073e9SAndroid Build Coastguard Worker        rng = torch.Generator()
4130*da0073e9SAndroid Build Coastguard Worker        rng.manual_seed(0)
4131*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_random_seed(graph, rng)
4132*da0073e9SAndroid Build Coastguard Worker
4133*da0073e9SAndroid Build Coastguard Worker        # Get expected result
4134*da0073e9SAndroid Build Coastguard Worker        expected_res = list(graph)
4135*da0073e9SAndroid Build Coastguard Worker
4136*da0073e9SAndroid Build Coastguard Worker        rng.manual_seed(0)
4137*da0073e9SAndroid Build Coastguard Worker        torch.utils.data.graph_settings.apply_random_seed(graph, rng)
4138*da0073e9SAndroid Build Coastguard Worker        it = iter(graph)
4139*da0073e9SAndroid Build Coastguard Worker        n_iter = 3
4140*da0073e9SAndroid Build Coastguard Worker        for _ in range(n_iter):
4141*da0073e9SAndroid Build Coastguard Worker            next(it)
4142*da0073e9SAndroid Build Coastguard Worker
4143*da0073e9SAndroid Build Coastguard Worker        # First serialization/deserialization
4144*da0073e9SAndroid Build Coastguard Worker        serialized_graph = pickle.dumps(graph)
4145*da0073e9SAndroid Build Coastguard Worker        deserialized_graph = pickle.loads(serialized_graph)
4146*da0073e9SAndroid Build Coastguard Worker
4147*da0073e9SAndroid Build Coastguard Worker        rng_for_deserialized = torch.Generator()
4148*da0073e9SAndroid Build Coastguard Worker        rng_for_deserialized.manual_seed(0)
4149*da0073e9SAndroid Build Coastguard Worker        _simple_graph_snapshot_restoration(
4150*da0073e9SAndroid Build Coastguard Worker            deserialized_graph,
4151*da0073e9SAndroid Build Coastguard Worker            deserialized_graph._number_of_samples_yielded,
4152*da0073e9SAndroid Build Coastguard Worker            rng=rng_for_deserialized,
4153*da0073e9SAndroid Build Coastguard Worker        )
4154*da0073e9SAndroid Build Coastguard Worker
4155*da0073e9SAndroid Build Coastguard Worker        it = iter(deserialized_graph)
4156*da0073e9SAndroid Build Coastguard Worker        # Get the next element and ensure it is as expected
4157*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_res[3], next(it))
4158*da0073e9SAndroid Build Coastguard Worker
4159*da0073e9SAndroid Build Coastguard Worker        # Serializalize/Deserialize and fast-forward again after to ensure it works
4160*da0073e9SAndroid Build Coastguard Worker        serialized_graph2 = pickle.dumps(deserialized_graph)
4161*da0073e9SAndroid Build Coastguard Worker        deserialized_graph2 = pickle.loads(serialized_graph2)
4162*da0073e9SAndroid Build Coastguard Worker
4163*da0073e9SAndroid Build Coastguard Worker        rng_for_deserialized = torch.Generator()
4164*da0073e9SAndroid Build Coastguard Worker        rng_for_deserialized.manual_seed(0)
4165*da0073e9SAndroid Build Coastguard Worker        _simple_graph_snapshot_restoration(
4166*da0073e9SAndroid Build Coastguard Worker            deserialized_graph2,
4167*da0073e9SAndroid Build Coastguard Worker            deserialized_graph._number_of_samples_yielded,
4168*da0073e9SAndroid Build Coastguard Worker            rng=rng_for_deserialized,
4169*da0073e9SAndroid Build Coastguard Worker        )
4170*da0073e9SAndroid Build Coastguard Worker
4171*da0073e9SAndroid Build Coastguard Worker        # Get the next element and ensure it is as expected
4172*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_res[4:], list(deserialized_graph2))
4173*da0073e9SAndroid Build Coastguard Worker
4174*da0073e9SAndroid Build Coastguard Worker
4175*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
4176*da0073e9SAndroid Build Coastguard Worker    run_tests()
4177