xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/datasets/base.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/data/example.h>
4 #include <torch/types.h>
5 
6 #include <c10/util/ArrayRef.h>
7 
8 #include <cstddef>
9 #include <cstdint>
10 #include <type_traits>
11 #include <utility>
12 #include <vector>
13 
14 namespace torch {
15 namespace data {
16 namespace datasets {
17 template <typename S, typename T>
18 class MapDataset;
19 template <typename D, typename T>
20 MapDataset<D, T> map(D, T); // NOLINT
21 } // namespace datasets
22 } // namespace data
23 } // namespace torch
24 
25 namespace torch {
26 namespace data {
27 namespace datasets {
28 namespace detail {
29 template <typename T>
30 struct is_optional : std::false_type {};
31 template <typename T>
32 struct is_optional<std::optional<T>> : std::true_type {};
33 } // namespace detail
34 
35 /// A dataset that can yield data only in batches.
36 template <
37     typename Self,
38     typename Batch = std::vector<Example<>>,
39     typename BatchRequest = ArrayRef<size_t>>
40 class BatchDataset {
41  public:
42   using SelfType = Self;
43   using BatchType = Batch;
44   using BatchRequestType = BatchRequest;
45   constexpr static bool is_stateful = detail::is_optional<BatchType>::value;
46 
47   virtual ~BatchDataset() = default;
48 
49   /// Returns a batch of data given an index.
50   virtual Batch get_batch(BatchRequest request) = 0;
51 
52   /// Returns the size of the dataset, or an empty std::optional if it is
53   /// unsized.
54   virtual std::optional<size_t> size() const = 0;
55 
56   /// Creates a `MapDataset` that applies the given `transform` to this dataset.
57   template <typename TransformType>
58   MapDataset<Self, TransformType> map(TransformType transform) & {
59     return datasets::map(static_cast<Self&>(*this), std::move(transform));
60   }
61 
62   /// Creates a `MapDataset` that applies the given `transform` to this dataset.
63   template <typename TransformType>
64   MapDataset<Self, TransformType> map(TransformType transform) && {
65     return datasets::map(
66         std::move(static_cast<Self&>(*this)), std::move(transform));
67   }
68 };
69 
70 /// A dataset that can yield data in batches, or as individual examples.
71 ///
72 /// A `Dataset` is a `BatchDataset`, because it supports random access and
73 /// therefore batched access is implemented (by default) by calling the random
74 /// access indexing function for each index in the requested batch of indices.
75 /// This can be customized.
76 template <typename Self, typename SingleExample = Example<>>
77 class Dataset : public BatchDataset<Self, std::vector<SingleExample>> {
78  public:
79   using ExampleType = SingleExample;
80 
81   /// Returns the example at the given index.
82   virtual ExampleType get(size_t index) = 0;
83 
84   /// Returns a batch of data.
85   /// The default implementation calls `get()` for every requested index
86   /// in the batch.
87   std::vector<ExampleType> get_batch(ArrayRef<size_t> indices) override {
88     std::vector<ExampleType> batch;
89     batch.reserve(indices.size());
90     for (const auto i : indices) {
91       batch.push_back(get(i));
92     }
93     return batch;
94   }
95 };
96 
97 /// A `StreamDataset` represents a dataset that is a potentially infinite
98 /// stream. It takes as batch index only a number, which is the batch size, and
99 /// yields that many elements from the stream.
100 template <typename Self, typename Batch = std::vector<Example<>>>
101 using StreamDataset = BatchDataset<Self, Batch, /*BatchRequest=*/size_t>;
102 } // namespace datasets
103 } // namespace data
104 } // namespace torch
105