xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/utils/data/cycling_iterator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10
11class CyclingIterator:
12    """
13    An iterator decorator that cycles through the
14    underlying iterator "n" times. Useful to "unroll"
15    the dataset across multiple training epochs.
16
17    The generator function is called as ``generator_fn(epoch)``
18    to obtain the underlying iterator, where ``epoch`` is a
19    number less than or equal to ``n`` representing the ``k``th cycle
20
21    For example if ``generator_fn`` always returns ``[1,2,3]``
22    then ``CyclingIterator(n=2, generator_fn)`` will iterate through
23    ``[1,2,3,1,2,3]``
24    """
25
26    def __init__(self, n: int, generator_fn, start_epoch=0):
27        self._n = n
28        self._epoch = start_epoch
29        self._generator_fn = generator_fn
30        self._iter = generator_fn(self._epoch)
31
32    def __iter__(self):
33        return self
34
35    def __next__(self):
36        try:
37            return next(self._iter)
38        except StopIteration as eod:  # eod == end of data
39            if self._epoch < self._n - 1:
40                self._epoch += 1
41                self._iter = self._generator_fn(self._epoch)
42                return self.__next__()
43            else:
44                raise eod
45