1 #pragma once 2 3 #include <torch/data/example.h> 4 #include <torch/types.h> 5 6 #include <c10/util/ArrayRef.h> 7 8 #include <cstddef> 9 #include <cstdint> 10 #include <type_traits> 11 #include <utility> 12 #include <vector> 13 14 namespace torch { 15 namespace data { 16 namespace datasets { 17 template <typename S, typename T> 18 class MapDataset; 19 template <typename D, typename T> 20 MapDataset<D, T> map(D, T); // NOLINT 21 } // namespace datasets 22 } // namespace data 23 } // namespace torch 24 25 namespace torch { 26 namespace data { 27 namespace datasets { 28 namespace detail { 29 template <typename T> 30 struct is_optional : std::false_type {}; 31 template <typename T> 32 struct is_optional<std::optional<T>> : std::true_type {}; 33 } // namespace detail 34 35 /// A dataset that can yield data only in batches. 36 template < 37 typename Self, 38 typename Batch = std::vector<Example<>>, 39 typename BatchRequest = ArrayRef<size_t>> 40 class BatchDataset { 41 public: 42 using SelfType = Self; 43 using BatchType = Batch; 44 using BatchRequestType = BatchRequest; 45 constexpr static bool is_stateful = detail::is_optional<BatchType>::value; 46 47 virtual ~BatchDataset() = default; 48 49 /// Returns a batch of data given an index. 50 virtual Batch get_batch(BatchRequest request) = 0; 51 52 /// Returns the size of the dataset, or an empty std::optional if it is 53 /// unsized. 54 virtual std::optional<size_t> size() const = 0; 55 56 /// Creates a `MapDataset` that applies the given `transform` to this dataset. 57 template <typename TransformType> 58 MapDataset<Self, TransformType> map(TransformType transform) & { 59 return datasets::map(static_cast<Self&>(*this), std::move(transform)); 60 } 61 62 /// Creates a `MapDataset` that applies the given `transform` to this dataset. 63 template <typename TransformType> 64 MapDataset<Self, TransformType> map(TransformType transform) && { 65 return datasets::map( 66 std::move(static_cast<Self&>(*this)), std::move(transform)); 67 } 68 }; 69 70 /// A dataset that can yield data in batches, or as individual examples. 71 /// 72 /// A `Dataset` is a `BatchDataset`, because it supports random access and 73 /// therefore batched access is implemented (by default) by calling the random 74 /// access indexing function for each index in the requested batch of indices. 75 /// This can be customized. 76 template <typename Self, typename SingleExample = Example<>> 77 class Dataset : public BatchDataset<Self, std::vector<SingleExample>> { 78 public: 79 using ExampleType = SingleExample; 80 81 /// Returns the example at the given index. 82 virtual ExampleType get(size_t index) = 0; 83 84 /// Returns a batch of data. 85 /// The default implementation calls `get()` for every requested index 86 /// in the batch. 87 std::vector<ExampleType> get_batch(ArrayRef<size_t> indices) override { 88 std::vector<ExampleType> batch; 89 batch.reserve(indices.size()); 90 for (const auto i : indices) { 91 batch.push_back(get(i)); 92 } 93 return batch; 94 } 95 }; 96 97 /// A `StreamDataset` represents a dataset that is a potentially infinite 98 /// stream. It takes as batch index only a number, which is the batch size, and 99 /// yields that many elements from the stream. 100 template <typename Self, typename Batch = std::vector<Example<>>> 101 using StreamDataset = BatchDataset<Self, Batch, /*BatchRequest=*/size_t>; 102 } // namespace datasets 103 } // namespace data 104 } // namespace torch 105