xref: /aosp_15_r20/external/pytorch/torch/utils/data/sampler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import (
3    Generic,
4    Iterable,
5    Iterator,
6    List,
7    Optional,
8    Sequence,
9    Sized,
10    TypeVar,
11    Union,
12)
13
14import torch
15
16
17__all__ = [
18    "BatchSampler",
19    "RandomSampler",
20    "Sampler",
21    "SequentialSampler",
22    "SubsetRandomSampler",
23    "WeightedRandomSampler",
24]
25
26
27_T_co = TypeVar("_T_co", covariant=True)
28
29
30class Sampler(Generic[_T_co]):
31    r"""Base class for all Samplers.
32
33    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
34    way to iterate over indices or lists of indices (batches) of dataset elements,
35    and may provide a :meth:`__len__` method that returns the length of the returned iterators.
36
37    Args:
38        data_source (Dataset): This argument is not used and will be removed in 2.2.0.
39            You may still have custom implementation that utilizes it.
40
41    Example:
42        >>> # xdoctest: +SKIP
43        >>> class AccedingSequenceLengthSampler(Sampler[int]):
44        >>>     def __init__(self, data: List[str]) -> None:
45        >>>         self.data = data
46        >>>
47        >>>     def __len__(self) -> int:
48        >>>         return len(self.data)
49        >>>
50        >>>     def __iter__(self) -> Iterator[int]:
51        >>>         sizes = torch.tensor([len(x) for x in self.data])
52        >>>         yield from torch.argsort(sizes).tolist()
53        >>>
54        >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
55        >>>     def __init__(self, data: List[str], batch_size: int) -> None:
56        >>>         self.data = data
57        >>>         self.batch_size = batch_size
58        >>>
59        >>>     def __len__(self) -> int:
60        >>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
61        >>>
62        >>>     def __iter__(self) -> Iterator[List[int]]:
63        >>>         sizes = torch.tensor([len(x) for x in self.data])
64        >>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
65        >>>             yield batch.tolist()
66
67    .. note:: The :meth:`__len__` method isn't strictly required by
68              :class:`~torch.utils.data.DataLoader`, but is expected in any
69              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
70    """
71
72    def __init__(self, data_source: Optional[Sized] = None) -> None:
73        if data_source is not None:
74            import warnings
75
76            warnings.warn(
77                "`data_source` argument is not used and will be removed in 2.2.0."
78                "You may still have custom implementation that utilizes it."
79            )
80
81    def __iter__(self) -> Iterator[_T_co]:
82        raise NotImplementedError
83
84    # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
85    #
86    # Many times we have an abstract class representing a collection/iterable of
87    # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
88    # implementing a `__len__` method. In such cases, we must make sure to not
89    # provide a default implementation, because both straightforward default
90    # implementations have their issues:
91    #
92    #   + `return NotImplemented`:
93    #     Calling `len(subclass_instance)` raises:
94    #       TypeError: 'NotImplementedType' object cannot be interpreted as an integer
95    #
96    #   + `raise NotImplementedError`:
97    #     This prevents triggering some fallback behavior. E.g., the built-in
98    #     `list(X)` tries to call `len(X)` first, and executes a different code
99    #     path if the method is not found or `NotImplemented` is returned, while
100    #     raising a `NotImplementedError` will propagate and make the call fail
101    #     where it could have used `__iter__` to complete the call.
102    #
103    # Thus, the only two sensible things to do are
104    #
105    #   + **not** provide a default `__len__`.
106    #
107    #   + raise a `TypeError` instead, which is what Python uses when users call
108    #     a method that is not defined on an object.
109    #     (@ssnl verifies that this works on at least Python 3.7.)
110
111
112class SequentialSampler(Sampler[int]):
113    r"""Samples elements sequentially, always in the same order.
114
115    Args:
116        data_source (Dataset): dataset to sample from
117    """
118
119    data_source: Sized
120
121    def __init__(self, data_source: Sized) -> None:
122        self.data_source = data_source
123
124    def __iter__(self) -> Iterator[int]:
125        return iter(range(len(self.data_source)))
126
127    def __len__(self) -> int:
128        return len(self.data_source)
129
130
131class RandomSampler(Sampler[int]):
132    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
133
134    If with replacement, then user can specify :attr:`num_samples` to draw.
135
136    Args:
137        data_source (Dataset): dataset to sample from
138        replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
139        num_samples (int): number of samples to draw, default=`len(dataset)`.
140        generator (Generator): Generator used in sampling.
141    """
142
143    data_source: Sized
144    replacement: bool
145
146    def __init__(
147        self,
148        data_source: Sized,
149        replacement: bool = False,
150        num_samples: Optional[int] = None,
151        generator=None,
152    ) -> None:
153        self.data_source = data_source
154        self.replacement = replacement
155        self._num_samples = num_samples
156        self.generator = generator
157
158        if not isinstance(self.replacement, bool):
159            raise TypeError(
160                f"replacement should be a boolean value, but got replacement={self.replacement}"
161            )
162
163        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
164            raise ValueError(
165                f"num_samples should be a positive integer value, but got num_samples={self.num_samples}"
166            )
167
168    @property
169    def num_samples(self) -> int:
170        # dataset size might change at runtime
171        if self._num_samples is None:
172            return len(self.data_source)
173        return self._num_samples
174
175    def __iter__(self) -> Iterator[int]:
176        n = len(self.data_source)
177        if self.generator is None:
178            seed = int(torch.empty((), dtype=torch.int64).random_().item())
179            generator = torch.Generator()
180            generator.manual_seed(seed)
181        else:
182            generator = self.generator
183
184        if self.replacement:
185            for _ in range(self.num_samples // 32):
186                yield from torch.randint(
187                    high=n, size=(32,), dtype=torch.int64, generator=generator
188                ).tolist()
189            yield from torch.randint(
190                high=n,
191                size=(self.num_samples % 32,),
192                dtype=torch.int64,
193                generator=generator,
194            ).tolist()
195        else:
196            for _ in range(self.num_samples // n):
197                yield from torch.randperm(n, generator=generator).tolist()
198            yield from torch.randperm(n, generator=generator).tolist()[
199                : self.num_samples % n
200            ]
201
202    def __len__(self) -> int:
203        return self.num_samples
204
205
206class SubsetRandomSampler(Sampler[int]):
207    r"""Samples elements randomly from a given list of indices, without replacement.
208
209    Args:
210        indices (sequence): a sequence of indices
211        generator (Generator): Generator used in sampling.
212    """
213
214    indices: Sequence[int]
215
216    def __init__(self, indices: Sequence[int], generator=None) -> None:
217        self.indices = indices
218        self.generator = generator
219
220    def __iter__(self) -> Iterator[int]:
221        for i in torch.randperm(len(self.indices), generator=self.generator):
222            yield self.indices[i]
223
224    def __len__(self) -> int:
225        return len(self.indices)
226
227
228class WeightedRandomSampler(Sampler[int]):
229    r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
230
231    Args:
232        weights (sequence)   : a sequence of weights, not necessary summing up to one
233        num_samples (int): number of samples to draw
234        replacement (bool): if ``True``, samples are drawn with replacement.
235            If not, they are drawn without replacement, which means that when a
236            sample index is drawn for a row, it cannot be drawn again for that row.
237        generator (Generator): Generator used in sampling.
238
239    Example:
240        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
241        >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
242        [4, 4, 1, 4, 5]
243        >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
244        [0, 1, 4, 3, 2]
245    """
246
247    weights: torch.Tensor
248    num_samples: int
249    replacement: bool
250
251    def __init__(
252        self,
253        weights: Sequence[float],
254        num_samples: int,
255        replacement: bool = True,
256        generator=None,
257    ) -> None:
258        if (
259            not isinstance(num_samples, int)
260            or isinstance(num_samples, bool)
261            or num_samples <= 0
262        ):
263            raise ValueError(
264                f"num_samples should be a positive integer value, but got num_samples={num_samples}"
265            )
266        if not isinstance(replacement, bool):
267            raise ValueError(
268                f"replacement should be a boolean value, but got replacement={replacement}"
269            )
270
271        weights_tensor = torch.as_tensor(weights, dtype=torch.double)
272        if len(weights_tensor.shape) != 1:
273            raise ValueError(
274                "weights should be a 1d sequence but given "
275                f"weights have shape {tuple(weights_tensor.shape)}"
276            )
277
278        self.weights = weights_tensor
279        self.num_samples = num_samples
280        self.replacement = replacement
281        self.generator = generator
282
283    def __iter__(self) -> Iterator[int]:
284        rand_tensor = torch.multinomial(
285            self.weights, self.num_samples, self.replacement, generator=self.generator
286        )
287        yield from iter(rand_tensor.tolist())
288
289    def __len__(self) -> int:
290        return self.num_samples
291
292
293class BatchSampler(Sampler[List[int]]):
294    r"""Wraps another sampler to yield a mini-batch of indices.
295
296    Args:
297        sampler (Sampler or Iterable): Base sampler. Can be any iterable object
298        batch_size (int): Size of mini-batch.
299        drop_last (bool): If ``True``, the sampler will drop the last batch if
300            its size would be less than ``batch_size``
301
302    Example:
303        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
304        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
305        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
306        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
307    """
308
309    def __init__(
310        self,
311        sampler: Union[Sampler[int], Iterable[int]],
312        batch_size: int,
313        drop_last: bool,
314    ) -> None:
315        # Since collections.abc.Iterable does not check for `__getitem__`, which
316        # is one way for an object to be an iterable, we don't do an `isinstance`
317        # check here.
318        if (
319            not isinstance(batch_size, int)
320            or isinstance(batch_size, bool)
321            or batch_size <= 0
322        ):
323            raise ValueError(
324                f"batch_size should be a positive integer value, but got batch_size={batch_size}"
325            )
326        if not isinstance(drop_last, bool):
327            raise ValueError(
328                f"drop_last should be a boolean value, but got drop_last={drop_last}"
329            )
330        self.sampler = sampler
331        self.batch_size = batch_size
332        self.drop_last = drop_last
333
334    def __iter__(self) -> Iterator[List[int]]:
335        # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
336        if self.drop_last:
337            sampler_iter = iter(self.sampler)
338            while True:
339                try:
340                    batch = [next(sampler_iter) for _ in range(self.batch_size)]
341                    yield batch
342                except StopIteration:
343                    break
344        else:
345            batch = [0] * self.batch_size
346            idx_in_batch = 0
347            for idx in self.sampler:
348                batch[idx_in_batch] = idx
349                idx_in_batch += 1
350                if idx_in_batch == self.batch_size:
351                    yield batch
352                    idx_in_batch = 0
353                    batch = [0] * self.batch_size
354            if idx_in_batch > 0:
355                yield batch[:idx_in_batch]
356
357    def __len__(self) -> int:
358        # Can only be called if self.sampler has __len__ implemented
359        # We cannot enforce this condition, so we turn off typechecking for the
360        # implementation below.
361        # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
362        if self.drop_last:
363            return len(self.sampler) // self.batch_size  # type: ignore[arg-type]
364        else:
365            return (len(self.sampler) + self.batch_size - 1) // self.batch_size  # type: ignore[arg-type]
366