xref: /aosp_15_r20/external/pytorch/torch/utils/data/dataset.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import bisect
3import itertools
4import math
5import warnings
6from typing import (
7    cast,
8    Dict,
9    Generic,
10    Iterable,
11    List,
12    Optional,
13    Sequence,
14    Tuple,
15    TypeVar,
16    Union,
17)
18from typing_extensions import deprecated
19
20# No 'default_generator' in torch/__init__.pyi
21from torch import default_generator, Generator, randperm, Tensor
22
23
24__all__ = [
25    "Dataset",
26    "IterableDataset",
27    "TensorDataset",
28    "StackDataset",
29    "ConcatDataset",
30    "ChainDataset",
31    "Subset",
32    "random_split",
33]
34
35
36_T = TypeVar("_T")
37_T_co = TypeVar("_T_co", covariant=True)
38_T_dict = Dict[str, _T_co]
39_T_tuple = Tuple[_T_co, ...]
40_T_stack = TypeVar("_T_stack", _T_tuple, _T_dict)
41
42
43class Dataset(Generic[_T_co]):
44    r"""An abstract class representing a :class:`Dataset`.
45
46    All datasets that represent a map from keys to data samples should subclass
47    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
48    data sample for a given key. Subclasses could also optionally overwrite
49    :meth:`__len__`, which is expected to return the size of the dataset by many
50    :class:`~torch.utils.data.Sampler` implementations and the default options
51    of :class:`~torch.utils.data.DataLoader`. Subclasses could also
52    optionally implement :meth:`__getitems__`, for speedup batched samples
53    loading. This method accepts list of indices of samples of batch and returns
54    list of samples.
55
56    .. note::
57      :class:`~torch.utils.data.DataLoader` by default constructs an index
58      sampler that yields integral indices.  To make it work with a map-style
59      dataset with non-integral indices/keys, a custom sampler must be provided.
60    """
61
62    def __getitem__(self, index) -> _T_co:
63        raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
64
65    # def __getitems__(self, indices: List) -> List[_T_co]:
66    # Not implemented to prevent false-positives in fetcher check in
67    # torch.utils.data._utils.fetch._MapDatasetFetcher
68
69    def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]":
70        return ConcatDataset([self, other])
71
72    # No `def __len__(self)` default?
73    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
74    # in pytorch/torch/utils/data/sampler.py
75
76
77class IterableDataset(Dataset[_T_co], Iterable[_T_co]):
78    r"""An iterable Dataset.
79
80    All datasets that represent an iterable of data samples should subclass it.
81    Such form of datasets is particularly useful when data come from a stream.
82
83    All subclasses should overwrite :meth:`__iter__`, which would return an
84    iterator of samples in this dataset.
85
86    When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
87    item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
88    iterator. When :attr:`num_workers > 0`, each worker process will have a
89    different copy of the dataset object, so it is often desired to configure
90    each copy independently to avoid having duplicate data returned from the
91    workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
92    process, returns information about the worker. It can be used in either the
93    dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
94    :attr:`worker_init_fn` option to modify each copy's behavior.
95
96    Example 1: splitting workload across all workers in :meth:`__iter__`::
97
98        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
99        >>> # xdoctest: +SKIP("Fails on MacOS12")
100        >>> class MyIterableDataset(torch.utils.data.IterableDataset):
101        ...     def __init__(self, start, end):
102        ...         super(MyIterableDataset).__init__()
103        ...         assert end > start, "this example code only works with end >= start"
104        ...         self.start = start
105        ...         self.end = end
106        ...
107        ...     def __iter__(self):
108        ...         worker_info = torch.utils.data.get_worker_info()
109        ...         if worker_info is None:  # single-process data loading, return the full iterator
110        ...             iter_start = self.start
111        ...             iter_end = self.end
112        ...         else:  # in a worker process
113        ...             # split workload
114        ...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
115        ...             worker_id = worker_info.id
116        ...             iter_start = self.start + worker_id * per_worker
117        ...             iter_end = min(iter_start + per_worker, self.end)
118        ...         return iter(range(iter_start, iter_end))
119        ...
120        >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
121        >>> ds = MyIterableDataset(start=3, end=7)
122
123        >>> # Single-process loading
124        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
125        [tensor([3]), tensor([4]), tensor([5]), tensor([6])]
126
127        >>> # xdoctest: +REQUIRES(POSIX)
128        >>> # Mult-process loading with two worker processes
129        >>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
130        >>> # xdoctest: +IGNORE_WANT("non deterministic")
131        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
132        [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
133
134        >>> # With even more workers
135        >>> # xdoctest: +IGNORE_WANT("non deterministic")
136        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
137        [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
138
139    Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
140
141        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
142        >>> class MyIterableDataset(torch.utils.data.IterableDataset):
143        ...     def __init__(self, start, end):
144        ...         super(MyIterableDataset).__init__()
145        ...         assert end > start, "this example code only works with end >= start"
146        ...         self.start = start
147        ...         self.end = end
148        ...
149        ...     def __iter__(self):
150        ...         return iter(range(self.start, self.end))
151        ...
152        >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
153        >>> ds = MyIterableDataset(start=3, end=7)
154
155        >>> # Single-process loading
156        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
157        [3, 4, 5, 6]
158        >>>
159        >>> # Directly doing multi-process loading yields duplicate data
160        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
161        [3, 3, 4, 4, 5, 5, 6, 6]
162
163        >>> # Define a `worker_init_fn` that configures each dataset copy differently
164        >>> def worker_init_fn(worker_id):
165        ...     worker_info = torch.utils.data.get_worker_info()
166        ...     dataset = worker_info.dataset  # the dataset copy in this worker process
167        ...     overall_start = dataset.start
168        ...     overall_end = dataset.end
169        ...     # configure the dataset to only process the split workload
170        ...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
171        ...     worker_id = worker_info.id
172        ...     dataset.start = overall_start + worker_id * per_worker
173        ...     dataset.end = min(dataset.start + per_worker, overall_end)
174        ...
175
176        >>> # Mult-process loading with the custom `worker_init_fn`
177        >>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
178        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
179        [3, 5, 4, 6]
180
181        >>> # With even more workers
182        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
183        [3, 4, 5, 6]
184    """
185
186    def __add__(self, other: Dataset[_T_co]):
187        return ChainDataset([self, other])
188
189    # No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
190    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
191
192
193class TensorDataset(Dataset[Tuple[Tensor, ...]]):
194    r"""Dataset wrapping tensors.
195
196    Each sample will be retrieved by indexing tensors along the first dimension.
197
198    Args:
199        *tensors (Tensor): tensors that have the same size of the first dimension.
200    """
201
202    tensors: Tuple[Tensor, ...]
203
204    def __init__(self, *tensors: Tensor) -> None:
205        assert all(
206            tensors[0].size(0) == tensor.size(0) for tensor in tensors
207        ), "Size mismatch between tensors"
208        self.tensors = tensors
209
210    def __getitem__(self, index):
211        return tuple(tensor[index] for tensor in self.tensors)
212
213    def __len__(self):
214        return self.tensors[0].size(0)
215
216
217class StackDataset(Dataset[_T_stack]):
218    r"""Dataset as a stacking of multiple datasets.
219
220    This class is useful to assemble different parts of complex input data, given as datasets.
221
222    Example:
223        >>> # xdoctest: +SKIP
224        >>> images = ImageDataset()
225        >>> texts = TextDataset()
226        >>> tuple_stack = StackDataset(images, texts)
227        >>> tuple_stack[0] == (images[0], texts[0])
228        >>> dict_stack = StackDataset(image=images, text=texts)
229        >>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
230
231    Args:
232        *args (Dataset): Datasets for stacking returned as tuple.
233        **kwargs (Dataset): Datasets for stacking returned as dict.
234    """
235
236    datasets: Union[tuple, dict]
237
238    def __init__(self, *args: Dataset[_T_co], **kwargs: Dataset[_T_co]) -> None:
239        if args:
240            if kwargs:
241                raise ValueError(
242                    "Supported either ``tuple``- (via ``args``) or"
243                    "``dict``- (via ``kwargs``) like input/output, but both types are given."
244                )
245            self._length = len(args[0])  # type: ignore[arg-type]
246            if any(self._length != len(dataset) for dataset in args):  # type: ignore[arg-type]
247                raise ValueError("Size mismatch between datasets")
248            self.datasets = args
249        elif kwargs:
250            tmp = list(kwargs.values())
251            self._length = len(tmp[0])  # type: ignore[arg-type]
252            if any(self._length != len(dataset) for dataset in tmp):  # type: ignore[arg-type]
253                raise ValueError("Size mismatch between datasets")
254            self.datasets = kwargs
255        else:
256            raise ValueError("At least one dataset should be passed")
257
258    def __getitem__(self, index):
259        if isinstance(self.datasets, dict):
260            return {k: dataset[index] for k, dataset in self.datasets.items()}
261        return tuple(dataset[index] for dataset in self.datasets)
262
263    def __getitems__(self, indices: list):
264        # add batched sampling support when parent datasets supports it.
265        if isinstance(self.datasets, dict):
266            dict_batch: List[_T_dict] = [{} for _ in indices]
267            for k, dataset in self.datasets.items():
268                if callable(getattr(dataset, "__getitems__", None)):
269                    items = dataset.__getitems__(indices)  # type: ignore[attr-defined]
270                    if len(items) != len(indices):
271                        raise ValueError(
272                            "Nested dataset's output size mismatch."
273                            f" Expected {len(indices)}, got {len(items)}"
274                        )
275                    for data, d_sample in zip(items, dict_batch):
276                        d_sample[k] = data
277                else:
278                    for idx, d_sample in zip(indices, dict_batch):
279                        d_sample[k] = dataset[idx]
280            return dict_batch
281
282        # tuple data
283        list_batch: List[list] = [[] for _ in indices]
284        for dataset in self.datasets:
285            if callable(getattr(dataset, "__getitems__", None)):
286                items = dataset.__getitems__(indices)  # type: ignore[attr-defined]
287                if len(items) != len(indices):
288                    raise ValueError(
289                        "Nested dataset's output size mismatch."
290                        f" Expected {len(indices)}, got {len(items)}"
291                    )
292                for data, t_sample in zip(items, list_batch):
293                    t_sample.append(data)
294            else:
295                for idx, t_sample in zip(indices, list_batch):
296                    t_sample.append(dataset[idx])
297        tuple_batch: List[_T_tuple] = [tuple(sample) for sample in list_batch]
298        return tuple_batch
299
300    def __len__(self):
301        return self._length
302
303
304class ConcatDataset(Dataset[_T_co]):
305    r"""Dataset as a concatenation of multiple datasets.
306
307    This class is useful to assemble different existing datasets.
308
309    Args:
310        datasets (sequence): List of datasets to be concatenated
311    """
312
313    datasets: List[Dataset[_T_co]]
314    cumulative_sizes: List[int]
315
316    @staticmethod
317    def cumsum(sequence):
318        r, s = [], 0
319        for e in sequence:
320            l = len(e)
321            r.append(l + s)
322            s += l
323        return r
324
325    def __init__(self, datasets: Iterable[Dataset]) -> None:
326        super().__init__()
327        self.datasets = list(datasets)
328        assert len(self.datasets) > 0, "datasets should not be an empty iterable"  # type: ignore[arg-type]
329        for d in self.datasets:
330            assert not isinstance(
331                d, IterableDataset
332            ), "ConcatDataset does not support IterableDataset"
333        self.cumulative_sizes = self.cumsum(self.datasets)
334
335    def __len__(self):
336        return self.cumulative_sizes[-1]
337
338    def __getitem__(self, idx):
339        if idx < 0:
340            if -idx > len(self):
341                raise ValueError(
342                    "absolute value of index should not exceed dataset length"
343                )
344            idx = len(self) + idx
345        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
346        if dataset_idx == 0:
347            sample_idx = idx
348        else:
349            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
350        return self.datasets[dataset_idx][sample_idx]
351
352    @property
353    @deprecated(
354        "`cummulative_sizes` attribute is renamed to `cumulative_sizes`",
355        category=FutureWarning,
356    )
357    def cummulative_sizes(self):
358        return self.cumulative_sizes
359
360
361class ChainDataset(IterableDataset):
362    r"""Dataset for chaining multiple :class:`IterableDataset` s.
363
364    This class is useful to assemble different existing dataset streams. The
365    chaining operation is done on-the-fly, so concatenating large-scale
366    datasets with this class will be efficient.
367
368    Args:
369        datasets (iterable of IterableDataset): datasets to be chained together
370    """
371
372    def __init__(self, datasets: Iterable[Dataset]) -> None:
373        super().__init__()
374        self.datasets = datasets
375
376    def __iter__(self):
377        for d in self.datasets:
378            assert isinstance(
379                d, IterableDataset
380            ), "ChainDataset only supports IterableDataset"
381            yield from d
382
383    def __len__(self):
384        total = 0
385        for d in self.datasets:
386            assert isinstance(
387                d, IterableDataset
388            ), "ChainDataset only supports IterableDataset"
389            total += len(d)  # type: ignore[arg-type]
390        return total
391
392
393class Subset(Dataset[_T_co]):
394    r"""
395    Subset of a dataset at specified indices.
396
397    Args:
398        dataset (Dataset): The whole Dataset
399        indices (sequence): Indices in the whole set selected for subset
400    """
401
402    dataset: Dataset[_T_co]
403    indices: Sequence[int]
404
405    def __init__(self, dataset: Dataset[_T_co], indices: Sequence[int]) -> None:
406        self.dataset = dataset
407        self.indices = indices
408
409    def __getitem__(self, idx):
410        if isinstance(idx, list):
411            return self.dataset[[self.indices[i] for i in idx]]
412        return self.dataset[self.indices[idx]]
413
414    def __getitems__(self, indices: List[int]) -> List[_T_co]:
415        # add batched sampling support when parent dataset supports it.
416        # see torch.utils.data._utils.fetch._MapDatasetFetcher
417        if callable(getattr(self.dataset, "__getitems__", None)):
418            return self.dataset.__getitems__([self.indices[idx] for idx in indices])  # type: ignore[attr-defined]
419        else:
420            return [self.dataset[self.indices[idx]] for idx in indices]
421
422    def __len__(self):
423        return len(self.indices)
424
425
426def random_split(
427    dataset: Dataset[_T],
428    lengths: Sequence[Union[int, float]],
429    generator: Optional[Generator] = default_generator,
430) -> List[Subset[_T]]:
431    r"""
432    Randomly split a dataset into non-overlapping new datasets of given lengths.
433
434    If a list of fractions that sum up to 1 is given,
435    the lengths will be computed automatically as
436    floor(frac * len(dataset)) for each fraction provided.
437
438    After computing the lengths, if there are any remainders, 1 count will be
439    distributed in round-robin fashion to the lengths
440    until there are no remainders left.
441
442    Optionally fix the generator for reproducible results, e.g.:
443
444    Example:
445        >>> # xdoctest: +SKIP
446        >>> generator1 = torch.Generator().manual_seed(42)
447        >>> generator2 = torch.Generator().manual_seed(42)
448        >>> random_split(range(10), [3, 7], generator=generator1)
449        >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
450
451    Args:
452        dataset (Dataset): Dataset to be split
453        lengths (sequence): lengths or fractions of splits to be produced
454        generator (Generator): Generator used for the random permutation.
455    """
456    if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
457        subset_lengths: List[int] = []
458        for i, frac in enumerate(lengths):
459            if frac < 0 or frac > 1:
460                raise ValueError(f"Fraction at index {i} is not between 0 and 1")
461            n_items_in_split = int(
462                math.floor(len(dataset) * frac)  # type: ignore[arg-type]
463            )
464            subset_lengths.append(n_items_in_split)
465        remainder = len(dataset) - sum(subset_lengths)  # type: ignore[arg-type]
466        # add 1 to all the lengths in round-robin fashion until the remainder is 0
467        for i in range(remainder):
468            idx_to_add_at = i % len(subset_lengths)
469            subset_lengths[idx_to_add_at] += 1
470        lengths = subset_lengths
471        for i, length in enumerate(lengths):
472            if length == 0:
473                warnings.warn(
474                    f"Length of split at index {i} is 0. "
475                    f"This might result in an empty dataset."
476                )
477
478    # Cannot verify that dataset is Sized
479    if sum(lengths) != len(dataset):  # type: ignore[arg-type]
480        raise ValueError(
481            "Sum of input lengths does not equal the length of the input dataset!"
482        )
483
484    indices = randperm(sum(lengths), generator=generator).tolist()  # type: ignore[arg-type, call-overload]
485    lengths = cast(Sequence[int], lengths)
486    return [
487        Subset(dataset, indices[offset - length : offset])
488        for offset, length in zip(itertools.accumulate(lengths), lengths)
489    ]
490