1 #pragma once 2 3 #include <torch/data/datasets/base.h> 4 #include <torch/data/example.h> 5 6 #include <cstddef> 7 #include <vector> 8 9 namespace torch { 10 namespace serialize { 11 class OutputArchive; 12 class InputArchive; 13 } // namespace serialize 14 } // namespace torch 15 16 namespace torch { 17 namespace data { 18 namespace datasets { 19 20 /// A stateful dataset is a dataset that maintains some internal state, which 21 /// will be `reset()` at the beginning of each epoch. Subclasses can override 22 /// the `reset()` method to configure this behavior. Further, the return type of 23 /// a stateful dataset's `get_batch()` method is always an `optional`. When the 24 /// stateful dataset wants to indicate to the dataloader that its epoch has 25 /// ended, it should return an empty optional. The dataloader knows to modify 26 /// its implementation based on whether the dataset is stateless or stateful. 27 /// 28 /// Note that when subclassing a from `StatefulDataset<Self, T>`, the return 29 /// type of `get_batch()`, which the subclass must override, will be 30 /// `optional<T>` (i.e. the type specified in the `StatefulDataset` 31 /// specialization is automatically boxed into an `optional` for the dataset's 32 /// `BatchType`). 33 template < 34 typename Self, 35 typename Batch = std::vector<Example<>>, 36 typename BatchRequest = size_t> 37 class StatefulDataset 38 : public BatchDataset<Self, std::optional<Batch>, BatchRequest> { 39 public: 40 /// Resets internal state of the dataset. 41 virtual void reset() = 0; 42 43 /// Saves the statefulDataset's state to OutputArchive. 44 virtual void save(serialize::OutputArchive& archive) const = 0; 45 46 /// Deserializes the statefulDataset's state from the `archive`. 47 virtual void load(serialize::InputArchive& archive) = 0; 48 }; 49 50 /// Serializes a statefulDataset to `OutputArchive`. 51 template <typename... Args> 52 serialize::OutputArchive& operator<<( 53 serialize::OutputArchive& archive, 54 const StatefulDataset<Args...>& statefulDataset) { 55 statefulDataset.save(archive); 56 return archive; 57 } 58 59 /// Deserializes a statefulDataset from an `InputArchive`. 60 template <typename... Args> 61 serialize::InputArchive& operator>>( 62 serialize::InputArchive& archive, 63 StatefulDataset<Args...>& statefulDataset) { 64 statefulDataset.load(archive); 65 return archive; 66 } 67 68 } // namespace datasets 69 } // namespace data 70 } // namespace torch 71