xref: /aosp_15_r20/external/pytorch/test/cpp/api/dataloader.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/torch.h>
4 
5 #include <test/cpp/api/support.h>
6 
7 #include <c10/util/ArrayRef.h>
8 #include <c10/util/irange.h>
9 #include <c10/util/tempfile.h>
10 
11 #include <algorithm>
12 #include <chrono>
13 #include <future>
14 #include <iostream>
15 #include <iterator>
16 #include <limits>
17 #include <mutex>
18 #include <numeric>
19 #include <stdexcept>
20 #include <string>
21 #include <thread>
22 #include <unordered_set>
23 #include <vector>
24 
25 using namespace torch::data; // NOLINT
26 
27 const std::chrono::milliseconds kMillisecond(1);
28 
29 struct DummyDataset : datasets::Dataset<DummyDataset, int> {
DummyDatasetDummyDataset30   explicit DummyDataset(size_t size = 100) : size_(size) {}
31 
getDummyDataset32   int get(size_t index) override {
33     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
34     return 1 + index;
35   }
sizeDummyDataset36   torch::optional<size_t> size() const override {
37     return size_;
38   }
39 
40   size_t size_;
41 };
42 
TEST(DataTest,DatasetCallsGetCorrectly)43 TEST(DataTest, DatasetCallsGetCorrectly) {
44   DummyDataset d;
45   std::vector<int> batch = d.get_batch({0, 1, 2, 3, 4});
46   std::vector<int> expected = {1, 2, 3, 4, 5};
47   ASSERT_EQ(batch, expected);
48 }
49 
TEST(DataTest,TransformCallsGetApplyCorrectly)50 TEST(DataTest, TransformCallsGetApplyCorrectly) {
51   struct T : transforms::Transform<int, std::string> {
52     std::string apply(int input) override {
53       return std::to_string(input);
54     }
55   };
56 
57   auto d = DummyDataset{}.map(T{});
58   std::vector<std::string> batch = d.get_batch({0, 1, 2, 3, 4});
59   std::vector<std::string> expected = {"1", "2", "3", "4", "5"};
60   ASSERT_EQ(batch, expected);
61 }
62 
63 // dummy chunk data reader with 3 chunks and 35 examples in total. Each chunk
64 // contains 10, 5, 20 examples respectively.
65 
66 struct DummyChunkDataReader : public datasets::ChunkDataReader<int> {
67  public:
68   using BatchType = datasets::ChunkDataReader<int>::ChunkType;
69   using DataType = datasets::ChunkDataReader<int>::ExampleType;
70 
71   /// Read an entire chunk.
read_chunkDummyChunkDataReader72   BatchType read_chunk(size_t chunk_index) override {
73     BatchType batch_data;
74     int start_index = chunk_index == 0
75         ? 0
76         // NOLINTNEXTLINE(bugprone-fold-init-type)
77         : std::accumulate(chunk_sizes, chunk_sizes + chunk_index, 0);
78 
79     batch_data.resize(chunk_sizes[chunk_index]);
80 
81     std::iota(batch_data.begin(), batch_data.end(), start_index);
82 
83     return batch_data;
84   }
85 
chunk_countDummyChunkDataReader86   size_t chunk_count() override {
87     return chunk_count_;
88   };
89 
resetDummyChunkDataReader90   void reset() override{};
91 
92   const static size_t chunk_count_ = 3;
93   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
94   size_t chunk_sizes[chunk_count_] = {10, 5, 20};
95 };
96 
TEST(DataTest,ChunkDataSetWithInvalidInitParameter)97 TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {
98   DummyChunkDataReader data_reader;
99   samplers::SequentialSampler sampler(0);
100 
101   auto initialization_function = [&](size_t preloader_count,
102                                      size_t batch_size,
103                                      size_t cache_size,
104                                      size_t cross_chunk_shuffle_count = 1) {
105     datasets::SharedBatchDataset<datasets::ChunkDataset<
106         DummyChunkDataReader,
107         samplers::SequentialSampler,
108         samplers::SequentialSampler>>
109         dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
110             DummyChunkDataReader,
111             samplers::SequentialSampler,
112             samplers::SequentialSampler>>(
113             data_reader,
114             sampler,
115             sampler,
116             datasets::ChunkDatasetOptions(
117                 preloader_count,
118                 batch_size,
119                 cache_size,
120                 cross_chunk_shuffle_count));
121   };
122 
123   ASSERT_THROWS_WITH(
124       initialization_function(0, 1, 1),
125       "Preloader count is 0. At least one preloader needs to be specified.");
126 
127   ASSERT_THROWS_WITH(
128       initialization_function(1, 0, 1),
129       "Batch size is 0. A positive batch size needs to be specified.");
130 
131   ASSERT_THROWS_WITH(
132       initialization_function(1, 1, 0),
133       "Cache size is 0. A positive cache size needs to be specified.");
134 
135   ASSERT_THROWS_WITH(
136       initialization_function(1, 10, 5),
137       "Cache size is less than batch size. Cache needs to be large enough to "
138       "hold at least one batch.");
139   ASSERT_THROWS_WITH(
140       initialization_function(1, 10, 20, 0),
141       "cross_chunk_shuffle_count needs to be greater than 0.");
142 }
143 
144 struct InfiniteStreamDataset
145     : datasets::StreamDataset<InfiniteStreamDataset, std::vector<int>> {
get_batchInfiniteStreamDataset146   std::vector<int> get_batch(size_t batch_size) override {
147     std::vector<int> batch(batch_size);
148     for (auto& i : batch) {
149       i = counter++;
150     }
151     return batch;
152   }
153 
sizeInfiniteStreamDataset154   torch::optional<size_t> size() const override {
155     return torch::nullopt;
156   }
157 
158   size_t counter = 0;
159 };
160 
TEST(DataTest,InfiniteStreamDataset)161 TEST(DataTest, InfiniteStreamDataset) {
162   const size_t kBatchSize = 13;
163 
164   auto dataset = InfiniteStreamDataset().map(
165       transforms::Lambda<int>([](int x) { return x + 1; }));
166 
167   auto data_loader = torch::data::make_data_loader(
168       std::move(dataset),
169       samplers::StreamSampler(/*epoch_size=*/39),
170       kBatchSize);
171 
172   size_t batch_index = 0;
173   for (auto& batch : *data_loader) {
174     ASSERT_LT(batch_index, 3);
175     ASSERT_EQ(batch.size(), kBatchSize);
176     for (const auto j : c10::irange(kBatchSize)) {
177       ASSERT_EQ(batch.at(j), 1 + (batch_index * kBatchSize) + j);
178     }
179     batch_index += 1;
180   }
181   ASSERT_EQ(batch_index, 3);
182 }
183 
TEST(DataTest,NoSequencerIsIdentity)184 TEST(DataTest, NoSequencerIsIdentity) {
185   using namespace torch::data::detail::sequencers; // NOLINT
186   NoSequencer<int> no_sequencer;
187   const auto value = no_sequencer.next([] { return 5; }).value();
188   ASSERT_EQ(value, 5);
189 }
190 
TEST(DataTest,OrderedSequencerIsSetUpWell)191 TEST(DataTest, OrderedSequencerIsSetUpWell) {
192   using namespace torch::data::detail::sequencers; // NOLINT
193   struct S {
194     size_t sequence_number;
195   };
196   const size_t kMaxJobs = 5;
197   OrderedSequencer<S> sequencer(kMaxJobs);
198   ASSERT_EQ(sequencer.next_sequence_number_, 0);
199   ASSERT_EQ(sequencer.buffer_.size(), kMaxJobs);
200 }
201 
TEST(DataTest,OrderedSequencerReOrdersValues)202 TEST(DataTest, OrderedSequencerReOrdersValues) {
203   using namespace torch::data::detail::sequencers; // NOLINT
204   struct S {
205     size_t sequence_number;
206   };
207   const size_t kMaxJobs = 5;
208   OrderedSequencer<S> sequencer(kMaxJobs);
209 
210   std::vector<size_t> v = {0, 2, 4, 3, 1};
211   size_t index = 0;
212   auto getter = [&v, &index]() { return S{v.at(index++)}; };
213 
214   // Let's say the sequence number matches for the batch one, then it should
215   // return immediately.
216   const auto batch = sequencer.next(getter);
217   ASSERT_EQ(batch.value().sequence_number, 0);
218   ASSERT_EQ(index, 1);
219 
220   // Now it should call the getter until it gets the next value.
221   ASSERT_EQ(1, sequencer.next(getter).value().sequence_number);
222   ASSERT_EQ(index, 5);
223 
224   // The next three should come in order.
225   for (size_t i = 2; i <= 4; ++i) {
226     // New value doesn't matter. In fact, it shouldn't be accessed.
227     ASSERT_EQ(i, sequencer.next(getter).value().sequence_number);
228     // The index doesn't change.
229     ASSERT_EQ(index, 5);
230   }
231 }
232 
TEST(DataTest,BatchLambdaAppliesFunctionToBatch)233 TEST(DataTest, BatchLambdaAppliesFunctionToBatch) {
234   using InputBatch = std::vector<int>;
235   using OutputBatch = std::string;
236   DummyDataset d;
237   auto e = d.map(transforms::BatchLambda<InputBatch, OutputBatch>(
238       [](std::vector<int> input) {
239         return std::to_string(std::accumulate(input.begin(), input.end(), 0));
240       }));
241   ASSERT_EQ(e.get_batch({1, 2, 3, 4, 5}), std::string("20"));
242 }
243 
TEST(DataTest,LambdaAppliesFunctionToExample)244 TEST(DataTest, LambdaAppliesFunctionToExample) {
245   auto d = DummyDataset().map(transforms::Lambda<int, std::string>(
246       static_cast<std::string (*)(int)>(std::to_string)));
247   std::vector<std::string> expected = {"1", "2", "3", "4", "5"};
248   ASSERT_EQ(d.get_batch({0, 1, 2, 3, 4}), expected);
249 }
250 
TEST(DataTest,CollateReducesBatch)251 TEST(DataTest, CollateReducesBatch) {
252   auto d =
253       DummyDataset().map(transforms::Collate<int>([](std::vector<int> input) {
254         return std::accumulate(input.begin(), input.end(), 0);
255       }));
256   ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20);
257 }
258 
TEST(DataTest,CollationReducesBatch)259 TEST(DataTest, CollationReducesBatch) {
260   struct Summer : transforms::Collation<int> {
261     int apply_batch(std::vector<int> input) override {
262       return std::accumulate(input.begin(), input.end(), 0);
263     }
264   };
265   auto d = DummyDataset().map(Summer{});
266   ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20);
267 }
268 
TEST(DataTest,SequentialSamplerReturnsIndicesInOrder)269 TEST(DataTest, SequentialSamplerReturnsIndicesInOrder) {
270   samplers::SequentialSampler sampler(10);
271   ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
272   ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({3, 4, 5, 6, 7}));
273   ASSERT_EQ(sampler.next(2).value(), std::vector<size_t>({8, 9}));
274   ASSERT_FALSE(sampler.next(2).has_value());
275 }
276 
TEST(DataTest,SequentialSamplerReturnsLessValuesForLastBatch)277 TEST(DataTest, SequentialSamplerReturnsLessValuesForLastBatch) {
278   samplers::SequentialSampler sampler(5);
279   ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
280   ASSERT_EQ(sampler.next(100).value(), std::vector<size_t>({3, 4}));
281   ASSERT_FALSE(sampler.next(2).has_value());
282 }
283 
TEST(DataTest,SequentialSamplerResetsWell)284 TEST(DataTest, SequentialSamplerResetsWell) {
285   samplers::SequentialSampler sampler(5);
286   ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
287   ASSERT_FALSE(sampler.next(2).has_value());
288   sampler.reset();
289   ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
290   ASSERT_FALSE(sampler.next(2).has_value());
291 }
292 
TEST(DataTest,SequentialSamplerResetsWithNewSizeWell)293 TEST(DataTest, SequentialSamplerResetsWithNewSizeWell) {
294   samplers::SequentialSampler sampler(5);
295   ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
296   ASSERT_FALSE(sampler.next(2).has_value());
297   sampler.reset(7);
298   ASSERT_EQ(
299       sampler.next(7).value(), std::vector<size_t>({0, 1, 2, 3, 4, 5, 6}));
300   ASSERT_FALSE(sampler.next(2).has_value());
301   sampler.reset(3);
302   ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
303   ASSERT_FALSE(sampler.next(2).has_value());
304 }
305 
TEST(DataTest,CanSaveAndLoadSequentialSampler)306 TEST(DataTest, CanSaveAndLoadSequentialSampler) {
307   {
308     samplers::SequentialSampler a(10);
309     ASSERT_EQ(a.index(), 0);
310     std::stringstream stream;
311     torch::save(a, stream);
312 
313     samplers::SequentialSampler b(10);
314     torch::load(b, stream);
315     ASSERT_EQ(b.index(), 0);
316   }
317   {
318     samplers::SequentialSampler a(10);
319     a.next(3);
320     a.next(4);
321     ASSERT_EQ(a.index(), 7);
322     std::stringstream stream;
323     torch::save(a, stream);
324 
325     samplers::SequentialSampler b(10);
326     torch::load(b, stream);
327     ASSERT_EQ(b.index(), 7);
328   }
329 }
330 
TEST(DataTest,RandomSamplerReturnsIndicesInCorrectRange)331 TEST(DataTest, RandomSamplerReturnsIndicesInCorrectRange) {
332   samplers::RandomSampler sampler(10);
333 
334   std::vector<size_t> indices = sampler.next(3).value();
335   for (auto i : indices) {
336     ASSERT_GE(i, 0);
337     ASSERT_LT(i, 10);
338   }
339 
340   indices = sampler.next(5).value();
341   for (auto i : indices) {
342     ASSERT_GE(i, 0);
343     ASSERT_LT(i, 10);
344   }
345 
346   indices = sampler.next(2).value();
347   for (auto i : indices) {
348     ASSERT_GE(i, 0);
349     ASSERT_LT(i, 10);
350   }
351 
352   ASSERT_FALSE(sampler.next(10).has_value());
353 }
354 
TEST(DataTest,RandomSamplerReturnsLessValuesForLastBatch)355 TEST(DataTest, RandomSamplerReturnsLessValuesForLastBatch) {
356   samplers::RandomSampler sampler(5);
357   ASSERT_EQ(sampler.next(3).value().size(), 3);
358   ASSERT_EQ(sampler.next(100).value().size(), 2);
359   ASSERT_FALSE(sampler.next(2).has_value());
360 }
361 
TEST(DataTest,RandomSamplerResetsWell)362 TEST(DataTest, RandomSamplerResetsWell) {
363   samplers::RandomSampler sampler(5);
364   ASSERT_EQ(sampler.next(5).value().size(), 5);
365   ASSERT_FALSE(sampler.next(2).has_value());
366   sampler.reset();
367   ASSERT_EQ(sampler.next(5).value().size(), 5);
368   ASSERT_FALSE(sampler.next(2).has_value());
369 }
370 
TEST(DataTest,RandomSamplerResetsWithNewSizeWell)371 TEST(DataTest, RandomSamplerResetsWithNewSizeWell) {
372   samplers::RandomSampler sampler(5);
373   ASSERT_EQ(sampler.next(5).value().size(), 5);
374   ASSERT_FALSE(sampler.next(2).has_value());
375   sampler.reset(7);
376   ASSERT_EQ(sampler.next(7).value().size(), 7);
377   ASSERT_FALSE(sampler.next(2).has_value());
378   sampler.reset(3);
379   ASSERT_EQ(sampler.next(3).value().size(), 3);
380   ASSERT_FALSE(sampler.next(2).has_value());
381 }
382 
TEST(DataTest,SavingAndLoadingRandomSamplerYieldsSameSequence)383 TEST(DataTest, SavingAndLoadingRandomSamplerYieldsSameSequence) {
384   {
385     samplers::RandomSampler a(10);
386 
387     std::stringstream stream;
388     torch::save(a, stream);
389 
390     samplers::RandomSampler b(10);
391     torch::load(b, stream);
392 
393     ASSERT_EQ(a.next(10).value(), b.next(10).value());
394   }
395   {
396     samplers::RandomSampler a(10);
397     a.next(3);
398     ASSERT_EQ(a.index(), 3);
399 
400     std::stringstream stream;
401     torch::save(a, stream);
402 
403     samplers::RandomSampler b(10);
404     torch::load(b, stream);
405     ASSERT_EQ(b.index(), 3);
406 
407     auto b_sequence = b.next(10).value();
408     ASSERT_EQ(b_sequence.size(), 7);
409     ASSERT_EQ(a.next(10).value(), b_sequence);
410   }
411 }
412 
TEST(DataTest,StreamSamplerReturnsTheBatchSizeAndThenRemainder)413 TEST(DataTest, StreamSamplerReturnsTheBatchSizeAndThenRemainder) {
414   samplers::StreamSampler sampler(/*epoch_size=*/100);
415   ASSERT_EQ(sampler.next(10).value(), 10);
416   ASSERT_EQ(sampler.next(2).value(), 2);
417   ASSERT_EQ(sampler.next(85).value(), 85);
418   ASSERT_EQ(sampler.next(123).value(), 3);
419   ASSERT_FALSE(sampler.next(1).has_value());
420 }
421 
TEST(DataTest,StreamSamplerResetsWell)422 TEST(DataTest, StreamSamplerResetsWell) {
423   samplers::StreamSampler sampler(/*epoch_size=*/5);
424   ASSERT_EQ(sampler.next(5).value().size(), 5);
425   ASSERT_FALSE(sampler.next(2).has_value());
426   sampler.reset();
427   ASSERT_EQ(sampler.next(5).value().size(), 5);
428   ASSERT_FALSE(sampler.next(2).has_value());
429 }
430 
TEST(DataTest,StreamSamplerResetsWithNewSizeWell)431 TEST(DataTest, StreamSamplerResetsWithNewSizeWell) {
432   samplers::StreamSampler sampler(/*epoch_size=*/5);
433   ASSERT_EQ(sampler.next(5).value().size(), 5);
434   ASSERT_FALSE(sampler.next(2).has_value());
435   sampler.reset(7);
436   ASSERT_EQ(sampler.next(7).value().size(), 7);
437   ASSERT_FALSE(sampler.next(2).has_value());
438   sampler.reset(3);
439   ASSERT_EQ(sampler.next(3).value().size(), 3);
440   ASSERT_FALSE(sampler.next(2).has_value());
441 }
442 
TEST(DataTest,TensorDatasetConstructsFromSingleTensor)443 TEST(DataTest, TensorDatasetConstructsFromSingleTensor) {
444   datasets::TensorDataset dataset(torch::eye(5));
445   ASSERT_TRUE(
446       torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.get(2)));
447 }
448 
TEST(DataTest,TensorDatasetConstructsFromInitializerListOfTensors)449 TEST(DataTest, TensorDatasetConstructsFromInitializerListOfTensors) {
450   std::vector<torch::Tensor> vector = torch::eye(5).chunk(5);
451   datasets::TensorDataset dataset(vector);
452   ASSERT_TRUE(
453       torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.get(2)));
454 }
455 
TEST(DataTest,StackTransformWorksForExample)456 TEST(DataTest, StackTransformWorksForExample) {
457   struct D : public datasets::Dataset<D> {
458     Example<> get(size_t index) override {
459       return {tensor[index], 1 + tensor[index]};
460     }
461 
462     torch::optional<size_t> size() const override {
463       return tensor.size(0);
464     }
465 
466     torch::Tensor tensor{torch::eye(4)};
467   };
468 
469   auto d = D().map(transforms::Stack<Example<>>());
470 
471   Example<> batch = d.get_batch({0, 1});
472   ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
473   ASSERT_TRUE(batch.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 0, 2)));
474 
475   Example<> second = d.get_batch({2, 3});
476   ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
477   ASSERT_TRUE(second.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 2, 4)));
478 }
479 
TEST(DataTest,StackTransformWorksForTensorExample)480 TEST(DataTest, StackTransformWorksForTensorExample) {
481   auto d = datasets::TensorDataset(torch::eye(4))
482                .map(transforms::Stack<TensorExample>());
483 
484   TensorExample batch = d.get_batch({0, 1});
485   ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
486 
487   TensorExample second = d.get_batch({2, 3});
488   ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
489 }
490 
491 // Template classes cannot be nested in functions.
492 template <typename Target>
493 struct T : transforms::TensorTransform<Target> {
operator ()T494   torch::Tensor operator()(torch::Tensor input) override {
495     return input * 2;
496   }
497 };
498 
499 struct TensorStringDataset
500     : datasets::
501           Dataset<TensorStringDataset, Example<torch::Tensor, std::string>> {
getTensorStringDataset502   Example<torch::Tensor, std::string> get(size_t index) override {
503     return {torch::tensor(static_cast<double>(index)), std::to_string(index)};
504   }
505 
sizeTensorStringDataset506   torch::optional<size_t> size() const override {
507     return 100;
508   }
509 };
510 
TEST(DataTest,TensorTransformWorksForAnyTargetType)511 TEST(DataTest, TensorTransformWorksForAnyTargetType) {
512   auto d = TensorStringDataset().map(T<std::string>{});
513   std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2});
514 
515   ASSERT_EQ(batch.size(), 2);
516   ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0)));
517   ASSERT_EQ(batch[0].target, "1");
518 
519   ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0)));
520   ASSERT_EQ(batch[1].target, "2");
521 }
522 
TEST(DataTest,TensorLambdaWorksforAnyTargetType)523 TEST(DataTest, TensorLambdaWorksforAnyTargetType) {
524   auto d = TensorStringDataset().map(transforms::TensorLambda<std::string>(
525       [](torch::Tensor input) { return input * 2; }));
526   std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2});
527 
528   ASSERT_EQ(batch.size(), 2);
529   ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0)));
530   ASSERT_EQ(batch[0].target, "1");
531 
532   ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0)));
533   ASSERT_EQ(batch[1].target, "2");
534 }
535 
536 struct DummyTensorDataset
537     : datasets::Dataset<DummyTensorDataset, Example<torch::Tensor, int>> {
getDummyTensorDataset538   Example<torch::Tensor, int> get(size_t index) override {
539     const auto channels = static_cast<int64_t>(index);
540     torch::Tensor tensor =
541         (channels > 0) ? torch::ones({channels, 4, 4}) : torch::ones({4, 4});
542     return {tensor, static_cast<int>(channels)};
543   }
544 
sizeDummyTensorDataset545   torch::optional<size_t> size() const override {
546     return 100;
547   }
548 };
549 
TEST(DataTest,NormalizeTransform)550 TEST(DataTest, NormalizeTransform) {
551   auto dataset = DummyTensorDataset().map(transforms::Normalize<int>(0.5, 0.1));
552 
553   // Works for zero (one implicit) channels
554   std::vector<Example<torch::Tensor, int>> output = dataset.get_batch(0);
555   ASSERT_EQ(output.size(), 1);
556   // (1 - 0.5) / 0.1 = 5
557   ASSERT_TRUE(output[0].data.allclose(torch::ones({4, 4}) * 5))
558       << output[0].data;
559 
560   // Works for one explicit channel
561   output = dataset.get_batch(1);
562   ASSERT_EQ(output.size(), 1);
563   ASSERT_EQ(output[0].data.size(0), 1);
564   ASSERT_TRUE(output[0].data.allclose(torch::ones({1, 4, 4}) * 5))
565       << output[0].data;
566 
567   // Works for two channels with different moments
568   dataset = DummyTensorDataset().map(
569       transforms::Normalize<int>({0.5, 1.5}, {0.1, 0.2}));
570   output = dataset.get_batch(2);
571   ASSERT_EQ(output.size(), 1);
572   ASSERT_EQ(output[0].data.size(0), 2);
573   ASSERT_TRUE(output[0]
574                   .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
575                   .allclose(torch::ones({1, 4, 4}) * 5))
576       << output[0].data;
577   ASSERT_TRUE(output[0]
578                   .data.slice(/*dim=*/0, /*start=*/1)
579                   .allclose(torch::ones({1, 4, 4}) * -2.5))
580       << output[0].data;
581 
582   // Works for three channels with one moment value
583   dataset = DummyTensorDataset().map(transforms::Normalize<int>(1.5, 0.2));
584   output = dataset.get_batch(3);
585   ASSERT_EQ(output.size(), 1);
586   ASSERT_EQ(output[0].data.size(0), 3);
587   ASSERT_TRUE(output[0].data.allclose(torch::ones({3, 4, 4}) * -2.5))
588       << output[0].data;
589 
590   // Works for three channels with different moments
591   dataset = DummyTensorDataset().map(
592       transforms::Normalize<int>({0.5, 1.5, -1.5}, {0.1, 0.2, 0.2}));
593   output = dataset.get_batch(3);
594   ASSERT_EQ(output.size(), 1);
595   ASSERT_EQ(output[0].data.size(0), 3);
596   ASSERT_TRUE(output[0]
597                   .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
598                   .allclose(torch::ones({1, 4, 4}) * 5))
599       << output[0].data;
600   ASSERT_TRUE(output[0]
601                   .data.slice(/*dim=*/0, /*start=*/1, /*end=*/2)
602                   .allclose(torch::ones({1, 4, 4}) * -2.5))
603       << output[0].data;
604   ASSERT_TRUE(output[0]
605                   .data.slice(/*dim=*/0, /*start=*/2)
606                   .allclose(torch::ones({1, 4, 4}) * 12.5))
607       << output[0].data;
608 }
609 
610 struct UnCopyableDataset : public datasets::Dataset<UnCopyableDataset> {
611   UnCopyableDataset() = default;
612 
613   UnCopyableDataset(const UnCopyableDataset&) = delete;
614   UnCopyableDataset& operator=(const UnCopyableDataset&) = delete;
615 
616   UnCopyableDataset(UnCopyableDataset&&) = default;
617   UnCopyableDataset& operator=(UnCopyableDataset&&) = default;
618 
619   ~UnCopyableDataset() override = default;
620 
getUnCopyableDataset621   Example<> get(size_t index) override {
622     return {
623         torch::tensor({static_cast<int64_t>(index)}),
624         torch::tensor({static_cast<int64_t>(index)})};
625   }
626 
sizeUnCopyableDataset627   torch::optional<size_t> size() const override {
628     return 100;
629   }
630 };
631 
TEST(DataTest,MapDoesNotCopy)632 TEST(DataTest, MapDoesNotCopy) {
633   auto dataset = UnCopyableDataset()
634                      .map(transforms::TensorLambda<>(
635                          [](torch::Tensor tensor) { return tensor + 1; }))
636                      .map(transforms::TensorLambda<>(
637                          [](torch::Tensor tensor) { return tensor + 2; }))
638                      .map(transforms::TensorLambda<>(
639                          [](torch::Tensor tensor) { return tensor + 3; }));
640 
641   auto data = dataset.get_batch(1).at(0).data;
642   ASSERT_EQ(data.numel(), 1);
643   ASSERT_EQ(data[0].item<float>(), 7);
644 }
645 
TEST(DataTest,QueuePushAndPopFromSameThread)646 TEST(DataTest, QueuePushAndPopFromSameThread) {
647   torch::data::detail::Queue<int> queue;
648   queue.push(1);
649   queue.push(2);
650   ASSERT_EQ(queue.pop(), 1);
651   ASSERT_EQ(queue.pop(), 2);
652 }
653 
TEST(DataTest,QueuePopWithTimeoutThrowsUponTimeout)654 TEST(DataTest, QueuePopWithTimeoutThrowsUponTimeout) {
655   torch::data::detail::Queue<int> queue;
656   ASSERT_THROWS_WITH(
657       queue.pop(10 * kMillisecond),
658       "Timeout in DataLoader queue while waiting for next batch "
659       "(timeout was 10 ms)");
660 }
661 
TEST(DataTest,QueuePushAndPopFromDifferentThreads)662 TEST(DataTest, QueuePushAndPopFromDifferentThreads) {
663   using torch::data::detail::Queue;
664 
665   // First test: push batch and the pop in thread.
666   {
667     Queue<int> queue;
668     queue.push(1);
669     auto future =
670         std::async(std::launch::async, [&queue] { return queue.pop(); });
671     ASSERT_EQ(future.get(), 1);
672   }
673 
674   // Second test: attempt to pop batch (and block), then push.
675   {
676     Queue<int> queue;
677     std::thread thread([&queue] {
678       std::this_thread::sleep_for(20 * kMillisecond);
679       queue.push(123);
680     });
681     ASSERT_EQ(queue.pop(), 123);
682     thread.join();
683   }
684 }
685 
TEST(DataTest,QueueClearEmptiesTheQueue)686 TEST(DataTest, QueueClearEmptiesTheQueue) {
687   torch::data::detail::Queue<int> queue;
688   queue.push(1);
689   queue.push(2);
690   queue.push(3);
691   ASSERT_EQ(queue.clear(), 3);
692   ASSERT_THROWS_WITH(queue.pop(1 * kMillisecond), "Timeout");
693 }
694 
TEST(DataTest,DataShuttleCanPushAndPopJob)695 TEST(DataTest, DataShuttleCanPushAndPopJob) {
696   torch::data::detail::DataShuttle<int, int> shuttle;
697   shuttle.push_job(1);
698   shuttle.push_job(2);
699   ASSERT_EQ(shuttle.pop_job(), 1);
700   ASSERT_EQ(shuttle.pop_job(), 2);
701 }
702 
TEST(DataTest,DataShuttleCanPushAndPopResult)703 TEST(DataTest, DataShuttleCanPushAndPopResult) {
704   torch::data::detail::DataShuttle<int, int> shuttle;
705   // pop_result() will only attempt to pop if there was a push_job() batch.
706   shuttle.push_job(1);
707   shuttle.push_job(2);
708 
709   shuttle.pop_job();
710   shuttle.push_result(1);
711   ASSERT_EQ(shuttle.pop_result().value(), 1);
712 
713   shuttle.pop_job();
714   shuttle.push_result(2);
715   ASSERT_EQ(shuttle.pop_result().value(), 2);
716 }
717 
TEST(DataTest,DataShuttlePopResultReturnsNulloptWhenNoJobsInFlight)718 TEST(DataTest, DataShuttlePopResultReturnsNulloptWhenNoJobsInFlight) {
719   torch::data::detail::DataShuttle<int, int> shuttle;
720   ASSERT_FALSE(shuttle.pop_result().has_value());
721   shuttle.push_job(1);
722   shuttle.pop_job();
723   shuttle.push_result(1);
724   ASSERT_EQ(shuttle.pop_result().value(), 1);
725   ASSERT_FALSE(shuttle.pop_result().has_value());
726   ASSERT_FALSE(shuttle.pop_result().has_value());
727 }
728 
TEST(DataTest,DataShuttleDrainMeansPopResultReturnsNullopt)729 TEST(DataTest, DataShuttleDrainMeansPopResultReturnsNullopt) {
730   torch::data::detail::DataShuttle<int, int> shuttle;
731   shuttle.push_job(1);
732   shuttle.push_result(1);
733   shuttle.drain();
734   ASSERT_FALSE(shuttle.pop_result().has_value());
735 }
736 
TEST(DataTest,DataShuttlePopResultTimesOut)737 TEST(DataTest, DataShuttlePopResultTimesOut) {
738   torch::data::detail::DataShuttle<int, int> shuttle;
739   shuttle.push_job(1);
740   ASSERT_THROWS_WITH(shuttle.pop_result(10 * kMillisecond), "Timeout");
741 }
742 
743 struct UncopyableDataset : datasets::Dataset<UncopyableDataset, int> {
UncopyableDatasetUncopyableDataset744   UncopyableDataset(const std::string& /* unused */) {}
745 
746   UncopyableDataset(UncopyableDataset&&) = default;
747   UncopyableDataset& operator=(UncopyableDataset&&) = default;
748 
749   UncopyableDataset(const UncopyableDataset&) = delete;
750   UncopyableDataset& operator=(const UncopyableDataset&) = delete;
751 
getUncopyableDataset752   int get(size_t index) override {
753     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
754     return 1 + index;
755   }
sizeUncopyableDataset756   torch::optional<size_t> size() const override {
757     return 100;
758   }
759 };
760 
TEST(DataTest,SharedBatchDatasetReallyIsShared)761 TEST(DataTest, SharedBatchDatasetReallyIsShared) {
762   // This test will only compile if we really are not making any copies.
763   // There is otherwise no logic to test and because it is not deterministic
764   // how many and when worker threads access the shareddataset, we don't have
765   // any additional assertions here.
766 
767   auto shared_dataset =
768       torch::data::datasets::make_shared_dataset<UncopyableDataset>(
769           "uncopyable");
770 
771   auto data_loader = torch::data::make_data_loader(
772       shared_dataset, torch::data::DataLoaderOptions().workers(3));
773 
774   for (auto batch : *data_loader) {
775     /* exhaust */
776   }
777 }
778 
TEST(DataTest,SharedBatchDatasetDoesNotIncurCopyWhenPassedDatasetObject)779 TEST(DataTest, SharedBatchDatasetDoesNotIncurCopyWhenPassedDatasetObject) {
780   // This will not compile if a copy is made.
781   auto shared_dataset =
782       torch::data::datasets::make_shared_dataset<UncopyableDataset>(
783           UncopyableDataset("uncopyable"));
784   ASSERT_EQ(shared_dataset.size().value(), 100);
785 }
786 
787 struct TestIndex : public torch::data::samplers::CustomBatchRequest {
TestIndexTestIndex788   explicit TestIndex(size_t offset, std::vector<size_t> index)
789       : offset(offset), index(std::move(index)) {}
sizeTestIndex790   size_t size() const override {
791     return index.size();
792   }
793   size_t offset;
794   std::vector<size_t> index;
795 };
796 
797 struct TestIndexDataset
798     : datasets::BatchDataset<TestIndexDataset, std::vector<int>, TestIndex> {
TestIndexDatasetTestIndexDataset799   explicit TestIndexDataset(size_t size) : data(size) {
800     std::iota(data.begin(), data.end(), size_t(0));
801   }
get_batchTestIndexDataset802   std::vector<int> get_batch(TestIndex index) override {
803     std::vector<int> batch;
804     for (auto i : index.index) {
805       batch.push_back(index.offset + data.at(i));
806     }
807     return batch;
808   }
sizeTestIndexDataset809   torch::optional<size_t> size() const override {
810     return data.size();
811   }
812   std::vector<int> data;
813 };
814 
815 struct TestIndexSampler : public samplers::Sampler<TestIndex> {
TestIndexSamplerTestIndexSampler816   explicit TestIndexSampler(size_t size) : size_(size) {}
resetTestIndexSampler817   void reset(torch::optional<size_t> new_size = torch::nullopt) override {}
nextTestIndexSampler818   torch::optional<TestIndex> next(size_t batch_size) override {
819     if (index_ >= size_) {
820       return torch::nullopt;
821     }
822     std::vector<size_t> indices(batch_size);
823     std::iota(indices.begin(), indices.end(), size_t(0));
824     index_ += batch_size;
825     return TestIndex(batch_size, std::move(indices));
826   }
saveTestIndexSampler827   void save(torch::serialize::OutputArchive& archive) const override {}
loadTestIndexSampler828   void load(torch::serialize::InputArchive& archive) override {}
829   size_t index_ = 0;
830   size_t size_;
831 };
832 
TEST(DataTest,CanUseCustomTypeAsIndexType)833 TEST(DataTest, CanUseCustomTypeAsIndexType) {
834   const int kBatchSize = 10;
835   auto data_loader = torch::data::make_data_loader(
836       TestIndexDataset(23), TestIndexSampler(23), kBatchSize);
837 
838   for (auto batch : *data_loader) {
839     for (const auto j : c10::irange(kBatchSize)) {
840       ASSERT_EQ(batch.at(j), 10 + j);
841     }
842   }
843 }
844 
TEST(DataTest,DistributedRandomSamplerSingleReplicaProduceCorrectSamples)845 TEST(DataTest, DistributedRandomSamplerSingleReplicaProduceCorrectSamples) {
846   size_t sample_count = 10;
847   samplers::DistributedRandomSampler drs(sample_count);
848 
849   std::vector<size_t> res;
850   torch::optional<std::vector<size_t>> idx;
851   while ((idx = drs.next(3)).has_value()) {
852     res.insert(std::end(res), std::begin(*idx), std::end(*idx));
853   }
854 
855   ASSERT_EQ(res.size(), sample_count);
856 
857   std::sort(res.begin(), res.end());
858   for (const auto i : c10::irange(res.size())) {
859     ASSERT_EQ(res[i], i);
860   }
861 }
862 
TEST(DataTest,DistributedRandomSamplerMultiReplicaProduceCorrectSamples)863 TEST(DataTest, DistributedRandomSamplerMultiReplicaProduceCorrectSamples) {
864   size_t sample_count = 10;
865   size_t num_replicas = 3;
866 
867   auto test_function = [&](bool allow_duplicates,
868                            size_t local_sample_count,
869                            std::vector<size_t>& output,
870                            size_t batch_size) {
871     std::vector<std::unique_ptr<samplers::DistributedRandomSampler>> samplers;
872 
873     for (const auto i : c10::irange(num_replicas)) {
874       samplers.emplace_back(
875           std::make_unique<samplers::DistributedRandomSampler>(
876               sample_count, num_replicas, i, allow_duplicates));
877     }
878 
879     std::vector<size_t> res;
880     for (const auto i : c10::irange(num_replicas)) {
881       (*samplers[i]).reset();
882       torch::optional<std::vector<size_t>> idx;
883       while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
884         res.insert(std::end(res), std::begin(*idx), std::end(*idx));
885       }
886       ASSERT_EQ(res.size(), local_sample_count * (i + 1));
887     }
888     std::sort(res.begin(), res.end());
889     ASSERT_EQ(res, output);
890   };
891 
892   for (size_t batch_size = 1; batch_size <= 3; ++batch_size) {
893     size_t local_sample_count =
894         static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
895     std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
896     test_function(true, local_sample_count, output1, batch_size);
897 
898     local_sample_count =
899         static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
900     std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
901     test_function(false, local_sample_count, output2, batch_size);
902   }
903 }
904 
TEST(DataTest,CanSaveAndLoadDistributedRandomSampler)905 TEST(DataTest, CanSaveAndLoadDistributedRandomSampler) {
906   {
907     samplers::DistributedRandomSampler a(10);
908     ASSERT_EQ(a.index(), 0);
909     std::stringstream stream;
910     torch::save(a, stream);
911 
912     samplers::DistributedRandomSampler b(10);
913     torch::load(b, stream);
914     ASSERT_EQ(b.index(), 0);
915   }
916   {
917     samplers::DistributedRandomSampler a(10);
918     a.next(3);
919     a.next(4);
920     ASSERT_EQ(a.index(), 7);
921     std::stringstream stream;
922     torch::save(a, stream);
923 
924     samplers::DistributedRandomSampler b(10);
925     torch::load(b, stream);
926     ASSERT_EQ(b.index(), 7);
927   }
928   {
929     samplers::DistributedRandomSampler a(10);
930     a.set_epoch(3);
931     std::stringstream stream;
932     torch::save(a, stream);
933 
934     samplers::DistributedRandomSampler b(10);
935     torch::load(b, stream);
936     ASSERT_EQ(b.epoch(), 3);
937   }
938 }
939 
TEST(DataTest,DistributedSequentialSamplerSingleReplicaProduceCorrectSamples)940 TEST(DataTest, DistributedSequentialSamplerSingleReplicaProduceCorrectSamples) {
941   size_t sample_count = 10;
942   size_t batch_size = 3;
943   samplers::DistributedSequentialSampler dss(sample_count);
944 
945   std::vector<size_t> res;
946   torch::optional<std::vector<size_t>> idx;
947   while ((idx = dss.next(batch_size)).has_value()) {
948     res.insert(std::end(res), std::begin(*idx), std::end(*idx));
949   }
950 
951   ASSERT_EQ(res.size(), sample_count);
952 
953   std::sort(res.begin(), res.end());
954   for (const auto i : c10::irange(res.size())) {
955     ASSERT_EQ(res[i], i);
956   }
957 }
958 
TEST(DataTest,DistributedSequentialSamplerMultiReplicaProduceCorrectSamples)959 TEST(DataTest, DistributedSequentialSamplerMultiReplicaProduceCorrectSamples) {
960   size_t sample_count = 10;
961   size_t num_replicas = 3;
962 
963   auto test_function = [&](bool allow_duplicates,
964                            size_t local_sample_count,
965                            std::vector<size_t>& output,
966                            size_t batch_size) {
967     std::vector<std::unique_ptr<samplers::DistributedSequentialSampler>>
968         samplers;
969 
970     for (const auto i : c10::irange(num_replicas)) {
971       samplers.emplace_back(
972           std::make_unique<samplers::DistributedSequentialSampler>(
973               sample_count, num_replicas, i, allow_duplicates));
974     }
975 
976     std::vector<size_t> res;
977     for (const auto i : c10::irange(num_replicas)) {
978       (*samplers[i]).reset();
979       torch::optional<std::vector<size_t>> idx;
980       while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
981         res.insert(std::end(res), std::begin(*idx), std::end(*idx));
982       }
983       ASSERT_EQ(res.size(), local_sample_count * (i + 1));
984     }
985     std::sort(res.begin(), res.end());
986     ASSERT_EQ(res, output);
987   };
988 
989   for (size_t batch_size = 1; batch_size <= 3; ++batch_size) {
990     size_t local_sample_count =
991         static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
992     std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
993     test_function(true, local_sample_count, output1, batch_size);
994 
995     local_sample_count =
996         static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
997     std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
998     test_function(false, local_sample_count, output2, batch_size);
999   }
1000 }
1001 
TEST(DataTest,CanSaveAndLoadDistributedSequentialSampler)1002 TEST(DataTest, CanSaveAndLoadDistributedSequentialSampler) {
1003   {
1004     samplers::DistributedSequentialSampler a(10);
1005     ASSERT_EQ(a.index(), 0);
1006     std::stringstream stream;
1007     torch::save(a, stream);
1008 
1009     samplers::DistributedSequentialSampler b(10);
1010     torch::load(b, stream);
1011     ASSERT_EQ(b.index(), 0);
1012   }
1013   {
1014     samplers::DistributedSequentialSampler a(10);
1015     a.next(3);
1016     a.next(4);
1017     ASSERT_EQ(a.index(), 7);
1018     std::stringstream stream;
1019     torch::save(a, stream);
1020 
1021     samplers::DistributedSequentialSampler b(10);
1022     torch::load(b, stream);
1023     ASSERT_EQ(b.index(), 7);
1024   }
1025 }
1026 
TEST(DataLoaderTest,DataLoaderOptionsDefaultAsExpected)1027 TEST(DataLoaderTest, DataLoaderOptionsDefaultAsExpected) {
1028   DataLoaderOptions partial_options;
1029   FullDataLoaderOptions full_options(partial_options);
1030   ASSERT_EQ(full_options.batch_size, 1);
1031   ASSERT_FALSE(full_options.drop_last);
1032   ASSERT_EQ(full_options.workers, 0);
1033   ASSERT_EQ(full_options.max_jobs, 0);
1034   ASSERT_FALSE(full_options.timeout.has_value());
1035   ASSERT_TRUE(full_options.enforce_ordering);
1036 }
1037 
TEST(DataLoaderTest,DataLoaderOptionsCoalesceOptionalValues)1038 TEST(DataLoaderTest, DataLoaderOptionsCoalesceOptionalValues) {
1039   auto partial_options = DataLoaderOptions(32).workers(10);
1040   FullDataLoaderOptions full_options(partial_options);
1041   ASSERT_EQ(full_options.batch_size, 32);
1042   ASSERT_EQ(full_options.max_jobs, 2 * 10);
1043 }
1044 
TEST(DataLoaderTest,MakeDataLoaderDefaultsAsExpected)1045 TEST(DataLoaderTest, MakeDataLoaderDefaultsAsExpected) {
1046   auto data_loader = torch::data::make_data_loader(
1047       DummyDataset().map(transforms::Lambda<int>([](int x) { return x + 1; })));
1048   ASSERT_EQ(data_loader->options().batch_size, 1);
1049 }
1050 
1051 struct UnsizedDataset : public datasets::Dataset<UnsizedDataset> {
getUnsizedDataset1052   torch::data::Example<> get(size_t i) override {
1053     return {torch::ones(i), torch::ones(i)};
1054   }
sizeUnsizedDataset1055   torch::optional<size_t> size() const noexcept override {
1056     return torch::nullopt;
1057   }
1058 };
1059 
TEST(DataLoaderTest,MakeDataLoaderThrowsWhenConstructingSamplerWithUnsizedDataset)1060 TEST(
1061     DataLoaderTest,
1062     MakeDataLoaderThrowsWhenConstructingSamplerWithUnsizedDataset) {
1063   ASSERT_THROWS_WITH(
1064       torch::data::make_data_loader(UnsizedDataset{}),
1065       "Expected the dataset to be sized in order to construct the Sampler");
1066 }
1067 
TEST(DataLoaderTest,IteratorsCompareEqualToThemselves)1068 TEST(DataLoaderTest, IteratorsCompareEqualToThemselves) {
1069   auto data_loader = torch::data::make_data_loader(DummyDataset(), 32);
1070   auto begin = data_loader->begin();
1071   ASSERT_EQ(begin, begin);
1072   auto end = data_loader->end();
1073   ASSERT_EQ(end, end);
1074 }
1075 
TEST(DataLoaderTest,ValidIteratorsCompareUnequalToEachOther)1076 TEST(DataLoaderTest, ValidIteratorsCompareUnequalToEachOther) {
1077   auto data_loader = torch::data::make_data_loader(DummyDataset(), 32);
1078   auto i = data_loader->begin();
1079   auto j = data_loader->begin();
1080   ASSERT_NE(i, j);
1081   ++j;
1082   ASSERT_NE(i, j);
1083 }
1084 
TEST(DataLoaderTest,SentinelIteratorsCompareEqualToEachOther)1085 TEST(DataLoaderTest, SentinelIteratorsCompareEqualToEachOther) {
1086   auto data_loader = torch::data::make_data_loader(DummyDataset(), 32);
1087   auto i = data_loader->end();
1088   auto j = data_loader->end();
1089   ASSERT_EQ(i, j);
1090 }
1091 
TEST(DataLoaderTest,IteratorsCompareEqualToSentinelWhenExhausted)1092 TEST(DataLoaderTest, IteratorsCompareEqualToSentinelWhenExhausted) {
1093   DummyDataset dataset;
1094   auto data_loader =
1095       torch::data::make_data_loader(dataset, dataset.size().value() / 4);
1096   auto i = data_loader->begin();
1097   auto end = data_loader->end();
1098   ASSERT_NE(i, end);
1099   ++i;
1100   ASSERT_NE(i, end);
1101   ++i;
1102   ASSERT_NE(i, end);
1103   ++i;
1104   ASSERT_NE(i, end);
1105   ++i;
1106   ASSERT_EQ(i, end);
1107 }
1108 
TEST(DataLoaderTest,IteratorsShareState)1109 TEST(DataLoaderTest, IteratorsShareState) {
1110   DummyDataset dataset;
1111   auto data_loader =
1112       torch::data::make_data_loader(dataset, dataset.size().value() / 2);
1113   auto i = data_loader->begin();
1114   auto j = i;
1115   auto end = data_loader->end();
1116   ASSERT_NE(i, end);
1117   ASSERT_NE(j, end);
1118   ++i;
1119   ASSERT_NE(i, end);
1120   ASSERT_NE(j, end);
1121   ++j;
1122   ASSERT_EQ(i, end);
1123   ASSERT_EQ(j, end);
1124 }
1125 
TEST(DataLoaderTest,CanDereferenceIteratorMultipleTimes)1126 TEST(DataLoaderTest, CanDereferenceIteratorMultipleTimes) {
1127   DummyDataset dataset;
1128   auto data_loader =
1129       torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
1130           dataset,
1131           // NOLINTNEXTLINE(bugprone-argument-comment)
1132           /*batch_size=*/1);
1133   auto iterator = data_loader->begin();
1134   std::vector<int> expected = {1};
1135   ASSERT_EQ(*iterator, expected);
1136   ASSERT_EQ(*iterator, expected);
1137   ++iterator;
1138   expected[0] = 2;
1139   ASSERT_EQ(*iterator, expected);
1140   ASSERT_EQ(*iterator, expected);
1141   ++iterator;
1142   expected[0] = 3;
1143   ASSERT_EQ(*iterator, expected);
1144   ASSERT_EQ(*iterator, expected);
1145 }
1146 
TEST(DataLoaderTest,CanUseIteratorAlgorithms)1147 TEST(DataLoaderTest, CanUseIteratorAlgorithms) {
1148   struct D : datasets::BatchDataset<D, int> {
1149     int get_batch(torch::ArrayRef<size_t> indices) override {
1150       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1151       return 1 + indices.front();
1152     }
1153     torch::optional<size_t> size() const override {
1154       return 10;
1155     }
1156   };
1157 
1158   D dataset;
1159   auto data_loader =
1160       torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
1161           dataset, 1);
1162   std::vector<int> values;
1163   std::copy(
1164       data_loader->begin(), data_loader->end(), std::back_inserter(values));
1165   std::vector<int> expected(dataset.size().value());
1166   std::iota(expected.begin(), expected.end(), size_t(1));
1167   ASSERT_EQ(values, expected);
1168 }
1169 
TEST(DataLoaderTest,CallingBeginWhileOtherIteratorIsInFlightThrows)1170 TEST(DataLoaderTest, CallingBeginWhileOtherIteratorIsInFlightThrows) {
1171   DummyDataset dataset;
1172   auto data_loader =
1173       torch::data::make_data_loader(dataset, DataLoaderOptions(1).workers(2));
1174   auto i = data_loader->begin();
1175   ASSERT_THROWS_WITH(
1176       data_loader->begin(),
1177       "Attempted to get a new DataLoader iterator "
1178       "while another iterator is not yet exhausted");
1179 }
1180 
TEST(DataLoaderTest,IncrementingExhaustedValidIteratorThrows)1181 TEST(DataLoaderTest, IncrementingExhaustedValidIteratorThrows) {
1182   DummyDataset dataset;
1183   auto data_loader =
1184       torch::data::make_data_loader(dataset, dataset.size().value());
1185   auto i = data_loader->begin();
1186   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1187   ASSERT_NO_THROW(++i);
1188   ASSERT_THROWS_WITH(++i, "Attempted to increment iterator past the end");
1189 }
1190 
TEST(DataLoaderTest,DereferencingExhaustedValidIteratorThrows)1191 TEST(DataLoaderTest, DereferencingExhaustedValidIteratorThrows) {
1192   DummyDataset dataset;
1193   auto data_loader =
1194       torch::data::make_data_loader(dataset, dataset.size().value());
1195   auto i = data_loader->begin();
1196   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1197   ASSERT_NO_THROW(++i);
1198   ASSERT_THROWS_WITH(
1199       *i, "Attempted to dereference iterator that was past the end");
1200 }
1201 
TEST(DataLoaderTest,IncrementingSentinelIteratorThrows)1202 TEST(DataLoaderTest, IncrementingSentinelIteratorThrows) {
1203   DummyDataset dataset;
1204   auto data_loader =
1205       torch::data::make_data_loader(dataset, dataset.size().value());
1206   auto i = data_loader->end();
1207   ASSERT_THROWS_WITH(
1208       ++i,
1209       "Incrementing the DataLoader's past-the-end iterator is not allowed");
1210 }
1211 
TEST(DataLoaderTest,DereferencingSentinelIteratorThrows)1212 TEST(DataLoaderTest, DereferencingSentinelIteratorThrows) {
1213   DummyDataset dataset;
1214   auto data_loader =
1215       torch::data::make_data_loader(dataset, dataset.size().value());
1216   auto i = data_loader->end();
1217   ASSERT_THROWS_WITH(
1218       *i,
1219       "Dereferencing the DataLoader's past-the-end iterator is not allowed");
1220 }
1221 
TEST(DataLoaderTest,YieldsCorrectBatchSize)1222 TEST(DataLoaderTest, YieldsCorrectBatchSize) {
1223   DummyDataset dataset;
1224   auto data_loader = torch::data::make_data_loader(dataset, 25);
1225   auto iterator = data_loader->begin();
1226   ASSERT_EQ(iterator->size(), 25);
1227   ASSERT_EQ((++iterator)->size(), 25);
1228   ASSERT_EQ((++iterator)->size(), 25);
1229   ASSERT_EQ((++iterator)->size(), 25);
1230   ASSERT_EQ(++iterator, data_loader->end());
1231 }
1232 
TEST(DataLoaderTest,ReturnsLastBatchWhenSmallerThanBatchSizeWhenDropLastIsFalse)1233 TEST(
1234     DataLoaderTest,
1235     ReturnsLastBatchWhenSmallerThanBatchSizeWhenDropLastIsFalse) {
1236   DummyDataset dataset;
1237   auto data_loader = torch::data::make_data_loader(
1238       dataset, DataLoaderOptions(33).drop_last(false));
1239   auto iterator = data_loader->begin();
1240   ASSERT_EQ(iterator->size(), 33);
1241   ASSERT_EQ((++iterator)->size(), 33);
1242   ASSERT_EQ((++iterator)->size(), 33);
1243   ASSERT_EQ((++iterator)->size(), 1);
1244   ASSERT_EQ(++iterator, data_loader->end());
1245 }
1246 
TEST(DataLoaderTest,DoesNotReturnLastBatchWhenSmallerThanBatchSizeWhenDropLastIsTrue)1247 TEST(
1248     DataLoaderTest,
1249     DoesNotReturnLastBatchWhenSmallerThanBatchSizeWhenDropLastIsTrue) {
1250   DummyDataset dataset;
1251   auto data_loader = torch::data::make_data_loader(
1252       dataset, DataLoaderOptions(33).drop_last(true));
1253   auto iterator = data_loader->begin();
1254   ASSERT_EQ(iterator->size(), 33);
1255   ASSERT_EQ((++iterator)->size(), 33);
1256   ASSERT_EQ((++iterator)->size(), 33);
1257   ASSERT_EQ(++iterator, data_loader->end());
1258 }
1259 
TEST(DataLoaderTest,RespectsTimeout)1260 TEST(DataLoaderTest, RespectsTimeout) {
1261   struct Baton {
1262     std::condition_variable cv;
1263     std::mutex mutex;
1264   };
1265 
1266   struct D : datasets::Dataset<DummyDataset, int> {
1267     D(std::shared_ptr<Baton> b) : baton(std::move(b)) {}
1268     int get(size_t index) override {
1269       std::unique_lock<std::mutex> lock(baton->mutex);
1270       baton->cv.wait_for(lock, 1000 * kMillisecond);
1271       return 0;
1272     }
1273     torch::optional<size_t> size() const override {
1274       return 100;
1275     }
1276     std::shared_ptr<Baton> baton;
1277   };
1278 
1279   auto baton = std::make_shared<Baton>();
1280 
1281   auto data_loader = torch::data::make_data_loader(
1282       D{baton}, DataLoaderOptions().workers(1).timeout(10 * kMillisecond));
1283 
1284   auto start = std::chrono::system_clock::now();
1285 
1286   ASSERT_THROWS_WITH(*data_loader->begin(), "Timeout");
1287   baton->cv.notify_one();
1288 
1289   auto end = std::chrono::system_clock::now();
1290   auto duration = std::chrono::duration_cast<std::chrono::seconds>(end - start);
1291   ASSERT_LT(duration.count(), 1);
1292 }
1293 
1294 // stackoverflow.com/questions/24465533/implementing-boostbarrier-in-c11
1295 struct Barrier {
BarrierBarrier1296   explicit Barrier(size_t target) : counter_(target) {}
waitBarrier1297   void wait() {
1298     std::unique_lock<std::mutex> lock(mutex_);
1299     if (--counter_ == 0) {
1300       cv_.notify_all();
1301     } else {
1302       cv_.wait(lock, [this] { return this->counter_ == 0; });
1303     }
1304   }
1305 
1306   size_t counter_;
1307   std::condition_variable cv_;
1308   std::mutex mutex_;
1309 };
1310 
1311 // On the OrderingTest: This test is intended to verify that the
1312 // `enforce_ordering` option of the dataloader works correctly. The reason this
1313 // flag exists is because when the dataloader has multiple workers (threads)
1314 // enabled and this flag is not set, the order in which worker threads finish
1315 // loading their respective batch and push it back to the dataloader's main
1316 // thread (for outside consumption) is not deterministic. Imagine the sampler is
1317 // a SequentialSampler with indices 0, 1, 2, 3. With batch size 1, each index
1318 // will be a single "job". Inside the dataloader, worker threads block until a
1319 // job is available. It is not deterministic which worker thread wakes up batch
1320 // to dequeue a particular batch. Further, some worker threads may take longer
1321 // than others to read the data for their index. As such, it could be that
1322 // worker thread 2 finishes before all other threads and returns its batch to
1323 // the main thread. In that case, the dataloader iterator would return the datum
1324 // at index 2 batch, and afterwards the datum from whatever thread finishes
1325 // next. As such, the user may see data from indices 2, 0, 3, 1. On another run
1326 // of the same dataloader on the same data, threads may be scheduled differently
1327 // and return in order 0, 2, 3, 1. To force this ordering to deterministically
1328 // be 0, 1, 2, 3, the `enforce_ordering` flag can be set to true. In that case,
1329 // the dataloader will use a *sequencer* internally which keeps track of which
1330 // datum is expected next, and buffers any other results until that next
1331 // expected value arrives. For example, workers 1, 2, 3 may finish before worker
1332 // 0. If `enforce_ordering` is true, the sequencer will internally buffer the
1333 // results from 1, 2, 3 until worker 0 finishes. Only then does the dataloader
1334 // return the datum from worker 0 to the user (and then datum 1 the next time,
1335 // then 2 and so on).
1336 //
1337 // The way the test works is that we start
1338 // `kNumberOfWorkers` workers in the dataloader, which each get an index from a
1339 // `SequentialSampler` in the range `0...kNumberOfWorkers-1`. Each worker thread
1340 // has a copy of the dataset, and thus `get_batch()` is called on the
1341 // thread-local copy in each worker. We want to simulate out-of-order completion
1342 // of these threads. For this, we batch set a barrier in the `get_batch()`
1343 // method to make sure every worker has some index to fetch assigned. Further,
1344 // each worker thread has a unique ID in `0...kNumberOfWorkers-1`.
1345 // There is a hard-coded ordering, `kOrderInWhichWorkersReturnTheirBatch`, in
1346 // which we want the worker threads to return. For this, an iterator into this
1347 // order is maintained. When the dereferenced iterator (the current order index)
1348 // matches the thread ID of a worker, it knows it can now return its index as
1349 // well as progress the iterator. Inside the dataloader, the sequencer should
1350 // buffer these indices such that they are ultimately returned in order.
1351 
1352 namespace ordering_test {
1353 namespace {
1354 const size_t kNumberOfWorkers = 10;
1355 const std::vector<size_t> kOrderInWhichWorkersReturnTheirBatch =
1356     {3, 7, 0, 5, 4, 8, 2, 1, 9, 6};
1357 } // namespace
1358 
1359 struct Dataset : datasets::BatchDataset<Dataset, size_t> {
1360   Dataset() = default;
1361 
1362   // This copy constructor will be called when we copy the dataset into a
1363   // particular thread.
Datasetordering_test::Dataset1364   Dataset(const Dataset& other) {
1365     static std::atomic<size_t> counter{0};
1366     thread_id_ = counter.fetch_add(1);
1367   }
1368 
1369   Dataset(Dataset&& other) noexcept = default;
1370   Dataset& operator=(const Dataset& other) = delete;
1371   Dataset& operator=(Dataset&& other) noexcept = delete;
1372 
get_batchordering_test::Dataset1373   size_t get_batch(torch::ArrayRef<size_t> indices) override {
1374     static Barrier barrier(kNumberOfWorkers);
1375     static auto order_iterator = kOrderInWhichWorkersReturnTheirBatch.begin();
1376     static std::condition_variable cv;
1377     static std::mutex mutex;
1378 
1379     // Wait for all threads to get an index batch and arrive here.
1380     barrier.wait();
1381 
1382     std::unique_lock<std::mutex> lock(mutex);
1383     cv.wait(lock, [this] { return *order_iterator == this->thread_id_; });
1384     ++order_iterator;
1385     lock.unlock();
1386     cv.notify_all();
1387 
1388     return indices.front();
1389   }
1390 
sizeordering_test::Dataset1391   torch::optional<size_t> size() const override {
1392     return kNumberOfWorkers;
1393   }
1394 
1395   size_t thread_id_ = 0;
1396 };
1397 
1398 } // namespace ordering_test
1399 
TEST(DataLoaderTest,EnforcesOrderingAmongThreadsWhenConfigured)1400 TEST(DataLoaderTest, EnforcesOrderingAmongThreadsWhenConfigured) {
1401   auto data_loader = torch::data::make_data_loader(
1402       ordering_test::Dataset{},
1403       torch::data::samplers::SequentialSampler(ordering_test::kNumberOfWorkers),
1404       DataLoaderOptions()
1405           .batch_size(1)
1406           .workers(ordering_test::kNumberOfWorkers)
1407           .enforce_ordering(true));
1408   std::vector<size_t> output;
1409   for (size_t value : *data_loader) {
1410     output.push_back(value);
1411   }
1412   std::vector<size_t> expected(ordering_test::kNumberOfWorkers);
1413   std::iota(expected.begin(), expected.end(), size_t(0));
1414   ASSERT_EQ(expected, output);
1415 }
1416 
TEST(DataLoaderTest,Reset)1417 TEST(DataLoaderTest, Reset) {
1418   DummyDataset dataset;
1419   auto data_loader =
1420       torch::data::make_data_loader(dataset, dataset.size().value() / 2);
1421   auto end = data_loader->end();
1422 
1423   auto iterator = data_loader->begin();
1424   ASSERT_NE(iterator, end);
1425   ASSERT_NE(++iterator, end);
1426   ASSERT_EQ(++iterator, end);
1427 
1428   iterator = data_loader->begin();
1429   ASSERT_NE(iterator, end);
1430   ASSERT_NE(++iterator, end);
1431   ASSERT_EQ(++iterator, end);
1432 
1433   iterator = data_loader->begin();
1434   ASSERT_NE(iterator, end);
1435   ASSERT_NE(++iterator, end);
1436   ASSERT_EQ(++iterator, end);
1437 }
1438 
TEST(DataLoaderTest,TestExceptionsArePropagatedFromWorkers)1439 TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) {
1440   struct D : datasets::Dataset<DummyDataset, int> {
1441     int get(size_t index) override {
1442       throw std::invalid_argument("badness");
1443     }
1444     torch::optional<size_t> size() const override {
1445       return 100;
1446     }
1447   };
1448 
1449   auto data_loader = torch::data::make_data_loader(
1450       D{}, samplers::RandomSampler(100), DataLoaderOptions().workers(2));
1451   auto iterator = data_loader->begin();
1452 
1453   try {
1454     (void)*iterator;
1455   } catch (torch::data::WorkerException& e) {
1456     ASSERT_EQ(
1457         e.what(),
1458         std::string("Caught exception in DataLoader worker thread. "
1459                     "Original message: badness"));
1460     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1461     ASSERT_THROW(
1462         std::rethrow_exception(e.original_exception), std::invalid_argument);
1463   }
1464 }
1465 
TEST(DataLoaderTest,StatefulDatasetWithNoWorkers)1466 TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
1467   const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1468 
1469   struct D : datasets::StatefulDataset<D, int, size_t> {
1470     torch::optional<int> get_batch(size_t) override {
1471       if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1472         return counter++;
1473       }
1474       return torch::nullopt;
1475     }
1476     torch::optional<size_t> size() const override {
1477       return 100;
1478     }
1479     void reset() override {
1480       counter = 0;
1481     }
1482     void save(torch::serialize::OutputArchive& archive) const override{};
1483     void load(torch::serialize::InputArchive& archive) override {}
1484     int counter = 0;
1485   };
1486 
1487   auto data_loader = torch::data::make_data_loader(D{});
1488 
1489   for (const auto i : c10::irange(10)) {
1490     const auto number_of_iterations =
1491         std::distance(data_loader->begin(), data_loader->end());
1492     ASSERT_EQ(
1493         number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1494         << "epoch " << i;
1495   }
1496 
1497   for (const int i : *data_loader) {
1498     ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
1499   }
1500 }
1501 
TEST(DataLoaderTest,StatefulDatasetWithManyWorkers)1502 TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
1503   const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1504   const int kNumberOfWorkers = 4;
1505 
1506   struct D : datasets::StatefulDataset<D, int, size_t> {
1507     torch::optional<int> get_batch(size_t) override {
1508       std::lock_guard<std::mutex> lock(mutex);
1509       if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1510         return counter++;
1511       }
1512       return torch::nullopt;
1513     }
1514     torch::optional<size_t> size() const override {
1515       return 100;
1516     }
1517     void reset() override {
1518       counter = 0;
1519     }
1520     void save(torch::serialize::OutputArchive& archive) const override{};
1521     void load(torch::serialize::InputArchive& archive) override {}
1522     int counter = 0;
1523     std::mutex mutex;
1524   };
1525 
1526   auto data_loader = torch::data::make_data_loader(
1527       torch::data::datasets::make_shared_dataset<D>(),
1528       DataLoaderOptions().workers(kNumberOfWorkers));
1529 
1530   for (const auto i : c10::irange(10)) {
1531     const auto number_of_iterations =
1532         std::distance(data_loader->begin(), data_loader->end());
1533     ASSERT_EQ(
1534         number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1535         << "epoch " << i;
1536   }
1537 
1538   for (const int i : *data_loader) {
1539     ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
1540   }
1541 }
1542 
TEST(DataLoaderTest,StatefulDatasetWithMap)1543 TEST(DataLoaderTest, StatefulDatasetWithMap) {
1544   const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1545 
1546   struct D : datasets::StatefulDataset<D, int, size_t> {
1547     torch::optional<int> get_batch(size_t) override {
1548       if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1549         return counter++;
1550       }
1551       return torch::nullopt;
1552     }
1553     torch::optional<size_t> size() const override {
1554       return 100;
1555     }
1556     void reset() override {
1557       counter = 0;
1558     }
1559     void save(torch::serialize::OutputArchive& archive) const override{};
1560     void load(torch::serialize::InputArchive& archive) override {}
1561     int counter = 0;
1562   };
1563 
1564   auto data_loader = torch::data::make_data_loader(
1565       D().map(transforms::BatchLambda<int, std::string>(
1566                   [](int x) { return std::to_string(x); }))
1567           .map(transforms::BatchLambda<std::string, torch::Tensor>(
1568               [](const std::string& x) {
1569                 return torch::tensor(static_cast<int64_t>(std::stoi(x)));
1570               })),
1571       DataLoaderOptions{});
1572 
1573   for (const auto i : c10::irange(10)) {
1574     const auto number_of_iterations =
1575         std::distance(data_loader->begin(), data_loader->end());
1576     ASSERT_EQ(
1577         number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1578         << "epoch " << i;
1579   }
1580 
1581   for (const torch::Tensor& t : *data_loader) {
1582     ASSERT_LT(t.item<int64_t>(), kNumberOfExamplesAfterWhichTheDatasetExhausts);
1583   }
1584 }
1585 
TEST(DataLoaderTest,StatefulDatasetWithCollate)1586 TEST(DataLoaderTest, StatefulDatasetWithCollate) {
1587   const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1588 
1589   struct D : datasets::StatefulDataset<D> {
1590     torch::optional<std::vector<Example<>>> get_batch(
1591         size_t batch_size) override {
1592       if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1593         counter += batch_size;
1594         std::vector<Example<>> batch(
1595             /*count=*/batch_size,
1596             Example<>{
1597                 torch::ones(batch_size + 1), torch::zeros(batch_size - 1)});
1598         return batch;
1599       }
1600       return torch::nullopt;
1601     }
1602     torch::optional<size_t> size() const override {
1603       return 100;
1604     }
1605     void reset() override {
1606       counter = 0;
1607     }
1608     void save(torch::serialize::OutputArchive& archive) const override{};
1609     void load(torch::serialize::InputArchive& archive) override {}
1610     int counter = 0;
1611   };
1612 
1613   auto d = D().map(transforms::Stack<Example<>>());
1614 
1615   const size_t kBatchSize = 5;
1616 
1617   // Notice that the `get_batch()` of the dataset returns a vector<Example>, but
1618   // the `Stack` collation stacks the tensors into one.
1619   torch::optional<Example<>> batch = d.get_batch(kBatchSize);
1620   ASSERT_TRUE(batch.has_value());
1621   ASSERT_EQ(batch->data.size(0), kBatchSize);
1622   ASSERT_EQ(batch->data.size(1), kBatchSize + 1);
1623   ASSERT_EQ(batch->target.size(0), kBatchSize);
1624   ASSERT_EQ(batch->target.size(1), kBatchSize - 1);
1625 
1626   ASSERT_TRUE(batch->data[0].allclose(torch::ones(kBatchSize + 1)));
1627   ASSERT_TRUE(batch->target[0].allclose(torch::zeros(kBatchSize - 1)));
1628 }
1629 
1630 // This test tests the core function for iterate through a chunk dataset. It
1631 // contains test cases with different parameter combination. (For example,
1632 // different prefetch count, batch size and data loader worker count). It
1633 // verifies the return batches size and content when the order is deterministic.
TEST(DataLoaderTest,ChunkDataSetGetBatch)1634 TEST(DataLoaderTest, ChunkDataSetGetBatch) {
1635   // different prefetch count for testing.
1636   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1637   const size_t prefetch_counts[] = {1, 2, 3, 4};
1638 
1639   // different batch size for testing.
1640   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1641   const size_t batch_sizes[] = {5, 7};
1642 
1643   // test with/without worker threads
1644   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1645   const size_t dataloader_worker_counts[] = {0, 2};
1646 
1647   const size_t total_example_count = 35;
1648   DummyChunkDataReader data_reader;
1649   samplers::SequentialSampler sampler(0);
1650 
1651   // test functionality across epoch boundary
1652   const int epoch_count = 2;
1653 
1654   for (auto prefetch_count : prefetch_counts) {
1655     for (auto batch_size : batch_sizes) {
1656       for (auto dataloader_worker_count : dataloader_worker_counts) {
1657         datasets::SharedBatchDataset<datasets::ChunkDataset<
1658             DummyChunkDataReader,
1659             samplers::SequentialSampler,
1660             samplers::SequentialSampler>>
1661             dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1662                 DummyChunkDataReader,
1663                 samplers::SequentialSampler,
1664                 samplers::SequentialSampler>>(
1665                 data_reader,
1666                 sampler,
1667                 sampler,
1668                 datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1669 
1670         auto data_loader = torch::data::make_data_loader(
1671             dataset,
1672             DataLoaderOptions(batch_size).workers(dataloader_worker_count));
1673 
1674         for (const auto epoch_index : c10::irange(epoch_count)) {
1675           (void)epoch_index; // Suppress unused variable warning
1676           std::vector<bool> result(total_example_count, false);
1677           int iteration_count = 0;
1678           for (auto iterator = data_loader->begin();
1679                iterator != data_loader->end();
1680                ++iterator, ++iteration_count) {
1681             DummyChunkDataReader::BatchType& batch = *iterator;
1682             ASSERT_EQ(batch.size(), batch_size);
1683 
1684             // When prefetch_count is equal to 1 and no worker thread, the batch
1685             // order is deterministic. So we can verify elements in each batch.
1686             if (prefetch_count == 1 && dataloader_worker_count == 0) {
1687               for (const auto j : c10::irange(batch_size)) {
1688                 ASSERT_EQ(batch[j], iteration_count * batch_size + j);
1689               }
1690             }
1691             for (const auto j : c10::irange(batch_size)) {
1692               result[batch[j]] = true;
1693             }
1694           }
1695 
1696           for (auto data : result) {
1697             ASSERT_EQ(data, true);
1698           }
1699         }
1700       }
1701     }
1702   }
1703 }
1704 
TEST(DataLoaderTest,ChunkDataSetWithBatchSizeMismatch)1705 TEST(DataLoaderTest, ChunkDataSetWithBatchSizeMismatch) {
1706   const size_t prefetch_count = 1;
1707   const size_t batch_size = 5;
1708   const size_t requested_batch_size = 6;
1709 
1710   DummyChunkDataReader data_reader;
1711   samplers::SequentialSampler sampler(0);
1712 
1713   datasets::SharedBatchDataset<datasets::ChunkDataset<
1714       DummyChunkDataReader,
1715       samplers::SequentialSampler,
1716       samplers::SequentialSampler>>
1717       dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1718           DummyChunkDataReader,
1719           samplers::SequentialSampler,
1720           samplers::SequentialSampler>>(
1721           data_reader,
1722           sampler,
1723           sampler,
1724           datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1725 
1726   auto data_loader = torch::data::make_data_loader(
1727       dataset, DataLoaderOptions(requested_batch_size).workers(0));
1728 
1729   std::string exception_msg =
1730       "The requested batch size does not match with the initialized batch "
1731       "size.\n The requested batch size is 6, while the dataset is created"
1732       " with batch size equal to 5";
1733 
1734   ASSERT_THROWS_WITH(*(data_loader->begin()), exception_msg);
1735 }
1736 
TEST(DataLoaderTest,ChunkDataSetWithEmptyBatch)1737 TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) {
1738   struct DummyEmptyChunkDataReader : datasets::ChunkDataReader<int> {
1739    public:
1740     using BatchType = datasets::ChunkDataReader<int>::ChunkType;
1741 
1742     BatchType read_chunk(size_t chunk_index) override {
1743       return {};
1744     }
1745 
1746     size_t chunk_count() override {
1747       return 1;
1748     };
1749 
1750     void reset() override{};
1751   };
1752 
1753   const size_t prefetch_count = 1;
1754   const size_t batch_size = 5;
1755   DummyEmptyChunkDataReader data_reader;
1756   samplers::SequentialSampler sampler(0);
1757 
1758   datasets::SharedBatchDataset<datasets::ChunkDataset<
1759       DummyEmptyChunkDataReader,
1760       samplers::SequentialSampler,
1761       samplers::SequentialSampler>>
1762       dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1763           DummyEmptyChunkDataReader,
1764           samplers::SequentialSampler,
1765           samplers::SequentialSampler>>(
1766           data_reader,
1767           sampler,
1768           sampler,
1769           datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1770 
1771   auto data_loader = torch::data::make_data_loader(
1772       dataset, DataLoaderOptions(batch_size).workers(0));
1773 
1774   for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1775        ++iterator) {
1776     ASSERT_EQ(iterator->size(), 0);
1777   }
1778 }
1779 
TEST(DataLoaderTest,ChunkDataSetGetBatchWithUnevenBatchSize)1780 TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {
1781   struct D : public datasets::ChunkDataReader<int> {
1782    public:
1783     using BatchType = datasets::ChunkDataReader<int>::ChunkType;
1784 
1785     BatchType read_chunk(size_t chunk_index) override {
1786       BatchType batch_data(10, 0);
1787       return batch_data;
1788     }
1789 
1790     size_t chunk_count() override {
1791       return 2;
1792     };
1793 
1794     void reset() override{};
1795   };
1796 
1797   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1798   const size_t batch_sizes[] = {17, 30};
1799   D data_reader;
1800   samplers::SequentialSampler sampler(0);
1801 
1802   for (auto batch_size : batch_sizes) {
1803     datasets::SharedBatchDataset<datasets::ChunkDataset<
1804         D,
1805         samplers::SequentialSampler,
1806         samplers::SequentialSampler>>
1807         dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1808             D,
1809             samplers::SequentialSampler,
1810             samplers::SequentialSampler>>(
1811             data_reader,
1812             sampler,
1813             sampler,
1814             datasets::ChunkDatasetOptions(1, batch_size));
1815 
1816     auto data_loader = torch::data::make_data_loader(
1817         dataset, DataLoaderOptions(batch_size).workers(0));
1818 
1819     for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1820          ++iterator) {
1821       DummyChunkDataReader::BatchType batch = *iterator;
1822       auto batch_size = batch.size();
1823       if (batch_size == 17) {
1824         ASSERT_TRUE(batch.size() == 17 || batch.size() == 3);
1825       }
1826       if (batch_size == 30) {
1827         ASSERT_TRUE(batch.size() == 20);
1828       }
1829     }
1830   }
1831 }
1832 
TEST(DataLoaderTest,CanAccessChunkSamplerWithChunkDataSet)1833 TEST(DataLoaderTest, CanAccessChunkSamplerWithChunkDataSet) {
1834   const size_t prefetch_count = 2;
1835   const size_t batch_size = 5;
1836 
1837   DummyChunkDataReader data_reader;
1838   samplers::SequentialSampler sampler(0);
1839   datasets::SharedBatchDataset<datasets::ChunkDataset<
1840       DummyChunkDataReader,
1841       samplers::SequentialSampler,
1842       samplers::SequentialSampler>>
1843       dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1844           DummyChunkDataReader,
1845           samplers::SequentialSampler,
1846           samplers::SequentialSampler>>(
1847           data_reader,
1848           sampler,
1849           sampler,
1850           datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1851 
1852   samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();
1853 
1854   auto data_loader = torch::data::make_data_loader(
1855       dataset.map(transforms::BatchLambda<
1856                   DummyChunkDataReader::BatchType,
1857                   DummyChunkDataReader::DataType>(
1858           [](DummyChunkDataReader::BatchType batch) {
1859             return std::accumulate(batch.begin(), batch.end(), 0);
1860           })),
1861       DataLoaderOptions(batch_size).workers(0));
1862 
1863   // before we start, the index should be 0.
1864   ASSERT_EQ(chunk_sampler.index(), 0);
1865 
1866   size_t sum = 0;
1867   for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1868        ++iterator) {
1869     sum += *iterator;
1870   }
1871   ASSERT_EQ(sum, 595); // sum([0, 35))
1872   // 3 chunks, and when exhausted the value is already incremented.
1873   ASSERT_EQ(chunk_sampler.index(), 3);
1874 }
1875 
TEST(DataLoaderTest,ChunkDatasetDoesNotHang)1876 TEST(DataLoaderTest, ChunkDatasetDoesNotHang) {
1877   const size_t prefetch_count = 2;
1878   const size_t batch_size = 5;
1879   // this will make the preloaders to wait till the `get_batch()` calls.
1880   const size_t cache_size = 10;
1881 
1882   DummyChunkDataReader data_reader;
1883   samplers::SequentialSampler sampler(0);
1884   datasets::SharedBatchDataset<datasets::ChunkDataset<
1885       DummyChunkDataReader,
1886       samplers::SequentialSampler,
1887       samplers::SequentialSampler>>
1888       dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1889           DummyChunkDataReader,
1890           samplers::SequentialSampler,
1891           samplers::SequentialSampler>>(
1892           data_reader,
1893           sampler,
1894           sampler,
1895           datasets::ChunkDatasetOptions(
1896               prefetch_count, batch_size, cache_size));
1897 
1898   auto data_loader = torch::data::make_data_loader(
1899       dataset.map(transforms::BatchLambda<
1900                   DummyChunkDataReader::BatchType,
1901                   DummyChunkDataReader::DataType>(
1902           [](DummyChunkDataReader::BatchType batch) {
1903             return std::accumulate(batch.begin(), batch.end(), 0);
1904           })),
1905       DataLoaderOptions(batch_size).workers(0));
1906   // simply creates the iterator but no iteration. chunk preloaders are waiting
1907   // to fill the batch buffer but it is not draining. Still we need to exit
1908   // cleanly.
1909   auto iterator = data_loader->begin();
1910 }
1911 
1912 // Test ChunkDataset save function.
1913 // Note [save/load ChunkDataset as ChunkSampler]:
1914 // The chunk sampler inside ChunkDataset is used in a separate thread pool other
1915 // than the main thread. Thus it is very hard to accurately estimate its status
1916 // when ChunkDataset::save/ChunkDataset::load is called. For the pure purpose of
1917 // testing, we utilize the implementation fact that the file format for sampler
1918 // serialization is the same as ChunkDataset serialization, and manually control
1919 // the chunk sampler by calling the sampler's save/load method for value
1920 // validation. This is only for testing the specific save/load functionality. In
1921 // real user case, the user should still use matching ChunkDataset::save and
1922 // ChunkDataset::load method.
TEST(DataLoaderTest,ChunkDatasetSave)1923 TEST(DataLoaderTest, ChunkDatasetSave) {
1924   const size_t chunk_count_ = 6;
1925   const size_t chunk_size = 10;
1926 
1927   struct DummyTestChunkDataReader : datasets::ChunkDataReader<int> {
1928    public:
1929     using BatchType = datasets::ChunkDataReader<int>::ChunkType;
1930 
1931     BatchType read_chunk(size_t chunk_index) override {
1932       return batch_data_;
1933     }
1934 
1935     size_t chunk_count() override {
1936       return chunk_count_;
1937     };
1938 
1939     void reset() override{};
1940     BatchType batch_data_ = BatchType(chunk_size, 0);
1941   };
1942 
1943   const size_t prefetch_count = 1;
1944   const size_t batch_size = chunk_size;
1945   const size_t dataloader_worker_count = 0;
1946   samplers::SequentialSampler sampler(0);
1947   const int epoch_count = 2;
1948 
1949   DummyTestChunkDataReader data_reader;
1950 
1951   // tested save_intervals
1952   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1953   const size_t save_intervals[] = {1, 2};
1954 
1955   using datasets::ChunkDatasetOptions;
1956 
1957   for (auto save_interval : save_intervals) {
1958     auto tempfile = c10::make_tempfile();
1959 
1960     datasets::SharedBatchDataset<datasets::ChunkDataset<
1961         DummyTestChunkDataReader,
1962         samplers::SequentialSampler,
1963         samplers::SequentialSampler>>
1964         dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1965             DummyTestChunkDataReader,
1966             samplers::SequentialSampler,
1967             samplers::SequentialSampler>>(
1968             data_reader,
1969             sampler,
1970             sampler,
1971             ChunkDatasetOptions(
1972                 prefetch_count, batch_size, chunk_size /*cache size*/));
1973 
1974     auto data_loader = torch::data::make_data_loader(
1975         dataset,
1976         DataLoaderOptions(batch_size).workers(dataloader_worker_count));
1977 
1978     for (const auto epoch_index : c10::irange(epoch_count)) {
1979       (void)epoch_index; // Suppress unused variable warning
1980       unsigned iteration_count = 0;
1981       for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1982            ++iterator, ++iteration_count) {
1983         if ((iteration_count + 1) % save_interval == 0) {
1984           torch::save(*dataset, tempfile.name);
1985 
1986           samplers::SequentialSampler new_sampler(0);
1987 
1988           // See Note [save/load ChunkDataset as ChunkSampler]
1989           torch::load(new_sampler, tempfile.name);
1990 
1991           // Verify save logic. For ChunkDataset, the chunk data is stored in a
1992           // cache inside the dataset. One pool of threads are constantly
1993           // writing to the cache, and a different pool of thread are constantly
1994           // reading from the cache. Due to the nature of asynchronization, at
1995           // the time of get_batch(), which chunk is written to the cache is not
1996           // fully deterministic.
1997           // But we can still calculate a restricted window on the expected
1998           // output, hence verify the logic. In this test, the cache size is
1999           // configured to be the same as chunk size and batch size. So the
2000           // chunk data is written to the cache one by one. Only the current
2001           // batch is retrieved, the next chunk is written. Now in iteration 0,
2002           // after the first batch is retrieved, when we save the dataset
2003           // statues, there are three possible scenarios for the writer thread:
2004           // 1. it hasn't started loading the next chunk data yet, so the
2005           // sequential sampler index is still 0;
2006           // 2. it started to load the second chunk, so the sequential sampler
2007           // index is at 1;
2008           // 3. it finished loading the second chunk, and start to load the
2009           // third chunk, because the cache is still fully occupied by the data
2010           // from the second chunk, it is waiting to write to the cache. At this
2011           // point, the sampler index is at 2.
2012           // So now we have a window of [0, 2], which is what we expected the
2013           // sampler to save the index from. Now noted for sequential sampler,
2014           // it advances to the next index automatically in the call next(). So
2015           // when save the index, it saves the next index in stead of the
2016           // current one. In other word, after getting the first index from
2017           // sequential sampler, it already moves to the second index. So when
2018           // we save it, it is the second index we save. As a result,
2019           // we need to advance the window by one. Now we have the expected
2020           // window of [1, 3].
2021           // This analysis applies to all scenarios. So extend it to a more
2022           // general case: the expected saved index should falling into the
2023           // range of [iteration, iteration + 3], which is the validation
2024           // below.
2025           ASSERT_TRUE(
2026               new_sampler.index() >= iteration_count + 1 &&
2027               new_sampler.index() <= iteration_count + 3);
2028         }
2029       }
2030     }
2031   }
2032 }
2033 
2034 // Test ChunkDataset load function.
TEST(DataLoaderTest,ChunkDatasetLoad)2035 TEST(DataLoaderTest, ChunkDatasetLoad) {
2036   auto tempfile = c10::make_tempfile();
2037 
2038   const size_t prefetch_count = 1;
2039   const size_t batch_size = 10;
2040   const size_t dataloader_worker_count = 0;
2041 
2042   DummyChunkDataReader data_reader;
2043   samplers::SequentialSampler sampler(0);
2044 
2045   const size_t skipped_chunk = 2;
2046 
2047   // Configure sampler to skip 2 chunks
2048   {
2049     sampler.reset(data_reader.chunk_count());
2050     sampler.next(skipped_chunk);
2051 
2052     // See Note [save/load ChunkDataset as ChunkSampler]
2053     torch::save(sampler, tempfile.name);
2054   }
2055 
2056   // test functionality across epoch boundary. The first epoch should be
2057   // affected by the checkpoint, but the second should start normally.
2058   const int epoch_count = 2;
2059 
2060   datasets::SharedBatchDataset<datasets::ChunkDataset<
2061       DummyChunkDataReader,
2062       samplers::SequentialSampler,
2063       samplers::SequentialSampler>>
2064       dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
2065           DummyChunkDataReader,
2066           samplers::SequentialSampler,
2067           samplers::SequentialSampler>>(
2068           data_reader,
2069           sampler,
2070           sampler,
2071           datasets::ChunkDatasetOptions(
2072               prefetch_count, batch_size, 20 /*cache size*/));
2073 
2074   torch::load(*dataset, tempfile.name);
2075 
2076   auto data_loader = torch::data::make_data_loader(
2077       dataset, DataLoaderOptions(batch_size).workers(dataloader_worker_count));
2078 
2079   for (const auto epoch_index : c10::irange(epoch_count)) {
2080     int iteration_count = 0;
2081 
2082     // For the first epoch, the returned batch should be returned from the
2083     // third chunk, because the check point skipped the first two chunks. But
2084     // for the next epoch, it should start from the first batch.
2085     int initial_value = epoch_index == 0 ? 15 : 0;
2086 
2087     for (auto iterator = data_loader->begin(); iterator != data_loader->end();
2088          ++iterator, ++iteration_count) {
2089       DummyChunkDataReader::BatchType batch = *iterator;
2090 
2091       std::vector<int> expected_result;
2092       size_t expected_size = (epoch_index > 0 && iteration_count == 3) ? 5 : 10;
2093       expected_result.resize(expected_size);
2094       std::iota(expected_result.begin(), expected_result.end(), initial_value);
2095 
2096       ASSERT_EQ(batch.size(), expected_result.size());
2097       ASSERT_TRUE(
2098           std::equal(batch.begin(), batch.end(), expected_result.begin()));
2099 
2100       initial_value += batch_size;
2101     }
2102   }
2103 
2104   samplers::SequentialSampler new_sampler(0);
2105 
2106   // See Note [save/load ChunkDataset as ChunkSampler]
2107   torch::load(new_sampler, tempfile.name);
2108 
2109   ASSERT_EQ(new_sampler.index(), skipped_chunk);
2110 }
2111 
TEST(DataLoaderTest,ChunkDatasetCrossChunkShuffle)2112 TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) {
2113   const size_t chunk_size = 5;
2114   const size_t batch_size = 5;
2115 
2116   class S : public samplers::Sampler<> {
2117    public:
2118     explicit S(size_t size) : size_(size), index_(0){};
2119 
2120     void reset(torch::optional<size_t> new_size = torch::nullopt) override {
2121       if (new_size.has_value()) {
2122         size_ = *new_size;
2123       }
2124       indices_.resize(size_);
2125       size_t index = 0;
2126 
2127       // Repeatedly sample every 5 indices.
2128       for (const auto i : c10::irange(batch_size)) {
2129         for (size_t j = 0; j < size_ / batch_size; ++j) {
2130           indices_[index++] = i + batch_size * j;
2131         }
2132       }
2133       index_ = 0;
2134     }
2135 
2136     // Returns the next batch of indices.
2137     torch::optional<std::vector<size_t>> next(size_t batch_size) override {
2138       const auto remaining_indices = size_ - index_;
2139       if (remaining_indices == 0) {
2140         return torch::nullopt;
2141       }
2142       auto return_size = std::min(batch_size, remaining_indices);
2143       std::vector<size_t> index_batch(
2144           indices_.begin() + index_, indices_.begin() + index_ + return_size);
2145       index_ += return_size;
2146 
2147       return index_batch;
2148     }
2149 
2150     void save(torch::serialize::OutputArchive& archive) const override {}
2151     void load(torch::serialize::InputArchive& archive) override {}
2152 
2153    private:
2154     size_t size_;
2155     std::vector<size_t> indices_;
2156     size_t index_{0};
2157   };
2158 
2159   struct D : public datasets::ChunkDataReader<int> {
2160    public:
2161     using BatchType = datasets::ChunkDataReader<int>::ChunkType;
2162     D(size_t chunk_count) : chunk_count_(chunk_count) {}
2163 
2164     BatchType read_chunk(size_t chunk_index) override {
2165       BatchType batch_data(chunk_size, chunk_index);
2166       return batch_data;
2167     }
2168 
2169     size_t chunk_count() override {
2170       return chunk_count_;
2171     };
2172 
2173     void reset() override{};
2174     size_t chunk_count_;
2175   };
2176 
2177   const size_t prefetch_count = 1;
2178   const size_t cache_size = 10;
2179   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2180   const size_t cross_chunk_shuffle_counts[] = {2, 3};
2181   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2182   const size_t chunk_counts[] = {3, 4, 5};
2183 
2184   samplers::SequentialSampler chunk_sampler(0);
2185   S example_sampler(0);
2186 
2187   for (auto chunk_count : chunk_counts) {
2188     for (auto cross_chunk_shuffle_count : cross_chunk_shuffle_counts) {
2189       D data_reader(chunk_count);
2190 
2191       datasets::SharedBatchDataset<
2192           datasets::ChunkDataset<D, samplers::SequentialSampler, S>>
2193           dataset = datasets::make_shared_dataset<
2194               datasets::ChunkDataset<D, samplers::SequentialSampler, S>>(
2195               data_reader,
2196               chunk_sampler,
2197               example_sampler,
2198               datasets::ChunkDatasetOptions(
2199                   prefetch_count,
2200                   batch_size,
2201                   cache_size,
2202                   cross_chunk_shuffle_count));
2203 
2204       auto data_loader = torch::data::make_data_loader(
2205           dataset, DataLoaderOptions(batch_size).workers(0));
2206 
2207       std::vector<int> result;
2208       for (auto iterator = data_loader->begin(); iterator != data_loader->end();
2209            ++iterator) {
2210         auto batch_result = *iterator;
2211         std::copy(
2212             batch_result.begin(),
2213             batch_result.end(),
2214             std::back_inserter(result));
2215       }
2216 
2217       std::vector<int> expected_result;
2218       {
2219         // construct expected result
2220         for (const auto i : c10::irange(
2221                  (chunk_count + cross_chunk_shuffle_count - 1) /
2222                  cross_chunk_shuffle_count)) {
2223           for (const auto j : c10::irange(chunk_size)) {
2224             (void)j; // Suppress unused variable warning
2225             for (const auto k : c10::irange(cross_chunk_shuffle_count)) {
2226               if (i * cross_chunk_shuffle_count + k < chunk_count) {
2227                 expected_result.push_back(i * cross_chunk_shuffle_count + k);
2228               }
2229             }
2230           }
2231         }
2232       }
2233 
2234       ASSERT_EQ(result.size(), expected_result.size());
2235       ASSERT_TRUE(
2236           std::equal(result.begin(), result.end(), expected_result.begin()));
2237     }
2238   }
2239 }
2240 
TEST(DataLoaderTest,CustomPreprocessPolicy)2241 TEST(DataLoaderTest, CustomPreprocessPolicy) {
2242   const size_t chunk_size = 5;
2243   const size_t batch_size = 10;
2244 
2245   struct D : public datasets::ChunkDataReader<int> {
2246    public:
2247     using BatchType = datasets::ChunkDataReader<int>::ChunkType;
2248     D(size_t chunk_count) : chunk_count_(chunk_count) {}
2249 
2250     BatchType read_chunk(size_t chunk_index) override {
2251       BatchType batch_data(chunk_size);
2252       // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand)
2253       auto rand_gen = []() { return std::rand() % 100; };
2254       std::generate(batch_data.begin(), batch_data.end(), rand_gen);
2255       return batch_data;
2256     }
2257 
2258     size_t chunk_count() override {
2259       return chunk_count_;
2260     };
2261 
2262     void reset() override{};
2263     size_t chunk_count_;
2264   };
2265 
2266   // custom preprocessing policy - sort the data ascendingly
2267   auto sorting_policy = [](std::vector<int>& raw_batch_data) {
2268     std::sort(raw_batch_data.begin(), raw_batch_data.end());
2269   };
2270   std::function<void(std::vector<int>&)> policy_function = sorting_policy;
2271 
2272   const size_t prefetch_count = 1;
2273   const size_t cache_size = 10;
2274   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2275   const size_t cross_chunk_shuffle_counts[] = {1, 2};
2276   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2277   const size_t chunk_counts[] = {3, 4};
2278 
2279   samplers::SequentialSampler chunk_sampler(0);
2280 
2281   for (auto chunk_count : chunk_counts) {
2282     for (auto cross_chunk_shuffle_count : cross_chunk_shuffle_counts) {
2283       D data_reader(chunk_count);
2284 
2285       datasets::SharedBatchDataset<datasets::ChunkDataset<
2286           D,
2287           samplers::SequentialSampler,
2288           samplers::SequentialSampler>>
2289           dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
2290               D,
2291               samplers::SequentialSampler,
2292               samplers::SequentialSampler>>(
2293               data_reader,
2294               chunk_sampler,
2295               chunk_sampler,
2296               datasets::ChunkDatasetOptions(
2297                   prefetch_count,
2298                   batch_size,
2299                   cache_size,
2300                   cross_chunk_shuffle_count),
2301               policy_function);
2302 
2303       auto data_loader = torch::data::make_data_loader(
2304           dataset, DataLoaderOptions(batch_size).workers(0));
2305 
2306       std::vector<int> result;
2307       for (auto iterator = data_loader->begin(); iterator != data_loader->end();
2308            ++iterator) {
2309         auto batch_result = *iterator;
2310         if (batch_result.size() > chunk_size * cross_chunk_shuffle_count) {
2311           for (unsigned i = 0; i < batch_result.size(); i += chunk_size) {
2312             ASSERT_TRUE(std::is_sorted(
2313                 batch_result.begin() + i,
2314                 batch_result.begin() + i + chunk_size));
2315           }
2316         } else {
2317           ASSERT_TRUE(std::is_sorted(batch_result.begin(), batch_result.end()));
2318         }
2319       }
2320     }
2321   }
2322 }
2323