xref: /aosp_15_r20/external/pytorch/test/cpp/api/dataloader.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ArrayRef.h>
8*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
9*da0073e9SAndroid Build Coastguard Worker #include <c10/util/tempfile.h>
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker #include <algorithm>
12*da0073e9SAndroid Build Coastguard Worker #include <chrono>
13*da0073e9SAndroid Build Coastguard Worker #include <future>
14*da0073e9SAndroid Build Coastguard Worker #include <iostream>
15*da0073e9SAndroid Build Coastguard Worker #include <iterator>
16*da0073e9SAndroid Build Coastguard Worker #include <limits>
17*da0073e9SAndroid Build Coastguard Worker #include <mutex>
18*da0073e9SAndroid Build Coastguard Worker #include <numeric>
19*da0073e9SAndroid Build Coastguard Worker #include <stdexcept>
20*da0073e9SAndroid Build Coastguard Worker #include <string>
21*da0073e9SAndroid Build Coastguard Worker #include <thread>
22*da0073e9SAndroid Build Coastguard Worker #include <unordered_set>
23*da0073e9SAndroid Build Coastguard Worker #include <vector>
24*da0073e9SAndroid Build Coastguard Worker 
25*da0073e9SAndroid Build Coastguard Worker using namespace torch::data; // NOLINT
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker const std::chrono::milliseconds kMillisecond(1);
28*da0073e9SAndroid Build Coastguard Worker 
29*da0073e9SAndroid Build Coastguard Worker struct DummyDataset : datasets::Dataset<DummyDataset, int> {
DummyDatasetDummyDataset30*da0073e9SAndroid Build Coastguard Worker   explicit DummyDataset(size_t size = 100) : size_(size) {}
31*da0073e9SAndroid Build Coastguard Worker 
getDummyDataset32*da0073e9SAndroid Build Coastguard Worker   int get(size_t index) override {
33*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
34*da0073e9SAndroid Build Coastguard Worker     return 1 + index;
35*da0073e9SAndroid Build Coastguard Worker   }
sizeDummyDataset36*da0073e9SAndroid Build Coastguard Worker   torch::optional<size_t> size() const override {
37*da0073e9SAndroid Build Coastguard Worker     return size_;
38*da0073e9SAndroid Build Coastguard Worker   }
39*da0073e9SAndroid Build Coastguard Worker 
40*da0073e9SAndroid Build Coastguard Worker   size_t size_;
41*da0073e9SAndroid Build Coastguard Worker };
42*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,DatasetCallsGetCorrectly)43*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, DatasetCallsGetCorrectly) {
44*da0073e9SAndroid Build Coastguard Worker   DummyDataset d;
45*da0073e9SAndroid Build Coastguard Worker   std::vector<int> batch = d.get_batch({0, 1, 2, 3, 4});
46*da0073e9SAndroid Build Coastguard Worker   std::vector<int> expected = {1, 2, 3, 4, 5};
47*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(batch, expected);
48*da0073e9SAndroid Build Coastguard Worker }
49*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,TransformCallsGetApplyCorrectly)50*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, TransformCallsGetApplyCorrectly) {
51*da0073e9SAndroid Build Coastguard Worker   struct T : transforms::Transform<int, std::string> {
52*da0073e9SAndroid Build Coastguard Worker     std::string apply(int input) override {
53*da0073e9SAndroid Build Coastguard Worker       return std::to_string(input);
54*da0073e9SAndroid Build Coastguard Worker     }
55*da0073e9SAndroid Build Coastguard Worker   };
56*da0073e9SAndroid Build Coastguard Worker 
57*da0073e9SAndroid Build Coastguard Worker   auto d = DummyDataset{}.map(T{});
58*da0073e9SAndroid Build Coastguard Worker   std::vector<std::string> batch = d.get_batch({0, 1, 2, 3, 4});
59*da0073e9SAndroid Build Coastguard Worker   std::vector<std::string> expected = {"1", "2", "3", "4", "5"};
60*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(batch, expected);
61*da0073e9SAndroid Build Coastguard Worker }
62*da0073e9SAndroid Build Coastguard Worker 
63*da0073e9SAndroid Build Coastguard Worker // dummy chunk data reader with 3 chunks and 35 examples in total. Each chunk
64*da0073e9SAndroid Build Coastguard Worker // contains 10, 5, 20 examples respectively.
65*da0073e9SAndroid Build Coastguard Worker 
66*da0073e9SAndroid Build Coastguard Worker struct DummyChunkDataReader : public datasets::ChunkDataReader<int> {
67*da0073e9SAndroid Build Coastguard Worker  public:
68*da0073e9SAndroid Build Coastguard Worker   using BatchType = datasets::ChunkDataReader<int>::ChunkType;
69*da0073e9SAndroid Build Coastguard Worker   using DataType = datasets::ChunkDataReader<int>::ExampleType;
70*da0073e9SAndroid Build Coastguard Worker 
71*da0073e9SAndroid Build Coastguard Worker   /// Read an entire chunk.
read_chunkDummyChunkDataReader72*da0073e9SAndroid Build Coastguard Worker   BatchType read_chunk(size_t chunk_index) override {
73*da0073e9SAndroid Build Coastguard Worker     BatchType batch_data;
74*da0073e9SAndroid Build Coastguard Worker     int start_index = chunk_index == 0
75*da0073e9SAndroid Build Coastguard Worker         ? 0
76*da0073e9SAndroid Build Coastguard Worker         // NOLINTNEXTLINE(bugprone-fold-init-type)
77*da0073e9SAndroid Build Coastguard Worker         : std::accumulate(chunk_sizes, chunk_sizes + chunk_index, 0);
78*da0073e9SAndroid Build Coastguard Worker 
79*da0073e9SAndroid Build Coastguard Worker     batch_data.resize(chunk_sizes[chunk_index]);
80*da0073e9SAndroid Build Coastguard Worker 
81*da0073e9SAndroid Build Coastguard Worker     std::iota(batch_data.begin(), batch_data.end(), start_index);
82*da0073e9SAndroid Build Coastguard Worker 
83*da0073e9SAndroid Build Coastguard Worker     return batch_data;
84*da0073e9SAndroid Build Coastguard Worker   }
85*da0073e9SAndroid Build Coastguard Worker 
chunk_countDummyChunkDataReader86*da0073e9SAndroid Build Coastguard Worker   size_t chunk_count() override {
87*da0073e9SAndroid Build Coastguard Worker     return chunk_count_;
88*da0073e9SAndroid Build Coastguard Worker   };
89*da0073e9SAndroid Build Coastguard Worker 
resetDummyChunkDataReader90*da0073e9SAndroid Build Coastguard Worker   void reset() override{};
91*da0073e9SAndroid Build Coastguard Worker 
92*da0073e9SAndroid Build Coastguard Worker   const static size_t chunk_count_ = 3;
93*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
94*da0073e9SAndroid Build Coastguard Worker   size_t chunk_sizes[chunk_count_] = {10, 5, 20};
95*da0073e9SAndroid Build Coastguard Worker };
96*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,ChunkDataSetWithInvalidInitParameter)97*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {
98*da0073e9SAndroid Build Coastguard Worker   DummyChunkDataReader data_reader;
99*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler sampler(0);
100*da0073e9SAndroid Build Coastguard Worker 
101*da0073e9SAndroid Build Coastguard Worker   auto initialization_function = [&](size_t preloader_count,
102*da0073e9SAndroid Build Coastguard Worker                                      size_t batch_size,
103*da0073e9SAndroid Build Coastguard Worker                                      size_t cache_size,
104*da0073e9SAndroid Build Coastguard Worker                                      size_t cross_chunk_shuffle_count = 1) {
105*da0073e9SAndroid Build Coastguard Worker     datasets::SharedBatchDataset<datasets::ChunkDataset<
106*da0073e9SAndroid Build Coastguard Worker         DummyChunkDataReader,
107*da0073e9SAndroid Build Coastguard Worker         samplers::SequentialSampler,
108*da0073e9SAndroid Build Coastguard Worker         samplers::SequentialSampler>>
109*da0073e9SAndroid Build Coastguard Worker         dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
110*da0073e9SAndroid Build Coastguard Worker             DummyChunkDataReader,
111*da0073e9SAndroid Build Coastguard Worker             samplers::SequentialSampler,
112*da0073e9SAndroid Build Coastguard Worker             samplers::SequentialSampler>>(
113*da0073e9SAndroid Build Coastguard Worker             data_reader,
114*da0073e9SAndroid Build Coastguard Worker             sampler,
115*da0073e9SAndroid Build Coastguard Worker             sampler,
116*da0073e9SAndroid Build Coastguard Worker             datasets::ChunkDatasetOptions(
117*da0073e9SAndroid Build Coastguard Worker                 preloader_count,
118*da0073e9SAndroid Build Coastguard Worker                 batch_size,
119*da0073e9SAndroid Build Coastguard Worker                 cache_size,
120*da0073e9SAndroid Build Coastguard Worker                 cross_chunk_shuffle_count));
121*da0073e9SAndroid Build Coastguard Worker   };
122*da0073e9SAndroid Build Coastguard Worker 
123*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
124*da0073e9SAndroid Build Coastguard Worker       initialization_function(0, 1, 1),
125*da0073e9SAndroid Build Coastguard Worker       "Preloader count is 0. At least one preloader needs to be specified.");
126*da0073e9SAndroid Build Coastguard Worker 
127*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
128*da0073e9SAndroid Build Coastguard Worker       initialization_function(1, 0, 1),
129*da0073e9SAndroid Build Coastguard Worker       "Batch size is 0. A positive batch size needs to be specified.");
130*da0073e9SAndroid Build Coastguard Worker 
131*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
132*da0073e9SAndroid Build Coastguard Worker       initialization_function(1, 1, 0),
133*da0073e9SAndroid Build Coastguard Worker       "Cache size is 0. A positive cache size needs to be specified.");
134*da0073e9SAndroid Build Coastguard Worker 
135*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
136*da0073e9SAndroid Build Coastguard Worker       initialization_function(1, 10, 5),
137*da0073e9SAndroid Build Coastguard Worker       "Cache size is less than batch size. Cache needs to be large enough to "
138*da0073e9SAndroid Build Coastguard Worker       "hold at least one batch.");
139*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
140*da0073e9SAndroid Build Coastguard Worker       initialization_function(1, 10, 20, 0),
141*da0073e9SAndroid Build Coastguard Worker       "cross_chunk_shuffle_count needs to be greater than 0.");
142*da0073e9SAndroid Build Coastguard Worker }
143*da0073e9SAndroid Build Coastguard Worker 
144*da0073e9SAndroid Build Coastguard Worker struct InfiniteStreamDataset
145*da0073e9SAndroid Build Coastguard Worker     : datasets::StreamDataset<InfiniteStreamDataset, std::vector<int>> {
get_batchInfiniteStreamDataset146*da0073e9SAndroid Build Coastguard Worker   std::vector<int> get_batch(size_t batch_size) override {
147*da0073e9SAndroid Build Coastguard Worker     std::vector<int> batch(batch_size);
148*da0073e9SAndroid Build Coastguard Worker     for (auto& i : batch) {
149*da0073e9SAndroid Build Coastguard Worker       i = counter++;
150*da0073e9SAndroid Build Coastguard Worker     }
151*da0073e9SAndroid Build Coastguard Worker     return batch;
152*da0073e9SAndroid Build Coastguard Worker   }
153*da0073e9SAndroid Build Coastguard Worker 
sizeInfiniteStreamDataset154*da0073e9SAndroid Build Coastguard Worker   torch::optional<size_t> size() const override {
155*da0073e9SAndroid Build Coastguard Worker     return torch::nullopt;
156*da0073e9SAndroid Build Coastguard Worker   }
157*da0073e9SAndroid Build Coastguard Worker 
158*da0073e9SAndroid Build Coastguard Worker   size_t counter = 0;
159*da0073e9SAndroid Build Coastguard Worker };
160*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,InfiniteStreamDataset)161*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, InfiniteStreamDataset) {
162*da0073e9SAndroid Build Coastguard Worker   const size_t kBatchSize = 13;
163*da0073e9SAndroid Build Coastguard Worker 
164*da0073e9SAndroid Build Coastguard Worker   auto dataset = InfiniteStreamDataset().map(
165*da0073e9SAndroid Build Coastguard Worker       transforms::Lambda<int>([](int x) { return x + 1; }));
166*da0073e9SAndroid Build Coastguard Worker 
167*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
168*da0073e9SAndroid Build Coastguard Worker       std::move(dataset),
169*da0073e9SAndroid Build Coastguard Worker       samplers::StreamSampler(/*epoch_size=*/39),
170*da0073e9SAndroid Build Coastguard Worker       kBatchSize);
171*da0073e9SAndroid Build Coastguard Worker 
172*da0073e9SAndroid Build Coastguard Worker   size_t batch_index = 0;
173*da0073e9SAndroid Build Coastguard Worker   for (auto& batch : *data_loader) {
174*da0073e9SAndroid Build Coastguard Worker     ASSERT_LT(batch_index, 3);
175*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(batch.size(), kBatchSize);
176*da0073e9SAndroid Build Coastguard Worker     for (const auto j : c10::irange(kBatchSize)) {
177*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(batch.at(j), 1 + (batch_index * kBatchSize) + j);
178*da0073e9SAndroid Build Coastguard Worker     }
179*da0073e9SAndroid Build Coastguard Worker     batch_index += 1;
180*da0073e9SAndroid Build Coastguard Worker   }
181*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(batch_index, 3);
182*da0073e9SAndroid Build Coastguard Worker }
183*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,NoSequencerIsIdentity)184*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, NoSequencerIsIdentity) {
185*da0073e9SAndroid Build Coastguard Worker   using namespace torch::data::detail::sequencers; // NOLINT
186*da0073e9SAndroid Build Coastguard Worker   NoSequencer<int> no_sequencer;
187*da0073e9SAndroid Build Coastguard Worker   const auto value = no_sequencer.next([] { return 5; }).value();
188*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(value, 5);
189*da0073e9SAndroid Build Coastguard Worker }
190*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,OrderedSequencerIsSetUpWell)191*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, OrderedSequencerIsSetUpWell) {
192*da0073e9SAndroid Build Coastguard Worker   using namespace torch::data::detail::sequencers; // NOLINT
193*da0073e9SAndroid Build Coastguard Worker   struct S {
194*da0073e9SAndroid Build Coastguard Worker     size_t sequence_number;
195*da0073e9SAndroid Build Coastguard Worker   };
196*da0073e9SAndroid Build Coastguard Worker   const size_t kMaxJobs = 5;
197*da0073e9SAndroid Build Coastguard Worker   OrderedSequencer<S> sequencer(kMaxJobs);
198*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequencer.next_sequence_number_, 0);
199*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequencer.buffer_.size(), kMaxJobs);
200*da0073e9SAndroid Build Coastguard Worker }
201*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,OrderedSequencerReOrdersValues)202*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, OrderedSequencerReOrdersValues) {
203*da0073e9SAndroid Build Coastguard Worker   using namespace torch::data::detail::sequencers; // NOLINT
204*da0073e9SAndroid Build Coastguard Worker   struct S {
205*da0073e9SAndroid Build Coastguard Worker     size_t sequence_number;
206*da0073e9SAndroid Build Coastguard Worker   };
207*da0073e9SAndroid Build Coastguard Worker   const size_t kMaxJobs = 5;
208*da0073e9SAndroid Build Coastguard Worker   OrderedSequencer<S> sequencer(kMaxJobs);
209*da0073e9SAndroid Build Coastguard Worker 
210*da0073e9SAndroid Build Coastguard Worker   std::vector<size_t> v = {0, 2, 4, 3, 1};
211*da0073e9SAndroid Build Coastguard Worker   size_t index = 0;
212*da0073e9SAndroid Build Coastguard Worker   auto getter = [&v, &index]() { return S{v.at(index++)}; };
213*da0073e9SAndroid Build Coastguard Worker 
214*da0073e9SAndroid Build Coastguard Worker   // Let's say the sequence number matches for the batch one, then it should
215*da0073e9SAndroid Build Coastguard Worker   // return immediately.
216*da0073e9SAndroid Build Coastguard Worker   const auto batch = sequencer.next(getter);
217*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(batch.value().sequence_number, 0);
218*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(index, 1);
219*da0073e9SAndroid Build Coastguard Worker 
220*da0073e9SAndroid Build Coastguard Worker   // Now it should call the getter until it gets the next value.
221*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(1, sequencer.next(getter).value().sequence_number);
222*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(index, 5);
223*da0073e9SAndroid Build Coastguard Worker 
224*da0073e9SAndroid Build Coastguard Worker   // The next three should come in order.
225*da0073e9SAndroid Build Coastguard Worker   for (size_t i = 2; i <= 4; ++i) {
226*da0073e9SAndroid Build Coastguard Worker     // New value doesn't matter. In fact, it shouldn't be accessed.
227*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(i, sequencer.next(getter).value().sequence_number);
228*da0073e9SAndroid Build Coastguard Worker     // The index doesn't change.
229*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(index, 5);
230*da0073e9SAndroid Build Coastguard Worker   }
231*da0073e9SAndroid Build Coastguard Worker }
232*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,BatchLambdaAppliesFunctionToBatch)233*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, BatchLambdaAppliesFunctionToBatch) {
234*da0073e9SAndroid Build Coastguard Worker   using InputBatch = std::vector<int>;
235*da0073e9SAndroid Build Coastguard Worker   using OutputBatch = std::string;
236*da0073e9SAndroid Build Coastguard Worker   DummyDataset d;
237*da0073e9SAndroid Build Coastguard Worker   auto e = d.map(transforms::BatchLambda<InputBatch, OutputBatch>(
238*da0073e9SAndroid Build Coastguard Worker       [](std::vector<int> input) {
239*da0073e9SAndroid Build Coastguard Worker         return std::to_string(std::accumulate(input.begin(), input.end(), 0));
240*da0073e9SAndroid Build Coastguard Worker       }));
241*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(e.get_batch({1, 2, 3, 4, 5}), std::string("20"));
242*da0073e9SAndroid Build Coastguard Worker }
243*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,LambdaAppliesFunctionToExample)244*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, LambdaAppliesFunctionToExample) {
245*da0073e9SAndroid Build Coastguard Worker   auto d = DummyDataset().map(transforms::Lambda<int, std::string>(
246*da0073e9SAndroid Build Coastguard Worker       static_cast<std::string (*)(int)>(std::to_string)));
247*da0073e9SAndroid Build Coastguard Worker   std::vector<std::string> expected = {"1", "2", "3", "4", "5"};
248*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(d.get_batch({0, 1, 2, 3, 4}), expected);
249*da0073e9SAndroid Build Coastguard Worker }
250*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,CollateReducesBatch)251*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, CollateReducesBatch) {
252*da0073e9SAndroid Build Coastguard Worker   auto d =
253*da0073e9SAndroid Build Coastguard Worker       DummyDataset().map(transforms::Collate<int>([](std::vector<int> input) {
254*da0073e9SAndroid Build Coastguard Worker         return std::accumulate(input.begin(), input.end(), 0);
255*da0073e9SAndroid Build Coastguard Worker       }));
256*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20);
257*da0073e9SAndroid Build Coastguard Worker }
258*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,CollationReducesBatch)259*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, CollationReducesBatch) {
260*da0073e9SAndroid Build Coastguard Worker   struct Summer : transforms::Collation<int> {
261*da0073e9SAndroid Build Coastguard Worker     int apply_batch(std::vector<int> input) override {
262*da0073e9SAndroid Build Coastguard Worker       return std::accumulate(input.begin(), input.end(), 0);
263*da0073e9SAndroid Build Coastguard Worker     }
264*da0073e9SAndroid Build Coastguard Worker   };
265*da0073e9SAndroid Build Coastguard Worker   auto d = DummyDataset().map(Summer{});
266*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20);
267*da0073e9SAndroid Build Coastguard Worker }
268*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,SequentialSamplerReturnsIndicesInOrder)269*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, SequentialSamplerReturnsIndicesInOrder) {
270*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler sampler(10);
271*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
272*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({3, 4, 5, 6, 7}));
273*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(2).value(), std::vector<size_t>({8, 9}));
274*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
275*da0073e9SAndroid Build Coastguard Worker }
276*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,SequentialSamplerReturnsLessValuesForLastBatch)277*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, SequentialSamplerReturnsLessValuesForLastBatch) {
278*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler sampler(5);
279*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
280*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(100).value(), std::vector<size_t>({3, 4}));
281*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
282*da0073e9SAndroid Build Coastguard Worker }
283*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,SequentialSamplerResetsWell)284*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, SequentialSamplerResetsWell) {
285*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler sampler(5);
286*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
287*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
288*da0073e9SAndroid Build Coastguard Worker   sampler.reset();
289*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
290*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
291*da0073e9SAndroid Build Coastguard Worker }
292*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,SequentialSamplerResetsWithNewSizeWell)293*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, SequentialSamplerResetsWithNewSizeWell) {
294*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler sampler(5);
295*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
296*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
297*da0073e9SAndroid Build Coastguard Worker   sampler.reset(7);
298*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
299*da0073e9SAndroid Build Coastguard Worker       sampler.next(7).value(), std::vector<size_t>({0, 1, 2, 3, 4, 5, 6}));
300*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
301*da0073e9SAndroid Build Coastguard Worker   sampler.reset(3);
302*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
303*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
304*da0073e9SAndroid Build Coastguard Worker }
305*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,CanSaveAndLoadSequentialSampler)306*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, CanSaveAndLoadSequentialSampler) {
307*da0073e9SAndroid Build Coastguard Worker   {
308*da0073e9SAndroid Build Coastguard Worker     samplers::SequentialSampler a(10);
309*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(a.index(), 0);
310*da0073e9SAndroid Build Coastguard Worker     std::stringstream stream;
311*da0073e9SAndroid Build Coastguard Worker     torch::save(a, stream);
312*da0073e9SAndroid Build Coastguard Worker 
313*da0073e9SAndroid Build Coastguard Worker     samplers::SequentialSampler b(10);
314*da0073e9SAndroid Build Coastguard Worker     torch::load(b, stream);
315*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(b.index(), 0);
316*da0073e9SAndroid Build Coastguard Worker   }
317*da0073e9SAndroid Build Coastguard Worker   {
318*da0073e9SAndroid Build Coastguard Worker     samplers::SequentialSampler a(10);
319*da0073e9SAndroid Build Coastguard Worker     a.next(3);
320*da0073e9SAndroid Build Coastguard Worker     a.next(4);
321*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(a.index(), 7);
322*da0073e9SAndroid Build Coastguard Worker     std::stringstream stream;
323*da0073e9SAndroid Build Coastguard Worker     torch::save(a, stream);
324*da0073e9SAndroid Build Coastguard Worker 
325*da0073e9SAndroid Build Coastguard Worker     samplers::SequentialSampler b(10);
326*da0073e9SAndroid Build Coastguard Worker     torch::load(b, stream);
327*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(b.index(), 7);
328*da0073e9SAndroid Build Coastguard Worker   }
329*da0073e9SAndroid Build Coastguard Worker }
330*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,RandomSamplerReturnsIndicesInCorrectRange)331*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, RandomSamplerReturnsIndicesInCorrectRange) {
332*da0073e9SAndroid Build Coastguard Worker   samplers::RandomSampler sampler(10);
333*da0073e9SAndroid Build Coastguard Worker 
334*da0073e9SAndroid Build Coastguard Worker   std::vector<size_t> indices = sampler.next(3).value();
335*da0073e9SAndroid Build Coastguard Worker   for (auto i : indices) {
336*da0073e9SAndroid Build Coastguard Worker     ASSERT_GE(i, 0);
337*da0073e9SAndroid Build Coastguard Worker     ASSERT_LT(i, 10);
338*da0073e9SAndroid Build Coastguard Worker   }
339*da0073e9SAndroid Build Coastguard Worker 
340*da0073e9SAndroid Build Coastguard Worker   indices = sampler.next(5).value();
341*da0073e9SAndroid Build Coastguard Worker   for (auto i : indices) {
342*da0073e9SAndroid Build Coastguard Worker     ASSERT_GE(i, 0);
343*da0073e9SAndroid Build Coastguard Worker     ASSERT_LT(i, 10);
344*da0073e9SAndroid Build Coastguard Worker   }
345*da0073e9SAndroid Build Coastguard Worker 
346*da0073e9SAndroid Build Coastguard Worker   indices = sampler.next(2).value();
347*da0073e9SAndroid Build Coastguard Worker   for (auto i : indices) {
348*da0073e9SAndroid Build Coastguard Worker     ASSERT_GE(i, 0);
349*da0073e9SAndroid Build Coastguard Worker     ASSERT_LT(i, 10);
350*da0073e9SAndroid Build Coastguard Worker   }
351*da0073e9SAndroid Build Coastguard Worker 
352*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(10).has_value());
353*da0073e9SAndroid Build Coastguard Worker }
354*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,RandomSamplerReturnsLessValuesForLastBatch)355*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, RandomSamplerReturnsLessValuesForLastBatch) {
356*da0073e9SAndroid Build Coastguard Worker   samplers::RandomSampler sampler(5);
357*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(3).value().size(), 3);
358*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(100).value().size(), 2);
359*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
360*da0073e9SAndroid Build Coastguard Worker }
361*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,RandomSamplerResetsWell)362*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, RandomSamplerResetsWell) {
363*da0073e9SAndroid Build Coastguard Worker   samplers::RandomSampler sampler(5);
364*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(5).value().size(), 5);
365*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
366*da0073e9SAndroid Build Coastguard Worker   sampler.reset();
367*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(5).value().size(), 5);
368*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
369*da0073e9SAndroid Build Coastguard Worker }
370*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,RandomSamplerResetsWithNewSizeWell)371*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, RandomSamplerResetsWithNewSizeWell) {
372*da0073e9SAndroid Build Coastguard Worker   samplers::RandomSampler sampler(5);
373*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(5).value().size(), 5);
374*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
375*da0073e9SAndroid Build Coastguard Worker   sampler.reset(7);
376*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(7).value().size(), 7);
377*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
378*da0073e9SAndroid Build Coastguard Worker   sampler.reset(3);
379*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(3).value().size(), 3);
380*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
381*da0073e9SAndroid Build Coastguard Worker }
382*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,SavingAndLoadingRandomSamplerYieldsSameSequence)383*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, SavingAndLoadingRandomSamplerYieldsSameSequence) {
384*da0073e9SAndroid Build Coastguard Worker   {
385*da0073e9SAndroid Build Coastguard Worker     samplers::RandomSampler a(10);
386*da0073e9SAndroid Build Coastguard Worker 
387*da0073e9SAndroid Build Coastguard Worker     std::stringstream stream;
388*da0073e9SAndroid Build Coastguard Worker     torch::save(a, stream);
389*da0073e9SAndroid Build Coastguard Worker 
390*da0073e9SAndroid Build Coastguard Worker     samplers::RandomSampler b(10);
391*da0073e9SAndroid Build Coastguard Worker     torch::load(b, stream);
392*da0073e9SAndroid Build Coastguard Worker 
393*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(a.next(10).value(), b.next(10).value());
394*da0073e9SAndroid Build Coastguard Worker   }
395*da0073e9SAndroid Build Coastguard Worker   {
396*da0073e9SAndroid Build Coastguard Worker     samplers::RandomSampler a(10);
397*da0073e9SAndroid Build Coastguard Worker     a.next(3);
398*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(a.index(), 3);
399*da0073e9SAndroid Build Coastguard Worker 
400*da0073e9SAndroid Build Coastguard Worker     std::stringstream stream;
401*da0073e9SAndroid Build Coastguard Worker     torch::save(a, stream);
402*da0073e9SAndroid Build Coastguard Worker 
403*da0073e9SAndroid Build Coastguard Worker     samplers::RandomSampler b(10);
404*da0073e9SAndroid Build Coastguard Worker     torch::load(b, stream);
405*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(b.index(), 3);
406*da0073e9SAndroid Build Coastguard Worker 
407*da0073e9SAndroid Build Coastguard Worker     auto b_sequence = b.next(10).value();
408*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(b_sequence.size(), 7);
409*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(a.next(10).value(), b_sequence);
410*da0073e9SAndroid Build Coastguard Worker   }
411*da0073e9SAndroid Build Coastguard Worker }
412*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,StreamSamplerReturnsTheBatchSizeAndThenRemainder)413*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, StreamSamplerReturnsTheBatchSizeAndThenRemainder) {
414*da0073e9SAndroid Build Coastguard Worker   samplers::StreamSampler sampler(/*epoch_size=*/100);
415*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(10).value(), 10);
416*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(2).value(), 2);
417*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(85).value(), 85);
418*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(123).value(), 3);
419*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(1).has_value());
420*da0073e9SAndroid Build Coastguard Worker }
421*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,StreamSamplerResetsWell)422*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, StreamSamplerResetsWell) {
423*da0073e9SAndroid Build Coastguard Worker   samplers::StreamSampler sampler(/*epoch_size=*/5);
424*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(5).value().size(), 5);
425*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
426*da0073e9SAndroid Build Coastguard Worker   sampler.reset();
427*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(5).value().size(), 5);
428*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
429*da0073e9SAndroid Build Coastguard Worker }
430*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,StreamSamplerResetsWithNewSizeWell)431*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, StreamSamplerResetsWithNewSizeWell) {
432*da0073e9SAndroid Build Coastguard Worker   samplers::StreamSampler sampler(/*epoch_size=*/5);
433*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(5).value().size(), 5);
434*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
435*da0073e9SAndroid Build Coastguard Worker   sampler.reset(7);
436*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(7).value().size(), 7);
437*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
438*da0073e9SAndroid Build Coastguard Worker   sampler.reset(3);
439*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sampler.next(3).value().size(), 3);
440*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(sampler.next(2).has_value());
441*da0073e9SAndroid Build Coastguard Worker }
442*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,TensorDatasetConstructsFromSingleTensor)443*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, TensorDatasetConstructsFromSingleTensor) {
444*da0073e9SAndroid Build Coastguard Worker   datasets::TensorDataset dataset(torch::eye(5));
445*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
446*da0073e9SAndroid Build Coastguard Worker       torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.get(2)));
447*da0073e9SAndroid Build Coastguard Worker }
448*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,TensorDatasetConstructsFromInitializerListOfTensors)449*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, TensorDatasetConstructsFromInitializerListOfTensors) {
450*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> vector = torch::eye(5).chunk(5);
451*da0073e9SAndroid Build Coastguard Worker   datasets::TensorDataset dataset(vector);
452*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
453*da0073e9SAndroid Build Coastguard Worker       torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.get(2)));
454*da0073e9SAndroid Build Coastguard Worker }
455*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,StackTransformWorksForExample)456*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, StackTransformWorksForExample) {
457*da0073e9SAndroid Build Coastguard Worker   struct D : public datasets::Dataset<D> {
458*da0073e9SAndroid Build Coastguard Worker     Example<> get(size_t index) override {
459*da0073e9SAndroid Build Coastguard Worker       return {tensor[index], 1 + tensor[index]};
460*da0073e9SAndroid Build Coastguard Worker     }
461*da0073e9SAndroid Build Coastguard Worker 
462*da0073e9SAndroid Build Coastguard Worker     torch::optional<size_t> size() const override {
463*da0073e9SAndroid Build Coastguard Worker       return tensor.size(0);
464*da0073e9SAndroid Build Coastguard Worker     }
465*da0073e9SAndroid Build Coastguard Worker 
466*da0073e9SAndroid Build Coastguard Worker     torch::Tensor tensor{torch::eye(4)};
467*da0073e9SAndroid Build Coastguard Worker   };
468*da0073e9SAndroid Build Coastguard Worker 
469*da0073e9SAndroid Build Coastguard Worker   auto d = D().map(transforms::Stack<Example<>>());
470*da0073e9SAndroid Build Coastguard Worker 
471*da0073e9SAndroid Build Coastguard Worker   Example<> batch = d.get_batch({0, 1});
472*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
473*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(batch.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 0, 2)));
474*da0073e9SAndroid Build Coastguard Worker 
475*da0073e9SAndroid Build Coastguard Worker   Example<> second = d.get_batch({2, 3});
476*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
477*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(second.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 2, 4)));
478*da0073e9SAndroid Build Coastguard Worker }
479*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,StackTransformWorksForTensorExample)480*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, StackTransformWorksForTensorExample) {
481*da0073e9SAndroid Build Coastguard Worker   auto d = datasets::TensorDataset(torch::eye(4))
482*da0073e9SAndroid Build Coastguard Worker                .map(transforms::Stack<TensorExample>());
483*da0073e9SAndroid Build Coastguard Worker 
484*da0073e9SAndroid Build Coastguard Worker   TensorExample batch = d.get_batch({0, 1});
485*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
486*da0073e9SAndroid Build Coastguard Worker 
487*da0073e9SAndroid Build Coastguard Worker   TensorExample second = d.get_batch({2, 3});
488*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
489*da0073e9SAndroid Build Coastguard Worker }
490*da0073e9SAndroid Build Coastguard Worker 
491*da0073e9SAndroid Build Coastguard Worker // Template classes cannot be nested in functions.
492*da0073e9SAndroid Build Coastguard Worker template <typename Target>
493*da0073e9SAndroid Build Coastguard Worker struct T : transforms::TensorTransform<Target> {
operator ()T494*da0073e9SAndroid Build Coastguard Worker   torch::Tensor operator()(torch::Tensor input) override {
495*da0073e9SAndroid Build Coastguard Worker     return input * 2;
496*da0073e9SAndroid Build Coastguard Worker   }
497*da0073e9SAndroid Build Coastguard Worker };
498*da0073e9SAndroid Build Coastguard Worker 
499*da0073e9SAndroid Build Coastguard Worker struct TensorStringDataset
500*da0073e9SAndroid Build Coastguard Worker     : datasets::
501*da0073e9SAndroid Build Coastguard Worker           Dataset<TensorStringDataset, Example<torch::Tensor, std::string>> {
getTensorStringDataset502*da0073e9SAndroid Build Coastguard Worker   Example<torch::Tensor, std::string> get(size_t index) override {
503*da0073e9SAndroid Build Coastguard Worker     return {torch::tensor(static_cast<double>(index)), std::to_string(index)};
504*da0073e9SAndroid Build Coastguard Worker   }
505*da0073e9SAndroid Build Coastguard Worker 
sizeTensorStringDataset506*da0073e9SAndroid Build Coastguard Worker   torch::optional<size_t> size() const override {
507*da0073e9SAndroid Build Coastguard Worker     return 100;
508*da0073e9SAndroid Build Coastguard Worker   }
509*da0073e9SAndroid Build Coastguard Worker };
510*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,TensorTransformWorksForAnyTargetType)511*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, TensorTransformWorksForAnyTargetType) {
512*da0073e9SAndroid Build Coastguard Worker   auto d = TensorStringDataset().map(T<std::string>{});
513*da0073e9SAndroid Build Coastguard Worker   std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2});
514*da0073e9SAndroid Build Coastguard Worker 
515*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(batch.size(), 2);
516*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0)));
517*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(batch[0].target, "1");
518*da0073e9SAndroid Build Coastguard Worker 
519*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0)));
520*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(batch[1].target, "2");
521*da0073e9SAndroid Build Coastguard Worker }
522*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,TensorLambdaWorksforAnyTargetType)523*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, TensorLambdaWorksforAnyTargetType) {
524*da0073e9SAndroid Build Coastguard Worker   auto d = TensorStringDataset().map(transforms::TensorLambda<std::string>(
525*da0073e9SAndroid Build Coastguard Worker       [](torch::Tensor input) { return input * 2; }));
526*da0073e9SAndroid Build Coastguard Worker   std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2});
527*da0073e9SAndroid Build Coastguard Worker 
528*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(batch.size(), 2);
529*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0)));
530*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(batch[0].target, "1");
531*da0073e9SAndroid Build Coastguard Worker 
532*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0)));
533*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(batch[1].target, "2");
534*da0073e9SAndroid Build Coastguard Worker }
535*da0073e9SAndroid Build Coastguard Worker 
536*da0073e9SAndroid Build Coastguard Worker struct DummyTensorDataset
537*da0073e9SAndroid Build Coastguard Worker     : datasets::Dataset<DummyTensorDataset, Example<torch::Tensor, int>> {
getDummyTensorDataset538*da0073e9SAndroid Build Coastguard Worker   Example<torch::Tensor, int> get(size_t index) override {
539*da0073e9SAndroid Build Coastguard Worker     const auto channels = static_cast<int64_t>(index);
540*da0073e9SAndroid Build Coastguard Worker     torch::Tensor tensor =
541*da0073e9SAndroid Build Coastguard Worker         (channels > 0) ? torch::ones({channels, 4, 4}) : torch::ones({4, 4});
542*da0073e9SAndroid Build Coastguard Worker     return {tensor, static_cast<int>(channels)};
543*da0073e9SAndroid Build Coastguard Worker   }
544*da0073e9SAndroid Build Coastguard Worker 
sizeDummyTensorDataset545*da0073e9SAndroid Build Coastguard Worker   torch::optional<size_t> size() const override {
546*da0073e9SAndroid Build Coastguard Worker     return 100;
547*da0073e9SAndroid Build Coastguard Worker   }
548*da0073e9SAndroid Build Coastguard Worker };
549*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,NormalizeTransform)550*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, NormalizeTransform) {
551*da0073e9SAndroid Build Coastguard Worker   auto dataset = DummyTensorDataset().map(transforms::Normalize<int>(0.5, 0.1));
552*da0073e9SAndroid Build Coastguard Worker 
553*da0073e9SAndroid Build Coastguard Worker   // Works for zero (one implicit) channels
554*da0073e9SAndroid Build Coastguard Worker   std::vector<Example<torch::Tensor, int>> output = dataset.get_batch(0);
555*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.size(), 1);
556*da0073e9SAndroid Build Coastguard Worker   // (1 - 0.5) / 0.1 = 5
557*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output[0].data.allclose(torch::ones({4, 4}) * 5))
558*da0073e9SAndroid Build Coastguard Worker       << output[0].data;
559*da0073e9SAndroid Build Coastguard Worker 
560*da0073e9SAndroid Build Coastguard Worker   // Works for one explicit channel
561*da0073e9SAndroid Build Coastguard Worker   output = dataset.get_batch(1);
562*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.size(), 1);
563*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output[0].data.size(0), 1);
564*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output[0].data.allclose(torch::ones({1, 4, 4}) * 5))
565*da0073e9SAndroid Build Coastguard Worker       << output[0].data;
566*da0073e9SAndroid Build Coastguard Worker 
567*da0073e9SAndroid Build Coastguard Worker   // Works for two channels with different moments
568*da0073e9SAndroid Build Coastguard Worker   dataset = DummyTensorDataset().map(
569*da0073e9SAndroid Build Coastguard Worker       transforms::Normalize<int>({0.5, 1.5}, {0.1, 0.2}));
570*da0073e9SAndroid Build Coastguard Worker   output = dataset.get_batch(2);
571*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.size(), 1);
572*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output[0].data.size(0), 2);
573*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output[0]
574*da0073e9SAndroid Build Coastguard Worker                   .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
575*da0073e9SAndroid Build Coastguard Worker                   .allclose(torch::ones({1, 4, 4}) * 5))
576*da0073e9SAndroid Build Coastguard Worker       << output[0].data;
577*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output[0]
578*da0073e9SAndroid Build Coastguard Worker                   .data.slice(/*dim=*/0, /*start=*/1)
579*da0073e9SAndroid Build Coastguard Worker                   .allclose(torch::ones({1, 4, 4}) * -2.5))
580*da0073e9SAndroid Build Coastguard Worker       << output[0].data;
581*da0073e9SAndroid Build Coastguard Worker 
582*da0073e9SAndroid Build Coastguard Worker   // Works for three channels with one moment value
583*da0073e9SAndroid Build Coastguard Worker   dataset = DummyTensorDataset().map(transforms::Normalize<int>(1.5, 0.2));
584*da0073e9SAndroid Build Coastguard Worker   output = dataset.get_batch(3);
585*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.size(), 1);
586*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output[0].data.size(0), 3);
587*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output[0].data.allclose(torch::ones({3, 4, 4}) * -2.5))
588*da0073e9SAndroid Build Coastguard Worker       << output[0].data;
589*da0073e9SAndroid Build Coastguard Worker 
590*da0073e9SAndroid Build Coastguard Worker   // Works for three channels with different moments
591*da0073e9SAndroid Build Coastguard Worker   dataset = DummyTensorDataset().map(
592*da0073e9SAndroid Build Coastguard Worker       transforms::Normalize<int>({0.5, 1.5, -1.5}, {0.1, 0.2, 0.2}));
593*da0073e9SAndroid Build Coastguard Worker   output = dataset.get_batch(3);
594*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.size(), 1);
595*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output[0].data.size(0), 3);
596*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output[0]
597*da0073e9SAndroid Build Coastguard Worker                   .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
598*da0073e9SAndroid Build Coastguard Worker                   .allclose(torch::ones({1, 4, 4}) * 5))
599*da0073e9SAndroid Build Coastguard Worker       << output[0].data;
600*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output[0]
601*da0073e9SAndroid Build Coastguard Worker                   .data.slice(/*dim=*/0, /*start=*/1, /*end=*/2)
602*da0073e9SAndroid Build Coastguard Worker                   .allclose(torch::ones({1, 4, 4}) * -2.5))
603*da0073e9SAndroid Build Coastguard Worker       << output[0].data;
604*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(output[0]
605*da0073e9SAndroid Build Coastguard Worker                   .data.slice(/*dim=*/0, /*start=*/2)
606*da0073e9SAndroid Build Coastguard Worker                   .allclose(torch::ones({1, 4, 4}) * 12.5))
607*da0073e9SAndroid Build Coastguard Worker       << output[0].data;
608*da0073e9SAndroid Build Coastguard Worker }
609*da0073e9SAndroid Build Coastguard Worker 
610*da0073e9SAndroid Build Coastguard Worker struct UnCopyableDataset : public datasets::Dataset<UnCopyableDataset> {
611*da0073e9SAndroid Build Coastguard Worker   UnCopyableDataset() = default;
612*da0073e9SAndroid Build Coastguard Worker 
613*da0073e9SAndroid Build Coastguard Worker   UnCopyableDataset(const UnCopyableDataset&) = delete;
614*da0073e9SAndroid Build Coastguard Worker   UnCopyableDataset& operator=(const UnCopyableDataset&) = delete;
615*da0073e9SAndroid Build Coastguard Worker 
616*da0073e9SAndroid Build Coastguard Worker   UnCopyableDataset(UnCopyableDataset&&) = default;
617*da0073e9SAndroid Build Coastguard Worker   UnCopyableDataset& operator=(UnCopyableDataset&&) = default;
618*da0073e9SAndroid Build Coastguard Worker 
619*da0073e9SAndroid Build Coastguard Worker   ~UnCopyableDataset() override = default;
620*da0073e9SAndroid Build Coastguard Worker 
getUnCopyableDataset621*da0073e9SAndroid Build Coastguard Worker   Example<> get(size_t index) override {
622*da0073e9SAndroid Build Coastguard Worker     return {
623*da0073e9SAndroid Build Coastguard Worker         torch::tensor({static_cast<int64_t>(index)}),
624*da0073e9SAndroid Build Coastguard Worker         torch::tensor({static_cast<int64_t>(index)})};
625*da0073e9SAndroid Build Coastguard Worker   }
626*da0073e9SAndroid Build Coastguard Worker 
sizeUnCopyableDataset627*da0073e9SAndroid Build Coastguard Worker   torch::optional<size_t> size() const override {
628*da0073e9SAndroid Build Coastguard Worker     return 100;
629*da0073e9SAndroid Build Coastguard Worker   }
630*da0073e9SAndroid Build Coastguard Worker };
631*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,MapDoesNotCopy)632*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, MapDoesNotCopy) {
633*da0073e9SAndroid Build Coastguard Worker   auto dataset = UnCopyableDataset()
634*da0073e9SAndroid Build Coastguard Worker                      .map(transforms::TensorLambda<>(
635*da0073e9SAndroid Build Coastguard Worker                          [](torch::Tensor tensor) { return tensor + 1; }))
636*da0073e9SAndroid Build Coastguard Worker                      .map(transforms::TensorLambda<>(
637*da0073e9SAndroid Build Coastguard Worker                          [](torch::Tensor tensor) { return tensor + 2; }))
638*da0073e9SAndroid Build Coastguard Worker                      .map(transforms::TensorLambda<>(
639*da0073e9SAndroid Build Coastguard Worker                          [](torch::Tensor tensor) { return tensor + 3; }));
640*da0073e9SAndroid Build Coastguard Worker 
641*da0073e9SAndroid Build Coastguard Worker   auto data = dataset.get_batch(1).at(0).data;
642*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(data.numel(), 1);
643*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(data[0].item<float>(), 7);
644*da0073e9SAndroid Build Coastguard Worker }
645*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,QueuePushAndPopFromSameThread)646*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, QueuePushAndPopFromSameThread) {
647*da0073e9SAndroid Build Coastguard Worker   torch::data::detail::Queue<int> queue;
648*da0073e9SAndroid Build Coastguard Worker   queue.push(1);
649*da0073e9SAndroid Build Coastguard Worker   queue.push(2);
650*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(queue.pop(), 1);
651*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(queue.pop(), 2);
652*da0073e9SAndroid Build Coastguard Worker }
653*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,QueuePopWithTimeoutThrowsUponTimeout)654*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, QueuePopWithTimeoutThrowsUponTimeout) {
655*da0073e9SAndroid Build Coastguard Worker   torch::data::detail::Queue<int> queue;
656*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
657*da0073e9SAndroid Build Coastguard Worker       queue.pop(10 * kMillisecond),
658*da0073e9SAndroid Build Coastguard Worker       "Timeout in DataLoader queue while waiting for next batch "
659*da0073e9SAndroid Build Coastguard Worker       "(timeout was 10 ms)");
660*da0073e9SAndroid Build Coastguard Worker }
661*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,QueuePushAndPopFromDifferentThreads)662*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, QueuePushAndPopFromDifferentThreads) {
663*da0073e9SAndroid Build Coastguard Worker   using torch::data::detail::Queue;
664*da0073e9SAndroid Build Coastguard Worker 
665*da0073e9SAndroid Build Coastguard Worker   // First test: push batch and the pop in thread.
666*da0073e9SAndroid Build Coastguard Worker   {
667*da0073e9SAndroid Build Coastguard Worker     Queue<int> queue;
668*da0073e9SAndroid Build Coastguard Worker     queue.push(1);
669*da0073e9SAndroid Build Coastguard Worker     auto future =
670*da0073e9SAndroid Build Coastguard Worker         std::async(std::launch::async, [&queue] { return queue.pop(); });
671*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(future.get(), 1);
672*da0073e9SAndroid Build Coastguard Worker   }
673*da0073e9SAndroid Build Coastguard Worker 
674*da0073e9SAndroid Build Coastguard Worker   // Second test: attempt to pop batch (and block), then push.
675*da0073e9SAndroid Build Coastguard Worker   {
676*da0073e9SAndroid Build Coastguard Worker     Queue<int> queue;
677*da0073e9SAndroid Build Coastguard Worker     std::thread thread([&queue] {
678*da0073e9SAndroid Build Coastguard Worker       std::this_thread::sleep_for(20 * kMillisecond);
679*da0073e9SAndroid Build Coastguard Worker       queue.push(123);
680*da0073e9SAndroid Build Coastguard Worker     });
681*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(queue.pop(), 123);
682*da0073e9SAndroid Build Coastguard Worker     thread.join();
683*da0073e9SAndroid Build Coastguard Worker   }
684*da0073e9SAndroid Build Coastguard Worker }
685*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,QueueClearEmptiesTheQueue)686*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, QueueClearEmptiesTheQueue) {
687*da0073e9SAndroid Build Coastguard Worker   torch::data::detail::Queue<int> queue;
688*da0073e9SAndroid Build Coastguard Worker   queue.push(1);
689*da0073e9SAndroid Build Coastguard Worker   queue.push(2);
690*da0073e9SAndroid Build Coastguard Worker   queue.push(3);
691*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(queue.clear(), 3);
692*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(queue.pop(1 * kMillisecond), "Timeout");
693*da0073e9SAndroid Build Coastguard Worker }
694*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,DataShuttleCanPushAndPopJob)695*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, DataShuttleCanPushAndPopJob) {
696*da0073e9SAndroid Build Coastguard Worker   torch::data::detail::DataShuttle<int, int> shuttle;
697*da0073e9SAndroid Build Coastguard Worker   shuttle.push_job(1);
698*da0073e9SAndroid Build Coastguard Worker   shuttle.push_job(2);
699*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(shuttle.pop_job(), 1);
700*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(shuttle.pop_job(), 2);
701*da0073e9SAndroid Build Coastguard Worker }
702*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,DataShuttleCanPushAndPopResult)703*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, DataShuttleCanPushAndPopResult) {
704*da0073e9SAndroid Build Coastguard Worker   torch::data::detail::DataShuttle<int, int> shuttle;
705*da0073e9SAndroid Build Coastguard Worker   // pop_result() will only attempt to pop if there was a push_job() batch.
706*da0073e9SAndroid Build Coastguard Worker   shuttle.push_job(1);
707*da0073e9SAndroid Build Coastguard Worker   shuttle.push_job(2);
708*da0073e9SAndroid Build Coastguard Worker 
709*da0073e9SAndroid Build Coastguard Worker   shuttle.pop_job();
710*da0073e9SAndroid Build Coastguard Worker   shuttle.push_result(1);
711*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(shuttle.pop_result().value(), 1);
712*da0073e9SAndroid Build Coastguard Worker 
713*da0073e9SAndroid Build Coastguard Worker   shuttle.pop_job();
714*da0073e9SAndroid Build Coastguard Worker   shuttle.push_result(2);
715*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(shuttle.pop_result().value(), 2);
716*da0073e9SAndroid Build Coastguard Worker }
717*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,DataShuttlePopResultReturnsNulloptWhenNoJobsInFlight)718*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, DataShuttlePopResultReturnsNulloptWhenNoJobsInFlight) {
719*da0073e9SAndroid Build Coastguard Worker   torch::data::detail::DataShuttle<int, int> shuttle;
720*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(shuttle.pop_result().has_value());
721*da0073e9SAndroid Build Coastguard Worker   shuttle.push_job(1);
722*da0073e9SAndroid Build Coastguard Worker   shuttle.pop_job();
723*da0073e9SAndroid Build Coastguard Worker   shuttle.push_result(1);
724*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(shuttle.pop_result().value(), 1);
725*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(shuttle.pop_result().has_value());
726*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(shuttle.pop_result().has_value());
727*da0073e9SAndroid Build Coastguard Worker }
728*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,DataShuttleDrainMeansPopResultReturnsNullopt)729*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, DataShuttleDrainMeansPopResultReturnsNullopt) {
730*da0073e9SAndroid Build Coastguard Worker   torch::data::detail::DataShuttle<int, int> shuttle;
731*da0073e9SAndroid Build Coastguard Worker   shuttle.push_job(1);
732*da0073e9SAndroid Build Coastguard Worker   shuttle.push_result(1);
733*da0073e9SAndroid Build Coastguard Worker   shuttle.drain();
734*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(shuttle.pop_result().has_value());
735*da0073e9SAndroid Build Coastguard Worker }
736*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,DataShuttlePopResultTimesOut)737*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, DataShuttlePopResultTimesOut) {
738*da0073e9SAndroid Build Coastguard Worker   torch::data::detail::DataShuttle<int, int> shuttle;
739*da0073e9SAndroid Build Coastguard Worker   shuttle.push_job(1);
740*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(shuttle.pop_result(10 * kMillisecond), "Timeout");
741*da0073e9SAndroid Build Coastguard Worker }
742*da0073e9SAndroid Build Coastguard Worker 
743*da0073e9SAndroid Build Coastguard Worker struct UncopyableDataset : datasets::Dataset<UncopyableDataset, int> {
UncopyableDatasetUncopyableDataset744*da0073e9SAndroid Build Coastguard Worker   UncopyableDataset(const std::string& /* unused */) {}
745*da0073e9SAndroid Build Coastguard Worker 
746*da0073e9SAndroid Build Coastguard Worker   UncopyableDataset(UncopyableDataset&&) = default;
747*da0073e9SAndroid Build Coastguard Worker   UncopyableDataset& operator=(UncopyableDataset&&) = default;
748*da0073e9SAndroid Build Coastguard Worker 
749*da0073e9SAndroid Build Coastguard Worker   UncopyableDataset(const UncopyableDataset&) = delete;
750*da0073e9SAndroid Build Coastguard Worker   UncopyableDataset& operator=(const UncopyableDataset&) = delete;
751*da0073e9SAndroid Build Coastguard Worker 
getUncopyableDataset752*da0073e9SAndroid Build Coastguard Worker   int get(size_t index) override {
753*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
754*da0073e9SAndroid Build Coastguard Worker     return 1 + index;
755*da0073e9SAndroid Build Coastguard Worker   }
sizeUncopyableDataset756*da0073e9SAndroid Build Coastguard Worker   torch::optional<size_t> size() const override {
757*da0073e9SAndroid Build Coastguard Worker     return 100;
758*da0073e9SAndroid Build Coastguard Worker   }
759*da0073e9SAndroid Build Coastguard Worker };
760*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,SharedBatchDatasetReallyIsShared)761*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, SharedBatchDatasetReallyIsShared) {
762*da0073e9SAndroid Build Coastguard Worker   // This test will only compile if we really are not making any copies.
763*da0073e9SAndroid Build Coastguard Worker   // There is otherwise no logic to test and because it is not deterministic
764*da0073e9SAndroid Build Coastguard Worker   // how many and when worker threads access the shareddataset, we don't have
765*da0073e9SAndroid Build Coastguard Worker   // any additional assertions here.
766*da0073e9SAndroid Build Coastguard Worker 
767*da0073e9SAndroid Build Coastguard Worker   auto shared_dataset =
768*da0073e9SAndroid Build Coastguard Worker       torch::data::datasets::make_shared_dataset<UncopyableDataset>(
769*da0073e9SAndroid Build Coastguard Worker           "uncopyable");
770*da0073e9SAndroid Build Coastguard Worker 
771*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
772*da0073e9SAndroid Build Coastguard Worker       shared_dataset, torch::data::DataLoaderOptions().workers(3));
773*da0073e9SAndroid Build Coastguard Worker 
774*da0073e9SAndroid Build Coastguard Worker   for (auto batch : *data_loader) {
775*da0073e9SAndroid Build Coastguard Worker     /* exhaust */
776*da0073e9SAndroid Build Coastguard Worker   }
777*da0073e9SAndroid Build Coastguard Worker }
778*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,SharedBatchDatasetDoesNotIncurCopyWhenPassedDatasetObject)779*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, SharedBatchDatasetDoesNotIncurCopyWhenPassedDatasetObject) {
780*da0073e9SAndroid Build Coastguard Worker   // This will not compile if a copy is made.
781*da0073e9SAndroid Build Coastguard Worker   auto shared_dataset =
782*da0073e9SAndroid Build Coastguard Worker       torch::data::datasets::make_shared_dataset<UncopyableDataset>(
783*da0073e9SAndroid Build Coastguard Worker           UncopyableDataset("uncopyable"));
784*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(shared_dataset.size().value(), 100);
785*da0073e9SAndroid Build Coastguard Worker }
786*da0073e9SAndroid Build Coastguard Worker 
787*da0073e9SAndroid Build Coastguard Worker struct TestIndex : public torch::data::samplers::CustomBatchRequest {
TestIndexTestIndex788*da0073e9SAndroid Build Coastguard Worker   explicit TestIndex(size_t offset, std::vector<size_t> index)
789*da0073e9SAndroid Build Coastguard Worker       : offset(offset), index(std::move(index)) {}
sizeTestIndex790*da0073e9SAndroid Build Coastguard Worker   size_t size() const override {
791*da0073e9SAndroid Build Coastguard Worker     return index.size();
792*da0073e9SAndroid Build Coastguard Worker   }
793*da0073e9SAndroid Build Coastguard Worker   size_t offset;
794*da0073e9SAndroid Build Coastguard Worker   std::vector<size_t> index;
795*da0073e9SAndroid Build Coastguard Worker };
796*da0073e9SAndroid Build Coastguard Worker 
797*da0073e9SAndroid Build Coastguard Worker struct TestIndexDataset
798*da0073e9SAndroid Build Coastguard Worker     : datasets::BatchDataset<TestIndexDataset, std::vector<int>, TestIndex> {
TestIndexDatasetTestIndexDataset799*da0073e9SAndroid Build Coastguard Worker   explicit TestIndexDataset(size_t size) : data(size) {
800*da0073e9SAndroid Build Coastguard Worker     std::iota(data.begin(), data.end(), size_t(0));
801*da0073e9SAndroid Build Coastguard Worker   }
get_batchTestIndexDataset802*da0073e9SAndroid Build Coastguard Worker   std::vector<int> get_batch(TestIndex index) override {
803*da0073e9SAndroid Build Coastguard Worker     std::vector<int> batch;
804*da0073e9SAndroid Build Coastguard Worker     for (auto i : index.index) {
805*da0073e9SAndroid Build Coastguard Worker       batch.push_back(index.offset + data.at(i));
806*da0073e9SAndroid Build Coastguard Worker     }
807*da0073e9SAndroid Build Coastguard Worker     return batch;
808*da0073e9SAndroid Build Coastguard Worker   }
sizeTestIndexDataset809*da0073e9SAndroid Build Coastguard Worker   torch::optional<size_t> size() const override {
810*da0073e9SAndroid Build Coastguard Worker     return data.size();
811*da0073e9SAndroid Build Coastguard Worker   }
812*da0073e9SAndroid Build Coastguard Worker   std::vector<int> data;
813*da0073e9SAndroid Build Coastguard Worker };
814*da0073e9SAndroid Build Coastguard Worker 
815*da0073e9SAndroid Build Coastguard Worker struct TestIndexSampler : public samplers::Sampler<TestIndex> {
TestIndexSamplerTestIndexSampler816*da0073e9SAndroid Build Coastguard Worker   explicit TestIndexSampler(size_t size) : size_(size) {}
resetTestIndexSampler817*da0073e9SAndroid Build Coastguard Worker   void reset(torch::optional<size_t> new_size = torch::nullopt) override {}
nextTestIndexSampler818*da0073e9SAndroid Build Coastguard Worker   torch::optional<TestIndex> next(size_t batch_size) override {
819*da0073e9SAndroid Build Coastguard Worker     if (index_ >= size_) {
820*da0073e9SAndroid Build Coastguard Worker       return torch::nullopt;
821*da0073e9SAndroid Build Coastguard Worker     }
822*da0073e9SAndroid Build Coastguard Worker     std::vector<size_t> indices(batch_size);
823*da0073e9SAndroid Build Coastguard Worker     std::iota(indices.begin(), indices.end(), size_t(0));
824*da0073e9SAndroid Build Coastguard Worker     index_ += batch_size;
825*da0073e9SAndroid Build Coastguard Worker     return TestIndex(batch_size, std::move(indices));
826*da0073e9SAndroid Build Coastguard Worker   }
saveTestIndexSampler827*da0073e9SAndroid Build Coastguard Worker   void save(torch::serialize::OutputArchive& archive) const override {}
loadTestIndexSampler828*da0073e9SAndroid Build Coastguard Worker   void load(torch::serialize::InputArchive& archive) override {}
829*da0073e9SAndroid Build Coastguard Worker   size_t index_ = 0;
830*da0073e9SAndroid Build Coastguard Worker   size_t size_;
831*da0073e9SAndroid Build Coastguard Worker };
832*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,CanUseCustomTypeAsIndexType)833*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, CanUseCustomTypeAsIndexType) {
834*da0073e9SAndroid Build Coastguard Worker   const int kBatchSize = 10;
835*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
836*da0073e9SAndroid Build Coastguard Worker       TestIndexDataset(23), TestIndexSampler(23), kBatchSize);
837*da0073e9SAndroid Build Coastguard Worker 
838*da0073e9SAndroid Build Coastguard Worker   for (auto batch : *data_loader) {
839*da0073e9SAndroid Build Coastguard Worker     for (const auto j : c10::irange(kBatchSize)) {
840*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(batch.at(j), 10 + j);
841*da0073e9SAndroid Build Coastguard Worker     }
842*da0073e9SAndroid Build Coastguard Worker   }
843*da0073e9SAndroid Build Coastguard Worker }
844*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,DistributedRandomSamplerSingleReplicaProduceCorrectSamples)845*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, DistributedRandomSamplerSingleReplicaProduceCorrectSamples) {
846*da0073e9SAndroid Build Coastguard Worker   size_t sample_count = 10;
847*da0073e9SAndroid Build Coastguard Worker   samplers::DistributedRandomSampler drs(sample_count);
848*da0073e9SAndroid Build Coastguard Worker 
849*da0073e9SAndroid Build Coastguard Worker   std::vector<size_t> res;
850*da0073e9SAndroid Build Coastguard Worker   torch::optional<std::vector<size_t>> idx;
851*da0073e9SAndroid Build Coastguard Worker   while ((idx = drs.next(3)).has_value()) {
852*da0073e9SAndroid Build Coastguard Worker     res.insert(std::end(res), std::begin(*idx), std::end(*idx));
853*da0073e9SAndroid Build Coastguard Worker   }
854*da0073e9SAndroid Build Coastguard Worker 
855*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(res.size(), sample_count);
856*da0073e9SAndroid Build Coastguard Worker 
857*da0073e9SAndroid Build Coastguard Worker   std::sort(res.begin(), res.end());
858*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(res.size())) {
859*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(res[i], i);
860*da0073e9SAndroid Build Coastguard Worker   }
861*da0073e9SAndroid Build Coastguard Worker }
862*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,DistributedRandomSamplerMultiReplicaProduceCorrectSamples)863*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, DistributedRandomSamplerMultiReplicaProduceCorrectSamples) {
864*da0073e9SAndroid Build Coastguard Worker   size_t sample_count = 10;
865*da0073e9SAndroid Build Coastguard Worker   size_t num_replicas = 3;
866*da0073e9SAndroid Build Coastguard Worker 
867*da0073e9SAndroid Build Coastguard Worker   auto test_function = [&](bool allow_duplicates,
868*da0073e9SAndroid Build Coastguard Worker                            size_t local_sample_count,
869*da0073e9SAndroid Build Coastguard Worker                            std::vector<size_t>& output,
870*da0073e9SAndroid Build Coastguard Worker                            size_t batch_size) {
871*da0073e9SAndroid Build Coastguard Worker     std::vector<std::unique_ptr<samplers::DistributedRandomSampler>> samplers;
872*da0073e9SAndroid Build Coastguard Worker 
873*da0073e9SAndroid Build Coastguard Worker     for (const auto i : c10::irange(num_replicas)) {
874*da0073e9SAndroid Build Coastguard Worker       samplers.emplace_back(
875*da0073e9SAndroid Build Coastguard Worker           std::make_unique<samplers::DistributedRandomSampler>(
876*da0073e9SAndroid Build Coastguard Worker               sample_count, num_replicas, i, allow_duplicates));
877*da0073e9SAndroid Build Coastguard Worker     }
878*da0073e9SAndroid Build Coastguard Worker 
879*da0073e9SAndroid Build Coastguard Worker     std::vector<size_t> res;
880*da0073e9SAndroid Build Coastguard Worker     for (const auto i : c10::irange(num_replicas)) {
881*da0073e9SAndroid Build Coastguard Worker       (*samplers[i]).reset();
882*da0073e9SAndroid Build Coastguard Worker       torch::optional<std::vector<size_t>> idx;
883*da0073e9SAndroid Build Coastguard Worker       while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
884*da0073e9SAndroid Build Coastguard Worker         res.insert(std::end(res), std::begin(*idx), std::end(*idx));
885*da0073e9SAndroid Build Coastguard Worker       }
886*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(res.size(), local_sample_count * (i + 1));
887*da0073e9SAndroid Build Coastguard Worker     }
888*da0073e9SAndroid Build Coastguard Worker     std::sort(res.begin(), res.end());
889*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(res, output);
890*da0073e9SAndroid Build Coastguard Worker   };
891*da0073e9SAndroid Build Coastguard Worker 
892*da0073e9SAndroid Build Coastguard Worker   for (size_t batch_size = 1; batch_size <= 3; ++batch_size) {
893*da0073e9SAndroid Build Coastguard Worker     size_t local_sample_count =
894*da0073e9SAndroid Build Coastguard Worker         static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
895*da0073e9SAndroid Build Coastguard Worker     std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
896*da0073e9SAndroid Build Coastguard Worker     test_function(true, local_sample_count, output1, batch_size);
897*da0073e9SAndroid Build Coastguard Worker 
898*da0073e9SAndroid Build Coastguard Worker     local_sample_count =
899*da0073e9SAndroid Build Coastguard Worker         static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
900*da0073e9SAndroid Build Coastguard Worker     std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
901*da0073e9SAndroid Build Coastguard Worker     test_function(false, local_sample_count, output2, batch_size);
902*da0073e9SAndroid Build Coastguard Worker   }
903*da0073e9SAndroid Build Coastguard Worker }
904*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,CanSaveAndLoadDistributedRandomSampler)905*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, CanSaveAndLoadDistributedRandomSampler) {
906*da0073e9SAndroid Build Coastguard Worker   {
907*da0073e9SAndroid Build Coastguard Worker     samplers::DistributedRandomSampler a(10);
908*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(a.index(), 0);
909*da0073e9SAndroid Build Coastguard Worker     std::stringstream stream;
910*da0073e9SAndroid Build Coastguard Worker     torch::save(a, stream);
911*da0073e9SAndroid Build Coastguard Worker 
912*da0073e9SAndroid Build Coastguard Worker     samplers::DistributedRandomSampler b(10);
913*da0073e9SAndroid Build Coastguard Worker     torch::load(b, stream);
914*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(b.index(), 0);
915*da0073e9SAndroid Build Coastguard Worker   }
916*da0073e9SAndroid Build Coastguard Worker   {
917*da0073e9SAndroid Build Coastguard Worker     samplers::DistributedRandomSampler a(10);
918*da0073e9SAndroid Build Coastguard Worker     a.next(3);
919*da0073e9SAndroid Build Coastguard Worker     a.next(4);
920*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(a.index(), 7);
921*da0073e9SAndroid Build Coastguard Worker     std::stringstream stream;
922*da0073e9SAndroid Build Coastguard Worker     torch::save(a, stream);
923*da0073e9SAndroid Build Coastguard Worker 
924*da0073e9SAndroid Build Coastguard Worker     samplers::DistributedRandomSampler b(10);
925*da0073e9SAndroid Build Coastguard Worker     torch::load(b, stream);
926*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(b.index(), 7);
927*da0073e9SAndroid Build Coastguard Worker   }
928*da0073e9SAndroid Build Coastguard Worker   {
929*da0073e9SAndroid Build Coastguard Worker     samplers::DistributedRandomSampler a(10);
930*da0073e9SAndroid Build Coastguard Worker     a.set_epoch(3);
931*da0073e9SAndroid Build Coastguard Worker     std::stringstream stream;
932*da0073e9SAndroid Build Coastguard Worker     torch::save(a, stream);
933*da0073e9SAndroid Build Coastguard Worker 
934*da0073e9SAndroid Build Coastguard Worker     samplers::DistributedRandomSampler b(10);
935*da0073e9SAndroid Build Coastguard Worker     torch::load(b, stream);
936*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(b.epoch(), 3);
937*da0073e9SAndroid Build Coastguard Worker   }
938*da0073e9SAndroid Build Coastguard Worker }
939*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,DistributedSequentialSamplerSingleReplicaProduceCorrectSamples)940*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, DistributedSequentialSamplerSingleReplicaProduceCorrectSamples) {
941*da0073e9SAndroid Build Coastguard Worker   size_t sample_count = 10;
942*da0073e9SAndroid Build Coastguard Worker   size_t batch_size = 3;
943*da0073e9SAndroid Build Coastguard Worker   samplers::DistributedSequentialSampler dss(sample_count);
944*da0073e9SAndroid Build Coastguard Worker 
945*da0073e9SAndroid Build Coastguard Worker   std::vector<size_t> res;
946*da0073e9SAndroid Build Coastguard Worker   torch::optional<std::vector<size_t>> idx;
947*da0073e9SAndroid Build Coastguard Worker   while ((idx = dss.next(batch_size)).has_value()) {
948*da0073e9SAndroid Build Coastguard Worker     res.insert(std::end(res), std::begin(*idx), std::end(*idx));
949*da0073e9SAndroid Build Coastguard Worker   }
950*da0073e9SAndroid Build Coastguard Worker 
951*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(res.size(), sample_count);
952*da0073e9SAndroid Build Coastguard Worker 
953*da0073e9SAndroid Build Coastguard Worker   std::sort(res.begin(), res.end());
954*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(res.size())) {
955*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(res[i], i);
956*da0073e9SAndroid Build Coastguard Worker   }
957*da0073e9SAndroid Build Coastguard Worker }
958*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,DistributedSequentialSamplerMultiReplicaProduceCorrectSamples)959*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, DistributedSequentialSamplerMultiReplicaProduceCorrectSamples) {
960*da0073e9SAndroid Build Coastguard Worker   size_t sample_count = 10;
961*da0073e9SAndroid Build Coastguard Worker   size_t num_replicas = 3;
962*da0073e9SAndroid Build Coastguard Worker 
963*da0073e9SAndroid Build Coastguard Worker   auto test_function = [&](bool allow_duplicates,
964*da0073e9SAndroid Build Coastguard Worker                            size_t local_sample_count,
965*da0073e9SAndroid Build Coastguard Worker                            std::vector<size_t>& output,
966*da0073e9SAndroid Build Coastguard Worker                            size_t batch_size) {
967*da0073e9SAndroid Build Coastguard Worker     std::vector<std::unique_ptr<samplers::DistributedSequentialSampler>>
968*da0073e9SAndroid Build Coastguard Worker         samplers;
969*da0073e9SAndroid Build Coastguard Worker 
970*da0073e9SAndroid Build Coastguard Worker     for (const auto i : c10::irange(num_replicas)) {
971*da0073e9SAndroid Build Coastguard Worker       samplers.emplace_back(
972*da0073e9SAndroid Build Coastguard Worker           std::make_unique<samplers::DistributedSequentialSampler>(
973*da0073e9SAndroid Build Coastguard Worker               sample_count, num_replicas, i, allow_duplicates));
974*da0073e9SAndroid Build Coastguard Worker     }
975*da0073e9SAndroid Build Coastguard Worker 
976*da0073e9SAndroid Build Coastguard Worker     std::vector<size_t> res;
977*da0073e9SAndroid Build Coastguard Worker     for (const auto i : c10::irange(num_replicas)) {
978*da0073e9SAndroid Build Coastguard Worker       (*samplers[i]).reset();
979*da0073e9SAndroid Build Coastguard Worker       torch::optional<std::vector<size_t>> idx;
980*da0073e9SAndroid Build Coastguard Worker       while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
981*da0073e9SAndroid Build Coastguard Worker         res.insert(std::end(res), std::begin(*idx), std::end(*idx));
982*da0073e9SAndroid Build Coastguard Worker       }
983*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(res.size(), local_sample_count * (i + 1));
984*da0073e9SAndroid Build Coastguard Worker     }
985*da0073e9SAndroid Build Coastguard Worker     std::sort(res.begin(), res.end());
986*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(res, output);
987*da0073e9SAndroid Build Coastguard Worker   };
988*da0073e9SAndroid Build Coastguard Worker 
989*da0073e9SAndroid Build Coastguard Worker   for (size_t batch_size = 1; batch_size <= 3; ++batch_size) {
990*da0073e9SAndroid Build Coastguard Worker     size_t local_sample_count =
991*da0073e9SAndroid Build Coastguard Worker         static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
992*da0073e9SAndroid Build Coastguard Worker     std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
993*da0073e9SAndroid Build Coastguard Worker     test_function(true, local_sample_count, output1, batch_size);
994*da0073e9SAndroid Build Coastguard Worker 
995*da0073e9SAndroid Build Coastguard Worker     local_sample_count =
996*da0073e9SAndroid Build Coastguard Worker         static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
997*da0073e9SAndroid Build Coastguard Worker     std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
998*da0073e9SAndroid Build Coastguard Worker     test_function(false, local_sample_count, output2, batch_size);
999*da0073e9SAndroid Build Coastguard Worker   }
1000*da0073e9SAndroid Build Coastguard Worker }
1001*da0073e9SAndroid Build Coastguard Worker 
TEST(DataTest,CanSaveAndLoadDistributedSequentialSampler)1002*da0073e9SAndroid Build Coastguard Worker TEST(DataTest, CanSaveAndLoadDistributedSequentialSampler) {
1003*da0073e9SAndroid Build Coastguard Worker   {
1004*da0073e9SAndroid Build Coastguard Worker     samplers::DistributedSequentialSampler a(10);
1005*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(a.index(), 0);
1006*da0073e9SAndroid Build Coastguard Worker     std::stringstream stream;
1007*da0073e9SAndroid Build Coastguard Worker     torch::save(a, stream);
1008*da0073e9SAndroid Build Coastguard Worker 
1009*da0073e9SAndroid Build Coastguard Worker     samplers::DistributedSequentialSampler b(10);
1010*da0073e9SAndroid Build Coastguard Worker     torch::load(b, stream);
1011*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(b.index(), 0);
1012*da0073e9SAndroid Build Coastguard Worker   }
1013*da0073e9SAndroid Build Coastguard Worker   {
1014*da0073e9SAndroid Build Coastguard Worker     samplers::DistributedSequentialSampler a(10);
1015*da0073e9SAndroid Build Coastguard Worker     a.next(3);
1016*da0073e9SAndroid Build Coastguard Worker     a.next(4);
1017*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(a.index(), 7);
1018*da0073e9SAndroid Build Coastguard Worker     std::stringstream stream;
1019*da0073e9SAndroid Build Coastguard Worker     torch::save(a, stream);
1020*da0073e9SAndroid Build Coastguard Worker 
1021*da0073e9SAndroid Build Coastguard Worker     samplers::DistributedSequentialSampler b(10);
1022*da0073e9SAndroid Build Coastguard Worker     torch::load(b, stream);
1023*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(b.index(), 7);
1024*da0073e9SAndroid Build Coastguard Worker   }
1025*da0073e9SAndroid Build Coastguard Worker }
1026*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,DataLoaderOptionsDefaultAsExpected)1027*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, DataLoaderOptionsDefaultAsExpected) {
1028*da0073e9SAndroid Build Coastguard Worker   DataLoaderOptions partial_options;
1029*da0073e9SAndroid Build Coastguard Worker   FullDataLoaderOptions full_options(partial_options);
1030*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(full_options.batch_size, 1);
1031*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(full_options.drop_last);
1032*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(full_options.workers, 0);
1033*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(full_options.max_jobs, 0);
1034*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(full_options.timeout.has_value());
1035*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(full_options.enforce_ordering);
1036*da0073e9SAndroid Build Coastguard Worker }
1037*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,DataLoaderOptionsCoalesceOptionalValues)1038*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, DataLoaderOptionsCoalesceOptionalValues) {
1039*da0073e9SAndroid Build Coastguard Worker   auto partial_options = DataLoaderOptions(32).workers(10);
1040*da0073e9SAndroid Build Coastguard Worker   FullDataLoaderOptions full_options(partial_options);
1041*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(full_options.batch_size, 32);
1042*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(full_options.max_jobs, 2 * 10);
1043*da0073e9SAndroid Build Coastguard Worker }
1044*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,MakeDataLoaderDefaultsAsExpected)1045*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, MakeDataLoaderDefaultsAsExpected) {
1046*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
1047*da0073e9SAndroid Build Coastguard Worker       DummyDataset().map(transforms::Lambda<int>([](int x) { return x + 1; })));
1048*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(data_loader->options().batch_size, 1);
1049*da0073e9SAndroid Build Coastguard Worker }
1050*da0073e9SAndroid Build Coastguard Worker 
1051*da0073e9SAndroid Build Coastguard Worker struct UnsizedDataset : public datasets::Dataset<UnsizedDataset> {
getUnsizedDataset1052*da0073e9SAndroid Build Coastguard Worker   torch::data::Example<> get(size_t i) override {
1053*da0073e9SAndroid Build Coastguard Worker     return {torch::ones(i), torch::ones(i)};
1054*da0073e9SAndroid Build Coastguard Worker   }
sizeUnsizedDataset1055*da0073e9SAndroid Build Coastguard Worker   torch::optional<size_t> size() const noexcept override {
1056*da0073e9SAndroid Build Coastguard Worker     return torch::nullopt;
1057*da0073e9SAndroid Build Coastguard Worker   }
1058*da0073e9SAndroid Build Coastguard Worker };
1059*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,MakeDataLoaderThrowsWhenConstructingSamplerWithUnsizedDataset)1060*da0073e9SAndroid Build Coastguard Worker TEST(
1061*da0073e9SAndroid Build Coastguard Worker     DataLoaderTest,
1062*da0073e9SAndroid Build Coastguard Worker     MakeDataLoaderThrowsWhenConstructingSamplerWithUnsizedDataset) {
1063*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
1064*da0073e9SAndroid Build Coastguard Worker       torch::data::make_data_loader(UnsizedDataset{}),
1065*da0073e9SAndroid Build Coastguard Worker       "Expected the dataset to be sized in order to construct the Sampler");
1066*da0073e9SAndroid Build Coastguard Worker }
1067*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,IteratorsCompareEqualToThemselves)1068*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, IteratorsCompareEqualToThemselves) {
1069*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(DummyDataset(), 32);
1070*da0073e9SAndroid Build Coastguard Worker   auto begin = data_loader->begin();
1071*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(begin, begin);
1072*da0073e9SAndroid Build Coastguard Worker   auto end = data_loader->end();
1073*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(end, end);
1074*da0073e9SAndroid Build Coastguard Worker }
1075*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,ValidIteratorsCompareUnequalToEachOther)1076*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, ValidIteratorsCompareUnequalToEachOther) {
1077*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(DummyDataset(), 32);
1078*da0073e9SAndroid Build Coastguard Worker   auto i = data_loader->begin();
1079*da0073e9SAndroid Build Coastguard Worker   auto j = data_loader->begin();
1080*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(i, j);
1081*da0073e9SAndroid Build Coastguard Worker   ++j;
1082*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(i, j);
1083*da0073e9SAndroid Build Coastguard Worker }
1084*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,SentinelIteratorsCompareEqualToEachOther)1085*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, SentinelIteratorsCompareEqualToEachOther) {
1086*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(DummyDataset(), 32);
1087*da0073e9SAndroid Build Coastguard Worker   auto i = data_loader->end();
1088*da0073e9SAndroid Build Coastguard Worker   auto j = data_loader->end();
1089*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(i, j);
1090*da0073e9SAndroid Build Coastguard Worker }
1091*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,IteratorsCompareEqualToSentinelWhenExhausted)1092*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, IteratorsCompareEqualToSentinelWhenExhausted) {
1093*da0073e9SAndroid Build Coastguard Worker   DummyDataset dataset;
1094*da0073e9SAndroid Build Coastguard Worker   auto data_loader =
1095*da0073e9SAndroid Build Coastguard Worker       torch::data::make_data_loader(dataset, dataset.size().value() / 4);
1096*da0073e9SAndroid Build Coastguard Worker   auto i = data_loader->begin();
1097*da0073e9SAndroid Build Coastguard Worker   auto end = data_loader->end();
1098*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(i, end);
1099*da0073e9SAndroid Build Coastguard Worker   ++i;
1100*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(i, end);
1101*da0073e9SAndroid Build Coastguard Worker   ++i;
1102*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(i, end);
1103*da0073e9SAndroid Build Coastguard Worker   ++i;
1104*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(i, end);
1105*da0073e9SAndroid Build Coastguard Worker   ++i;
1106*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(i, end);
1107*da0073e9SAndroid Build Coastguard Worker }
1108*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,IteratorsShareState)1109*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, IteratorsShareState) {
1110*da0073e9SAndroid Build Coastguard Worker   DummyDataset dataset;
1111*da0073e9SAndroid Build Coastguard Worker   auto data_loader =
1112*da0073e9SAndroid Build Coastguard Worker       torch::data::make_data_loader(dataset, dataset.size().value() / 2);
1113*da0073e9SAndroid Build Coastguard Worker   auto i = data_loader->begin();
1114*da0073e9SAndroid Build Coastguard Worker   auto j = i;
1115*da0073e9SAndroid Build Coastguard Worker   auto end = data_loader->end();
1116*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(i, end);
1117*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(j, end);
1118*da0073e9SAndroid Build Coastguard Worker   ++i;
1119*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(i, end);
1120*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(j, end);
1121*da0073e9SAndroid Build Coastguard Worker   ++j;
1122*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(i, end);
1123*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(j, end);
1124*da0073e9SAndroid Build Coastguard Worker }
1125*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,CanDereferenceIteratorMultipleTimes)1126*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, CanDereferenceIteratorMultipleTimes) {
1127*da0073e9SAndroid Build Coastguard Worker   DummyDataset dataset;
1128*da0073e9SAndroid Build Coastguard Worker   auto data_loader =
1129*da0073e9SAndroid Build Coastguard Worker       torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
1130*da0073e9SAndroid Build Coastguard Worker           dataset,
1131*da0073e9SAndroid Build Coastguard Worker           // NOLINTNEXTLINE(bugprone-argument-comment)
1132*da0073e9SAndroid Build Coastguard Worker           /*batch_size=*/1);
1133*da0073e9SAndroid Build Coastguard Worker   auto iterator = data_loader->begin();
1134*da0073e9SAndroid Build Coastguard Worker   std::vector<int> expected = {1};
1135*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(*iterator, expected);
1136*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(*iterator, expected);
1137*da0073e9SAndroid Build Coastguard Worker   ++iterator;
1138*da0073e9SAndroid Build Coastguard Worker   expected[0] = 2;
1139*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(*iterator, expected);
1140*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(*iterator, expected);
1141*da0073e9SAndroid Build Coastguard Worker   ++iterator;
1142*da0073e9SAndroid Build Coastguard Worker   expected[0] = 3;
1143*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(*iterator, expected);
1144*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(*iterator, expected);
1145*da0073e9SAndroid Build Coastguard Worker }
1146*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,CanUseIteratorAlgorithms)1147*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, CanUseIteratorAlgorithms) {
1148*da0073e9SAndroid Build Coastguard Worker   struct D : datasets::BatchDataset<D, int> {
1149*da0073e9SAndroid Build Coastguard Worker     int get_batch(torch::ArrayRef<size_t> indices) override {
1150*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1151*da0073e9SAndroid Build Coastguard Worker       return 1 + indices.front();
1152*da0073e9SAndroid Build Coastguard Worker     }
1153*da0073e9SAndroid Build Coastguard Worker     torch::optional<size_t> size() const override {
1154*da0073e9SAndroid Build Coastguard Worker       return 10;
1155*da0073e9SAndroid Build Coastguard Worker     }
1156*da0073e9SAndroid Build Coastguard Worker   };
1157*da0073e9SAndroid Build Coastguard Worker 
1158*da0073e9SAndroid Build Coastguard Worker   D dataset;
1159*da0073e9SAndroid Build Coastguard Worker   auto data_loader =
1160*da0073e9SAndroid Build Coastguard Worker       torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
1161*da0073e9SAndroid Build Coastguard Worker           dataset, 1);
1162*da0073e9SAndroid Build Coastguard Worker   std::vector<int> values;
1163*da0073e9SAndroid Build Coastguard Worker   std::copy(
1164*da0073e9SAndroid Build Coastguard Worker       data_loader->begin(), data_loader->end(), std::back_inserter(values));
1165*da0073e9SAndroid Build Coastguard Worker   std::vector<int> expected(dataset.size().value());
1166*da0073e9SAndroid Build Coastguard Worker   std::iota(expected.begin(), expected.end(), size_t(1));
1167*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(values, expected);
1168*da0073e9SAndroid Build Coastguard Worker }
1169*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,CallingBeginWhileOtherIteratorIsInFlightThrows)1170*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, CallingBeginWhileOtherIteratorIsInFlightThrows) {
1171*da0073e9SAndroid Build Coastguard Worker   DummyDataset dataset;
1172*da0073e9SAndroid Build Coastguard Worker   auto data_loader =
1173*da0073e9SAndroid Build Coastguard Worker       torch::data::make_data_loader(dataset, DataLoaderOptions(1).workers(2));
1174*da0073e9SAndroid Build Coastguard Worker   auto i = data_loader->begin();
1175*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
1176*da0073e9SAndroid Build Coastguard Worker       data_loader->begin(),
1177*da0073e9SAndroid Build Coastguard Worker       "Attempted to get a new DataLoader iterator "
1178*da0073e9SAndroid Build Coastguard Worker       "while another iterator is not yet exhausted");
1179*da0073e9SAndroid Build Coastguard Worker }
1180*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,IncrementingExhaustedValidIteratorThrows)1181*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, IncrementingExhaustedValidIteratorThrows) {
1182*da0073e9SAndroid Build Coastguard Worker   DummyDataset dataset;
1183*da0073e9SAndroid Build Coastguard Worker   auto data_loader =
1184*da0073e9SAndroid Build Coastguard Worker       torch::data::make_data_loader(dataset, dataset.size().value());
1185*da0073e9SAndroid Build Coastguard Worker   auto i = data_loader->begin();
1186*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1187*da0073e9SAndroid Build Coastguard Worker   ASSERT_NO_THROW(++i);
1188*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(++i, "Attempted to increment iterator past the end");
1189*da0073e9SAndroid Build Coastguard Worker }
1190*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,DereferencingExhaustedValidIteratorThrows)1191*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, DereferencingExhaustedValidIteratorThrows) {
1192*da0073e9SAndroid Build Coastguard Worker   DummyDataset dataset;
1193*da0073e9SAndroid Build Coastguard Worker   auto data_loader =
1194*da0073e9SAndroid Build Coastguard Worker       torch::data::make_data_loader(dataset, dataset.size().value());
1195*da0073e9SAndroid Build Coastguard Worker   auto i = data_loader->begin();
1196*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1197*da0073e9SAndroid Build Coastguard Worker   ASSERT_NO_THROW(++i);
1198*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
1199*da0073e9SAndroid Build Coastguard Worker       *i, "Attempted to dereference iterator that was past the end");
1200*da0073e9SAndroid Build Coastguard Worker }
1201*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,IncrementingSentinelIteratorThrows)1202*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, IncrementingSentinelIteratorThrows) {
1203*da0073e9SAndroid Build Coastguard Worker   DummyDataset dataset;
1204*da0073e9SAndroid Build Coastguard Worker   auto data_loader =
1205*da0073e9SAndroid Build Coastguard Worker       torch::data::make_data_loader(dataset, dataset.size().value());
1206*da0073e9SAndroid Build Coastguard Worker   auto i = data_loader->end();
1207*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
1208*da0073e9SAndroid Build Coastguard Worker       ++i,
1209*da0073e9SAndroid Build Coastguard Worker       "Incrementing the DataLoader's past-the-end iterator is not allowed");
1210*da0073e9SAndroid Build Coastguard Worker }
1211*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,DereferencingSentinelIteratorThrows)1212*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, DereferencingSentinelIteratorThrows) {
1213*da0073e9SAndroid Build Coastguard Worker   DummyDataset dataset;
1214*da0073e9SAndroid Build Coastguard Worker   auto data_loader =
1215*da0073e9SAndroid Build Coastguard Worker       torch::data::make_data_loader(dataset, dataset.size().value());
1216*da0073e9SAndroid Build Coastguard Worker   auto i = data_loader->end();
1217*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
1218*da0073e9SAndroid Build Coastguard Worker       *i,
1219*da0073e9SAndroid Build Coastguard Worker       "Dereferencing the DataLoader's past-the-end iterator is not allowed");
1220*da0073e9SAndroid Build Coastguard Worker }
1221*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,YieldsCorrectBatchSize)1222*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, YieldsCorrectBatchSize) {
1223*da0073e9SAndroid Build Coastguard Worker   DummyDataset dataset;
1224*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(dataset, 25);
1225*da0073e9SAndroid Build Coastguard Worker   auto iterator = data_loader->begin();
1226*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(iterator->size(), 25);
1227*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ((++iterator)->size(), 25);
1228*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ((++iterator)->size(), 25);
1229*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ((++iterator)->size(), 25);
1230*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(++iterator, data_loader->end());
1231*da0073e9SAndroid Build Coastguard Worker }
1232*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,ReturnsLastBatchWhenSmallerThanBatchSizeWhenDropLastIsFalse)1233*da0073e9SAndroid Build Coastguard Worker TEST(
1234*da0073e9SAndroid Build Coastguard Worker     DataLoaderTest,
1235*da0073e9SAndroid Build Coastguard Worker     ReturnsLastBatchWhenSmallerThanBatchSizeWhenDropLastIsFalse) {
1236*da0073e9SAndroid Build Coastguard Worker   DummyDataset dataset;
1237*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
1238*da0073e9SAndroid Build Coastguard Worker       dataset, DataLoaderOptions(33).drop_last(false));
1239*da0073e9SAndroid Build Coastguard Worker   auto iterator = data_loader->begin();
1240*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(iterator->size(), 33);
1241*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ((++iterator)->size(), 33);
1242*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ((++iterator)->size(), 33);
1243*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ((++iterator)->size(), 1);
1244*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(++iterator, data_loader->end());
1245*da0073e9SAndroid Build Coastguard Worker }
1246*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,DoesNotReturnLastBatchWhenSmallerThanBatchSizeWhenDropLastIsTrue)1247*da0073e9SAndroid Build Coastguard Worker TEST(
1248*da0073e9SAndroid Build Coastguard Worker     DataLoaderTest,
1249*da0073e9SAndroid Build Coastguard Worker     DoesNotReturnLastBatchWhenSmallerThanBatchSizeWhenDropLastIsTrue) {
1250*da0073e9SAndroid Build Coastguard Worker   DummyDataset dataset;
1251*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
1252*da0073e9SAndroid Build Coastguard Worker       dataset, DataLoaderOptions(33).drop_last(true));
1253*da0073e9SAndroid Build Coastguard Worker   auto iterator = data_loader->begin();
1254*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(iterator->size(), 33);
1255*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ((++iterator)->size(), 33);
1256*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ((++iterator)->size(), 33);
1257*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(++iterator, data_loader->end());
1258*da0073e9SAndroid Build Coastguard Worker }
1259*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,RespectsTimeout)1260*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, RespectsTimeout) {
1261*da0073e9SAndroid Build Coastguard Worker   struct Baton {
1262*da0073e9SAndroid Build Coastguard Worker     std::condition_variable cv;
1263*da0073e9SAndroid Build Coastguard Worker     std::mutex mutex;
1264*da0073e9SAndroid Build Coastguard Worker   };
1265*da0073e9SAndroid Build Coastguard Worker 
1266*da0073e9SAndroid Build Coastguard Worker   struct D : datasets::Dataset<DummyDataset, int> {
1267*da0073e9SAndroid Build Coastguard Worker     D(std::shared_ptr<Baton> b) : baton(std::move(b)) {}
1268*da0073e9SAndroid Build Coastguard Worker     int get(size_t index) override {
1269*da0073e9SAndroid Build Coastguard Worker       std::unique_lock<std::mutex> lock(baton->mutex);
1270*da0073e9SAndroid Build Coastguard Worker       baton->cv.wait_for(lock, 1000 * kMillisecond);
1271*da0073e9SAndroid Build Coastguard Worker       return 0;
1272*da0073e9SAndroid Build Coastguard Worker     }
1273*da0073e9SAndroid Build Coastguard Worker     torch::optional<size_t> size() const override {
1274*da0073e9SAndroid Build Coastguard Worker       return 100;
1275*da0073e9SAndroid Build Coastguard Worker     }
1276*da0073e9SAndroid Build Coastguard Worker     std::shared_ptr<Baton> baton;
1277*da0073e9SAndroid Build Coastguard Worker   };
1278*da0073e9SAndroid Build Coastguard Worker 
1279*da0073e9SAndroid Build Coastguard Worker   auto baton = std::make_shared<Baton>();
1280*da0073e9SAndroid Build Coastguard Worker 
1281*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
1282*da0073e9SAndroid Build Coastguard Worker       D{baton}, DataLoaderOptions().workers(1).timeout(10 * kMillisecond));
1283*da0073e9SAndroid Build Coastguard Worker 
1284*da0073e9SAndroid Build Coastguard Worker   auto start = std::chrono::system_clock::now();
1285*da0073e9SAndroid Build Coastguard Worker 
1286*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(*data_loader->begin(), "Timeout");
1287*da0073e9SAndroid Build Coastguard Worker   baton->cv.notify_one();
1288*da0073e9SAndroid Build Coastguard Worker 
1289*da0073e9SAndroid Build Coastguard Worker   auto end = std::chrono::system_clock::now();
1290*da0073e9SAndroid Build Coastguard Worker   auto duration = std::chrono::duration_cast<std::chrono::seconds>(end - start);
1291*da0073e9SAndroid Build Coastguard Worker   ASSERT_LT(duration.count(), 1);
1292*da0073e9SAndroid Build Coastguard Worker }
1293*da0073e9SAndroid Build Coastguard Worker 
1294*da0073e9SAndroid Build Coastguard Worker // stackoverflow.com/questions/24465533/implementing-boostbarrier-in-c11
1295*da0073e9SAndroid Build Coastguard Worker struct Barrier {
BarrierBarrier1296*da0073e9SAndroid Build Coastguard Worker   explicit Barrier(size_t target) : counter_(target) {}
waitBarrier1297*da0073e9SAndroid Build Coastguard Worker   void wait() {
1298*da0073e9SAndroid Build Coastguard Worker     std::unique_lock<std::mutex> lock(mutex_);
1299*da0073e9SAndroid Build Coastguard Worker     if (--counter_ == 0) {
1300*da0073e9SAndroid Build Coastguard Worker       cv_.notify_all();
1301*da0073e9SAndroid Build Coastguard Worker     } else {
1302*da0073e9SAndroid Build Coastguard Worker       cv_.wait(lock, [this] { return this->counter_ == 0; });
1303*da0073e9SAndroid Build Coastguard Worker     }
1304*da0073e9SAndroid Build Coastguard Worker   }
1305*da0073e9SAndroid Build Coastguard Worker 
1306*da0073e9SAndroid Build Coastguard Worker   size_t counter_;
1307*da0073e9SAndroid Build Coastguard Worker   std::condition_variable cv_;
1308*da0073e9SAndroid Build Coastguard Worker   std::mutex mutex_;
1309*da0073e9SAndroid Build Coastguard Worker };
1310*da0073e9SAndroid Build Coastguard Worker 
1311*da0073e9SAndroid Build Coastguard Worker // On the OrderingTest: This test is intended to verify that the
1312*da0073e9SAndroid Build Coastguard Worker // `enforce_ordering` option of the dataloader works correctly. The reason this
1313*da0073e9SAndroid Build Coastguard Worker // flag exists is because when the dataloader has multiple workers (threads)
1314*da0073e9SAndroid Build Coastguard Worker // enabled and this flag is not set, the order in which worker threads finish
1315*da0073e9SAndroid Build Coastguard Worker // loading their respective batch and push it back to the dataloader's main
1316*da0073e9SAndroid Build Coastguard Worker // thread (for outside consumption) is not deterministic. Imagine the sampler is
1317*da0073e9SAndroid Build Coastguard Worker // a SequentialSampler with indices 0, 1, 2, 3. With batch size 1, each index
1318*da0073e9SAndroid Build Coastguard Worker // will be a single "job". Inside the dataloader, worker threads block until a
1319*da0073e9SAndroid Build Coastguard Worker // job is available. It is not deterministic which worker thread wakes up batch
1320*da0073e9SAndroid Build Coastguard Worker // to dequeue a particular batch. Further, some worker threads may take longer
1321*da0073e9SAndroid Build Coastguard Worker // than others to read the data for their index. As such, it could be that
1322*da0073e9SAndroid Build Coastguard Worker // worker thread 2 finishes before all other threads and returns its batch to
1323*da0073e9SAndroid Build Coastguard Worker // the main thread. In that case, the dataloader iterator would return the datum
1324*da0073e9SAndroid Build Coastguard Worker // at index 2 batch, and afterwards the datum from whatever thread finishes
1325*da0073e9SAndroid Build Coastguard Worker // next. As such, the user may see data from indices 2, 0, 3, 1. On another run
1326*da0073e9SAndroid Build Coastguard Worker // of the same dataloader on the same data, threads may be scheduled differently
1327*da0073e9SAndroid Build Coastguard Worker // and return in order 0, 2, 3, 1. To force this ordering to deterministically
1328*da0073e9SAndroid Build Coastguard Worker // be 0, 1, 2, 3, the `enforce_ordering` flag can be set to true. In that case,
1329*da0073e9SAndroid Build Coastguard Worker // the dataloader will use a *sequencer* internally which keeps track of which
1330*da0073e9SAndroid Build Coastguard Worker // datum is expected next, and buffers any other results until that next
1331*da0073e9SAndroid Build Coastguard Worker // expected value arrives. For example, workers 1, 2, 3 may finish before worker
1332*da0073e9SAndroid Build Coastguard Worker // 0. If `enforce_ordering` is true, the sequencer will internally buffer the
1333*da0073e9SAndroid Build Coastguard Worker // results from 1, 2, 3 until worker 0 finishes. Only then does the dataloader
1334*da0073e9SAndroid Build Coastguard Worker // return the datum from worker 0 to the user (and then datum 1 the next time,
1335*da0073e9SAndroid Build Coastguard Worker // then 2 and so on).
1336*da0073e9SAndroid Build Coastguard Worker //
1337*da0073e9SAndroid Build Coastguard Worker // The way the test works is that we start
1338*da0073e9SAndroid Build Coastguard Worker // `kNumberOfWorkers` workers in the dataloader, which each get an index from a
1339*da0073e9SAndroid Build Coastguard Worker // `SequentialSampler` in the range `0...kNumberOfWorkers-1`. Each worker thread
1340*da0073e9SAndroid Build Coastguard Worker // has a copy of the dataset, and thus `get_batch()` is called on the
1341*da0073e9SAndroid Build Coastguard Worker // thread-local copy in each worker. We want to simulate out-of-order completion
1342*da0073e9SAndroid Build Coastguard Worker // of these threads. For this, we batch set a barrier in the `get_batch()`
1343*da0073e9SAndroid Build Coastguard Worker // method to make sure every worker has some index to fetch assigned. Further,
1344*da0073e9SAndroid Build Coastguard Worker // each worker thread has a unique ID in `0...kNumberOfWorkers-1`.
1345*da0073e9SAndroid Build Coastguard Worker // There is a hard-coded ordering, `kOrderInWhichWorkersReturnTheirBatch`, in
1346*da0073e9SAndroid Build Coastguard Worker // which we want the worker threads to return. For this, an iterator into this
1347*da0073e9SAndroid Build Coastguard Worker // order is maintained. When the dereferenced iterator (the current order index)
1348*da0073e9SAndroid Build Coastguard Worker // matches the thread ID of a worker, it knows it can now return its index as
1349*da0073e9SAndroid Build Coastguard Worker // well as progress the iterator. Inside the dataloader, the sequencer should
1350*da0073e9SAndroid Build Coastguard Worker // buffer these indices such that they are ultimately returned in order.
1351*da0073e9SAndroid Build Coastguard Worker 
1352*da0073e9SAndroid Build Coastguard Worker namespace ordering_test {
1353*da0073e9SAndroid Build Coastguard Worker namespace {
1354*da0073e9SAndroid Build Coastguard Worker const size_t kNumberOfWorkers = 10;
1355*da0073e9SAndroid Build Coastguard Worker const std::vector<size_t> kOrderInWhichWorkersReturnTheirBatch =
1356*da0073e9SAndroid Build Coastguard Worker     {3, 7, 0, 5, 4, 8, 2, 1, 9, 6};
1357*da0073e9SAndroid Build Coastguard Worker } // namespace
1358*da0073e9SAndroid Build Coastguard Worker 
1359*da0073e9SAndroid Build Coastguard Worker struct Dataset : datasets::BatchDataset<Dataset, size_t> {
1360*da0073e9SAndroid Build Coastguard Worker   Dataset() = default;
1361*da0073e9SAndroid Build Coastguard Worker 
1362*da0073e9SAndroid Build Coastguard Worker   // This copy constructor will be called when we copy the dataset into a
1363*da0073e9SAndroid Build Coastguard Worker   // particular thread.
Datasetordering_test::Dataset1364*da0073e9SAndroid Build Coastguard Worker   Dataset(const Dataset& other) {
1365*da0073e9SAndroid Build Coastguard Worker     static std::atomic<size_t> counter{0};
1366*da0073e9SAndroid Build Coastguard Worker     thread_id_ = counter.fetch_add(1);
1367*da0073e9SAndroid Build Coastguard Worker   }
1368*da0073e9SAndroid Build Coastguard Worker 
1369*da0073e9SAndroid Build Coastguard Worker   Dataset(Dataset&& other) noexcept = default;
1370*da0073e9SAndroid Build Coastguard Worker   Dataset& operator=(const Dataset& other) = delete;
1371*da0073e9SAndroid Build Coastguard Worker   Dataset& operator=(Dataset&& other) noexcept = delete;
1372*da0073e9SAndroid Build Coastguard Worker 
get_batchordering_test::Dataset1373*da0073e9SAndroid Build Coastguard Worker   size_t get_batch(torch::ArrayRef<size_t> indices) override {
1374*da0073e9SAndroid Build Coastguard Worker     static Barrier barrier(kNumberOfWorkers);
1375*da0073e9SAndroid Build Coastguard Worker     static auto order_iterator = kOrderInWhichWorkersReturnTheirBatch.begin();
1376*da0073e9SAndroid Build Coastguard Worker     static std::condition_variable cv;
1377*da0073e9SAndroid Build Coastguard Worker     static std::mutex mutex;
1378*da0073e9SAndroid Build Coastguard Worker 
1379*da0073e9SAndroid Build Coastguard Worker     // Wait for all threads to get an index batch and arrive here.
1380*da0073e9SAndroid Build Coastguard Worker     barrier.wait();
1381*da0073e9SAndroid Build Coastguard Worker 
1382*da0073e9SAndroid Build Coastguard Worker     std::unique_lock<std::mutex> lock(mutex);
1383*da0073e9SAndroid Build Coastguard Worker     cv.wait(lock, [this] { return *order_iterator == this->thread_id_; });
1384*da0073e9SAndroid Build Coastguard Worker     ++order_iterator;
1385*da0073e9SAndroid Build Coastguard Worker     lock.unlock();
1386*da0073e9SAndroid Build Coastguard Worker     cv.notify_all();
1387*da0073e9SAndroid Build Coastguard Worker 
1388*da0073e9SAndroid Build Coastguard Worker     return indices.front();
1389*da0073e9SAndroid Build Coastguard Worker   }
1390*da0073e9SAndroid Build Coastguard Worker 
sizeordering_test::Dataset1391*da0073e9SAndroid Build Coastguard Worker   torch::optional<size_t> size() const override {
1392*da0073e9SAndroid Build Coastguard Worker     return kNumberOfWorkers;
1393*da0073e9SAndroid Build Coastguard Worker   }
1394*da0073e9SAndroid Build Coastguard Worker 
1395*da0073e9SAndroid Build Coastguard Worker   size_t thread_id_ = 0;
1396*da0073e9SAndroid Build Coastguard Worker };
1397*da0073e9SAndroid Build Coastguard Worker 
1398*da0073e9SAndroid Build Coastguard Worker } // namespace ordering_test
1399*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,EnforcesOrderingAmongThreadsWhenConfigured)1400*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, EnforcesOrderingAmongThreadsWhenConfigured) {
1401*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
1402*da0073e9SAndroid Build Coastguard Worker       ordering_test::Dataset{},
1403*da0073e9SAndroid Build Coastguard Worker       torch::data::samplers::SequentialSampler(ordering_test::kNumberOfWorkers),
1404*da0073e9SAndroid Build Coastguard Worker       DataLoaderOptions()
1405*da0073e9SAndroid Build Coastguard Worker           .batch_size(1)
1406*da0073e9SAndroid Build Coastguard Worker           .workers(ordering_test::kNumberOfWorkers)
1407*da0073e9SAndroid Build Coastguard Worker           .enforce_ordering(true));
1408*da0073e9SAndroid Build Coastguard Worker   std::vector<size_t> output;
1409*da0073e9SAndroid Build Coastguard Worker   for (size_t value : *data_loader) {
1410*da0073e9SAndroid Build Coastguard Worker     output.push_back(value);
1411*da0073e9SAndroid Build Coastguard Worker   }
1412*da0073e9SAndroid Build Coastguard Worker   std::vector<size_t> expected(ordering_test::kNumberOfWorkers);
1413*da0073e9SAndroid Build Coastguard Worker   std::iota(expected.begin(), expected.end(), size_t(0));
1414*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(expected, output);
1415*da0073e9SAndroid Build Coastguard Worker }
1416*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,Reset)1417*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, Reset) {
1418*da0073e9SAndroid Build Coastguard Worker   DummyDataset dataset;
1419*da0073e9SAndroid Build Coastguard Worker   auto data_loader =
1420*da0073e9SAndroid Build Coastguard Worker       torch::data::make_data_loader(dataset, dataset.size().value() / 2);
1421*da0073e9SAndroid Build Coastguard Worker   auto end = data_loader->end();
1422*da0073e9SAndroid Build Coastguard Worker 
1423*da0073e9SAndroid Build Coastguard Worker   auto iterator = data_loader->begin();
1424*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(iterator, end);
1425*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(++iterator, end);
1426*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(++iterator, end);
1427*da0073e9SAndroid Build Coastguard Worker 
1428*da0073e9SAndroid Build Coastguard Worker   iterator = data_loader->begin();
1429*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(iterator, end);
1430*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(++iterator, end);
1431*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(++iterator, end);
1432*da0073e9SAndroid Build Coastguard Worker 
1433*da0073e9SAndroid Build Coastguard Worker   iterator = data_loader->begin();
1434*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(iterator, end);
1435*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(++iterator, end);
1436*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(++iterator, end);
1437*da0073e9SAndroid Build Coastguard Worker }
1438*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,TestExceptionsArePropagatedFromWorkers)1439*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) {
1440*da0073e9SAndroid Build Coastguard Worker   struct D : datasets::Dataset<DummyDataset, int> {
1441*da0073e9SAndroid Build Coastguard Worker     int get(size_t index) override {
1442*da0073e9SAndroid Build Coastguard Worker       throw std::invalid_argument("badness");
1443*da0073e9SAndroid Build Coastguard Worker     }
1444*da0073e9SAndroid Build Coastguard Worker     torch::optional<size_t> size() const override {
1445*da0073e9SAndroid Build Coastguard Worker       return 100;
1446*da0073e9SAndroid Build Coastguard Worker     }
1447*da0073e9SAndroid Build Coastguard Worker   };
1448*da0073e9SAndroid Build Coastguard Worker 
1449*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
1450*da0073e9SAndroid Build Coastguard Worker       D{}, samplers::RandomSampler(100), DataLoaderOptions().workers(2));
1451*da0073e9SAndroid Build Coastguard Worker   auto iterator = data_loader->begin();
1452*da0073e9SAndroid Build Coastguard Worker 
1453*da0073e9SAndroid Build Coastguard Worker   try {
1454*da0073e9SAndroid Build Coastguard Worker     (void)*iterator;
1455*da0073e9SAndroid Build Coastguard Worker   } catch (torch::data::WorkerException& e) {
1456*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(
1457*da0073e9SAndroid Build Coastguard Worker         e.what(),
1458*da0073e9SAndroid Build Coastguard Worker         std::string("Caught exception in DataLoader worker thread. "
1459*da0073e9SAndroid Build Coastguard Worker                     "Original message: badness"));
1460*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1461*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROW(
1462*da0073e9SAndroid Build Coastguard Worker         std::rethrow_exception(e.original_exception), std::invalid_argument);
1463*da0073e9SAndroid Build Coastguard Worker   }
1464*da0073e9SAndroid Build Coastguard Worker }
1465*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,StatefulDatasetWithNoWorkers)1466*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
1467*da0073e9SAndroid Build Coastguard Worker   const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1468*da0073e9SAndroid Build Coastguard Worker 
1469*da0073e9SAndroid Build Coastguard Worker   struct D : datasets::StatefulDataset<D, int, size_t> {
1470*da0073e9SAndroid Build Coastguard Worker     torch::optional<int> get_batch(size_t) override {
1471*da0073e9SAndroid Build Coastguard Worker       if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1472*da0073e9SAndroid Build Coastguard Worker         return counter++;
1473*da0073e9SAndroid Build Coastguard Worker       }
1474*da0073e9SAndroid Build Coastguard Worker       return torch::nullopt;
1475*da0073e9SAndroid Build Coastguard Worker     }
1476*da0073e9SAndroid Build Coastguard Worker     torch::optional<size_t> size() const override {
1477*da0073e9SAndroid Build Coastguard Worker       return 100;
1478*da0073e9SAndroid Build Coastguard Worker     }
1479*da0073e9SAndroid Build Coastguard Worker     void reset() override {
1480*da0073e9SAndroid Build Coastguard Worker       counter = 0;
1481*da0073e9SAndroid Build Coastguard Worker     }
1482*da0073e9SAndroid Build Coastguard Worker     void save(torch::serialize::OutputArchive& archive) const override{};
1483*da0073e9SAndroid Build Coastguard Worker     void load(torch::serialize::InputArchive& archive) override {}
1484*da0073e9SAndroid Build Coastguard Worker     int counter = 0;
1485*da0073e9SAndroid Build Coastguard Worker   };
1486*da0073e9SAndroid Build Coastguard Worker 
1487*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(D{});
1488*da0073e9SAndroid Build Coastguard Worker 
1489*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(10)) {
1490*da0073e9SAndroid Build Coastguard Worker     const auto number_of_iterations =
1491*da0073e9SAndroid Build Coastguard Worker         std::distance(data_loader->begin(), data_loader->end());
1492*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(
1493*da0073e9SAndroid Build Coastguard Worker         number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1494*da0073e9SAndroid Build Coastguard Worker         << "epoch " << i;
1495*da0073e9SAndroid Build Coastguard Worker   }
1496*da0073e9SAndroid Build Coastguard Worker 
1497*da0073e9SAndroid Build Coastguard Worker   for (const int i : *data_loader) {
1498*da0073e9SAndroid Build Coastguard Worker     ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
1499*da0073e9SAndroid Build Coastguard Worker   }
1500*da0073e9SAndroid Build Coastguard Worker }
1501*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,StatefulDatasetWithManyWorkers)1502*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
1503*da0073e9SAndroid Build Coastguard Worker   const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1504*da0073e9SAndroid Build Coastguard Worker   const int kNumberOfWorkers = 4;
1505*da0073e9SAndroid Build Coastguard Worker 
1506*da0073e9SAndroid Build Coastguard Worker   struct D : datasets::StatefulDataset<D, int, size_t> {
1507*da0073e9SAndroid Build Coastguard Worker     torch::optional<int> get_batch(size_t) override {
1508*da0073e9SAndroid Build Coastguard Worker       std::lock_guard<std::mutex> lock(mutex);
1509*da0073e9SAndroid Build Coastguard Worker       if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1510*da0073e9SAndroid Build Coastguard Worker         return counter++;
1511*da0073e9SAndroid Build Coastguard Worker       }
1512*da0073e9SAndroid Build Coastguard Worker       return torch::nullopt;
1513*da0073e9SAndroid Build Coastguard Worker     }
1514*da0073e9SAndroid Build Coastguard Worker     torch::optional<size_t> size() const override {
1515*da0073e9SAndroid Build Coastguard Worker       return 100;
1516*da0073e9SAndroid Build Coastguard Worker     }
1517*da0073e9SAndroid Build Coastguard Worker     void reset() override {
1518*da0073e9SAndroid Build Coastguard Worker       counter = 0;
1519*da0073e9SAndroid Build Coastguard Worker     }
1520*da0073e9SAndroid Build Coastguard Worker     void save(torch::serialize::OutputArchive& archive) const override{};
1521*da0073e9SAndroid Build Coastguard Worker     void load(torch::serialize::InputArchive& archive) override {}
1522*da0073e9SAndroid Build Coastguard Worker     int counter = 0;
1523*da0073e9SAndroid Build Coastguard Worker     std::mutex mutex;
1524*da0073e9SAndroid Build Coastguard Worker   };
1525*da0073e9SAndroid Build Coastguard Worker 
1526*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
1527*da0073e9SAndroid Build Coastguard Worker       torch::data::datasets::make_shared_dataset<D>(),
1528*da0073e9SAndroid Build Coastguard Worker       DataLoaderOptions().workers(kNumberOfWorkers));
1529*da0073e9SAndroid Build Coastguard Worker 
1530*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(10)) {
1531*da0073e9SAndroid Build Coastguard Worker     const auto number_of_iterations =
1532*da0073e9SAndroid Build Coastguard Worker         std::distance(data_loader->begin(), data_loader->end());
1533*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(
1534*da0073e9SAndroid Build Coastguard Worker         number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1535*da0073e9SAndroid Build Coastguard Worker         << "epoch " << i;
1536*da0073e9SAndroid Build Coastguard Worker   }
1537*da0073e9SAndroid Build Coastguard Worker 
1538*da0073e9SAndroid Build Coastguard Worker   for (const int i : *data_loader) {
1539*da0073e9SAndroid Build Coastguard Worker     ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
1540*da0073e9SAndroid Build Coastguard Worker   }
1541*da0073e9SAndroid Build Coastguard Worker }
1542*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,StatefulDatasetWithMap)1543*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, StatefulDatasetWithMap) {
1544*da0073e9SAndroid Build Coastguard Worker   const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1545*da0073e9SAndroid Build Coastguard Worker 
1546*da0073e9SAndroid Build Coastguard Worker   struct D : datasets::StatefulDataset<D, int, size_t> {
1547*da0073e9SAndroid Build Coastguard Worker     torch::optional<int> get_batch(size_t) override {
1548*da0073e9SAndroid Build Coastguard Worker       if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1549*da0073e9SAndroid Build Coastguard Worker         return counter++;
1550*da0073e9SAndroid Build Coastguard Worker       }
1551*da0073e9SAndroid Build Coastguard Worker       return torch::nullopt;
1552*da0073e9SAndroid Build Coastguard Worker     }
1553*da0073e9SAndroid Build Coastguard Worker     torch::optional<size_t> size() const override {
1554*da0073e9SAndroid Build Coastguard Worker       return 100;
1555*da0073e9SAndroid Build Coastguard Worker     }
1556*da0073e9SAndroid Build Coastguard Worker     void reset() override {
1557*da0073e9SAndroid Build Coastguard Worker       counter = 0;
1558*da0073e9SAndroid Build Coastguard Worker     }
1559*da0073e9SAndroid Build Coastguard Worker     void save(torch::serialize::OutputArchive& archive) const override{};
1560*da0073e9SAndroid Build Coastguard Worker     void load(torch::serialize::InputArchive& archive) override {}
1561*da0073e9SAndroid Build Coastguard Worker     int counter = 0;
1562*da0073e9SAndroid Build Coastguard Worker   };
1563*da0073e9SAndroid Build Coastguard Worker 
1564*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
1565*da0073e9SAndroid Build Coastguard Worker       D().map(transforms::BatchLambda<int, std::string>(
1566*da0073e9SAndroid Build Coastguard Worker                   [](int x) { return std::to_string(x); }))
1567*da0073e9SAndroid Build Coastguard Worker           .map(transforms::BatchLambda<std::string, torch::Tensor>(
1568*da0073e9SAndroid Build Coastguard Worker               [](const std::string& x) {
1569*da0073e9SAndroid Build Coastguard Worker                 return torch::tensor(static_cast<int64_t>(std::stoi(x)));
1570*da0073e9SAndroid Build Coastguard Worker               })),
1571*da0073e9SAndroid Build Coastguard Worker       DataLoaderOptions{});
1572*da0073e9SAndroid Build Coastguard Worker 
1573*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(10)) {
1574*da0073e9SAndroid Build Coastguard Worker     const auto number_of_iterations =
1575*da0073e9SAndroid Build Coastguard Worker         std::distance(data_loader->begin(), data_loader->end());
1576*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(
1577*da0073e9SAndroid Build Coastguard Worker         number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1578*da0073e9SAndroid Build Coastguard Worker         << "epoch " << i;
1579*da0073e9SAndroid Build Coastguard Worker   }
1580*da0073e9SAndroid Build Coastguard Worker 
1581*da0073e9SAndroid Build Coastguard Worker   for (const torch::Tensor& t : *data_loader) {
1582*da0073e9SAndroid Build Coastguard Worker     ASSERT_LT(t.item<int64_t>(), kNumberOfExamplesAfterWhichTheDatasetExhausts);
1583*da0073e9SAndroid Build Coastguard Worker   }
1584*da0073e9SAndroid Build Coastguard Worker }
1585*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,StatefulDatasetWithCollate)1586*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, StatefulDatasetWithCollate) {
1587*da0073e9SAndroid Build Coastguard Worker   const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1588*da0073e9SAndroid Build Coastguard Worker 
1589*da0073e9SAndroid Build Coastguard Worker   struct D : datasets::StatefulDataset<D> {
1590*da0073e9SAndroid Build Coastguard Worker     torch::optional<std::vector<Example<>>> get_batch(
1591*da0073e9SAndroid Build Coastguard Worker         size_t batch_size) override {
1592*da0073e9SAndroid Build Coastguard Worker       if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1593*da0073e9SAndroid Build Coastguard Worker         counter += batch_size;
1594*da0073e9SAndroid Build Coastguard Worker         std::vector<Example<>> batch(
1595*da0073e9SAndroid Build Coastguard Worker             /*count=*/batch_size,
1596*da0073e9SAndroid Build Coastguard Worker             Example<>{
1597*da0073e9SAndroid Build Coastguard Worker                 torch::ones(batch_size + 1), torch::zeros(batch_size - 1)});
1598*da0073e9SAndroid Build Coastguard Worker         return batch;
1599*da0073e9SAndroid Build Coastguard Worker       }
1600*da0073e9SAndroid Build Coastguard Worker       return torch::nullopt;
1601*da0073e9SAndroid Build Coastguard Worker     }
1602*da0073e9SAndroid Build Coastguard Worker     torch::optional<size_t> size() const override {
1603*da0073e9SAndroid Build Coastguard Worker       return 100;
1604*da0073e9SAndroid Build Coastguard Worker     }
1605*da0073e9SAndroid Build Coastguard Worker     void reset() override {
1606*da0073e9SAndroid Build Coastguard Worker       counter = 0;
1607*da0073e9SAndroid Build Coastguard Worker     }
1608*da0073e9SAndroid Build Coastguard Worker     void save(torch::serialize::OutputArchive& archive) const override{};
1609*da0073e9SAndroid Build Coastguard Worker     void load(torch::serialize::InputArchive& archive) override {}
1610*da0073e9SAndroid Build Coastguard Worker     int counter = 0;
1611*da0073e9SAndroid Build Coastguard Worker   };
1612*da0073e9SAndroid Build Coastguard Worker 
1613*da0073e9SAndroid Build Coastguard Worker   auto d = D().map(transforms::Stack<Example<>>());
1614*da0073e9SAndroid Build Coastguard Worker 
1615*da0073e9SAndroid Build Coastguard Worker   const size_t kBatchSize = 5;
1616*da0073e9SAndroid Build Coastguard Worker 
1617*da0073e9SAndroid Build Coastguard Worker   // Notice that the `get_batch()` of the dataset returns a vector<Example>, but
1618*da0073e9SAndroid Build Coastguard Worker   // the `Stack` collation stacks the tensors into one.
1619*da0073e9SAndroid Build Coastguard Worker   torch::optional<Example<>> batch = d.get_batch(kBatchSize);
1620*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(batch.has_value());
1621*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(batch->data.size(0), kBatchSize);
1622*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(batch->data.size(1), kBatchSize + 1);
1623*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(batch->target.size(0), kBatchSize);
1624*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(batch->target.size(1), kBatchSize - 1);
1625*da0073e9SAndroid Build Coastguard Worker 
1626*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(batch->data[0].allclose(torch::ones(kBatchSize + 1)));
1627*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(batch->target[0].allclose(torch::zeros(kBatchSize - 1)));
1628*da0073e9SAndroid Build Coastguard Worker }
1629*da0073e9SAndroid Build Coastguard Worker 
1630*da0073e9SAndroid Build Coastguard Worker // This test tests the core function for iterate through a chunk dataset. It
1631*da0073e9SAndroid Build Coastguard Worker // contains test cases with different parameter combination. (For example,
1632*da0073e9SAndroid Build Coastguard Worker // different prefetch count, batch size and data loader worker count). It
1633*da0073e9SAndroid Build Coastguard Worker // verifies the return batches size and content when the order is deterministic.
TEST(DataLoaderTest,ChunkDataSetGetBatch)1634*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, ChunkDataSetGetBatch) {
1635*da0073e9SAndroid Build Coastguard Worker   // different prefetch count for testing.
1636*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1637*da0073e9SAndroid Build Coastguard Worker   const size_t prefetch_counts[] = {1, 2, 3, 4};
1638*da0073e9SAndroid Build Coastguard Worker 
1639*da0073e9SAndroid Build Coastguard Worker   // different batch size for testing.
1640*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1641*da0073e9SAndroid Build Coastguard Worker   const size_t batch_sizes[] = {5, 7};
1642*da0073e9SAndroid Build Coastguard Worker 
1643*da0073e9SAndroid Build Coastguard Worker   // test with/without worker threads
1644*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1645*da0073e9SAndroid Build Coastguard Worker   const size_t dataloader_worker_counts[] = {0, 2};
1646*da0073e9SAndroid Build Coastguard Worker 
1647*da0073e9SAndroid Build Coastguard Worker   const size_t total_example_count = 35;
1648*da0073e9SAndroid Build Coastguard Worker   DummyChunkDataReader data_reader;
1649*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler sampler(0);
1650*da0073e9SAndroid Build Coastguard Worker 
1651*da0073e9SAndroid Build Coastguard Worker   // test functionality across epoch boundary
1652*da0073e9SAndroid Build Coastguard Worker   const int epoch_count = 2;
1653*da0073e9SAndroid Build Coastguard Worker 
1654*da0073e9SAndroid Build Coastguard Worker   for (auto prefetch_count : prefetch_counts) {
1655*da0073e9SAndroid Build Coastguard Worker     for (auto batch_size : batch_sizes) {
1656*da0073e9SAndroid Build Coastguard Worker       for (auto dataloader_worker_count : dataloader_worker_counts) {
1657*da0073e9SAndroid Build Coastguard Worker         datasets::SharedBatchDataset<datasets::ChunkDataset<
1658*da0073e9SAndroid Build Coastguard Worker             DummyChunkDataReader,
1659*da0073e9SAndroid Build Coastguard Worker             samplers::SequentialSampler,
1660*da0073e9SAndroid Build Coastguard Worker             samplers::SequentialSampler>>
1661*da0073e9SAndroid Build Coastguard Worker             dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1662*da0073e9SAndroid Build Coastguard Worker                 DummyChunkDataReader,
1663*da0073e9SAndroid Build Coastguard Worker                 samplers::SequentialSampler,
1664*da0073e9SAndroid Build Coastguard Worker                 samplers::SequentialSampler>>(
1665*da0073e9SAndroid Build Coastguard Worker                 data_reader,
1666*da0073e9SAndroid Build Coastguard Worker                 sampler,
1667*da0073e9SAndroid Build Coastguard Worker                 sampler,
1668*da0073e9SAndroid Build Coastguard Worker                 datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1669*da0073e9SAndroid Build Coastguard Worker 
1670*da0073e9SAndroid Build Coastguard Worker         auto data_loader = torch::data::make_data_loader(
1671*da0073e9SAndroid Build Coastguard Worker             dataset,
1672*da0073e9SAndroid Build Coastguard Worker             DataLoaderOptions(batch_size).workers(dataloader_worker_count));
1673*da0073e9SAndroid Build Coastguard Worker 
1674*da0073e9SAndroid Build Coastguard Worker         for (const auto epoch_index : c10::irange(epoch_count)) {
1675*da0073e9SAndroid Build Coastguard Worker           (void)epoch_index; // Suppress unused variable warning
1676*da0073e9SAndroid Build Coastguard Worker           std::vector<bool> result(total_example_count, false);
1677*da0073e9SAndroid Build Coastguard Worker           int iteration_count = 0;
1678*da0073e9SAndroid Build Coastguard Worker           for (auto iterator = data_loader->begin();
1679*da0073e9SAndroid Build Coastguard Worker                iterator != data_loader->end();
1680*da0073e9SAndroid Build Coastguard Worker                ++iterator, ++iteration_count) {
1681*da0073e9SAndroid Build Coastguard Worker             DummyChunkDataReader::BatchType& batch = *iterator;
1682*da0073e9SAndroid Build Coastguard Worker             ASSERT_EQ(batch.size(), batch_size);
1683*da0073e9SAndroid Build Coastguard Worker 
1684*da0073e9SAndroid Build Coastguard Worker             // When prefetch_count is equal to 1 and no worker thread, the batch
1685*da0073e9SAndroid Build Coastguard Worker             // order is deterministic. So we can verify elements in each batch.
1686*da0073e9SAndroid Build Coastguard Worker             if (prefetch_count == 1 && dataloader_worker_count == 0) {
1687*da0073e9SAndroid Build Coastguard Worker               for (const auto j : c10::irange(batch_size)) {
1688*da0073e9SAndroid Build Coastguard Worker                 ASSERT_EQ(batch[j], iteration_count * batch_size + j);
1689*da0073e9SAndroid Build Coastguard Worker               }
1690*da0073e9SAndroid Build Coastguard Worker             }
1691*da0073e9SAndroid Build Coastguard Worker             for (const auto j : c10::irange(batch_size)) {
1692*da0073e9SAndroid Build Coastguard Worker               result[batch[j]] = true;
1693*da0073e9SAndroid Build Coastguard Worker             }
1694*da0073e9SAndroid Build Coastguard Worker           }
1695*da0073e9SAndroid Build Coastguard Worker 
1696*da0073e9SAndroid Build Coastguard Worker           for (auto data : result) {
1697*da0073e9SAndroid Build Coastguard Worker             ASSERT_EQ(data, true);
1698*da0073e9SAndroid Build Coastguard Worker           }
1699*da0073e9SAndroid Build Coastguard Worker         }
1700*da0073e9SAndroid Build Coastguard Worker       }
1701*da0073e9SAndroid Build Coastguard Worker     }
1702*da0073e9SAndroid Build Coastguard Worker   }
1703*da0073e9SAndroid Build Coastguard Worker }
1704*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,ChunkDataSetWithBatchSizeMismatch)1705*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, ChunkDataSetWithBatchSizeMismatch) {
1706*da0073e9SAndroid Build Coastguard Worker   const size_t prefetch_count = 1;
1707*da0073e9SAndroid Build Coastguard Worker   const size_t batch_size = 5;
1708*da0073e9SAndroid Build Coastguard Worker   const size_t requested_batch_size = 6;
1709*da0073e9SAndroid Build Coastguard Worker 
1710*da0073e9SAndroid Build Coastguard Worker   DummyChunkDataReader data_reader;
1711*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler sampler(0);
1712*da0073e9SAndroid Build Coastguard Worker 
1713*da0073e9SAndroid Build Coastguard Worker   datasets::SharedBatchDataset<datasets::ChunkDataset<
1714*da0073e9SAndroid Build Coastguard Worker       DummyChunkDataReader,
1715*da0073e9SAndroid Build Coastguard Worker       samplers::SequentialSampler,
1716*da0073e9SAndroid Build Coastguard Worker       samplers::SequentialSampler>>
1717*da0073e9SAndroid Build Coastguard Worker       dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1718*da0073e9SAndroid Build Coastguard Worker           DummyChunkDataReader,
1719*da0073e9SAndroid Build Coastguard Worker           samplers::SequentialSampler,
1720*da0073e9SAndroid Build Coastguard Worker           samplers::SequentialSampler>>(
1721*da0073e9SAndroid Build Coastguard Worker           data_reader,
1722*da0073e9SAndroid Build Coastguard Worker           sampler,
1723*da0073e9SAndroid Build Coastguard Worker           sampler,
1724*da0073e9SAndroid Build Coastguard Worker           datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1725*da0073e9SAndroid Build Coastguard Worker 
1726*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
1727*da0073e9SAndroid Build Coastguard Worker       dataset, DataLoaderOptions(requested_batch_size).workers(0));
1728*da0073e9SAndroid Build Coastguard Worker 
1729*da0073e9SAndroid Build Coastguard Worker   std::string exception_msg =
1730*da0073e9SAndroid Build Coastguard Worker       "The requested batch size does not match with the initialized batch "
1731*da0073e9SAndroid Build Coastguard Worker       "size.\n The requested batch size is 6, while the dataset is created"
1732*da0073e9SAndroid Build Coastguard Worker       " with batch size equal to 5";
1733*da0073e9SAndroid Build Coastguard Worker 
1734*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(*(data_loader->begin()), exception_msg);
1735*da0073e9SAndroid Build Coastguard Worker }
1736*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,ChunkDataSetWithEmptyBatch)1737*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) {
1738*da0073e9SAndroid Build Coastguard Worker   struct DummyEmptyChunkDataReader : datasets::ChunkDataReader<int> {
1739*da0073e9SAndroid Build Coastguard Worker    public:
1740*da0073e9SAndroid Build Coastguard Worker     using BatchType = datasets::ChunkDataReader<int>::ChunkType;
1741*da0073e9SAndroid Build Coastguard Worker 
1742*da0073e9SAndroid Build Coastguard Worker     BatchType read_chunk(size_t chunk_index) override {
1743*da0073e9SAndroid Build Coastguard Worker       return {};
1744*da0073e9SAndroid Build Coastguard Worker     }
1745*da0073e9SAndroid Build Coastguard Worker 
1746*da0073e9SAndroid Build Coastguard Worker     size_t chunk_count() override {
1747*da0073e9SAndroid Build Coastguard Worker       return 1;
1748*da0073e9SAndroid Build Coastguard Worker     };
1749*da0073e9SAndroid Build Coastguard Worker 
1750*da0073e9SAndroid Build Coastguard Worker     void reset() override{};
1751*da0073e9SAndroid Build Coastguard Worker   };
1752*da0073e9SAndroid Build Coastguard Worker 
1753*da0073e9SAndroid Build Coastguard Worker   const size_t prefetch_count = 1;
1754*da0073e9SAndroid Build Coastguard Worker   const size_t batch_size = 5;
1755*da0073e9SAndroid Build Coastguard Worker   DummyEmptyChunkDataReader data_reader;
1756*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler sampler(0);
1757*da0073e9SAndroid Build Coastguard Worker 
1758*da0073e9SAndroid Build Coastguard Worker   datasets::SharedBatchDataset<datasets::ChunkDataset<
1759*da0073e9SAndroid Build Coastguard Worker       DummyEmptyChunkDataReader,
1760*da0073e9SAndroid Build Coastguard Worker       samplers::SequentialSampler,
1761*da0073e9SAndroid Build Coastguard Worker       samplers::SequentialSampler>>
1762*da0073e9SAndroid Build Coastguard Worker       dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1763*da0073e9SAndroid Build Coastguard Worker           DummyEmptyChunkDataReader,
1764*da0073e9SAndroid Build Coastguard Worker           samplers::SequentialSampler,
1765*da0073e9SAndroid Build Coastguard Worker           samplers::SequentialSampler>>(
1766*da0073e9SAndroid Build Coastguard Worker           data_reader,
1767*da0073e9SAndroid Build Coastguard Worker           sampler,
1768*da0073e9SAndroid Build Coastguard Worker           sampler,
1769*da0073e9SAndroid Build Coastguard Worker           datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1770*da0073e9SAndroid Build Coastguard Worker 
1771*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
1772*da0073e9SAndroid Build Coastguard Worker       dataset, DataLoaderOptions(batch_size).workers(0));
1773*da0073e9SAndroid Build Coastguard Worker 
1774*da0073e9SAndroid Build Coastguard Worker   for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1775*da0073e9SAndroid Build Coastguard Worker        ++iterator) {
1776*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(iterator->size(), 0);
1777*da0073e9SAndroid Build Coastguard Worker   }
1778*da0073e9SAndroid Build Coastguard Worker }
1779*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,ChunkDataSetGetBatchWithUnevenBatchSize)1780*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {
1781*da0073e9SAndroid Build Coastguard Worker   struct D : public datasets::ChunkDataReader<int> {
1782*da0073e9SAndroid Build Coastguard Worker    public:
1783*da0073e9SAndroid Build Coastguard Worker     using BatchType = datasets::ChunkDataReader<int>::ChunkType;
1784*da0073e9SAndroid Build Coastguard Worker 
1785*da0073e9SAndroid Build Coastguard Worker     BatchType read_chunk(size_t chunk_index) override {
1786*da0073e9SAndroid Build Coastguard Worker       BatchType batch_data(10, 0);
1787*da0073e9SAndroid Build Coastguard Worker       return batch_data;
1788*da0073e9SAndroid Build Coastguard Worker     }
1789*da0073e9SAndroid Build Coastguard Worker 
1790*da0073e9SAndroid Build Coastguard Worker     size_t chunk_count() override {
1791*da0073e9SAndroid Build Coastguard Worker       return 2;
1792*da0073e9SAndroid Build Coastguard Worker     };
1793*da0073e9SAndroid Build Coastguard Worker 
1794*da0073e9SAndroid Build Coastguard Worker     void reset() override{};
1795*da0073e9SAndroid Build Coastguard Worker   };
1796*da0073e9SAndroid Build Coastguard Worker 
1797*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1798*da0073e9SAndroid Build Coastguard Worker   const size_t batch_sizes[] = {17, 30};
1799*da0073e9SAndroid Build Coastguard Worker   D data_reader;
1800*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler sampler(0);
1801*da0073e9SAndroid Build Coastguard Worker 
1802*da0073e9SAndroid Build Coastguard Worker   for (auto batch_size : batch_sizes) {
1803*da0073e9SAndroid Build Coastguard Worker     datasets::SharedBatchDataset<datasets::ChunkDataset<
1804*da0073e9SAndroid Build Coastguard Worker         D,
1805*da0073e9SAndroid Build Coastguard Worker         samplers::SequentialSampler,
1806*da0073e9SAndroid Build Coastguard Worker         samplers::SequentialSampler>>
1807*da0073e9SAndroid Build Coastguard Worker         dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1808*da0073e9SAndroid Build Coastguard Worker             D,
1809*da0073e9SAndroid Build Coastguard Worker             samplers::SequentialSampler,
1810*da0073e9SAndroid Build Coastguard Worker             samplers::SequentialSampler>>(
1811*da0073e9SAndroid Build Coastguard Worker             data_reader,
1812*da0073e9SAndroid Build Coastguard Worker             sampler,
1813*da0073e9SAndroid Build Coastguard Worker             sampler,
1814*da0073e9SAndroid Build Coastguard Worker             datasets::ChunkDatasetOptions(1, batch_size));
1815*da0073e9SAndroid Build Coastguard Worker 
1816*da0073e9SAndroid Build Coastguard Worker     auto data_loader = torch::data::make_data_loader(
1817*da0073e9SAndroid Build Coastguard Worker         dataset, DataLoaderOptions(batch_size).workers(0));
1818*da0073e9SAndroid Build Coastguard Worker 
1819*da0073e9SAndroid Build Coastguard Worker     for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1820*da0073e9SAndroid Build Coastguard Worker          ++iterator) {
1821*da0073e9SAndroid Build Coastguard Worker       DummyChunkDataReader::BatchType batch = *iterator;
1822*da0073e9SAndroid Build Coastguard Worker       auto batch_size = batch.size();
1823*da0073e9SAndroid Build Coastguard Worker       if (batch_size == 17) {
1824*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(batch.size() == 17 || batch.size() == 3);
1825*da0073e9SAndroid Build Coastguard Worker       }
1826*da0073e9SAndroid Build Coastguard Worker       if (batch_size == 30) {
1827*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(batch.size() == 20);
1828*da0073e9SAndroid Build Coastguard Worker       }
1829*da0073e9SAndroid Build Coastguard Worker     }
1830*da0073e9SAndroid Build Coastguard Worker   }
1831*da0073e9SAndroid Build Coastguard Worker }
1832*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,CanAccessChunkSamplerWithChunkDataSet)1833*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, CanAccessChunkSamplerWithChunkDataSet) {
1834*da0073e9SAndroid Build Coastguard Worker   const size_t prefetch_count = 2;
1835*da0073e9SAndroid Build Coastguard Worker   const size_t batch_size = 5;
1836*da0073e9SAndroid Build Coastguard Worker 
1837*da0073e9SAndroid Build Coastguard Worker   DummyChunkDataReader data_reader;
1838*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler sampler(0);
1839*da0073e9SAndroid Build Coastguard Worker   datasets::SharedBatchDataset<datasets::ChunkDataset<
1840*da0073e9SAndroid Build Coastguard Worker       DummyChunkDataReader,
1841*da0073e9SAndroid Build Coastguard Worker       samplers::SequentialSampler,
1842*da0073e9SAndroid Build Coastguard Worker       samplers::SequentialSampler>>
1843*da0073e9SAndroid Build Coastguard Worker       dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1844*da0073e9SAndroid Build Coastguard Worker           DummyChunkDataReader,
1845*da0073e9SAndroid Build Coastguard Worker           samplers::SequentialSampler,
1846*da0073e9SAndroid Build Coastguard Worker           samplers::SequentialSampler>>(
1847*da0073e9SAndroid Build Coastguard Worker           data_reader,
1848*da0073e9SAndroid Build Coastguard Worker           sampler,
1849*da0073e9SAndroid Build Coastguard Worker           sampler,
1850*da0073e9SAndroid Build Coastguard Worker           datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1851*da0073e9SAndroid Build Coastguard Worker 
1852*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();
1853*da0073e9SAndroid Build Coastguard Worker 
1854*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
1855*da0073e9SAndroid Build Coastguard Worker       dataset.map(transforms::BatchLambda<
1856*da0073e9SAndroid Build Coastguard Worker                   DummyChunkDataReader::BatchType,
1857*da0073e9SAndroid Build Coastguard Worker                   DummyChunkDataReader::DataType>(
1858*da0073e9SAndroid Build Coastguard Worker           [](DummyChunkDataReader::BatchType batch) {
1859*da0073e9SAndroid Build Coastguard Worker             return std::accumulate(batch.begin(), batch.end(), 0);
1860*da0073e9SAndroid Build Coastguard Worker           })),
1861*da0073e9SAndroid Build Coastguard Worker       DataLoaderOptions(batch_size).workers(0));
1862*da0073e9SAndroid Build Coastguard Worker 
1863*da0073e9SAndroid Build Coastguard Worker   // before we start, the index should be 0.
1864*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(chunk_sampler.index(), 0);
1865*da0073e9SAndroid Build Coastguard Worker 
1866*da0073e9SAndroid Build Coastguard Worker   size_t sum = 0;
1867*da0073e9SAndroid Build Coastguard Worker   for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1868*da0073e9SAndroid Build Coastguard Worker        ++iterator) {
1869*da0073e9SAndroid Build Coastguard Worker     sum += *iterator;
1870*da0073e9SAndroid Build Coastguard Worker   }
1871*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sum, 595); // sum([0, 35))
1872*da0073e9SAndroid Build Coastguard Worker   // 3 chunks, and when exhausted the value is already incremented.
1873*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(chunk_sampler.index(), 3);
1874*da0073e9SAndroid Build Coastguard Worker }
1875*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,ChunkDatasetDoesNotHang)1876*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, ChunkDatasetDoesNotHang) {
1877*da0073e9SAndroid Build Coastguard Worker   const size_t prefetch_count = 2;
1878*da0073e9SAndroid Build Coastguard Worker   const size_t batch_size = 5;
1879*da0073e9SAndroid Build Coastguard Worker   // this will make the preloaders to wait till the `get_batch()` calls.
1880*da0073e9SAndroid Build Coastguard Worker   const size_t cache_size = 10;
1881*da0073e9SAndroid Build Coastguard Worker 
1882*da0073e9SAndroid Build Coastguard Worker   DummyChunkDataReader data_reader;
1883*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler sampler(0);
1884*da0073e9SAndroid Build Coastguard Worker   datasets::SharedBatchDataset<datasets::ChunkDataset<
1885*da0073e9SAndroid Build Coastguard Worker       DummyChunkDataReader,
1886*da0073e9SAndroid Build Coastguard Worker       samplers::SequentialSampler,
1887*da0073e9SAndroid Build Coastguard Worker       samplers::SequentialSampler>>
1888*da0073e9SAndroid Build Coastguard Worker       dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1889*da0073e9SAndroid Build Coastguard Worker           DummyChunkDataReader,
1890*da0073e9SAndroid Build Coastguard Worker           samplers::SequentialSampler,
1891*da0073e9SAndroid Build Coastguard Worker           samplers::SequentialSampler>>(
1892*da0073e9SAndroid Build Coastguard Worker           data_reader,
1893*da0073e9SAndroid Build Coastguard Worker           sampler,
1894*da0073e9SAndroid Build Coastguard Worker           sampler,
1895*da0073e9SAndroid Build Coastguard Worker           datasets::ChunkDatasetOptions(
1896*da0073e9SAndroid Build Coastguard Worker               prefetch_count, batch_size, cache_size));
1897*da0073e9SAndroid Build Coastguard Worker 
1898*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
1899*da0073e9SAndroid Build Coastguard Worker       dataset.map(transforms::BatchLambda<
1900*da0073e9SAndroid Build Coastguard Worker                   DummyChunkDataReader::BatchType,
1901*da0073e9SAndroid Build Coastguard Worker                   DummyChunkDataReader::DataType>(
1902*da0073e9SAndroid Build Coastguard Worker           [](DummyChunkDataReader::BatchType batch) {
1903*da0073e9SAndroid Build Coastguard Worker             return std::accumulate(batch.begin(), batch.end(), 0);
1904*da0073e9SAndroid Build Coastguard Worker           })),
1905*da0073e9SAndroid Build Coastguard Worker       DataLoaderOptions(batch_size).workers(0));
1906*da0073e9SAndroid Build Coastguard Worker   // simply creates the iterator but no iteration. chunk preloaders are waiting
1907*da0073e9SAndroid Build Coastguard Worker   // to fill the batch buffer but it is not draining. Still we need to exit
1908*da0073e9SAndroid Build Coastguard Worker   // cleanly.
1909*da0073e9SAndroid Build Coastguard Worker   auto iterator = data_loader->begin();
1910*da0073e9SAndroid Build Coastguard Worker }
1911*da0073e9SAndroid Build Coastguard Worker 
1912*da0073e9SAndroid Build Coastguard Worker // Test ChunkDataset save function.
1913*da0073e9SAndroid Build Coastguard Worker // Note [save/load ChunkDataset as ChunkSampler]:
1914*da0073e9SAndroid Build Coastguard Worker // The chunk sampler inside ChunkDataset is used in a separate thread pool other
1915*da0073e9SAndroid Build Coastguard Worker // than the main thread. Thus it is very hard to accurately estimate its status
1916*da0073e9SAndroid Build Coastguard Worker // when ChunkDataset::save/ChunkDataset::load is called. For the pure purpose of
1917*da0073e9SAndroid Build Coastguard Worker // testing, we utilize the implementation fact that the file format for sampler
1918*da0073e9SAndroid Build Coastguard Worker // serialization is the same as ChunkDataset serialization, and manually control
1919*da0073e9SAndroid Build Coastguard Worker // the chunk sampler by calling the sampler's save/load method for value
1920*da0073e9SAndroid Build Coastguard Worker // validation. This is only for testing the specific save/load functionality. In
1921*da0073e9SAndroid Build Coastguard Worker // real user case, the user should still use matching ChunkDataset::save and
1922*da0073e9SAndroid Build Coastguard Worker // ChunkDataset::load method.
TEST(DataLoaderTest,ChunkDatasetSave)1923*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, ChunkDatasetSave) {
1924*da0073e9SAndroid Build Coastguard Worker   const size_t chunk_count_ = 6;
1925*da0073e9SAndroid Build Coastguard Worker   const size_t chunk_size = 10;
1926*da0073e9SAndroid Build Coastguard Worker 
1927*da0073e9SAndroid Build Coastguard Worker   struct DummyTestChunkDataReader : datasets::ChunkDataReader<int> {
1928*da0073e9SAndroid Build Coastguard Worker    public:
1929*da0073e9SAndroid Build Coastguard Worker     using BatchType = datasets::ChunkDataReader<int>::ChunkType;
1930*da0073e9SAndroid Build Coastguard Worker 
1931*da0073e9SAndroid Build Coastguard Worker     BatchType read_chunk(size_t chunk_index) override {
1932*da0073e9SAndroid Build Coastguard Worker       return batch_data_;
1933*da0073e9SAndroid Build Coastguard Worker     }
1934*da0073e9SAndroid Build Coastguard Worker 
1935*da0073e9SAndroid Build Coastguard Worker     size_t chunk_count() override {
1936*da0073e9SAndroid Build Coastguard Worker       return chunk_count_;
1937*da0073e9SAndroid Build Coastguard Worker     };
1938*da0073e9SAndroid Build Coastguard Worker 
1939*da0073e9SAndroid Build Coastguard Worker     void reset() override{};
1940*da0073e9SAndroid Build Coastguard Worker     BatchType batch_data_ = BatchType(chunk_size, 0);
1941*da0073e9SAndroid Build Coastguard Worker   };
1942*da0073e9SAndroid Build Coastguard Worker 
1943*da0073e9SAndroid Build Coastguard Worker   const size_t prefetch_count = 1;
1944*da0073e9SAndroid Build Coastguard Worker   const size_t batch_size = chunk_size;
1945*da0073e9SAndroid Build Coastguard Worker   const size_t dataloader_worker_count = 0;
1946*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler sampler(0);
1947*da0073e9SAndroid Build Coastguard Worker   const int epoch_count = 2;
1948*da0073e9SAndroid Build Coastguard Worker 
1949*da0073e9SAndroid Build Coastguard Worker   DummyTestChunkDataReader data_reader;
1950*da0073e9SAndroid Build Coastguard Worker 
1951*da0073e9SAndroid Build Coastguard Worker   // tested save_intervals
1952*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1953*da0073e9SAndroid Build Coastguard Worker   const size_t save_intervals[] = {1, 2};
1954*da0073e9SAndroid Build Coastguard Worker 
1955*da0073e9SAndroid Build Coastguard Worker   using datasets::ChunkDatasetOptions;
1956*da0073e9SAndroid Build Coastguard Worker 
1957*da0073e9SAndroid Build Coastguard Worker   for (auto save_interval : save_intervals) {
1958*da0073e9SAndroid Build Coastguard Worker     auto tempfile = c10::make_tempfile();
1959*da0073e9SAndroid Build Coastguard Worker 
1960*da0073e9SAndroid Build Coastguard Worker     datasets::SharedBatchDataset<datasets::ChunkDataset<
1961*da0073e9SAndroid Build Coastguard Worker         DummyTestChunkDataReader,
1962*da0073e9SAndroid Build Coastguard Worker         samplers::SequentialSampler,
1963*da0073e9SAndroid Build Coastguard Worker         samplers::SequentialSampler>>
1964*da0073e9SAndroid Build Coastguard Worker         dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1965*da0073e9SAndroid Build Coastguard Worker             DummyTestChunkDataReader,
1966*da0073e9SAndroid Build Coastguard Worker             samplers::SequentialSampler,
1967*da0073e9SAndroid Build Coastguard Worker             samplers::SequentialSampler>>(
1968*da0073e9SAndroid Build Coastguard Worker             data_reader,
1969*da0073e9SAndroid Build Coastguard Worker             sampler,
1970*da0073e9SAndroid Build Coastguard Worker             sampler,
1971*da0073e9SAndroid Build Coastguard Worker             ChunkDatasetOptions(
1972*da0073e9SAndroid Build Coastguard Worker                 prefetch_count, batch_size, chunk_size /*cache size*/));
1973*da0073e9SAndroid Build Coastguard Worker 
1974*da0073e9SAndroid Build Coastguard Worker     auto data_loader = torch::data::make_data_loader(
1975*da0073e9SAndroid Build Coastguard Worker         dataset,
1976*da0073e9SAndroid Build Coastguard Worker         DataLoaderOptions(batch_size).workers(dataloader_worker_count));
1977*da0073e9SAndroid Build Coastguard Worker 
1978*da0073e9SAndroid Build Coastguard Worker     for (const auto epoch_index : c10::irange(epoch_count)) {
1979*da0073e9SAndroid Build Coastguard Worker       (void)epoch_index; // Suppress unused variable warning
1980*da0073e9SAndroid Build Coastguard Worker       unsigned iteration_count = 0;
1981*da0073e9SAndroid Build Coastguard Worker       for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1982*da0073e9SAndroid Build Coastguard Worker            ++iterator, ++iteration_count) {
1983*da0073e9SAndroid Build Coastguard Worker         if ((iteration_count + 1) % save_interval == 0) {
1984*da0073e9SAndroid Build Coastguard Worker           torch::save(*dataset, tempfile.name);
1985*da0073e9SAndroid Build Coastguard Worker 
1986*da0073e9SAndroid Build Coastguard Worker           samplers::SequentialSampler new_sampler(0);
1987*da0073e9SAndroid Build Coastguard Worker 
1988*da0073e9SAndroid Build Coastguard Worker           // See Note [save/load ChunkDataset as ChunkSampler]
1989*da0073e9SAndroid Build Coastguard Worker           torch::load(new_sampler, tempfile.name);
1990*da0073e9SAndroid Build Coastguard Worker 
1991*da0073e9SAndroid Build Coastguard Worker           // Verify save logic. For ChunkDataset, the chunk data is stored in a
1992*da0073e9SAndroid Build Coastguard Worker           // cache inside the dataset. One pool of threads are constantly
1993*da0073e9SAndroid Build Coastguard Worker           // writing to the cache, and a different pool of thread are constantly
1994*da0073e9SAndroid Build Coastguard Worker           // reading from the cache. Due to the nature of asynchronization, at
1995*da0073e9SAndroid Build Coastguard Worker           // the time of get_batch(), which chunk is written to the cache is not
1996*da0073e9SAndroid Build Coastguard Worker           // fully deterministic.
1997*da0073e9SAndroid Build Coastguard Worker           // But we can still calculate a restricted window on the expected
1998*da0073e9SAndroid Build Coastguard Worker           // output, hence verify the logic. In this test, the cache size is
1999*da0073e9SAndroid Build Coastguard Worker           // configured to be the same as chunk size and batch size. So the
2000*da0073e9SAndroid Build Coastguard Worker           // chunk data is written to the cache one by one. Only the current
2001*da0073e9SAndroid Build Coastguard Worker           // batch is retrieved, the next chunk is written. Now in iteration 0,
2002*da0073e9SAndroid Build Coastguard Worker           // after the first batch is retrieved, when we save the dataset
2003*da0073e9SAndroid Build Coastguard Worker           // statues, there are three possible scenarios for the writer thread:
2004*da0073e9SAndroid Build Coastguard Worker           // 1. it hasn't started loading the next chunk data yet, so the
2005*da0073e9SAndroid Build Coastguard Worker           // sequential sampler index is still 0;
2006*da0073e9SAndroid Build Coastguard Worker           // 2. it started to load the second chunk, so the sequential sampler
2007*da0073e9SAndroid Build Coastguard Worker           // index is at 1;
2008*da0073e9SAndroid Build Coastguard Worker           // 3. it finished loading the second chunk, and start to load the
2009*da0073e9SAndroid Build Coastguard Worker           // third chunk, because the cache is still fully occupied by the data
2010*da0073e9SAndroid Build Coastguard Worker           // from the second chunk, it is waiting to write to the cache. At this
2011*da0073e9SAndroid Build Coastguard Worker           // point, the sampler index is at 2.
2012*da0073e9SAndroid Build Coastguard Worker           // So now we have a window of [0, 2], which is what we expected the
2013*da0073e9SAndroid Build Coastguard Worker           // sampler to save the index from. Now noted for sequential sampler,
2014*da0073e9SAndroid Build Coastguard Worker           // it advances to the next index automatically in the call next(). So
2015*da0073e9SAndroid Build Coastguard Worker           // when save the index, it saves the next index in stead of the
2016*da0073e9SAndroid Build Coastguard Worker           // current one. In other word, after getting the first index from
2017*da0073e9SAndroid Build Coastguard Worker           // sequential sampler, it already moves to the second index. So when
2018*da0073e9SAndroid Build Coastguard Worker           // we save it, it is the second index we save. As a result,
2019*da0073e9SAndroid Build Coastguard Worker           // we need to advance the window by one. Now we have the expected
2020*da0073e9SAndroid Build Coastguard Worker           // window of [1, 3].
2021*da0073e9SAndroid Build Coastguard Worker           // This analysis applies to all scenarios. So extend it to a more
2022*da0073e9SAndroid Build Coastguard Worker           // general case: the expected saved index should falling into the
2023*da0073e9SAndroid Build Coastguard Worker           // range of [iteration, iteration + 3], which is the validation
2024*da0073e9SAndroid Build Coastguard Worker           // below.
2025*da0073e9SAndroid Build Coastguard Worker           ASSERT_TRUE(
2026*da0073e9SAndroid Build Coastguard Worker               new_sampler.index() >= iteration_count + 1 &&
2027*da0073e9SAndroid Build Coastguard Worker               new_sampler.index() <= iteration_count + 3);
2028*da0073e9SAndroid Build Coastguard Worker         }
2029*da0073e9SAndroid Build Coastguard Worker       }
2030*da0073e9SAndroid Build Coastguard Worker     }
2031*da0073e9SAndroid Build Coastguard Worker   }
2032*da0073e9SAndroid Build Coastguard Worker }
2033*da0073e9SAndroid Build Coastguard Worker 
2034*da0073e9SAndroid Build Coastguard Worker // Test ChunkDataset load function.
TEST(DataLoaderTest,ChunkDatasetLoad)2035*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, ChunkDatasetLoad) {
2036*da0073e9SAndroid Build Coastguard Worker   auto tempfile = c10::make_tempfile();
2037*da0073e9SAndroid Build Coastguard Worker 
2038*da0073e9SAndroid Build Coastguard Worker   const size_t prefetch_count = 1;
2039*da0073e9SAndroid Build Coastguard Worker   const size_t batch_size = 10;
2040*da0073e9SAndroid Build Coastguard Worker   const size_t dataloader_worker_count = 0;
2041*da0073e9SAndroid Build Coastguard Worker 
2042*da0073e9SAndroid Build Coastguard Worker   DummyChunkDataReader data_reader;
2043*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler sampler(0);
2044*da0073e9SAndroid Build Coastguard Worker 
2045*da0073e9SAndroid Build Coastguard Worker   const size_t skipped_chunk = 2;
2046*da0073e9SAndroid Build Coastguard Worker 
2047*da0073e9SAndroid Build Coastguard Worker   // Configure sampler to skip 2 chunks
2048*da0073e9SAndroid Build Coastguard Worker   {
2049*da0073e9SAndroid Build Coastguard Worker     sampler.reset(data_reader.chunk_count());
2050*da0073e9SAndroid Build Coastguard Worker     sampler.next(skipped_chunk);
2051*da0073e9SAndroid Build Coastguard Worker 
2052*da0073e9SAndroid Build Coastguard Worker     // See Note [save/load ChunkDataset as ChunkSampler]
2053*da0073e9SAndroid Build Coastguard Worker     torch::save(sampler, tempfile.name);
2054*da0073e9SAndroid Build Coastguard Worker   }
2055*da0073e9SAndroid Build Coastguard Worker 
2056*da0073e9SAndroid Build Coastguard Worker   // test functionality across epoch boundary. The first epoch should be
2057*da0073e9SAndroid Build Coastguard Worker   // affected by the checkpoint, but the second should start normally.
2058*da0073e9SAndroid Build Coastguard Worker   const int epoch_count = 2;
2059*da0073e9SAndroid Build Coastguard Worker 
2060*da0073e9SAndroid Build Coastguard Worker   datasets::SharedBatchDataset<datasets::ChunkDataset<
2061*da0073e9SAndroid Build Coastguard Worker       DummyChunkDataReader,
2062*da0073e9SAndroid Build Coastguard Worker       samplers::SequentialSampler,
2063*da0073e9SAndroid Build Coastguard Worker       samplers::SequentialSampler>>
2064*da0073e9SAndroid Build Coastguard Worker       dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
2065*da0073e9SAndroid Build Coastguard Worker           DummyChunkDataReader,
2066*da0073e9SAndroid Build Coastguard Worker           samplers::SequentialSampler,
2067*da0073e9SAndroid Build Coastguard Worker           samplers::SequentialSampler>>(
2068*da0073e9SAndroid Build Coastguard Worker           data_reader,
2069*da0073e9SAndroid Build Coastguard Worker           sampler,
2070*da0073e9SAndroid Build Coastguard Worker           sampler,
2071*da0073e9SAndroid Build Coastguard Worker           datasets::ChunkDatasetOptions(
2072*da0073e9SAndroid Build Coastguard Worker               prefetch_count, batch_size, 20 /*cache size*/));
2073*da0073e9SAndroid Build Coastguard Worker 
2074*da0073e9SAndroid Build Coastguard Worker   torch::load(*dataset, tempfile.name);
2075*da0073e9SAndroid Build Coastguard Worker 
2076*da0073e9SAndroid Build Coastguard Worker   auto data_loader = torch::data::make_data_loader(
2077*da0073e9SAndroid Build Coastguard Worker       dataset, DataLoaderOptions(batch_size).workers(dataloader_worker_count));
2078*da0073e9SAndroid Build Coastguard Worker 
2079*da0073e9SAndroid Build Coastguard Worker   for (const auto epoch_index : c10::irange(epoch_count)) {
2080*da0073e9SAndroid Build Coastguard Worker     int iteration_count = 0;
2081*da0073e9SAndroid Build Coastguard Worker 
2082*da0073e9SAndroid Build Coastguard Worker     // For the first epoch, the returned batch should be returned from the
2083*da0073e9SAndroid Build Coastguard Worker     // third chunk, because the check point skipped the first two chunks. But
2084*da0073e9SAndroid Build Coastguard Worker     // for the next epoch, it should start from the first batch.
2085*da0073e9SAndroid Build Coastguard Worker     int initial_value = epoch_index == 0 ? 15 : 0;
2086*da0073e9SAndroid Build Coastguard Worker 
2087*da0073e9SAndroid Build Coastguard Worker     for (auto iterator = data_loader->begin(); iterator != data_loader->end();
2088*da0073e9SAndroid Build Coastguard Worker          ++iterator, ++iteration_count) {
2089*da0073e9SAndroid Build Coastguard Worker       DummyChunkDataReader::BatchType batch = *iterator;
2090*da0073e9SAndroid Build Coastguard Worker 
2091*da0073e9SAndroid Build Coastguard Worker       std::vector<int> expected_result;
2092*da0073e9SAndroid Build Coastguard Worker       size_t expected_size = (epoch_index > 0 && iteration_count == 3) ? 5 : 10;
2093*da0073e9SAndroid Build Coastguard Worker       expected_result.resize(expected_size);
2094*da0073e9SAndroid Build Coastguard Worker       std::iota(expected_result.begin(), expected_result.end(), initial_value);
2095*da0073e9SAndroid Build Coastguard Worker 
2096*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(batch.size(), expected_result.size());
2097*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(
2098*da0073e9SAndroid Build Coastguard Worker           std::equal(batch.begin(), batch.end(), expected_result.begin()));
2099*da0073e9SAndroid Build Coastguard Worker 
2100*da0073e9SAndroid Build Coastguard Worker       initial_value += batch_size;
2101*da0073e9SAndroid Build Coastguard Worker     }
2102*da0073e9SAndroid Build Coastguard Worker   }
2103*da0073e9SAndroid Build Coastguard Worker 
2104*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler new_sampler(0);
2105*da0073e9SAndroid Build Coastguard Worker 
2106*da0073e9SAndroid Build Coastguard Worker   // See Note [save/load ChunkDataset as ChunkSampler]
2107*da0073e9SAndroid Build Coastguard Worker   torch::load(new_sampler, tempfile.name);
2108*da0073e9SAndroid Build Coastguard Worker 
2109*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(new_sampler.index(), skipped_chunk);
2110*da0073e9SAndroid Build Coastguard Worker }
2111*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,ChunkDatasetCrossChunkShuffle)2112*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) {
2113*da0073e9SAndroid Build Coastguard Worker   const size_t chunk_size = 5;
2114*da0073e9SAndroid Build Coastguard Worker   const size_t batch_size = 5;
2115*da0073e9SAndroid Build Coastguard Worker 
2116*da0073e9SAndroid Build Coastguard Worker   class S : public samplers::Sampler<> {
2117*da0073e9SAndroid Build Coastguard Worker    public:
2118*da0073e9SAndroid Build Coastguard Worker     explicit S(size_t size) : size_(size), index_(0){};
2119*da0073e9SAndroid Build Coastguard Worker 
2120*da0073e9SAndroid Build Coastguard Worker     void reset(torch::optional<size_t> new_size = torch::nullopt) override {
2121*da0073e9SAndroid Build Coastguard Worker       if (new_size.has_value()) {
2122*da0073e9SAndroid Build Coastguard Worker         size_ = *new_size;
2123*da0073e9SAndroid Build Coastguard Worker       }
2124*da0073e9SAndroid Build Coastguard Worker       indices_.resize(size_);
2125*da0073e9SAndroid Build Coastguard Worker       size_t index = 0;
2126*da0073e9SAndroid Build Coastguard Worker 
2127*da0073e9SAndroid Build Coastguard Worker       // Repeatedly sample every 5 indices.
2128*da0073e9SAndroid Build Coastguard Worker       for (const auto i : c10::irange(batch_size)) {
2129*da0073e9SAndroid Build Coastguard Worker         for (size_t j = 0; j < size_ / batch_size; ++j) {
2130*da0073e9SAndroid Build Coastguard Worker           indices_[index++] = i + batch_size * j;
2131*da0073e9SAndroid Build Coastguard Worker         }
2132*da0073e9SAndroid Build Coastguard Worker       }
2133*da0073e9SAndroid Build Coastguard Worker       index_ = 0;
2134*da0073e9SAndroid Build Coastguard Worker     }
2135*da0073e9SAndroid Build Coastguard Worker 
2136*da0073e9SAndroid Build Coastguard Worker     // Returns the next batch of indices.
2137*da0073e9SAndroid Build Coastguard Worker     torch::optional<std::vector<size_t>> next(size_t batch_size) override {
2138*da0073e9SAndroid Build Coastguard Worker       const auto remaining_indices = size_ - index_;
2139*da0073e9SAndroid Build Coastguard Worker       if (remaining_indices == 0) {
2140*da0073e9SAndroid Build Coastguard Worker         return torch::nullopt;
2141*da0073e9SAndroid Build Coastguard Worker       }
2142*da0073e9SAndroid Build Coastguard Worker       auto return_size = std::min(batch_size, remaining_indices);
2143*da0073e9SAndroid Build Coastguard Worker       std::vector<size_t> index_batch(
2144*da0073e9SAndroid Build Coastguard Worker           indices_.begin() + index_, indices_.begin() + index_ + return_size);
2145*da0073e9SAndroid Build Coastguard Worker       index_ += return_size;
2146*da0073e9SAndroid Build Coastguard Worker 
2147*da0073e9SAndroid Build Coastguard Worker       return index_batch;
2148*da0073e9SAndroid Build Coastguard Worker     }
2149*da0073e9SAndroid Build Coastguard Worker 
2150*da0073e9SAndroid Build Coastguard Worker     void save(torch::serialize::OutputArchive& archive) const override {}
2151*da0073e9SAndroid Build Coastguard Worker     void load(torch::serialize::InputArchive& archive) override {}
2152*da0073e9SAndroid Build Coastguard Worker 
2153*da0073e9SAndroid Build Coastguard Worker    private:
2154*da0073e9SAndroid Build Coastguard Worker     size_t size_;
2155*da0073e9SAndroid Build Coastguard Worker     std::vector<size_t> indices_;
2156*da0073e9SAndroid Build Coastguard Worker     size_t index_{0};
2157*da0073e9SAndroid Build Coastguard Worker   };
2158*da0073e9SAndroid Build Coastguard Worker 
2159*da0073e9SAndroid Build Coastguard Worker   struct D : public datasets::ChunkDataReader<int> {
2160*da0073e9SAndroid Build Coastguard Worker    public:
2161*da0073e9SAndroid Build Coastguard Worker     using BatchType = datasets::ChunkDataReader<int>::ChunkType;
2162*da0073e9SAndroid Build Coastguard Worker     D(size_t chunk_count) : chunk_count_(chunk_count) {}
2163*da0073e9SAndroid Build Coastguard Worker 
2164*da0073e9SAndroid Build Coastguard Worker     BatchType read_chunk(size_t chunk_index) override {
2165*da0073e9SAndroid Build Coastguard Worker       BatchType batch_data(chunk_size, chunk_index);
2166*da0073e9SAndroid Build Coastguard Worker       return batch_data;
2167*da0073e9SAndroid Build Coastguard Worker     }
2168*da0073e9SAndroid Build Coastguard Worker 
2169*da0073e9SAndroid Build Coastguard Worker     size_t chunk_count() override {
2170*da0073e9SAndroid Build Coastguard Worker       return chunk_count_;
2171*da0073e9SAndroid Build Coastguard Worker     };
2172*da0073e9SAndroid Build Coastguard Worker 
2173*da0073e9SAndroid Build Coastguard Worker     void reset() override{};
2174*da0073e9SAndroid Build Coastguard Worker     size_t chunk_count_;
2175*da0073e9SAndroid Build Coastguard Worker   };
2176*da0073e9SAndroid Build Coastguard Worker 
2177*da0073e9SAndroid Build Coastguard Worker   const size_t prefetch_count = 1;
2178*da0073e9SAndroid Build Coastguard Worker   const size_t cache_size = 10;
2179*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2180*da0073e9SAndroid Build Coastguard Worker   const size_t cross_chunk_shuffle_counts[] = {2, 3};
2181*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2182*da0073e9SAndroid Build Coastguard Worker   const size_t chunk_counts[] = {3, 4, 5};
2183*da0073e9SAndroid Build Coastguard Worker 
2184*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler chunk_sampler(0);
2185*da0073e9SAndroid Build Coastguard Worker   S example_sampler(0);
2186*da0073e9SAndroid Build Coastguard Worker 
2187*da0073e9SAndroid Build Coastguard Worker   for (auto chunk_count : chunk_counts) {
2188*da0073e9SAndroid Build Coastguard Worker     for (auto cross_chunk_shuffle_count : cross_chunk_shuffle_counts) {
2189*da0073e9SAndroid Build Coastguard Worker       D data_reader(chunk_count);
2190*da0073e9SAndroid Build Coastguard Worker 
2191*da0073e9SAndroid Build Coastguard Worker       datasets::SharedBatchDataset<
2192*da0073e9SAndroid Build Coastguard Worker           datasets::ChunkDataset<D, samplers::SequentialSampler, S>>
2193*da0073e9SAndroid Build Coastguard Worker           dataset = datasets::make_shared_dataset<
2194*da0073e9SAndroid Build Coastguard Worker               datasets::ChunkDataset<D, samplers::SequentialSampler, S>>(
2195*da0073e9SAndroid Build Coastguard Worker               data_reader,
2196*da0073e9SAndroid Build Coastguard Worker               chunk_sampler,
2197*da0073e9SAndroid Build Coastguard Worker               example_sampler,
2198*da0073e9SAndroid Build Coastguard Worker               datasets::ChunkDatasetOptions(
2199*da0073e9SAndroid Build Coastguard Worker                   prefetch_count,
2200*da0073e9SAndroid Build Coastguard Worker                   batch_size,
2201*da0073e9SAndroid Build Coastguard Worker                   cache_size,
2202*da0073e9SAndroid Build Coastguard Worker                   cross_chunk_shuffle_count));
2203*da0073e9SAndroid Build Coastguard Worker 
2204*da0073e9SAndroid Build Coastguard Worker       auto data_loader = torch::data::make_data_loader(
2205*da0073e9SAndroid Build Coastguard Worker           dataset, DataLoaderOptions(batch_size).workers(0));
2206*da0073e9SAndroid Build Coastguard Worker 
2207*da0073e9SAndroid Build Coastguard Worker       std::vector<int> result;
2208*da0073e9SAndroid Build Coastguard Worker       for (auto iterator = data_loader->begin(); iterator != data_loader->end();
2209*da0073e9SAndroid Build Coastguard Worker            ++iterator) {
2210*da0073e9SAndroid Build Coastguard Worker         auto batch_result = *iterator;
2211*da0073e9SAndroid Build Coastguard Worker         std::copy(
2212*da0073e9SAndroid Build Coastguard Worker             batch_result.begin(),
2213*da0073e9SAndroid Build Coastguard Worker             batch_result.end(),
2214*da0073e9SAndroid Build Coastguard Worker             std::back_inserter(result));
2215*da0073e9SAndroid Build Coastguard Worker       }
2216*da0073e9SAndroid Build Coastguard Worker 
2217*da0073e9SAndroid Build Coastguard Worker       std::vector<int> expected_result;
2218*da0073e9SAndroid Build Coastguard Worker       {
2219*da0073e9SAndroid Build Coastguard Worker         // construct expected result
2220*da0073e9SAndroid Build Coastguard Worker         for (const auto i : c10::irange(
2221*da0073e9SAndroid Build Coastguard Worker                  (chunk_count + cross_chunk_shuffle_count - 1) /
2222*da0073e9SAndroid Build Coastguard Worker                  cross_chunk_shuffle_count)) {
2223*da0073e9SAndroid Build Coastguard Worker           for (const auto j : c10::irange(chunk_size)) {
2224*da0073e9SAndroid Build Coastguard Worker             (void)j; // Suppress unused variable warning
2225*da0073e9SAndroid Build Coastguard Worker             for (const auto k : c10::irange(cross_chunk_shuffle_count)) {
2226*da0073e9SAndroid Build Coastguard Worker               if (i * cross_chunk_shuffle_count + k < chunk_count) {
2227*da0073e9SAndroid Build Coastguard Worker                 expected_result.push_back(i * cross_chunk_shuffle_count + k);
2228*da0073e9SAndroid Build Coastguard Worker               }
2229*da0073e9SAndroid Build Coastguard Worker             }
2230*da0073e9SAndroid Build Coastguard Worker           }
2231*da0073e9SAndroid Build Coastguard Worker         }
2232*da0073e9SAndroid Build Coastguard Worker       }
2233*da0073e9SAndroid Build Coastguard Worker 
2234*da0073e9SAndroid Build Coastguard Worker       ASSERT_EQ(result.size(), expected_result.size());
2235*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(
2236*da0073e9SAndroid Build Coastguard Worker           std::equal(result.begin(), result.end(), expected_result.begin()));
2237*da0073e9SAndroid Build Coastguard Worker     }
2238*da0073e9SAndroid Build Coastguard Worker   }
2239*da0073e9SAndroid Build Coastguard Worker }
2240*da0073e9SAndroid Build Coastguard Worker 
TEST(DataLoaderTest,CustomPreprocessPolicy)2241*da0073e9SAndroid Build Coastguard Worker TEST(DataLoaderTest, CustomPreprocessPolicy) {
2242*da0073e9SAndroid Build Coastguard Worker   const size_t chunk_size = 5;
2243*da0073e9SAndroid Build Coastguard Worker   const size_t batch_size = 10;
2244*da0073e9SAndroid Build Coastguard Worker 
2245*da0073e9SAndroid Build Coastguard Worker   struct D : public datasets::ChunkDataReader<int> {
2246*da0073e9SAndroid Build Coastguard Worker    public:
2247*da0073e9SAndroid Build Coastguard Worker     using BatchType = datasets::ChunkDataReader<int>::ChunkType;
2248*da0073e9SAndroid Build Coastguard Worker     D(size_t chunk_count) : chunk_count_(chunk_count) {}
2249*da0073e9SAndroid Build Coastguard Worker 
2250*da0073e9SAndroid Build Coastguard Worker     BatchType read_chunk(size_t chunk_index) override {
2251*da0073e9SAndroid Build Coastguard Worker       BatchType batch_data(chunk_size);
2252*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand)
2253*da0073e9SAndroid Build Coastguard Worker       auto rand_gen = []() { return std::rand() % 100; };
2254*da0073e9SAndroid Build Coastguard Worker       std::generate(batch_data.begin(), batch_data.end(), rand_gen);
2255*da0073e9SAndroid Build Coastguard Worker       return batch_data;
2256*da0073e9SAndroid Build Coastguard Worker     }
2257*da0073e9SAndroid Build Coastguard Worker 
2258*da0073e9SAndroid Build Coastguard Worker     size_t chunk_count() override {
2259*da0073e9SAndroid Build Coastguard Worker       return chunk_count_;
2260*da0073e9SAndroid Build Coastguard Worker     };
2261*da0073e9SAndroid Build Coastguard Worker 
2262*da0073e9SAndroid Build Coastguard Worker     void reset() override{};
2263*da0073e9SAndroid Build Coastguard Worker     size_t chunk_count_;
2264*da0073e9SAndroid Build Coastguard Worker   };
2265*da0073e9SAndroid Build Coastguard Worker 
2266*da0073e9SAndroid Build Coastguard Worker   // custom preprocessing policy - sort the data ascendingly
2267*da0073e9SAndroid Build Coastguard Worker   auto sorting_policy = [](std::vector<int>& raw_batch_data) {
2268*da0073e9SAndroid Build Coastguard Worker     std::sort(raw_batch_data.begin(), raw_batch_data.end());
2269*da0073e9SAndroid Build Coastguard Worker   };
2270*da0073e9SAndroid Build Coastguard Worker   std::function<void(std::vector<int>&)> policy_function = sorting_policy;
2271*da0073e9SAndroid Build Coastguard Worker 
2272*da0073e9SAndroid Build Coastguard Worker   const size_t prefetch_count = 1;
2273*da0073e9SAndroid Build Coastguard Worker   const size_t cache_size = 10;
2274*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2275*da0073e9SAndroid Build Coastguard Worker   const size_t cross_chunk_shuffle_counts[] = {1, 2};
2276*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2277*da0073e9SAndroid Build Coastguard Worker   const size_t chunk_counts[] = {3, 4};
2278*da0073e9SAndroid Build Coastguard Worker 
2279*da0073e9SAndroid Build Coastguard Worker   samplers::SequentialSampler chunk_sampler(0);
2280*da0073e9SAndroid Build Coastguard Worker 
2281*da0073e9SAndroid Build Coastguard Worker   for (auto chunk_count : chunk_counts) {
2282*da0073e9SAndroid Build Coastguard Worker     for (auto cross_chunk_shuffle_count : cross_chunk_shuffle_counts) {
2283*da0073e9SAndroid Build Coastguard Worker       D data_reader(chunk_count);
2284*da0073e9SAndroid Build Coastguard Worker 
2285*da0073e9SAndroid Build Coastguard Worker       datasets::SharedBatchDataset<datasets::ChunkDataset<
2286*da0073e9SAndroid Build Coastguard Worker           D,
2287*da0073e9SAndroid Build Coastguard Worker           samplers::SequentialSampler,
2288*da0073e9SAndroid Build Coastguard Worker           samplers::SequentialSampler>>
2289*da0073e9SAndroid Build Coastguard Worker           dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
2290*da0073e9SAndroid Build Coastguard Worker               D,
2291*da0073e9SAndroid Build Coastguard Worker               samplers::SequentialSampler,
2292*da0073e9SAndroid Build Coastguard Worker               samplers::SequentialSampler>>(
2293*da0073e9SAndroid Build Coastguard Worker               data_reader,
2294*da0073e9SAndroid Build Coastguard Worker               chunk_sampler,
2295*da0073e9SAndroid Build Coastguard Worker               chunk_sampler,
2296*da0073e9SAndroid Build Coastguard Worker               datasets::ChunkDatasetOptions(
2297*da0073e9SAndroid Build Coastguard Worker                   prefetch_count,
2298*da0073e9SAndroid Build Coastguard Worker                   batch_size,
2299*da0073e9SAndroid Build Coastguard Worker                   cache_size,
2300*da0073e9SAndroid Build Coastguard Worker                   cross_chunk_shuffle_count),
2301*da0073e9SAndroid Build Coastguard Worker               policy_function);
2302*da0073e9SAndroid Build Coastguard Worker 
2303*da0073e9SAndroid Build Coastguard Worker       auto data_loader = torch::data::make_data_loader(
2304*da0073e9SAndroid Build Coastguard Worker           dataset, DataLoaderOptions(batch_size).workers(0));
2305*da0073e9SAndroid Build Coastguard Worker 
2306*da0073e9SAndroid Build Coastguard Worker       std::vector<int> result;
2307*da0073e9SAndroid Build Coastguard Worker       for (auto iterator = data_loader->begin(); iterator != data_loader->end();
2308*da0073e9SAndroid Build Coastguard Worker            ++iterator) {
2309*da0073e9SAndroid Build Coastguard Worker         auto batch_result = *iterator;
2310*da0073e9SAndroid Build Coastguard Worker         if (batch_result.size() > chunk_size * cross_chunk_shuffle_count) {
2311*da0073e9SAndroid Build Coastguard Worker           for (unsigned i = 0; i < batch_result.size(); i += chunk_size) {
2312*da0073e9SAndroid Build Coastguard Worker             ASSERT_TRUE(std::is_sorted(
2313*da0073e9SAndroid Build Coastguard Worker                 batch_result.begin() + i,
2314*da0073e9SAndroid Build Coastguard Worker                 batch_result.begin() + i + chunk_size));
2315*da0073e9SAndroid Build Coastguard Worker           }
2316*da0073e9SAndroid Build Coastguard Worker         } else {
2317*da0073e9SAndroid Build Coastguard Worker           ASSERT_TRUE(std::is_sorted(batch_result.begin(), batch_result.end()));
2318*da0073e9SAndroid Build Coastguard Worker         }
2319*da0073e9SAndroid Build Coastguard Worker       }
2320*da0073e9SAndroid Build Coastguard Worker     }
2321*da0073e9SAndroid Build Coastguard Worker   }
2322*da0073e9SAndroid Build Coastguard Worker }
2323