1#!/usr/bin/env python3 2# Owner(s): ["oncall: r2p"] 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9import unittest 10 11from torch.distributed.elastic.utils.data import CyclingIterator 12 13 14class CyclingIteratorTest(unittest.TestCase): 15 def generator(self, epoch, stride, max_epochs): 16 # generate an continuously incrementing list each epoch 17 # e.g. [0,1,2] [3,4,5] [6,7,8] ... 18 return iter([stride * epoch + i for i in range(0, stride)]) 19 20 def test_cycling_iterator(self): 21 stride = 3 22 max_epochs = 90 23 24 def generator_fn(epoch): 25 return self.generator(epoch, stride, max_epochs) 26 27 it = CyclingIterator(n=max_epochs, generator_fn=generator_fn) 28 for i in range(0, stride * max_epochs): 29 self.assertEqual(i, next(it)) 30 31 with self.assertRaises(StopIteration): 32 next(it) 33 34 def test_cycling_iterator_start_epoch(self): 35 stride = 3 36 max_epochs = 2 37 start_epoch = 1 38 39 def generator_fn(epoch): 40 return self.generator(epoch, stride, max_epochs) 41 42 it = CyclingIterator(max_epochs, generator_fn, start_epoch) 43 for i in range(stride * start_epoch, stride * max_epochs): 44 self.assertEqual(i, next(it)) 45 46 with self.assertRaises(StopIteration): 47 next(it) 48