1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/data/samplers/base.h> 5 #include <torch/types.h> 6 7 #include <cstddef> 8 #include <vector> 9 10 namespace torch::serialize { 11 class OutputArchive; 12 class InputArchive; 13 } // namespace torch::serialize 14 15 namespace torch::jit::mobile { 16 17 /// A lighter `Sampler` that returns indices sequentially and cannot be 18 /// serialized. 19 class TORCH_API SequentialSampler : public torch::data::samplers::Sampler<> { 20 public: 21 /// Creates a `SequentialSampler` that will return indices in the range 22 /// `0...size - 1`. 23 explicit SequentialSampler(size_t size); 24 25 /// Resets the `SequentialSampler` to zero. 26 void reset(std::optional<size_t> new_size = std::nullopt) override; 27 28 /// Returns the next batch of indices. 29 std::optional<std::vector<size_t>> next(size_t batch_size) override; 30 31 /// Not supported for mobile SequentialSampler 32 void save(serialize::OutputArchive& archive) const override; 33 34 /// Not supported for mobile SequentialSampler 35 void load(serialize::InputArchive& archive) override; 36 37 /// Returns the current index of the `SequentialSampler`. 38 size_t index() const noexcept; 39 40 private: 41 size_t size_; 42 size_t index_{0}; 43 }; 44 45 } // namespace torch::jit::mobile 46