xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/samplers/base.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/types.h>
5 
6 #include <cstddef>
7 #include <mutex>
8 #include <vector>
9 
10 namespace torch {
11 namespace serialize {
12 class OutputArchive;
13 class InputArchive;
14 } // namespace serialize
15 } // namespace torch
16 
17 namespace torch {
18 namespace data {
19 namespace samplers {
20 /// A `Sampler` is an object that yields an index with which to access a
21 /// dataset.
22 template <typename BatchRequest = std::vector<size_t>>
23 class Sampler {
24  public:
25   using BatchRequestType = BatchRequest;
26 
27   virtual ~Sampler() = default;
28 
29   /// Resets the `Sampler`'s internal state.
30   /// Typically called before a new epoch.
31   /// Optionally, accepts a new size when reseting the sampler.
32   virtual void reset(std::optional<size_t> new_size) = 0;
33 
34   /// Returns the next index if possible, or an empty optional if the
35   /// sampler is exhausted for this epoch.
36   virtual std::optional<BatchRequest> next(size_t batch_size) = 0;
37 
38   /// Serializes the `Sampler` to the `archive`.
39   virtual void save(serialize::OutputArchive& archive) const = 0;
40 
41   /// Deserializes the `Sampler` from the `archive`.
42   virtual void load(serialize::InputArchive& archive) = 0;
43 };
44 
45 } // namespace samplers
46 } // namespace data
47 } // namespace torch
48