1*da0073e9SAndroid Build Coastguard Worker# mypy: ignore-errors 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dataloader"] 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport copy 6*da0073e9SAndroid Build Coastguard Workerimport itertools 7*da0073e9SAndroid Build Coastguard Workerimport os 8*da0073e9SAndroid Build Coastguard Workerimport os.path 9*da0073e9SAndroid Build Coastguard Workerimport pickle 10*da0073e9SAndroid Build Coastguard Workerimport pydoc 11*da0073e9SAndroid Build Coastguard Workerimport random 12*da0073e9SAndroid Build Coastguard Workerimport sys 13*da0073e9SAndroid Build Coastguard Workerimport tempfile 14*da0073e9SAndroid Build Coastguard Workerimport warnings 15*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 16*da0073e9SAndroid Build Coastguard Workerfrom typing import ( 17*da0073e9SAndroid Build Coastguard Worker Any, 18*da0073e9SAndroid Build Coastguard Worker Awaitable, 19*da0073e9SAndroid Build Coastguard Worker Dict, 20*da0073e9SAndroid Build Coastguard Worker Generic, 21*da0073e9SAndroid Build Coastguard Worker Iterator, 22*da0073e9SAndroid Build Coastguard Worker List, 23*da0073e9SAndroid Build Coastguard Worker Optional, 24*da0073e9SAndroid Build Coastguard Worker Set, 25*da0073e9SAndroid Build Coastguard Worker Tuple, 26*da0073e9SAndroid Build Coastguard Worker Type, 27*da0073e9SAndroid Build Coastguard Worker TYPE_CHECKING, 28*da0073e9SAndroid Build Coastguard Worker TypeVar, 29*da0073e9SAndroid Build Coastguard Worker Union, 30*da0073e9SAndroid Build Coastguard Worker) 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Workerif not TYPE_CHECKING: 33*da0073e9SAndroid Build Coastguard Worker # pyre isn't treating this the same as a typing.NamedTuple 34*da0073e9SAndroid Build Coastguard Worker from typing_extensions import NamedTuple 35*da0073e9SAndroid Build Coastguard Workerelse: 36*da0073e9SAndroid Build Coastguard Worker from typing import NamedTuple 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Workerimport operator 39*da0073e9SAndroid Build Coastguard Workerfrom unittest import skipIf 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Workerimport numpy as np 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Workerimport torch 44*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 45*da0073e9SAndroid Build Coastguard Workerimport torch.utils.data.datapipes as dp 46*da0073e9SAndroid Build Coastguard Workerimport torch.utils.data.graph 47*da0073e9SAndroid Build Coastguard Workerimport torch.utils.data.graph_settings 48*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 49*da0073e9SAndroid Build Coastguard Worker run_tests, 50*da0073e9SAndroid Build Coastguard Worker skipIfNoDill, 51*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 52*da0073e9SAndroid Build Coastguard Worker suppress_warnings, 53*da0073e9SAndroid Build Coastguard Worker TEST_DILL, 54*da0073e9SAndroid Build Coastguard Worker TestCase, 55*da0073e9SAndroid Build Coastguard Worker) 56*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._import_utils import import_dill 57*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data import ( 58*da0073e9SAndroid Build Coastguard Worker argument_validation, 59*da0073e9SAndroid Build Coastguard Worker DataChunk, 60*da0073e9SAndroid Build Coastguard Worker DataLoader, 61*da0073e9SAndroid Build Coastguard Worker IterDataPipe, 62*da0073e9SAndroid Build Coastguard Worker MapDataPipe, 63*da0073e9SAndroid Build Coastguard Worker RandomSampler, 64*da0073e9SAndroid Build Coastguard Worker runtime_validation, 65*da0073e9SAndroid Build Coastguard Worker runtime_validation_disabled, 66*da0073e9SAndroid Build Coastguard Worker) 67*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.datapipes.dataframe import ( 68*da0073e9SAndroid Build Coastguard Worker CaptureDataFrame, 69*da0073e9SAndroid Build Coastguard Worker dataframe_wrapper as df_wrapper, 70*da0073e9SAndroid Build Coastguard Worker) 71*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES 72*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.datapipes.utils.common import StreamWrapper 73*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.datapipes.utils.decoder import ( 74*da0073e9SAndroid Build Coastguard Worker basichandlers as decoder_basichandlers, 75*da0073e9SAndroid Build Coastguard Worker) 76*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration 77*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data.graph import traverse_dps 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Workerdill = import_dill() 80*da0073e9SAndroid Build Coastguard WorkerHAS_DILL = TEST_DILL 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Workertry: 83*da0073e9SAndroid Build Coastguard Worker import pandas # type: ignore[import] # noqa: F401 F403 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker HAS_PANDAS = True 86*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 87*da0073e9SAndroid Build Coastguard Worker HAS_PANDAS = False 88*da0073e9SAndroid Build Coastguard WorkerskipIfNoDataFrames = skipIf(not HAS_PANDAS, "no dataframes (pandas)") 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard WorkerskipTyping = skipIf(True, "TODO: Fix typing bug") 91*da0073e9SAndroid Build Coastguard WorkerT_co = TypeVar("T_co", covariant=True) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Workerdef create_temp_dir_and_files(): 95*da0073e9SAndroid Build Coastguard Worker # The temp dir and files within it will be released and deleted in tearDown(). 96*da0073e9SAndroid Build Coastguard Worker # Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function. 97*da0073e9SAndroid Build Coastguard Worker temp_dir = tempfile.TemporaryDirectory() # noqa: P201 98*da0073e9SAndroid Build Coastguard Worker temp_dir_path = temp_dir.name 99*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile( 100*da0073e9SAndroid Build Coastguard Worker dir=temp_dir_path, delete=False, suffix=".txt" 101*da0073e9SAndroid Build Coastguard Worker ) as f: 102*da0073e9SAndroid Build Coastguard Worker temp_file1_name = f.name 103*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile( 104*da0073e9SAndroid Build Coastguard Worker dir=temp_dir_path, delete=False, suffix=".byte" 105*da0073e9SAndroid Build Coastguard Worker ) as f: 106*da0073e9SAndroid Build Coastguard Worker temp_file2_name = f.name 107*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile( 108*da0073e9SAndroid Build Coastguard Worker dir=temp_dir_path, delete=False, suffix=".empty" 109*da0073e9SAndroid Build Coastguard Worker ) as f: 110*da0073e9SAndroid Build Coastguard Worker temp_file3_name = f.name 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker with open(temp_file1_name, "w") as f1: 113*da0073e9SAndroid Build Coastguard Worker f1.write("0123456789abcdef") 114*da0073e9SAndroid Build Coastguard Worker with open(temp_file2_name, "wb") as f2: 115*da0073e9SAndroid Build Coastguard Worker f2.write(b"0123456789abcdef") 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker temp_sub_dir = tempfile.TemporaryDirectory(dir=temp_dir_path) # noqa: P201 118*da0073e9SAndroid Build Coastguard Worker temp_sub_dir_path = temp_sub_dir.name 119*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile( 120*da0073e9SAndroid Build Coastguard Worker dir=temp_sub_dir_path, delete=False, suffix=".txt" 121*da0073e9SAndroid Build Coastguard Worker ) as f: 122*da0073e9SAndroid Build Coastguard Worker temp_sub_file1_name = f.name 123*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile( 124*da0073e9SAndroid Build Coastguard Worker dir=temp_sub_dir_path, delete=False, suffix=".byte" 125*da0073e9SAndroid Build Coastguard Worker ) as f: 126*da0073e9SAndroid Build Coastguard Worker temp_sub_file2_name = f.name 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker with open(temp_sub_file1_name, "w") as f1: 129*da0073e9SAndroid Build Coastguard Worker f1.write("0123456789abcdef") 130*da0073e9SAndroid Build Coastguard Worker with open(temp_sub_file2_name, "wb") as f2: 131*da0073e9SAndroid Build Coastguard Worker f2.write(b"0123456789abcdef") 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker return [ 134*da0073e9SAndroid Build Coastguard Worker (temp_dir, temp_file1_name, temp_file2_name, temp_file3_name), 135*da0073e9SAndroid Build Coastguard Worker (temp_sub_dir, temp_sub_file1_name, temp_sub_file2_name), 136*da0073e9SAndroid Build Coastguard Worker ] 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Workerdef reset_after_n_next_calls( 140*da0073e9SAndroid Build Coastguard Worker datapipe: Union[IterDataPipe[T_co], MapDataPipe[T_co]], n: int 141*da0073e9SAndroid Build Coastguard Worker) -> Tuple[List[T_co], List[T_co]]: 142*da0073e9SAndroid Build Coastguard Worker """ 143*da0073e9SAndroid Build Coastguard Worker Given a DataPipe and integer n, iterate the DataPipe for n elements and store the elements into a list 144*da0073e9SAndroid Build Coastguard Worker Then, reset the DataPipe and return a tuple of two lists 145*da0073e9SAndroid Build Coastguard Worker 1. A list of elements yielded before the reset 146*da0073e9SAndroid Build Coastguard Worker 2. A list of all elements of the DataPipe after the reset 147*da0073e9SAndroid Build Coastguard Worker """ 148*da0073e9SAndroid Build Coastguard Worker it = iter(datapipe) 149*da0073e9SAndroid Build Coastguard Worker res_before_reset = [] 150*da0073e9SAndroid Build Coastguard Worker for _ in range(n): 151*da0073e9SAndroid Build Coastguard Worker res_before_reset.append(next(it)) 152*da0073e9SAndroid Build Coastguard Worker return res_before_reset, list(datapipe) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Workerdef odd_or_even(x: int) -> int: 156*da0073e9SAndroid Build Coastguard Worker return x % 2 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Workerclass TestDataChunk(TestCase): 160*da0073e9SAndroid Build Coastguard Worker def setUp(self): 161*da0073e9SAndroid Build Coastguard Worker self.elements = list(range(10)) 162*da0073e9SAndroid Build Coastguard Worker random.shuffle(self.elements) 163*da0073e9SAndroid Build Coastguard Worker self.chunk: DataChunk[int] = DataChunk(self.elements) 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker def test_getitem(self): 166*da0073e9SAndroid Build Coastguard Worker for i in range(10): 167*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.elements[i], self.chunk[i]) 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker def test_iter(self): 170*da0073e9SAndroid Build Coastguard Worker for ele, dc in zip(self.elements, iter(self.chunk)): 171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ele, dc) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker def test_len(self): 174*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(self.elements), len(self.chunk)) 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker def test_as_string(self): 177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(self.chunk), str(self.elements)) 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker batch = [self.elements] * 3 180*da0073e9SAndroid Build Coastguard Worker chunks: List[DataChunk[int]] = [DataChunk(self.elements)] * 3 181*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(batch), str(chunks)) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker def test_sort(self): 184*da0073e9SAndroid Build Coastguard Worker chunk: DataChunk[int] = DataChunk(self.elements) 185*da0073e9SAndroid Build Coastguard Worker chunk.sort() 186*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(chunk, DataChunk)) 187*da0073e9SAndroid Build Coastguard Worker for i, d in enumerate(chunk): 188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(i, d) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker def test_reverse(self): 191*da0073e9SAndroid Build Coastguard Worker chunk: DataChunk[int] = DataChunk(self.elements) 192*da0073e9SAndroid Build Coastguard Worker chunk.reverse() 193*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(chunk, DataChunk)) 194*da0073e9SAndroid Build Coastguard Worker for i in range(10): 195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chunk[i], self.elements[9 - i]) 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker def test_random_shuffle(self): 198*da0073e9SAndroid Build Coastguard Worker elements = list(range(10)) 199*da0073e9SAndroid Build Coastguard Worker chunk: DataChunk[int] = DataChunk(elements) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker rng = random.Random(0) 202*da0073e9SAndroid Build Coastguard Worker rng.shuffle(chunk) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker rng = random.Random(0) 205*da0073e9SAndroid Build Coastguard Worker rng.shuffle(elements) 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chunk, elements) 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Workerclass TestStreamWrapper(TestCase): 211*da0073e9SAndroid Build Coastguard Worker class _FakeFD: 212*da0073e9SAndroid Build Coastguard Worker def __init__(self, filepath): 213*da0073e9SAndroid Build Coastguard Worker self.filepath = filepath 214*da0073e9SAndroid Build Coastguard Worker self.opened = False 215*da0073e9SAndroid Build Coastguard Worker self.closed = False 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker def open(self): 218*da0073e9SAndroid Build Coastguard Worker self.opened = True 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker def read(self): 221*da0073e9SAndroid Build Coastguard Worker if self.opened: 222*da0073e9SAndroid Build Coastguard Worker return "".join(self) 223*da0073e9SAndroid Build Coastguard Worker else: 224*da0073e9SAndroid Build Coastguard Worker raise OSError("Cannot read from un-opened file descriptor") 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 227*da0073e9SAndroid Build Coastguard Worker for i in range(5): 228*da0073e9SAndroid Build Coastguard Worker yield str(i) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker def close(self): 231*da0073e9SAndroid Build Coastguard Worker if self.opened: 232*da0073e9SAndroid Build Coastguard Worker self.opened = False 233*da0073e9SAndroid Build Coastguard Worker self.closed = True 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 236*da0073e9SAndroid Build Coastguard Worker return "FakeFD" 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker def test_dir(self): 239*da0073e9SAndroid Build Coastguard Worker fd = TestStreamWrapper._FakeFD("") 240*da0073e9SAndroid Build Coastguard Worker wrap_fd = StreamWrapper(fd) 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker s = set(dir(wrap_fd)) 243*da0073e9SAndroid Build Coastguard Worker for api in ["open", "read", "close"]: 244*da0073e9SAndroid Build Coastguard Worker self.assertTrue(api in s) 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 247*da0073e9SAndroid Build Coastguard Worker def test_api(self): 248*da0073e9SAndroid Build Coastguard Worker fd = TestStreamWrapper._FakeFD("") 249*da0073e9SAndroid Build Coastguard Worker wrap_fd = StreamWrapper(fd) 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker self.assertFalse(fd.opened) 252*da0073e9SAndroid Build Coastguard Worker self.assertFalse(fd.closed) 253*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IOError, "Cannot read from"): 254*da0073e9SAndroid Build Coastguard Worker wrap_fd.read() 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker wrap_fd.open() 257*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fd.opened) 258*da0073e9SAndroid Build Coastguard Worker self.assertEqual("01234", wrap_fd.read()) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker del wrap_fd 261*da0073e9SAndroid Build Coastguard Worker self.assertFalse(fd.opened) 262*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fd.closed) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker def test_pickle(self): 265*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryFile() as f: 266*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError) as ctx1: 267*da0073e9SAndroid Build Coastguard Worker pickle.dumps(f) 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker wrap_f = StreamWrapper(f) 270*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError) as ctx2: 271*da0073e9SAndroid Build Coastguard Worker pickle.dumps(wrap_f) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker # Same exception when pickle 274*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(ctx1.exception), str(ctx2.exception)) 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker fd = TestStreamWrapper._FakeFD("") 277*da0073e9SAndroid Build Coastguard Worker wrap_fd = StreamWrapper(fd) 278*da0073e9SAndroid Build Coastguard Worker _ = pickle.loads(pickle.dumps(wrap_fd)) 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker def test_repr(self): 281*da0073e9SAndroid Build Coastguard Worker fd = TestStreamWrapper._FakeFD("") 282*da0073e9SAndroid Build Coastguard Worker wrap_fd = StreamWrapper(fd) 283*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(wrap_fd), "StreamWrapper<FakeFD>") 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryFile() as f: 286*da0073e9SAndroid Build Coastguard Worker wrap_f = StreamWrapper(f) 287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(wrap_f), "StreamWrapper<" + str(f) + ">") 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Workerclass TestIterableDataPipeBasic(TestCase): 291*da0073e9SAndroid Build Coastguard Worker def setUp(self): 292*da0073e9SAndroid Build Coastguard Worker ret = create_temp_dir_and_files() 293*da0073e9SAndroid Build Coastguard Worker self.temp_dir = ret[0][0] 294*da0073e9SAndroid Build Coastguard Worker self.temp_files = ret[0][1:] 295*da0073e9SAndroid Build Coastguard Worker self.temp_sub_dir = ret[1][0] 296*da0073e9SAndroid Build Coastguard Worker self.temp_sub_files = ret[1][1:] 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 299*da0073e9SAndroid Build Coastguard Worker try: 300*da0073e9SAndroid Build Coastguard Worker self.temp_sub_dir.cleanup() 301*da0073e9SAndroid Build Coastguard Worker self.temp_dir.cleanup() 302*da0073e9SAndroid Build Coastguard Worker except Exception as e: 303*da0073e9SAndroid Build Coastguard Worker warnings.warn( 304*da0073e9SAndroid Build Coastguard Worker f"TestIterableDatasetBasic was not able to cleanup temp dir due to {str(e)}" 305*da0073e9SAndroid Build Coastguard Worker ) 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker def test_listdirfiles_iterable_datapipe(self): 308*da0073e9SAndroid Build Coastguard Worker temp_dir = self.temp_dir.name 309*da0073e9SAndroid Build Coastguard Worker datapipe: IterDataPipe = dp.iter.FileLister(temp_dir, "") 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker count = 0 312*da0073e9SAndroid Build Coastguard Worker for pathname in datapipe: 313*da0073e9SAndroid Build Coastguard Worker count = count + 1 314*da0073e9SAndroid Build Coastguard Worker self.assertTrue(pathname in self.temp_files) 315*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count, len(self.temp_files)) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker count = 0 318*da0073e9SAndroid Build Coastguard Worker datapipe = dp.iter.FileLister(temp_dir, "", recursive=True) 319*da0073e9SAndroid Build Coastguard Worker for pathname in datapipe: 320*da0073e9SAndroid Build Coastguard Worker count = count + 1 321*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 322*da0073e9SAndroid Build Coastguard Worker (pathname in self.temp_files) or (pathname in self.temp_sub_files) 323*da0073e9SAndroid Build Coastguard Worker ) 324*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count, len(self.temp_files) + len(self.temp_sub_files)) 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Worker temp_files = self.temp_files 327*da0073e9SAndroid Build Coastguard Worker datapipe = dp.iter.FileLister([temp_dir, *temp_files]) 328*da0073e9SAndroid Build Coastguard Worker count = 0 329*da0073e9SAndroid Build Coastguard Worker for pathname in datapipe: 330*da0073e9SAndroid Build Coastguard Worker count += 1 331*da0073e9SAndroid Build Coastguard Worker self.assertTrue(pathname in self.temp_files) 332*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count, 2 * len(self.temp_files)) 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker # test functional API 335*da0073e9SAndroid Build Coastguard Worker datapipe = datapipe.list_files() 336*da0073e9SAndroid Build Coastguard Worker count = 0 337*da0073e9SAndroid Build Coastguard Worker for pathname in datapipe: 338*da0073e9SAndroid Build Coastguard Worker count += 1 339*da0073e9SAndroid Build Coastguard Worker self.assertTrue(pathname in self.temp_files) 340*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count, 2 * len(self.temp_files)) 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker def test_listdirfilesdeterministic_iterable_datapipe(self): 343*da0073e9SAndroid Build Coastguard Worker temp_dir = self.temp_dir.name 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker datapipe = dp.iter.FileLister(temp_dir, "") 346*da0073e9SAndroid Build Coastguard Worker # The output order should be always the same. 347*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(datapipe), list(datapipe)) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker datapipe = dp.iter.FileLister(temp_dir, "", recursive=True) 350*da0073e9SAndroid Build Coastguard Worker # The output order should be always the same. 351*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(datapipe), list(datapipe)) 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker def test_openfilesfromdisk_iterable_datapipe(self): 354*da0073e9SAndroid Build Coastguard Worker # test import datapipe class directly 355*da0073e9SAndroid Build Coastguard Worker from torch.utils.data.datapipes.iter import FileLister, FileOpener 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker temp_dir = self.temp_dir.name 358*da0073e9SAndroid Build Coastguard Worker datapipe1 = FileLister(temp_dir, "") 359*da0073e9SAndroid Build Coastguard Worker datapipe2 = FileOpener(datapipe1, mode="b") 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker count = 0 362*da0073e9SAndroid Build Coastguard Worker for rec in datapipe2: 363*da0073e9SAndroid Build Coastguard Worker count = count + 1 364*da0073e9SAndroid Build Coastguard Worker self.assertTrue(rec[0] in self.temp_files) 365*da0073e9SAndroid Build Coastguard Worker with open(rec[0], "rb") as f: 366*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rec[1].read(), f.read()) 367*da0073e9SAndroid Build Coastguard Worker rec[1].close() 368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count, len(self.temp_files)) 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker # functional API 371*da0073e9SAndroid Build Coastguard Worker datapipe3 = datapipe1.open_files(mode="b") 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker count = 0 374*da0073e9SAndroid Build Coastguard Worker for rec in datapipe3: 375*da0073e9SAndroid Build Coastguard Worker count = count + 1 376*da0073e9SAndroid Build Coastguard Worker self.assertTrue(rec[0] in self.temp_files) 377*da0073e9SAndroid Build Coastguard Worker with open(rec[0], "rb") as f: 378*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rec[1].read(), f.read()) 379*da0073e9SAndroid Build Coastguard Worker rec[1].close() 380*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count, len(self.temp_files)) 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker # __len__ Test 383*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 384*da0073e9SAndroid Build Coastguard Worker len(datapipe3) 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker def test_routeddecoder_iterable_datapipe(self): 387*da0073e9SAndroid Build Coastguard Worker temp_dir = self.temp_dir.name 388*da0073e9SAndroid Build Coastguard Worker temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png") 389*da0073e9SAndroid Build Coastguard Worker png_data = np.array( 390*da0073e9SAndroid Build Coastguard Worker [[[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], 391*da0073e9SAndroid Build Coastguard Worker dtype=np.single, 392*da0073e9SAndroid Build Coastguard Worker ) 393*da0073e9SAndroid Build Coastguard Worker np.save(temp_pngfile_pathname, png_data) 394*da0073e9SAndroid Build Coastguard Worker datapipe1 = dp.iter.FileLister(temp_dir, ["*.png", "*.txt"]) 395*da0073e9SAndroid Build Coastguard Worker datapipe2 = dp.iter.FileOpener(datapipe1, mode="b") 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker def _png_decoder(extension, data): 398*da0073e9SAndroid Build Coastguard Worker if extension != "png": 399*da0073e9SAndroid Build Coastguard Worker return None 400*da0073e9SAndroid Build Coastguard Worker return np.load(data) 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker def _helper(prior_dp, dp, channel_first=False): 403*da0073e9SAndroid Build Coastguard Worker # Byte stream is not closed 404*da0073e9SAndroid Build Coastguard Worker for inp in prior_dp: 405*da0073e9SAndroid Build Coastguard Worker self.assertFalse(inp[1].closed) 406*da0073e9SAndroid Build Coastguard Worker for inp, rec in zip(prior_dp, dp): 407*da0073e9SAndroid Build Coastguard Worker ext = os.path.splitext(rec[0])[1] 408*da0073e9SAndroid Build Coastguard Worker if ext == ".png": 409*da0073e9SAndroid Build Coastguard Worker expected = np.array( 410*da0073e9SAndroid Build Coastguard Worker [ 411*da0073e9SAndroid Build Coastguard Worker [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], 412*da0073e9SAndroid Build Coastguard Worker [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], 413*da0073e9SAndroid Build Coastguard Worker ], 414*da0073e9SAndroid Build Coastguard Worker dtype=np.single, 415*da0073e9SAndroid Build Coastguard Worker ) 416*da0073e9SAndroid Build Coastguard Worker if channel_first: 417*da0073e9SAndroid Build Coastguard Worker expected = expected.transpose(2, 0, 1) 418*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rec[1], expected) 419*da0073e9SAndroid Build Coastguard Worker else: 420*da0073e9SAndroid Build Coastguard Worker with open(rec[0], "rb") as f: 421*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rec[1], f.read().decode("utf-8")) 422*da0073e9SAndroid Build Coastguard Worker # Corresponding byte stream is closed by Decoder 423*da0073e9SAndroid Build Coastguard Worker self.assertTrue(inp[1].closed) 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker cached = list(datapipe2) 426*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 427*da0073e9SAndroid Build Coastguard Worker datapipe3 = dp.iter.RoutedDecoder(cached, _png_decoder) 428*da0073e9SAndroid Build Coastguard Worker datapipe3.add_handler(decoder_basichandlers) 429*da0073e9SAndroid Build Coastguard Worker _helper(cached, datapipe3) 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker cached = list(datapipe2) 432*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 433*da0073e9SAndroid Build Coastguard Worker datapipe4 = dp.iter.RoutedDecoder(cached, decoder_basichandlers) 434*da0073e9SAndroid Build Coastguard Worker datapipe4.add_handler(_png_decoder) 435*da0073e9SAndroid Build Coastguard Worker _helper(cached, datapipe4, channel_first=True) 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker def test_groupby_iterable_datapipe(self): 438*da0073e9SAndroid Build Coastguard Worker file_list = [ 439*da0073e9SAndroid Build Coastguard Worker "a.png", 440*da0073e9SAndroid Build Coastguard Worker "b.png", 441*da0073e9SAndroid Build Coastguard Worker "c.json", 442*da0073e9SAndroid Build Coastguard Worker "a.json", 443*da0073e9SAndroid Build Coastguard Worker "c.png", 444*da0073e9SAndroid Build Coastguard Worker "b.json", 445*da0073e9SAndroid Build Coastguard Worker "d.png", 446*da0073e9SAndroid Build Coastguard Worker "d.json", 447*da0073e9SAndroid Build Coastguard Worker "e.png", 448*da0073e9SAndroid Build Coastguard Worker "f.json", 449*da0073e9SAndroid Build Coastguard Worker "g.png", 450*da0073e9SAndroid Build Coastguard Worker "f.png", 451*da0073e9SAndroid Build Coastguard Worker "g.json", 452*da0073e9SAndroid Build Coastguard Worker "e.json", 453*da0073e9SAndroid Build Coastguard Worker "h.txt", 454*da0073e9SAndroid Build Coastguard Worker "h.json", 455*da0073e9SAndroid Build Coastguard Worker ] 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker import io 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker datapipe1 = dp.iter.IterableWrapper( 460*da0073e9SAndroid Build Coastguard Worker [(filename, io.BytesIO(b"12345abcde")) for filename in file_list] 461*da0073e9SAndroid Build Coastguard Worker ) 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker def group_fn(data): 464*da0073e9SAndroid Build Coastguard Worker filepath, _ = data 465*da0073e9SAndroid Build Coastguard Worker return os.path.basename(filepath).split(".")[0] 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker datapipe2 = dp.iter.Grouper(datapipe1, group_key_fn=group_fn, group_size=2) 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Worker def order_fn(data): 470*da0073e9SAndroid Build Coastguard Worker data.sort(key=lambda f: f[0], reverse=True) 471*da0073e9SAndroid Build Coastguard Worker return data 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker datapipe3 = dp.iter.Mapper(datapipe2, fn=order_fn) # type: ignore[var-annotated] 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker expected_result = [ 476*da0073e9SAndroid Build Coastguard Worker ("a.png", "a.json"), 477*da0073e9SAndroid Build Coastguard Worker ("c.png", "c.json"), 478*da0073e9SAndroid Build Coastguard Worker ("b.png", "b.json"), 479*da0073e9SAndroid Build Coastguard Worker ("d.png", "d.json"), 480*da0073e9SAndroid Build Coastguard Worker ("f.png", "f.json"), 481*da0073e9SAndroid Build Coastguard Worker ("g.png", "g.json"), 482*da0073e9SAndroid Build Coastguard Worker ("e.png", "e.json"), 483*da0073e9SAndroid Build Coastguard Worker ("h.txt", "h.json"), 484*da0073e9SAndroid Build Coastguard Worker ] 485*da0073e9SAndroid Build Coastguard Worker 486*da0073e9SAndroid Build Coastguard Worker count = 0 487*da0073e9SAndroid Build Coastguard Worker for rec, expected in zip(datapipe3, expected_result): 488*da0073e9SAndroid Build Coastguard Worker count = count + 1 489*da0073e9SAndroid Build Coastguard Worker self.assertEqual(os.path.basename(rec[0][0]), expected[0]) 490*da0073e9SAndroid Build Coastguard Worker self.assertEqual(os.path.basename(rec[1][0]), expected[1]) 491*da0073e9SAndroid Build Coastguard Worker for i in [0, 1]: 492*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rec[i][1].read(), b"12345abcde") 493*da0073e9SAndroid Build Coastguard Worker rec[i][1].close() 494*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count, 8) 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker # testing the keep_key option 497*da0073e9SAndroid Build Coastguard Worker datapipe4 = dp.iter.Grouper( 498*da0073e9SAndroid Build Coastguard Worker datapipe1, group_key_fn=group_fn, keep_key=True, group_size=2 499*da0073e9SAndroid Build Coastguard Worker ) 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker def order_fn(data): 502*da0073e9SAndroid Build Coastguard Worker data[1].sort(key=lambda f: f[0], reverse=True) 503*da0073e9SAndroid Build Coastguard Worker return data 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker datapipe5 = dp.iter.Mapper(datapipe4, fn=order_fn) # type: ignore[var-annotated] 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker expected_result = [ 508*da0073e9SAndroid Build Coastguard Worker ("a", ("a.png", "a.json")), 509*da0073e9SAndroid Build Coastguard Worker ("c", ("c.png", "c.json")), 510*da0073e9SAndroid Build Coastguard Worker ("b", ("b.png", "b.json")), 511*da0073e9SAndroid Build Coastguard Worker ("d", ("d.png", "d.json")), 512*da0073e9SAndroid Build Coastguard Worker ("f", ("f.png", "f.json")), 513*da0073e9SAndroid Build Coastguard Worker ("g", ("g.png", "g.json")), 514*da0073e9SAndroid Build Coastguard Worker ("e", ("e.png", "e.json")), 515*da0073e9SAndroid Build Coastguard Worker ("h", ("h.txt", "h.json")), 516*da0073e9SAndroid Build Coastguard Worker ] 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker count = 0 519*da0073e9SAndroid Build Coastguard Worker for rec, expected in zip(datapipe5, expected_result): 520*da0073e9SAndroid Build Coastguard Worker count = count + 1 521*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rec[0], expected[0]) 522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rec[1][0][0], expected[1][0]) 523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rec[1][1][0], expected[1][1]) 524*da0073e9SAndroid Build Coastguard Worker for i in [0, 1]: 525*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rec[1][i][1].read(), b"12345abcde") 526*da0073e9SAndroid Build Coastguard Worker rec[1][i][1].close() 527*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count, 8) 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker def test_demux_mux_datapipe(self): 530*da0073e9SAndroid Build Coastguard Worker numbers = NumbersDataset(10) 531*da0073e9SAndroid Build Coastguard Worker n1, n2 = numbers.demux(2, lambda x: x % 2) 532*da0073e9SAndroid Build Coastguard Worker self.assertEqual([0, 2, 4, 6, 8], list(n1)) 533*da0073e9SAndroid Build Coastguard Worker self.assertEqual([1, 3, 5, 7, 9], list(n2)) 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker # Functional Test: demux and mux works sequentially as expected 536*da0073e9SAndroid Build Coastguard Worker numbers = NumbersDataset(10) 537*da0073e9SAndroid Build Coastguard Worker n1, n2, n3 = numbers.demux(3, lambda x: x % 3) 538*da0073e9SAndroid Build Coastguard Worker n = n1.mux(n2, n3) 539*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(9)), list(n)) 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker # Functional Test: Uneven DataPipes 542*da0073e9SAndroid Build Coastguard Worker source_numbers = list(range(0, 10)) + [10, 12] 543*da0073e9SAndroid Build Coastguard Worker numbers_dp = dp.iter.IterableWrapper(source_numbers) 544*da0073e9SAndroid Build Coastguard Worker n1, n2 = numbers_dp.demux(2, lambda x: x % 2) 545*da0073e9SAndroid Build Coastguard Worker self.assertEqual([0, 2, 4, 6, 8, 10, 12], list(n1)) 546*da0073e9SAndroid Build Coastguard Worker self.assertEqual([1, 3, 5, 7, 9], list(n2)) 547*da0073e9SAndroid Build Coastguard Worker n = n1.mux(n2) 548*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), list(n)) 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker @suppress_warnings # Suppress warning for lambda fn 551*da0073e9SAndroid Build Coastguard Worker def test_map_with_col_file_handle_datapipe(self): 552*da0073e9SAndroid Build Coastguard Worker temp_dir = self.temp_dir.name 553*da0073e9SAndroid Build Coastguard Worker datapipe1 = dp.iter.FileLister(temp_dir, "") 554*da0073e9SAndroid Build Coastguard Worker datapipe2 = dp.iter.FileOpener(datapipe1) 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker def _helper(datapipe): 557*da0073e9SAndroid Build Coastguard Worker dp1 = datapipe.map(lambda x: x.read(), input_col=1) 558*da0073e9SAndroid Build Coastguard Worker dp2 = datapipe.map(lambda x: (x[0], x[1].read())) 559*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(dp1), list(dp2)) 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker # tuple 562*da0073e9SAndroid Build Coastguard Worker _helper(datapipe2) 563*da0073e9SAndroid Build Coastguard Worker # list 564*da0073e9SAndroid Build Coastguard Worker datapipe3 = datapipe2.map(lambda x: list(x)) 565*da0073e9SAndroid Build Coastguard Worker _helper(datapipe3) 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker@skipIfNoDataFrames 569*da0073e9SAndroid Build Coastguard Workerclass TestCaptureDataFrame(TestCase): 570*da0073e9SAndroid Build Coastguard Worker def get_new_df(self): 571*da0073e9SAndroid Build Coastguard Worker return df_wrapper.create_dataframe([[1, 2]], columns=["a", "b"]) 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker def compare_capture_and_eager(self, operations): 574*da0073e9SAndroid Build Coastguard Worker cdf = CaptureDataFrame() 575*da0073e9SAndroid Build Coastguard Worker cdf = operations(cdf) 576*da0073e9SAndroid Build Coastguard Worker df = self.get_new_df() 577*da0073e9SAndroid Build Coastguard Worker cdf = cdf.apply_ops(df) 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker df = self.get_new_df() 580*da0073e9SAndroid Build Coastguard Worker df = operations(df) 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker self.assertTrue(df.equals(cdf)) 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker def test_basic_capture(self): 585*da0073e9SAndroid Build Coastguard Worker def operations(df): 586*da0073e9SAndroid Build Coastguard Worker df["c"] = df.b + df["a"] * 7 587*da0073e9SAndroid Build Coastguard Worker # somehow swallows pandas UserWarning when `df.c = df.b + df['a'] * 7` 588*da0073e9SAndroid Build Coastguard Worker return df 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker self.compare_capture_and_eager(operations) 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker 593*da0073e9SAndroid Build Coastguard Workerclass TestDataFramesPipes(TestCase): 594*da0073e9SAndroid Build Coastguard Worker """ 595*da0073e9SAndroid Build Coastguard Worker Most of test will fail if pandas instaled, but no dill available. 596*da0073e9SAndroid Build Coastguard Worker Need to rework them to avoid multiple skips. 597*da0073e9SAndroid Build Coastguard Worker """ 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker def _get_datapipe(self, range=10, dataframe_size=7): 600*da0073e9SAndroid Build Coastguard Worker return NumbersDataset(range).map(lambda i: (i, i % 3)) 601*da0073e9SAndroid Build Coastguard Worker 602*da0073e9SAndroid Build Coastguard Worker def _get_dataframes_pipe(self, range=10, dataframe_size=7): 603*da0073e9SAndroid Build Coastguard Worker return ( 604*da0073e9SAndroid Build Coastguard Worker NumbersDataset(range) 605*da0073e9SAndroid Build Coastguard Worker .map(lambda i: (i, i % 3)) 606*da0073e9SAndroid Build Coastguard Worker ._to_dataframes_pipe(columns=["i", "j"], dataframe_size=dataframe_size) 607*da0073e9SAndroid Build Coastguard Worker ) 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker @skipIfNoDataFrames 610*da0073e9SAndroid Build Coastguard Worker @skipIfNoDill # TODO(VitalyFedyunin): Decouple tests from dill by avoiding lambdas in map 611*da0073e9SAndroid Build Coastguard Worker def test_capture(self): 612*da0073e9SAndroid Build Coastguard Worker dp_numbers = self._get_datapipe().map(lambda x: (x[0], x[1], x[1] + 3 * x[0])) 613*da0073e9SAndroid Build Coastguard Worker df_numbers = self._get_dataframes_pipe() 614*da0073e9SAndroid Build Coastguard Worker df_numbers["k"] = df_numbers["j"] + df_numbers.i * 3 615*da0073e9SAndroid Build Coastguard Worker expected = list(dp_numbers) 616*da0073e9SAndroid Build Coastguard Worker actual = list(df_numbers) 617*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker @skipIfNoDataFrames 620*da0073e9SAndroid Build Coastguard Worker @skipIfNoDill 621*da0073e9SAndroid Build Coastguard Worker def test_shuffle(self): 622*da0073e9SAndroid Build Coastguard Worker # With non-zero (but extremely low) probability (when shuffle do nothing), 623*da0073e9SAndroid Build Coastguard Worker # this test fails, so feel free to restart 624*da0073e9SAndroid Build Coastguard Worker df_numbers = self._get_dataframes_pipe(range=1000).shuffle() 625*da0073e9SAndroid Build Coastguard Worker dp_numbers = self._get_datapipe(range=1000) 626*da0073e9SAndroid Build Coastguard Worker df_result = [tuple(item) for item in df_numbers] 627*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(list(dp_numbers), df_result) 628*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(dp_numbers), sorted(df_result)) 629*da0073e9SAndroid Build Coastguard Worker 630*da0073e9SAndroid Build Coastguard Worker @skipIfNoDataFrames 631*da0073e9SAndroid Build Coastguard Worker @skipIfNoDill 632*da0073e9SAndroid Build Coastguard Worker def test_batch(self): 633*da0073e9SAndroid Build Coastguard Worker df_numbers = self._get_dataframes_pipe(range=100).batch(8) 634*da0073e9SAndroid Build Coastguard Worker df_numbers_list = list(df_numbers) 635*da0073e9SAndroid Build Coastguard Worker last_batch = df_numbers_list[-1] 636*da0073e9SAndroid Build Coastguard Worker self.assertEqual(4, len(last_batch)) 637*da0073e9SAndroid Build Coastguard Worker unpacked_batch = [tuple(row) for row in last_batch] 638*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(96, 0), (97, 1), (98, 2), (99, 0)], unpacked_batch) 639*da0073e9SAndroid Build Coastguard Worker 640*da0073e9SAndroid Build Coastguard Worker @skipIfNoDataFrames 641*da0073e9SAndroid Build Coastguard Worker @skipIfNoDill 642*da0073e9SAndroid Build Coastguard Worker def test_unbatch(self): 643*da0073e9SAndroid Build Coastguard Worker df_numbers = self._get_dataframes_pipe(range=100).batch(8).batch(3) 644*da0073e9SAndroid Build Coastguard Worker dp_numbers = self._get_datapipe(range=100) 645*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(dp_numbers), list(df_numbers.unbatch(2))) 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker @skipIfNoDataFrames 648*da0073e9SAndroid Build Coastguard Worker @skipIfNoDill 649*da0073e9SAndroid Build Coastguard Worker def test_filter(self): 650*da0073e9SAndroid Build Coastguard Worker df_numbers = self._get_dataframes_pipe(range=10).filter(lambda x: x.i > 5) 651*da0073e9SAndroid Build Coastguard Worker actual = list(df_numbers) 652*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(6, 0), (7, 1), (8, 2), (9, 0)], actual) 653*da0073e9SAndroid Build Coastguard Worker 654*da0073e9SAndroid Build Coastguard Worker @skipIfNoDataFrames 655*da0073e9SAndroid Build Coastguard Worker @skipIfNoDill 656*da0073e9SAndroid Build Coastguard Worker def test_collate(self): 657*da0073e9SAndroid Build Coastguard Worker def collate_i(column): 658*da0073e9SAndroid Build Coastguard Worker return column.sum() 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Worker def collate_j(column): 661*da0073e9SAndroid Build Coastguard Worker return column.prod() 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker df_numbers = self._get_dataframes_pipe(range=30).batch(3) 664*da0073e9SAndroid Build Coastguard Worker df_numbers = df_numbers.collate({"j": collate_j, "i": collate_i}) 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker expected_i = [ 667*da0073e9SAndroid Build Coastguard Worker 3, 668*da0073e9SAndroid Build Coastguard Worker 12, 669*da0073e9SAndroid Build Coastguard Worker 21, 670*da0073e9SAndroid Build Coastguard Worker 30, 671*da0073e9SAndroid Build Coastguard Worker 39, 672*da0073e9SAndroid Build Coastguard Worker 48, 673*da0073e9SAndroid Build Coastguard Worker 57, 674*da0073e9SAndroid Build Coastguard Worker 66, 675*da0073e9SAndroid Build Coastguard Worker 75, 676*da0073e9SAndroid Build Coastguard Worker 84, 677*da0073e9SAndroid Build Coastguard Worker ] 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker actual_i = [] 680*da0073e9SAndroid Build Coastguard Worker for i, j in df_numbers: 681*da0073e9SAndroid Build Coastguard Worker actual_i.append(i) 682*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_i, actual_i) 683*da0073e9SAndroid Build Coastguard Worker 684*da0073e9SAndroid Build Coastguard Worker actual_i = [] 685*da0073e9SAndroid Build Coastguard Worker for item in df_numbers: 686*da0073e9SAndroid Build Coastguard Worker actual_i.append(item.i) 687*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_i, actual_i) 688*da0073e9SAndroid Build Coastguard Worker 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Workerclass IDP_NoLen(IterDataPipe): 691*da0073e9SAndroid Build Coastguard Worker def __init__(self, input_dp): 692*da0073e9SAndroid Build Coastguard Worker super().__init__() 693*da0073e9SAndroid Build Coastguard Worker self.input_dp = input_dp 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker # Prevent in-place modification 696*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 697*da0073e9SAndroid Build Coastguard Worker input_dp = ( 698*da0073e9SAndroid Build Coastguard Worker self.input_dp 699*da0073e9SAndroid Build Coastguard Worker if isinstance(self.input_dp, IterDataPipe) 700*da0073e9SAndroid Build Coastguard Worker else copy.deepcopy(self.input_dp) 701*da0073e9SAndroid Build Coastguard Worker ) 702*da0073e9SAndroid Build Coastguard Worker yield from input_dp 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard Workerdef _fake_fn(data): 706*da0073e9SAndroid Build Coastguard Worker return data 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker 709*da0073e9SAndroid Build Coastguard Workerdef _fake_add(constant, data): 710*da0073e9SAndroid Build Coastguard Worker return constant + data 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker 713*da0073e9SAndroid Build Coastguard Workerdef _fake_filter_fn(data): 714*da0073e9SAndroid Build Coastguard Worker return True 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Workerdef _simple_filter_fn(data): 718*da0073e9SAndroid Build Coastguard Worker return data >= 5 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Workerdef _fake_filter_fn_constant(constant, data): 722*da0073e9SAndroid Build Coastguard Worker return data >= constant 723*da0073e9SAndroid Build Coastguard Worker 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Workerdef _mul_10(x): 726*da0073e9SAndroid Build Coastguard Worker return x * 10 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Workerdef _mod_3_test(x): 730*da0073e9SAndroid Build Coastguard Worker return x % 3 == 1 731*da0073e9SAndroid Build Coastguard Worker 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Workerdef _to_list(x): 734*da0073e9SAndroid Build Coastguard Worker return [x] 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Workerlambda_fn1 = lambda x: x # noqa: E731 738*da0073e9SAndroid Build Coastguard Workerlambda_fn2 = lambda x: x % 2 # noqa: E731 739*da0073e9SAndroid Build Coastguard Workerlambda_fn3 = lambda x: x >= 5 # noqa: E731 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Workerclass Add1Module(nn.Module): 743*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 744*da0073e9SAndroid Build Coastguard Worker return x + 1 745*da0073e9SAndroid Build Coastguard Worker 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Workerclass Add1Callable: 748*da0073e9SAndroid Build Coastguard Worker def __call__(self, x): 749*da0073e9SAndroid Build Coastguard Worker return x + 1 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Workerclass TestFunctionalIterDataPipe(TestCase): 753*da0073e9SAndroid Build Coastguard Worker def _serialization_test_helper(self, datapipe, use_dill): 754*da0073e9SAndroid Build Coastguard Worker if use_dill: 755*da0073e9SAndroid Build Coastguard Worker serialized_dp = dill.dumps(datapipe) 756*da0073e9SAndroid Build Coastguard Worker deserialized_dp = dill.loads(serialized_dp) 757*da0073e9SAndroid Build Coastguard Worker else: 758*da0073e9SAndroid Build Coastguard Worker serialized_dp = pickle.dumps(datapipe) 759*da0073e9SAndroid Build Coastguard Worker deserialized_dp = pickle.loads(serialized_dp) 760*da0073e9SAndroid Build Coastguard Worker try: 761*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(datapipe), list(deserialized_dp)) 762*da0073e9SAndroid Build Coastguard Worker except AssertionError as e: 763*da0073e9SAndroid Build Coastguard Worker print(f"{datapipe} is failing.") 764*da0073e9SAndroid Build Coastguard Worker raise e 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker def _serialization_test_for_single_dp(self, dp, use_dill=False): 767*da0073e9SAndroid Build Coastguard Worker # 1. Testing for serialization before any iteration starts 768*da0073e9SAndroid Build Coastguard Worker self._serialization_test_helper(dp, use_dill) 769*da0073e9SAndroid Build Coastguard Worker # 2. Testing for serialization after DataPipe is partially read 770*da0073e9SAndroid Build Coastguard Worker it = iter(dp) 771*da0073e9SAndroid Build Coastguard Worker _ = next(it) 772*da0073e9SAndroid Build Coastguard Worker self._serialization_test_helper(dp, use_dill) 773*da0073e9SAndroid Build Coastguard Worker # 3. Testing for serialization after DataPipe is fully read 774*da0073e9SAndroid Build Coastguard Worker it = iter(dp) 775*da0073e9SAndroid Build Coastguard Worker _ = list(it) 776*da0073e9SAndroid Build Coastguard Worker self._serialization_test_helper(dp, use_dill) 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker def _serialization_test_for_dp_with_children(self, dp1, dp2, use_dill=False): 779*da0073e9SAndroid Build Coastguard Worker # 1. Testing for serialization before any iteration starts 780*da0073e9SAndroid Build Coastguard Worker self._serialization_test_helper(dp1, use_dill) 781*da0073e9SAndroid Build Coastguard Worker self._serialization_test_helper(dp2, use_dill) 782*da0073e9SAndroid Build Coastguard Worker 783*da0073e9SAndroid Build Coastguard Worker # 2. Testing for serialization after DataPipe is partially read 784*da0073e9SAndroid Build Coastguard Worker it1, it2 = iter(dp1), iter(dp2) 785*da0073e9SAndroid Build Coastguard Worker _, _ = next(it1), next(it2) 786*da0073e9SAndroid Build Coastguard Worker # Catch `fork`, `demux` "some child DataPipes are not exhausted" warning 787*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 788*da0073e9SAndroid Build Coastguard Worker self._serialization_test_helper(dp1, use_dill) 789*da0073e9SAndroid Build Coastguard Worker self._serialization_test_helper(dp2, use_dill) 790*da0073e9SAndroid Build Coastguard Worker 791*da0073e9SAndroid Build Coastguard Worker # 2.5. Testing for serialization after one child DataPipe is fully read 792*da0073e9SAndroid Build Coastguard Worker # (Only for DataPipes with children DataPipes) 793*da0073e9SAndroid Build Coastguard Worker it1 = iter(dp1) 794*da0073e9SAndroid Build Coastguard Worker _ = list(it1) # fully read one child 795*da0073e9SAndroid Build Coastguard Worker # Catch `fork`, `demux` "some child DataPipes are not exhausted" warning 796*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 797*da0073e9SAndroid Build Coastguard Worker self._serialization_test_helper(dp1, use_dill) 798*da0073e9SAndroid Build Coastguard Worker self._serialization_test_helper(dp2, use_dill) 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Worker # 3. Testing for serialization after DataPipe is fully read 801*da0073e9SAndroid Build Coastguard Worker it2 = iter(dp2) 802*da0073e9SAndroid Build Coastguard Worker _ = list(it2) # fully read the other child 803*da0073e9SAndroid Build Coastguard Worker self._serialization_test_helper(dp1, use_dill) 804*da0073e9SAndroid Build Coastguard Worker self._serialization_test_helper(dp2, use_dill) 805*da0073e9SAndroid Build Coastguard Worker 806*da0073e9SAndroid Build Coastguard Worker def test_serializable(self): 807*da0073e9SAndroid Build Coastguard Worker picklable_datapipes: List = [ 808*da0073e9SAndroid Build Coastguard Worker ( 809*da0073e9SAndroid Build Coastguard Worker dp.iter.Batcher, 810*da0073e9SAndroid Build Coastguard Worker None, 811*da0073e9SAndroid Build Coastguard Worker ( 812*da0073e9SAndroid Build Coastguard Worker 3, 813*da0073e9SAndroid Build Coastguard Worker True, 814*da0073e9SAndroid Build Coastguard Worker ), 815*da0073e9SAndroid Build Coastguard Worker {}, 816*da0073e9SAndroid Build Coastguard Worker ), 817*da0073e9SAndroid Build Coastguard Worker (dp.iter.Collator, None, (_fake_fn,), {}), 818*da0073e9SAndroid Build Coastguard Worker (dp.iter.Concater, None, (dp.iter.IterableWrapper(range(5)),), {}), 819*da0073e9SAndroid Build Coastguard Worker (dp.iter.Demultiplexer, None, (2, _simple_filter_fn), {}), 820*da0073e9SAndroid Build Coastguard Worker (dp.iter.FileLister, ".", (), {}), 821*da0073e9SAndroid Build Coastguard Worker (dp.iter.FileOpener, None, (), {}), 822*da0073e9SAndroid Build Coastguard Worker (dp.iter.Filter, None, (_fake_filter_fn,), {}), 823*da0073e9SAndroid Build Coastguard Worker (dp.iter.Filter, None, (partial(_fake_filter_fn_constant, 5),), {}), 824*da0073e9SAndroid Build Coastguard Worker (dp.iter.Forker, None, (2,), {}), 825*da0073e9SAndroid Build Coastguard Worker (dp.iter.Forker, None, (2,), {"copy": "shallow"}), 826*da0073e9SAndroid Build Coastguard Worker (dp.iter.Grouper, None, (_fake_filter_fn,), {"group_size": 2}), 827*da0073e9SAndroid Build Coastguard Worker (dp.iter.IterableWrapper, range(10), (), {}), 828*da0073e9SAndroid Build Coastguard Worker (dp.iter.Mapper, None, (_fake_fn,), {}), 829*da0073e9SAndroid Build Coastguard Worker (dp.iter.Mapper, None, (partial(_fake_add, 1),), {}), 830*da0073e9SAndroid Build Coastguard Worker (dp.iter.Multiplexer, None, (dp.iter.IterableWrapper(range(10)),), {}), 831*da0073e9SAndroid Build Coastguard Worker (dp.iter.Sampler, None, (), {}), 832*da0073e9SAndroid Build Coastguard Worker (dp.iter.Shuffler, dp.iter.IterableWrapper([0] * 10), (), {}), 833*da0073e9SAndroid Build Coastguard Worker (dp.iter.StreamReader, None, (), {}), 834*da0073e9SAndroid Build Coastguard Worker (dp.iter.UnBatcher, None, (0,), {}), 835*da0073e9SAndroid Build Coastguard Worker (dp.iter.Zipper, None, (dp.iter.IterableWrapper(range(10)),), {}), 836*da0073e9SAndroid Build Coastguard Worker ] 837*da0073e9SAndroid Build Coastguard Worker # Skipping comparison for these DataPipes 838*da0073e9SAndroid Build Coastguard Worker dp_skip_comparison = {dp.iter.FileOpener, dp.iter.StreamReader} 839*da0073e9SAndroid Build Coastguard Worker # These DataPipes produce multiple DataPipes as outputs and those should be compared 840*da0073e9SAndroid Build Coastguard Worker dp_compare_children = {dp.iter.Demultiplexer, dp.iter.Forker} 841*da0073e9SAndroid Build Coastguard Worker 842*da0073e9SAndroid Build Coastguard Worker for dpipe, custom_input, dp_args, dp_kwargs in picklable_datapipes: 843*da0073e9SAndroid Build Coastguard Worker if custom_input is None: 844*da0073e9SAndroid Build Coastguard Worker custom_input = dp.iter.IterableWrapper(range(10)) 845*da0073e9SAndroid Build Coastguard Worker if ( 846*da0073e9SAndroid Build Coastguard Worker dpipe in dp_skip_comparison 847*da0073e9SAndroid Build Coastguard Worker ): # Merely make sure they are picklable and loadable (no value comparison) 848*da0073e9SAndroid Build Coastguard Worker datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] 849*da0073e9SAndroid Build Coastguard Worker serialized_dp = pickle.dumps(datapipe) 850*da0073e9SAndroid Build Coastguard Worker _ = pickle.loads(serialized_dp) 851*da0073e9SAndroid Build Coastguard Worker elif dpipe in dp_compare_children: # DataPipes that have children 852*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] 853*da0073e9SAndroid Build Coastguard Worker self._serialization_test_for_dp_with_children(dp1, dp2) 854*da0073e9SAndroid Build Coastguard Worker else: # Single DataPipe that requires comparison 855*da0073e9SAndroid Build Coastguard Worker datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] 856*da0073e9SAndroid Build Coastguard Worker self._serialization_test_for_single_dp(datapipe) 857*da0073e9SAndroid Build Coastguard Worker 858*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Dict with function as keys") 859*da0073e9SAndroid Build Coastguard Worker def test_serializable_with_dill(self): 860*da0073e9SAndroid Build Coastguard Worker """Only for DataPipes that take in a function as argument""" 861*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper(range(10)) 862*da0073e9SAndroid Build Coastguard Worker 863*da0073e9SAndroid Build Coastguard Worker datapipes_with_lambda_fn: List[ 864*da0073e9SAndroid Build Coastguard Worker Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]] 865*da0073e9SAndroid Build Coastguard Worker ] = [ 866*da0073e9SAndroid Build Coastguard Worker (dp.iter.Collator, (lambda_fn1,), {}), 867*da0073e9SAndroid Build Coastguard Worker ( 868*da0073e9SAndroid Build Coastguard Worker dp.iter.Demultiplexer, 869*da0073e9SAndroid Build Coastguard Worker ( 870*da0073e9SAndroid Build Coastguard Worker 2, 871*da0073e9SAndroid Build Coastguard Worker lambda_fn2, 872*da0073e9SAndroid Build Coastguard Worker ), 873*da0073e9SAndroid Build Coastguard Worker {}, 874*da0073e9SAndroid Build Coastguard Worker ), 875*da0073e9SAndroid Build Coastguard Worker (dp.iter.Filter, (lambda_fn3,), {}), 876*da0073e9SAndroid Build Coastguard Worker (dp.iter.Grouper, (lambda_fn3,), {}), 877*da0073e9SAndroid Build Coastguard Worker (dp.iter.Mapper, (lambda_fn1,), {}), 878*da0073e9SAndroid Build Coastguard Worker ] 879*da0073e9SAndroid Build Coastguard Worker 880*da0073e9SAndroid Build Coastguard Worker def _local_fns(): 881*da0073e9SAndroid Build Coastguard Worker def _fn1(x): 882*da0073e9SAndroid Build Coastguard Worker return x 883*da0073e9SAndroid Build Coastguard Worker 884*da0073e9SAndroid Build Coastguard Worker def _fn2(x): 885*da0073e9SAndroid Build Coastguard Worker return x % 2 886*da0073e9SAndroid Build Coastguard Worker 887*da0073e9SAndroid Build Coastguard Worker def _fn3(x): 888*da0073e9SAndroid Build Coastguard Worker return x >= 5 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker return _fn1, _fn2, _fn3 891*da0073e9SAndroid Build Coastguard Worker 892*da0073e9SAndroid Build Coastguard Worker fn1, fn2, fn3 = _local_fns() 893*da0073e9SAndroid Build Coastguard Worker 894*da0073e9SAndroid Build Coastguard Worker datapipes_with_local_fn: List[ 895*da0073e9SAndroid Build Coastguard Worker Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]] 896*da0073e9SAndroid Build Coastguard Worker ] = [ 897*da0073e9SAndroid Build Coastguard Worker (dp.iter.Collator, (fn1,), {}), 898*da0073e9SAndroid Build Coastguard Worker ( 899*da0073e9SAndroid Build Coastguard Worker dp.iter.Demultiplexer, 900*da0073e9SAndroid Build Coastguard Worker ( 901*da0073e9SAndroid Build Coastguard Worker 2, 902*da0073e9SAndroid Build Coastguard Worker fn2, 903*da0073e9SAndroid Build Coastguard Worker ), 904*da0073e9SAndroid Build Coastguard Worker {}, 905*da0073e9SAndroid Build Coastguard Worker ), 906*da0073e9SAndroid Build Coastguard Worker (dp.iter.Filter, (fn3,), {}), 907*da0073e9SAndroid Build Coastguard Worker (dp.iter.Grouper, (fn3,), {}), 908*da0073e9SAndroid Build Coastguard Worker (dp.iter.Mapper, (fn1,), {}), 909*da0073e9SAndroid Build Coastguard Worker ] 910*da0073e9SAndroid Build Coastguard Worker 911*da0073e9SAndroid Build Coastguard Worker dp_compare_children = {dp.iter.Demultiplexer} 912*da0073e9SAndroid Build Coastguard Worker 913*da0073e9SAndroid Build Coastguard Worker if HAS_DILL: 914*da0073e9SAndroid Build Coastguard Worker for dpipe, dp_args, dp_kwargs in ( 915*da0073e9SAndroid Build Coastguard Worker datapipes_with_lambda_fn + datapipes_with_local_fn 916*da0073e9SAndroid Build Coastguard Worker ): 917*da0073e9SAndroid Build Coastguard Worker if dpipe in dp_compare_children: 918*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] 919*da0073e9SAndroid Build Coastguard Worker self._serialization_test_for_dp_with_children( 920*da0073e9SAndroid Build Coastguard Worker dp1, dp2, use_dill=True 921*da0073e9SAndroid Build Coastguard Worker ) 922*da0073e9SAndroid Build Coastguard Worker else: 923*da0073e9SAndroid Build Coastguard Worker datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] 924*da0073e9SAndroid Build Coastguard Worker self._serialization_test_for_single_dp(datapipe, use_dill=True) 925*da0073e9SAndroid Build Coastguard Worker else: 926*da0073e9SAndroid Build Coastguard Worker msgs = ( 927*da0073e9SAndroid Build Coastguard Worker r"^Lambda function is not supported by pickle", 928*da0073e9SAndroid Build Coastguard Worker r"^Local function is not supported by pickle", 929*da0073e9SAndroid Build Coastguard Worker ) 930*da0073e9SAndroid Build Coastguard Worker for dps, msg in zip( 931*da0073e9SAndroid Build Coastguard Worker (datapipes_with_lambda_fn, datapipes_with_local_fn), msgs 932*da0073e9SAndroid Build Coastguard Worker ): 933*da0073e9SAndroid Build Coastguard Worker for dpipe, dp_args, dp_kwargs in dps: 934*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, msg): 935*da0073e9SAndroid Build Coastguard Worker datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] 936*da0073e9SAndroid Build Coastguard Worker with self.assertRaises((pickle.PicklingError, AttributeError)): 937*da0073e9SAndroid Build Coastguard Worker pickle.dumps(datapipe) 938*da0073e9SAndroid Build Coastguard Worker 939*da0073e9SAndroid Build Coastguard Worker def test_docstring(self): 940*da0073e9SAndroid Build Coastguard Worker """ 941*da0073e9SAndroid Build Coastguard Worker Ensure functional form of IterDataPipe has the correct docstring from 942*da0073e9SAndroid Build Coastguard Worker the class form. 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker Regression test for https://github.com/pytorch/data/issues/792. 945*da0073e9SAndroid Build Coastguard Worker """ 946*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper(range(10)) 947*da0073e9SAndroid Build Coastguard Worker 948*da0073e9SAndroid Build Coastguard Worker for dp_funcname in [ 949*da0073e9SAndroid Build Coastguard Worker "batch", 950*da0073e9SAndroid Build Coastguard Worker "collate", 951*da0073e9SAndroid Build Coastguard Worker "concat", 952*da0073e9SAndroid Build Coastguard Worker "demux", 953*da0073e9SAndroid Build Coastguard Worker "filter", 954*da0073e9SAndroid Build Coastguard Worker "fork", 955*da0073e9SAndroid Build Coastguard Worker "map", 956*da0073e9SAndroid Build Coastguard Worker "mux", 957*da0073e9SAndroid Build Coastguard Worker "read_from_stream", 958*da0073e9SAndroid Build Coastguard Worker # "sampler", 959*da0073e9SAndroid Build Coastguard Worker "shuffle", 960*da0073e9SAndroid Build Coastguard Worker "unbatch", 961*da0073e9SAndroid Build Coastguard Worker "zip", 962*da0073e9SAndroid Build Coastguard Worker ]: 963*da0073e9SAndroid Build Coastguard Worker if sys.version_info >= (3, 9): 964*da0073e9SAndroid Build Coastguard Worker docstring = pydoc.render_doc( 965*da0073e9SAndroid Build Coastguard Worker thing=getattr(input_dp, dp_funcname), forceload=True 966*da0073e9SAndroid Build Coastguard Worker ) 967*da0073e9SAndroid Build Coastguard Worker elif sys.version_info < (3, 9): 968*da0073e9SAndroid Build Coastguard Worker # pydoc works differently on Python 3.8, see 969*da0073e9SAndroid Build Coastguard Worker # https://docs.python.org/3/whatsnew/3.9.html#pydoc 970*da0073e9SAndroid Build Coastguard Worker docstring = getattr(input_dp, dp_funcname).__doc__ 971*da0073e9SAndroid Build Coastguard Worker 972*da0073e9SAndroid Build Coastguard Worker assert f"(functional name: ``{dp_funcname}``)" in docstring 973*da0073e9SAndroid Build Coastguard Worker assert "Args:" in docstring 974*da0073e9SAndroid Build Coastguard Worker assert "Example:" in docstring or "Examples:" in docstring 975*da0073e9SAndroid Build Coastguard Worker 976*da0073e9SAndroid Build Coastguard Worker def test_iterable_wrapper_datapipe(self): 977*da0073e9SAndroid Build Coastguard Worker input_ls = list(range(10)) 978*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper(input_ls) 979*da0073e9SAndroid Build Coastguard Worker 980*da0073e9SAndroid Build Coastguard Worker # Functional Test: values are unchanged and in the same order 981*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_ls, list(input_dp)) 982*da0073e9SAndroid Build Coastguard Worker 983*da0073e9SAndroid Build Coastguard Worker # Functional Test: deep copy by default when an iterator is initialized (first element is read) 984*da0073e9SAndroid Build Coastguard Worker it = iter(input_dp) 985*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 986*da0073e9SAndroid Build Coastguard Worker 0, next(it) 987*da0073e9SAndroid Build Coastguard Worker ) # The deep copy only happens when the first element is read 988*da0073e9SAndroid Build Coastguard Worker input_ls.append(50) 989*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(1, 10)), list(it)) 990*da0073e9SAndroid Build Coastguard Worker 991*da0073e9SAndroid Build Coastguard Worker # Functional Test: shallow copy 992*da0073e9SAndroid Build Coastguard Worker input_ls2 = [1, 2, 3] 993*da0073e9SAndroid Build Coastguard Worker input_dp_shallow = dp.iter.IterableWrapper(input_ls2, deepcopy=False) 994*da0073e9SAndroid Build Coastguard Worker input_ls2.append(10) 995*da0073e9SAndroid Build Coastguard Worker self.assertEqual([1, 2, 3, 10], list(input_dp_shallow)) 996*da0073e9SAndroid Build Coastguard Worker 997*da0073e9SAndroid Build Coastguard Worker # Reset Test: reset the DataPipe 998*da0073e9SAndroid Build Coastguard Worker input_ls = list(range(10)) 999*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper(input_ls) 1000*da0073e9SAndroid Build Coastguard Worker n_elements_before_reset = 5 1001*da0073e9SAndroid Build Coastguard Worker res_before_reset, res_after_reset = reset_after_n_next_calls( 1002*da0073e9SAndroid Build Coastguard Worker input_dp, n_elements_before_reset 1003*da0073e9SAndroid Build Coastguard Worker ) 1004*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_ls[:n_elements_before_reset], res_before_reset) 1005*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_ls, res_after_reset) 1006*da0073e9SAndroid Build Coastguard Worker 1007*da0073e9SAndroid Build Coastguard Worker # __len__ Test: inherits length from sequence 1008*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(input_ls), len(input_dp)) 1009*da0073e9SAndroid Build Coastguard Worker 1010*da0073e9SAndroid Build Coastguard Worker def test_concat_iterdatapipe(self): 1011*da0073e9SAndroid Build Coastguard Worker input_dp1 = dp.iter.IterableWrapper(range(10)) 1012*da0073e9SAndroid Build Coastguard Worker input_dp2 = dp.iter.IterableWrapper(range(5)) 1013*da0073e9SAndroid Build Coastguard Worker 1014*da0073e9SAndroid Build Coastguard Worker # Functional Test: Raises exception for empty input 1015*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"): 1016*da0073e9SAndroid Build Coastguard Worker dp.iter.Concater() 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard Worker # Functional Test: Raises exception for non-IterDataPipe input 1019*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1020*da0073e9SAndroid Build Coastguard Worker TypeError, r"Expected all inputs to be `IterDataPipe`" 1021*da0073e9SAndroid Build Coastguard Worker ): 1022*da0073e9SAndroid Build Coastguard Worker dp.iter.Concater(input_dp1, ()) # type: ignore[arg-type] 1023*da0073e9SAndroid Build Coastguard Worker 1024*da0073e9SAndroid Build Coastguard Worker # Functional Test: Concatenate DataPipes as expected 1025*da0073e9SAndroid Build Coastguard Worker concat_dp = input_dp1.concat(input_dp2) 1026*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(concat_dp), 15) 1027*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(concat_dp), list(range(10)) + list(range(5))) 1028*da0073e9SAndroid Build Coastguard Worker 1029*da0073e9SAndroid Build Coastguard Worker # Reset Test: reset the DataPipe 1030*da0073e9SAndroid Build Coastguard Worker n_elements_before_reset = 5 1031*da0073e9SAndroid Build Coastguard Worker res_before_reset, res_after_reset = reset_after_n_next_calls( 1032*da0073e9SAndroid Build Coastguard Worker concat_dp, n_elements_before_reset 1033*da0073e9SAndroid Build Coastguard Worker ) 1034*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(5)), res_before_reset) 1035*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)) + list(range(5)), res_after_reset) 1036*da0073e9SAndroid Build Coastguard Worker 1037*da0073e9SAndroid Build Coastguard Worker # __len__ Test: inherits length from source DataPipe 1038*da0073e9SAndroid Build Coastguard Worker input_dp_nl = IDP_NoLen(range(5)) 1039*da0073e9SAndroid Build Coastguard Worker concat_dp = input_dp1.concat(input_dp_nl) 1040*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): 1041*da0073e9SAndroid Build Coastguard Worker len(concat_dp) 1042*da0073e9SAndroid Build Coastguard Worker 1043*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(concat_dp), list(range(10)) + list(range(5))) 1044*da0073e9SAndroid Build Coastguard Worker 1045*da0073e9SAndroid Build Coastguard Worker def test_fork_iterdatapipe(self): 1046*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper(range(10)) 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1049*da0073e9SAndroid Build Coastguard Worker input_dp.fork(num_instances=0) 1050*da0073e9SAndroid Build Coastguard Worker 1051*da0073e9SAndroid Build Coastguard Worker dp0 = input_dp.fork(num_instances=1, buffer_size=0) 1052*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dp0, input_dp) 1053*da0073e9SAndroid Build Coastguard Worker 1054*da0073e9SAndroid Build Coastguard Worker # Functional Test: making sure all child DataPipe shares the same reference 1055*da0073e9SAndroid Build Coastguard Worker dp1, dp2, dp3 = input_dp.fork(num_instances=3) 1056*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(n1 is n2 and n1 is n3 for n1, n2, n3 in zip(dp1, dp2, dp3))) 1057*da0073e9SAndroid Build Coastguard Worker 1058*da0073e9SAndroid Build Coastguard Worker # Functional Test: one child DataPipe yields all value at a time 1059*da0073e9SAndroid Build Coastguard Worker output1, output2, output3 = list(dp1), list(dp2), list(dp3) 1060*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), output1) 1061*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), output2) 1062*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), output3) 1063*da0073e9SAndroid Build Coastguard Worker 1064*da0073e9SAndroid Build Coastguard Worker # Functional Test: two child DataPipes yield value together 1065*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.fork(num_instances=2) 1066*da0073e9SAndroid Build Coastguard Worker output = [] 1067*da0073e9SAndroid Build Coastguard Worker for n1, n2 in zip(dp1, dp2): 1068*da0073e9SAndroid Build Coastguard Worker output.append((n1, n2)) 1069*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(i, i) for i in range(10)], output) 1070*da0073e9SAndroid Build Coastguard Worker 1071*da0073e9SAndroid Build Coastguard Worker # Functional Test: one child DataPipe yields all value first, but buffer_size = 5 being too small 1072*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=4) 1073*da0073e9SAndroid Build Coastguard Worker it1 = iter(dp1) 1074*da0073e9SAndroid Build Coastguard Worker for _ in range(4): 1075*da0073e9SAndroid Build Coastguard Worker next(it1) 1076*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(BufferError): 1077*da0073e9SAndroid Build Coastguard Worker next(it1) 1078*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(BufferError): 1079*da0073e9SAndroid Build Coastguard Worker list(dp2) 1080*da0073e9SAndroid Build Coastguard Worker 1081*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=5) 1082*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(BufferError): 1083*da0073e9SAndroid Build Coastguard Worker list(dp2) 1084*da0073e9SAndroid Build Coastguard Worker 1085*da0073e9SAndroid Build Coastguard Worker # Functional Test: one child DataPipe yields all value first with unlimited buffer 1086*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 1087*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=-1) 1088*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 1) 1089*da0073e9SAndroid Build Coastguard Worker self.assertRegex(str(wa[0].message), r"Unlimited buffer size is set") 1090*da0073e9SAndroid Build Coastguard Worker l1, l2 = list(dp1), list(dp2) 1091*da0073e9SAndroid Build Coastguard Worker for d1, d2 in zip(l1, l2): 1092*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d1, d2) 1093*da0073e9SAndroid Build Coastguard Worker 1094*da0073e9SAndroid Build Coastguard Worker # Functional Test: two child DataPipes yield value together with buffer size 1 1095*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=1) 1096*da0073e9SAndroid Build Coastguard Worker output = [] 1097*da0073e9SAndroid Build Coastguard Worker for n1, n2 in zip(dp1, dp2): 1098*da0073e9SAndroid Build Coastguard Worker output.append((n1, n2)) 1099*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(i, i) for i in range(10)], output) 1100*da0073e9SAndroid Build Coastguard Worker 1101*da0073e9SAndroid Build Coastguard Worker # Functional Test: two child DataPipes yield shallow copies with copy equals shallow 1102*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.map(_to_list).fork(num_instances=2, copy="shallow") 1103*da0073e9SAndroid Build Coastguard Worker for n1, n2 in zip(dp1, dp2): 1104*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(n1, n2) 1105*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n1, n2) 1106*da0073e9SAndroid Build Coastguard Worker 1107*da0073e9SAndroid Build Coastguard Worker # Functional Test: two child DataPipes yield deep copies with copy equals deep 1108*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = ( 1109*da0073e9SAndroid Build Coastguard Worker input_dp.map(_to_list).map(_to_list).fork(num_instances=2, copy="deep") 1110*da0073e9SAndroid Build Coastguard Worker ) 1111*da0073e9SAndroid Build Coastguard Worker for n1, n2 in zip(dp1, dp2): 1112*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(n1[0], n2[0]) 1113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n1, n2) 1114*da0073e9SAndroid Build Coastguard Worker 1115*da0073e9SAndroid Build Coastguard Worker # Functional Test: fork DataPipe raises error for unknown copy method 1116*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1117*da0073e9SAndroid Build Coastguard Worker input_dp.fork(num_instances=2, copy="unknown") 1118*da0073e9SAndroid Build Coastguard Worker 1119*da0073e9SAndroid Build Coastguard Worker # Functional Test: make sure logic related to slowest_ptr is working properly 1120*da0073e9SAndroid Build Coastguard Worker dp1, dp2, dp3 = input_dp.fork(num_instances=3) 1121*da0073e9SAndroid Build Coastguard Worker output1, output2, output3 = [], [], [] 1122*da0073e9SAndroid Build Coastguard Worker for i, (n1, n2) in enumerate(zip(dp1, dp2)): 1123*da0073e9SAndroid Build Coastguard Worker output1.append(n1) 1124*da0073e9SAndroid Build Coastguard Worker output2.append(n2) 1125*da0073e9SAndroid Build Coastguard Worker if i == 4: # yield all of dp3 when halfway through dp1, dp2 1126*da0073e9SAndroid Build Coastguard Worker output3 = list(dp3) 1127*da0073e9SAndroid Build Coastguard Worker break 1128*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(5)), output1) 1129*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(5)), output2) 1130*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), output3) 1131*da0073e9SAndroid Build Coastguard Worker 1132*da0073e9SAndroid Build Coastguard Worker # Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read 1133*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.fork(num_instances=2) 1134*da0073e9SAndroid Build Coastguard Worker _ = iter(dp1) 1135*da0073e9SAndroid Build Coastguard Worker output2 = [] 1136*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"iterator has been invalidated"): 1137*da0073e9SAndroid Build Coastguard Worker for i, n2 in enumerate(dp2): 1138*da0073e9SAndroid Build Coastguard Worker output2.append(n2) 1139*da0073e9SAndroid Build Coastguard Worker if i == 4: 1140*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 1141*da0073e9SAndroid Build Coastguard Worker _ = iter(dp1) # This will reset all child DataPipes 1142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 1) 1143*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 1144*da0073e9SAndroid Build Coastguard Worker str(wa[0].message), r"child DataPipes are not exhausted" 1145*da0073e9SAndroid Build Coastguard Worker ) 1146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(5)), output2) 1147*da0073e9SAndroid Build Coastguard Worker 1148*da0073e9SAndroid Build Coastguard Worker # Reset Test: DataPipe resets when some of it has been read 1149*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.fork(num_instances=2) 1150*da0073e9SAndroid Build Coastguard Worker output1, output2 = [], [] 1151*da0073e9SAndroid Build Coastguard Worker for i, (n1, n2) in enumerate(zip(dp1, dp2)): 1152*da0073e9SAndroid Build Coastguard Worker output1.append(n1) 1153*da0073e9SAndroid Build Coastguard Worker output2.append(n2) 1154*da0073e9SAndroid Build Coastguard Worker if i == 4: 1155*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 1156*da0073e9SAndroid Build Coastguard Worker _ = iter(dp1) # Reset both all child DataPipe 1157*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 1) 1158*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 1159*da0073e9SAndroid Build Coastguard Worker str(wa[0].message), r"Some child DataPipes are not exhausted" 1160*da0073e9SAndroid Build Coastguard Worker ) 1161*da0073e9SAndroid Build Coastguard Worker break 1162*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 1163*da0073e9SAndroid Build Coastguard Worker for i, (n1, n2) in enumerate(zip(dp1, dp2)): 1164*da0073e9SAndroid Build Coastguard Worker output1.append(n1) 1165*da0073e9SAndroid Build Coastguard Worker output2.append(n2) 1166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 1) 1167*da0073e9SAndroid Build Coastguard Worker self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted") 1168*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(5)) + list(range(10)), output1) 1169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(5)) + list(range(10)), output2) 1170*da0073e9SAndroid Build Coastguard Worker 1171*da0073e9SAndroid Build Coastguard Worker # Reset Test: DataPipe reset, even when some other child DataPipes are not read 1172*da0073e9SAndroid Build Coastguard Worker dp1, dp2, dp3 = input_dp.fork(num_instances=3) 1173*da0073e9SAndroid Build Coastguard Worker output1, output2 = list(dp1), list(dp2) 1174*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), output1) 1175*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), output2) 1176*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 1177*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1178*da0073e9SAndroid Build Coastguard Worker list(range(10)), list(dp1) 1179*da0073e9SAndroid Build Coastguard Worker ) # Resets even though dp3 has not been read 1180*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 1) 1181*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 1182*da0073e9SAndroid Build Coastguard Worker str(wa[0].message), r"Some child DataPipes are not exhausted" 1183*da0073e9SAndroid Build Coastguard Worker ) 1184*da0073e9SAndroid Build Coastguard Worker output3 = [] 1185*da0073e9SAndroid Build Coastguard Worker for i, n3 in enumerate(dp3): 1186*da0073e9SAndroid Build Coastguard Worker output3.append(n3) 1187*da0073e9SAndroid Build Coastguard Worker if i == 4: 1188*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 1189*da0073e9SAndroid Build Coastguard Worker output1 = list(dp1) # Resets even though dp3 is only partially read 1190*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 1) 1191*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 1192*da0073e9SAndroid Build Coastguard Worker str(wa[0].message), r"Some child DataPipes are not exhausted" 1193*da0073e9SAndroid Build Coastguard Worker ) 1194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(5)), output3) 1195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), output1) 1196*da0073e9SAndroid Build Coastguard Worker break 1197*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1198*da0073e9SAndroid Build Coastguard Worker list(range(10)), list(dp3) 1199*da0073e9SAndroid Build Coastguard Worker ) # dp3 has to read from the start again 1200*da0073e9SAndroid Build Coastguard Worker 1201*da0073e9SAndroid Build Coastguard Worker # __len__ Test: Each DataPipe inherits the source datapipe's length 1202*da0073e9SAndroid Build Coastguard Worker dp1, dp2, dp3 = input_dp.fork(num_instances=3) 1203*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(input_dp), len(dp1)) 1204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(input_dp), len(dp2)) 1205*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(input_dp), len(dp3)) 1206*da0073e9SAndroid Build Coastguard Worker 1207*da0073e9SAndroid Build Coastguard Worker # Pickle Test: 1208*da0073e9SAndroid Build Coastguard Worker dp1, dp2, dp3 = input_dp.fork(num_instances=3) 1209*da0073e9SAndroid Build Coastguard Worker traverse_dps(dp1) # This should not raise any error 1210*da0073e9SAndroid Build Coastguard Worker for _ in zip(dp1, dp2, dp3): 1211*da0073e9SAndroid Build Coastguard Worker pass 1212*da0073e9SAndroid Build Coastguard Worker traverse_dps(dp2) # This should not raise any error either 1213*da0073e9SAndroid Build Coastguard Worker 1214*da0073e9SAndroid Build Coastguard Worker def test_mux_iterdatapipe(self): 1215*da0073e9SAndroid Build Coastguard Worker # Functional Test: Elements are yielded one at a time from each DataPipe, until they are all exhausted 1216*da0073e9SAndroid Build Coastguard Worker input_dp1 = dp.iter.IterableWrapper(range(4)) 1217*da0073e9SAndroid Build Coastguard Worker input_dp2 = dp.iter.IterableWrapper(range(4, 8)) 1218*da0073e9SAndroid Build Coastguard Worker input_dp3 = dp.iter.IterableWrapper(range(8, 12)) 1219*da0073e9SAndroid Build Coastguard Worker output_dp = input_dp1.mux(input_dp2, input_dp3) 1220*da0073e9SAndroid Build Coastguard Worker expected_output = [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11] 1221*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(expected_output), len(output_dp)) 1222*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_output, list(output_dp)) 1223*da0073e9SAndroid Build Coastguard Worker 1224*da0073e9SAndroid Build Coastguard Worker # Functional Test: Uneven input Data Pipes 1225*da0073e9SAndroid Build Coastguard Worker input_dp1 = dp.iter.IterableWrapper([1, 2, 3, 4]) 1226*da0073e9SAndroid Build Coastguard Worker input_dp2 = dp.iter.IterableWrapper([10]) 1227*da0073e9SAndroid Build Coastguard Worker input_dp3 = dp.iter.IterableWrapper([100, 200, 300]) 1228*da0073e9SAndroid Build Coastguard Worker output_dp = input_dp1.mux(input_dp2, input_dp3) 1229*da0073e9SAndroid Build Coastguard Worker expected_output = [1, 10, 100] 1230*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(expected_output), len(output_dp)) 1231*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_output, list(output_dp)) 1232*da0073e9SAndroid Build Coastguard Worker 1233*da0073e9SAndroid Build Coastguard Worker # Functional Test: Empty Data Pipe 1234*da0073e9SAndroid Build Coastguard Worker input_dp1 = dp.iter.IterableWrapper([0, 1, 2, 3]) 1235*da0073e9SAndroid Build Coastguard Worker input_dp2 = dp.iter.IterableWrapper([]) 1236*da0073e9SAndroid Build Coastguard Worker output_dp = input_dp1.mux(input_dp2) 1237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(input_dp2), len(output_dp)) 1238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(input_dp2), list(output_dp)) 1239*da0073e9SAndroid Build Coastguard Worker 1240*da0073e9SAndroid Build Coastguard Worker # __len__ Test: raises TypeError when __len__ is called and an input doesn't have __len__ 1241*da0073e9SAndroid Build Coastguard Worker input_dp1 = dp.iter.IterableWrapper(range(10)) 1242*da0073e9SAndroid Build Coastguard Worker input_dp_no_len = IDP_NoLen(range(10)) 1243*da0073e9SAndroid Build Coastguard Worker output_dp = input_dp1.mux(input_dp_no_len) 1244*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1245*da0073e9SAndroid Build Coastguard Worker len(output_dp) 1246*da0073e9SAndroid Build Coastguard Worker 1247*da0073e9SAndroid Build Coastguard Worker def test_demux_iterdatapipe(self): 1248*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper(range(10)) 1249*da0073e9SAndroid Build Coastguard Worker 1250*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1251*da0073e9SAndroid Build Coastguard Worker input_dp.demux(num_instances=0, classifier_fn=lambda x: 0) 1252*da0073e9SAndroid Build Coastguard Worker 1253*da0073e9SAndroid Build Coastguard Worker # Functional Test: split into 2 DataPipes and output them one at a time 1254*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) 1255*da0073e9SAndroid Build Coastguard Worker output1, output2 = list(dp1), list(dp2) 1256*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(0, 10, 2)), output1) 1257*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(1, 10, 2)), output2) 1258*da0073e9SAndroid Build Coastguard Worker 1259*da0073e9SAndroid Build Coastguard Worker # Functional Test: split into 2 DataPipes and output them together 1260*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) 1261*da0073e9SAndroid Build Coastguard Worker output = [] 1262*da0073e9SAndroid Build Coastguard Worker for n1, n2 in zip(dp1, dp2): 1263*da0073e9SAndroid Build Coastguard Worker output.append((n1, n2)) 1264*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(i, i + 1) for i in range(0, 10, 2)], output) 1265*da0073e9SAndroid Build Coastguard Worker 1266*da0073e9SAndroid Build Coastguard Worker # Functional Test: values of the same classification are lumped together, and buffer_size = 3 being too small 1267*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.demux( 1268*da0073e9SAndroid Build Coastguard Worker num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=4 1269*da0073e9SAndroid Build Coastguard Worker ) 1270*da0073e9SAndroid Build Coastguard Worker it1 = iter(dp1) 1271*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(BufferError): 1272*da0073e9SAndroid Build Coastguard Worker next( 1273*da0073e9SAndroid Build Coastguard Worker it1 1274*da0073e9SAndroid Build Coastguard Worker ) # Buffer raises because first 5 elements all belong to the a different child 1275*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(BufferError): 1276*da0073e9SAndroid Build Coastguard Worker list(dp2) 1277*da0073e9SAndroid Build Coastguard Worker 1278*da0073e9SAndroid Build Coastguard Worker # Functional Test: values of the same classification are lumped together, and buffer_size = 5 is just enough 1279*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.demux( 1280*da0073e9SAndroid Build Coastguard Worker num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=5 1281*da0073e9SAndroid Build Coastguard Worker ) 1282*da0073e9SAndroid Build Coastguard Worker output1, output2 = list(dp1), list(dp2) 1283*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(5, 10)), output1) 1284*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(0, 5)), output2) 1285*da0073e9SAndroid Build Coastguard Worker 1286*da0073e9SAndroid Build Coastguard Worker # Functional Test: values of the same classification are lumped together, and unlimited buffer 1287*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 1288*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.demux( 1289*da0073e9SAndroid Build Coastguard Worker num_instances=2, 1290*da0073e9SAndroid Build Coastguard Worker classifier_fn=lambda x: 0 if x >= 5 else 1, 1291*da0073e9SAndroid Build Coastguard Worker buffer_size=-1, 1292*da0073e9SAndroid Build Coastguard Worker ) 1293*da0073e9SAndroid Build Coastguard Worker exp_l = 1 if HAS_DILL else 2 1294*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), exp_l) 1295*da0073e9SAndroid Build Coastguard Worker self.assertRegex(str(wa[-1].message), r"Unlimited buffer size is set") 1296*da0073e9SAndroid Build Coastguard Worker output1, output2 = list(dp1), list(dp2) 1297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(5, 10)), output1) 1298*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(0, 5)), output2) 1299*da0073e9SAndroid Build Coastguard Worker 1300*da0073e9SAndroid Build Coastguard Worker # Functional Test: classifier returns a value outside of [0, num_instance - 1] 1301*da0073e9SAndroid Build Coastguard Worker dp0 = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2) 1302*da0073e9SAndroid Build Coastguard Worker it = iter(dp0[0]) 1303*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1304*da0073e9SAndroid Build Coastguard Worker next(it) 1305*da0073e9SAndroid Build Coastguard Worker next(it) 1306*da0073e9SAndroid Build Coastguard Worker 1307*da0073e9SAndroid Build Coastguard Worker # Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read 1308*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) 1309*da0073e9SAndroid Build Coastguard Worker _ = iter(dp1) 1310*da0073e9SAndroid Build Coastguard Worker output2 = [] 1311*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"iterator has been invalidated"): 1312*da0073e9SAndroid Build Coastguard Worker for i, n2 in enumerate(dp2): 1313*da0073e9SAndroid Build Coastguard Worker output2.append(n2) 1314*da0073e9SAndroid Build Coastguard Worker if i == 4: 1315*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 1316*da0073e9SAndroid Build Coastguard Worker _ = iter(dp1) # This will reset all child DataPipes 1317*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 1) 1318*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 1319*da0073e9SAndroid Build Coastguard Worker str(wa[0].message), r"child DataPipes are not exhausted" 1320*da0073e9SAndroid Build Coastguard Worker ) 1321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(1, 10, 2)), output2) 1322*da0073e9SAndroid Build Coastguard Worker 1323*da0073e9SAndroid Build Coastguard Worker # Reset Test: DataPipe resets when some of it has been read 1324*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) 1325*da0073e9SAndroid Build Coastguard Worker output1, output2 = [], [] 1326*da0073e9SAndroid Build Coastguard Worker for n1, n2 in zip(dp1, dp2): 1327*da0073e9SAndroid Build Coastguard Worker output1.append(n1) 1328*da0073e9SAndroid Build Coastguard Worker output2.append(n2) 1329*da0073e9SAndroid Build Coastguard Worker if n1 == 4: 1330*da0073e9SAndroid Build Coastguard Worker break 1331*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 1332*da0073e9SAndroid Build Coastguard Worker i1 = iter(dp1) # Reset all child DataPipes 1333*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 1) 1334*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 1335*da0073e9SAndroid Build Coastguard Worker str(wa[0].message), r"Some child DataPipes are not exhausted" 1336*da0073e9SAndroid Build Coastguard Worker ) 1337*da0073e9SAndroid Build Coastguard Worker for n1, n2 in zip(dp1, dp2): 1338*da0073e9SAndroid Build Coastguard Worker output1.append(n1) 1339*da0073e9SAndroid Build Coastguard Worker output2.append(n2) 1340*da0073e9SAndroid Build Coastguard Worker self.assertEqual([0, 2, 4] + list(range(0, 10, 2)), output1) 1341*da0073e9SAndroid Build Coastguard Worker self.assertEqual([1, 3, 5] + list(range(1, 10, 2)), output2) 1342*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 1) 1343*da0073e9SAndroid Build Coastguard Worker self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted") 1344*da0073e9SAndroid Build Coastguard Worker 1345*da0073e9SAndroid Build Coastguard Worker # Reset Test: DataPipe reset, even when not all child DataPipes are exhausted 1346*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) 1347*da0073e9SAndroid Build Coastguard Worker output1 = list(dp1) 1348*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(0, 10, 2)), output1) 1349*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 1350*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1351*da0073e9SAndroid Build Coastguard Worker list(range(0, 10, 2)), list(dp1) 1352*da0073e9SAndroid Build Coastguard Worker ) # Reset even when dp2 is not read 1353*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 1) 1354*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 1355*da0073e9SAndroid Build Coastguard Worker str(wa[0].message), r"Some child DataPipes are not exhausted" 1356*da0073e9SAndroid Build Coastguard Worker ) 1357*da0073e9SAndroid Build Coastguard Worker output2 = [] 1358*da0073e9SAndroid Build Coastguard Worker for i, n2 in enumerate(dp2): 1359*da0073e9SAndroid Build Coastguard Worker output2.append(n2) 1360*da0073e9SAndroid Build Coastguard Worker if i == 1: 1361*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(1, 5, 2)), output2) 1362*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 1363*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1364*da0073e9SAndroid Build Coastguard Worker list(range(0, 10, 2)), list(dp1) 1365*da0073e9SAndroid Build Coastguard Worker ) # Can reset even when dp2 is partially read 1366*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 1) 1367*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 1368*da0073e9SAndroid Build Coastguard Worker str(wa[0].message), r"Some child DataPipes are not exhausted" 1369*da0073e9SAndroid Build Coastguard Worker ) 1370*da0073e9SAndroid Build Coastguard Worker break 1371*da0073e9SAndroid Build Coastguard Worker output2 = list(dp2) # output2 has to read from beginning again 1372*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(1, 10, 2)), output2) 1373*da0073e9SAndroid Build Coastguard Worker 1374*da0073e9SAndroid Build Coastguard Worker # Functional Test: drop_none = True 1375*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.demux( 1376*da0073e9SAndroid Build Coastguard Worker num_instances=2, 1377*da0073e9SAndroid Build Coastguard Worker classifier_fn=lambda x: x % 2 if x % 5 != 0 else None, 1378*da0073e9SAndroid Build Coastguard Worker drop_none=True, 1379*da0073e9SAndroid Build Coastguard Worker ) 1380*da0073e9SAndroid Build Coastguard Worker self.assertEqual([2, 4, 6, 8], list(dp1)) 1381*da0073e9SAndroid Build Coastguard Worker self.assertEqual([1, 3, 7, 9], list(dp2)) 1382*da0073e9SAndroid Build Coastguard Worker 1383*da0073e9SAndroid Build Coastguard Worker # Functional Test: drop_none = False 1384*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.demux( 1385*da0073e9SAndroid Build Coastguard Worker num_instances=2, 1386*da0073e9SAndroid Build Coastguard Worker classifier_fn=lambda x: x % 2 if x % 5 != 0 else None, 1387*da0073e9SAndroid Build Coastguard Worker drop_none=False, 1388*da0073e9SAndroid Build Coastguard Worker ) 1389*da0073e9SAndroid Build Coastguard Worker it1 = iter(dp1) 1390*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1391*da0073e9SAndroid Build Coastguard Worker next(it1) 1392*da0073e9SAndroid Build Coastguard Worker 1393*da0073e9SAndroid Build Coastguard Worker # __len__ Test: __len__ not implemented 1394*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) 1395*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1396*da0073e9SAndroid Build Coastguard Worker len( 1397*da0073e9SAndroid Build Coastguard Worker dp1 1398*da0073e9SAndroid Build Coastguard Worker ) # It is not implemented as we do not know length for each child in advance 1399*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1400*da0073e9SAndroid Build Coastguard Worker len(dp2) 1401*da0073e9SAndroid Build Coastguard Worker 1402*da0073e9SAndroid Build Coastguard Worker # Pickle Test: 1403*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=odd_or_even) 1404*da0073e9SAndroid Build Coastguard Worker traverse_dps(dp1) # This should not raise any error 1405*da0073e9SAndroid Build Coastguard Worker for _ in zip(dp1, dp2): 1406*da0073e9SAndroid Build Coastguard Worker pass 1407*da0073e9SAndroid Build Coastguard Worker traverse_dps(dp2) # This should not raise any error either 1408*da0073e9SAndroid Build Coastguard Worker 1409*da0073e9SAndroid Build Coastguard Worker def test_map_iterdatapipe(self): 1410*da0073e9SAndroid Build Coastguard Worker target_length = 10 1411*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper(range(target_length)) 1412*da0073e9SAndroid Build Coastguard Worker 1413*da0073e9SAndroid Build Coastguard Worker def fn(item, dtype=torch.float, *, sum=False): 1414*da0073e9SAndroid Build Coastguard Worker data = torch.tensor(item, dtype=dtype) 1415*da0073e9SAndroid Build Coastguard Worker return data if not sum else data.sum() 1416*da0073e9SAndroid Build Coastguard Worker 1417*da0073e9SAndroid Build Coastguard Worker # Functional Test: apply to each element correctly 1418*da0073e9SAndroid Build Coastguard Worker map_dp = input_dp.map(fn) 1419*da0073e9SAndroid Build Coastguard Worker self.assertEqual(target_length, len(map_dp)) 1420*da0073e9SAndroid Build Coastguard Worker for x, y in zip(map_dp, range(target_length)): 1421*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.tensor(y, dtype=torch.float)) 1422*da0073e9SAndroid Build Coastguard Worker 1423*da0073e9SAndroid Build Coastguard Worker # Functional Test: works with partial function 1424*da0073e9SAndroid Build Coastguard Worker map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True)) 1425*da0073e9SAndroid Build Coastguard Worker for x, y in zip(map_dp, range(target_length)): 1426*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum()) 1427*da0073e9SAndroid Build Coastguard Worker 1428*da0073e9SAndroid Build Coastguard Worker # __len__ Test: inherits length from source DataPipe 1429*da0073e9SAndroid Build Coastguard Worker self.assertEqual(target_length, len(map_dp)) 1430*da0073e9SAndroid Build Coastguard Worker 1431*da0073e9SAndroid Build Coastguard Worker input_dp_nl = IDP_NoLen(range(target_length)) 1432*da0073e9SAndroid Build Coastguard Worker map_dp_nl = input_dp_nl.map(lambda x: x) 1433*da0073e9SAndroid Build Coastguard Worker for x, y in zip(map_dp_nl, range(target_length)): 1434*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.tensor(y, dtype=torch.float)) 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker # __len__ Test: inherits length from source DataPipe - raises error when invalid 1437*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): 1438*da0073e9SAndroid Build Coastguard Worker len(map_dp_nl) 1439*da0073e9SAndroid Build Coastguard Worker 1440*da0073e9SAndroid Build Coastguard Worker # Reset Test: DataPipe resets properly 1441*da0073e9SAndroid Build Coastguard Worker n_elements_before_reset = 5 1442*da0073e9SAndroid Build Coastguard Worker res_before_reset, res_after_reset = reset_after_n_next_calls( 1443*da0073e9SAndroid Build Coastguard Worker map_dp, n_elements_before_reset 1444*da0073e9SAndroid Build Coastguard Worker ) 1445*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(n_elements_before_reset)), res_before_reset) 1446*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), res_after_reset) 1447*da0073e9SAndroid Build Coastguard Worker 1448*da0073e9SAndroid Build Coastguard Worker @suppress_warnings # Suppress warning for lambda fn 1449*da0073e9SAndroid Build Coastguard Worker def test_map_tuple_list_with_col_iterdatapipe(self): 1450*da0073e9SAndroid Build Coastguard Worker def fn_11(d): 1451*da0073e9SAndroid Build Coastguard Worker return -d 1452*da0073e9SAndroid Build Coastguard Worker 1453*da0073e9SAndroid Build Coastguard Worker def fn_1n(d): 1454*da0073e9SAndroid Build Coastguard Worker return -d, d 1455*da0073e9SAndroid Build Coastguard Worker 1456*da0073e9SAndroid Build Coastguard Worker def fn_n1(d0, d1): 1457*da0073e9SAndroid Build Coastguard Worker return d0 + d1 1458*da0073e9SAndroid Build Coastguard Worker 1459*da0073e9SAndroid Build Coastguard Worker def fn_nn(d0, d1): 1460*da0073e9SAndroid Build Coastguard Worker return -d0, -d1, d0 + d1 1461*da0073e9SAndroid Build Coastguard Worker 1462*da0073e9SAndroid Build Coastguard Worker def fn_n1_def(d0, d1=1): 1463*da0073e9SAndroid Build Coastguard Worker return d0 + d1 1464*da0073e9SAndroid Build Coastguard Worker 1465*da0073e9SAndroid Build Coastguard Worker def fn_n1_kwargs(d0, d1, **kwargs): 1466*da0073e9SAndroid Build Coastguard Worker return d0 + d1 1467*da0073e9SAndroid Build Coastguard Worker 1468*da0073e9SAndroid Build Coastguard Worker def fn_n1_pos(d0, d1, *args): 1469*da0073e9SAndroid Build Coastguard Worker return d0 + d1 1470*da0073e9SAndroid Build Coastguard Worker 1471*da0073e9SAndroid Build Coastguard Worker def fn_n1_sep_pos(d0, *args, d1): 1472*da0073e9SAndroid Build Coastguard Worker return d0 + d1 1473*da0073e9SAndroid Build Coastguard Worker 1474*da0073e9SAndroid Build Coastguard Worker def fn_cmplx(d0, d1=1, *args, d2, **kwargs): 1475*da0073e9SAndroid Build Coastguard Worker return d0 + d1 1476*da0073e9SAndroid Build Coastguard Worker 1477*da0073e9SAndroid Build Coastguard Worker p_fn_n1 = partial(fn_n1, d1=1) 1478*da0073e9SAndroid Build Coastguard Worker p_fn_cmplx = partial(fn_cmplx, d2=2) 1479*da0073e9SAndroid Build Coastguard Worker p_fn_cmplx_large_arg = partial( 1480*da0073e9SAndroid Build Coastguard Worker fn_cmplx, d2={i: list(range(i)) for i in range(10_000)} 1481*da0073e9SAndroid Build Coastguard Worker ) 1482*da0073e9SAndroid Build Coastguard Worker 1483*da0073e9SAndroid Build Coastguard Worker def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): 1484*da0073e9SAndroid Build Coastguard Worker for constr in (list, tuple): 1485*da0073e9SAndroid Build Coastguard Worker datapipe = dp.iter.IterableWrapper( 1486*da0073e9SAndroid Build Coastguard Worker [constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))] 1487*da0073e9SAndroid Build Coastguard Worker ) 1488*da0073e9SAndroid Build Coastguard Worker if ref_fn is None: 1489*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(error): 1490*da0073e9SAndroid Build Coastguard Worker res_dp = datapipe.map(fn, input_col, output_col) 1491*da0073e9SAndroid Build Coastguard Worker list(res_dp) 1492*da0073e9SAndroid Build Coastguard Worker else: 1493*da0073e9SAndroid Build Coastguard Worker res_dp = datapipe.map(fn, input_col, output_col) 1494*da0073e9SAndroid Build Coastguard Worker ref_dp = datapipe.map(ref_fn) 1495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(res_dp), list(ref_dp)) 1496*da0073e9SAndroid Build Coastguard Worker # Reset 1497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(res_dp), list(ref_dp)) 1498*da0073e9SAndroid Build Coastguard Worker 1499*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: data, fn_n1_def, 0, 1) 1500*da0073e9SAndroid Build Coastguard Worker _helper( 1501*da0073e9SAndroid Build Coastguard Worker lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2 1502*da0073e9SAndroid Build Coastguard Worker ) 1503*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: data, p_fn_n1, 0, 1) 1504*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: data, p_fn_cmplx, 0, 1) 1505*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: data, p_fn_cmplx_large_arg, 0, 1) 1506*da0073e9SAndroid Build Coastguard Worker _helper( 1507*da0073e9SAndroid Build Coastguard Worker lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2 1508*da0073e9SAndroid Build Coastguard Worker ) 1509*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: (data[0] + data[1],), fn_n1_pos, [0, 1, 2]) 1510*da0073e9SAndroid Build Coastguard Worker 1511*da0073e9SAndroid Build Coastguard Worker # Replacing with one input column and default output column 1512*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1) 1513*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1) 1514*da0073e9SAndroid Build Coastguard Worker # The index of input column is out of range 1515*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_1n, 3, error=IndexError) 1516*da0073e9SAndroid Build Coastguard Worker # Unmatched input columns with fn arguments 1517*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_n1, 1, error=ValueError) 1518*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_n1, [0, 1, 2], error=ValueError) 1519*da0073e9SAndroid Build Coastguard Worker _helper(None, operator.add, 0, error=ValueError) 1520*da0073e9SAndroid Build Coastguard Worker _helper(None, operator.add, [0, 1, 2], error=ValueError) 1521*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_cmplx, 0, 1, ValueError) 1522*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_n1_pos, 1, error=ValueError) 1523*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError) 1524*da0073e9SAndroid Build Coastguard Worker _helper(None, p_fn_n1, [0, 1], error=ValueError) 1525*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_1n, [1, 2], error=ValueError) 1526*da0073e9SAndroid Build Coastguard Worker # _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError) 1527*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError) 1528*da0073e9SAndroid Build Coastguard Worker # Fn has keyword-only arguments 1529*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_n1_kwargs, 1, error=ValueError) 1530*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_cmplx, [0, 1], 2, ValueError) 1531*da0073e9SAndroid Build Coastguard Worker 1532*da0073e9SAndroid Build Coastguard Worker # Replacing with multiple input columns and default output column (the left-most input column) 1533*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0]) 1534*da0073e9SAndroid Build Coastguard Worker _helper( 1535*da0073e9SAndroid Build Coastguard Worker lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), 1536*da0073e9SAndroid Build Coastguard Worker fn_nn, 1537*da0073e9SAndroid Build Coastguard Worker [2, 1], 1538*da0073e9SAndroid Build Coastguard Worker ) 1539*da0073e9SAndroid Build Coastguard Worker 1540*da0073e9SAndroid Build Coastguard Worker # output_col can only be specified when input_col is not None 1541*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_n1, None, 1, error=ValueError) 1542*da0073e9SAndroid Build Coastguard Worker # output_col can only be single-element list or tuple 1543*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_n1, None, [0, 1], error=ValueError) 1544*da0073e9SAndroid Build Coastguard Worker # Single-element list as output_col 1545*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0]) 1546*da0073e9SAndroid Build Coastguard Worker # Replacing with one input column and single specified output column 1547*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0) 1548*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2) 1549*da0073e9SAndroid Build Coastguard Worker # The index of output column is out of range 1550*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_1n, 1, 3, error=IndexError) 1551*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1) 1552*da0073e9SAndroid Build Coastguard Worker _helper( 1553*da0073e9SAndroid Build Coastguard Worker lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), 1554*da0073e9SAndroid Build Coastguard Worker fn_nn, 1555*da0073e9SAndroid Build Coastguard Worker [1, 2], 1556*da0073e9SAndroid Build Coastguard Worker 0, 1557*da0073e9SAndroid Build Coastguard Worker ) 1558*da0073e9SAndroid Build Coastguard Worker 1559*da0073e9SAndroid Build Coastguard Worker # Appending the output at the end 1560*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: (*data, -data[1]), fn_11, 1, -1) 1561*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1) 1562*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1) 1563*da0073e9SAndroid Build Coastguard Worker _helper( 1564*da0073e9SAndroid Build Coastguard Worker lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), 1565*da0073e9SAndroid Build Coastguard Worker fn_nn, 1566*da0073e9SAndroid Build Coastguard Worker [1, 2], 1567*da0073e9SAndroid Build Coastguard Worker -1, 1568*da0073e9SAndroid Build Coastguard Worker ) 1569*da0073e9SAndroid Build Coastguard Worker 1570*da0073e9SAndroid Build Coastguard Worker # Handling built-in functions (e.g. `dict`, `iter`, `int`, `str`) whose signatures cannot be inspected 1571*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: (str(data[0]), data[1], data[2]), str, 0) 1572*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: (data[0], data[1], int(data[2])), int, 2) 1573*da0073e9SAndroid Build Coastguard Worker 1574*da0073e9SAndroid Build Coastguard Worker # Handle nn.Module and Callable (without __name__ implemented) 1575*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Module(), 0) 1576*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Callable(), 0) 1577*da0073e9SAndroid Build Coastguard Worker 1578*da0073e9SAndroid Build Coastguard Worker @suppress_warnings # Suppress warning for lambda fn 1579*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 1580*da0073e9SAndroid Build Coastguard Worker def test_map_dict_with_col_iterdatapipe(self): 1581*da0073e9SAndroid Build Coastguard Worker def fn_11(d): 1582*da0073e9SAndroid Build Coastguard Worker return -d 1583*da0073e9SAndroid Build Coastguard Worker 1584*da0073e9SAndroid Build Coastguard Worker def fn_1n(d): 1585*da0073e9SAndroid Build Coastguard Worker return -d, d 1586*da0073e9SAndroid Build Coastguard Worker 1587*da0073e9SAndroid Build Coastguard Worker def fn_n1(d0, d1): 1588*da0073e9SAndroid Build Coastguard Worker return d0 + d1 1589*da0073e9SAndroid Build Coastguard Worker 1590*da0073e9SAndroid Build Coastguard Worker def fn_nn(d0, d1): 1591*da0073e9SAndroid Build Coastguard Worker return -d0, -d1, d0 + d1 1592*da0073e9SAndroid Build Coastguard Worker 1593*da0073e9SAndroid Build Coastguard Worker def fn_n1_def(d0, d1=1): 1594*da0073e9SAndroid Build Coastguard Worker return d0 + d1 1595*da0073e9SAndroid Build Coastguard Worker 1596*da0073e9SAndroid Build Coastguard Worker p_fn_n1 = partial(fn_n1, d1=1) 1597*da0073e9SAndroid Build Coastguard Worker 1598*da0073e9SAndroid Build Coastguard Worker def fn_n1_pos(d0, d1, *args): 1599*da0073e9SAndroid Build Coastguard Worker return d0 + d1 1600*da0073e9SAndroid Build Coastguard Worker 1601*da0073e9SAndroid Build Coastguard Worker def fn_n1_kwargs(d0, d1, **kwargs): 1602*da0073e9SAndroid Build Coastguard Worker return d0 + d1 1603*da0073e9SAndroid Build Coastguard Worker 1604*da0073e9SAndroid Build Coastguard Worker def fn_kwonly(*, d0, d1): 1605*da0073e9SAndroid Build Coastguard Worker return d0 + d1 1606*da0073e9SAndroid Build Coastguard Worker 1607*da0073e9SAndroid Build Coastguard Worker def fn_has_nondefault_kwonly(d0, *, d1): 1608*da0073e9SAndroid Build Coastguard Worker return d0 + d1 1609*da0073e9SAndroid Build Coastguard Worker 1610*da0073e9SAndroid Build Coastguard Worker def fn_cmplx(d0, d1=1, *args, d2, **kwargs): 1611*da0073e9SAndroid Build Coastguard Worker return d0 + d1 1612*da0073e9SAndroid Build Coastguard Worker 1613*da0073e9SAndroid Build Coastguard Worker p_fn_cmplx = partial(fn_cmplx, d2=2) 1614*da0073e9SAndroid Build Coastguard Worker p_fn_cmplx_large_arg = partial( 1615*da0073e9SAndroid Build Coastguard Worker fn_cmplx, d2={i: list(range(i)) for i in range(10_000)} 1616*da0073e9SAndroid Build Coastguard Worker ) 1617*da0073e9SAndroid Build Coastguard Worker 1618*da0073e9SAndroid Build Coastguard Worker # Prevent modification in-place to support resetting 1619*da0073e9SAndroid Build Coastguard Worker def _dict_update(data, newdata, remove_idx=None): 1620*da0073e9SAndroid Build Coastguard Worker _data = dict(data) 1621*da0073e9SAndroid Build Coastguard Worker _data.update(newdata) 1622*da0073e9SAndroid Build Coastguard Worker if remove_idx: 1623*da0073e9SAndroid Build Coastguard Worker for idx in remove_idx: 1624*da0073e9SAndroid Build Coastguard Worker del _data[idx] 1625*da0073e9SAndroid Build Coastguard Worker return _data 1626*da0073e9SAndroid Build Coastguard Worker 1627*da0073e9SAndroid Build Coastguard Worker def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): 1628*da0073e9SAndroid Build Coastguard Worker datapipe = dp.iter.IterableWrapper( 1629*da0073e9SAndroid Build Coastguard Worker [ 1630*da0073e9SAndroid Build Coastguard Worker {"x": 0, "y": 1, "z": 2}, 1631*da0073e9SAndroid Build Coastguard Worker {"x": 3, "y": 4, "z": 5}, 1632*da0073e9SAndroid Build Coastguard Worker {"x": 6, "y": 7, "z": 8}, 1633*da0073e9SAndroid Build Coastguard Worker ] 1634*da0073e9SAndroid Build Coastguard Worker ) 1635*da0073e9SAndroid Build Coastguard Worker if ref_fn is None: 1636*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(error): 1637*da0073e9SAndroid Build Coastguard Worker res_dp = datapipe.map(fn, input_col, output_col) 1638*da0073e9SAndroid Build Coastguard Worker list(res_dp) 1639*da0073e9SAndroid Build Coastguard Worker else: 1640*da0073e9SAndroid Build Coastguard Worker res_dp = datapipe.map(fn, input_col, output_col) 1641*da0073e9SAndroid Build Coastguard Worker ref_dp = datapipe.map(ref_fn) 1642*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(res_dp), list(ref_dp)) 1643*da0073e9SAndroid Build Coastguard Worker # Reset 1644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(res_dp), list(ref_dp)) 1645*da0073e9SAndroid Build Coastguard Worker 1646*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: data, fn_n1_def, "x", "y") 1647*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: data, p_fn_n1, "x", "y") 1648*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: data, p_fn_cmplx, "x", "y") 1649*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: data, p_fn_cmplx_large_arg, "x", "y") 1650*da0073e9SAndroid Build Coastguard Worker _helper( 1651*da0073e9SAndroid Build Coastguard Worker lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), 1652*da0073e9SAndroid Build Coastguard Worker p_fn_cmplx, 1653*da0073e9SAndroid Build Coastguard Worker ["x", "y", "z"], 1654*da0073e9SAndroid Build Coastguard Worker "z", 1655*da0073e9SAndroid Build Coastguard Worker ) 1656*da0073e9SAndroid Build Coastguard Worker 1657*da0073e9SAndroid Build Coastguard Worker _helper( 1658*da0073e9SAndroid Build Coastguard Worker lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), 1659*da0073e9SAndroid Build Coastguard Worker fn_n1_def, 1660*da0073e9SAndroid Build Coastguard Worker ["x", "y"], 1661*da0073e9SAndroid Build Coastguard Worker "z", 1662*da0073e9SAndroid Build Coastguard Worker ) 1663*da0073e9SAndroid Build Coastguard Worker 1664*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_n1_pos, "x", error=ValueError) 1665*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_n1_kwargs, "x", error=ValueError) 1666*da0073e9SAndroid Build Coastguard Worker # non-default kw-only args 1667*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_kwonly, ["x", "y"], error=ValueError) 1668*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_has_nondefault_kwonly, ["x", "y"], error=ValueError) 1669*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_cmplx, ["x", "y"], error=ValueError) 1670*da0073e9SAndroid Build Coastguard Worker 1671*da0073e9SAndroid Build Coastguard Worker # Replacing with one input column and default output column 1672*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y") 1673*da0073e9SAndroid Build Coastguard Worker _helper( 1674*da0073e9SAndroid Build Coastguard Worker lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y" 1675*da0073e9SAndroid Build Coastguard Worker ) 1676*da0073e9SAndroid Build Coastguard Worker # The key of input column is not in dict 1677*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_1n, "a", error=KeyError) 1678*da0073e9SAndroid Build Coastguard Worker # Unmatched input columns with fn arguments 1679*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_n1, "y", error=ValueError) 1680*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_1n, ["x", "y"], error=ValueError) 1681*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError) 1682*da0073e9SAndroid Build Coastguard Worker _helper(None, p_fn_n1, ["x", "y"], error=ValueError) 1683*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError) 1684*da0073e9SAndroid Build Coastguard Worker # Replacing with multiple input columns and default output column (the left-most input column) 1685*da0073e9SAndroid Build Coastguard Worker _helper( 1686*da0073e9SAndroid Build Coastguard Worker lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), 1687*da0073e9SAndroid Build Coastguard Worker fn_n1, 1688*da0073e9SAndroid Build Coastguard Worker ["z", "x"], 1689*da0073e9SAndroid Build Coastguard Worker ) 1690*da0073e9SAndroid Build Coastguard Worker _helper( 1691*da0073e9SAndroid Build Coastguard Worker lambda data: _dict_update( 1692*da0073e9SAndroid Build Coastguard Worker data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"] 1693*da0073e9SAndroid Build Coastguard Worker ), 1694*da0073e9SAndroid Build Coastguard Worker fn_nn, 1695*da0073e9SAndroid Build Coastguard Worker ["z", "y"], 1696*da0073e9SAndroid Build Coastguard Worker ) 1697*da0073e9SAndroid Build Coastguard Worker 1698*da0073e9SAndroid Build Coastguard Worker # output_col can only be specified when input_col is not None 1699*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_n1, None, "x", error=ValueError) 1700*da0073e9SAndroid Build Coastguard Worker # output_col can only be single-element list or tuple 1701*da0073e9SAndroid Build Coastguard Worker _helper(None, fn_n1, None, ["x", "y"], error=ValueError) 1702*da0073e9SAndroid Build Coastguard Worker # Single-element list as output_col 1703*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"]) 1704*da0073e9SAndroid Build Coastguard Worker # Replacing with one input column and single specified output column 1705*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x") 1706*da0073e9SAndroid Build Coastguard Worker _helper( 1707*da0073e9SAndroid Build Coastguard Worker lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), 1708*da0073e9SAndroid Build Coastguard Worker fn_1n, 1709*da0073e9SAndroid Build Coastguard Worker "y", 1710*da0073e9SAndroid Build Coastguard Worker "z", 1711*da0073e9SAndroid Build Coastguard Worker ) 1712*da0073e9SAndroid Build Coastguard Worker _helper( 1713*da0073e9SAndroid Build Coastguard Worker lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), 1714*da0073e9SAndroid Build Coastguard Worker fn_n1, 1715*da0073e9SAndroid Build Coastguard Worker ["x", "z"], 1716*da0073e9SAndroid Build Coastguard Worker "y", 1717*da0073e9SAndroid Build Coastguard Worker ) 1718*da0073e9SAndroid Build Coastguard Worker _helper( 1719*da0073e9SAndroid Build Coastguard Worker lambda data: _dict_update( 1720*da0073e9SAndroid Build Coastguard Worker data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])} 1721*da0073e9SAndroid Build Coastguard Worker ), 1722*da0073e9SAndroid Build Coastguard Worker fn_nn, 1723*da0073e9SAndroid Build Coastguard Worker ["y", "z"], 1724*da0073e9SAndroid Build Coastguard Worker "x", 1725*da0073e9SAndroid Build Coastguard Worker ) 1726*da0073e9SAndroid Build Coastguard Worker 1727*da0073e9SAndroid Build Coastguard Worker # Adding new key to dict for the output 1728*da0073e9SAndroid Build Coastguard Worker _helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a") 1729*da0073e9SAndroid Build Coastguard Worker _helper( 1730*da0073e9SAndroid Build Coastguard Worker lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), 1731*da0073e9SAndroid Build Coastguard Worker fn_1n, 1732*da0073e9SAndroid Build Coastguard Worker "y", 1733*da0073e9SAndroid Build Coastguard Worker "a", 1734*da0073e9SAndroid Build Coastguard Worker ) 1735*da0073e9SAndroid Build Coastguard Worker _helper( 1736*da0073e9SAndroid Build Coastguard Worker lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), 1737*da0073e9SAndroid Build Coastguard Worker fn_n1, 1738*da0073e9SAndroid Build Coastguard Worker ["x", "z"], 1739*da0073e9SAndroid Build Coastguard Worker "a", 1740*da0073e9SAndroid Build Coastguard Worker ) 1741*da0073e9SAndroid Build Coastguard Worker _helper( 1742*da0073e9SAndroid Build Coastguard Worker lambda data: _dict_update( 1743*da0073e9SAndroid Build Coastguard Worker data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])} 1744*da0073e9SAndroid Build Coastguard Worker ), 1745*da0073e9SAndroid Build Coastguard Worker fn_nn, 1746*da0073e9SAndroid Build Coastguard Worker ["y", "z"], 1747*da0073e9SAndroid Build Coastguard Worker "a", 1748*da0073e9SAndroid Build Coastguard Worker ) 1749*da0073e9SAndroid Build Coastguard Worker 1750*da0073e9SAndroid Build Coastguard Worker def test_collate_iterdatapipe(self): 1751*da0073e9SAndroid Build Coastguard Worker arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] 1752*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper(arrs) 1753*da0073e9SAndroid Build Coastguard Worker 1754*da0073e9SAndroid Build Coastguard Worker def _collate_fn(batch, default_type=torch.float): 1755*da0073e9SAndroid Build Coastguard Worker return torch.tensor(sum(batch), dtype=default_type) 1756*da0073e9SAndroid Build Coastguard Worker 1757*da0073e9SAndroid Build Coastguard Worker # Functional Test: defaults to the default collate function when a custom one is not specified 1758*da0073e9SAndroid Build Coastguard Worker collate_dp = input_dp.collate() 1759*da0073e9SAndroid Build Coastguard Worker for x, y in zip(arrs, collate_dp): 1760*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(x), y) 1761*da0073e9SAndroid Build Coastguard Worker 1762*da0073e9SAndroid Build Coastguard Worker # Functional Test: custom collate function 1763*da0073e9SAndroid Build Coastguard Worker collate_dp = input_dp.collate(collate_fn=_collate_fn) 1764*da0073e9SAndroid Build Coastguard Worker for x, y in zip(arrs, collate_dp): 1765*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(sum(x), dtype=torch.float), y) 1766*da0073e9SAndroid Build Coastguard Worker 1767*da0073e9SAndroid Build Coastguard Worker # Functional Test: custom, partial collate function 1768*da0073e9SAndroid Build Coastguard Worker collate_dp = input_dp.collate(partial(_collate_fn, default_type=torch.int)) 1769*da0073e9SAndroid Build Coastguard Worker for x, y in zip(arrs, collate_dp): 1770*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y) 1771*da0073e9SAndroid Build Coastguard Worker 1772*da0073e9SAndroid Build Coastguard Worker # Reset Test: reset the DataPipe and results are still correct 1773*da0073e9SAndroid Build Coastguard Worker n_elements_before_reset = 1 1774*da0073e9SAndroid Build Coastguard Worker res_before_reset, res_after_reset = reset_after_n_next_calls( 1775*da0073e9SAndroid Build Coastguard Worker collate_dp, n_elements_before_reset 1776*da0073e9SAndroid Build Coastguard Worker ) 1777*da0073e9SAndroid Build Coastguard Worker self.assertEqual([torch.tensor(6, dtype=torch.int)], res_before_reset) 1778*da0073e9SAndroid Build Coastguard Worker for x, y in zip(arrs, res_after_reset): 1779*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y) 1780*da0073e9SAndroid Build Coastguard Worker 1781*da0073e9SAndroid Build Coastguard Worker # __len__ Test: __len__ is inherited 1782*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(input_dp), len(collate_dp)) 1783*da0073e9SAndroid Build Coastguard Worker 1784*da0073e9SAndroid Build Coastguard Worker # __len__ Test: verify that it has no valid __len__ when the source doesn't have it 1785*da0073e9SAndroid Build Coastguard Worker input_dp_nl = IDP_NoLen(arrs) 1786*da0073e9SAndroid Build Coastguard Worker collate_dp_nl = input_dp_nl.collate() 1787*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): 1788*da0073e9SAndroid Build Coastguard Worker len(collate_dp_nl) 1789*da0073e9SAndroid Build Coastguard Worker for x, y in zip(arrs, collate_dp_nl): 1790*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(x), y) 1791*da0073e9SAndroid Build Coastguard Worker 1792*da0073e9SAndroid Build Coastguard Worker def test_batch_iterdatapipe(self): 1793*da0073e9SAndroid Build Coastguard Worker arrs = list(range(10)) 1794*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper(arrs) 1795*da0073e9SAndroid Build Coastguard Worker 1796*da0073e9SAndroid Build Coastguard Worker # Functional Test: raise error when input argument `batch_size = 0` 1797*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 1798*da0073e9SAndroid Build Coastguard Worker input_dp.batch(batch_size=0) 1799*da0073e9SAndroid Build Coastguard Worker 1800*da0073e9SAndroid Build Coastguard Worker # Functional Test: by default, do not drop the last batch 1801*da0073e9SAndroid Build Coastguard Worker bs = 3 1802*da0073e9SAndroid Build Coastguard Worker batch_dp = input_dp.batch(batch_size=bs) 1803*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(batch_dp), 4) 1804*da0073e9SAndroid Build Coastguard Worker for i, batch in enumerate(batch_dp): 1805*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(batch), 1 if i == 3 else bs) 1806*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batch, arrs[i * bs : i * bs + len(batch)]) 1807*da0073e9SAndroid Build Coastguard Worker 1808*da0073e9SAndroid Build Coastguard Worker # Functional Test: Drop the last batch when specified 1809*da0073e9SAndroid Build Coastguard Worker bs = 4 1810*da0073e9SAndroid Build Coastguard Worker batch_dp = input_dp.batch(batch_size=bs, drop_last=True) 1811*da0073e9SAndroid Build Coastguard Worker for i, batch in enumerate(batch_dp): 1812*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batch, arrs[i * bs : i * bs + len(batch)]) 1813*da0073e9SAndroid Build Coastguard Worker 1814*da0073e9SAndroid Build Coastguard Worker # __len__ test: verifying that the overall length and of each batch is correct 1815*da0073e9SAndroid Build Coastguard Worker for i, batch in enumerate(batch_dp): 1816*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(batch), bs) 1817*da0073e9SAndroid Build Coastguard Worker 1818*da0073e9SAndroid Build Coastguard Worker # __len__ Test: the length is missing if the source DataPipe doesn't have length 1819*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(batch_dp), 2) 1820*da0073e9SAndroid Build Coastguard Worker input_dp_nl = IDP_NoLen(range(10)) 1821*da0073e9SAndroid Build Coastguard Worker batch_dp_nl = input_dp_nl.batch(batch_size=2) 1822*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): 1823*da0073e9SAndroid Build Coastguard Worker len(batch_dp_nl) 1824*da0073e9SAndroid Build Coastguard Worker 1825*da0073e9SAndroid Build Coastguard Worker # Reset Test: Ensures that the DataPipe can properly reset 1826*da0073e9SAndroid Build Coastguard Worker n_elements_before_reset = 1 1827*da0073e9SAndroid Build Coastguard Worker res_before_reset, res_after_reset = reset_after_n_next_calls( 1828*da0073e9SAndroid Build Coastguard Worker batch_dp, n_elements_before_reset 1829*da0073e9SAndroid Build Coastguard Worker ) 1830*da0073e9SAndroid Build Coastguard Worker self.assertEqual([[0, 1, 2, 3]], res_before_reset) 1831*da0073e9SAndroid Build Coastguard Worker self.assertEqual([[0, 1, 2, 3], [4, 5, 6, 7]], res_after_reset) 1832*da0073e9SAndroid Build Coastguard Worker 1833*da0073e9SAndroid Build Coastguard Worker def test_unbatch_iterdatapipe(self): 1834*da0073e9SAndroid Build Coastguard Worker target_length = 6 1835*da0073e9SAndroid Build Coastguard Worker prebatch_dp = dp.iter.IterableWrapper(range(target_length)) 1836*da0073e9SAndroid Build Coastguard Worker 1837*da0073e9SAndroid Build Coastguard Worker # Functional Test: Unbatch DataPipe should be the same as pre-batch DataPipe 1838*da0073e9SAndroid Build Coastguard Worker input_dp = prebatch_dp.batch(3) 1839*da0073e9SAndroid Build Coastguard Worker unbatch_dp = input_dp.unbatch() 1840*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(unbatch_dp)), target_length) # __len__ is as expected 1841*da0073e9SAndroid Build Coastguard Worker for i, res in zip(range(target_length), unbatch_dp): 1842*da0073e9SAndroid Build Coastguard Worker self.assertEqual(i, res) 1843*da0073e9SAndroid Build Coastguard Worker 1844*da0073e9SAndroid Build Coastguard Worker # Functional Test: unbatch works for an input with nested levels 1845*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]]) 1846*da0073e9SAndroid Build Coastguard Worker unbatch_dp = input_dp.unbatch() 1847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(unbatch_dp)), target_length) 1848*da0073e9SAndroid Build Coastguard Worker for i, res in zip(range(target_length), unbatch_dp): 1849*da0073e9SAndroid Build Coastguard Worker self.assertEqual(i, res) 1850*da0073e9SAndroid Build Coastguard Worker 1851*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) 1852*da0073e9SAndroid Build Coastguard Worker 1853*da0073e9SAndroid Build Coastguard Worker # Functional Test: unbatch works for an input with nested levels 1854*da0073e9SAndroid Build Coastguard Worker unbatch_dp = input_dp.unbatch() 1855*da0073e9SAndroid Build Coastguard Worker expected_dp = [[0, 1], [2, 3], [4, 5], [6, 7]] 1856*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(unbatch_dp)), 4) 1857*da0073e9SAndroid Build Coastguard Worker for j, res in zip(expected_dp, unbatch_dp): 1858*da0073e9SAndroid Build Coastguard Worker self.assertEqual(j, res) 1859*da0073e9SAndroid Build Coastguard Worker 1860*da0073e9SAndroid Build Coastguard Worker # Functional Test: unbatching multiple levels at the same time 1861*da0073e9SAndroid Build Coastguard Worker unbatch_dp = input_dp.unbatch(unbatch_level=2) 1862*da0073e9SAndroid Build Coastguard Worker expected_dp2 = [0, 1, 2, 3, 4, 5, 6, 7] 1863*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(unbatch_dp)), 8) 1864*da0073e9SAndroid Build Coastguard Worker for i, res in zip(expected_dp2, unbatch_dp): 1865*da0073e9SAndroid Build Coastguard Worker self.assertEqual(i, res) 1866*da0073e9SAndroid Build Coastguard Worker 1867*da0073e9SAndroid Build Coastguard Worker # Functional Test: unbatching all levels at the same time 1868*da0073e9SAndroid Build Coastguard Worker unbatch_dp = input_dp.unbatch(unbatch_level=-1) 1869*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(unbatch_dp)), 8) 1870*da0073e9SAndroid Build Coastguard Worker for i, res in zip(expected_dp2, unbatch_dp): 1871*da0073e9SAndroid Build Coastguard Worker self.assertEqual(i, res) 1872*da0073e9SAndroid Build Coastguard Worker 1873*da0073e9SAndroid Build Coastguard Worker # Functional Test: raises error when input unbatch_level is less than -1 1874*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]]) 1875*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1876*da0073e9SAndroid Build Coastguard Worker unbatch_dp = input_dp.unbatch(unbatch_level=-2) 1877*da0073e9SAndroid Build Coastguard Worker for i in unbatch_dp: 1878*da0073e9SAndroid Build Coastguard Worker print(i) 1879*da0073e9SAndroid Build Coastguard Worker 1880*da0073e9SAndroid Build Coastguard Worker # Functional Test: raises error when input unbatch_level is too high 1881*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 1882*da0073e9SAndroid Build Coastguard Worker unbatch_dp = input_dp.unbatch(unbatch_level=5) 1883*da0073e9SAndroid Build Coastguard Worker for i in unbatch_dp: 1884*da0073e9SAndroid Build Coastguard Worker print(i) 1885*da0073e9SAndroid Build Coastguard Worker 1886*da0073e9SAndroid Build Coastguard Worker # Reset Test: unbatch_dp resets properly 1887*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]]) 1888*da0073e9SAndroid Build Coastguard Worker unbatch_dp = input_dp.unbatch(unbatch_level=-1) 1889*da0073e9SAndroid Build Coastguard Worker n_elements_before_reset = 3 1890*da0073e9SAndroid Build Coastguard Worker res_before_reset, res_after_reset = reset_after_n_next_calls( 1891*da0073e9SAndroid Build Coastguard Worker unbatch_dp, n_elements_before_reset 1892*da0073e9SAndroid Build Coastguard Worker ) 1893*da0073e9SAndroid Build Coastguard Worker self.assertEqual([0, 1, 2], res_before_reset) 1894*da0073e9SAndroid Build Coastguard Worker self.assertEqual([0, 1, 2, 3, 4, 5], res_after_reset) 1895*da0073e9SAndroid Build Coastguard Worker 1896*da0073e9SAndroid Build Coastguard Worker def test_filter_datapipe(self): 1897*da0073e9SAndroid Build Coastguard Worker input_ds = dp.iter.IterableWrapper(range(10)) 1898*da0073e9SAndroid Build Coastguard Worker 1899*da0073e9SAndroid Build Coastguard Worker def _filter_fn(data, val): 1900*da0073e9SAndroid Build Coastguard Worker return data >= val 1901*da0073e9SAndroid Build Coastguard Worker 1902*da0073e9SAndroid Build Coastguard Worker # Functional Test: filter works with partial function 1903*da0073e9SAndroid Build Coastguard Worker filter_dp = input_ds.filter(partial(_filter_fn, val=5)) 1904*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(filter_dp), list(range(5, 10))) 1905*da0073e9SAndroid Build Coastguard Worker 1906*da0073e9SAndroid Build Coastguard Worker def _non_bool_fn(data): 1907*da0073e9SAndroid Build Coastguard Worker return 1 1908*da0073e9SAndroid Build Coastguard Worker 1909*da0073e9SAndroid Build Coastguard Worker # Functional Test: filter function must return bool 1910*da0073e9SAndroid Build Coastguard Worker filter_dp = input_ds.filter(filter_fn=_non_bool_fn) 1911*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1912*da0073e9SAndroid Build Coastguard Worker temp = list(filter_dp) 1913*da0073e9SAndroid Build Coastguard Worker 1914*da0073e9SAndroid Build Coastguard Worker # Funtional Test: Specify input_col 1915*da0073e9SAndroid Build Coastguard Worker tuple_input_ds = dp.iter.IterableWrapper([(d - 1, d, d + 1) for d in range(10)]) 1916*da0073e9SAndroid Build Coastguard Worker 1917*da0073e9SAndroid Build Coastguard Worker # Single input_col 1918*da0073e9SAndroid Build Coastguard Worker input_col_1_dp = tuple_input_ds.filter(partial(_filter_fn, val=5), input_col=1) 1919*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1920*da0073e9SAndroid Build Coastguard Worker list(input_col_1_dp), [(d - 1, d, d + 1) for d in range(5, 10)] 1921*da0073e9SAndroid Build Coastguard Worker ) 1922*da0073e9SAndroid Build Coastguard Worker 1923*da0073e9SAndroid Build Coastguard Worker # Multiple input_col 1924*da0073e9SAndroid Build Coastguard Worker def _mul_filter_fn(a, b): 1925*da0073e9SAndroid Build Coastguard Worker return a + b < 10 1926*da0073e9SAndroid Build Coastguard Worker 1927*da0073e9SAndroid Build Coastguard Worker input_col_2_dp = tuple_input_ds.filter(_mul_filter_fn, input_col=[0, 2]) 1928*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(input_col_2_dp), [(d - 1, d, d + 1) for d in range(5)]) 1929*da0073e9SAndroid Build Coastguard Worker 1930*da0073e9SAndroid Build Coastguard Worker # invalid input col 1931*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1932*da0073e9SAndroid Build Coastguard Worker tuple_input_ds.filter(_mul_filter_fn, input_col=0) 1933*da0073e9SAndroid Build Coastguard Worker 1934*da0073e9SAndroid Build Coastguard Worker p_mul_filter_fn = partial(_mul_filter_fn, b=1) 1935*da0073e9SAndroid Build Coastguard Worker out = tuple_input_ds.filter(p_mul_filter_fn, input_col=0) 1936*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)]) 1937*da0073e9SAndroid Build Coastguard Worker 1938*da0073e9SAndroid Build Coastguard Worker def _mul_filter_fn_with_defaults(a, b=1): 1939*da0073e9SAndroid Build Coastguard Worker return a + b < 10 1940*da0073e9SAndroid Build Coastguard Worker 1941*da0073e9SAndroid Build Coastguard Worker out = tuple_input_ds.filter(_mul_filter_fn_with_defaults, input_col=0) 1942*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)]) 1943*da0073e9SAndroid Build Coastguard Worker 1944*da0073e9SAndroid Build Coastguard Worker def _mul_filter_fn_with_kw_only(*, a, b): 1945*da0073e9SAndroid Build Coastguard Worker return a + b < 10 1946*da0073e9SAndroid Build Coastguard Worker 1947*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1948*da0073e9SAndroid Build Coastguard Worker tuple_input_ds.filter(_mul_filter_fn_with_kw_only, input_col=0) 1949*da0073e9SAndroid Build Coastguard Worker 1950*da0073e9SAndroid Build Coastguard Worker def _mul_filter_fn_with_kw_only_1_default(*, a, b=1): 1951*da0073e9SAndroid Build Coastguard Worker return a + b < 10 1952*da0073e9SAndroid Build Coastguard Worker 1953*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1954*da0073e9SAndroid Build Coastguard Worker tuple_input_ds.filter(_mul_filter_fn_with_kw_only_1_default, input_col=0) 1955*da0073e9SAndroid Build Coastguard Worker 1956*da0073e9SAndroid Build Coastguard Worker # __len__ Test: DataPipe has no valid len 1957*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"has no len"): 1958*da0073e9SAndroid Build Coastguard Worker len(filter_dp) 1959*da0073e9SAndroid Build Coastguard Worker 1960*da0073e9SAndroid Build Coastguard Worker # Reset Test: DataPipe resets correctly 1961*da0073e9SAndroid Build Coastguard Worker filter_dp = input_ds.filter(partial(_filter_fn, val=5)) 1962*da0073e9SAndroid Build Coastguard Worker n_elements_before_reset = 3 1963*da0073e9SAndroid Build Coastguard Worker res_before_reset, res_after_reset = reset_after_n_next_calls( 1964*da0073e9SAndroid Build Coastguard Worker filter_dp, n_elements_before_reset 1965*da0073e9SAndroid Build Coastguard Worker ) 1966*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(5, 10))[:n_elements_before_reset], res_before_reset) 1967*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(5, 10)), res_after_reset) 1968*da0073e9SAndroid Build Coastguard Worker 1969*da0073e9SAndroid Build Coastguard Worker def test_sampler_iterdatapipe(self): 1970*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper(range(10)) 1971*da0073e9SAndroid Build Coastguard Worker # Default SequentialSampler 1972*da0073e9SAndroid Build Coastguard Worker sampled_dp = dp.iter.Sampler(input_dp) # type: ignore[var-annotated] 1973*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(sampled_dp), 10) 1974*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(sampled_dp): 1975*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, i) 1976*da0073e9SAndroid Build Coastguard Worker 1977*da0073e9SAndroid Build Coastguard Worker # RandomSampler 1978*da0073e9SAndroid Build Coastguard Worker random_sampled_dp = dp.iter.Sampler( 1979*da0073e9SAndroid Build Coastguard Worker input_dp, sampler=RandomSampler, sampler_kwargs={"replacement": True} 1980*da0073e9SAndroid Build Coastguard Worker ) # type: ignore[var-annotated] # noqa: B950 1981*da0073e9SAndroid Build Coastguard Worker 1982*da0073e9SAndroid Build Coastguard Worker # Requires `__len__` to build SamplerDataPipe 1983*da0073e9SAndroid Build Coastguard Worker input_dp_nolen = IDP_NoLen(range(10)) 1984*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 1985*da0073e9SAndroid Build Coastguard Worker sampled_dp = dp.iter.Sampler(input_dp_nolen) 1986*da0073e9SAndroid Build Coastguard Worker 1987*da0073e9SAndroid Build Coastguard Worker def test_stream_reader_iterdatapipe(self): 1988*da0073e9SAndroid Build Coastguard Worker from io import StringIO 1989*da0073e9SAndroid Build Coastguard Worker 1990*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper( 1991*da0073e9SAndroid Build Coastguard Worker [("f1", StringIO("abcde")), ("f2", StringIO("bcdef"))] 1992*da0073e9SAndroid Build Coastguard Worker ) 1993*da0073e9SAndroid Build Coastguard Worker expected_res = ["abcde", "bcdef"] 1994*da0073e9SAndroid Build Coastguard Worker 1995*da0073e9SAndroid Build Coastguard Worker # Functional Test: Read full chunk 1996*da0073e9SAndroid Build Coastguard Worker dp1 = input_dp.read_from_stream() 1997*da0073e9SAndroid Build Coastguard Worker self.assertEqual([d[1] for d in dp1], expected_res) 1998*da0073e9SAndroid Build Coastguard Worker 1999*da0073e9SAndroid Build Coastguard Worker # Functional Test: Read full chunk 2000*da0073e9SAndroid Build Coastguard Worker dp2 = input_dp.read_from_stream(chunk=1) 2001*da0073e9SAndroid Build Coastguard Worker self.assertEqual([d[1] for d in dp2], [c for s in expected_res for c in s]) 2002*da0073e9SAndroid Build Coastguard Worker 2003*da0073e9SAndroid Build Coastguard Worker # `__len__` Test 2004*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 2005*da0073e9SAndroid Build Coastguard Worker len(dp1) 2006*da0073e9SAndroid Build Coastguard Worker 2007*da0073e9SAndroid Build Coastguard Worker def test_shuffler_iterdatapipe(self): 2008*da0073e9SAndroid Build Coastguard Worker input_dp = dp.iter.IterableWrapper(list(range(10))) 2009*da0073e9SAndroid Build Coastguard Worker 2010*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 2011*da0073e9SAndroid Build Coastguard Worker shuffle_dp = input_dp.shuffle(buffer_size=0) 2012*da0073e9SAndroid Build Coastguard Worker 2013*da0073e9SAndroid Build Coastguard Worker # Functional Test: No seed 2014*da0073e9SAndroid Build Coastguard Worker shuffler_dp = input_dp.shuffle() 2015*da0073e9SAndroid Build Coastguard Worker self.assertEqual(set(range(10)), set(shuffler_dp)) 2016*da0073e9SAndroid Build Coastguard Worker 2017*da0073e9SAndroid Build Coastguard Worker # Functional Test: With global seed 2018*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 2019*da0073e9SAndroid Build Coastguard Worker shuffler_dp = input_dp.shuffle() 2020*da0073e9SAndroid Build Coastguard Worker res = list(shuffler_dp) 2021*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 2022*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(shuffler_dp), res) 2023*da0073e9SAndroid Build Coastguard Worker 2024*da0073e9SAndroid Build Coastguard Worker # Functional Test: Set seed 2025*da0073e9SAndroid Build Coastguard Worker shuffler_dp = input_dp.shuffle().set_seed(123) 2026*da0073e9SAndroid Build Coastguard Worker res = list(shuffler_dp) 2027*da0073e9SAndroid Build Coastguard Worker shuffler_dp.set_seed(123) 2028*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(shuffler_dp), res) 2029*da0073e9SAndroid Build Coastguard Worker 2030*da0073e9SAndroid Build Coastguard Worker # Functional Test: deactivate shuffling via set_shuffle 2031*da0073e9SAndroid Build Coastguard Worker unshuffled_dp = input_dp.shuffle().set_shuffle(False) 2032*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(unshuffled_dp), list(input_dp)) 2033*da0073e9SAndroid Build Coastguard Worker 2034*da0073e9SAndroid Build Coastguard Worker # Reset Test: 2035*da0073e9SAndroid Build Coastguard Worker shuffler_dp = input_dp.shuffle() 2036*da0073e9SAndroid Build Coastguard Worker n_elements_before_reset = 5 2037*da0073e9SAndroid Build Coastguard Worker res_before_reset, res_after_reset = reset_after_n_next_calls( 2038*da0073e9SAndroid Build Coastguard Worker shuffler_dp, n_elements_before_reset 2039*da0073e9SAndroid Build Coastguard Worker ) 2040*da0073e9SAndroid Build Coastguard Worker self.assertEqual(5, len(res_before_reset)) 2041*da0073e9SAndroid Build Coastguard Worker for x in res_before_reset: 2042*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x in set(range(10))) 2043*da0073e9SAndroid Build Coastguard Worker self.assertEqual(set(range(10)), set(res_after_reset)) 2044*da0073e9SAndroid Build Coastguard Worker 2045*da0073e9SAndroid Build Coastguard Worker # __len__ Test: returns the length of the input DataPipe 2046*da0073e9SAndroid Build Coastguard Worker shuffler_dp = input_dp.shuffle() 2047*da0073e9SAndroid Build Coastguard Worker self.assertEqual(10, len(shuffler_dp)) 2048*da0073e9SAndroid Build Coastguard Worker exp = list(range(100)) 2049*da0073e9SAndroid Build Coastguard Worker 2050*da0073e9SAndroid Build Coastguard Worker # Serialization Test 2051*da0073e9SAndroid Build Coastguard Worker from torch.utils.data.datapipes._hook_iterator import _SnapshotState 2052*da0073e9SAndroid Build Coastguard Worker 2053*da0073e9SAndroid Build Coastguard Worker def _serialization_helper(bs): 2054*da0073e9SAndroid Build Coastguard Worker shuffler_dp = input_dp.shuffle(buffer_size=bs) 2055*da0073e9SAndroid Build Coastguard Worker it = iter(shuffler_dp) 2056*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 2057*da0073e9SAndroid Build Coastguard Worker next(it) 2058*da0073e9SAndroid Build Coastguard Worker shuffler_dp_copy = pickle.loads(pickle.dumps(shuffler_dp)) 2059*da0073e9SAndroid Build Coastguard Worker _simple_graph_snapshot_restoration( 2060*da0073e9SAndroid Build Coastguard Worker shuffler_dp_copy.datapipe, 2061*da0073e9SAndroid Build Coastguard Worker shuffler_dp.datapipe._number_of_samples_yielded, 2062*da0073e9SAndroid Build Coastguard Worker ) 2063*da0073e9SAndroid Build Coastguard Worker 2064*da0073e9SAndroid Build Coastguard Worker exp = list(it) 2065*da0073e9SAndroid Build Coastguard Worker shuffler_dp_copy._snapshot_state = _SnapshotState.Restored 2066*da0073e9SAndroid Build Coastguard Worker self.assertEqual(exp, list(shuffler_dp_copy)) 2067*da0073e9SAndroid Build Coastguard Worker 2068*da0073e9SAndroid Build Coastguard Worker buffer_sizes = [2, 5, 15] 2069*da0073e9SAndroid Build Coastguard Worker for bs in buffer_sizes: 2070*da0073e9SAndroid Build Coastguard Worker _serialization_helper(bs) 2071*da0073e9SAndroid Build Coastguard Worker 2072*da0073e9SAndroid Build Coastguard Worker def test_zip_iterdatapipe(self): 2073*da0073e9SAndroid Build Coastguard Worker # Functional Test: raises TypeError when an input is not of type `IterDataPipe` 2074*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 2075*da0073e9SAndroid Build Coastguard Worker dp.iter.Zipper(dp.iter.IterableWrapper(range(10)), list(range(10))) # type: ignore[arg-type] 2076*da0073e9SAndroid Build Coastguard Worker 2077*da0073e9SAndroid Build Coastguard Worker # Functional Test: raises TypeError when an input does not have valid length 2078*da0073e9SAndroid Build Coastguard Worker zipped_dp = dp.iter.Zipper( 2079*da0073e9SAndroid Build Coastguard Worker dp.iter.IterableWrapper(range(10)), IDP_NoLen(range(5)) 2080*da0073e9SAndroid Build Coastguard Worker ) # type: ignore[var-annotated] 2081*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): 2082*da0073e9SAndroid Build Coastguard Worker len(zipped_dp) 2083*da0073e9SAndroid Build Coastguard Worker 2084*da0073e9SAndroid Build Coastguard Worker # Functional Test: zips the results properly 2085*da0073e9SAndroid Build Coastguard Worker exp = [(i, i) for i in range(5)] 2086*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(zipped_dp), exp) 2087*da0073e9SAndroid Build Coastguard Worker 2088*da0073e9SAndroid Build Coastguard Worker # Functional Test: zips the inputs properly even when lengths are different (zips to the shortest) 2089*da0073e9SAndroid Build Coastguard Worker zipped_dp = dp.iter.Zipper( 2090*da0073e9SAndroid Build Coastguard Worker dp.iter.IterableWrapper(range(10)), dp.iter.IterableWrapper(range(5)) 2091*da0073e9SAndroid Build Coastguard Worker ) 2092*da0073e9SAndroid Build Coastguard Worker 2093*da0073e9SAndroid Build Coastguard Worker # __len__ Test: length matches the length of the shortest input 2094*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(zipped_dp), 5) 2095*da0073e9SAndroid Build Coastguard Worker 2096*da0073e9SAndroid Build Coastguard Worker # Reset Test: 2097*da0073e9SAndroid Build Coastguard Worker n_elements_before_reset = 3 2098*da0073e9SAndroid Build Coastguard Worker res_before_reset, res_after_reset = reset_after_n_next_calls( 2099*da0073e9SAndroid Build Coastguard Worker zipped_dp, n_elements_before_reset 2100*da0073e9SAndroid Build Coastguard Worker ) 2101*da0073e9SAndroid Build Coastguard Worker expected_res = [(i, i) for i in range(5)] 2102*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset) 2103*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_res, res_after_reset) 2104*da0073e9SAndroid Build Coastguard Worker 2105*da0073e9SAndroid Build Coastguard Worker 2106*da0073e9SAndroid Build Coastguard Workerclass TestFunctionalMapDataPipe(TestCase): 2107*da0073e9SAndroid Build Coastguard Worker def _serialization_test_helper(self, datapipe, use_dill): 2108*da0073e9SAndroid Build Coastguard Worker if use_dill: 2109*da0073e9SAndroid Build Coastguard Worker serialized_dp = dill.dumps(datapipe) 2110*da0073e9SAndroid Build Coastguard Worker deserialized_dp = dill.loads(serialized_dp) 2111*da0073e9SAndroid Build Coastguard Worker else: 2112*da0073e9SAndroid Build Coastguard Worker serialized_dp = pickle.dumps(datapipe) 2113*da0073e9SAndroid Build Coastguard Worker deserialized_dp = pickle.loads(serialized_dp) 2114*da0073e9SAndroid Build Coastguard Worker try: 2115*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(datapipe), list(deserialized_dp)) 2116*da0073e9SAndroid Build Coastguard Worker except AssertionError as e: 2117*da0073e9SAndroid Build Coastguard Worker print(f"{datapipe} is failing.") 2118*da0073e9SAndroid Build Coastguard Worker raise e 2119*da0073e9SAndroid Build Coastguard Worker 2120*da0073e9SAndroid Build Coastguard Worker def _serialization_test_for_single_dp(self, dp, use_dill=False): 2121*da0073e9SAndroid Build Coastguard Worker # 1. Testing for serialization before any iteration starts 2122*da0073e9SAndroid Build Coastguard Worker self._serialization_test_helper(dp, use_dill) 2123*da0073e9SAndroid Build Coastguard Worker # 2. Testing for serialization after DataPipe is partially read 2124*da0073e9SAndroid Build Coastguard Worker it = iter(dp) 2125*da0073e9SAndroid Build Coastguard Worker _ = next(it) 2126*da0073e9SAndroid Build Coastguard Worker self._serialization_test_helper(dp, use_dill) 2127*da0073e9SAndroid Build Coastguard Worker # 3. Testing for serialization after DataPipe is fully read 2128*da0073e9SAndroid Build Coastguard Worker _ = list(dp) 2129*da0073e9SAndroid Build Coastguard Worker self._serialization_test_helper(dp, use_dill) 2130*da0073e9SAndroid Build Coastguard Worker 2131*da0073e9SAndroid Build Coastguard Worker def test_serializable(self): 2132*da0073e9SAndroid Build Coastguard Worker picklable_datapipes: List = [ 2133*da0073e9SAndroid Build Coastguard Worker (dp.map.Batcher, None, (2,), {}), 2134*da0073e9SAndroid Build Coastguard Worker (dp.map.Concater, None, (dp.map.SequenceWrapper(range(10)),), {}), 2135*da0073e9SAndroid Build Coastguard Worker (dp.map.Mapper, None, (), {}), 2136*da0073e9SAndroid Build Coastguard Worker (dp.map.Mapper, None, (_fake_fn,), {}), 2137*da0073e9SAndroid Build Coastguard Worker (dp.map.Mapper, None, (partial(_fake_add, 1),), {}), 2138*da0073e9SAndroid Build Coastguard Worker (dp.map.SequenceWrapper, range(10), (), {}), 2139*da0073e9SAndroid Build Coastguard Worker (dp.map.Shuffler, dp.map.SequenceWrapper([0] * 5), (), {}), 2140*da0073e9SAndroid Build Coastguard Worker (dp.map.Zipper, None, (dp.map.SequenceWrapper(range(10)),), {}), 2141*da0073e9SAndroid Build Coastguard Worker ] 2142*da0073e9SAndroid Build Coastguard Worker for dpipe, custom_input, dp_args, dp_kwargs in picklable_datapipes: 2143*da0073e9SAndroid Build Coastguard Worker if custom_input is None: 2144*da0073e9SAndroid Build Coastguard Worker custom_input = dp.map.SequenceWrapper(range(10)) 2145*da0073e9SAndroid Build Coastguard Worker datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] 2146*da0073e9SAndroid Build Coastguard Worker self._serialization_test_for_single_dp(datapipe) 2147*da0073e9SAndroid Build Coastguard Worker 2148*da0073e9SAndroid Build Coastguard Worker def test_serializable_with_dill(self): 2149*da0073e9SAndroid Build Coastguard Worker """Only for DataPipes that take in a function as argument""" 2150*da0073e9SAndroid Build Coastguard Worker input_dp = dp.map.SequenceWrapper(range(10)) 2151*da0073e9SAndroid Build Coastguard Worker 2152*da0073e9SAndroid Build Coastguard Worker datapipes_with_lambda_fn: List[ 2153*da0073e9SAndroid Build Coastguard Worker Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]] 2154*da0073e9SAndroid Build Coastguard Worker ] = [ 2155*da0073e9SAndroid Build Coastguard Worker (dp.map.Mapper, (lambda_fn1,), {}), 2156*da0073e9SAndroid Build Coastguard Worker ] 2157*da0073e9SAndroid Build Coastguard Worker 2158*da0073e9SAndroid Build Coastguard Worker def _local_fns(): 2159*da0073e9SAndroid Build Coastguard Worker def _fn1(x): 2160*da0073e9SAndroid Build Coastguard Worker return x 2161*da0073e9SAndroid Build Coastguard Worker 2162*da0073e9SAndroid Build Coastguard Worker return _fn1 2163*da0073e9SAndroid Build Coastguard Worker 2164*da0073e9SAndroid Build Coastguard Worker fn1 = _local_fns() 2165*da0073e9SAndroid Build Coastguard Worker 2166*da0073e9SAndroid Build Coastguard Worker datapipes_with_local_fn: List[ 2167*da0073e9SAndroid Build Coastguard Worker Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]] 2168*da0073e9SAndroid Build Coastguard Worker ] = [ 2169*da0073e9SAndroid Build Coastguard Worker (dp.map.Mapper, (fn1,), {}), 2170*da0073e9SAndroid Build Coastguard Worker ] 2171*da0073e9SAndroid Build Coastguard Worker 2172*da0073e9SAndroid Build Coastguard Worker if HAS_DILL: 2173*da0073e9SAndroid Build Coastguard Worker for dpipe, dp_args, dp_kwargs in ( 2174*da0073e9SAndroid Build Coastguard Worker datapipes_with_lambda_fn + datapipes_with_local_fn 2175*da0073e9SAndroid Build Coastguard Worker ): 2176*da0073e9SAndroid Build Coastguard Worker _ = dill.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] 2177*da0073e9SAndroid Build Coastguard Worker else: 2178*da0073e9SAndroid Build Coastguard Worker msgs = ( 2179*da0073e9SAndroid Build Coastguard Worker r"^Lambda function is not supported by pickle", 2180*da0073e9SAndroid Build Coastguard Worker r"^Local function is not supported by pickle", 2181*da0073e9SAndroid Build Coastguard Worker ) 2182*da0073e9SAndroid Build Coastguard Worker for dps, msg in zip( 2183*da0073e9SAndroid Build Coastguard Worker (datapipes_with_lambda_fn, datapipes_with_local_fn), msgs 2184*da0073e9SAndroid Build Coastguard Worker ): 2185*da0073e9SAndroid Build Coastguard Worker for dpipe, dp_args, dp_kwargs in dps: 2186*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, msg): 2187*da0073e9SAndroid Build Coastguard Worker datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] 2188*da0073e9SAndroid Build Coastguard Worker with self.assertRaises((pickle.PicklingError, AttributeError)): 2189*da0073e9SAndroid Build Coastguard Worker pickle.dumps(datapipe) 2190*da0073e9SAndroid Build Coastguard Worker 2191*da0073e9SAndroid Build Coastguard Worker def test_docstring(self): 2192*da0073e9SAndroid Build Coastguard Worker """ 2193*da0073e9SAndroid Build Coastguard Worker Ensure functional form of MapDataPipe has the correct docstring from 2194*da0073e9SAndroid Build Coastguard Worker the class form. 2195*da0073e9SAndroid Build Coastguard Worker 2196*da0073e9SAndroid Build Coastguard Worker Regression test for https://github.com/pytorch/data/issues/792. 2197*da0073e9SAndroid Build Coastguard Worker """ 2198*da0073e9SAndroid Build Coastguard Worker input_dp = dp.map.SequenceWrapper(range(10)) 2199*da0073e9SAndroid Build Coastguard Worker 2200*da0073e9SAndroid Build Coastguard Worker for dp_funcname in [ 2201*da0073e9SAndroid Build Coastguard Worker "batch", 2202*da0073e9SAndroid Build Coastguard Worker "concat", 2203*da0073e9SAndroid Build Coastguard Worker "map", 2204*da0073e9SAndroid Build Coastguard Worker "shuffle", 2205*da0073e9SAndroid Build Coastguard Worker "zip", 2206*da0073e9SAndroid Build Coastguard Worker ]: 2207*da0073e9SAndroid Build Coastguard Worker if sys.version_info >= (3, 9): 2208*da0073e9SAndroid Build Coastguard Worker docstring = pydoc.render_doc( 2209*da0073e9SAndroid Build Coastguard Worker thing=getattr(input_dp, dp_funcname), forceload=True 2210*da0073e9SAndroid Build Coastguard Worker ) 2211*da0073e9SAndroid Build Coastguard Worker elif sys.version_info < (3, 9): 2212*da0073e9SAndroid Build Coastguard Worker # pydoc works differently on Python 3.8, see 2213*da0073e9SAndroid Build Coastguard Worker # https://docs.python.org/3/whatsnew/3.9.html#pydoc 2214*da0073e9SAndroid Build Coastguard Worker docstring = getattr(input_dp, dp_funcname).__doc__ 2215*da0073e9SAndroid Build Coastguard Worker assert f"(functional name: ``{dp_funcname}``)" in docstring 2216*da0073e9SAndroid Build Coastguard Worker assert "Args:" in docstring 2217*da0073e9SAndroid Build Coastguard Worker assert "Example:" in docstring or "Examples:" in docstring 2218*da0073e9SAndroid Build Coastguard Worker 2219*da0073e9SAndroid Build Coastguard Worker def test_sequence_wrapper_datapipe(self): 2220*da0073e9SAndroid Build Coastguard Worker seq = list(range(10)) 2221*da0073e9SAndroid Build Coastguard Worker input_dp = dp.map.SequenceWrapper(seq) 2222*da0073e9SAndroid Build Coastguard Worker 2223*da0073e9SAndroid Build Coastguard Worker # Functional Test: all elements are equal in the same order 2224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(seq, list(input_dp)) 2225*da0073e9SAndroid Build Coastguard Worker 2226*da0073e9SAndroid Build Coastguard Worker # Functional Test: confirm deepcopy works by default 2227*da0073e9SAndroid Build Coastguard Worker seq.append(11) 2228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), list(input_dp)) # input_dp shouldn't have 11 2229*da0073e9SAndroid Build Coastguard Worker 2230*da0073e9SAndroid Build Coastguard Worker # Functional Test: non-deepcopy version is working 2231*da0073e9SAndroid Build Coastguard Worker seq2 = [1, 2, 3] 2232*da0073e9SAndroid Build Coastguard Worker input_dp_non_deep = dp.map.SequenceWrapper(seq2, deepcopy=False) 2233*da0073e9SAndroid Build Coastguard Worker seq2.append(4) 2234*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(seq2), list(input_dp_non_deep)) # should have 4 2235*da0073e9SAndroid Build Coastguard Worker 2236*da0073e9SAndroid Build Coastguard Worker # Reset Test: reset the DataPipe 2237*da0073e9SAndroid Build Coastguard Worker seq = list(range(10)) 2238*da0073e9SAndroid Build Coastguard Worker n_elements_before_reset = 5 2239*da0073e9SAndroid Build Coastguard Worker res_before_reset, res_after_reset = reset_after_n_next_calls( 2240*da0073e9SAndroid Build Coastguard Worker input_dp, n_elements_before_reset 2241*da0073e9SAndroid Build Coastguard Worker ) 2242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(5)), res_before_reset) 2243*da0073e9SAndroid Build Coastguard Worker self.assertEqual(seq, res_after_reset) 2244*da0073e9SAndroid Build Coastguard Worker 2245*da0073e9SAndroid Build Coastguard Worker # __len__ Test: inherits length from sequence 2246*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(seq), len(input_dp)) 2247*da0073e9SAndroid Build Coastguard Worker 2248*da0073e9SAndroid Build Coastguard Worker def test_concat_mapdatapipe(self): 2249*da0073e9SAndroid Build Coastguard Worker input_dp1 = dp.map.SequenceWrapper(range(10)) 2250*da0073e9SAndroid Build Coastguard Worker input_dp2 = dp.map.SequenceWrapper(range(5)) 2251*da0073e9SAndroid Build Coastguard Worker 2252*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"): 2253*da0073e9SAndroid Build Coastguard Worker dp.map.Concater() 2254*da0073e9SAndroid Build Coastguard Worker 2255*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2256*da0073e9SAndroid Build Coastguard Worker TypeError, r"Expected all inputs to be `MapDataPipe`" 2257*da0073e9SAndroid Build Coastguard Worker ): 2258*da0073e9SAndroid Build Coastguard Worker dp.map.Concater(input_dp1, ()) # type: ignore[arg-type] 2259*da0073e9SAndroid Build Coastguard Worker 2260*da0073e9SAndroid Build Coastguard Worker concat_dp = input_dp1.concat(input_dp2) 2261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(concat_dp), 15) 2262*da0073e9SAndroid Build Coastguard Worker for index in range(15): 2263*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2264*da0073e9SAndroid Build Coastguard Worker concat_dp[index], (list(range(10)) + list(range(5)))[index] 2265*da0073e9SAndroid Build Coastguard Worker ) 2266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(concat_dp), list(range(10)) + list(range(5))) 2267*da0073e9SAndroid Build Coastguard Worker 2268*da0073e9SAndroid Build Coastguard Worker def test_zip_mapdatapipe(self): 2269*da0073e9SAndroid Build Coastguard Worker input_dp1 = dp.map.SequenceWrapper(range(10)) 2270*da0073e9SAndroid Build Coastguard Worker input_dp2 = dp.map.SequenceWrapper(range(5)) 2271*da0073e9SAndroid Build Coastguard Worker input_dp3 = dp.map.SequenceWrapper(range(15)) 2272*da0073e9SAndroid Build Coastguard Worker 2273*da0073e9SAndroid Build Coastguard Worker # Functional Test: requires at least one input DataPipe 2274*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"): 2275*da0073e9SAndroid Build Coastguard Worker dp.map.Zipper() 2276*da0073e9SAndroid Build Coastguard Worker 2277*da0073e9SAndroid Build Coastguard Worker # Functional Test: all inputs must be MapDataPipes 2278*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2279*da0073e9SAndroid Build Coastguard Worker TypeError, r"Expected all inputs to be `MapDataPipe`" 2280*da0073e9SAndroid Build Coastguard Worker ): 2281*da0073e9SAndroid Build Coastguard Worker dp.map.Zipper(input_dp1, ()) # type: ignore[arg-type] 2282*da0073e9SAndroid Build Coastguard Worker 2283*da0073e9SAndroid Build Coastguard Worker # Functional Test: Zip the elements up as a tuples 2284*da0073e9SAndroid Build Coastguard Worker zip_dp = input_dp1.zip(input_dp2, input_dp3) 2285*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(i, i, i) for i in range(5)], [zip_dp[i] for i in range(5)]) 2286*da0073e9SAndroid Build Coastguard Worker 2287*da0073e9SAndroid Build Coastguard Worker # Functional Test: Raise IndexError when index equal or exceed the length of the shortest DataPipe 2288*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, r"out of range"): 2289*da0073e9SAndroid Build Coastguard Worker input_dp1.zip(input_dp2, input_dp3)[5] 2290*da0073e9SAndroid Build Coastguard Worker 2291*da0073e9SAndroid Build Coastguard Worker # Functional Test: Ensure `zip` can combine `Batcher` with others 2292*da0073e9SAndroid Build Coastguard Worker dp1 = dp.map.SequenceWrapper(range(10)) 2293*da0073e9SAndroid Build Coastguard Worker shuffle_dp1 = dp1.batch(2) 2294*da0073e9SAndroid Build Coastguard Worker dp2 = dp.map.SequenceWrapper(range(10)) 2295*da0073e9SAndroid Build Coastguard Worker shuffle_dp2 = dp2.batch(3) 2296*da0073e9SAndroid Build Coastguard Worker zip_dp1 = shuffle_dp1.zip(shuffle_dp2) 2297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(4, len(list(zip_dp1))) 2298*da0073e9SAndroid Build Coastguard Worker zip_dp2 = shuffle_dp1.zip(dp2) 2299*da0073e9SAndroid Build Coastguard Worker self.assertEqual(5, len(list(zip_dp2))) 2300*da0073e9SAndroid Build Coastguard Worker 2301*da0073e9SAndroid Build Coastguard Worker # __len__ Test: returns the length of the shortest DataPipe 2302*da0073e9SAndroid Build Coastguard Worker zip_dp = input_dp1.zip(input_dp2, input_dp3) 2303*da0073e9SAndroid Build Coastguard Worker self.assertEqual(5, len(zip_dp)) 2304*da0073e9SAndroid Build Coastguard Worker 2305*da0073e9SAndroid Build Coastguard Worker def test_shuffler_mapdatapipe(self): 2306*da0073e9SAndroid Build Coastguard Worker input_dp1 = dp.map.SequenceWrapper(range(10)) 2307*da0073e9SAndroid Build Coastguard Worker input_dp2 = dp.map.SequenceWrapper({"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) 2308*da0073e9SAndroid Build Coastguard Worker 2309*da0073e9SAndroid Build Coastguard Worker # Functional Test: Assumes 0-index when indices is not given 2310*da0073e9SAndroid Build Coastguard Worker shuffler_dp = input_dp1.shuffle() 2311*da0073e9SAndroid Build Coastguard Worker self.assertEqual(set(range(10)), set(shuffler_dp)) 2312*da0073e9SAndroid Build Coastguard Worker 2313*da0073e9SAndroid Build Coastguard Worker # Functional Test: Custom indices are working 2314*da0073e9SAndroid Build Coastguard Worker shuffler_dp = input_dp2.shuffle(indices=["a", "b", "c", "d", "e"]) 2315*da0073e9SAndroid Build Coastguard Worker self.assertEqual(set(range(1, 6)), set(shuffler_dp)) 2316*da0073e9SAndroid Build Coastguard Worker 2317*da0073e9SAndroid Build Coastguard Worker # Functional Test: With global seed 2318*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 2319*da0073e9SAndroid Build Coastguard Worker shuffler_dp = input_dp1.shuffle() 2320*da0073e9SAndroid Build Coastguard Worker res = list(shuffler_dp) 2321*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 2322*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(shuffler_dp), res) 2323*da0073e9SAndroid Build Coastguard Worker 2324*da0073e9SAndroid Build Coastguard Worker # Functional Test: Set seed 2325*da0073e9SAndroid Build Coastguard Worker shuffler_dp = input_dp1.shuffle().set_seed(123) 2326*da0073e9SAndroid Build Coastguard Worker res = list(shuffler_dp) 2327*da0073e9SAndroid Build Coastguard Worker shuffler_dp.set_seed(123) 2328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(shuffler_dp), res) 2329*da0073e9SAndroid Build Coastguard Worker 2330*da0073e9SAndroid Build Coastguard Worker # Functional Test: deactivate shuffling via set_shuffle 2331*da0073e9SAndroid Build Coastguard Worker unshuffled_dp = input_dp1.shuffle().set_shuffle(False) 2332*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(unshuffled_dp), list(input_dp1)) 2333*da0073e9SAndroid Build Coastguard Worker 2334*da0073e9SAndroid Build Coastguard Worker # Reset Test: 2335*da0073e9SAndroid Build Coastguard Worker shuffler_dp = input_dp1.shuffle() 2336*da0073e9SAndroid Build Coastguard Worker n_elements_before_reset = 5 2337*da0073e9SAndroid Build Coastguard Worker res_before_reset, res_after_reset = reset_after_n_next_calls( 2338*da0073e9SAndroid Build Coastguard Worker shuffler_dp, n_elements_before_reset 2339*da0073e9SAndroid Build Coastguard Worker ) 2340*da0073e9SAndroid Build Coastguard Worker self.assertEqual(5, len(res_before_reset)) 2341*da0073e9SAndroid Build Coastguard Worker for x in res_before_reset: 2342*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x in set(range(10))) 2343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(set(range(10)), set(res_after_reset)) 2344*da0073e9SAndroid Build Coastguard Worker 2345*da0073e9SAndroid Build Coastguard Worker # __len__ Test: returns the length of the input DataPipe 2346*da0073e9SAndroid Build Coastguard Worker shuffler_dp = input_dp1.shuffle() 2347*da0073e9SAndroid Build Coastguard Worker self.assertEqual(10, len(shuffler_dp)) 2348*da0073e9SAndroid Build Coastguard Worker 2349*da0073e9SAndroid Build Coastguard Worker # Serialization Test 2350*da0073e9SAndroid Build Coastguard Worker from torch.utils.data.datapipes._hook_iterator import _SnapshotState 2351*da0073e9SAndroid Build Coastguard Worker 2352*da0073e9SAndroid Build Coastguard Worker shuffler_dp = input_dp1.shuffle() 2353*da0073e9SAndroid Build Coastguard Worker it = iter(shuffler_dp) 2354*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 2355*da0073e9SAndroid Build Coastguard Worker next(it) 2356*da0073e9SAndroid Build Coastguard Worker shuffler_dp_copy = pickle.loads(pickle.dumps(shuffler_dp)) 2357*da0073e9SAndroid Build Coastguard Worker 2358*da0073e9SAndroid Build Coastguard Worker exp = list(it) 2359*da0073e9SAndroid Build Coastguard Worker shuffler_dp_copy._snapshot_state = _SnapshotState.Restored 2360*da0073e9SAndroid Build Coastguard Worker self.assertEqual(exp, list(shuffler_dp_copy)) 2361*da0073e9SAndroid Build Coastguard Worker 2362*da0073e9SAndroid Build Coastguard Worker def test_map_mapdatapipe(self): 2363*da0073e9SAndroid Build Coastguard Worker arr = range(10) 2364*da0073e9SAndroid Build Coastguard Worker input_dp = dp.map.SequenceWrapper(arr) 2365*da0073e9SAndroid Build Coastguard Worker 2366*da0073e9SAndroid Build Coastguard Worker def fn(item, dtype=torch.float, *, sum=False): 2367*da0073e9SAndroid Build Coastguard Worker data = torch.tensor(item, dtype=dtype) 2368*da0073e9SAndroid Build Coastguard Worker return data if not sum else data.sum() 2369*da0073e9SAndroid Build Coastguard Worker 2370*da0073e9SAndroid Build Coastguard Worker map_dp = input_dp.map(fn) 2371*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(input_dp), len(map_dp)) 2372*da0073e9SAndroid Build Coastguard Worker for index in arr: 2373*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2374*da0073e9SAndroid Build Coastguard Worker map_dp[index], torch.tensor(input_dp[index], dtype=torch.float) 2375*da0073e9SAndroid Build Coastguard Worker ) 2376*da0073e9SAndroid Build Coastguard Worker 2377*da0073e9SAndroid Build Coastguard Worker map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True)) 2378*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(input_dp), len(map_dp)) 2379*da0073e9SAndroid Build Coastguard Worker for index in arr: 2380*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2381*da0073e9SAndroid Build Coastguard Worker map_dp[index], torch.tensor(input_dp[index], dtype=torch.int).sum() 2382*da0073e9SAndroid Build Coastguard Worker ) 2383*da0073e9SAndroid Build Coastguard Worker 2384*da0073e9SAndroid Build Coastguard Worker def test_batch_mapdatapipe(self): 2385*da0073e9SAndroid Build Coastguard Worker arr = list(range(13)) 2386*da0073e9SAndroid Build Coastguard Worker input_dp = dp.map.SequenceWrapper(arr) 2387*da0073e9SAndroid Build Coastguard Worker 2388*da0073e9SAndroid Build Coastguard Worker # Functional Test: batches top level by default 2389*da0073e9SAndroid Build Coastguard Worker batch_dp = dp.map.Batcher(input_dp, batch_size=2) 2390*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2391*da0073e9SAndroid Build Coastguard Worker [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12]], list(batch_dp) 2392*da0073e9SAndroid Build Coastguard Worker ) 2393*da0073e9SAndroid Build Coastguard Worker 2394*da0073e9SAndroid Build Coastguard Worker # Functional Test: drop_last on command 2395*da0073e9SAndroid Build Coastguard Worker batch_dp = dp.map.Batcher(input_dp, batch_size=2, drop_last=True) 2396*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2397*da0073e9SAndroid Build Coastguard Worker [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]], list(batch_dp) 2398*da0073e9SAndroid Build Coastguard Worker ) 2399*da0073e9SAndroid Build Coastguard Worker 2400*da0073e9SAndroid Build Coastguard Worker # Functional Test: nested batching 2401*da0073e9SAndroid Build Coastguard Worker batch_dp_2 = batch_dp.batch(batch_size=3) 2402*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2403*da0073e9SAndroid Build Coastguard Worker [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], list(batch_dp_2) 2404*da0073e9SAndroid Build Coastguard Worker ) 2405*da0073e9SAndroid Build Coastguard Worker 2406*da0073e9SAndroid Build Coastguard Worker # Reset Test: 2407*da0073e9SAndroid Build Coastguard Worker n_elements_before_reset = 3 2408*da0073e9SAndroid Build Coastguard Worker res_before_reset, res_after_reset = reset_after_n_next_calls( 2409*da0073e9SAndroid Build Coastguard Worker batch_dp, n_elements_before_reset 2410*da0073e9SAndroid Build Coastguard Worker ) 2411*da0073e9SAndroid Build Coastguard Worker self.assertEqual([[0, 1], [2, 3], [4, 5]], res_before_reset) 2412*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2413*da0073e9SAndroid Build Coastguard Worker [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]], res_after_reset 2414*da0073e9SAndroid Build Coastguard Worker ) 2415*da0073e9SAndroid Build Coastguard Worker 2416*da0073e9SAndroid Build Coastguard Worker # __len__ Test: 2417*da0073e9SAndroid Build Coastguard Worker self.assertEqual(6, len(batch_dp)) 2418*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, len(batch_dp_2)) 2419*da0073e9SAndroid Build Coastguard Worker 2420*da0073e9SAndroid Build Coastguard Worker 2421*da0073e9SAndroid Build Coastguard Worker# Metaclass conflict for Python 3.6 2422*da0073e9SAndroid Build Coastguard Worker# Multiple inheritance with NamedTuple is not supported for Python 3.9 2423*da0073e9SAndroid Build Coastguard Worker_generic_namedtuple_allowed = sys.version_info >= (3, 7) and sys.version_info < (3, 9) 2424*da0073e9SAndroid Build Coastguard Workerif _generic_namedtuple_allowed: 2425*da0073e9SAndroid Build Coastguard Worker 2426*da0073e9SAndroid Build Coastguard Worker class InvalidData(NamedTuple, Generic[T_co]): 2427*da0073e9SAndroid Build Coastguard Worker name: str 2428*da0073e9SAndroid Build Coastguard Worker data: T_co 2429*da0073e9SAndroid Build Coastguard Worker 2430*da0073e9SAndroid Build Coastguard Worker 2431*da0073e9SAndroid Build Coastguard Workerclass TestTyping(TestCase): 2432*da0073e9SAndroid Build Coastguard Worker def test_isinstance(self): 2433*da0073e9SAndroid Build Coastguard Worker class A(IterDataPipe): 2434*da0073e9SAndroid Build Coastguard Worker pass 2435*da0073e9SAndroid Build Coastguard Worker 2436*da0073e9SAndroid Build Coastguard Worker class B(IterDataPipe): 2437*da0073e9SAndroid Build Coastguard Worker pass 2438*da0073e9SAndroid Build Coastguard Worker 2439*da0073e9SAndroid Build Coastguard Worker a = A() 2440*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(a, A)) 2441*da0073e9SAndroid Build Coastguard Worker self.assertFalse(isinstance(a, B)) 2442*da0073e9SAndroid Build Coastguard Worker 2443*da0073e9SAndroid Build Coastguard Worker def test_protocol(self): 2444*da0073e9SAndroid Build Coastguard Worker try: 2445*da0073e9SAndroid Build Coastguard Worker from typing import Protocol # type: ignore[attr-defined] 2446*da0073e9SAndroid Build Coastguard Worker except ImportError: 2447*da0073e9SAndroid Build Coastguard Worker from typing import _Protocol # type: ignore[attr-defined] 2448*da0073e9SAndroid Build Coastguard Worker 2449*da0073e9SAndroid Build Coastguard Worker Protocol = _Protocol 2450*da0073e9SAndroid Build Coastguard Worker 2451*da0073e9SAndroid Build Coastguard Worker class P(Protocol): 2452*da0073e9SAndroid Build Coastguard Worker pass 2453*da0073e9SAndroid Build Coastguard Worker 2454*da0073e9SAndroid Build Coastguard Worker class A(IterDataPipe[P]): 2455*da0073e9SAndroid Build Coastguard Worker pass 2456*da0073e9SAndroid Build Coastguard Worker 2457*da0073e9SAndroid Build Coastguard Worker @skipTyping 2458*da0073e9SAndroid Build Coastguard Worker def test_subtype(self): 2459*da0073e9SAndroid Build Coastguard Worker from torch.utils.data.datapipes._typing import issubtype 2460*da0073e9SAndroid Build Coastguard Worker 2461*da0073e9SAndroid Build Coastguard Worker basic_type = (int, str, bool, float, complex, list, tuple, dict, set, T_co) 2462*da0073e9SAndroid Build Coastguard Worker for t in basic_type: 2463*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubtype(t, t)) 2464*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubtype(t, Any)) 2465*da0073e9SAndroid Build Coastguard Worker if t == T_co: 2466*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubtype(Any, t)) 2467*da0073e9SAndroid Build Coastguard Worker else: 2468*da0073e9SAndroid Build Coastguard Worker self.assertFalse(issubtype(Any, t)) 2469*da0073e9SAndroid Build Coastguard Worker for t1, t2 in itertools.product(basic_type, basic_type): 2470*da0073e9SAndroid Build Coastguard Worker if t1 == t2 or t2 == T_co: 2471*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubtype(t1, t2)) 2472*da0073e9SAndroid Build Coastguard Worker else: 2473*da0073e9SAndroid Build Coastguard Worker self.assertFalse(issubtype(t1, t2)) 2474*da0073e9SAndroid Build Coastguard Worker 2475*da0073e9SAndroid Build Coastguard Worker T = TypeVar("T", int, str) 2476*da0073e9SAndroid Build Coastguard Worker S = TypeVar("S", bool, Union[str, int], Tuple[int, T]) # type: ignore[valid-type] 2477*da0073e9SAndroid Build Coastguard Worker types = ( 2478*da0073e9SAndroid Build Coastguard Worker (int, Optional[int]), 2479*da0073e9SAndroid Build Coastguard Worker (List, Union[int, list]), 2480*da0073e9SAndroid Build Coastguard Worker (Tuple[int, str], S), 2481*da0073e9SAndroid Build Coastguard Worker (Tuple[int, str], tuple), 2482*da0073e9SAndroid Build Coastguard Worker (T, S), 2483*da0073e9SAndroid Build Coastguard Worker (S, T_co), 2484*da0073e9SAndroid Build Coastguard Worker (T, Union[S, Set]), 2485*da0073e9SAndroid Build Coastguard Worker ) 2486*da0073e9SAndroid Build Coastguard Worker for sub, par in types: 2487*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubtype(sub, par)) 2488*da0073e9SAndroid Build Coastguard Worker self.assertFalse(issubtype(par, sub)) 2489*da0073e9SAndroid Build Coastguard Worker 2490*da0073e9SAndroid Build Coastguard Worker subscriptable_types = { 2491*da0073e9SAndroid Build Coastguard Worker List: 1, 2492*da0073e9SAndroid Build Coastguard Worker Tuple: 2, # use 2 parameters 2493*da0073e9SAndroid Build Coastguard Worker Set: 1, 2494*da0073e9SAndroid Build Coastguard Worker Dict: 2, 2495*da0073e9SAndroid Build Coastguard Worker } 2496*da0073e9SAndroid Build Coastguard Worker for subscript_type, n in subscriptable_types.items(): 2497*da0073e9SAndroid Build Coastguard Worker for ts in itertools.combinations(types, n): 2498*da0073e9SAndroid Build Coastguard Worker subs, pars = zip(*ts) 2499*da0073e9SAndroid Build Coastguard Worker sub = subscript_type[subs] # type: ignore[index] 2500*da0073e9SAndroid Build Coastguard Worker par = subscript_type[pars] # type: ignore[index] 2501*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubtype(sub, par)) 2502*da0073e9SAndroid Build Coastguard Worker self.assertFalse(issubtype(par, sub)) 2503*da0073e9SAndroid Build Coastguard Worker # Non-recursive check 2504*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubtype(par, sub, recursive=False)) 2505*da0073e9SAndroid Build Coastguard Worker 2506*da0073e9SAndroid Build Coastguard Worker @skipTyping 2507*da0073e9SAndroid Build Coastguard Worker def test_issubinstance(self): 2508*da0073e9SAndroid Build Coastguard Worker from torch.utils.data.datapipes._typing import issubinstance 2509*da0073e9SAndroid Build Coastguard Worker 2510*da0073e9SAndroid Build Coastguard Worker basic_data = (1, "1", True, 1.0, complex(1.0, 0.0)) 2511*da0073e9SAndroid Build Coastguard Worker basic_type = (int, str, bool, float, complex) 2512*da0073e9SAndroid Build Coastguard Worker S = TypeVar("S", bool, Union[str, int]) 2513*da0073e9SAndroid Build Coastguard Worker for d in basic_data: 2514*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubinstance(d, Any)) 2515*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubinstance(d, T_co)) 2516*da0073e9SAndroid Build Coastguard Worker if type(d) in (bool, int, str): 2517*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubinstance(d, S)) 2518*da0073e9SAndroid Build Coastguard Worker else: 2519*da0073e9SAndroid Build Coastguard Worker self.assertFalse(issubinstance(d, S)) 2520*da0073e9SAndroid Build Coastguard Worker for t in basic_type: 2521*da0073e9SAndroid Build Coastguard Worker if type(d) == t: 2522*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubinstance(d, t)) 2523*da0073e9SAndroid Build Coastguard Worker else: 2524*da0073e9SAndroid Build Coastguard Worker self.assertFalse(issubinstance(d, t)) 2525*da0073e9SAndroid Build Coastguard Worker # list/set 2526*da0073e9SAndroid Build Coastguard Worker dt = (([1, "1", 2], List), (set({1, "1", 2}), Set)) 2527*da0073e9SAndroid Build Coastguard Worker for d, t in dt: 2528*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubinstance(d, t)) 2529*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubinstance(d, t[T_co])) # type: ignore[index] 2530*da0073e9SAndroid Build Coastguard Worker self.assertFalse(issubinstance(d, t[int])) # type: ignore[index] 2531*da0073e9SAndroid Build Coastguard Worker 2532*da0073e9SAndroid Build Coastguard Worker # dict 2533*da0073e9SAndroid Build Coastguard Worker d = {"1": 1, "2": 2.0} 2534*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubinstance(d, Dict)) 2535*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubinstance(d, Dict[str, T_co])) 2536*da0073e9SAndroid Build Coastguard Worker self.assertFalse(issubinstance(d, Dict[str, int])) 2537*da0073e9SAndroid Build Coastguard Worker 2538*da0073e9SAndroid Build Coastguard Worker # tuple 2539*da0073e9SAndroid Build Coastguard Worker d = (1, "1", 2) 2540*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubinstance(d, Tuple)) 2541*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubinstance(d, Tuple[int, str, T_co])) 2542*da0073e9SAndroid Build Coastguard Worker self.assertFalse(issubinstance(d, Tuple[int, Any])) 2543*da0073e9SAndroid Build Coastguard Worker self.assertFalse(issubinstance(d, Tuple[int, int, int])) 2544*da0073e9SAndroid Build Coastguard Worker 2545*da0073e9SAndroid Build Coastguard Worker # Static checking annotation 2546*da0073e9SAndroid Build Coastguard Worker @skipTyping 2547*da0073e9SAndroid Build Coastguard Worker def test_compile_time(self): 2548*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"Expected 'Iterator' as the return"): 2549*da0073e9SAndroid Build Coastguard Worker 2550*da0073e9SAndroid Build Coastguard Worker class InvalidDP1(IterDataPipe[int]): 2551*da0073e9SAndroid Build Coastguard Worker def __iter__(self) -> str: # type: ignore[misc, override] 2552*da0073e9SAndroid Build Coastguard Worker yield 0 2553*da0073e9SAndroid Build Coastguard Worker 2554*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"): 2555*da0073e9SAndroid Build Coastguard Worker 2556*da0073e9SAndroid Build Coastguard Worker class InvalidDP2(IterDataPipe[Tuple]): 2557*da0073e9SAndroid Build Coastguard Worker def __iter__(self) -> Iterator[int]: # type: ignore[override] 2558*da0073e9SAndroid Build Coastguard Worker yield 0 2559*da0073e9SAndroid Build Coastguard Worker 2560*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"): 2561*da0073e9SAndroid Build Coastguard Worker 2562*da0073e9SAndroid Build Coastguard Worker class InvalidDP3(IterDataPipe[Tuple[int, str]]): 2563*da0073e9SAndroid Build Coastguard Worker def __iter__(self) -> Iterator[tuple]: # type: ignore[override] 2564*da0073e9SAndroid Build Coastguard Worker yield (0,) 2565*da0073e9SAndroid Build Coastguard Worker 2566*da0073e9SAndroid Build Coastguard Worker if _generic_namedtuple_allowed: 2567*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2568*da0073e9SAndroid Build Coastguard Worker TypeError, r"is not supported by Python typing" 2569*da0073e9SAndroid Build Coastguard Worker ): 2570*da0073e9SAndroid Build Coastguard Worker 2571*da0073e9SAndroid Build Coastguard Worker class InvalidDP4(IterDataPipe["InvalidData[int]"]): # type: ignore[type-arg, misc] 2572*da0073e9SAndroid Build Coastguard Worker pass 2573*da0073e9SAndroid Build Coastguard Worker 2574*da0073e9SAndroid Build Coastguard Worker class DP1(IterDataPipe[Tuple[int, str]]): 2575*da0073e9SAndroid Build Coastguard Worker def __init__(self, length): 2576*da0073e9SAndroid Build Coastguard Worker self.length = length 2577*da0073e9SAndroid Build Coastguard Worker 2578*da0073e9SAndroid Build Coastguard Worker def __iter__(self) -> Iterator[Tuple[int, str]]: 2579*da0073e9SAndroid Build Coastguard Worker for d in range(self.length): 2580*da0073e9SAndroid Build Coastguard Worker yield d, str(d) 2581*da0073e9SAndroid Build Coastguard Worker 2582*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubclass(DP1, IterDataPipe)) 2583*da0073e9SAndroid Build Coastguard Worker dp1 = DP1(10) 2584*da0073e9SAndroid Build Coastguard Worker self.assertTrue(DP1.type.issubtype(dp1.type) and dp1.type.issubtype(DP1.type)) # type: ignore[attr-defined] 2585*da0073e9SAndroid Build Coastguard Worker dp1_ = DP1(5) 2586*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dp1.type, dp1_.type) 2587*da0073e9SAndroid Build Coastguard Worker 2588*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"is not a generic class"): 2589*da0073e9SAndroid Build Coastguard Worker 2590*da0073e9SAndroid Build Coastguard Worker class InvalidDP5(DP1[tuple]): # type: ignore[type-arg] 2591*da0073e9SAndroid Build Coastguard Worker def __iter__(self) -> Iterator[tuple]: # type: ignore[override] 2592*da0073e9SAndroid Build Coastguard Worker yield (0,) 2593*da0073e9SAndroid Build Coastguard Worker 2594*da0073e9SAndroid Build Coastguard Worker class DP2(IterDataPipe[T_co]): 2595*da0073e9SAndroid Build Coastguard Worker def __iter__(self) -> Iterator[T_co]: 2596*da0073e9SAndroid Build Coastguard Worker yield from range(10) # type: ignore[misc] 2597*da0073e9SAndroid Build Coastguard Worker 2598*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubclass(DP2, IterDataPipe)) 2599*da0073e9SAndroid Build Coastguard Worker dp2 = DP2() # type: ignore[var-annotated] 2600*da0073e9SAndroid Build Coastguard Worker self.assertTrue(DP2.type.issubtype(dp2.type) and dp2.type.issubtype(DP2.type)) # type: ignore[attr-defined] 2601*da0073e9SAndroid Build Coastguard Worker dp2_ = DP2() # type: ignore[var-annotated] 2602*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dp2.type, dp2_.type) 2603*da0073e9SAndroid Build Coastguard Worker 2604*da0073e9SAndroid Build Coastguard Worker class DP3(IterDataPipe[Tuple[T_co, str]]): 2605*da0073e9SAndroid Build Coastguard Worker r"""DataPipe without fixed type with __init__ function""" 2606*da0073e9SAndroid Build Coastguard Worker 2607*da0073e9SAndroid Build Coastguard Worker def __init__(self, datasource): 2608*da0073e9SAndroid Build Coastguard Worker self.datasource = datasource 2609*da0073e9SAndroid Build Coastguard Worker 2610*da0073e9SAndroid Build Coastguard Worker def __iter__(self) -> Iterator[Tuple[T_co, str]]: 2611*da0073e9SAndroid Build Coastguard Worker for d in self.datasource: 2612*da0073e9SAndroid Build Coastguard Worker yield d, str(d) 2613*da0073e9SAndroid Build Coastguard Worker 2614*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubclass(DP3, IterDataPipe)) 2615*da0073e9SAndroid Build Coastguard Worker dp3 = DP3(range(10)) # type: ignore[var-annotated] 2616*da0073e9SAndroid Build Coastguard Worker self.assertTrue(DP3.type.issubtype(dp3.type) and dp3.type.issubtype(DP3.type)) # type: ignore[attr-defined] 2617*da0073e9SAndroid Build Coastguard Worker dp3_ = DP3(5) # type: ignore[var-annotated] 2618*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dp3.type, dp3_.type) 2619*da0073e9SAndroid Build Coastguard Worker 2620*da0073e9SAndroid Build Coastguard Worker class DP4(IterDataPipe[tuple]): 2621*da0073e9SAndroid Build Coastguard Worker r"""DataPipe without __iter__ annotation""" 2622*da0073e9SAndroid Build Coastguard Worker 2623*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 2624*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 2625*da0073e9SAndroid Build Coastguard Worker 2626*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubclass(DP4, IterDataPipe)) 2627*da0073e9SAndroid Build Coastguard Worker dp4 = DP4() 2628*da0073e9SAndroid Build Coastguard Worker self.assertTrue(dp4.type.param == tuple) 2629*da0073e9SAndroid Build Coastguard Worker 2630*da0073e9SAndroid Build Coastguard Worker class DP5(IterDataPipe): 2631*da0073e9SAndroid Build Coastguard Worker r"""DataPipe without type annotation""" 2632*da0073e9SAndroid Build Coastguard Worker 2633*da0073e9SAndroid Build Coastguard Worker def __iter__(self) -> Iterator[str]: 2634*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 2635*da0073e9SAndroid Build Coastguard Worker 2636*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubclass(DP5, IterDataPipe)) 2637*da0073e9SAndroid Build Coastguard Worker dp5 = DP5() 2638*da0073e9SAndroid Build Coastguard Worker from torch.utils.data.datapipes._typing import issubtype 2639*da0073e9SAndroid Build Coastguard Worker 2640*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2641*da0073e9SAndroid Build Coastguard Worker issubtype(dp5.type.param, Any) and issubtype(Any, dp5.type.param) 2642*da0073e9SAndroid Build Coastguard Worker ) 2643*da0073e9SAndroid Build Coastguard Worker 2644*da0073e9SAndroid Build Coastguard Worker class DP6(IterDataPipe[int]): 2645*da0073e9SAndroid Build Coastguard Worker r"""DataPipe with plain Iterator""" 2646*da0073e9SAndroid Build Coastguard Worker 2647*da0073e9SAndroid Build Coastguard Worker def __iter__(self) -> Iterator: 2648*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 2649*da0073e9SAndroid Build Coastguard Worker 2650*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubclass(DP6, IterDataPipe)) 2651*da0073e9SAndroid Build Coastguard Worker dp6 = DP6() 2652*da0073e9SAndroid Build Coastguard Worker self.assertTrue(dp6.type.param == int) 2653*da0073e9SAndroid Build Coastguard Worker 2654*da0073e9SAndroid Build Coastguard Worker class DP7(IterDataPipe[Awaitable[T_co]]): 2655*da0073e9SAndroid Build Coastguard Worker r"""DataPipe with abstract base class""" 2656*da0073e9SAndroid Build Coastguard Worker 2657*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubclass(DP7, IterDataPipe)) 2658*da0073e9SAndroid Build Coastguard Worker self.assertTrue(DP7.type.param == Awaitable[T_co]) # type: ignore[attr-defined] 2659*da0073e9SAndroid Build Coastguard Worker 2660*da0073e9SAndroid Build Coastguard Worker class DP8(DP7[str]): 2661*da0073e9SAndroid Build Coastguard Worker r"""DataPipe subclass from a DataPipe with abc type""" 2662*da0073e9SAndroid Build Coastguard Worker 2663*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubclass(DP8, IterDataPipe)) 2664*da0073e9SAndroid Build Coastguard Worker self.assertTrue(DP8.type.param == Awaitable[str]) # type: ignore[attr-defined] 2665*da0073e9SAndroid Build Coastguard Worker 2666*da0073e9SAndroid Build Coastguard Worker @skipTyping 2667*da0073e9SAndroid Build Coastguard Worker def test_construct_time(self): 2668*da0073e9SAndroid Build Coastguard Worker class DP0(IterDataPipe[Tuple]): 2669*da0073e9SAndroid Build Coastguard Worker @argument_validation 2670*da0073e9SAndroid Build Coastguard Worker def __init__(self, dp: IterDataPipe): 2671*da0073e9SAndroid Build Coastguard Worker self.dp = dp 2672*da0073e9SAndroid Build Coastguard Worker 2673*da0073e9SAndroid Build Coastguard Worker def __iter__(self) -> Iterator[Tuple]: 2674*da0073e9SAndroid Build Coastguard Worker for d in self.dp: 2675*da0073e9SAndroid Build Coastguard Worker yield d, str(d) 2676*da0073e9SAndroid Build Coastguard Worker 2677*da0073e9SAndroid Build Coastguard Worker class DP1(IterDataPipe[int]): 2678*da0073e9SAndroid Build Coastguard Worker @argument_validation 2679*da0073e9SAndroid Build Coastguard Worker def __init__(self, dp: IterDataPipe[Tuple[int, str]]): 2680*da0073e9SAndroid Build Coastguard Worker self.dp = dp 2681*da0073e9SAndroid Build Coastguard Worker 2682*da0073e9SAndroid Build Coastguard Worker def __iter__(self) -> Iterator[int]: 2683*da0073e9SAndroid Build Coastguard Worker for a, b in self.dp: 2684*da0073e9SAndroid Build Coastguard Worker yield a 2685*da0073e9SAndroid Build Coastguard Worker 2686*da0073e9SAndroid Build Coastguard Worker # Non-DataPipe input with DataPipe hint 2687*da0073e9SAndroid Build Coastguard Worker datasource = [(1, "1"), (2, "2"), (3, "3")] 2688*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2689*da0073e9SAndroid Build Coastguard Worker TypeError, r"Expected argument 'dp' as a IterDataPipe" 2690*da0073e9SAndroid Build Coastguard Worker ): 2691*da0073e9SAndroid Build Coastguard Worker dp0 = DP0(datasource) 2692*da0073e9SAndroid Build Coastguard Worker 2693*da0073e9SAndroid Build Coastguard Worker dp0 = DP0(dp.iter.IterableWrapper(range(10))) 2694*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2695*da0073e9SAndroid Build Coastguard Worker TypeError, r"Expected type of argument 'dp' as a subtype" 2696*da0073e9SAndroid Build Coastguard Worker ): 2697*da0073e9SAndroid Build Coastguard Worker dp1 = DP1(dp0) 2698*da0073e9SAndroid Build Coastguard Worker 2699*da0073e9SAndroid Build Coastguard Worker @skipTyping 2700*da0073e9SAndroid Build Coastguard Worker def test_runtime(self): 2701*da0073e9SAndroid Build Coastguard Worker class DP(IterDataPipe[Tuple[int, T_co]]): 2702*da0073e9SAndroid Build Coastguard Worker def __init__(self, datasource): 2703*da0073e9SAndroid Build Coastguard Worker self.ds = datasource 2704*da0073e9SAndroid Build Coastguard Worker 2705*da0073e9SAndroid Build Coastguard Worker @runtime_validation 2706*da0073e9SAndroid Build Coastguard Worker def __iter__(self) -> Iterator[Tuple[int, T_co]]: 2707*da0073e9SAndroid Build Coastguard Worker yield from self.ds 2708*da0073e9SAndroid Build Coastguard Worker 2709*da0073e9SAndroid Build Coastguard Worker dss = ([(1, "1"), (2, "2")], [(1, 1), (2, "2")]) 2710*da0073e9SAndroid Build Coastguard Worker for ds in dss: 2711*da0073e9SAndroid Build Coastguard Worker dp0 = DP(ds) # type: ignore[var-annotated] 2712*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(dp0), ds) 2713*da0073e9SAndroid Build Coastguard Worker # Reset __iter__ 2714*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(dp0), ds) 2715*da0073e9SAndroid Build Coastguard Worker 2716*da0073e9SAndroid Build Coastguard Worker dss = ( 2717*da0073e9SAndroid Build Coastguard Worker [(1, 1), ("2", 2)], # type: ignore[assignment, list-item] 2718*da0073e9SAndroid Build Coastguard Worker [[1, "1"], [2, "2"]], # type: ignore[list-item] 2719*da0073e9SAndroid Build Coastguard Worker [1, "1", 2, "2"], 2720*da0073e9SAndroid Build Coastguard Worker ) 2721*da0073e9SAndroid Build Coastguard Worker for ds in dss: 2722*da0073e9SAndroid Build Coastguard Worker dp0 = DP(ds) 2723*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2724*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Expected an instance as subtype" 2725*da0073e9SAndroid Build Coastguard Worker ): 2726*da0073e9SAndroid Build Coastguard Worker list(dp0) 2727*da0073e9SAndroid Build Coastguard Worker 2728*da0073e9SAndroid Build Coastguard Worker with runtime_validation_disabled(): 2729*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(dp0), ds) 2730*da0073e9SAndroid Build Coastguard Worker with runtime_validation_disabled(): 2731*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(dp0), ds) 2732*da0073e9SAndroid Build Coastguard Worker 2733*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2734*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Expected an instance as subtype" 2735*da0073e9SAndroid Build Coastguard Worker ): 2736*da0073e9SAndroid Build Coastguard Worker list(dp0) 2737*da0073e9SAndroid Build Coastguard Worker 2738*da0073e9SAndroid Build Coastguard Worker @skipTyping 2739*da0073e9SAndroid Build Coastguard Worker def test_reinforce(self): 2740*da0073e9SAndroid Build Coastguard Worker T = TypeVar("T", int, str) 2741*da0073e9SAndroid Build Coastguard Worker 2742*da0073e9SAndroid Build Coastguard Worker class DP(IterDataPipe[T]): 2743*da0073e9SAndroid Build Coastguard Worker def __init__(self, ds): 2744*da0073e9SAndroid Build Coastguard Worker self.ds = ds 2745*da0073e9SAndroid Build Coastguard Worker 2746*da0073e9SAndroid Build Coastguard Worker @runtime_validation 2747*da0073e9SAndroid Build Coastguard Worker def __iter__(self) -> Iterator[T]: 2748*da0073e9SAndroid Build Coastguard Worker yield from self.ds 2749*da0073e9SAndroid Build Coastguard Worker 2750*da0073e9SAndroid Build Coastguard Worker ds = list(range(10)) 2751*da0073e9SAndroid Build Coastguard Worker # Valid type reinforcement 2752*da0073e9SAndroid Build Coastguard Worker dp0 = DP(ds).reinforce_type(int) 2753*da0073e9SAndroid Build Coastguard Worker self.assertTrue(dp0.type, int) 2754*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(dp0), ds) 2755*da0073e9SAndroid Build Coastguard Worker 2756*da0073e9SAndroid Build Coastguard Worker # Invalid type 2757*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"'expected_type' must be a type"): 2758*da0073e9SAndroid Build Coastguard Worker dp1 = DP(ds).reinforce_type(1) 2759*da0073e9SAndroid Build Coastguard Worker 2760*da0073e9SAndroid Build Coastguard Worker # Type is not subtype 2761*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2762*da0073e9SAndroid Build Coastguard Worker TypeError, r"Expected 'expected_type' as subtype of" 2763*da0073e9SAndroid Build Coastguard Worker ): 2764*da0073e9SAndroid Build Coastguard Worker dp2 = DP(ds).reinforce_type(float) 2765*da0073e9SAndroid Build Coastguard Worker 2766*da0073e9SAndroid Build Coastguard Worker # Invalid data at runtime 2767*da0073e9SAndroid Build Coastguard Worker dp3 = DP(ds).reinforce_type(str) 2768*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"): 2769*da0073e9SAndroid Build Coastguard Worker list(dp3) 2770*da0073e9SAndroid Build Coastguard Worker 2771*da0073e9SAndroid Build Coastguard Worker # Context Manager to disable the runtime validation 2772*da0073e9SAndroid Build Coastguard Worker with runtime_validation_disabled(): 2773*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(dp3), ds) 2774*da0073e9SAndroid Build Coastguard Worker 2775*da0073e9SAndroid Build Coastguard Worker 2776*da0073e9SAndroid Build Coastguard Workerclass NumbersDataset(IterDataPipe): 2777*da0073e9SAndroid Build Coastguard Worker def __init__(self, size=10): 2778*da0073e9SAndroid Build Coastguard Worker self.size = size 2779*da0073e9SAndroid Build Coastguard Worker 2780*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 2781*da0073e9SAndroid Build Coastguard Worker yield from range(self.size) 2782*da0073e9SAndroid Build Coastguard Worker 2783*da0073e9SAndroid Build Coastguard Worker def __len__(self): 2784*da0073e9SAndroid Build Coastguard Worker return self.size 2785*da0073e9SAndroid Build Coastguard Worker 2786*da0073e9SAndroid Build Coastguard Worker 2787*da0073e9SAndroid Build Coastguard Workerclass TestGraph(TestCase): 2788*da0073e9SAndroid Build Coastguard Worker class CustomIterDataPipe(IterDataPipe): 2789*da0073e9SAndroid Build Coastguard Worker def add_v(self, x): 2790*da0073e9SAndroid Build Coastguard Worker return x + self.v 2791*da0073e9SAndroid Build Coastguard Worker 2792*da0073e9SAndroid Build Coastguard Worker def __init__(self, source_dp, v=1): 2793*da0073e9SAndroid Build Coastguard Worker self._dp = source_dp.map(self.add_v) 2794*da0073e9SAndroid Build Coastguard Worker self.v = 1 2795*da0073e9SAndroid Build Coastguard Worker 2796*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 2797*da0073e9SAndroid Build Coastguard Worker yield from self._dp 2798*da0073e9SAndroid Build Coastguard Worker 2799*da0073e9SAndroid Build Coastguard Worker def __hash__(self): 2800*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 2801*da0073e9SAndroid Build Coastguard Worker 2802*da0073e9SAndroid Build Coastguard Worker def test_simple_traverse(self): 2803*da0073e9SAndroid Build Coastguard Worker numbers_dp = NumbersDataset(size=50) 2804*da0073e9SAndroid Build Coastguard Worker shuffled_dp = numbers_dp.shuffle() 2805*da0073e9SAndroid Build Coastguard Worker sharded_dp = shuffled_dp.sharding_filter() 2806*da0073e9SAndroid Build Coastguard Worker mapped_dp = sharded_dp.map(lambda x: x * 10) 2807*da0073e9SAndroid Build Coastguard Worker graph = traverse_dps(mapped_dp) 2808*da0073e9SAndroid Build Coastguard Worker expected: Dict[Any, Any] = { 2809*da0073e9SAndroid Build Coastguard Worker id(mapped_dp): ( 2810*da0073e9SAndroid Build Coastguard Worker mapped_dp, 2811*da0073e9SAndroid Build Coastguard Worker { 2812*da0073e9SAndroid Build Coastguard Worker id(sharded_dp): ( 2813*da0073e9SAndroid Build Coastguard Worker sharded_dp, 2814*da0073e9SAndroid Build Coastguard Worker { 2815*da0073e9SAndroid Build Coastguard Worker id(shuffled_dp): ( 2816*da0073e9SAndroid Build Coastguard Worker shuffled_dp, 2817*da0073e9SAndroid Build Coastguard Worker {id(numbers_dp): (numbers_dp, {})}, 2818*da0073e9SAndroid Build Coastguard Worker ) 2819*da0073e9SAndroid Build Coastguard Worker }, 2820*da0073e9SAndroid Build Coastguard Worker ) 2821*da0073e9SAndroid Build Coastguard Worker }, 2822*da0073e9SAndroid Build Coastguard Worker ) 2823*da0073e9SAndroid Build Coastguard Worker } 2824*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, graph) 2825*da0073e9SAndroid Build Coastguard Worker 2826*da0073e9SAndroid Build Coastguard Worker dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph) 2827*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(dps), 4) 2828*da0073e9SAndroid Build Coastguard Worker for datapipe in (numbers_dp, shuffled_dp, sharded_dp, mapped_dp): 2829*da0073e9SAndroid Build Coastguard Worker self.assertTrue(datapipe in dps) 2830*da0073e9SAndroid Build Coastguard Worker 2831*da0073e9SAndroid Build Coastguard Worker def test_traverse_forked(self): 2832*da0073e9SAndroid Build Coastguard Worker numbers_dp = NumbersDataset(size=50) 2833*da0073e9SAndroid Build Coastguard Worker dp0, dp1, dp2 = numbers_dp.fork(num_instances=3) 2834*da0073e9SAndroid Build Coastguard Worker dp0_upd = dp0.map(lambda x: x * 10) 2835*da0073e9SAndroid Build Coastguard Worker dp1_upd = dp1.filter(lambda x: x % 3 == 1) 2836*da0073e9SAndroid Build Coastguard Worker combined_dp = dp0_upd.mux(dp1_upd, dp2) 2837*da0073e9SAndroid Build Coastguard Worker graph = traverse_dps(combined_dp) 2838*da0073e9SAndroid Build Coastguard Worker expected = { 2839*da0073e9SAndroid Build Coastguard Worker id(combined_dp): ( 2840*da0073e9SAndroid Build Coastguard Worker combined_dp, 2841*da0073e9SAndroid Build Coastguard Worker { 2842*da0073e9SAndroid Build Coastguard Worker id(dp0_upd): ( 2843*da0073e9SAndroid Build Coastguard Worker dp0_upd, 2844*da0073e9SAndroid Build Coastguard Worker { 2845*da0073e9SAndroid Build Coastguard Worker id(dp0): ( 2846*da0073e9SAndroid Build Coastguard Worker dp0, 2847*da0073e9SAndroid Build Coastguard Worker { 2848*da0073e9SAndroid Build Coastguard Worker id(dp0.main_datapipe): ( 2849*da0073e9SAndroid Build Coastguard Worker dp0.main_datapipe, 2850*da0073e9SAndroid Build Coastguard Worker { 2851*da0073e9SAndroid Build Coastguard Worker id(dp0.main_datapipe.main_datapipe): ( 2852*da0073e9SAndroid Build Coastguard Worker dp0.main_datapipe.main_datapipe, 2853*da0073e9SAndroid Build Coastguard Worker {}, 2854*da0073e9SAndroid Build Coastguard Worker ) 2855*da0073e9SAndroid Build Coastguard Worker }, 2856*da0073e9SAndroid Build Coastguard Worker ) 2857*da0073e9SAndroid Build Coastguard Worker }, 2858*da0073e9SAndroid Build Coastguard Worker ) 2859*da0073e9SAndroid Build Coastguard Worker }, 2860*da0073e9SAndroid Build Coastguard Worker ), 2861*da0073e9SAndroid Build Coastguard Worker id(dp1_upd): ( 2862*da0073e9SAndroid Build Coastguard Worker dp1_upd, 2863*da0073e9SAndroid Build Coastguard Worker { 2864*da0073e9SAndroid Build Coastguard Worker id(dp1): ( 2865*da0073e9SAndroid Build Coastguard Worker dp1, 2866*da0073e9SAndroid Build Coastguard Worker { 2867*da0073e9SAndroid Build Coastguard Worker id(dp1.main_datapipe): ( 2868*da0073e9SAndroid Build Coastguard Worker dp1.main_datapipe, 2869*da0073e9SAndroid Build Coastguard Worker { 2870*da0073e9SAndroid Build Coastguard Worker id(dp1.main_datapipe.main_datapipe): ( 2871*da0073e9SAndroid Build Coastguard Worker dp1.main_datapipe.main_datapipe, 2872*da0073e9SAndroid Build Coastguard Worker {}, 2873*da0073e9SAndroid Build Coastguard Worker ) 2874*da0073e9SAndroid Build Coastguard Worker }, 2875*da0073e9SAndroid Build Coastguard Worker ) 2876*da0073e9SAndroid Build Coastguard Worker }, 2877*da0073e9SAndroid Build Coastguard Worker ) 2878*da0073e9SAndroid Build Coastguard Worker }, 2879*da0073e9SAndroid Build Coastguard Worker ), 2880*da0073e9SAndroid Build Coastguard Worker id(dp2): ( 2881*da0073e9SAndroid Build Coastguard Worker dp2, 2882*da0073e9SAndroid Build Coastguard Worker { 2883*da0073e9SAndroid Build Coastguard Worker id(dp2.main_datapipe): ( 2884*da0073e9SAndroid Build Coastguard Worker dp2.main_datapipe, 2885*da0073e9SAndroid Build Coastguard Worker { 2886*da0073e9SAndroid Build Coastguard Worker id(dp2.main_datapipe.main_datapipe): ( 2887*da0073e9SAndroid Build Coastguard Worker dp2.main_datapipe.main_datapipe, 2888*da0073e9SAndroid Build Coastguard Worker {}, 2889*da0073e9SAndroid Build Coastguard Worker ) 2890*da0073e9SAndroid Build Coastguard Worker }, 2891*da0073e9SAndroid Build Coastguard Worker ) 2892*da0073e9SAndroid Build Coastguard Worker }, 2893*da0073e9SAndroid Build Coastguard Worker ), 2894*da0073e9SAndroid Build Coastguard Worker }, 2895*da0073e9SAndroid Build Coastguard Worker ) 2896*da0073e9SAndroid Build Coastguard Worker } 2897*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, graph) 2898*da0073e9SAndroid Build Coastguard Worker 2899*da0073e9SAndroid Build Coastguard Worker dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph) 2900*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(dps), 8) 2901*da0073e9SAndroid Build Coastguard Worker for _dp in [ 2902*da0073e9SAndroid Build Coastguard Worker numbers_dp, 2903*da0073e9SAndroid Build Coastguard Worker dp0.main_datapipe, 2904*da0073e9SAndroid Build Coastguard Worker dp0, 2905*da0073e9SAndroid Build Coastguard Worker dp1, 2906*da0073e9SAndroid Build Coastguard Worker dp2, 2907*da0073e9SAndroid Build Coastguard Worker dp0_upd, 2908*da0073e9SAndroid Build Coastguard Worker dp1_upd, 2909*da0073e9SAndroid Build Coastguard Worker combined_dp, 2910*da0073e9SAndroid Build Coastguard Worker ]: 2911*da0073e9SAndroid Build Coastguard Worker self.assertTrue(_dp in dps) 2912*da0073e9SAndroid Build Coastguard Worker 2913*da0073e9SAndroid Build Coastguard Worker def test_traverse_mapdatapipe(self): 2914*da0073e9SAndroid Build Coastguard Worker source_dp = dp.map.SequenceWrapper(range(10)) 2915*da0073e9SAndroid Build Coastguard Worker map_dp = source_dp.map(partial(_fake_add, 1)) 2916*da0073e9SAndroid Build Coastguard Worker graph = traverse_dps(map_dp) 2917*da0073e9SAndroid Build Coastguard Worker expected: Dict[Any, Any] = { 2918*da0073e9SAndroid Build Coastguard Worker id(map_dp): (map_dp, {id(source_dp): (source_dp, {})}) 2919*da0073e9SAndroid Build Coastguard Worker } 2920*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, graph) 2921*da0073e9SAndroid Build Coastguard Worker 2922*da0073e9SAndroid Build Coastguard Worker def test_traverse_mixdatapipe(self): 2923*da0073e9SAndroid Build Coastguard Worker source_map_dp = dp.map.SequenceWrapper(range(10)) 2924*da0073e9SAndroid Build Coastguard Worker iter_dp = dp.iter.IterableWrapper(source_map_dp) 2925*da0073e9SAndroid Build Coastguard Worker graph = traverse_dps(iter_dp) 2926*da0073e9SAndroid Build Coastguard Worker expected: Dict[Any, Any] = { 2927*da0073e9SAndroid Build Coastguard Worker id(iter_dp): (iter_dp, {id(source_map_dp): (source_map_dp, {})}) 2928*da0073e9SAndroid Build Coastguard Worker } 2929*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, graph) 2930*da0073e9SAndroid Build Coastguard Worker 2931*da0073e9SAndroid Build Coastguard Worker def test_traverse_circular_datapipe(self): 2932*da0073e9SAndroid Build Coastguard Worker source_iter_dp = dp.iter.IterableWrapper(list(range(10))) 2933*da0073e9SAndroid Build Coastguard Worker circular_dp = TestGraph.CustomIterDataPipe(source_iter_dp) 2934*da0073e9SAndroid Build Coastguard Worker graph = traverse_dps(circular_dp) 2935*da0073e9SAndroid Build Coastguard Worker # See issue: https://github.com/pytorch/data/issues/535 2936*da0073e9SAndroid Build Coastguard Worker expected: Dict[Any, Any] = { 2937*da0073e9SAndroid Build Coastguard Worker id(circular_dp): ( 2938*da0073e9SAndroid Build Coastguard Worker circular_dp, 2939*da0073e9SAndroid Build Coastguard Worker { 2940*da0073e9SAndroid Build Coastguard Worker id(circular_dp._dp): ( 2941*da0073e9SAndroid Build Coastguard Worker circular_dp._dp, 2942*da0073e9SAndroid Build Coastguard Worker {id(source_iter_dp): (source_iter_dp, {})}, 2943*da0073e9SAndroid Build Coastguard Worker ) 2944*da0073e9SAndroid Build Coastguard Worker }, 2945*da0073e9SAndroid Build Coastguard Worker ) 2946*da0073e9SAndroid Build Coastguard Worker } 2947*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, graph) 2948*da0073e9SAndroid Build Coastguard Worker 2949*da0073e9SAndroid Build Coastguard Worker dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph) 2950*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(dps), 3) 2951*da0073e9SAndroid Build Coastguard Worker for _dp in [circular_dp, circular_dp._dp, source_iter_dp]: 2952*da0073e9SAndroid Build Coastguard Worker self.assertTrue(_dp in dps) 2953*da0073e9SAndroid Build Coastguard Worker 2954*da0073e9SAndroid Build Coastguard Worker def test_traverse_unhashable_datapipe(self): 2955*da0073e9SAndroid Build Coastguard Worker source_iter_dp = dp.iter.IterableWrapper(list(range(10))) 2956*da0073e9SAndroid Build Coastguard Worker unhashable_dp = TestGraph.CustomIterDataPipe(source_iter_dp) 2957*da0073e9SAndroid Build Coastguard Worker graph = traverse_dps(unhashable_dp) 2958*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(NotImplementedError): 2959*da0073e9SAndroid Build Coastguard Worker hash(unhashable_dp) 2960*da0073e9SAndroid Build Coastguard Worker expected: Dict[Any, Any] = { 2961*da0073e9SAndroid Build Coastguard Worker id(unhashable_dp): ( 2962*da0073e9SAndroid Build Coastguard Worker unhashable_dp, 2963*da0073e9SAndroid Build Coastguard Worker { 2964*da0073e9SAndroid Build Coastguard Worker id(unhashable_dp._dp): ( 2965*da0073e9SAndroid Build Coastguard Worker unhashable_dp._dp, 2966*da0073e9SAndroid Build Coastguard Worker {id(source_iter_dp): (source_iter_dp, {})}, 2967*da0073e9SAndroid Build Coastguard Worker ) 2968*da0073e9SAndroid Build Coastguard Worker }, 2969*da0073e9SAndroid Build Coastguard Worker ) 2970*da0073e9SAndroid Build Coastguard Worker } 2971*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, graph) 2972*da0073e9SAndroid Build Coastguard Worker 2973*da0073e9SAndroid Build Coastguard Worker 2974*da0073e9SAndroid Build Coastguard Workerdef unbatch(x): 2975*da0073e9SAndroid Build Coastguard Worker return x[0] 2976*da0073e9SAndroid Build Coastguard Worker 2977*da0073e9SAndroid Build Coastguard Worker 2978*da0073e9SAndroid Build Coastguard Workerclass TestSerialization(TestCase): 2979*da0073e9SAndroid Build Coastguard Worker @skipIfNoDill 2980*da0073e9SAndroid Build Coastguard Worker def test_spawn_lambdas_iter(self): 2981*da0073e9SAndroid Build Coastguard Worker idp = dp.iter.IterableWrapper(range(3)).map(lambda x: x + 1).shuffle() 2982*da0073e9SAndroid Build Coastguard Worker dl = DataLoader( 2983*da0073e9SAndroid Build Coastguard Worker idp, 2984*da0073e9SAndroid Build Coastguard Worker num_workers=2, 2985*da0073e9SAndroid Build Coastguard Worker shuffle=True, 2986*da0073e9SAndroid Build Coastguard Worker multiprocessing_context="spawn", 2987*da0073e9SAndroid Build Coastguard Worker collate_fn=unbatch, 2988*da0073e9SAndroid Build Coastguard Worker batch_size=1, 2989*da0073e9SAndroid Build Coastguard Worker ) 2990*da0073e9SAndroid Build Coastguard Worker result = list(dl) 2991*da0073e9SAndroid Build Coastguard Worker self.assertEqual([1, 1, 2, 2, 3, 3], sorted(result)) 2992*da0073e9SAndroid Build Coastguard Worker 2993*da0073e9SAndroid Build Coastguard Worker @skipIfNoDill 2994*da0073e9SAndroid Build Coastguard Worker def test_spawn_lambdas_map(self): 2995*da0073e9SAndroid Build Coastguard Worker mdp = dp.map.SequenceWrapper(range(3)).map(lambda x: x + 1).shuffle() 2996*da0073e9SAndroid Build Coastguard Worker dl = DataLoader( 2997*da0073e9SAndroid Build Coastguard Worker mdp, 2998*da0073e9SAndroid Build Coastguard Worker num_workers=2, 2999*da0073e9SAndroid Build Coastguard Worker shuffle=True, 3000*da0073e9SAndroid Build Coastguard Worker multiprocessing_context="spawn", 3001*da0073e9SAndroid Build Coastguard Worker collate_fn=unbatch, 3002*da0073e9SAndroid Build Coastguard Worker batch_size=1, 3003*da0073e9SAndroid Build Coastguard Worker ) 3004*da0073e9SAndroid Build Coastguard Worker result = list(dl) 3005*da0073e9SAndroid Build Coastguard Worker self.assertEqual([1, 1, 2, 2, 3, 3], sorted(result)) 3006*da0073e9SAndroid Build Coastguard Worker 3007*da0073e9SAndroid Build Coastguard Worker 3008*da0073e9SAndroid Build Coastguard Workerclass TestCircularSerialization(TestCase): 3009*da0073e9SAndroid Build Coastguard Worker class CustomIterDataPipe(IterDataPipe): 3010*da0073e9SAndroid Build Coastguard Worker @staticmethod 3011*da0073e9SAndroid Build Coastguard Worker def add_one(x): 3012*da0073e9SAndroid Build Coastguard Worker return x + 1 3013*da0073e9SAndroid Build Coastguard Worker 3014*da0073e9SAndroid Build Coastguard Worker @classmethod 3015*da0073e9SAndroid Build Coastguard Worker def classify(cls, x): 3016*da0073e9SAndroid Build Coastguard Worker return 0 3017*da0073e9SAndroid Build Coastguard Worker 3018*da0073e9SAndroid Build Coastguard Worker def add_v(self, x): 3019*da0073e9SAndroid Build Coastguard Worker return x + self.v 3020*da0073e9SAndroid Build Coastguard Worker 3021*da0073e9SAndroid Build Coastguard Worker def __init__(self, fn, source_dp=None): 3022*da0073e9SAndroid Build Coastguard Worker self.fn = fn 3023*da0073e9SAndroid Build Coastguard Worker self.source_dp = ( 3024*da0073e9SAndroid Build Coastguard Worker source_dp if source_dp else dp.iter.IterableWrapper([1, 2, 4]) 3025*da0073e9SAndroid Build Coastguard Worker ) 3026*da0073e9SAndroid Build Coastguard Worker self._dp = ( 3027*da0073e9SAndroid Build Coastguard Worker self.source_dp.map(self.add_one) 3028*da0073e9SAndroid Build Coastguard Worker .map(self.add_v) 3029*da0073e9SAndroid Build Coastguard Worker .demux(2, self.classify)[0] 3030*da0073e9SAndroid Build Coastguard Worker ) 3031*da0073e9SAndroid Build Coastguard Worker self.v = 1 3032*da0073e9SAndroid Build Coastguard Worker 3033*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 3034*da0073e9SAndroid Build Coastguard Worker yield from self._dp 3035*da0073e9SAndroid Build Coastguard Worker 3036*da0073e9SAndroid Build Coastguard Worker def test_circular_serialization_with_pickle(self): 3037*da0073e9SAndroid Build Coastguard Worker # Test for circular reference issue with pickle 3038*da0073e9SAndroid Build Coastguard Worker dp1 = TestCircularSerialization.CustomIterDataPipe(fn=_fake_fn) 3039*da0073e9SAndroid Build Coastguard Worker self.assertTrue(list(dp1) == list(pickle.loads(pickle.dumps(dp1)))) 3040*da0073e9SAndroid Build Coastguard Worker 3041*da0073e9SAndroid Build Coastguard Worker child_1 = dp1._dp 3042*da0073e9SAndroid Build Coastguard Worker dm_1 = child_1.main_datapipe 3043*da0073e9SAndroid Build Coastguard Worker m2_1 = dm_1.main_datapipe 3044*da0073e9SAndroid Build Coastguard Worker m1_1 = m2_1.datapipe 3045*da0073e9SAndroid Build Coastguard Worker src_1 = m1_1.datapipe 3046*da0073e9SAndroid Build Coastguard Worker 3047*da0073e9SAndroid Build Coastguard Worker res1 = traverse_dps(dp1) 3048*da0073e9SAndroid Build Coastguard Worker exp_res_1 = { 3049*da0073e9SAndroid Build Coastguard Worker id(dp1): ( 3050*da0073e9SAndroid Build Coastguard Worker dp1, 3051*da0073e9SAndroid Build Coastguard Worker { 3052*da0073e9SAndroid Build Coastguard Worker id(src_1): (src_1, {}), 3053*da0073e9SAndroid Build Coastguard Worker id(child_1): ( 3054*da0073e9SAndroid Build Coastguard Worker child_1, 3055*da0073e9SAndroid Build Coastguard Worker { 3056*da0073e9SAndroid Build Coastguard Worker id(dm_1): ( 3057*da0073e9SAndroid Build Coastguard Worker dm_1, 3058*da0073e9SAndroid Build Coastguard Worker { 3059*da0073e9SAndroid Build Coastguard Worker id(m2_1): ( 3060*da0073e9SAndroid Build Coastguard Worker m2_1, 3061*da0073e9SAndroid Build Coastguard Worker {id(m1_1): (m1_1, {id(src_1): (src_1, {})})}, 3062*da0073e9SAndroid Build Coastguard Worker ) 3063*da0073e9SAndroid Build Coastguard Worker }, 3064*da0073e9SAndroid Build Coastguard Worker ) 3065*da0073e9SAndroid Build Coastguard Worker }, 3066*da0073e9SAndroid Build Coastguard Worker ), 3067*da0073e9SAndroid Build Coastguard Worker }, 3068*da0073e9SAndroid Build Coastguard Worker ) 3069*da0073e9SAndroid Build Coastguard Worker } 3070*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, exp_res_1) 3071*da0073e9SAndroid Build Coastguard Worker dp2 = TestCircularSerialization.CustomIterDataPipe(fn=_fake_fn, source_dp=dp1) 3072*da0073e9SAndroid Build Coastguard Worker self.assertTrue(list(dp2) == list(pickle.loads(pickle.dumps(dp2)))) 3073*da0073e9SAndroid Build Coastguard Worker 3074*da0073e9SAndroid Build Coastguard Worker child_2 = dp2._dp 3075*da0073e9SAndroid Build Coastguard Worker dm_2 = child_2.main_datapipe 3076*da0073e9SAndroid Build Coastguard Worker m2_2 = dm_2.main_datapipe 3077*da0073e9SAndroid Build Coastguard Worker m1_2 = m2_2.datapipe 3078*da0073e9SAndroid Build Coastguard Worker 3079*da0073e9SAndroid Build Coastguard Worker res2 = traverse_dps(dp2) 3080*da0073e9SAndroid Build Coastguard Worker exp_res_2 = { 3081*da0073e9SAndroid Build Coastguard Worker id(dp2): ( 3082*da0073e9SAndroid Build Coastguard Worker dp2, 3083*da0073e9SAndroid Build Coastguard Worker { 3084*da0073e9SAndroid Build Coastguard Worker id(dp1): ( 3085*da0073e9SAndroid Build Coastguard Worker dp1, 3086*da0073e9SAndroid Build Coastguard Worker { 3087*da0073e9SAndroid Build Coastguard Worker id(src_1): (src_1, {}), 3088*da0073e9SAndroid Build Coastguard Worker id(child_1): ( 3089*da0073e9SAndroid Build Coastguard Worker child_1, 3090*da0073e9SAndroid Build Coastguard Worker { 3091*da0073e9SAndroid Build Coastguard Worker id(dm_1): ( 3092*da0073e9SAndroid Build Coastguard Worker dm_1, 3093*da0073e9SAndroid Build Coastguard Worker { 3094*da0073e9SAndroid Build Coastguard Worker id(m2_1): ( 3095*da0073e9SAndroid Build Coastguard Worker m2_1, 3096*da0073e9SAndroid Build Coastguard Worker { 3097*da0073e9SAndroid Build Coastguard Worker id(m1_1): ( 3098*da0073e9SAndroid Build Coastguard Worker m1_1, 3099*da0073e9SAndroid Build Coastguard Worker {id(src_1): (src_1, {})}, 3100*da0073e9SAndroid Build Coastguard Worker ) 3101*da0073e9SAndroid Build Coastguard Worker }, 3102*da0073e9SAndroid Build Coastguard Worker ) 3103*da0073e9SAndroid Build Coastguard Worker }, 3104*da0073e9SAndroid Build Coastguard Worker ) 3105*da0073e9SAndroid Build Coastguard Worker }, 3106*da0073e9SAndroid Build Coastguard Worker ), 3107*da0073e9SAndroid Build Coastguard Worker }, 3108*da0073e9SAndroid Build Coastguard Worker ), 3109*da0073e9SAndroid Build Coastguard Worker id(child_2): ( 3110*da0073e9SAndroid Build Coastguard Worker child_2, 3111*da0073e9SAndroid Build Coastguard Worker { 3112*da0073e9SAndroid Build Coastguard Worker id(dm_2): ( 3113*da0073e9SAndroid Build Coastguard Worker dm_2, 3114*da0073e9SAndroid Build Coastguard Worker { 3115*da0073e9SAndroid Build Coastguard Worker id(m2_2): ( 3116*da0073e9SAndroid Build Coastguard Worker m2_2, 3117*da0073e9SAndroid Build Coastguard Worker { 3118*da0073e9SAndroid Build Coastguard Worker id(m1_2): ( 3119*da0073e9SAndroid Build Coastguard Worker m1_2, 3120*da0073e9SAndroid Build Coastguard Worker { 3121*da0073e9SAndroid Build Coastguard Worker id(dp1): ( 3122*da0073e9SAndroid Build Coastguard Worker dp1, 3123*da0073e9SAndroid Build Coastguard Worker { 3124*da0073e9SAndroid Build Coastguard Worker id(src_1): (src_1, {}), 3125*da0073e9SAndroid Build Coastguard Worker id(child_1): ( 3126*da0073e9SAndroid Build Coastguard Worker child_1, 3127*da0073e9SAndroid Build Coastguard Worker { 3128*da0073e9SAndroid Build Coastguard Worker id(dm_1): ( 3129*da0073e9SAndroid Build Coastguard Worker dm_1, 3130*da0073e9SAndroid Build Coastguard Worker { 3131*da0073e9SAndroid Build Coastguard Worker id(m2_1): ( 3132*da0073e9SAndroid Build Coastguard Worker m2_1, 3133*da0073e9SAndroid Build Coastguard Worker { 3134*da0073e9SAndroid Build Coastguard Worker id( 3135*da0073e9SAndroid Build Coastguard Worker m1_1 3136*da0073e9SAndroid Build Coastguard Worker ): ( 3137*da0073e9SAndroid Build Coastguard Worker m1_1, 3138*da0073e9SAndroid Build Coastguard Worker { 3139*da0073e9SAndroid Build Coastguard Worker id( 3140*da0073e9SAndroid Build Coastguard Worker src_1 3141*da0073e9SAndroid Build Coastguard Worker ): ( 3142*da0073e9SAndroid Build Coastguard Worker src_1, 3143*da0073e9SAndroid Build Coastguard Worker {}, 3144*da0073e9SAndroid Build Coastguard Worker ) 3145*da0073e9SAndroid Build Coastguard Worker }, 3146*da0073e9SAndroid Build Coastguard Worker ) 3147*da0073e9SAndroid Build Coastguard Worker }, 3148*da0073e9SAndroid Build Coastguard Worker ) 3149*da0073e9SAndroid Build Coastguard Worker }, 3150*da0073e9SAndroid Build Coastguard Worker ) 3151*da0073e9SAndroid Build Coastguard Worker }, 3152*da0073e9SAndroid Build Coastguard Worker ), 3153*da0073e9SAndroid Build Coastguard Worker }, 3154*da0073e9SAndroid Build Coastguard Worker ), 3155*da0073e9SAndroid Build Coastguard Worker }, 3156*da0073e9SAndroid Build Coastguard Worker ) 3157*da0073e9SAndroid Build Coastguard Worker }, 3158*da0073e9SAndroid Build Coastguard Worker ) 3159*da0073e9SAndroid Build Coastguard Worker }, 3160*da0073e9SAndroid Build Coastguard Worker ) 3161*da0073e9SAndroid Build Coastguard Worker }, 3162*da0073e9SAndroid Build Coastguard Worker ), 3163*da0073e9SAndroid Build Coastguard Worker }, 3164*da0073e9SAndroid Build Coastguard Worker ) 3165*da0073e9SAndroid Build Coastguard Worker } 3166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res2, exp_res_2) 3167*da0073e9SAndroid Build Coastguard Worker 3168*da0073e9SAndroid Build Coastguard Worker class LambdaIterDataPipe(CustomIterDataPipe): 3169*da0073e9SAndroid Build Coastguard Worker def __init__(self, fn, source_dp=None): 3170*da0073e9SAndroid Build Coastguard Worker super().__init__(fn, source_dp) 3171*da0073e9SAndroid Build Coastguard Worker self.container = [ 3172*da0073e9SAndroid Build Coastguard Worker lambda x: x + 1, 3173*da0073e9SAndroid Build Coastguard Worker ] 3174*da0073e9SAndroid Build Coastguard Worker self.lambda_fn = lambda x: x + 1 3175*da0073e9SAndroid Build Coastguard Worker self._dp = ( 3176*da0073e9SAndroid Build Coastguard Worker self.source_dp.map(self.add_one) 3177*da0073e9SAndroid Build Coastguard Worker .map(self.lambda_fn) 3178*da0073e9SAndroid Build Coastguard Worker .map(self.add_v) 3179*da0073e9SAndroid Build Coastguard Worker .demux(2, self.classify)[0] 3180*da0073e9SAndroid Build Coastguard Worker ) 3181*da0073e9SAndroid Build Coastguard Worker 3182*da0073e9SAndroid Build Coastguard Worker @skipIfNoDill 3183*da0073e9SAndroid Build Coastguard Worker @skipIf(True, "Dill Tests") 3184*da0073e9SAndroid Build Coastguard Worker def test_circular_serialization_with_dill(self): 3185*da0073e9SAndroid Build Coastguard Worker # Test for circular reference issue with dill 3186*da0073e9SAndroid Build Coastguard Worker dp1 = TestCircularSerialization.LambdaIterDataPipe(lambda x: x + 1) 3187*da0073e9SAndroid Build Coastguard Worker self.assertTrue(list(dp1) == list(dill.loads(dill.dumps(dp1)))) 3188*da0073e9SAndroid Build Coastguard Worker 3189*da0073e9SAndroid Build Coastguard Worker child_1 = dp1._dp 3190*da0073e9SAndroid Build Coastguard Worker dm_1 = child_1.main_datapipe 3191*da0073e9SAndroid Build Coastguard Worker m2_1 = dm_1.main_datapipe 3192*da0073e9SAndroid Build Coastguard Worker m1_1 = m2_1.datapipe 3193*da0073e9SAndroid Build Coastguard Worker src_1 = m1_1.datapipe 3194*da0073e9SAndroid Build Coastguard Worker 3195*da0073e9SAndroid Build Coastguard Worker res1 = traverse_dps(dp1) 3196*da0073e9SAndroid Build Coastguard Worker 3197*da0073e9SAndroid Build Coastguard Worker exp_res_1 = { 3198*da0073e9SAndroid Build Coastguard Worker id(dp1): ( 3199*da0073e9SAndroid Build Coastguard Worker dp1, 3200*da0073e9SAndroid Build Coastguard Worker { 3201*da0073e9SAndroid Build Coastguard Worker id(src_1): (src_1, {}), 3202*da0073e9SAndroid Build Coastguard Worker id(child_1): ( 3203*da0073e9SAndroid Build Coastguard Worker child_1, 3204*da0073e9SAndroid Build Coastguard Worker { 3205*da0073e9SAndroid Build Coastguard Worker id(dm_1): ( 3206*da0073e9SAndroid Build Coastguard Worker dm_1, 3207*da0073e9SAndroid Build Coastguard Worker { 3208*da0073e9SAndroid Build Coastguard Worker id(m2_1): ( 3209*da0073e9SAndroid Build Coastguard Worker m2_1, 3210*da0073e9SAndroid Build Coastguard Worker {id(m1_1): (m1_1, {id(src_1): (src_1, {})})}, 3211*da0073e9SAndroid Build Coastguard Worker ) 3212*da0073e9SAndroid Build Coastguard Worker }, 3213*da0073e9SAndroid Build Coastguard Worker ) 3214*da0073e9SAndroid Build Coastguard Worker }, 3215*da0073e9SAndroid Build Coastguard Worker ), 3216*da0073e9SAndroid Build Coastguard Worker }, 3217*da0073e9SAndroid Build Coastguard Worker ) 3218*da0073e9SAndroid Build Coastguard Worker } 3219*da0073e9SAndroid Build Coastguard Worker 3220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, exp_res_1) 3221*da0073e9SAndroid Build Coastguard Worker 3222*da0073e9SAndroid Build Coastguard Worker dp2 = TestCircularSerialization.LambdaIterDataPipe(fn=_fake_fn, source_dp=dp1) 3223*da0073e9SAndroid Build Coastguard Worker self.assertTrue(list(dp2) == list(dill.loads(dill.dumps(dp2)))) 3224*da0073e9SAndroid Build Coastguard Worker 3225*da0073e9SAndroid Build Coastguard Worker child_2 = dp2._dp 3226*da0073e9SAndroid Build Coastguard Worker dm_2 = child_2.main_datapipe 3227*da0073e9SAndroid Build Coastguard Worker m2_2 = dm_2.main_datapipe 3228*da0073e9SAndroid Build Coastguard Worker m1_2 = m2_2.datapipe 3229*da0073e9SAndroid Build Coastguard Worker 3230*da0073e9SAndroid Build Coastguard Worker res2 = traverse_dps(dp2) 3231*da0073e9SAndroid Build Coastguard Worker exp_res_2 = { 3232*da0073e9SAndroid Build Coastguard Worker id(dp2): ( 3233*da0073e9SAndroid Build Coastguard Worker dp2, 3234*da0073e9SAndroid Build Coastguard Worker { 3235*da0073e9SAndroid Build Coastguard Worker id(dp1): ( 3236*da0073e9SAndroid Build Coastguard Worker dp1, 3237*da0073e9SAndroid Build Coastguard Worker { 3238*da0073e9SAndroid Build Coastguard Worker id(src_1): (src_1, {}), 3239*da0073e9SAndroid Build Coastguard Worker id(child_1): ( 3240*da0073e9SAndroid Build Coastguard Worker child_1, 3241*da0073e9SAndroid Build Coastguard Worker { 3242*da0073e9SAndroid Build Coastguard Worker id(dm_1): ( 3243*da0073e9SAndroid Build Coastguard Worker dm_1, 3244*da0073e9SAndroid Build Coastguard Worker { 3245*da0073e9SAndroid Build Coastguard Worker id(m2_1): ( 3246*da0073e9SAndroid Build Coastguard Worker m2_1, 3247*da0073e9SAndroid Build Coastguard Worker { 3248*da0073e9SAndroid Build Coastguard Worker id(m1_1): ( 3249*da0073e9SAndroid Build Coastguard Worker m1_1, 3250*da0073e9SAndroid Build Coastguard Worker {id(src_1): (src_1, {})}, 3251*da0073e9SAndroid Build Coastguard Worker ) 3252*da0073e9SAndroid Build Coastguard Worker }, 3253*da0073e9SAndroid Build Coastguard Worker ) 3254*da0073e9SAndroid Build Coastguard Worker }, 3255*da0073e9SAndroid Build Coastguard Worker ) 3256*da0073e9SAndroid Build Coastguard Worker }, 3257*da0073e9SAndroid Build Coastguard Worker ), 3258*da0073e9SAndroid Build Coastguard Worker }, 3259*da0073e9SAndroid Build Coastguard Worker ), 3260*da0073e9SAndroid Build Coastguard Worker id(child_2): ( 3261*da0073e9SAndroid Build Coastguard Worker child_2, 3262*da0073e9SAndroid Build Coastguard Worker { 3263*da0073e9SAndroid Build Coastguard Worker id(dm_2): ( 3264*da0073e9SAndroid Build Coastguard Worker dm_2, 3265*da0073e9SAndroid Build Coastguard Worker { 3266*da0073e9SAndroid Build Coastguard Worker id(m2_2): ( 3267*da0073e9SAndroid Build Coastguard Worker m2_2, 3268*da0073e9SAndroid Build Coastguard Worker { 3269*da0073e9SAndroid Build Coastguard Worker id(m1_2): ( 3270*da0073e9SAndroid Build Coastguard Worker m1_2, 3271*da0073e9SAndroid Build Coastguard Worker { 3272*da0073e9SAndroid Build Coastguard Worker id(dp1): ( 3273*da0073e9SAndroid Build Coastguard Worker dp1, 3274*da0073e9SAndroid Build Coastguard Worker { 3275*da0073e9SAndroid Build Coastguard Worker id(src_1): (src_1, {}), 3276*da0073e9SAndroid Build Coastguard Worker id(child_1): ( 3277*da0073e9SAndroid Build Coastguard Worker child_1, 3278*da0073e9SAndroid Build Coastguard Worker { 3279*da0073e9SAndroid Build Coastguard Worker id(dm_1): ( 3280*da0073e9SAndroid Build Coastguard Worker dm_1, 3281*da0073e9SAndroid Build Coastguard Worker { 3282*da0073e9SAndroid Build Coastguard Worker id(m2_1): ( 3283*da0073e9SAndroid Build Coastguard Worker m2_1, 3284*da0073e9SAndroid Build Coastguard Worker { 3285*da0073e9SAndroid Build Coastguard Worker id( 3286*da0073e9SAndroid Build Coastguard Worker m1_1 3287*da0073e9SAndroid Build Coastguard Worker ): ( 3288*da0073e9SAndroid Build Coastguard Worker m1_1, 3289*da0073e9SAndroid Build Coastguard Worker { 3290*da0073e9SAndroid Build Coastguard Worker id( 3291*da0073e9SAndroid Build Coastguard Worker src_1 3292*da0073e9SAndroid Build Coastguard Worker ): ( 3293*da0073e9SAndroid Build Coastguard Worker src_1, 3294*da0073e9SAndroid Build Coastguard Worker {}, 3295*da0073e9SAndroid Build Coastguard Worker ) 3296*da0073e9SAndroid Build Coastguard Worker }, 3297*da0073e9SAndroid Build Coastguard Worker ) 3298*da0073e9SAndroid Build Coastguard Worker }, 3299*da0073e9SAndroid Build Coastguard Worker ) 3300*da0073e9SAndroid Build Coastguard Worker }, 3301*da0073e9SAndroid Build Coastguard Worker ) 3302*da0073e9SAndroid Build Coastguard Worker }, 3303*da0073e9SAndroid Build Coastguard Worker ), 3304*da0073e9SAndroid Build Coastguard Worker }, 3305*da0073e9SAndroid Build Coastguard Worker ), 3306*da0073e9SAndroid Build Coastguard Worker }, 3307*da0073e9SAndroid Build Coastguard Worker ) 3308*da0073e9SAndroid Build Coastguard Worker }, 3309*da0073e9SAndroid Build Coastguard Worker ) 3310*da0073e9SAndroid Build Coastguard Worker }, 3311*da0073e9SAndroid Build Coastguard Worker ) 3312*da0073e9SAndroid Build Coastguard Worker }, 3313*da0073e9SAndroid Build Coastguard Worker ), 3314*da0073e9SAndroid Build Coastguard Worker }, 3315*da0073e9SAndroid Build Coastguard Worker ) 3316*da0073e9SAndroid Build Coastguard Worker } 3317*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res2, exp_res_2) 3318*da0073e9SAndroid Build Coastguard Worker 3319*da0073e9SAndroid Build Coastguard Worker 3320*da0073e9SAndroid Build Coastguard Workerclass CustomShardingIterDataPipe(IterDataPipe): 3321*da0073e9SAndroid Build Coastguard Worker def __init__(self, dp): 3322*da0073e9SAndroid Build Coastguard Worker self.dp = dp 3323*da0073e9SAndroid Build Coastguard Worker self.num_of_instances = 1 3324*da0073e9SAndroid Build Coastguard Worker self.instance_id = 0 3325*da0073e9SAndroid Build Coastguard Worker 3326*da0073e9SAndroid Build Coastguard Worker def apply_sharding(self, num_of_instances, instance_id): 3327*da0073e9SAndroid Build Coastguard Worker self.num_of_instances = num_of_instances 3328*da0073e9SAndroid Build Coastguard Worker self.instance_id = instance_id 3329*da0073e9SAndroid Build Coastguard Worker 3330*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 3331*da0073e9SAndroid Build Coastguard Worker for i, d in enumerate(self.dp): 3332*da0073e9SAndroid Build Coastguard Worker if i % self.num_of_instances == self.instance_id: 3333*da0073e9SAndroid Build Coastguard Worker yield d 3334*da0073e9SAndroid Build Coastguard Worker 3335*da0073e9SAndroid Build Coastguard Worker 3336*da0073e9SAndroid Build Coastguard Workerclass TestSharding(TestCase): 3337*da0073e9SAndroid Build Coastguard Worker def _get_pipeline(self): 3338*da0073e9SAndroid Build Coastguard Worker numbers_dp = NumbersDataset(size=10) 3339*da0073e9SAndroid Build Coastguard Worker dp0, dp1 = numbers_dp.fork(num_instances=2) 3340*da0073e9SAndroid Build Coastguard Worker dp0_upd = dp0.map(_mul_10) 3341*da0073e9SAndroid Build Coastguard Worker dp1_upd = dp1.filter(_mod_3_test) 3342*da0073e9SAndroid Build Coastguard Worker combined_dp = dp0_upd.mux(dp1_upd) 3343*da0073e9SAndroid Build Coastguard Worker return combined_dp 3344*da0073e9SAndroid Build Coastguard Worker 3345*da0073e9SAndroid Build Coastguard Worker def _get_dill_pipeline(self): 3346*da0073e9SAndroid Build Coastguard Worker numbers_dp = NumbersDataset(size=10) 3347*da0073e9SAndroid Build Coastguard Worker dp0, dp1 = numbers_dp.fork(num_instances=2) 3348*da0073e9SAndroid Build Coastguard Worker dp0_upd = dp0.map(lambda x: x * 10) 3349*da0073e9SAndroid Build Coastguard Worker dp1_upd = dp1.filter(lambda x: x % 3 == 1) 3350*da0073e9SAndroid Build Coastguard Worker combined_dp = dp0_upd.mux(dp1_upd) 3351*da0073e9SAndroid Build Coastguard Worker return combined_dp 3352*da0073e9SAndroid Build Coastguard Worker 3353*da0073e9SAndroid Build Coastguard Worker def test_simple_sharding(self): 3354*da0073e9SAndroid Build Coastguard Worker sharded_dp = self._get_pipeline().sharding_filter() 3355*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 1) 3356*da0073e9SAndroid Build Coastguard Worker items = list(sharded_dp) 3357*da0073e9SAndroid Build Coastguard Worker self.assertEqual([1, 20], items) 3358*da0073e9SAndroid Build Coastguard Worker 3359*da0073e9SAndroid Build Coastguard Worker all_items = [0, 1, 10, 4, 20, 7] 3360*da0073e9SAndroid Build Coastguard Worker items = [] 3361*da0073e9SAndroid Build Coastguard Worker for i in range(3): 3362*da0073e9SAndroid Build Coastguard Worker sharded_dp = self._get_pipeline().sharding_filter() 3363*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, i) 3364*da0073e9SAndroid Build Coastguard Worker items += list(sharded_dp) 3365*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sorted(all_items), sorted(items)) 3366*da0073e9SAndroid Build Coastguard Worker 3367*da0073e9SAndroid Build Coastguard Worker def test_sharding_groups(self): 3368*da0073e9SAndroid Build Coastguard Worker def construct_sharded_pipe(): 3369*da0073e9SAndroid Build Coastguard Worker sharding_pipes = [] 3370*da0073e9SAndroid Build Coastguard Worker dp = NumbersDataset(size=90) 3371*da0073e9SAndroid Build Coastguard Worker dp = dp.sharding_filter( 3372*da0073e9SAndroid Build Coastguard Worker sharding_group_filter=SHARDING_PRIORITIES.DISTRIBUTED 3373*da0073e9SAndroid Build Coastguard Worker ) 3374*da0073e9SAndroid Build Coastguard Worker sharding_pipes.append(dp) 3375*da0073e9SAndroid Build Coastguard Worker dp = dp.sharding_filter( 3376*da0073e9SAndroid Build Coastguard Worker sharding_group_filter=SHARDING_PRIORITIES.MULTIPROCESSING 3377*da0073e9SAndroid Build Coastguard Worker ) 3378*da0073e9SAndroid Build Coastguard Worker sharding_pipes.append(dp) 3379*da0073e9SAndroid Build Coastguard Worker dp = dp.sharding_filter(sharding_group_filter=300) 3380*da0073e9SAndroid Build Coastguard Worker sharding_pipes.append(dp) 3381*da0073e9SAndroid Build Coastguard Worker return dp, sharding_pipes 3382*da0073e9SAndroid Build Coastguard Worker 3383*da0073e9SAndroid Build Coastguard Worker dp, sharding_pipes = construct_sharded_pipe() 3384*da0073e9SAndroid Build Coastguard Worker 3385*da0073e9SAndroid Build Coastguard Worker for pipe in sharding_pipes: 3386*da0073e9SAndroid Build Coastguard Worker pipe.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DISTRIBUTED) 3387*da0073e9SAndroid Build Coastguard Worker pipe.apply_sharding( 3388*da0073e9SAndroid Build Coastguard Worker 5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING 3389*da0073e9SAndroid Build Coastguard Worker ) 3390*da0073e9SAndroid Build Coastguard Worker pipe.apply_sharding(3, 1, sharding_group=300) 3391*da0073e9SAndroid Build Coastguard Worker 3392*da0073e9SAndroid Build Coastguard Worker actual = list(dp) 3393*da0073e9SAndroid Build Coastguard Worker expected = [17, 47, 77] 3394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 3395*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3, len(dp)) 3396*da0073e9SAndroid Build Coastguard Worker 3397*da0073e9SAndroid Build Coastguard Worker dp, _ = construct_sharded_pipe() 3398*da0073e9SAndroid Build Coastguard Worker dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT) 3399*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 3400*da0073e9SAndroid Build Coastguard Worker dp.apply_sharding(5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING) 3401*da0073e9SAndroid Build Coastguard Worker 3402*da0073e9SAndroid Build Coastguard Worker dp, _ = construct_sharded_pipe() 3403*da0073e9SAndroid Build Coastguard Worker dp.apply_sharding(5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING) 3404*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 3405*da0073e9SAndroid Build Coastguard Worker dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT) 3406*da0073e9SAndroid Build Coastguard Worker 3407*da0073e9SAndroid Build Coastguard Worker # Test tud.datapipes.iter.grouping.SHARDING_PRIORITIES for backward compatbility 3408*da0073e9SAndroid Build Coastguard Worker # TODO: Remove this test once tud.datapipes.iter.grouping.SHARDING_PRIORITIES is deprecated 3409*da0073e9SAndroid Build Coastguard Worker def test_sharding_groups_in_legacy_grouping_package(self): 3410*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 3411*da0073e9SAndroid Build Coastguard Worker FutureWarning, 3412*da0073e9SAndroid Build Coastguard Worker r"Please use `SHARDING_PRIORITIES` " 3413*da0073e9SAndroid Build Coastguard Worker "from the `torch.utils.data.datapipes.iter.sharding`", 3414*da0073e9SAndroid Build Coastguard Worker ): 3415*da0073e9SAndroid Build Coastguard Worker from torch.utils.data.datapipes.iter.grouping import ( 3416*da0073e9SAndroid Build Coastguard Worker SHARDING_PRIORITIES as LEGACY_SHARDING_PRIORITIES, 3417*da0073e9SAndroid Build Coastguard Worker ) 3418*da0073e9SAndroid Build Coastguard Worker 3419*da0073e9SAndroid Build Coastguard Worker def construct_sharded_pipe(): 3420*da0073e9SAndroid Build Coastguard Worker sharding_pipes = [] 3421*da0073e9SAndroid Build Coastguard Worker dp = NumbersDataset(size=90) 3422*da0073e9SAndroid Build Coastguard Worker dp = dp.sharding_filter( 3423*da0073e9SAndroid Build Coastguard Worker sharding_group_filter=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED 3424*da0073e9SAndroid Build Coastguard Worker ) 3425*da0073e9SAndroid Build Coastguard Worker sharding_pipes.append(dp) 3426*da0073e9SAndroid Build Coastguard Worker dp = dp.sharding_filter( 3427*da0073e9SAndroid Build Coastguard Worker sharding_group_filter=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING 3428*da0073e9SAndroid Build Coastguard Worker ) 3429*da0073e9SAndroid Build Coastguard Worker sharding_pipes.append(dp) 3430*da0073e9SAndroid Build Coastguard Worker dp = dp.sharding_filter(sharding_group_filter=300) 3431*da0073e9SAndroid Build Coastguard Worker sharding_pipes.append(dp) 3432*da0073e9SAndroid Build Coastguard Worker return dp, sharding_pipes 3433*da0073e9SAndroid Build Coastguard Worker 3434*da0073e9SAndroid Build Coastguard Worker dp, sharding_pipes = construct_sharded_pipe() 3435*da0073e9SAndroid Build Coastguard Worker 3436*da0073e9SAndroid Build Coastguard Worker for pipe in sharding_pipes: 3437*da0073e9SAndroid Build Coastguard Worker pipe.apply_sharding( 3438*da0073e9SAndroid Build Coastguard Worker 2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED 3439*da0073e9SAndroid Build Coastguard Worker ) 3440*da0073e9SAndroid Build Coastguard Worker pipe.apply_sharding( 3441*da0073e9SAndroid Build Coastguard Worker 5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING 3442*da0073e9SAndroid Build Coastguard Worker ) 3443*da0073e9SAndroid Build Coastguard Worker pipe.apply_sharding(3, 1, sharding_group=300) 3444*da0073e9SAndroid Build Coastguard Worker 3445*da0073e9SAndroid Build Coastguard Worker actual = list(dp) 3446*da0073e9SAndroid Build Coastguard Worker expected = [17, 47, 77] 3447*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 3448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3, len(dp)) 3449*da0073e9SAndroid Build Coastguard Worker 3450*da0073e9SAndroid Build Coastguard Worker dp, _ = construct_sharded_pipe() 3451*da0073e9SAndroid Build Coastguard Worker dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT) 3452*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 3453*da0073e9SAndroid Build Coastguard Worker dp.apply_sharding( 3454*da0073e9SAndroid Build Coastguard Worker 5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING 3455*da0073e9SAndroid Build Coastguard Worker ) 3456*da0073e9SAndroid Build Coastguard Worker 3457*da0073e9SAndroid Build Coastguard Worker dp, _ = construct_sharded_pipe() 3458*da0073e9SAndroid Build Coastguard Worker dp.apply_sharding( 3459*da0073e9SAndroid Build Coastguard Worker 5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING 3460*da0073e9SAndroid Build Coastguard Worker ) 3461*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 3462*da0073e9SAndroid Build Coastguard Worker dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT) 3463*da0073e9SAndroid Build Coastguard Worker 3464*da0073e9SAndroid Build Coastguard Worker def test_legacy_custom_sharding(self): 3465*da0073e9SAndroid Build Coastguard Worker dp = self._get_pipeline() 3466*da0073e9SAndroid Build Coastguard Worker sharded_dp = CustomShardingIterDataPipe(dp) 3467*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 1) 3468*da0073e9SAndroid Build Coastguard Worker items = list(sharded_dp) 3469*da0073e9SAndroid Build Coastguard Worker self.assertEqual([1, 20], items) 3470*da0073e9SAndroid Build Coastguard Worker 3471*da0073e9SAndroid Build Coastguard Worker def test_sharding_length(self): 3472*da0073e9SAndroid Build Coastguard Worker numbers_dp = dp.iter.IterableWrapper(range(13)) 3473*da0073e9SAndroid Build Coastguard Worker sharded_dp0 = numbers_dp.sharding_filter() 3474*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_sharding(sharded_dp0, 3, 0) 3475*da0073e9SAndroid Build Coastguard Worker sharded_dp1 = numbers_dp.sharding_filter() 3476*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_sharding(sharded_dp1, 3, 1) 3477*da0073e9SAndroid Build Coastguard Worker sharded_dp2 = numbers_dp.sharding_filter() 3478*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_sharding(sharded_dp2, 3, 2) 3479*da0073e9SAndroid Build Coastguard Worker self.assertEqual(13, len(numbers_dp)) 3480*da0073e9SAndroid Build Coastguard Worker self.assertEqual(5, len(sharded_dp0)) 3481*da0073e9SAndroid Build Coastguard Worker self.assertEqual(4, len(sharded_dp1)) 3482*da0073e9SAndroid Build Coastguard Worker self.assertEqual(4, len(sharded_dp2)) 3483*da0073e9SAndroid Build Coastguard Worker 3484*da0073e9SAndroid Build Coastguard Worker numbers_dp = dp.iter.IterableWrapper(range(1)) 3485*da0073e9SAndroid Build Coastguard Worker sharded_dp0 = numbers_dp.sharding_filter() 3486*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_sharding(sharded_dp0, 2, 0) 3487*da0073e9SAndroid Build Coastguard Worker sharded_dp1 = numbers_dp.sharding_filter() 3488*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_sharding(sharded_dp1, 2, 1) 3489*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, len(sharded_dp0)) 3490*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, len(sharded_dp1)) 3491*da0073e9SAndroid Build Coastguard Worker 3492*da0073e9SAndroid Build Coastguard Worker def test_old_dataloader(self): 3493*da0073e9SAndroid Build Coastguard Worker dp0 = self._get_pipeline() 3494*da0073e9SAndroid Build Coastguard Worker expected = list(dp0) 3495*da0073e9SAndroid Build Coastguard Worker 3496*da0073e9SAndroid Build Coastguard Worker dp0 = self._get_pipeline().sharding_filter() 3497*da0073e9SAndroid Build Coastguard Worker dl = DataLoader(dp0, batch_size=1, shuffle=False, num_workers=2) 3498*da0073e9SAndroid Build Coastguard Worker items = list(dl) 3499*da0073e9SAndroid Build Coastguard Worker 3500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sorted(expected), sorted(items)) 3501*da0073e9SAndroid Build Coastguard Worker 3502*da0073e9SAndroid Build Coastguard Worker def test_legacy_custom_sharding_with_old_dataloader(self): 3503*da0073e9SAndroid Build Coastguard Worker dp0 = self._get_pipeline() 3504*da0073e9SAndroid Build Coastguard Worker expected = list(dp0) 3505*da0073e9SAndroid Build Coastguard Worker 3506*da0073e9SAndroid Build Coastguard Worker dp0 = self._get_pipeline() 3507*da0073e9SAndroid Build Coastguard Worker dp0 = CustomShardingIterDataPipe(dp0) 3508*da0073e9SAndroid Build Coastguard Worker dl = DataLoader(dp0, batch_size=1, shuffle=False, num_workers=2) 3509*da0073e9SAndroid Build Coastguard Worker items = list(dl) 3510*da0073e9SAndroid Build Coastguard Worker 3511*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sorted(expected), sorted(items)) 3512*da0073e9SAndroid Build Coastguard Worker 3513*da0073e9SAndroid Build Coastguard Worker def test_multi_sharding(self): 3514*da0073e9SAndroid Build Coastguard Worker # Raises Error when multiple sharding on the single branch 3515*da0073e9SAndroid Build Coastguard Worker numbers_dp = dp.iter.IterableWrapper(range(13)) 3516*da0073e9SAndroid Build Coastguard Worker sharded_dp = numbers_dp.sharding_filter() 3517*da0073e9SAndroid Build Coastguard Worker sharded_dp = sharded_dp.sharding_filter() 3518*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3519*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Sharding twice on a single pipeline" 3520*da0073e9SAndroid Build Coastguard Worker ): 3521*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 0) 3522*da0073e9SAndroid Build Coastguard Worker 3523*da0073e9SAndroid Build Coastguard Worker # Raises Error when sharding on both data source and branch 3524*da0073e9SAndroid Build Coastguard Worker numbers_dp = dp.iter.IterableWrapper(range(13)).sharding_filter() 3525*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = numbers_dp.fork(2) 3526*da0073e9SAndroid Build Coastguard Worker sharded_dp = dp1.sharding_filter() 3527*da0073e9SAndroid Build Coastguard Worker zip_dp = dp2.zip(sharded_dp) 3528*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3529*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Sharding twice on a single pipeline" 3530*da0073e9SAndroid Build Coastguard Worker ): 3531*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0) 3532*da0073e9SAndroid Build Coastguard Worker 3533*da0073e9SAndroid Build Coastguard Worker # Raises Error when multiple sharding on the branch and end 3534*da0073e9SAndroid Build Coastguard Worker numbers_dp = dp.iter.IterableWrapper(range(13)) 3535*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = numbers_dp.fork(2) 3536*da0073e9SAndroid Build Coastguard Worker sharded_dp = dp1.sharding_filter() 3537*da0073e9SAndroid Build Coastguard Worker zip_dp = dp2.zip(sharded_dp).sharding_filter() 3538*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3539*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Sharding twice on a single pipeline" 3540*da0073e9SAndroid Build Coastguard Worker ): 3541*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0) 3542*da0073e9SAndroid Build Coastguard Worker 3543*da0073e9SAndroid Build Coastguard Worker # Single sharding_filter on data source 3544*da0073e9SAndroid Build Coastguard Worker numbers_dp = dp.iter.IterableWrapper(range(13)).sharding_filter() 3545*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = numbers_dp.fork(2) 3546*da0073e9SAndroid Build Coastguard Worker zip_dp = dp1.zip(dp2) 3547*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0) 3548*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(zip_dp), [(i * 3, i * 3) for i in range(13 // 3 + 1)]) 3549*da0073e9SAndroid Build Coastguard Worker 3550*da0073e9SAndroid Build Coastguard Worker # Single sharding_filter per branch 3551*da0073e9SAndroid Build Coastguard Worker numbers_dp = dp.iter.IterableWrapper(range(13)) 3552*da0073e9SAndroid Build Coastguard Worker dp1, dp2 = numbers_dp.fork(2) 3553*da0073e9SAndroid Build Coastguard Worker sharded_dp1 = dp1.sharding_filter() 3554*da0073e9SAndroid Build Coastguard Worker sharded_dp2 = dp2.sharding_filter() 3555*da0073e9SAndroid Build Coastguard Worker zip_dp = sharded_dp1.zip(sharded_dp2) 3556*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0) 3557*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(zip_dp), [(i * 3, i * 3) for i in range(13 // 3 + 1)]) 3558*da0073e9SAndroid Build Coastguard Worker 3559*da0073e9SAndroid Build Coastguard Worker 3560*da0073e9SAndroid Build Coastguard Workerclass TestIterDataPipeSingletonConstraint(TestCase): 3561*da0073e9SAndroid Build Coastguard Worker r""" 3562*da0073e9SAndroid Build Coastguard Worker Each `IterDataPipe` can only have one active iterator. Whenever a new iterator is created, older 3563*da0073e9SAndroid Build Coastguard Worker iterators are invalidated. These tests aim to ensure `IterDataPipe` follows this behavior. 3564*da0073e9SAndroid Build Coastguard Worker """ 3565*da0073e9SAndroid Build Coastguard Worker 3566*da0073e9SAndroid Build Coastguard Worker def _check_single_iterator_invalidation_logic(self, source_dp: IterDataPipe): 3567*da0073e9SAndroid Build Coastguard Worker r""" 3568*da0073e9SAndroid Build Coastguard Worker Given a IterDataPipe, verifies that the iterator can be read, reset, and the creation of 3569*da0073e9SAndroid Build Coastguard Worker a second iterator invalidates the first one. 3570*da0073e9SAndroid Build Coastguard Worker """ 3571*da0073e9SAndroid Build Coastguard Worker it1 = iter(source_dp) 3572*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), list(it1)) 3573*da0073e9SAndroid Build Coastguard Worker it1 = iter(source_dp) 3574*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3575*da0073e9SAndroid Build Coastguard Worker list(range(10)), list(it1) 3576*da0073e9SAndroid Build Coastguard Worker ) # A fresh iterator can be read in full again 3577*da0073e9SAndroid Build Coastguard Worker it1 = iter(source_dp) 3578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, next(it1)) 3579*da0073e9SAndroid Build Coastguard Worker it2 = iter(source_dp) # This should invalidate `it1` 3580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, next(it2)) # Should read from the beginning again 3581*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3582*da0073e9SAndroid Build Coastguard Worker next(it1) 3583*da0073e9SAndroid Build Coastguard Worker 3584*da0073e9SAndroid Build Coastguard Worker def test_iterdatapipe_singleton_generator(self): 3585*da0073e9SAndroid Build Coastguard Worker r""" 3586*da0073e9SAndroid Build Coastguard Worker Testing for the case where IterDataPipe's `__iter__` is a generator function. 3587*da0073e9SAndroid Build Coastguard Worker """ 3588*da0073e9SAndroid Build Coastguard Worker 3589*da0073e9SAndroid Build Coastguard Worker # Functional Test: Check if invalidation logic is correct 3590*da0073e9SAndroid Build Coastguard Worker source_dp: IterDataPipe = dp.iter.IterableWrapper(range(10)) 3591*da0073e9SAndroid Build Coastguard Worker self._check_single_iterator_invalidation_logic(source_dp) 3592*da0073e9SAndroid Build Coastguard Worker 3593*da0073e9SAndroid Build Coastguard Worker # Functional Test: extend the test to a pipeline 3594*da0073e9SAndroid Build Coastguard Worker dps = source_dp.map(_fake_fn).filter(_fake_filter_fn) 3595*da0073e9SAndroid Build Coastguard Worker self._check_single_iterator_invalidation_logic(dps) 3596*da0073e9SAndroid Build Coastguard Worker 3597*da0073e9SAndroid Build Coastguard Worker # Functional Test: multiple simultaneous references to the same DataPipe fails 3598*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3599*da0073e9SAndroid Build Coastguard Worker for _ in zip(source_dp, source_dp): 3600*da0073e9SAndroid Build Coastguard Worker pass 3601*da0073e9SAndroid Build Coastguard Worker 3602*da0073e9SAndroid Build Coastguard Worker # Function Test: sequential references work 3603*da0073e9SAndroid Build Coastguard Worker for _ in zip(list(source_dp), list(source_dp)): 3604*da0073e9SAndroid Build Coastguard Worker pass 3605*da0073e9SAndroid Build Coastguard Worker 3606*da0073e9SAndroid Build Coastguard Worker def test_iterdatapipe_singleton_self_next(self): 3607*da0073e9SAndroid Build Coastguard Worker r""" 3608*da0073e9SAndroid Build Coastguard Worker Testing for the case where IterDataPipe's `__iter__` returns `self` and there is a `__next__` method 3609*da0073e9SAndroid Build Coastguard Worker Note that the following DataPipe by is singleton by default (because `__iter__` returns `self`). 3610*da0073e9SAndroid Build Coastguard Worker """ 3611*da0073e9SAndroid Build Coastguard Worker 3612*da0073e9SAndroid Build Coastguard Worker class _CustomIterDP_Self(IterDataPipe): 3613*da0073e9SAndroid Build Coastguard Worker def __init__(self, iterable): 3614*da0073e9SAndroid Build Coastguard Worker self.source = iterable 3615*da0073e9SAndroid Build Coastguard Worker self.iterable = iter(iterable) 3616*da0073e9SAndroid Build Coastguard Worker 3617*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 3618*da0073e9SAndroid Build Coastguard Worker self.reset() 3619*da0073e9SAndroid Build Coastguard Worker return self 3620*da0073e9SAndroid Build Coastguard Worker 3621*da0073e9SAndroid Build Coastguard Worker def __next__(self): 3622*da0073e9SAndroid Build Coastguard Worker return next(self.iterable) 3623*da0073e9SAndroid Build Coastguard Worker 3624*da0073e9SAndroid Build Coastguard Worker def reset(self): 3625*da0073e9SAndroid Build Coastguard Worker self.iterable = iter(self.source) 3626*da0073e9SAndroid Build Coastguard Worker 3627*da0073e9SAndroid Build Coastguard Worker # Functional Test: Check that every `__iter__` call returns the same object 3628*da0073e9SAndroid Build Coastguard Worker source_dp = _CustomIterDP_Self(range(10)) 3629*da0073e9SAndroid Build Coastguard Worker res = list(source_dp) 3630*da0073e9SAndroid Build Coastguard Worker it = iter(source_dp) 3631*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, list(it)) 3632*da0073e9SAndroid Build Coastguard Worker 3633*da0073e9SAndroid Build Coastguard Worker # Functional Test: Check if invalidation logic is correct 3634*da0073e9SAndroid Build Coastguard Worker source_dp = _CustomIterDP_Self(range(10)) 3635*da0073e9SAndroid Build Coastguard Worker self._check_single_iterator_invalidation_logic(source_dp) 3636*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3637*da0073e9SAndroid Build Coastguard Worker 1, next(source_dp) 3638*da0073e9SAndroid Build Coastguard Worker ) # `source_dp` is still valid and can be read 3639*da0073e9SAndroid Build Coastguard Worker 3640*da0073e9SAndroid Build Coastguard Worker # Functional Test: extend the test to a pipeline 3641*da0073e9SAndroid Build Coastguard Worker source_dp = _CustomIterDP_Self( 3642*da0073e9SAndroid Build Coastguard Worker dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn) 3643*da0073e9SAndroid Build Coastguard Worker ) 3644*da0073e9SAndroid Build Coastguard Worker self._check_single_iterator_invalidation_logic(source_dp) 3645*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3646*da0073e9SAndroid Build Coastguard Worker 1, next(source_dp) 3647*da0073e9SAndroid Build Coastguard Worker ) # `source_dp` is still valid and can be read 3648*da0073e9SAndroid Build Coastguard Worker 3649*da0073e9SAndroid Build Coastguard Worker # Functional Test: multiple simultaneous references to the same DataPipe fails 3650*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3651*da0073e9SAndroid Build Coastguard Worker for _ in zip(source_dp, source_dp): 3652*da0073e9SAndroid Build Coastguard Worker pass 3653*da0073e9SAndroid Build Coastguard Worker 3654*da0073e9SAndroid Build Coastguard Worker def test_iterdatapipe_singleton_new_object(self): 3655*da0073e9SAndroid Build Coastguard Worker r""" 3656*da0073e9SAndroid Build Coastguard Worker Testing for the case where IterDataPipe's `__iter__` isn't a generator nor returns `self`, 3657*da0073e9SAndroid Build Coastguard Worker and there isn't a `__next__` method. 3658*da0073e9SAndroid Build Coastguard Worker """ 3659*da0073e9SAndroid Build Coastguard Worker 3660*da0073e9SAndroid Build Coastguard Worker class _CustomIterDP(IterDataPipe): 3661*da0073e9SAndroid Build Coastguard Worker def __init__(self, iterable): 3662*da0073e9SAndroid Build Coastguard Worker self.iterable = iter(iterable) 3663*da0073e9SAndroid Build Coastguard Worker 3664*da0073e9SAndroid Build Coastguard Worker def __iter__(self): # Note that this doesn't reset 3665*da0073e9SAndroid Build Coastguard Worker return self.iterable # Intentionally not returning `self` 3666*da0073e9SAndroid Build Coastguard Worker 3667*da0073e9SAndroid Build Coastguard Worker # Functional Test: Check if invalidation logic is correct 3668*da0073e9SAndroid Build Coastguard Worker source_dp = _CustomIterDP(range(10)) 3669*da0073e9SAndroid Build Coastguard Worker it1 = iter(source_dp) 3670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, next(it1)) 3671*da0073e9SAndroid Build Coastguard Worker it2 = iter(source_dp) 3672*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, next(it2)) 3673*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3674*da0073e9SAndroid Build Coastguard Worker next(it1) 3675*da0073e9SAndroid Build Coastguard Worker 3676*da0073e9SAndroid Build Coastguard Worker # Functional Test: extend the test to a pipeline 3677*da0073e9SAndroid Build Coastguard Worker source_dp = _CustomIterDP( 3678*da0073e9SAndroid Build Coastguard Worker dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn) 3679*da0073e9SAndroid Build Coastguard Worker ) 3680*da0073e9SAndroid Build Coastguard Worker it1 = iter(source_dp) 3681*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, next(it1)) 3682*da0073e9SAndroid Build Coastguard Worker it2 = iter(source_dp) 3683*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, next(it2)) 3684*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3685*da0073e9SAndroid Build Coastguard Worker next(it1) 3686*da0073e9SAndroid Build Coastguard Worker 3687*da0073e9SAndroid Build Coastguard Worker # Functional Test: multiple simultaneous references to the same DataPipe fails 3688*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3689*da0073e9SAndroid Build Coastguard Worker for _ in zip(source_dp, source_dp): 3690*da0073e9SAndroid Build Coastguard Worker pass 3691*da0073e9SAndroid Build Coastguard Worker 3692*da0073e9SAndroid Build Coastguard Worker def test_iterdatapipe_singleton_buggy(self): 3693*da0073e9SAndroid Build Coastguard Worker r""" 3694*da0073e9SAndroid Build Coastguard Worker Buggy test case case where IterDataPipe's `__iter__` returns a new object, but also has 3695*da0073e9SAndroid Build Coastguard Worker a `__next__` method. 3696*da0073e9SAndroid Build Coastguard Worker """ 3697*da0073e9SAndroid Build Coastguard Worker 3698*da0073e9SAndroid Build Coastguard Worker class _CustomIterDP(IterDataPipe): 3699*da0073e9SAndroid Build Coastguard Worker def __init__(self, iterable): 3700*da0073e9SAndroid Build Coastguard Worker self.source = iterable 3701*da0073e9SAndroid Build Coastguard Worker self.iterable = iter(iterable) 3702*da0073e9SAndroid Build Coastguard Worker 3703*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 3704*da0073e9SAndroid Build Coastguard Worker return iter(self.source) # Intentionally not returning `self` 3705*da0073e9SAndroid Build Coastguard Worker 3706*da0073e9SAndroid Build Coastguard Worker def __next__(self): 3707*da0073e9SAndroid Build Coastguard Worker return next(self.iterable) 3708*da0073e9SAndroid Build Coastguard Worker 3709*da0073e9SAndroid Build Coastguard Worker # Functional Test: Check if invalidation logic is correct 3710*da0073e9SAndroid Build Coastguard Worker source_dp = _CustomIterDP(range(10)) 3711*da0073e9SAndroid Build Coastguard Worker self._check_single_iterator_invalidation_logic(source_dp) 3712*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, next(source_dp)) # `__next__` is unrelated with `__iter__` 3713*da0073e9SAndroid Build Coastguard Worker 3714*da0073e9SAndroid Build Coastguard Worker # Functional Test: Special case to show `__next__` is unrelated with `__iter__` 3715*da0073e9SAndroid Build Coastguard Worker source_dp = _CustomIterDP(range(10)) 3716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, next(source_dp)) 3717*da0073e9SAndroid Build Coastguard Worker it1 = iter(source_dp) 3718*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, next(it1)) 3719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, next(source_dp)) 3720*da0073e9SAndroid Build Coastguard Worker it2 = iter(source_dp) # invalidates both `it1` 3721*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3722*da0073e9SAndroid Build Coastguard Worker next(it1) 3723*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, next(source_dp)) # not impacted by the creation of `it2` 3724*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3725*da0073e9SAndroid Build Coastguard Worker list(range(10)), list(it2) 3726*da0073e9SAndroid Build Coastguard Worker ) # `it2` still works because it is a new object 3727*da0073e9SAndroid Build Coastguard Worker 3728*da0073e9SAndroid Build Coastguard Worker def test_iterdatapipe_singleton_constraint_multiple_outputs(self): 3729*da0073e9SAndroid Build Coastguard Worker r""" 3730*da0073e9SAndroid Build Coastguard Worker Testing for the case where IterDataPipe has multiple child DataPipes as outputs. 3731*da0073e9SAndroid Build Coastguard Worker """ 3732*da0073e9SAndroid Build Coastguard Worker # Functional Test: all previous related iterators should be invalidated when a new iterator 3733*da0073e9SAndroid Build Coastguard Worker # is created from a ChildDataPipe 3734*da0073e9SAndroid Build Coastguard Worker source_dp: IterDataPipe = dp.iter.IterableWrapper(range(10)) 3735*da0073e9SAndroid Build Coastguard Worker cdp1, cdp2 = source_dp.fork(num_instances=2) 3736*da0073e9SAndroid Build Coastguard Worker it1, it2 = iter(cdp1), iter(cdp2) 3737*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), list(it1)) 3738*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), list(it2)) 3739*da0073e9SAndroid Build Coastguard Worker it1, it2 = iter(cdp1), iter(cdp2) 3740*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 3741*da0073e9SAndroid Build Coastguard Worker it3 = iter(cdp1) # This should invalidate `it1` and `it2` 3742*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 1) 3743*da0073e9SAndroid Build Coastguard Worker self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted") 3744*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3745*da0073e9SAndroid Build Coastguard Worker next(it1) 3746*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3747*da0073e9SAndroid Build Coastguard Worker next(it2) 3748*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, next(it3)) 3749*da0073e9SAndroid Build Coastguard Worker # The next line should not invalidate anything, as there was no new iterator created 3750*da0073e9SAndroid Build Coastguard Worker # for `cdp2` after `it2` was invalidated 3751*da0073e9SAndroid Build Coastguard Worker it4 = iter(cdp2) 3752*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, next(it3)) # An error shouldn't be raised here 3753*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), list(it4)) 3754*da0073e9SAndroid Build Coastguard Worker 3755*da0073e9SAndroid Build Coastguard Worker # Functional Test: invalidation when a new iterator is created from `source_dp` 3756*da0073e9SAndroid Build Coastguard Worker source_dp = dp.iter.IterableWrapper(range(10)) 3757*da0073e9SAndroid Build Coastguard Worker cdp1, cdp2 = source_dp.fork(num_instances=2) 3758*da0073e9SAndroid Build Coastguard Worker it1, it2 = iter(cdp1), iter(cdp2) 3759*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), list(it1)) 3760*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), list(it2)) 3761*da0073e9SAndroid Build Coastguard Worker it1, it2 = iter(cdp1), iter(cdp2) 3762*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, next(it1)) 3763*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, next(it2)) 3764*da0073e9SAndroid Build Coastguard Worker it3 = iter(source_dp) # note that a new iterator is created from `source_dp` 3765*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3766*da0073e9SAndroid Build Coastguard Worker 0, next(it3) 3767*da0073e9SAndroid Build Coastguard Worker ) # `it3` should invalidate `it1` and `it2` since they both use `source_dp` 3768*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3769*da0073e9SAndroid Build Coastguard Worker next(it1) 3770*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, next(it3)) 3771*da0073e9SAndroid Build Coastguard Worker 3772*da0073e9SAndroid Build Coastguard Worker # Function Test: Extending test to pipeline 3773*da0073e9SAndroid Build Coastguard Worker source_dp = ( 3774*da0073e9SAndroid Build Coastguard Worker dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn) 3775*da0073e9SAndroid Build Coastguard Worker ) 3776*da0073e9SAndroid Build Coastguard Worker cdp1, cdp2 = source_dp.fork(num_instances=2) 3777*da0073e9SAndroid Build Coastguard Worker it1, it2 = iter(cdp1), iter(cdp2) 3778*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), list(it1)) 3779*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(range(10)), list(it2)) 3780*da0073e9SAndroid Build Coastguard Worker it1, it2 = iter(cdp1), iter(cdp2) 3781*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 3782*da0073e9SAndroid Build Coastguard Worker it3 = iter(cdp1) # This should invalidate `it1` and `it2` 3783*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 1) 3784*da0073e9SAndroid Build Coastguard Worker self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted") 3785*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3786*da0073e9SAndroid Build Coastguard Worker next(it1) 3787*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3788*da0073e9SAndroid Build Coastguard Worker next(it2) 3789*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 3790*da0073e9SAndroid Build Coastguard Worker it1, it2 = iter(cdp1), iter(cdp2) 3791*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 1) 3792*da0073e9SAndroid Build Coastguard Worker self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted") 3793*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, next(it1)) 3794*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, next(it2)) 3795*da0073e9SAndroid Build Coastguard Worker it3 = iter(source_dp) # note that a new iterator is created from `source_dp` 3796*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3797*da0073e9SAndroid Build Coastguard Worker 0, next(it3) 3798*da0073e9SAndroid Build Coastguard Worker ) # `it3` should invalidate `it1` and `it2` since they both use `source_dp` 3799*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3800*da0073e9SAndroid Build Coastguard Worker next(it1) 3801*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, next(it3)) 3802*da0073e9SAndroid Build Coastguard Worker 3803*da0073e9SAndroid Build Coastguard Worker 3804*da0073e9SAndroid Build Coastguard Workerclass TestIterDataPipeCountSampleYielded(TestCase): 3805*da0073e9SAndroid Build Coastguard Worker def _yield_count_test_helper(self, datapipe, n_expected_samples): 3806*da0073e9SAndroid Build Coastguard Worker # Functional Test: Check if number of samples yielded is as expected 3807*da0073e9SAndroid Build Coastguard Worker res = list(datapipe) 3808*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(res), datapipe._number_of_samples_yielded) 3809*da0073e9SAndroid Build Coastguard Worker 3810*da0073e9SAndroid Build Coastguard Worker # Functional Test: Check if the count is correct when DataPipe is partially read 3811*da0073e9SAndroid Build Coastguard Worker it = iter(datapipe) 3812*da0073e9SAndroid Build Coastguard Worker res = [] 3813*da0073e9SAndroid Build Coastguard Worker for i, value in enumerate(it): 3814*da0073e9SAndroid Build Coastguard Worker res.append(value) 3815*da0073e9SAndroid Build Coastguard Worker if i == n_expected_samples - 1: 3816*da0073e9SAndroid Build Coastguard Worker break 3817*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n_expected_samples, datapipe._number_of_samples_yielded) 3818*da0073e9SAndroid Build Coastguard Worker 3819*da0073e9SAndroid Build Coastguard Worker # Functional Test: Check for reset behavior and if iterator also works 3820*da0073e9SAndroid Build Coastguard Worker it = iter(datapipe) # reset the DataPipe 3821*da0073e9SAndroid Build Coastguard Worker res = list(it) 3822*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(res), datapipe._number_of_samples_yielded) 3823*da0073e9SAndroid Build Coastguard Worker 3824*da0073e9SAndroid Build Coastguard Worker def test_iterdatapipe_sample_yielded_generator_function(self): 3825*da0073e9SAndroid Build Coastguard Worker # Functional Test: `__iter__` is a generator function 3826*da0073e9SAndroid Build Coastguard Worker datapipe: IterDataPipe = dp.iter.IterableWrapper(range(10)) 3827*da0073e9SAndroid Build Coastguard Worker self._yield_count_test_helper(datapipe, n_expected_samples=5) 3828*da0073e9SAndroid Build Coastguard Worker 3829*da0073e9SAndroid Build Coastguard Worker def test_iterdatapipe_sample_yielded_generator_function_exception(self): 3830*da0073e9SAndroid Build Coastguard Worker # Functional Test: `__iter__` is a custom generator function with exception 3831*da0073e9SAndroid Build Coastguard Worker class _CustomGeneratorFnDataPipe(IterDataPipe): 3832*da0073e9SAndroid Build Coastguard Worker # This class's `__iter__` has a Runtime Error 3833*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 3834*da0073e9SAndroid Build Coastguard Worker yield 0 3835*da0073e9SAndroid Build Coastguard Worker yield 1 3836*da0073e9SAndroid Build Coastguard Worker yield 2 3837*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Custom test error after yielding 3 elements") 3838*da0073e9SAndroid Build Coastguard Worker yield 3 3839*da0073e9SAndroid Build Coastguard Worker 3840*da0073e9SAndroid Build Coastguard Worker # Functional Test: Ensure the count is correct even when exception is raised 3841*da0073e9SAndroid Build Coastguard Worker datapipe: IterDataPipe = _CustomGeneratorFnDataPipe() 3842*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3843*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Custom test error after yielding 3 elements" 3844*da0073e9SAndroid Build Coastguard Worker ): 3845*da0073e9SAndroid Build Coastguard Worker list(datapipe) 3846*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3, datapipe._number_of_samples_yielded) 3847*da0073e9SAndroid Build Coastguard Worker 3848*da0073e9SAndroid Build Coastguard Worker # Functional Test: Check for reset behavior and if iterator also works 3849*da0073e9SAndroid Build Coastguard Worker it = iter(datapipe) # reset the DataPipe 3850*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3851*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Custom test error after yielding 3 elements" 3852*da0073e9SAndroid Build Coastguard Worker ): 3853*da0073e9SAndroid Build Coastguard Worker list(it) 3854*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3, datapipe._number_of_samples_yielded) 3855*da0073e9SAndroid Build Coastguard Worker 3856*da0073e9SAndroid Build Coastguard Worker def test_iterdatapipe_sample_yielded_return_self(self): 3857*da0073e9SAndroid Build Coastguard Worker class _CustomGeneratorDataPipe(IterDataPipe): 3858*da0073e9SAndroid Build Coastguard Worker # This class's `__iter__` is not a generator function 3859*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3860*da0073e9SAndroid Build Coastguard Worker self.source = iter(range(10)) 3861*da0073e9SAndroid Build Coastguard Worker 3862*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 3863*da0073e9SAndroid Build Coastguard Worker return self.source 3864*da0073e9SAndroid Build Coastguard Worker 3865*da0073e9SAndroid Build Coastguard Worker def reset(self): 3866*da0073e9SAndroid Build Coastguard Worker self.source = iter(range(10)) 3867*da0073e9SAndroid Build Coastguard Worker 3868*da0073e9SAndroid Build Coastguard Worker datapipe: IterDataPipe = _CustomGeneratorDataPipe() 3869*da0073e9SAndroid Build Coastguard Worker self._yield_count_test_helper(datapipe, n_expected_samples=5) 3870*da0073e9SAndroid Build Coastguard Worker 3871*da0073e9SAndroid Build Coastguard Worker def test_iterdatapipe_sample_yielded_next(self): 3872*da0073e9SAndroid Build Coastguard Worker class _CustomNextDataPipe(IterDataPipe): 3873*da0073e9SAndroid Build Coastguard Worker # This class's `__iter__` returns `self` and has a `__next__` 3874*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3875*da0073e9SAndroid Build Coastguard Worker self.source = iter(range(10)) 3876*da0073e9SAndroid Build Coastguard Worker 3877*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 3878*da0073e9SAndroid Build Coastguard Worker return self 3879*da0073e9SAndroid Build Coastguard Worker 3880*da0073e9SAndroid Build Coastguard Worker def __next__(self): 3881*da0073e9SAndroid Build Coastguard Worker return next(self.source) 3882*da0073e9SAndroid Build Coastguard Worker 3883*da0073e9SAndroid Build Coastguard Worker def reset(self): 3884*da0073e9SAndroid Build Coastguard Worker self.source = iter(range(10)) 3885*da0073e9SAndroid Build Coastguard Worker 3886*da0073e9SAndroid Build Coastguard Worker datapipe: IterDataPipe = _CustomNextDataPipe() 3887*da0073e9SAndroid Build Coastguard Worker self._yield_count_test_helper(datapipe, n_expected_samples=5) 3888*da0073e9SAndroid Build Coastguard Worker 3889*da0073e9SAndroid Build Coastguard Worker def test_iterdatapipe_sample_yielded_next_exception(self): 3890*da0073e9SAndroid Build Coastguard Worker class _CustomNextDataPipe(IterDataPipe): 3891*da0073e9SAndroid Build Coastguard Worker # This class's `__iter__` returns `self` and has a `__next__` 3892*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3893*da0073e9SAndroid Build Coastguard Worker self.source = iter(range(10)) 3894*da0073e9SAndroid Build Coastguard Worker self.count = 0 3895*da0073e9SAndroid Build Coastguard Worker 3896*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 3897*da0073e9SAndroid Build Coastguard Worker return self 3898*da0073e9SAndroid Build Coastguard Worker 3899*da0073e9SAndroid Build Coastguard Worker def __next__(self): 3900*da0073e9SAndroid Build Coastguard Worker if self.count == 3: 3901*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Custom test error after yielding 3 elements") 3902*da0073e9SAndroid Build Coastguard Worker self.count += 1 3903*da0073e9SAndroid Build Coastguard Worker return next(self.source) 3904*da0073e9SAndroid Build Coastguard Worker 3905*da0073e9SAndroid Build Coastguard Worker def reset(self): 3906*da0073e9SAndroid Build Coastguard Worker self.count = 0 3907*da0073e9SAndroid Build Coastguard Worker self.source = iter(range(10)) 3908*da0073e9SAndroid Build Coastguard Worker 3909*da0073e9SAndroid Build Coastguard Worker # Functional Test: Ensure the count is correct even when exception is raised 3910*da0073e9SAndroid Build Coastguard Worker datapipe: IterDataPipe = _CustomNextDataPipe() 3911*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3912*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Custom test error after yielding 3 elements" 3913*da0073e9SAndroid Build Coastguard Worker ): 3914*da0073e9SAndroid Build Coastguard Worker list(datapipe) 3915*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3, datapipe._number_of_samples_yielded) 3916*da0073e9SAndroid Build Coastguard Worker 3917*da0073e9SAndroid Build Coastguard Worker # Functional Test: Check for reset behavior and if iterator also works 3918*da0073e9SAndroid Build Coastguard Worker it = iter(datapipe) # reset the DataPipe 3919*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3920*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Custom test error after yielding 3 elements" 3921*da0073e9SAndroid Build Coastguard Worker ): 3922*da0073e9SAndroid Build Coastguard Worker list(it) 3923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3, datapipe._number_of_samples_yielded) 3924*da0073e9SAndroid Build Coastguard Worker 3925*da0073e9SAndroid Build Coastguard Worker 3926*da0073e9SAndroid Build Coastguard Workerclass _CustomNonGeneratorTestDataPipe(IterDataPipe): 3927*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3928*da0073e9SAndroid Build Coastguard Worker self.n = 10 3929*da0073e9SAndroid Build Coastguard Worker self.source = list(range(self.n)) 3930*da0073e9SAndroid Build Coastguard Worker 3931*da0073e9SAndroid Build Coastguard Worker # This class's `__iter__` is not a generator function 3932*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 3933*da0073e9SAndroid Build Coastguard Worker return iter(self.source) 3934*da0073e9SAndroid Build Coastguard Worker 3935*da0073e9SAndroid Build Coastguard Worker def __len__(self): 3936*da0073e9SAndroid Build Coastguard Worker return self.n 3937*da0073e9SAndroid Build Coastguard Worker 3938*da0073e9SAndroid Build Coastguard Worker 3939*da0073e9SAndroid Build Coastguard Workerclass _CustomSelfNextTestDataPipe(IterDataPipe): 3940*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3941*da0073e9SAndroid Build Coastguard Worker self.n = 10 3942*da0073e9SAndroid Build Coastguard Worker self.iter = iter(range(self.n)) 3943*da0073e9SAndroid Build Coastguard Worker 3944*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 3945*da0073e9SAndroid Build Coastguard Worker return self 3946*da0073e9SAndroid Build Coastguard Worker 3947*da0073e9SAndroid Build Coastguard Worker def __next__(self): 3948*da0073e9SAndroid Build Coastguard Worker return next(self.iter) 3949*da0073e9SAndroid Build Coastguard Worker 3950*da0073e9SAndroid Build Coastguard Worker def reset(self): 3951*da0073e9SAndroid Build Coastguard Worker self.iter = iter(range(self.n)) 3952*da0073e9SAndroid Build Coastguard Worker 3953*da0073e9SAndroid Build Coastguard Worker def __len__(self): 3954*da0073e9SAndroid Build Coastguard Worker return self.n 3955*da0073e9SAndroid Build Coastguard Worker 3956*da0073e9SAndroid Build Coastguard Worker 3957*da0073e9SAndroid Build Coastguard Workerclass TestIterDataPipeGraphFastForward(TestCase): 3958*da0073e9SAndroid Build Coastguard Worker def _fast_forward_graph_test_helper( 3959*da0073e9SAndroid Build Coastguard Worker self, datapipe, fast_forward_fn, expected_res, n_iterations=3, rng=None 3960*da0073e9SAndroid Build Coastguard Worker ): 3961*da0073e9SAndroid Build Coastguard Worker if rng is None: 3962*da0073e9SAndroid Build Coastguard Worker rng = torch.Generator() 3963*da0073e9SAndroid Build Coastguard Worker rng = rng.manual_seed(0) 3964*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_random_seed(datapipe, rng) 3965*da0073e9SAndroid Build Coastguard Worker 3966*da0073e9SAndroid Build Coastguard Worker # Test Case: fast forward works with list 3967*da0073e9SAndroid Build Coastguard Worker rng.manual_seed(0) 3968*da0073e9SAndroid Build Coastguard Worker fast_forward_fn(datapipe, n_iterations, rng) 3969*da0073e9SAndroid Build Coastguard Worker actual_res = list(datapipe) 3970*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(datapipe) - n_iterations, len(actual_res)) 3971*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_res[n_iterations:], actual_res) 3972*da0073e9SAndroid Build Coastguard Worker 3973*da0073e9SAndroid Build Coastguard Worker # Test Case: fast forward works with iterator 3974*da0073e9SAndroid Build Coastguard Worker rng.manual_seed(0) 3975*da0073e9SAndroid Build Coastguard Worker fast_forward_fn(datapipe, n_iterations, rng) 3976*da0073e9SAndroid Build Coastguard Worker it = iter(datapipe) 3977*da0073e9SAndroid Build Coastguard Worker actual_res = list(it) 3978*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(datapipe) - n_iterations, len(actual_res)) 3979*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_res[n_iterations:], actual_res) 3980*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(StopIteration): 3981*da0073e9SAndroid Build Coastguard Worker next(it) 3982*da0073e9SAndroid Build Coastguard Worker 3983*da0073e9SAndroid Build Coastguard Worker def test_simple_snapshot_graph(self): 3984*da0073e9SAndroid Build Coastguard Worker graph1 = dp.iter.IterableWrapper(range(10)) 3985*da0073e9SAndroid Build Coastguard Worker res1 = list(range(10)) 3986*da0073e9SAndroid Build Coastguard Worker self._fast_forward_graph_test_helper( 3987*da0073e9SAndroid Build Coastguard Worker graph1, _simple_graph_snapshot_restoration, expected_res=res1 3988*da0073e9SAndroid Build Coastguard Worker ) 3989*da0073e9SAndroid Build Coastguard Worker 3990*da0073e9SAndroid Build Coastguard Worker graph2 = graph1.map(_mul_10) 3991*da0073e9SAndroid Build Coastguard Worker res2 = [10 * x for x in res1] 3992*da0073e9SAndroid Build Coastguard Worker self._fast_forward_graph_test_helper( 3993*da0073e9SAndroid Build Coastguard Worker graph2, _simple_graph_snapshot_restoration, expected_res=res2 3994*da0073e9SAndroid Build Coastguard Worker ) 3995*da0073e9SAndroid Build Coastguard Worker 3996*da0073e9SAndroid Build Coastguard Worker rng = torch.Generator() 3997*da0073e9SAndroid Build Coastguard Worker graph3 = graph2.shuffle() 3998*da0073e9SAndroid Build Coastguard Worker rng.manual_seed(0) 3999*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_random_seed(graph3, rng) 4000*da0073e9SAndroid Build Coastguard Worker res3 = list(graph3) 4001*da0073e9SAndroid Build Coastguard Worker self._fast_forward_graph_test_helper( 4002*da0073e9SAndroid Build Coastguard Worker graph3, _simple_graph_snapshot_restoration, expected_res=res3 4003*da0073e9SAndroid Build Coastguard Worker ) 4004*da0073e9SAndroid Build Coastguard Worker 4005*da0073e9SAndroid Build Coastguard Worker graph4 = graph3.map(_mul_10) 4006*da0073e9SAndroid Build Coastguard Worker res4 = [10 * x for x in res3] 4007*da0073e9SAndroid Build Coastguard Worker self._fast_forward_graph_test_helper( 4008*da0073e9SAndroid Build Coastguard Worker graph4, _simple_graph_snapshot_restoration, expected_res=res4 4009*da0073e9SAndroid Build Coastguard Worker ) 4010*da0073e9SAndroid Build Coastguard Worker 4011*da0073e9SAndroid Build Coastguard Worker batch_size = 2 4012*da0073e9SAndroid Build Coastguard Worker graph5 = graph4.batch(batch_size) 4013*da0073e9SAndroid Build Coastguard Worker res5 = [ 4014*da0073e9SAndroid Build Coastguard Worker res4[i : i + batch_size] for i in range(0, len(res4), batch_size) 4015*da0073e9SAndroid Build Coastguard Worker ] # .batch(2) 4016*da0073e9SAndroid Build Coastguard Worker self._fast_forward_graph_test_helper( 4017*da0073e9SAndroid Build Coastguard Worker graph5, _simple_graph_snapshot_restoration, expected_res=res5 4018*da0073e9SAndroid Build Coastguard Worker ) 4019*da0073e9SAndroid Build Coastguard Worker 4020*da0073e9SAndroid Build Coastguard Worker # With `fork` and `zip` 4021*da0073e9SAndroid Build Coastguard Worker cdp1, cdp2 = graph5.fork(2) 4022*da0073e9SAndroid Build Coastguard Worker graph6 = cdp1.zip(cdp2) 4023*da0073e9SAndroid Build Coastguard Worker rng = rng.manual_seed(100) 4024*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_random_seed(graph6, rng) 4025*da0073e9SAndroid Build Coastguard Worker res6 = [(x, x) for x in res5] 4026*da0073e9SAndroid Build Coastguard Worker self._fast_forward_graph_test_helper( 4027*da0073e9SAndroid Build Coastguard Worker graph6, _simple_graph_snapshot_restoration, expected_res=res6 4028*da0073e9SAndroid Build Coastguard Worker ) 4029*da0073e9SAndroid Build Coastguard Worker 4030*da0073e9SAndroid Build Coastguard Worker # With `fork` and `concat` 4031*da0073e9SAndroid Build Coastguard Worker graph7 = cdp1.concat(cdp2) 4032*da0073e9SAndroid Build Coastguard Worker res7 = res5 * 2 4033*da0073e9SAndroid Build Coastguard Worker self._fast_forward_graph_test_helper( 4034*da0073e9SAndroid Build Coastguard Worker graph7, _simple_graph_snapshot_restoration, expected_res=res7 4035*da0073e9SAndroid Build Coastguard Worker ) 4036*da0073e9SAndroid Build Coastguard Worker 4037*da0073e9SAndroid Build Coastguard Worker # Raises an exception if the graph has already been restored 4038*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4039*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Snapshot restoration cannot be applied." 4040*da0073e9SAndroid Build Coastguard Worker ): 4041*da0073e9SAndroid Build Coastguard Worker _simple_graph_snapshot_restoration(graph7, 1) 4042*da0073e9SAndroid Build Coastguard Worker _simple_graph_snapshot_restoration(graph7, 1) 4043*da0073e9SAndroid Build Coastguard Worker 4044*da0073e9SAndroid Build Coastguard Worker def test_simple_snapshot_custom_non_generator(self): 4045*da0073e9SAndroid Build Coastguard Worker graph = _CustomNonGeneratorTestDataPipe() 4046*da0073e9SAndroid Build Coastguard Worker self._fast_forward_graph_test_helper( 4047*da0073e9SAndroid Build Coastguard Worker graph, _simple_graph_snapshot_restoration, expected_res=range(10) 4048*da0073e9SAndroid Build Coastguard Worker ) 4049*da0073e9SAndroid Build Coastguard Worker 4050*da0073e9SAndroid Build Coastguard Worker def test_simple_snapshot_custom_self_next(self): 4051*da0073e9SAndroid Build Coastguard Worker graph = _CustomSelfNextTestDataPipe() 4052*da0073e9SAndroid Build Coastguard Worker self._fast_forward_graph_test_helper( 4053*da0073e9SAndroid Build Coastguard Worker graph, _simple_graph_snapshot_restoration, expected_res=range(10) 4054*da0073e9SAndroid Build Coastguard Worker ) 4055*da0073e9SAndroid Build Coastguard Worker 4056*da0073e9SAndroid Build Coastguard Worker def _snapshot_test_helper(self, datapipe, expected_res, n_iter=3, rng=None): 4057*da0073e9SAndroid Build Coastguard Worker """ 4058*da0073e9SAndroid Build Coastguard Worker Extend the previous test with serialization and deserialization test. 4059*da0073e9SAndroid Build Coastguard Worker """ 4060*da0073e9SAndroid Build Coastguard Worker if rng is None: 4061*da0073e9SAndroid Build Coastguard Worker rng = torch.Generator() 4062*da0073e9SAndroid Build Coastguard Worker rng.manual_seed(0) 4063*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_random_seed(datapipe, rng) 4064*da0073e9SAndroid Build Coastguard Worker it = iter(datapipe) 4065*da0073e9SAndroid Build Coastguard Worker for _ in range(n_iter): 4066*da0073e9SAndroid Build Coastguard Worker next(it) 4067*da0073e9SAndroid Build Coastguard Worker serialized_graph = pickle.dumps(datapipe) 4068*da0073e9SAndroid Build Coastguard Worker deserialized_graph = pickle.loads(serialized_graph) 4069*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n_iter, datapipe._number_of_samples_yielded) 4070*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n_iter, deserialized_graph._number_of_samples_yielded) 4071*da0073e9SAndroid Build Coastguard Worker 4072*da0073e9SAndroid Build Coastguard Worker rng_for_deserialized = torch.Generator() 4073*da0073e9SAndroid Build Coastguard Worker rng_for_deserialized.manual_seed(0) 4074*da0073e9SAndroid Build Coastguard Worker _simple_graph_snapshot_restoration( 4075*da0073e9SAndroid Build Coastguard Worker deserialized_graph, n_iter, rng=rng_for_deserialized 4076*da0073e9SAndroid Build Coastguard Worker ) 4077*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_res[n_iter:], list(it)) 4078*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_res[n_iter:], list(deserialized_graph)) 4079*da0073e9SAndroid Build Coastguard Worker 4080*da0073e9SAndroid Build Coastguard Worker def test_simple_snapshot_graph_with_serialization(self): 4081*da0073e9SAndroid Build Coastguard Worker graph1 = dp.iter.IterableWrapper(range(10)) 4082*da0073e9SAndroid Build Coastguard Worker res1 = list(range(10)) 4083*da0073e9SAndroid Build Coastguard Worker self._snapshot_test_helper(graph1, expected_res=res1) 4084*da0073e9SAndroid Build Coastguard Worker 4085*da0073e9SAndroid Build Coastguard Worker graph2 = graph1.map(_mul_10) 4086*da0073e9SAndroid Build Coastguard Worker res2 = [10 * x for x in res1] 4087*da0073e9SAndroid Build Coastguard Worker self._snapshot_test_helper(graph2, expected_res=res2) 4088*da0073e9SAndroid Build Coastguard Worker 4089*da0073e9SAndroid Build Coastguard Worker rng = torch.Generator() 4090*da0073e9SAndroid Build Coastguard Worker graph3 = graph2.shuffle() 4091*da0073e9SAndroid Build Coastguard Worker rng.manual_seed(0) 4092*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_random_seed(graph3, rng) 4093*da0073e9SAndroid Build Coastguard Worker res3 = list(graph3) 4094*da0073e9SAndroid Build Coastguard Worker self._snapshot_test_helper(graph3, expected_res=res3) 4095*da0073e9SAndroid Build Coastguard Worker 4096*da0073e9SAndroid Build Coastguard Worker graph4 = graph3.map(_mul_10) 4097*da0073e9SAndroid Build Coastguard Worker res4 = [10 * x for x in res3] 4098*da0073e9SAndroid Build Coastguard Worker self._snapshot_test_helper(graph4, expected_res=res4) 4099*da0073e9SAndroid Build Coastguard Worker 4100*da0073e9SAndroid Build Coastguard Worker batch_size = 2 4101*da0073e9SAndroid Build Coastguard Worker graph5 = graph4.batch(batch_size) 4102*da0073e9SAndroid Build Coastguard Worker res5 = [ 4103*da0073e9SAndroid Build Coastguard Worker res4[i : i + batch_size] for i in range(0, len(res4), batch_size) 4104*da0073e9SAndroid Build Coastguard Worker ] # .batch(2) 4105*da0073e9SAndroid Build Coastguard Worker self._snapshot_test_helper(graph5, expected_res=res5) 4106*da0073e9SAndroid Build Coastguard Worker 4107*da0073e9SAndroid Build Coastguard Worker # With `fork` and `zip` 4108*da0073e9SAndroid Build Coastguard Worker cdp1, cdp2 = graph5.fork(2) 4109*da0073e9SAndroid Build Coastguard Worker graph6 = cdp1.zip(cdp2) 4110*da0073e9SAndroid Build Coastguard Worker res6 = [(x, x) for x in res5] 4111*da0073e9SAndroid Build Coastguard Worker self._snapshot_test_helper(graph6, expected_res=res6) 4112*da0073e9SAndroid Build Coastguard Worker 4113*da0073e9SAndroid Build Coastguard Worker # With `fork` and `concat` 4114*da0073e9SAndroid Build Coastguard Worker graph7 = cdp1.concat(cdp2) 4115*da0073e9SAndroid Build Coastguard Worker res7 = res5 * 2 4116*da0073e9SAndroid Build Coastguard Worker self._snapshot_test_helper(graph7, expected_res=res7) 4117*da0073e9SAndroid Build Coastguard Worker 4118*da0073e9SAndroid Build Coastguard Worker def test_simple_snapshot_graph_repeated(self): 4119*da0073e9SAndroid Build Coastguard Worker cdp1, cdp2 = ( 4120*da0073e9SAndroid Build Coastguard Worker dp.iter.IterableWrapper(range(10)) 4121*da0073e9SAndroid Build Coastguard Worker .map(_mul_10) 4122*da0073e9SAndroid Build Coastguard Worker .shuffle() 4123*da0073e9SAndroid Build Coastguard Worker .map(_mul_10) 4124*da0073e9SAndroid Build Coastguard Worker .map(_mul_10) 4125*da0073e9SAndroid Build Coastguard Worker .fork(2) 4126*da0073e9SAndroid Build Coastguard Worker ) 4127*da0073e9SAndroid Build Coastguard Worker graph = cdp1.zip(cdp2) 4128*da0073e9SAndroid Build Coastguard Worker 4129*da0073e9SAndroid Build Coastguard Worker rng = torch.Generator() 4130*da0073e9SAndroid Build Coastguard Worker rng.manual_seed(0) 4131*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_random_seed(graph, rng) 4132*da0073e9SAndroid Build Coastguard Worker 4133*da0073e9SAndroid Build Coastguard Worker # Get expected result 4134*da0073e9SAndroid Build Coastguard Worker expected_res = list(graph) 4135*da0073e9SAndroid Build Coastguard Worker 4136*da0073e9SAndroid Build Coastguard Worker rng.manual_seed(0) 4137*da0073e9SAndroid Build Coastguard Worker torch.utils.data.graph_settings.apply_random_seed(graph, rng) 4138*da0073e9SAndroid Build Coastguard Worker it = iter(graph) 4139*da0073e9SAndroid Build Coastguard Worker n_iter = 3 4140*da0073e9SAndroid Build Coastguard Worker for _ in range(n_iter): 4141*da0073e9SAndroid Build Coastguard Worker next(it) 4142*da0073e9SAndroid Build Coastguard Worker 4143*da0073e9SAndroid Build Coastguard Worker # First serialization/deserialization 4144*da0073e9SAndroid Build Coastguard Worker serialized_graph = pickle.dumps(graph) 4145*da0073e9SAndroid Build Coastguard Worker deserialized_graph = pickle.loads(serialized_graph) 4146*da0073e9SAndroid Build Coastguard Worker 4147*da0073e9SAndroid Build Coastguard Worker rng_for_deserialized = torch.Generator() 4148*da0073e9SAndroid Build Coastguard Worker rng_for_deserialized.manual_seed(0) 4149*da0073e9SAndroid Build Coastguard Worker _simple_graph_snapshot_restoration( 4150*da0073e9SAndroid Build Coastguard Worker deserialized_graph, 4151*da0073e9SAndroid Build Coastguard Worker deserialized_graph._number_of_samples_yielded, 4152*da0073e9SAndroid Build Coastguard Worker rng=rng_for_deserialized, 4153*da0073e9SAndroid Build Coastguard Worker ) 4154*da0073e9SAndroid Build Coastguard Worker 4155*da0073e9SAndroid Build Coastguard Worker it = iter(deserialized_graph) 4156*da0073e9SAndroid Build Coastguard Worker # Get the next element and ensure it is as expected 4157*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_res[3], next(it)) 4158*da0073e9SAndroid Build Coastguard Worker 4159*da0073e9SAndroid Build Coastguard Worker # Serializalize/Deserialize and fast-forward again after to ensure it works 4160*da0073e9SAndroid Build Coastguard Worker serialized_graph2 = pickle.dumps(deserialized_graph) 4161*da0073e9SAndroid Build Coastguard Worker deserialized_graph2 = pickle.loads(serialized_graph2) 4162*da0073e9SAndroid Build Coastguard Worker 4163*da0073e9SAndroid Build Coastguard Worker rng_for_deserialized = torch.Generator() 4164*da0073e9SAndroid Build Coastguard Worker rng_for_deserialized.manual_seed(0) 4165*da0073e9SAndroid Build Coastguard Worker _simple_graph_snapshot_restoration( 4166*da0073e9SAndroid Build Coastguard Worker deserialized_graph2, 4167*da0073e9SAndroid Build Coastguard Worker deserialized_graph._number_of_samples_yielded, 4168*da0073e9SAndroid Build Coastguard Worker rng=rng_for_deserialized, 4169*da0073e9SAndroid Build Coastguard Worker ) 4170*da0073e9SAndroid Build Coastguard Worker 4171*da0073e9SAndroid Build Coastguard Worker # Get the next element and ensure it is as expected 4172*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_res[4:], list(deserialized_graph2)) 4173*da0073e9SAndroid Build Coastguard Worker 4174*da0073e9SAndroid Build Coastguard Worker 4175*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 4176*da0073e9SAndroid Build Coastguard Worker run_tests() 4177