1 #include <torch/csrc/jit/mobile/train/sequential.h> 2 #include <torch/types.h> 3 4 #include <algorithm> 5 #include <cstddef> 6 #include <vector> 7 8 namespace torch::jit::mobile { SequentialSampler(size_t size)9SequentialSampler::SequentialSampler(size_t size) : size_(size) {} 10 reset(std::optional<size_t> new_size)11void SequentialSampler::reset(std::optional<size_t> new_size) { 12 if (new_size.has_value()) { 13 size_ = *new_size; 14 } 15 index_ = 0; 16 } 17 next(size_t batch_size)18optional<std::vector<size_t>> SequentialSampler::next(size_t batch_size) { 19 const auto remaining_indices = size_ - index_; 20 if (remaining_indices == 0) { 21 return nullopt; 22 } 23 std::vector<size_t> index_batch(std::min(batch_size, remaining_indices)); 24 for (auto& i : index_batch) { 25 i = index_++; 26 } 27 return index_batch; 28 } 29 save(serialize::OutputArchive & archive) const30void SequentialSampler::save(serialize::OutputArchive& archive) const { 31 TORCH_CHECK( 32 false, "Serialization of SequentialSampler not supported on mobile."); 33 } 34 load(serialize::InputArchive & archive)35void SequentialSampler::load(serialize::InputArchive& archive) { 36 TORCH_CHECK( 37 false, "Serialization of SequentialSampler not supported on mobile."); 38 } 39 index() const40size_t SequentialSampler::index() const noexcept { 41 return index_; 42 } 43 44 } // namespace torch::jit::mobile 45