1# mypy: allow-untyped-defs 2r"""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch data from an iterable-style or map-style dataset. 3 4This logic is shared in both single- and multi-processing data loading. 5""" 6 7 8class _BaseDatasetFetcher: 9 def __init__(self, dataset, auto_collation, collate_fn, drop_last): 10 self.dataset = dataset 11 self.auto_collation = auto_collation 12 self.collate_fn = collate_fn 13 self.drop_last = drop_last 14 15 def fetch(self, possibly_batched_index): 16 raise NotImplementedError 17 18 19class _IterableDatasetFetcher(_BaseDatasetFetcher): 20 def __init__(self, dataset, auto_collation, collate_fn, drop_last): 21 super().__init__(dataset, auto_collation, collate_fn, drop_last) 22 self.dataset_iter = iter(dataset) 23 self.ended = False 24 25 def fetch(self, possibly_batched_index): 26 if self.ended: 27 raise StopIteration 28 29 if self.auto_collation: 30 data = [] 31 for _ in possibly_batched_index: 32 try: 33 data.append(next(self.dataset_iter)) 34 except StopIteration: 35 self.ended = True 36 break 37 if len(data) == 0 or ( 38 self.drop_last and len(data) < len(possibly_batched_index) 39 ): 40 raise StopIteration 41 else: 42 data = next(self.dataset_iter) 43 return self.collate_fn(data) 44 45 46class _MapDatasetFetcher(_BaseDatasetFetcher): 47 def fetch(self, possibly_batched_index): 48 if self.auto_collation: 49 if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__: 50 data = self.dataset.__getitems__(possibly_batched_index) 51 else: 52 data = [self.dataset[idx] for idx in possibly_batched_index] 53 else: 54 data = self.dataset[possibly_batched_index] 55 return self.collate_fn(data) 56