xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/train/sequential.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/data/samplers/base.h>
5 #include <torch/types.h>
6 
7 #include <cstddef>
8 #include <vector>
9 
10 namespace torch::serialize {
11 class OutputArchive;
12 class InputArchive;
13 } // namespace torch::serialize
14 
15 namespace torch::jit::mobile {
16 
17 /// A lighter `Sampler` that returns indices sequentially and cannot be
18 /// serialized.
19 class TORCH_API SequentialSampler : public torch::data::samplers::Sampler<> {
20  public:
21   /// Creates a `SequentialSampler` that will return indices in the range
22   /// `0...size - 1`.
23   explicit SequentialSampler(size_t size);
24 
25   /// Resets the `SequentialSampler` to zero.
26   void reset(std::optional<size_t> new_size = std::nullopt) override;
27 
28   /// Returns the next batch of indices.
29   std::optional<std::vector<size_t>> next(size_t batch_size) override;
30 
31   /// Not supported for mobile SequentialSampler
32   void save(serialize::OutputArchive& archive) const override;
33 
34   /// Not supported for mobile SequentialSampler
35   void load(serialize::InputArchive& archive) override;
36 
37   /// Returns the current index of the `SequentialSampler`.
38   size_t index() const noexcept;
39 
40  private:
41   size_t size_;
42   size_t index_{0};
43 };
44 
45 } // namespace torch::jit::mobile
46