xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/datasets/shared.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/data/datasets/base.h>
4 
5 #include <memory>
6 #include <utility>
7 
8 namespace torch {
9 namespace data {
10 namespace datasets {
11 
12 /// A dataset that wraps another dataset in a shared pointer and implements the
13 /// `BatchDataset` API, delegating all calls to the shared instance. This is
14 /// useful when you want all worker threads in the dataloader to access the same
15 /// dataset instance. The dataset must take care of synchronization and
16 /// thread-safe access itself.
17 ///
18 /// Use `torch::data::datasets::make_shared_dataset()` to create a new
19 /// `SharedBatchDataset` like you would a `std::shared_ptr`.
20 template <typename UnderlyingDataset>
21 class SharedBatchDataset : public BatchDataset<
22                                SharedBatchDataset<UnderlyingDataset>,
23                                typename UnderlyingDataset::BatchType,
24                                typename UnderlyingDataset::BatchRequestType> {
25  public:
26   using BatchType = typename UnderlyingDataset::BatchType;
27   using BatchRequestType = typename UnderlyingDataset::BatchRequestType;
28 
29   /// Constructs a new `SharedBatchDataset` from a `shared_ptr` to the
30   /// `UnderlyingDataset`.
SharedBatchDataset(std::shared_ptr<UnderlyingDataset> shared_dataset)31   /* implicit */ SharedBatchDataset(
32       std::shared_ptr<UnderlyingDataset> shared_dataset)
33       : dataset_(std::move(shared_dataset)) {}
34 
35   /// Calls `get_batch` on the underlying dataset.
get_batch(BatchRequestType request)36   BatchType get_batch(BatchRequestType request) override {
37     return dataset_->get_batch(std::move(request));
38   }
39 
40   /// Returns the `size` from the underlying dataset.
size()41   std::optional<size_t> size() const override {
42     return dataset_->size();
43   }
44 
45   /// Accesses the underlying dataset.
46   UnderlyingDataset& operator*() {
47     return *dataset_;
48   }
49 
50   /// Accesses the underlying dataset.
51   const UnderlyingDataset& operator*() const {
52     return *dataset_;
53   }
54 
55   /// Accesses the underlying dataset.
56   UnderlyingDataset* operator->() {
57     return dataset_.get();
58   }
59 
60   /// Accesses the underlying dataset.
61   const UnderlyingDataset* operator->() const {
62     return dataset_.get();
63   }
64 
65   /// Calls `reset()` on the underlying dataset.
reset()66   void reset() {
67     dataset_->reset();
68   }
69 
70  private:
71   std::shared_ptr<UnderlyingDataset> dataset_;
72 };
73 
74 /// Constructs a new `SharedBatchDataset` by creating a
75 /// `shared_ptr<UnderlyingDatase>`. All arguments are forwarded to
76 /// `make_shared<UnderlyingDataset>`.
77 template <typename UnderlyingDataset, typename... Args>
make_shared_dataset(Args &&...args)78 SharedBatchDataset<UnderlyingDataset> make_shared_dataset(Args&&... args) {
79   return std::make_shared<UnderlyingDataset>(std::forward<Args>(args)...);
80 }
81 } // namespace datasets
82 } // namespace data
83 } // namespace torch
84