xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/datasets/stateful.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/data/datasets/base.h>
4 #include <torch/data/example.h>
5 
6 #include <cstddef>
7 #include <vector>
8 
9 namespace torch {
10 namespace serialize {
11 class OutputArchive;
12 class InputArchive;
13 } // namespace serialize
14 } // namespace torch
15 
16 namespace torch {
17 namespace data {
18 namespace datasets {
19 
20 /// A stateful dataset is a dataset that maintains some internal state, which
21 /// will be `reset()` at the beginning of each epoch. Subclasses can override
22 /// the `reset()` method to configure this behavior. Further, the return type of
23 /// a stateful dataset's `get_batch()` method is always an `optional`. When the
24 /// stateful dataset wants to indicate to the dataloader that its epoch has
25 /// ended, it should return an empty optional. The dataloader knows to modify
26 /// its implementation based on whether the dataset is stateless or stateful.
27 ///
28 /// Note that when subclassing a from `StatefulDataset<Self, T>`, the return
29 /// type of `get_batch()`, which the subclass must override, will be
30 /// `optional<T>` (i.e. the type specified in the `StatefulDataset`
31 /// specialization is automatically boxed into an `optional` for the dataset's
32 /// `BatchType`).
33 template <
34     typename Self,
35     typename Batch = std::vector<Example<>>,
36     typename BatchRequest = size_t>
37 class StatefulDataset
38     : public BatchDataset<Self, std::optional<Batch>, BatchRequest> {
39  public:
40   /// Resets internal state of the dataset.
41   virtual void reset() = 0;
42 
43   /// Saves the statefulDataset's state to OutputArchive.
44   virtual void save(serialize::OutputArchive& archive) const = 0;
45 
46   /// Deserializes the statefulDataset's state from the `archive`.
47   virtual void load(serialize::InputArchive& archive) = 0;
48 };
49 
50 /// Serializes a statefulDataset to `OutputArchive`.
51 template <typename... Args>
52 serialize::OutputArchive& operator<<(
53     serialize::OutputArchive& archive,
54     const StatefulDataset<Args...>& statefulDataset) {
55   statefulDataset.save(archive);
56   return archive;
57 }
58 
59 /// Deserializes a statefulDataset from an `InputArchive`.
60 template <typename... Args>
61 serialize::InputArchive& operator>>(
62     serialize::InputArchive& archive,
63     StatefulDataset<Args...>& statefulDataset) {
64   statefulDataset.load(archive);
65   return archive;
66 }
67 
68 } // namespace datasets
69 } // namespace data
70 } // namespace torch
71