xref: /aosp_15_r20/external/pytorch/torch/utils/data/_utils/fetch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch data from an iterable-style or map-style dataset.
3
4This logic is shared in both single- and multi-processing data loading.
5"""
6
7
8class _BaseDatasetFetcher:
9    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
10        self.dataset = dataset
11        self.auto_collation = auto_collation
12        self.collate_fn = collate_fn
13        self.drop_last = drop_last
14
15    def fetch(self, possibly_batched_index):
16        raise NotImplementedError
17
18
19class _IterableDatasetFetcher(_BaseDatasetFetcher):
20    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
21        super().__init__(dataset, auto_collation, collate_fn, drop_last)
22        self.dataset_iter = iter(dataset)
23        self.ended = False
24
25    def fetch(self, possibly_batched_index):
26        if self.ended:
27            raise StopIteration
28
29        if self.auto_collation:
30            data = []
31            for _ in possibly_batched_index:
32                try:
33                    data.append(next(self.dataset_iter))
34                except StopIteration:
35                    self.ended = True
36                    break
37            if len(data) == 0 or (
38                self.drop_last and len(data) < len(possibly_batched_index)
39            ):
40                raise StopIteration
41        else:
42            data = next(self.dataset_iter)
43        return self.collate_fn(data)
44
45
46class _MapDatasetFetcher(_BaseDatasetFetcher):
47    def fetch(self, possibly_batched_index):
48        if self.auto_collation:
49            if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
50                data = self.dataset.__getitems__(possibly_batched_index)
51            else:
52                data = [self.dataset[idx] for idx in possibly_batched_index]
53        else:
54            data = self.dataset[possibly_batched_index]
55        return self.collate_fn(data)
56