xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/data/samplers/sequential.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/data/samplers/sequential.h>
2 #include <torch/serialize/archive.h>
3 #include <torch/types.h>
4 
5 #include <algorithm>
6 #include <cstddef>
7 #include <vector>
8 
9 namespace torch {
10 namespace data {
11 namespace samplers {
SequentialSampler(size_t size)12 SequentialSampler::SequentialSampler(size_t size) : size_(size) {}
13 
reset(std::optional<size_t> new_size)14 void SequentialSampler::reset(std::optional<size_t> new_size) {
15   if (new_size.has_value()) {
16     size_ = *new_size;
17   }
18   index_ = 0;
19 }
20 
next(size_t batch_size)21 std::optional<std::vector<size_t>> SequentialSampler::next(size_t batch_size) {
22   const auto remaining_indices = size_ - index_;
23   if (remaining_indices == 0) {
24     return nullopt;
25   }
26   std::vector<size_t> index_batch(std::min(batch_size, remaining_indices));
27   for (auto& i : index_batch) {
28     i = index_++;
29   }
30   return index_batch;
31 }
32 
save(serialize::OutputArchive & archive) const33 void SequentialSampler::save(serialize::OutputArchive& archive) const {
34   archive.write(
35       "index",
36       torch::tensor(static_cast<int64_t>(index_), torch::kInt64),
37       /*is_buffer=*/true);
38 }
39 
load(serialize::InputArchive & archive)40 void SequentialSampler::load(serialize::InputArchive& archive) {
41   auto tensor = torch::empty(1, torch::kInt64);
42   archive.read(
43       "index",
44       tensor,
45       /*is_buffer=*/true);
46   index_ = tensor.item<int64_t>();
47 }
48 
index() const49 size_t SequentialSampler::index() const noexcept {
50   return index_;
51 }
52 
53 } // namespace samplers
54 } // namespace data
55 } // namespace torch
56