1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/types.h> 5 6 #include <cstddef> 7 #include <mutex> 8 #include <vector> 9 10 namespace torch { 11 namespace serialize { 12 class OutputArchive; 13 class InputArchive; 14 } // namespace serialize 15 } // namespace torch 16 17 namespace torch { 18 namespace data { 19 namespace samplers { 20 /// A `Sampler` is an object that yields an index with which to access a 21 /// dataset. 22 template <typename BatchRequest = std::vector<size_t>> 23 class Sampler { 24 public: 25 using BatchRequestType = BatchRequest; 26 27 virtual ~Sampler() = default; 28 29 /// Resets the `Sampler`'s internal state. 30 /// Typically called before a new epoch. 31 /// Optionally, accepts a new size when reseting the sampler. 32 virtual void reset(std::optional<size_t> new_size) = 0; 33 34 /// Returns the next index if possible, or an empty optional if the 35 /// sampler is exhausted for this epoch. 36 virtual std::optional<BatchRequest> next(size_t batch_size) = 0; 37 38 /// Serializes the `Sampler` to the `archive`. 39 virtual void save(serialize::OutputArchive& archive) const = 0; 40 41 /// Deserializes the `Sampler` from the `archive`. 42 virtual void load(serialize::InputArchive& archive) = 0; 43 }; 44 45 } // namespace samplers 46 } // namespace data 47 } // namespace torch 48