1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9 10 11class CyclingIterator: 12 """ 13 An iterator decorator that cycles through the 14 underlying iterator "n" times. Useful to "unroll" 15 the dataset across multiple training epochs. 16 17 The generator function is called as ``generator_fn(epoch)`` 18 to obtain the underlying iterator, where ``epoch`` is a 19 number less than or equal to ``n`` representing the ``k``th cycle 20 21 For example if ``generator_fn`` always returns ``[1,2,3]`` 22 then ``CyclingIterator(n=2, generator_fn)`` will iterate through 23 ``[1,2,3,1,2,3]`` 24 """ 25 26 def __init__(self, n: int, generator_fn, start_epoch=0): 27 self._n = n 28 self._epoch = start_epoch 29 self._generator_fn = generator_fn 30 self._iter = generator_fn(self._epoch) 31 32 def __iter__(self): 33 return self 34 35 def __next__(self): 36 try: 37 return next(self._iter) 38 except StopIteration as eod: # eod == end of data 39 if self._epoch < self._n - 1: 40 self._epoch += 1 41 self._iter = self._generator_fn(self._epoch) 42 return self.__next__() 43 else: 44 raise eod 45