1 #include <gtest/gtest.h>
2
3 #include <torch/torch.h>
4
5 #include <test/cpp/api/support.h>
6
7 #include <c10/util/ArrayRef.h>
8 #include <c10/util/irange.h>
9 #include <c10/util/tempfile.h>
10
11 #include <algorithm>
12 #include <chrono>
13 #include <future>
14 #include <iostream>
15 #include <iterator>
16 #include <limits>
17 #include <mutex>
18 #include <numeric>
19 #include <stdexcept>
20 #include <string>
21 #include <thread>
22 #include <unordered_set>
23 #include <vector>
24
25 using namespace torch::data; // NOLINT
26
27 const std::chrono::milliseconds kMillisecond(1);
28
29 struct DummyDataset : datasets::Dataset<DummyDataset, int> {
DummyDatasetDummyDataset30 explicit DummyDataset(size_t size = 100) : size_(size) {}
31
getDummyDataset32 int get(size_t index) override {
33 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
34 return 1 + index;
35 }
sizeDummyDataset36 torch::optional<size_t> size() const override {
37 return size_;
38 }
39
40 size_t size_;
41 };
42
TEST(DataTest,DatasetCallsGetCorrectly)43 TEST(DataTest, DatasetCallsGetCorrectly) {
44 DummyDataset d;
45 std::vector<int> batch = d.get_batch({0, 1, 2, 3, 4});
46 std::vector<int> expected = {1, 2, 3, 4, 5};
47 ASSERT_EQ(batch, expected);
48 }
49
TEST(DataTest,TransformCallsGetApplyCorrectly)50 TEST(DataTest, TransformCallsGetApplyCorrectly) {
51 struct T : transforms::Transform<int, std::string> {
52 std::string apply(int input) override {
53 return std::to_string(input);
54 }
55 };
56
57 auto d = DummyDataset{}.map(T{});
58 std::vector<std::string> batch = d.get_batch({0, 1, 2, 3, 4});
59 std::vector<std::string> expected = {"1", "2", "3", "4", "5"};
60 ASSERT_EQ(batch, expected);
61 }
62
63 // dummy chunk data reader with 3 chunks and 35 examples in total. Each chunk
64 // contains 10, 5, 20 examples respectively.
65
66 struct DummyChunkDataReader : public datasets::ChunkDataReader<int> {
67 public:
68 using BatchType = datasets::ChunkDataReader<int>::ChunkType;
69 using DataType = datasets::ChunkDataReader<int>::ExampleType;
70
71 /// Read an entire chunk.
read_chunkDummyChunkDataReader72 BatchType read_chunk(size_t chunk_index) override {
73 BatchType batch_data;
74 int start_index = chunk_index == 0
75 ? 0
76 // NOLINTNEXTLINE(bugprone-fold-init-type)
77 : std::accumulate(chunk_sizes, chunk_sizes + chunk_index, 0);
78
79 batch_data.resize(chunk_sizes[chunk_index]);
80
81 std::iota(batch_data.begin(), batch_data.end(), start_index);
82
83 return batch_data;
84 }
85
chunk_countDummyChunkDataReader86 size_t chunk_count() override {
87 return chunk_count_;
88 };
89
resetDummyChunkDataReader90 void reset() override{};
91
92 const static size_t chunk_count_ = 3;
93 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
94 size_t chunk_sizes[chunk_count_] = {10, 5, 20};
95 };
96
TEST(DataTest,ChunkDataSetWithInvalidInitParameter)97 TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {
98 DummyChunkDataReader data_reader;
99 samplers::SequentialSampler sampler(0);
100
101 auto initialization_function = [&](size_t preloader_count,
102 size_t batch_size,
103 size_t cache_size,
104 size_t cross_chunk_shuffle_count = 1) {
105 datasets::SharedBatchDataset<datasets::ChunkDataset<
106 DummyChunkDataReader,
107 samplers::SequentialSampler,
108 samplers::SequentialSampler>>
109 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
110 DummyChunkDataReader,
111 samplers::SequentialSampler,
112 samplers::SequentialSampler>>(
113 data_reader,
114 sampler,
115 sampler,
116 datasets::ChunkDatasetOptions(
117 preloader_count,
118 batch_size,
119 cache_size,
120 cross_chunk_shuffle_count));
121 };
122
123 ASSERT_THROWS_WITH(
124 initialization_function(0, 1, 1),
125 "Preloader count is 0. At least one preloader needs to be specified.");
126
127 ASSERT_THROWS_WITH(
128 initialization_function(1, 0, 1),
129 "Batch size is 0. A positive batch size needs to be specified.");
130
131 ASSERT_THROWS_WITH(
132 initialization_function(1, 1, 0),
133 "Cache size is 0. A positive cache size needs to be specified.");
134
135 ASSERT_THROWS_WITH(
136 initialization_function(1, 10, 5),
137 "Cache size is less than batch size. Cache needs to be large enough to "
138 "hold at least one batch.");
139 ASSERT_THROWS_WITH(
140 initialization_function(1, 10, 20, 0),
141 "cross_chunk_shuffle_count needs to be greater than 0.");
142 }
143
144 struct InfiniteStreamDataset
145 : datasets::StreamDataset<InfiniteStreamDataset, std::vector<int>> {
get_batchInfiniteStreamDataset146 std::vector<int> get_batch(size_t batch_size) override {
147 std::vector<int> batch(batch_size);
148 for (auto& i : batch) {
149 i = counter++;
150 }
151 return batch;
152 }
153
sizeInfiniteStreamDataset154 torch::optional<size_t> size() const override {
155 return torch::nullopt;
156 }
157
158 size_t counter = 0;
159 };
160
TEST(DataTest,InfiniteStreamDataset)161 TEST(DataTest, InfiniteStreamDataset) {
162 const size_t kBatchSize = 13;
163
164 auto dataset = InfiniteStreamDataset().map(
165 transforms::Lambda<int>([](int x) { return x + 1; }));
166
167 auto data_loader = torch::data::make_data_loader(
168 std::move(dataset),
169 samplers::StreamSampler(/*epoch_size=*/39),
170 kBatchSize);
171
172 size_t batch_index = 0;
173 for (auto& batch : *data_loader) {
174 ASSERT_LT(batch_index, 3);
175 ASSERT_EQ(batch.size(), kBatchSize);
176 for (const auto j : c10::irange(kBatchSize)) {
177 ASSERT_EQ(batch.at(j), 1 + (batch_index * kBatchSize) + j);
178 }
179 batch_index += 1;
180 }
181 ASSERT_EQ(batch_index, 3);
182 }
183
TEST(DataTest,NoSequencerIsIdentity)184 TEST(DataTest, NoSequencerIsIdentity) {
185 using namespace torch::data::detail::sequencers; // NOLINT
186 NoSequencer<int> no_sequencer;
187 const auto value = no_sequencer.next([] { return 5; }).value();
188 ASSERT_EQ(value, 5);
189 }
190
TEST(DataTest,OrderedSequencerIsSetUpWell)191 TEST(DataTest, OrderedSequencerIsSetUpWell) {
192 using namespace torch::data::detail::sequencers; // NOLINT
193 struct S {
194 size_t sequence_number;
195 };
196 const size_t kMaxJobs = 5;
197 OrderedSequencer<S> sequencer(kMaxJobs);
198 ASSERT_EQ(sequencer.next_sequence_number_, 0);
199 ASSERT_EQ(sequencer.buffer_.size(), kMaxJobs);
200 }
201
TEST(DataTest,OrderedSequencerReOrdersValues)202 TEST(DataTest, OrderedSequencerReOrdersValues) {
203 using namespace torch::data::detail::sequencers; // NOLINT
204 struct S {
205 size_t sequence_number;
206 };
207 const size_t kMaxJobs = 5;
208 OrderedSequencer<S> sequencer(kMaxJobs);
209
210 std::vector<size_t> v = {0, 2, 4, 3, 1};
211 size_t index = 0;
212 auto getter = [&v, &index]() { return S{v.at(index++)}; };
213
214 // Let's say the sequence number matches for the batch one, then it should
215 // return immediately.
216 const auto batch = sequencer.next(getter);
217 ASSERT_EQ(batch.value().sequence_number, 0);
218 ASSERT_EQ(index, 1);
219
220 // Now it should call the getter until it gets the next value.
221 ASSERT_EQ(1, sequencer.next(getter).value().sequence_number);
222 ASSERT_EQ(index, 5);
223
224 // The next three should come in order.
225 for (size_t i = 2; i <= 4; ++i) {
226 // New value doesn't matter. In fact, it shouldn't be accessed.
227 ASSERT_EQ(i, sequencer.next(getter).value().sequence_number);
228 // The index doesn't change.
229 ASSERT_EQ(index, 5);
230 }
231 }
232
TEST(DataTest,BatchLambdaAppliesFunctionToBatch)233 TEST(DataTest, BatchLambdaAppliesFunctionToBatch) {
234 using InputBatch = std::vector<int>;
235 using OutputBatch = std::string;
236 DummyDataset d;
237 auto e = d.map(transforms::BatchLambda<InputBatch, OutputBatch>(
238 [](std::vector<int> input) {
239 return std::to_string(std::accumulate(input.begin(), input.end(), 0));
240 }));
241 ASSERT_EQ(e.get_batch({1, 2, 3, 4, 5}), std::string("20"));
242 }
243
TEST(DataTest,LambdaAppliesFunctionToExample)244 TEST(DataTest, LambdaAppliesFunctionToExample) {
245 auto d = DummyDataset().map(transforms::Lambda<int, std::string>(
246 static_cast<std::string (*)(int)>(std::to_string)));
247 std::vector<std::string> expected = {"1", "2", "3", "4", "5"};
248 ASSERT_EQ(d.get_batch({0, 1, 2, 3, 4}), expected);
249 }
250
TEST(DataTest,CollateReducesBatch)251 TEST(DataTest, CollateReducesBatch) {
252 auto d =
253 DummyDataset().map(transforms::Collate<int>([](std::vector<int> input) {
254 return std::accumulate(input.begin(), input.end(), 0);
255 }));
256 ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20);
257 }
258
TEST(DataTest,CollationReducesBatch)259 TEST(DataTest, CollationReducesBatch) {
260 struct Summer : transforms::Collation<int> {
261 int apply_batch(std::vector<int> input) override {
262 return std::accumulate(input.begin(), input.end(), 0);
263 }
264 };
265 auto d = DummyDataset().map(Summer{});
266 ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20);
267 }
268
TEST(DataTest,SequentialSamplerReturnsIndicesInOrder)269 TEST(DataTest, SequentialSamplerReturnsIndicesInOrder) {
270 samplers::SequentialSampler sampler(10);
271 ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
272 ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({3, 4, 5, 6, 7}));
273 ASSERT_EQ(sampler.next(2).value(), std::vector<size_t>({8, 9}));
274 ASSERT_FALSE(sampler.next(2).has_value());
275 }
276
TEST(DataTest,SequentialSamplerReturnsLessValuesForLastBatch)277 TEST(DataTest, SequentialSamplerReturnsLessValuesForLastBatch) {
278 samplers::SequentialSampler sampler(5);
279 ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
280 ASSERT_EQ(sampler.next(100).value(), std::vector<size_t>({3, 4}));
281 ASSERT_FALSE(sampler.next(2).has_value());
282 }
283
TEST(DataTest,SequentialSamplerResetsWell)284 TEST(DataTest, SequentialSamplerResetsWell) {
285 samplers::SequentialSampler sampler(5);
286 ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
287 ASSERT_FALSE(sampler.next(2).has_value());
288 sampler.reset();
289 ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
290 ASSERT_FALSE(sampler.next(2).has_value());
291 }
292
TEST(DataTest,SequentialSamplerResetsWithNewSizeWell)293 TEST(DataTest, SequentialSamplerResetsWithNewSizeWell) {
294 samplers::SequentialSampler sampler(5);
295 ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
296 ASSERT_FALSE(sampler.next(2).has_value());
297 sampler.reset(7);
298 ASSERT_EQ(
299 sampler.next(7).value(), std::vector<size_t>({0, 1, 2, 3, 4, 5, 6}));
300 ASSERT_FALSE(sampler.next(2).has_value());
301 sampler.reset(3);
302 ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
303 ASSERT_FALSE(sampler.next(2).has_value());
304 }
305
TEST(DataTest,CanSaveAndLoadSequentialSampler)306 TEST(DataTest, CanSaveAndLoadSequentialSampler) {
307 {
308 samplers::SequentialSampler a(10);
309 ASSERT_EQ(a.index(), 0);
310 std::stringstream stream;
311 torch::save(a, stream);
312
313 samplers::SequentialSampler b(10);
314 torch::load(b, stream);
315 ASSERT_EQ(b.index(), 0);
316 }
317 {
318 samplers::SequentialSampler a(10);
319 a.next(3);
320 a.next(4);
321 ASSERT_EQ(a.index(), 7);
322 std::stringstream stream;
323 torch::save(a, stream);
324
325 samplers::SequentialSampler b(10);
326 torch::load(b, stream);
327 ASSERT_EQ(b.index(), 7);
328 }
329 }
330
TEST(DataTest,RandomSamplerReturnsIndicesInCorrectRange)331 TEST(DataTest, RandomSamplerReturnsIndicesInCorrectRange) {
332 samplers::RandomSampler sampler(10);
333
334 std::vector<size_t> indices = sampler.next(3).value();
335 for (auto i : indices) {
336 ASSERT_GE(i, 0);
337 ASSERT_LT(i, 10);
338 }
339
340 indices = sampler.next(5).value();
341 for (auto i : indices) {
342 ASSERT_GE(i, 0);
343 ASSERT_LT(i, 10);
344 }
345
346 indices = sampler.next(2).value();
347 for (auto i : indices) {
348 ASSERT_GE(i, 0);
349 ASSERT_LT(i, 10);
350 }
351
352 ASSERT_FALSE(sampler.next(10).has_value());
353 }
354
TEST(DataTest,RandomSamplerReturnsLessValuesForLastBatch)355 TEST(DataTest, RandomSamplerReturnsLessValuesForLastBatch) {
356 samplers::RandomSampler sampler(5);
357 ASSERT_EQ(sampler.next(3).value().size(), 3);
358 ASSERT_EQ(sampler.next(100).value().size(), 2);
359 ASSERT_FALSE(sampler.next(2).has_value());
360 }
361
TEST(DataTest,RandomSamplerResetsWell)362 TEST(DataTest, RandomSamplerResetsWell) {
363 samplers::RandomSampler sampler(5);
364 ASSERT_EQ(sampler.next(5).value().size(), 5);
365 ASSERT_FALSE(sampler.next(2).has_value());
366 sampler.reset();
367 ASSERT_EQ(sampler.next(5).value().size(), 5);
368 ASSERT_FALSE(sampler.next(2).has_value());
369 }
370
TEST(DataTest,RandomSamplerResetsWithNewSizeWell)371 TEST(DataTest, RandomSamplerResetsWithNewSizeWell) {
372 samplers::RandomSampler sampler(5);
373 ASSERT_EQ(sampler.next(5).value().size(), 5);
374 ASSERT_FALSE(sampler.next(2).has_value());
375 sampler.reset(7);
376 ASSERT_EQ(sampler.next(7).value().size(), 7);
377 ASSERT_FALSE(sampler.next(2).has_value());
378 sampler.reset(3);
379 ASSERT_EQ(sampler.next(3).value().size(), 3);
380 ASSERT_FALSE(sampler.next(2).has_value());
381 }
382
TEST(DataTest,SavingAndLoadingRandomSamplerYieldsSameSequence)383 TEST(DataTest, SavingAndLoadingRandomSamplerYieldsSameSequence) {
384 {
385 samplers::RandomSampler a(10);
386
387 std::stringstream stream;
388 torch::save(a, stream);
389
390 samplers::RandomSampler b(10);
391 torch::load(b, stream);
392
393 ASSERT_EQ(a.next(10).value(), b.next(10).value());
394 }
395 {
396 samplers::RandomSampler a(10);
397 a.next(3);
398 ASSERT_EQ(a.index(), 3);
399
400 std::stringstream stream;
401 torch::save(a, stream);
402
403 samplers::RandomSampler b(10);
404 torch::load(b, stream);
405 ASSERT_EQ(b.index(), 3);
406
407 auto b_sequence = b.next(10).value();
408 ASSERT_EQ(b_sequence.size(), 7);
409 ASSERT_EQ(a.next(10).value(), b_sequence);
410 }
411 }
412
TEST(DataTest,StreamSamplerReturnsTheBatchSizeAndThenRemainder)413 TEST(DataTest, StreamSamplerReturnsTheBatchSizeAndThenRemainder) {
414 samplers::StreamSampler sampler(/*epoch_size=*/100);
415 ASSERT_EQ(sampler.next(10).value(), 10);
416 ASSERT_EQ(sampler.next(2).value(), 2);
417 ASSERT_EQ(sampler.next(85).value(), 85);
418 ASSERT_EQ(sampler.next(123).value(), 3);
419 ASSERT_FALSE(sampler.next(1).has_value());
420 }
421
TEST(DataTest,StreamSamplerResetsWell)422 TEST(DataTest, StreamSamplerResetsWell) {
423 samplers::StreamSampler sampler(/*epoch_size=*/5);
424 ASSERT_EQ(sampler.next(5).value().size(), 5);
425 ASSERT_FALSE(sampler.next(2).has_value());
426 sampler.reset();
427 ASSERT_EQ(sampler.next(5).value().size(), 5);
428 ASSERT_FALSE(sampler.next(2).has_value());
429 }
430
TEST(DataTest,StreamSamplerResetsWithNewSizeWell)431 TEST(DataTest, StreamSamplerResetsWithNewSizeWell) {
432 samplers::StreamSampler sampler(/*epoch_size=*/5);
433 ASSERT_EQ(sampler.next(5).value().size(), 5);
434 ASSERT_FALSE(sampler.next(2).has_value());
435 sampler.reset(7);
436 ASSERT_EQ(sampler.next(7).value().size(), 7);
437 ASSERT_FALSE(sampler.next(2).has_value());
438 sampler.reset(3);
439 ASSERT_EQ(sampler.next(3).value().size(), 3);
440 ASSERT_FALSE(sampler.next(2).has_value());
441 }
442
TEST(DataTest,TensorDatasetConstructsFromSingleTensor)443 TEST(DataTest, TensorDatasetConstructsFromSingleTensor) {
444 datasets::TensorDataset dataset(torch::eye(5));
445 ASSERT_TRUE(
446 torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.get(2)));
447 }
448
TEST(DataTest,TensorDatasetConstructsFromInitializerListOfTensors)449 TEST(DataTest, TensorDatasetConstructsFromInitializerListOfTensors) {
450 std::vector<torch::Tensor> vector = torch::eye(5).chunk(5);
451 datasets::TensorDataset dataset(vector);
452 ASSERT_TRUE(
453 torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.get(2)));
454 }
455
TEST(DataTest,StackTransformWorksForExample)456 TEST(DataTest, StackTransformWorksForExample) {
457 struct D : public datasets::Dataset<D> {
458 Example<> get(size_t index) override {
459 return {tensor[index], 1 + tensor[index]};
460 }
461
462 torch::optional<size_t> size() const override {
463 return tensor.size(0);
464 }
465
466 torch::Tensor tensor{torch::eye(4)};
467 };
468
469 auto d = D().map(transforms::Stack<Example<>>());
470
471 Example<> batch = d.get_batch({0, 1});
472 ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
473 ASSERT_TRUE(batch.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 0, 2)));
474
475 Example<> second = d.get_batch({2, 3});
476 ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
477 ASSERT_TRUE(second.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 2, 4)));
478 }
479
TEST(DataTest,StackTransformWorksForTensorExample)480 TEST(DataTest, StackTransformWorksForTensorExample) {
481 auto d = datasets::TensorDataset(torch::eye(4))
482 .map(transforms::Stack<TensorExample>());
483
484 TensorExample batch = d.get_batch({0, 1});
485 ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
486
487 TensorExample second = d.get_batch({2, 3});
488 ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
489 }
490
491 // Template classes cannot be nested in functions.
492 template <typename Target>
493 struct T : transforms::TensorTransform<Target> {
operator ()T494 torch::Tensor operator()(torch::Tensor input) override {
495 return input * 2;
496 }
497 };
498
499 struct TensorStringDataset
500 : datasets::
501 Dataset<TensorStringDataset, Example<torch::Tensor, std::string>> {
getTensorStringDataset502 Example<torch::Tensor, std::string> get(size_t index) override {
503 return {torch::tensor(static_cast<double>(index)), std::to_string(index)};
504 }
505
sizeTensorStringDataset506 torch::optional<size_t> size() const override {
507 return 100;
508 }
509 };
510
TEST(DataTest,TensorTransformWorksForAnyTargetType)511 TEST(DataTest, TensorTransformWorksForAnyTargetType) {
512 auto d = TensorStringDataset().map(T<std::string>{});
513 std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2});
514
515 ASSERT_EQ(batch.size(), 2);
516 ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0)));
517 ASSERT_EQ(batch[0].target, "1");
518
519 ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0)));
520 ASSERT_EQ(batch[1].target, "2");
521 }
522
TEST(DataTest,TensorLambdaWorksforAnyTargetType)523 TEST(DataTest, TensorLambdaWorksforAnyTargetType) {
524 auto d = TensorStringDataset().map(transforms::TensorLambda<std::string>(
525 [](torch::Tensor input) { return input * 2; }));
526 std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2});
527
528 ASSERT_EQ(batch.size(), 2);
529 ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0)));
530 ASSERT_EQ(batch[0].target, "1");
531
532 ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0)));
533 ASSERT_EQ(batch[1].target, "2");
534 }
535
536 struct DummyTensorDataset
537 : datasets::Dataset<DummyTensorDataset, Example<torch::Tensor, int>> {
getDummyTensorDataset538 Example<torch::Tensor, int> get(size_t index) override {
539 const auto channels = static_cast<int64_t>(index);
540 torch::Tensor tensor =
541 (channels > 0) ? torch::ones({channels, 4, 4}) : torch::ones({4, 4});
542 return {tensor, static_cast<int>(channels)};
543 }
544
sizeDummyTensorDataset545 torch::optional<size_t> size() const override {
546 return 100;
547 }
548 };
549
TEST(DataTest,NormalizeTransform)550 TEST(DataTest, NormalizeTransform) {
551 auto dataset = DummyTensorDataset().map(transforms::Normalize<int>(0.5, 0.1));
552
553 // Works for zero (one implicit) channels
554 std::vector<Example<torch::Tensor, int>> output = dataset.get_batch(0);
555 ASSERT_EQ(output.size(), 1);
556 // (1 - 0.5) / 0.1 = 5
557 ASSERT_TRUE(output[0].data.allclose(torch::ones({4, 4}) * 5))
558 << output[0].data;
559
560 // Works for one explicit channel
561 output = dataset.get_batch(1);
562 ASSERT_EQ(output.size(), 1);
563 ASSERT_EQ(output[0].data.size(0), 1);
564 ASSERT_TRUE(output[0].data.allclose(torch::ones({1, 4, 4}) * 5))
565 << output[0].data;
566
567 // Works for two channels with different moments
568 dataset = DummyTensorDataset().map(
569 transforms::Normalize<int>({0.5, 1.5}, {0.1, 0.2}));
570 output = dataset.get_batch(2);
571 ASSERT_EQ(output.size(), 1);
572 ASSERT_EQ(output[0].data.size(0), 2);
573 ASSERT_TRUE(output[0]
574 .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
575 .allclose(torch::ones({1, 4, 4}) * 5))
576 << output[0].data;
577 ASSERT_TRUE(output[0]
578 .data.slice(/*dim=*/0, /*start=*/1)
579 .allclose(torch::ones({1, 4, 4}) * -2.5))
580 << output[0].data;
581
582 // Works for three channels with one moment value
583 dataset = DummyTensorDataset().map(transforms::Normalize<int>(1.5, 0.2));
584 output = dataset.get_batch(3);
585 ASSERT_EQ(output.size(), 1);
586 ASSERT_EQ(output[0].data.size(0), 3);
587 ASSERT_TRUE(output[0].data.allclose(torch::ones({3, 4, 4}) * -2.5))
588 << output[0].data;
589
590 // Works for three channels with different moments
591 dataset = DummyTensorDataset().map(
592 transforms::Normalize<int>({0.5, 1.5, -1.5}, {0.1, 0.2, 0.2}));
593 output = dataset.get_batch(3);
594 ASSERT_EQ(output.size(), 1);
595 ASSERT_EQ(output[0].data.size(0), 3);
596 ASSERT_TRUE(output[0]
597 .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
598 .allclose(torch::ones({1, 4, 4}) * 5))
599 << output[0].data;
600 ASSERT_TRUE(output[0]
601 .data.slice(/*dim=*/0, /*start=*/1, /*end=*/2)
602 .allclose(torch::ones({1, 4, 4}) * -2.5))
603 << output[0].data;
604 ASSERT_TRUE(output[0]
605 .data.slice(/*dim=*/0, /*start=*/2)
606 .allclose(torch::ones({1, 4, 4}) * 12.5))
607 << output[0].data;
608 }
609
610 struct UnCopyableDataset : public datasets::Dataset<UnCopyableDataset> {
611 UnCopyableDataset() = default;
612
613 UnCopyableDataset(const UnCopyableDataset&) = delete;
614 UnCopyableDataset& operator=(const UnCopyableDataset&) = delete;
615
616 UnCopyableDataset(UnCopyableDataset&&) = default;
617 UnCopyableDataset& operator=(UnCopyableDataset&&) = default;
618
619 ~UnCopyableDataset() override = default;
620
getUnCopyableDataset621 Example<> get(size_t index) override {
622 return {
623 torch::tensor({static_cast<int64_t>(index)}),
624 torch::tensor({static_cast<int64_t>(index)})};
625 }
626
sizeUnCopyableDataset627 torch::optional<size_t> size() const override {
628 return 100;
629 }
630 };
631
TEST(DataTest,MapDoesNotCopy)632 TEST(DataTest, MapDoesNotCopy) {
633 auto dataset = UnCopyableDataset()
634 .map(transforms::TensorLambda<>(
635 [](torch::Tensor tensor) { return tensor + 1; }))
636 .map(transforms::TensorLambda<>(
637 [](torch::Tensor tensor) { return tensor + 2; }))
638 .map(transforms::TensorLambda<>(
639 [](torch::Tensor tensor) { return tensor + 3; }));
640
641 auto data = dataset.get_batch(1).at(0).data;
642 ASSERT_EQ(data.numel(), 1);
643 ASSERT_EQ(data[0].item<float>(), 7);
644 }
645
TEST(DataTest,QueuePushAndPopFromSameThread)646 TEST(DataTest, QueuePushAndPopFromSameThread) {
647 torch::data::detail::Queue<int> queue;
648 queue.push(1);
649 queue.push(2);
650 ASSERT_EQ(queue.pop(), 1);
651 ASSERT_EQ(queue.pop(), 2);
652 }
653
TEST(DataTest,QueuePopWithTimeoutThrowsUponTimeout)654 TEST(DataTest, QueuePopWithTimeoutThrowsUponTimeout) {
655 torch::data::detail::Queue<int> queue;
656 ASSERT_THROWS_WITH(
657 queue.pop(10 * kMillisecond),
658 "Timeout in DataLoader queue while waiting for next batch "
659 "(timeout was 10 ms)");
660 }
661
TEST(DataTest,QueuePushAndPopFromDifferentThreads)662 TEST(DataTest, QueuePushAndPopFromDifferentThreads) {
663 using torch::data::detail::Queue;
664
665 // First test: push batch and the pop in thread.
666 {
667 Queue<int> queue;
668 queue.push(1);
669 auto future =
670 std::async(std::launch::async, [&queue] { return queue.pop(); });
671 ASSERT_EQ(future.get(), 1);
672 }
673
674 // Second test: attempt to pop batch (and block), then push.
675 {
676 Queue<int> queue;
677 std::thread thread([&queue] {
678 std::this_thread::sleep_for(20 * kMillisecond);
679 queue.push(123);
680 });
681 ASSERT_EQ(queue.pop(), 123);
682 thread.join();
683 }
684 }
685
TEST(DataTest,QueueClearEmptiesTheQueue)686 TEST(DataTest, QueueClearEmptiesTheQueue) {
687 torch::data::detail::Queue<int> queue;
688 queue.push(1);
689 queue.push(2);
690 queue.push(3);
691 ASSERT_EQ(queue.clear(), 3);
692 ASSERT_THROWS_WITH(queue.pop(1 * kMillisecond), "Timeout");
693 }
694
TEST(DataTest,DataShuttleCanPushAndPopJob)695 TEST(DataTest, DataShuttleCanPushAndPopJob) {
696 torch::data::detail::DataShuttle<int, int> shuttle;
697 shuttle.push_job(1);
698 shuttle.push_job(2);
699 ASSERT_EQ(shuttle.pop_job(), 1);
700 ASSERT_EQ(shuttle.pop_job(), 2);
701 }
702
TEST(DataTest,DataShuttleCanPushAndPopResult)703 TEST(DataTest, DataShuttleCanPushAndPopResult) {
704 torch::data::detail::DataShuttle<int, int> shuttle;
705 // pop_result() will only attempt to pop if there was a push_job() batch.
706 shuttle.push_job(1);
707 shuttle.push_job(2);
708
709 shuttle.pop_job();
710 shuttle.push_result(1);
711 ASSERT_EQ(shuttle.pop_result().value(), 1);
712
713 shuttle.pop_job();
714 shuttle.push_result(2);
715 ASSERT_EQ(shuttle.pop_result().value(), 2);
716 }
717
TEST(DataTest,DataShuttlePopResultReturnsNulloptWhenNoJobsInFlight)718 TEST(DataTest, DataShuttlePopResultReturnsNulloptWhenNoJobsInFlight) {
719 torch::data::detail::DataShuttle<int, int> shuttle;
720 ASSERT_FALSE(shuttle.pop_result().has_value());
721 shuttle.push_job(1);
722 shuttle.pop_job();
723 shuttle.push_result(1);
724 ASSERT_EQ(shuttle.pop_result().value(), 1);
725 ASSERT_FALSE(shuttle.pop_result().has_value());
726 ASSERT_FALSE(shuttle.pop_result().has_value());
727 }
728
TEST(DataTest,DataShuttleDrainMeansPopResultReturnsNullopt)729 TEST(DataTest, DataShuttleDrainMeansPopResultReturnsNullopt) {
730 torch::data::detail::DataShuttle<int, int> shuttle;
731 shuttle.push_job(1);
732 shuttle.push_result(1);
733 shuttle.drain();
734 ASSERT_FALSE(shuttle.pop_result().has_value());
735 }
736
TEST(DataTest,DataShuttlePopResultTimesOut)737 TEST(DataTest, DataShuttlePopResultTimesOut) {
738 torch::data::detail::DataShuttle<int, int> shuttle;
739 shuttle.push_job(1);
740 ASSERT_THROWS_WITH(shuttle.pop_result(10 * kMillisecond), "Timeout");
741 }
742
743 struct UncopyableDataset : datasets::Dataset<UncopyableDataset, int> {
UncopyableDatasetUncopyableDataset744 UncopyableDataset(const std::string& /* unused */) {}
745
746 UncopyableDataset(UncopyableDataset&&) = default;
747 UncopyableDataset& operator=(UncopyableDataset&&) = default;
748
749 UncopyableDataset(const UncopyableDataset&) = delete;
750 UncopyableDataset& operator=(const UncopyableDataset&) = delete;
751
getUncopyableDataset752 int get(size_t index) override {
753 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
754 return 1 + index;
755 }
sizeUncopyableDataset756 torch::optional<size_t> size() const override {
757 return 100;
758 }
759 };
760
TEST(DataTest,SharedBatchDatasetReallyIsShared)761 TEST(DataTest, SharedBatchDatasetReallyIsShared) {
762 // This test will only compile if we really are not making any copies.
763 // There is otherwise no logic to test and because it is not deterministic
764 // how many and when worker threads access the shareddataset, we don't have
765 // any additional assertions here.
766
767 auto shared_dataset =
768 torch::data::datasets::make_shared_dataset<UncopyableDataset>(
769 "uncopyable");
770
771 auto data_loader = torch::data::make_data_loader(
772 shared_dataset, torch::data::DataLoaderOptions().workers(3));
773
774 for (auto batch : *data_loader) {
775 /* exhaust */
776 }
777 }
778
TEST(DataTest,SharedBatchDatasetDoesNotIncurCopyWhenPassedDatasetObject)779 TEST(DataTest, SharedBatchDatasetDoesNotIncurCopyWhenPassedDatasetObject) {
780 // This will not compile if a copy is made.
781 auto shared_dataset =
782 torch::data::datasets::make_shared_dataset<UncopyableDataset>(
783 UncopyableDataset("uncopyable"));
784 ASSERT_EQ(shared_dataset.size().value(), 100);
785 }
786
787 struct TestIndex : public torch::data::samplers::CustomBatchRequest {
TestIndexTestIndex788 explicit TestIndex(size_t offset, std::vector<size_t> index)
789 : offset(offset), index(std::move(index)) {}
sizeTestIndex790 size_t size() const override {
791 return index.size();
792 }
793 size_t offset;
794 std::vector<size_t> index;
795 };
796
797 struct TestIndexDataset
798 : datasets::BatchDataset<TestIndexDataset, std::vector<int>, TestIndex> {
TestIndexDatasetTestIndexDataset799 explicit TestIndexDataset(size_t size) : data(size) {
800 std::iota(data.begin(), data.end(), size_t(0));
801 }
get_batchTestIndexDataset802 std::vector<int> get_batch(TestIndex index) override {
803 std::vector<int> batch;
804 for (auto i : index.index) {
805 batch.push_back(index.offset + data.at(i));
806 }
807 return batch;
808 }
sizeTestIndexDataset809 torch::optional<size_t> size() const override {
810 return data.size();
811 }
812 std::vector<int> data;
813 };
814
815 struct TestIndexSampler : public samplers::Sampler<TestIndex> {
TestIndexSamplerTestIndexSampler816 explicit TestIndexSampler(size_t size) : size_(size) {}
resetTestIndexSampler817 void reset(torch::optional<size_t> new_size = torch::nullopt) override {}
nextTestIndexSampler818 torch::optional<TestIndex> next(size_t batch_size) override {
819 if (index_ >= size_) {
820 return torch::nullopt;
821 }
822 std::vector<size_t> indices(batch_size);
823 std::iota(indices.begin(), indices.end(), size_t(0));
824 index_ += batch_size;
825 return TestIndex(batch_size, std::move(indices));
826 }
saveTestIndexSampler827 void save(torch::serialize::OutputArchive& archive) const override {}
loadTestIndexSampler828 void load(torch::serialize::InputArchive& archive) override {}
829 size_t index_ = 0;
830 size_t size_;
831 };
832
TEST(DataTest,CanUseCustomTypeAsIndexType)833 TEST(DataTest, CanUseCustomTypeAsIndexType) {
834 const int kBatchSize = 10;
835 auto data_loader = torch::data::make_data_loader(
836 TestIndexDataset(23), TestIndexSampler(23), kBatchSize);
837
838 for (auto batch : *data_loader) {
839 for (const auto j : c10::irange(kBatchSize)) {
840 ASSERT_EQ(batch.at(j), 10 + j);
841 }
842 }
843 }
844
TEST(DataTest,DistributedRandomSamplerSingleReplicaProduceCorrectSamples)845 TEST(DataTest, DistributedRandomSamplerSingleReplicaProduceCorrectSamples) {
846 size_t sample_count = 10;
847 samplers::DistributedRandomSampler drs(sample_count);
848
849 std::vector<size_t> res;
850 torch::optional<std::vector<size_t>> idx;
851 while ((idx = drs.next(3)).has_value()) {
852 res.insert(std::end(res), std::begin(*idx), std::end(*idx));
853 }
854
855 ASSERT_EQ(res.size(), sample_count);
856
857 std::sort(res.begin(), res.end());
858 for (const auto i : c10::irange(res.size())) {
859 ASSERT_EQ(res[i], i);
860 }
861 }
862
TEST(DataTest,DistributedRandomSamplerMultiReplicaProduceCorrectSamples)863 TEST(DataTest, DistributedRandomSamplerMultiReplicaProduceCorrectSamples) {
864 size_t sample_count = 10;
865 size_t num_replicas = 3;
866
867 auto test_function = [&](bool allow_duplicates,
868 size_t local_sample_count,
869 std::vector<size_t>& output,
870 size_t batch_size) {
871 std::vector<std::unique_ptr<samplers::DistributedRandomSampler>> samplers;
872
873 for (const auto i : c10::irange(num_replicas)) {
874 samplers.emplace_back(
875 std::make_unique<samplers::DistributedRandomSampler>(
876 sample_count, num_replicas, i, allow_duplicates));
877 }
878
879 std::vector<size_t> res;
880 for (const auto i : c10::irange(num_replicas)) {
881 (*samplers[i]).reset();
882 torch::optional<std::vector<size_t>> idx;
883 while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
884 res.insert(std::end(res), std::begin(*idx), std::end(*idx));
885 }
886 ASSERT_EQ(res.size(), local_sample_count * (i + 1));
887 }
888 std::sort(res.begin(), res.end());
889 ASSERT_EQ(res, output);
890 };
891
892 for (size_t batch_size = 1; batch_size <= 3; ++batch_size) {
893 size_t local_sample_count =
894 static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
895 std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
896 test_function(true, local_sample_count, output1, batch_size);
897
898 local_sample_count =
899 static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
900 std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
901 test_function(false, local_sample_count, output2, batch_size);
902 }
903 }
904
TEST(DataTest,CanSaveAndLoadDistributedRandomSampler)905 TEST(DataTest, CanSaveAndLoadDistributedRandomSampler) {
906 {
907 samplers::DistributedRandomSampler a(10);
908 ASSERT_EQ(a.index(), 0);
909 std::stringstream stream;
910 torch::save(a, stream);
911
912 samplers::DistributedRandomSampler b(10);
913 torch::load(b, stream);
914 ASSERT_EQ(b.index(), 0);
915 }
916 {
917 samplers::DistributedRandomSampler a(10);
918 a.next(3);
919 a.next(4);
920 ASSERT_EQ(a.index(), 7);
921 std::stringstream stream;
922 torch::save(a, stream);
923
924 samplers::DistributedRandomSampler b(10);
925 torch::load(b, stream);
926 ASSERT_EQ(b.index(), 7);
927 }
928 {
929 samplers::DistributedRandomSampler a(10);
930 a.set_epoch(3);
931 std::stringstream stream;
932 torch::save(a, stream);
933
934 samplers::DistributedRandomSampler b(10);
935 torch::load(b, stream);
936 ASSERT_EQ(b.epoch(), 3);
937 }
938 }
939
TEST(DataTest,DistributedSequentialSamplerSingleReplicaProduceCorrectSamples)940 TEST(DataTest, DistributedSequentialSamplerSingleReplicaProduceCorrectSamples) {
941 size_t sample_count = 10;
942 size_t batch_size = 3;
943 samplers::DistributedSequentialSampler dss(sample_count);
944
945 std::vector<size_t> res;
946 torch::optional<std::vector<size_t>> idx;
947 while ((idx = dss.next(batch_size)).has_value()) {
948 res.insert(std::end(res), std::begin(*idx), std::end(*idx));
949 }
950
951 ASSERT_EQ(res.size(), sample_count);
952
953 std::sort(res.begin(), res.end());
954 for (const auto i : c10::irange(res.size())) {
955 ASSERT_EQ(res[i], i);
956 }
957 }
958
TEST(DataTest,DistributedSequentialSamplerMultiReplicaProduceCorrectSamples)959 TEST(DataTest, DistributedSequentialSamplerMultiReplicaProduceCorrectSamples) {
960 size_t sample_count = 10;
961 size_t num_replicas = 3;
962
963 auto test_function = [&](bool allow_duplicates,
964 size_t local_sample_count,
965 std::vector<size_t>& output,
966 size_t batch_size) {
967 std::vector<std::unique_ptr<samplers::DistributedSequentialSampler>>
968 samplers;
969
970 for (const auto i : c10::irange(num_replicas)) {
971 samplers.emplace_back(
972 std::make_unique<samplers::DistributedSequentialSampler>(
973 sample_count, num_replicas, i, allow_duplicates));
974 }
975
976 std::vector<size_t> res;
977 for (const auto i : c10::irange(num_replicas)) {
978 (*samplers[i]).reset();
979 torch::optional<std::vector<size_t>> idx;
980 while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
981 res.insert(std::end(res), std::begin(*idx), std::end(*idx));
982 }
983 ASSERT_EQ(res.size(), local_sample_count * (i + 1));
984 }
985 std::sort(res.begin(), res.end());
986 ASSERT_EQ(res, output);
987 };
988
989 for (size_t batch_size = 1; batch_size <= 3; ++batch_size) {
990 size_t local_sample_count =
991 static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
992 std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
993 test_function(true, local_sample_count, output1, batch_size);
994
995 local_sample_count =
996 static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
997 std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
998 test_function(false, local_sample_count, output2, batch_size);
999 }
1000 }
1001
TEST(DataTest,CanSaveAndLoadDistributedSequentialSampler)1002 TEST(DataTest, CanSaveAndLoadDistributedSequentialSampler) {
1003 {
1004 samplers::DistributedSequentialSampler a(10);
1005 ASSERT_EQ(a.index(), 0);
1006 std::stringstream stream;
1007 torch::save(a, stream);
1008
1009 samplers::DistributedSequentialSampler b(10);
1010 torch::load(b, stream);
1011 ASSERT_EQ(b.index(), 0);
1012 }
1013 {
1014 samplers::DistributedSequentialSampler a(10);
1015 a.next(3);
1016 a.next(4);
1017 ASSERT_EQ(a.index(), 7);
1018 std::stringstream stream;
1019 torch::save(a, stream);
1020
1021 samplers::DistributedSequentialSampler b(10);
1022 torch::load(b, stream);
1023 ASSERT_EQ(b.index(), 7);
1024 }
1025 }
1026
TEST(DataLoaderTest,DataLoaderOptionsDefaultAsExpected)1027 TEST(DataLoaderTest, DataLoaderOptionsDefaultAsExpected) {
1028 DataLoaderOptions partial_options;
1029 FullDataLoaderOptions full_options(partial_options);
1030 ASSERT_EQ(full_options.batch_size, 1);
1031 ASSERT_FALSE(full_options.drop_last);
1032 ASSERT_EQ(full_options.workers, 0);
1033 ASSERT_EQ(full_options.max_jobs, 0);
1034 ASSERT_FALSE(full_options.timeout.has_value());
1035 ASSERT_TRUE(full_options.enforce_ordering);
1036 }
1037
TEST(DataLoaderTest,DataLoaderOptionsCoalesceOptionalValues)1038 TEST(DataLoaderTest, DataLoaderOptionsCoalesceOptionalValues) {
1039 auto partial_options = DataLoaderOptions(32).workers(10);
1040 FullDataLoaderOptions full_options(partial_options);
1041 ASSERT_EQ(full_options.batch_size, 32);
1042 ASSERT_EQ(full_options.max_jobs, 2 * 10);
1043 }
1044
TEST(DataLoaderTest,MakeDataLoaderDefaultsAsExpected)1045 TEST(DataLoaderTest, MakeDataLoaderDefaultsAsExpected) {
1046 auto data_loader = torch::data::make_data_loader(
1047 DummyDataset().map(transforms::Lambda<int>([](int x) { return x + 1; })));
1048 ASSERT_EQ(data_loader->options().batch_size, 1);
1049 }
1050
1051 struct UnsizedDataset : public datasets::Dataset<UnsizedDataset> {
getUnsizedDataset1052 torch::data::Example<> get(size_t i) override {
1053 return {torch::ones(i), torch::ones(i)};
1054 }
sizeUnsizedDataset1055 torch::optional<size_t> size() const noexcept override {
1056 return torch::nullopt;
1057 }
1058 };
1059
TEST(DataLoaderTest,MakeDataLoaderThrowsWhenConstructingSamplerWithUnsizedDataset)1060 TEST(
1061 DataLoaderTest,
1062 MakeDataLoaderThrowsWhenConstructingSamplerWithUnsizedDataset) {
1063 ASSERT_THROWS_WITH(
1064 torch::data::make_data_loader(UnsizedDataset{}),
1065 "Expected the dataset to be sized in order to construct the Sampler");
1066 }
1067
TEST(DataLoaderTest,IteratorsCompareEqualToThemselves)1068 TEST(DataLoaderTest, IteratorsCompareEqualToThemselves) {
1069 auto data_loader = torch::data::make_data_loader(DummyDataset(), 32);
1070 auto begin = data_loader->begin();
1071 ASSERT_EQ(begin, begin);
1072 auto end = data_loader->end();
1073 ASSERT_EQ(end, end);
1074 }
1075
TEST(DataLoaderTest,ValidIteratorsCompareUnequalToEachOther)1076 TEST(DataLoaderTest, ValidIteratorsCompareUnequalToEachOther) {
1077 auto data_loader = torch::data::make_data_loader(DummyDataset(), 32);
1078 auto i = data_loader->begin();
1079 auto j = data_loader->begin();
1080 ASSERT_NE(i, j);
1081 ++j;
1082 ASSERT_NE(i, j);
1083 }
1084
TEST(DataLoaderTest,SentinelIteratorsCompareEqualToEachOther)1085 TEST(DataLoaderTest, SentinelIteratorsCompareEqualToEachOther) {
1086 auto data_loader = torch::data::make_data_loader(DummyDataset(), 32);
1087 auto i = data_loader->end();
1088 auto j = data_loader->end();
1089 ASSERT_EQ(i, j);
1090 }
1091
TEST(DataLoaderTest,IteratorsCompareEqualToSentinelWhenExhausted)1092 TEST(DataLoaderTest, IteratorsCompareEqualToSentinelWhenExhausted) {
1093 DummyDataset dataset;
1094 auto data_loader =
1095 torch::data::make_data_loader(dataset, dataset.size().value() / 4);
1096 auto i = data_loader->begin();
1097 auto end = data_loader->end();
1098 ASSERT_NE(i, end);
1099 ++i;
1100 ASSERT_NE(i, end);
1101 ++i;
1102 ASSERT_NE(i, end);
1103 ++i;
1104 ASSERT_NE(i, end);
1105 ++i;
1106 ASSERT_EQ(i, end);
1107 }
1108
TEST(DataLoaderTest,IteratorsShareState)1109 TEST(DataLoaderTest, IteratorsShareState) {
1110 DummyDataset dataset;
1111 auto data_loader =
1112 torch::data::make_data_loader(dataset, dataset.size().value() / 2);
1113 auto i = data_loader->begin();
1114 auto j = i;
1115 auto end = data_loader->end();
1116 ASSERT_NE(i, end);
1117 ASSERT_NE(j, end);
1118 ++i;
1119 ASSERT_NE(i, end);
1120 ASSERT_NE(j, end);
1121 ++j;
1122 ASSERT_EQ(i, end);
1123 ASSERT_EQ(j, end);
1124 }
1125
TEST(DataLoaderTest,CanDereferenceIteratorMultipleTimes)1126 TEST(DataLoaderTest, CanDereferenceIteratorMultipleTimes) {
1127 DummyDataset dataset;
1128 auto data_loader =
1129 torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
1130 dataset,
1131 // NOLINTNEXTLINE(bugprone-argument-comment)
1132 /*batch_size=*/1);
1133 auto iterator = data_loader->begin();
1134 std::vector<int> expected = {1};
1135 ASSERT_EQ(*iterator, expected);
1136 ASSERT_EQ(*iterator, expected);
1137 ++iterator;
1138 expected[0] = 2;
1139 ASSERT_EQ(*iterator, expected);
1140 ASSERT_EQ(*iterator, expected);
1141 ++iterator;
1142 expected[0] = 3;
1143 ASSERT_EQ(*iterator, expected);
1144 ASSERT_EQ(*iterator, expected);
1145 }
1146
TEST(DataLoaderTest,CanUseIteratorAlgorithms)1147 TEST(DataLoaderTest, CanUseIteratorAlgorithms) {
1148 struct D : datasets::BatchDataset<D, int> {
1149 int get_batch(torch::ArrayRef<size_t> indices) override {
1150 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1151 return 1 + indices.front();
1152 }
1153 torch::optional<size_t> size() const override {
1154 return 10;
1155 }
1156 };
1157
1158 D dataset;
1159 auto data_loader =
1160 torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
1161 dataset, 1);
1162 std::vector<int> values;
1163 std::copy(
1164 data_loader->begin(), data_loader->end(), std::back_inserter(values));
1165 std::vector<int> expected(dataset.size().value());
1166 std::iota(expected.begin(), expected.end(), size_t(1));
1167 ASSERT_EQ(values, expected);
1168 }
1169
TEST(DataLoaderTest,CallingBeginWhileOtherIteratorIsInFlightThrows)1170 TEST(DataLoaderTest, CallingBeginWhileOtherIteratorIsInFlightThrows) {
1171 DummyDataset dataset;
1172 auto data_loader =
1173 torch::data::make_data_loader(dataset, DataLoaderOptions(1).workers(2));
1174 auto i = data_loader->begin();
1175 ASSERT_THROWS_WITH(
1176 data_loader->begin(),
1177 "Attempted to get a new DataLoader iterator "
1178 "while another iterator is not yet exhausted");
1179 }
1180
TEST(DataLoaderTest,IncrementingExhaustedValidIteratorThrows)1181 TEST(DataLoaderTest, IncrementingExhaustedValidIteratorThrows) {
1182 DummyDataset dataset;
1183 auto data_loader =
1184 torch::data::make_data_loader(dataset, dataset.size().value());
1185 auto i = data_loader->begin();
1186 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1187 ASSERT_NO_THROW(++i);
1188 ASSERT_THROWS_WITH(++i, "Attempted to increment iterator past the end");
1189 }
1190
TEST(DataLoaderTest,DereferencingExhaustedValidIteratorThrows)1191 TEST(DataLoaderTest, DereferencingExhaustedValidIteratorThrows) {
1192 DummyDataset dataset;
1193 auto data_loader =
1194 torch::data::make_data_loader(dataset, dataset.size().value());
1195 auto i = data_loader->begin();
1196 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1197 ASSERT_NO_THROW(++i);
1198 ASSERT_THROWS_WITH(
1199 *i, "Attempted to dereference iterator that was past the end");
1200 }
1201
TEST(DataLoaderTest,IncrementingSentinelIteratorThrows)1202 TEST(DataLoaderTest, IncrementingSentinelIteratorThrows) {
1203 DummyDataset dataset;
1204 auto data_loader =
1205 torch::data::make_data_loader(dataset, dataset.size().value());
1206 auto i = data_loader->end();
1207 ASSERT_THROWS_WITH(
1208 ++i,
1209 "Incrementing the DataLoader's past-the-end iterator is not allowed");
1210 }
1211
TEST(DataLoaderTest,DereferencingSentinelIteratorThrows)1212 TEST(DataLoaderTest, DereferencingSentinelIteratorThrows) {
1213 DummyDataset dataset;
1214 auto data_loader =
1215 torch::data::make_data_loader(dataset, dataset.size().value());
1216 auto i = data_loader->end();
1217 ASSERT_THROWS_WITH(
1218 *i,
1219 "Dereferencing the DataLoader's past-the-end iterator is not allowed");
1220 }
1221
TEST(DataLoaderTest,YieldsCorrectBatchSize)1222 TEST(DataLoaderTest, YieldsCorrectBatchSize) {
1223 DummyDataset dataset;
1224 auto data_loader = torch::data::make_data_loader(dataset, 25);
1225 auto iterator = data_loader->begin();
1226 ASSERT_EQ(iterator->size(), 25);
1227 ASSERT_EQ((++iterator)->size(), 25);
1228 ASSERT_EQ((++iterator)->size(), 25);
1229 ASSERT_EQ((++iterator)->size(), 25);
1230 ASSERT_EQ(++iterator, data_loader->end());
1231 }
1232
TEST(DataLoaderTest,ReturnsLastBatchWhenSmallerThanBatchSizeWhenDropLastIsFalse)1233 TEST(
1234 DataLoaderTest,
1235 ReturnsLastBatchWhenSmallerThanBatchSizeWhenDropLastIsFalse) {
1236 DummyDataset dataset;
1237 auto data_loader = torch::data::make_data_loader(
1238 dataset, DataLoaderOptions(33).drop_last(false));
1239 auto iterator = data_loader->begin();
1240 ASSERT_EQ(iterator->size(), 33);
1241 ASSERT_EQ((++iterator)->size(), 33);
1242 ASSERT_EQ((++iterator)->size(), 33);
1243 ASSERT_EQ((++iterator)->size(), 1);
1244 ASSERT_EQ(++iterator, data_loader->end());
1245 }
1246
TEST(DataLoaderTest,DoesNotReturnLastBatchWhenSmallerThanBatchSizeWhenDropLastIsTrue)1247 TEST(
1248 DataLoaderTest,
1249 DoesNotReturnLastBatchWhenSmallerThanBatchSizeWhenDropLastIsTrue) {
1250 DummyDataset dataset;
1251 auto data_loader = torch::data::make_data_loader(
1252 dataset, DataLoaderOptions(33).drop_last(true));
1253 auto iterator = data_loader->begin();
1254 ASSERT_EQ(iterator->size(), 33);
1255 ASSERT_EQ((++iterator)->size(), 33);
1256 ASSERT_EQ((++iterator)->size(), 33);
1257 ASSERT_EQ(++iterator, data_loader->end());
1258 }
1259
TEST(DataLoaderTest,RespectsTimeout)1260 TEST(DataLoaderTest, RespectsTimeout) {
1261 struct Baton {
1262 std::condition_variable cv;
1263 std::mutex mutex;
1264 };
1265
1266 struct D : datasets::Dataset<DummyDataset, int> {
1267 D(std::shared_ptr<Baton> b) : baton(std::move(b)) {}
1268 int get(size_t index) override {
1269 std::unique_lock<std::mutex> lock(baton->mutex);
1270 baton->cv.wait_for(lock, 1000 * kMillisecond);
1271 return 0;
1272 }
1273 torch::optional<size_t> size() const override {
1274 return 100;
1275 }
1276 std::shared_ptr<Baton> baton;
1277 };
1278
1279 auto baton = std::make_shared<Baton>();
1280
1281 auto data_loader = torch::data::make_data_loader(
1282 D{baton}, DataLoaderOptions().workers(1).timeout(10 * kMillisecond));
1283
1284 auto start = std::chrono::system_clock::now();
1285
1286 ASSERT_THROWS_WITH(*data_loader->begin(), "Timeout");
1287 baton->cv.notify_one();
1288
1289 auto end = std::chrono::system_clock::now();
1290 auto duration = std::chrono::duration_cast<std::chrono::seconds>(end - start);
1291 ASSERT_LT(duration.count(), 1);
1292 }
1293
1294 // stackoverflow.com/questions/24465533/implementing-boostbarrier-in-c11
1295 struct Barrier {
BarrierBarrier1296 explicit Barrier(size_t target) : counter_(target) {}
waitBarrier1297 void wait() {
1298 std::unique_lock<std::mutex> lock(mutex_);
1299 if (--counter_ == 0) {
1300 cv_.notify_all();
1301 } else {
1302 cv_.wait(lock, [this] { return this->counter_ == 0; });
1303 }
1304 }
1305
1306 size_t counter_;
1307 std::condition_variable cv_;
1308 std::mutex mutex_;
1309 };
1310
1311 // On the OrderingTest: This test is intended to verify that the
1312 // `enforce_ordering` option of the dataloader works correctly. The reason this
1313 // flag exists is because when the dataloader has multiple workers (threads)
1314 // enabled and this flag is not set, the order in which worker threads finish
1315 // loading their respective batch and push it back to the dataloader's main
1316 // thread (for outside consumption) is not deterministic. Imagine the sampler is
1317 // a SequentialSampler with indices 0, 1, 2, 3. With batch size 1, each index
1318 // will be a single "job". Inside the dataloader, worker threads block until a
1319 // job is available. It is not deterministic which worker thread wakes up batch
1320 // to dequeue a particular batch. Further, some worker threads may take longer
1321 // than others to read the data for their index. As such, it could be that
1322 // worker thread 2 finishes before all other threads and returns its batch to
1323 // the main thread. In that case, the dataloader iterator would return the datum
1324 // at index 2 batch, and afterwards the datum from whatever thread finishes
1325 // next. As such, the user may see data from indices 2, 0, 3, 1. On another run
1326 // of the same dataloader on the same data, threads may be scheduled differently
1327 // and return in order 0, 2, 3, 1. To force this ordering to deterministically
1328 // be 0, 1, 2, 3, the `enforce_ordering` flag can be set to true. In that case,
1329 // the dataloader will use a *sequencer* internally which keeps track of which
1330 // datum is expected next, and buffers any other results until that next
1331 // expected value arrives. For example, workers 1, 2, 3 may finish before worker
1332 // 0. If `enforce_ordering` is true, the sequencer will internally buffer the
1333 // results from 1, 2, 3 until worker 0 finishes. Only then does the dataloader
1334 // return the datum from worker 0 to the user (and then datum 1 the next time,
1335 // then 2 and so on).
1336 //
1337 // The way the test works is that we start
1338 // `kNumberOfWorkers` workers in the dataloader, which each get an index from a
1339 // `SequentialSampler` in the range `0...kNumberOfWorkers-1`. Each worker thread
1340 // has a copy of the dataset, and thus `get_batch()` is called on the
1341 // thread-local copy in each worker. We want to simulate out-of-order completion
1342 // of these threads. For this, we batch set a barrier in the `get_batch()`
1343 // method to make sure every worker has some index to fetch assigned. Further,
1344 // each worker thread has a unique ID in `0...kNumberOfWorkers-1`.
1345 // There is a hard-coded ordering, `kOrderInWhichWorkersReturnTheirBatch`, in
1346 // which we want the worker threads to return. For this, an iterator into this
1347 // order is maintained. When the dereferenced iterator (the current order index)
1348 // matches the thread ID of a worker, it knows it can now return its index as
1349 // well as progress the iterator. Inside the dataloader, the sequencer should
1350 // buffer these indices such that they are ultimately returned in order.
1351
1352 namespace ordering_test {
1353 namespace {
1354 const size_t kNumberOfWorkers = 10;
1355 const std::vector<size_t> kOrderInWhichWorkersReturnTheirBatch =
1356 {3, 7, 0, 5, 4, 8, 2, 1, 9, 6};
1357 } // namespace
1358
1359 struct Dataset : datasets::BatchDataset<Dataset, size_t> {
1360 Dataset() = default;
1361
1362 // This copy constructor will be called when we copy the dataset into a
1363 // particular thread.
Datasetordering_test::Dataset1364 Dataset(const Dataset& other) {
1365 static std::atomic<size_t> counter{0};
1366 thread_id_ = counter.fetch_add(1);
1367 }
1368
1369 Dataset(Dataset&& other) noexcept = default;
1370 Dataset& operator=(const Dataset& other) = delete;
1371 Dataset& operator=(Dataset&& other) noexcept = delete;
1372
get_batchordering_test::Dataset1373 size_t get_batch(torch::ArrayRef<size_t> indices) override {
1374 static Barrier barrier(kNumberOfWorkers);
1375 static auto order_iterator = kOrderInWhichWorkersReturnTheirBatch.begin();
1376 static std::condition_variable cv;
1377 static std::mutex mutex;
1378
1379 // Wait for all threads to get an index batch and arrive here.
1380 barrier.wait();
1381
1382 std::unique_lock<std::mutex> lock(mutex);
1383 cv.wait(lock, [this] { return *order_iterator == this->thread_id_; });
1384 ++order_iterator;
1385 lock.unlock();
1386 cv.notify_all();
1387
1388 return indices.front();
1389 }
1390
sizeordering_test::Dataset1391 torch::optional<size_t> size() const override {
1392 return kNumberOfWorkers;
1393 }
1394
1395 size_t thread_id_ = 0;
1396 };
1397
1398 } // namespace ordering_test
1399
TEST(DataLoaderTest,EnforcesOrderingAmongThreadsWhenConfigured)1400 TEST(DataLoaderTest, EnforcesOrderingAmongThreadsWhenConfigured) {
1401 auto data_loader = torch::data::make_data_loader(
1402 ordering_test::Dataset{},
1403 torch::data::samplers::SequentialSampler(ordering_test::kNumberOfWorkers),
1404 DataLoaderOptions()
1405 .batch_size(1)
1406 .workers(ordering_test::kNumberOfWorkers)
1407 .enforce_ordering(true));
1408 std::vector<size_t> output;
1409 for (size_t value : *data_loader) {
1410 output.push_back(value);
1411 }
1412 std::vector<size_t> expected(ordering_test::kNumberOfWorkers);
1413 std::iota(expected.begin(), expected.end(), size_t(0));
1414 ASSERT_EQ(expected, output);
1415 }
1416
TEST(DataLoaderTest,Reset)1417 TEST(DataLoaderTest, Reset) {
1418 DummyDataset dataset;
1419 auto data_loader =
1420 torch::data::make_data_loader(dataset, dataset.size().value() / 2);
1421 auto end = data_loader->end();
1422
1423 auto iterator = data_loader->begin();
1424 ASSERT_NE(iterator, end);
1425 ASSERT_NE(++iterator, end);
1426 ASSERT_EQ(++iterator, end);
1427
1428 iterator = data_loader->begin();
1429 ASSERT_NE(iterator, end);
1430 ASSERT_NE(++iterator, end);
1431 ASSERT_EQ(++iterator, end);
1432
1433 iterator = data_loader->begin();
1434 ASSERT_NE(iterator, end);
1435 ASSERT_NE(++iterator, end);
1436 ASSERT_EQ(++iterator, end);
1437 }
1438
TEST(DataLoaderTest,TestExceptionsArePropagatedFromWorkers)1439 TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) {
1440 struct D : datasets::Dataset<DummyDataset, int> {
1441 int get(size_t index) override {
1442 throw std::invalid_argument("badness");
1443 }
1444 torch::optional<size_t> size() const override {
1445 return 100;
1446 }
1447 };
1448
1449 auto data_loader = torch::data::make_data_loader(
1450 D{}, samplers::RandomSampler(100), DataLoaderOptions().workers(2));
1451 auto iterator = data_loader->begin();
1452
1453 try {
1454 (void)*iterator;
1455 } catch (torch::data::WorkerException& e) {
1456 ASSERT_EQ(
1457 e.what(),
1458 std::string("Caught exception in DataLoader worker thread. "
1459 "Original message: badness"));
1460 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1461 ASSERT_THROW(
1462 std::rethrow_exception(e.original_exception), std::invalid_argument);
1463 }
1464 }
1465
TEST(DataLoaderTest,StatefulDatasetWithNoWorkers)1466 TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
1467 const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1468
1469 struct D : datasets::StatefulDataset<D, int, size_t> {
1470 torch::optional<int> get_batch(size_t) override {
1471 if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1472 return counter++;
1473 }
1474 return torch::nullopt;
1475 }
1476 torch::optional<size_t> size() const override {
1477 return 100;
1478 }
1479 void reset() override {
1480 counter = 0;
1481 }
1482 void save(torch::serialize::OutputArchive& archive) const override{};
1483 void load(torch::serialize::InputArchive& archive) override {}
1484 int counter = 0;
1485 };
1486
1487 auto data_loader = torch::data::make_data_loader(D{});
1488
1489 for (const auto i : c10::irange(10)) {
1490 const auto number_of_iterations =
1491 std::distance(data_loader->begin(), data_loader->end());
1492 ASSERT_EQ(
1493 number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1494 << "epoch " << i;
1495 }
1496
1497 for (const int i : *data_loader) {
1498 ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
1499 }
1500 }
1501
TEST(DataLoaderTest,StatefulDatasetWithManyWorkers)1502 TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
1503 const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1504 const int kNumberOfWorkers = 4;
1505
1506 struct D : datasets::StatefulDataset<D, int, size_t> {
1507 torch::optional<int> get_batch(size_t) override {
1508 std::lock_guard<std::mutex> lock(mutex);
1509 if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1510 return counter++;
1511 }
1512 return torch::nullopt;
1513 }
1514 torch::optional<size_t> size() const override {
1515 return 100;
1516 }
1517 void reset() override {
1518 counter = 0;
1519 }
1520 void save(torch::serialize::OutputArchive& archive) const override{};
1521 void load(torch::serialize::InputArchive& archive) override {}
1522 int counter = 0;
1523 std::mutex mutex;
1524 };
1525
1526 auto data_loader = torch::data::make_data_loader(
1527 torch::data::datasets::make_shared_dataset<D>(),
1528 DataLoaderOptions().workers(kNumberOfWorkers));
1529
1530 for (const auto i : c10::irange(10)) {
1531 const auto number_of_iterations =
1532 std::distance(data_loader->begin(), data_loader->end());
1533 ASSERT_EQ(
1534 number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1535 << "epoch " << i;
1536 }
1537
1538 for (const int i : *data_loader) {
1539 ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
1540 }
1541 }
1542
TEST(DataLoaderTest,StatefulDatasetWithMap)1543 TEST(DataLoaderTest, StatefulDatasetWithMap) {
1544 const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1545
1546 struct D : datasets::StatefulDataset<D, int, size_t> {
1547 torch::optional<int> get_batch(size_t) override {
1548 if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1549 return counter++;
1550 }
1551 return torch::nullopt;
1552 }
1553 torch::optional<size_t> size() const override {
1554 return 100;
1555 }
1556 void reset() override {
1557 counter = 0;
1558 }
1559 void save(torch::serialize::OutputArchive& archive) const override{};
1560 void load(torch::serialize::InputArchive& archive) override {}
1561 int counter = 0;
1562 };
1563
1564 auto data_loader = torch::data::make_data_loader(
1565 D().map(transforms::BatchLambda<int, std::string>(
1566 [](int x) { return std::to_string(x); }))
1567 .map(transforms::BatchLambda<std::string, torch::Tensor>(
1568 [](const std::string& x) {
1569 return torch::tensor(static_cast<int64_t>(std::stoi(x)));
1570 })),
1571 DataLoaderOptions{});
1572
1573 for (const auto i : c10::irange(10)) {
1574 const auto number_of_iterations =
1575 std::distance(data_loader->begin(), data_loader->end());
1576 ASSERT_EQ(
1577 number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1578 << "epoch " << i;
1579 }
1580
1581 for (const torch::Tensor& t : *data_loader) {
1582 ASSERT_LT(t.item<int64_t>(), kNumberOfExamplesAfterWhichTheDatasetExhausts);
1583 }
1584 }
1585
TEST(DataLoaderTest,StatefulDatasetWithCollate)1586 TEST(DataLoaderTest, StatefulDatasetWithCollate) {
1587 const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1588
1589 struct D : datasets::StatefulDataset<D> {
1590 torch::optional<std::vector<Example<>>> get_batch(
1591 size_t batch_size) override {
1592 if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1593 counter += batch_size;
1594 std::vector<Example<>> batch(
1595 /*count=*/batch_size,
1596 Example<>{
1597 torch::ones(batch_size + 1), torch::zeros(batch_size - 1)});
1598 return batch;
1599 }
1600 return torch::nullopt;
1601 }
1602 torch::optional<size_t> size() const override {
1603 return 100;
1604 }
1605 void reset() override {
1606 counter = 0;
1607 }
1608 void save(torch::serialize::OutputArchive& archive) const override{};
1609 void load(torch::serialize::InputArchive& archive) override {}
1610 int counter = 0;
1611 };
1612
1613 auto d = D().map(transforms::Stack<Example<>>());
1614
1615 const size_t kBatchSize = 5;
1616
1617 // Notice that the `get_batch()` of the dataset returns a vector<Example>, but
1618 // the `Stack` collation stacks the tensors into one.
1619 torch::optional<Example<>> batch = d.get_batch(kBatchSize);
1620 ASSERT_TRUE(batch.has_value());
1621 ASSERT_EQ(batch->data.size(0), kBatchSize);
1622 ASSERT_EQ(batch->data.size(1), kBatchSize + 1);
1623 ASSERT_EQ(batch->target.size(0), kBatchSize);
1624 ASSERT_EQ(batch->target.size(1), kBatchSize - 1);
1625
1626 ASSERT_TRUE(batch->data[0].allclose(torch::ones(kBatchSize + 1)));
1627 ASSERT_TRUE(batch->target[0].allclose(torch::zeros(kBatchSize - 1)));
1628 }
1629
1630 // This test tests the core function for iterate through a chunk dataset. It
1631 // contains test cases with different parameter combination. (For example,
1632 // different prefetch count, batch size and data loader worker count). It
1633 // verifies the return batches size and content when the order is deterministic.
TEST(DataLoaderTest,ChunkDataSetGetBatch)1634 TEST(DataLoaderTest, ChunkDataSetGetBatch) {
1635 // different prefetch count for testing.
1636 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1637 const size_t prefetch_counts[] = {1, 2, 3, 4};
1638
1639 // different batch size for testing.
1640 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1641 const size_t batch_sizes[] = {5, 7};
1642
1643 // test with/without worker threads
1644 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1645 const size_t dataloader_worker_counts[] = {0, 2};
1646
1647 const size_t total_example_count = 35;
1648 DummyChunkDataReader data_reader;
1649 samplers::SequentialSampler sampler(0);
1650
1651 // test functionality across epoch boundary
1652 const int epoch_count = 2;
1653
1654 for (auto prefetch_count : prefetch_counts) {
1655 for (auto batch_size : batch_sizes) {
1656 for (auto dataloader_worker_count : dataloader_worker_counts) {
1657 datasets::SharedBatchDataset<datasets::ChunkDataset<
1658 DummyChunkDataReader,
1659 samplers::SequentialSampler,
1660 samplers::SequentialSampler>>
1661 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1662 DummyChunkDataReader,
1663 samplers::SequentialSampler,
1664 samplers::SequentialSampler>>(
1665 data_reader,
1666 sampler,
1667 sampler,
1668 datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1669
1670 auto data_loader = torch::data::make_data_loader(
1671 dataset,
1672 DataLoaderOptions(batch_size).workers(dataloader_worker_count));
1673
1674 for (const auto epoch_index : c10::irange(epoch_count)) {
1675 (void)epoch_index; // Suppress unused variable warning
1676 std::vector<bool> result(total_example_count, false);
1677 int iteration_count = 0;
1678 for (auto iterator = data_loader->begin();
1679 iterator != data_loader->end();
1680 ++iterator, ++iteration_count) {
1681 DummyChunkDataReader::BatchType& batch = *iterator;
1682 ASSERT_EQ(batch.size(), batch_size);
1683
1684 // When prefetch_count is equal to 1 and no worker thread, the batch
1685 // order is deterministic. So we can verify elements in each batch.
1686 if (prefetch_count == 1 && dataloader_worker_count == 0) {
1687 for (const auto j : c10::irange(batch_size)) {
1688 ASSERT_EQ(batch[j], iteration_count * batch_size + j);
1689 }
1690 }
1691 for (const auto j : c10::irange(batch_size)) {
1692 result[batch[j]] = true;
1693 }
1694 }
1695
1696 for (auto data : result) {
1697 ASSERT_EQ(data, true);
1698 }
1699 }
1700 }
1701 }
1702 }
1703 }
1704
TEST(DataLoaderTest,ChunkDataSetWithBatchSizeMismatch)1705 TEST(DataLoaderTest, ChunkDataSetWithBatchSizeMismatch) {
1706 const size_t prefetch_count = 1;
1707 const size_t batch_size = 5;
1708 const size_t requested_batch_size = 6;
1709
1710 DummyChunkDataReader data_reader;
1711 samplers::SequentialSampler sampler(0);
1712
1713 datasets::SharedBatchDataset<datasets::ChunkDataset<
1714 DummyChunkDataReader,
1715 samplers::SequentialSampler,
1716 samplers::SequentialSampler>>
1717 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1718 DummyChunkDataReader,
1719 samplers::SequentialSampler,
1720 samplers::SequentialSampler>>(
1721 data_reader,
1722 sampler,
1723 sampler,
1724 datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1725
1726 auto data_loader = torch::data::make_data_loader(
1727 dataset, DataLoaderOptions(requested_batch_size).workers(0));
1728
1729 std::string exception_msg =
1730 "The requested batch size does not match with the initialized batch "
1731 "size.\n The requested batch size is 6, while the dataset is created"
1732 " with batch size equal to 5";
1733
1734 ASSERT_THROWS_WITH(*(data_loader->begin()), exception_msg);
1735 }
1736
TEST(DataLoaderTest,ChunkDataSetWithEmptyBatch)1737 TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) {
1738 struct DummyEmptyChunkDataReader : datasets::ChunkDataReader<int> {
1739 public:
1740 using BatchType = datasets::ChunkDataReader<int>::ChunkType;
1741
1742 BatchType read_chunk(size_t chunk_index) override {
1743 return {};
1744 }
1745
1746 size_t chunk_count() override {
1747 return 1;
1748 };
1749
1750 void reset() override{};
1751 };
1752
1753 const size_t prefetch_count = 1;
1754 const size_t batch_size = 5;
1755 DummyEmptyChunkDataReader data_reader;
1756 samplers::SequentialSampler sampler(0);
1757
1758 datasets::SharedBatchDataset<datasets::ChunkDataset<
1759 DummyEmptyChunkDataReader,
1760 samplers::SequentialSampler,
1761 samplers::SequentialSampler>>
1762 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1763 DummyEmptyChunkDataReader,
1764 samplers::SequentialSampler,
1765 samplers::SequentialSampler>>(
1766 data_reader,
1767 sampler,
1768 sampler,
1769 datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1770
1771 auto data_loader = torch::data::make_data_loader(
1772 dataset, DataLoaderOptions(batch_size).workers(0));
1773
1774 for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1775 ++iterator) {
1776 ASSERT_EQ(iterator->size(), 0);
1777 }
1778 }
1779
TEST(DataLoaderTest,ChunkDataSetGetBatchWithUnevenBatchSize)1780 TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {
1781 struct D : public datasets::ChunkDataReader<int> {
1782 public:
1783 using BatchType = datasets::ChunkDataReader<int>::ChunkType;
1784
1785 BatchType read_chunk(size_t chunk_index) override {
1786 BatchType batch_data(10, 0);
1787 return batch_data;
1788 }
1789
1790 size_t chunk_count() override {
1791 return 2;
1792 };
1793
1794 void reset() override{};
1795 };
1796
1797 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1798 const size_t batch_sizes[] = {17, 30};
1799 D data_reader;
1800 samplers::SequentialSampler sampler(0);
1801
1802 for (auto batch_size : batch_sizes) {
1803 datasets::SharedBatchDataset<datasets::ChunkDataset<
1804 D,
1805 samplers::SequentialSampler,
1806 samplers::SequentialSampler>>
1807 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1808 D,
1809 samplers::SequentialSampler,
1810 samplers::SequentialSampler>>(
1811 data_reader,
1812 sampler,
1813 sampler,
1814 datasets::ChunkDatasetOptions(1, batch_size));
1815
1816 auto data_loader = torch::data::make_data_loader(
1817 dataset, DataLoaderOptions(batch_size).workers(0));
1818
1819 for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1820 ++iterator) {
1821 DummyChunkDataReader::BatchType batch = *iterator;
1822 auto batch_size = batch.size();
1823 if (batch_size == 17) {
1824 ASSERT_TRUE(batch.size() == 17 || batch.size() == 3);
1825 }
1826 if (batch_size == 30) {
1827 ASSERT_TRUE(batch.size() == 20);
1828 }
1829 }
1830 }
1831 }
1832
TEST(DataLoaderTest,CanAccessChunkSamplerWithChunkDataSet)1833 TEST(DataLoaderTest, CanAccessChunkSamplerWithChunkDataSet) {
1834 const size_t prefetch_count = 2;
1835 const size_t batch_size = 5;
1836
1837 DummyChunkDataReader data_reader;
1838 samplers::SequentialSampler sampler(0);
1839 datasets::SharedBatchDataset<datasets::ChunkDataset<
1840 DummyChunkDataReader,
1841 samplers::SequentialSampler,
1842 samplers::SequentialSampler>>
1843 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1844 DummyChunkDataReader,
1845 samplers::SequentialSampler,
1846 samplers::SequentialSampler>>(
1847 data_reader,
1848 sampler,
1849 sampler,
1850 datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1851
1852 samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();
1853
1854 auto data_loader = torch::data::make_data_loader(
1855 dataset.map(transforms::BatchLambda<
1856 DummyChunkDataReader::BatchType,
1857 DummyChunkDataReader::DataType>(
1858 [](DummyChunkDataReader::BatchType batch) {
1859 return std::accumulate(batch.begin(), batch.end(), 0);
1860 })),
1861 DataLoaderOptions(batch_size).workers(0));
1862
1863 // before we start, the index should be 0.
1864 ASSERT_EQ(chunk_sampler.index(), 0);
1865
1866 size_t sum = 0;
1867 for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1868 ++iterator) {
1869 sum += *iterator;
1870 }
1871 ASSERT_EQ(sum, 595); // sum([0, 35))
1872 // 3 chunks, and when exhausted the value is already incremented.
1873 ASSERT_EQ(chunk_sampler.index(), 3);
1874 }
1875
TEST(DataLoaderTest,ChunkDatasetDoesNotHang)1876 TEST(DataLoaderTest, ChunkDatasetDoesNotHang) {
1877 const size_t prefetch_count = 2;
1878 const size_t batch_size = 5;
1879 // this will make the preloaders to wait till the `get_batch()` calls.
1880 const size_t cache_size = 10;
1881
1882 DummyChunkDataReader data_reader;
1883 samplers::SequentialSampler sampler(0);
1884 datasets::SharedBatchDataset<datasets::ChunkDataset<
1885 DummyChunkDataReader,
1886 samplers::SequentialSampler,
1887 samplers::SequentialSampler>>
1888 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1889 DummyChunkDataReader,
1890 samplers::SequentialSampler,
1891 samplers::SequentialSampler>>(
1892 data_reader,
1893 sampler,
1894 sampler,
1895 datasets::ChunkDatasetOptions(
1896 prefetch_count, batch_size, cache_size));
1897
1898 auto data_loader = torch::data::make_data_loader(
1899 dataset.map(transforms::BatchLambda<
1900 DummyChunkDataReader::BatchType,
1901 DummyChunkDataReader::DataType>(
1902 [](DummyChunkDataReader::BatchType batch) {
1903 return std::accumulate(batch.begin(), batch.end(), 0);
1904 })),
1905 DataLoaderOptions(batch_size).workers(0));
1906 // simply creates the iterator but no iteration. chunk preloaders are waiting
1907 // to fill the batch buffer but it is not draining. Still we need to exit
1908 // cleanly.
1909 auto iterator = data_loader->begin();
1910 }
1911
1912 // Test ChunkDataset save function.
1913 // Note [save/load ChunkDataset as ChunkSampler]:
1914 // The chunk sampler inside ChunkDataset is used in a separate thread pool other
1915 // than the main thread. Thus it is very hard to accurately estimate its status
1916 // when ChunkDataset::save/ChunkDataset::load is called. For the pure purpose of
1917 // testing, we utilize the implementation fact that the file format for sampler
1918 // serialization is the same as ChunkDataset serialization, and manually control
1919 // the chunk sampler by calling the sampler's save/load method for value
1920 // validation. This is only for testing the specific save/load functionality. In
1921 // real user case, the user should still use matching ChunkDataset::save and
1922 // ChunkDataset::load method.
TEST(DataLoaderTest,ChunkDatasetSave)1923 TEST(DataLoaderTest, ChunkDatasetSave) {
1924 const size_t chunk_count_ = 6;
1925 const size_t chunk_size = 10;
1926
1927 struct DummyTestChunkDataReader : datasets::ChunkDataReader<int> {
1928 public:
1929 using BatchType = datasets::ChunkDataReader<int>::ChunkType;
1930
1931 BatchType read_chunk(size_t chunk_index) override {
1932 return batch_data_;
1933 }
1934
1935 size_t chunk_count() override {
1936 return chunk_count_;
1937 };
1938
1939 void reset() override{};
1940 BatchType batch_data_ = BatchType(chunk_size, 0);
1941 };
1942
1943 const size_t prefetch_count = 1;
1944 const size_t batch_size = chunk_size;
1945 const size_t dataloader_worker_count = 0;
1946 samplers::SequentialSampler sampler(0);
1947 const int epoch_count = 2;
1948
1949 DummyTestChunkDataReader data_reader;
1950
1951 // tested save_intervals
1952 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1953 const size_t save_intervals[] = {1, 2};
1954
1955 using datasets::ChunkDatasetOptions;
1956
1957 for (auto save_interval : save_intervals) {
1958 auto tempfile = c10::make_tempfile();
1959
1960 datasets::SharedBatchDataset<datasets::ChunkDataset<
1961 DummyTestChunkDataReader,
1962 samplers::SequentialSampler,
1963 samplers::SequentialSampler>>
1964 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1965 DummyTestChunkDataReader,
1966 samplers::SequentialSampler,
1967 samplers::SequentialSampler>>(
1968 data_reader,
1969 sampler,
1970 sampler,
1971 ChunkDatasetOptions(
1972 prefetch_count, batch_size, chunk_size /*cache size*/));
1973
1974 auto data_loader = torch::data::make_data_loader(
1975 dataset,
1976 DataLoaderOptions(batch_size).workers(dataloader_worker_count));
1977
1978 for (const auto epoch_index : c10::irange(epoch_count)) {
1979 (void)epoch_index; // Suppress unused variable warning
1980 unsigned iteration_count = 0;
1981 for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1982 ++iterator, ++iteration_count) {
1983 if ((iteration_count + 1) % save_interval == 0) {
1984 torch::save(*dataset, tempfile.name);
1985
1986 samplers::SequentialSampler new_sampler(0);
1987
1988 // See Note [save/load ChunkDataset as ChunkSampler]
1989 torch::load(new_sampler, tempfile.name);
1990
1991 // Verify save logic. For ChunkDataset, the chunk data is stored in a
1992 // cache inside the dataset. One pool of threads are constantly
1993 // writing to the cache, and a different pool of thread are constantly
1994 // reading from the cache. Due to the nature of asynchronization, at
1995 // the time of get_batch(), which chunk is written to the cache is not
1996 // fully deterministic.
1997 // But we can still calculate a restricted window on the expected
1998 // output, hence verify the logic. In this test, the cache size is
1999 // configured to be the same as chunk size and batch size. So the
2000 // chunk data is written to the cache one by one. Only the current
2001 // batch is retrieved, the next chunk is written. Now in iteration 0,
2002 // after the first batch is retrieved, when we save the dataset
2003 // statues, there are three possible scenarios for the writer thread:
2004 // 1. it hasn't started loading the next chunk data yet, so the
2005 // sequential sampler index is still 0;
2006 // 2. it started to load the second chunk, so the sequential sampler
2007 // index is at 1;
2008 // 3. it finished loading the second chunk, and start to load the
2009 // third chunk, because the cache is still fully occupied by the data
2010 // from the second chunk, it is waiting to write to the cache. At this
2011 // point, the sampler index is at 2.
2012 // So now we have a window of [0, 2], which is what we expected the
2013 // sampler to save the index from. Now noted for sequential sampler,
2014 // it advances to the next index automatically in the call next(). So
2015 // when save the index, it saves the next index in stead of the
2016 // current one. In other word, after getting the first index from
2017 // sequential sampler, it already moves to the second index. So when
2018 // we save it, it is the second index we save. As a result,
2019 // we need to advance the window by one. Now we have the expected
2020 // window of [1, 3].
2021 // This analysis applies to all scenarios. So extend it to a more
2022 // general case: the expected saved index should falling into the
2023 // range of [iteration, iteration + 3], which is the validation
2024 // below.
2025 ASSERT_TRUE(
2026 new_sampler.index() >= iteration_count + 1 &&
2027 new_sampler.index() <= iteration_count + 3);
2028 }
2029 }
2030 }
2031 }
2032 }
2033
2034 // Test ChunkDataset load function.
TEST(DataLoaderTest,ChunkDatasetLoad)2035 TEST(DataLoaderTest, ChunkDatasetLoad) {
2036 auto tempfile = c10::make_tempfile();
2037
2038 const size_t prefetch_count = 1;
2039 const size_t batch_size = 10;
2040 const size_t dataloader_worker_count = 0;
2041
2042 DummyChunkDataReader data_reader;
2043 samplers::SequentialSampler sampler(0);
2044
2045 const size_t skipped_chunk = 2;
2046
2047 // Configure sampler to skip 2 chunks
2048 {
2049 sampler.reset(data_reader.chunk_count());
2050 sampler.next(skipped_chunk);
2051
2052 // See Note [save/load ChunkDataset as ChunkSampler]
2053 torch::save(sampler, tempfile.name);
2054 }
2055
2056 // test functionality across epoch boundary. The first epoch should be
2057 // affected by the checkpoint, but the second should start normally.
2058 const int epoch_count = 2;
2059
2060 datasets::SharedBatchDataset<datasets::ChunkDataset<
2061 DummyChunkDataReader,
2062 samplers::SequentialSampler,
2063 samplers::SequentialSampler>>
2064 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
2065 DummyChunkDataReader,
2066 samplers::SequentialSampler,
2067 samplers::SequentialSampler>>(
2068 data_reader,
2069 sampler,
2070 sampler,
2071 datasets::ChunkDatasetOptions(
2072 prefetch_count, batch_size, 20 /*cache size*/));
2073
2074 torch::load(*dataset, tempfile.name);
2075
2076 auto data_loader = torch::data::make_data_loader(
2077 dataset, DataLoaderOptions(batch_size).workers(dataloader_worker_count));
2078
2079 for (const auto epoch_index : c10::irange(epoch_count)) {
2080 int iteration_count = 0;
2081
2082 // For the first epoch, the returned batch should be returned from the
2083 // third chunk, because the check point skipped the first two chunks. But
2084 // for the next epoch, it should start from the first batch.
2085 int initial_value = epoch_index == 0 ? 15 : 0;
2086
2087 for (auto iterator = data_loader->begin(); iterator != data_loader->end();
2088 ++iterator, ++iteration_count) {
2089 DummyChunkDataReader::BatchType batch = *iterator;
2090
2091 std::vector<int> expected_result;
2092 size_t expected_size = (epoch_index > 0 && iteration_count == 3) ? 5 : 10;
2093 expected_result.resize(expected_size);
2094 std::iota(expected_result.begin(), expected_result.end(), initial_value);
2095
2096 ASSERT_EQ(batch.size(), expected_result.size());
2097 ASSERT_TRUE(
2098 std::equal(batch.begin(), batch.end(), expected_result.begin()));
2099
2100 initial_value += batch_size;
2101 }
2102 }
2103
2104 samplers::SequentialSampler new_sampler(0);
2105
2106 // See Note [save/load ChunkDataset as ChunkSampler]
2107 torch::load(new_sampler, tempfile.name);
2108
2109 ASSERT_EQ(new_sampler.index(), skipped_chunk);
2110 }
2111
TEST(DataLoaderTest,ChunkDatasetCrossChunkShuffle)2112 TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) {
2113 const size_t chunk_size = 5;
2114 const size_t batch_size = 5;
2115
2116 class S : public samplers::Sampler<> {
2117 public:
2118 explicit S(size_t size) : size_(size), index_(0){};
2119
2120 void reset(torch::optional<size_t> new_size = torch::nullopt) override {
2121 if (new_size.has_value()) {
2122 size_ = *new_size;
2123 }
2124 indices_.resize(size_);
2125 size_t index = 0;
2126
2127 // Repeatedly sample every 5 indices.
2128 for (const auto i : c10::irange(batch_size)) {
2129 for (size_t j = 0; j < size_ / batch_size; ++j) {
2130 indices_[index++] = i + batch_size * j;
2131 }
2132 }
2133 index_ = 0;
2134 }
2135
2136 // Returns the next batch of indices.
2137 torch::optional<std::vector<size_t>> next(size_t batch_size) override {
2138 const auto remaining_indices = size_ - index_;
2139 if (remaining_indices == 0) {
2140 return torch::nullopt;
2141 }
2142 auto return_size = std::min(batch_size, remaining_indices);
2143 std::vector<size_t> index_batch(
2144 indices_.begin() + index_, indices_.begin() + index_ + return_size);
2145 index_ += return_size;
2146
2147 return index_batch;
2148 }
2149
2150 void save(torch::serialize::OutputArchive& archive) const override {}
2151 void load(torch::serialize::InputArchive& archive) override {}
2152
2153 private:
2154 size_t size_;
2155 std::vector<size_t> indices_;
2156 size_t index_{0};
2157 };
2158
2159 struct D : public datasets::ChunkDataReader<int> {
2160 public:
2161 using BatchType = datasets::ChunkDataReader<int>::ChunkType;
2162 D(size_t chunk_count) : chunk_count_(chunk_count) {}
2163
2164 BatchType read_chunk(size_t chunk_index) override {
2165 BatchType batch_data(chunk_size, chunk_index);
2166 return batch_data;
2167 }
2168
2169 size_t chunk_count() override {
2170 return chunk_count_;
2171 };
2172
2173 void reset() override{};
2174 size_t chunk_count_;
2175 };
2176
2177 const size_t prefetch_count = 1;
2178 const size_t cache_size = 10;
2179 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2180 const size_t cross_chunk_shuffle_counts[] = {2, 3};
2181 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2182 const size_t chunk_counts[] = {3, 4, 5};
2183
2184 samplers::SequentialSampler chunk_sampler(0);
2185 S example_sampler(0);
2186
2187 for (auto chunk_count : chunk_counts) {
2188 for (auto cross_chunk_shuffle_count : cross_chunk_shuffle_counts) {
2189 D data_reader(chunk_count);
2190
2191 datasets::SharedBatchDataset<
2192 datasets::ChunkDataset<D, samplers::SequentialSampler, S>>
2193 dataset = datasets::make_shared_dataset<
2194 datasets::ChunkDataset<D, samplers::SequentialSampler, S>>(
2195 data_reader,
2196 chunk_sampler,
2197 example_sampler,
2198 datasets::ChunkDatasetOptions(
2199 prefetch_count,
2200 batch_size,
2201 cache_size,
2202 cross_chunk_shuffle_count));
2203
2204 auto data_loader = torch::data::make_data_loader(
2205 dataset, DataLoaderOptions(batch_size).workers(0));
2206
2207 std::vector<int> result;
2208 for (auto iterator = data_loader->begin(); iterator != data_loader->end();
2209 ++iterator) {
2210 auto batch_result = *iterator;
2211 std::copy(
2212 batch_result.begin(),
2213 batch_result.end(),
2214 std::back_inserter(result));
2215 }
2216
2217 std::vector<int> expected_result;
2218 {
2219 // construct expected result
2220 for (const auto i : c10::irange(
2221 (chunk_count + cross_chunk_shuffle_count - 1) /
2222 cross_chunk_shuffle_count)) {
2223 for (const auto j : c10::irange(chunk_size)) {
2224 (void)j; // Suppress unused variable warning
2225 for (const auto k : c10::irange(cross_chunk_shuffle_count)) {
2226 if (i * cross_chunk_shuffle_count + k < chunk_count) {
2227 expected_result.push_back(i * cross_chunk_shuffle_count + k);
2228 }
2229 }
2230 }
2231 }
2232 }
2233
2234 ASSERT_EQ(result.size(), expected_result.size());
2235 ASSERT_TRUE(
2236 std::equal(result.begin(), result.end(), expected_result.begin()));
2237 }
2238 }
2239 }
2240
TEST(DataLoaderTest,CustomPreprocessPolicy)2241 TEST(DataLoaderTest, CustomPreprocessPolicy) {
2242 const size_t chunk_size = 5;
2243 const size_t batch_size = 10;
2244
2245 struct D : public datasets::ChunkDataReader<int> {
2246 public:
2247 using BatchType = datasets::ChunkDataReader<int>::ChunkType;
2248 D(size_t chunk_count) : chunk_count_(chunk_count) {}
2249
2250 BatchType read_chunk(size_t chunk_index) override {
2251 BatchType batch_data(chunk_size);
2252 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand)
2253 auto rand_gen = []() { return std::rand() % 100; };
2254 std::generate(batch_data.begin(), batch_data.end(), rand_gen);
2255 return batch_data;
2256 }
2257
2258 size_t chunk_count() override {
2259 return chunk_count_;
2260 };
2261
2262 void reset() override{};
2263 size_t chunk_count_;
2264 };
2265
2266 // custom preprocessing policy - sort the data ascendingly
2267 auto sorting_policy = [](std::vector<int>& raw_batch_data) {
2268 std::sort(raw_batch_data.begin(), raw_batch_data.end());
2269 };
2270 std::function<void(std::vector<int>&)> policy_function = sorting_policy;
2271
2272 const size_t prefetch_count = 1;
2273 const size_t cache_size = 10;
2274 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2275 const size_t cross_chunk_shuffle_counts[] = {1, 2};
2276 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2277 const size_t chunk_counts[] = {3, 4};
2278
2279 samplers::SequentialSampler chunk_sampler(0);
2280
2281 for (auto chunk_count : chunk_counts) {
2282 for (auto cross_chunk_shuffle_count : cross_chunk_shuffle_counts) {
2283 D data_reader(chunk_count);
2284
2285 datasets::SharedBatchDataset<datasets::ChunkDataset<
2286 D,
2287 samplers::SequentialSampler,
2288 samplers::SequentialSampler>>
2289 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
2290 D,
2291 samplers::SequentialSampler,
2292 samplers::SequentialSampler>>(
2293 data_reader,
2294 chunk_sampler,
2295 chunk_sampler,
2296 datasets::ChunkDatasetOptions(
2297 prefetch_count,
2298 batch_size,
2299 cache_size,
2300 cross_chunk_shuffle_count),
2301 policy_function);
2302
2303 auto data_loader = torch::data::make_data_loader(
2304 dataset, DataLoaderOptions(batch_size).workers(0));
2305
2306 std::vector<int> result;
2307 for (auto iterator = data_loader->begin(); iterator != data_loader->end();
2308 ++iterator) {
2309 auto batch_result = *iterator;
2310 if (batch_result.size() > chunk_size * cross_chunk_shuffle_count) {
2311 for (unsigned i = 0; i < batch_result.size(); i += chunk_size) {
2312 ASSERT_TRUE(std::is_sorted(
2313 batch_result.begin() + i,
2314 batch_result.begin() + i + chunk_size));
2315 }
2316 } else {
2317 ASSERT_TRUE(std::is_sorted(batch_result.begin(), batch_result.end()));
2318 }
2319 }
2320 }
2321 }
2322 }
2323