xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/utils/data/cycling_iterator_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9import unittest
10
11from torch.distributed.elastic.utils.data import CyclingIterator
12
13
14class CyclingIteratorTest(unittest.TestCase):
15    def generator(self, epoch, stride, max_epochs):
16        # generate an continuously incrementing list each epoch
17        # e.g. [0,1,2] [3,4,5] [6,7,8] ...
18        return iter([stride * epoch + i for i in range(0, stride)])
19
20    def test_cycling_iterator(self):
21        stride = 3
22        max_epochs = 90
23
24        def generator_fn(epoch):
25            return self.generator(epoch, stride, max_epochs)
26
27        it = CyclingIterator(n=max_epochs, generator_fn=generator_fn)
28        for i in range(0, stride * max_epochs):
29            self.assertEqual(i, next(it))
30
31        with self.assertRaises(StopIteration):
32            next(it)
33
34    def test_cycling_iterator_start_epoch(self):
35        stride = 3
36        max_epochs = 2
37        start_epoch = 1
38
39        def generator_fn(epoch):
40            return self.generator(epoch, stride, max_epochs)
41
42        it = CyclingIterator(max_epochs, generator_fn, start_epoch)
43        for i in range(stride * start_epoch, stride * max_epochs):
44            self.assertEqual(i, next(it))
45
46        with self.assertRaises(StopIteration):
47            next(it)
48