xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/train/sequential.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/train/sequential.h>
2 #include <torch/types.h>
3 
4 #include <algorithm>
5 #include <cstddef>
6 #include <vector>
7 
8 namespace torch::jit::mobile {
SequentialSampler(size_t size)9 SequentialSampler::SequentialSampler(size_t size) : size_(size) {}
10 
reset(std::optional<size_t> new_size)11 void SequentialSampler::reset(std::optional<size_t> new_size) {
12   if (new_size.has_value()) {
13     size_ = *new_size;
14   }
15   index_ = 0;
16 }
17 
next(size_t batch_size)18 optional<std::vector<size_t>> SequentialSampler::next(size_t batch_size) {
19   const auto remaining_indices = size_ - index_;
20   if (remaining_indices == 0) {
21     return nullopt;
22   }
23   std::vector<size_t> index_batch(std::min(batch_size, remaining_indices));
24   for (auto& i : index_batch) {
25     i = index_++;
26   }
27   return index_batch;
28 }
29 
save(serialize::OutputArchive & archive) const30 void SequentialSampler::save(serialize::OutputArchive& archive) const {
31   TORCH_CHECK(
32       false, "Serialization of SequentialSampler not supported on mobile.");
33 }
34 
load(serialize::InputArchive & archive)35 void SequentialSampler::load(serialize::InputArchive& archive) {
36   TORCH_CHECK(
37       false, "Serialization of SequentialSampler not supported on mobile.");
38 }
39 
index() const40 size_t SequentialSampler::index() const noexcept {
41   return index_;
42 }
43 
44 } // namespace torch::jit::mobile
45