1 #pragma once 2 3 #include <torch/data/datasets/base.h> 4 5 #include <memory> 6 #include <utility> 7 8 namespace torch { 9 namespace data { 10 namespace datasets { 11 12 /// A dataset that wraps another dataset in a shared pointer and implements the 13 /// `BatchDataset` API, delegating all calls to the shared instance. This is 14 /// useful when you want all worker threads in the dataloader to access the same 15 /// dataset instance. The dataset must take care of synchronization and 16 /// thread-safe access itself. 17 /// 18 /// Use `torch::data::datasets::make_shared_dataset()` to create a new 19 /// `SharedBatchDataset` like you would a `std::shared_ptr`. 20 template <typename UnderlyingDataset> 21 class SharedBatchDataset : public BatchDataset< 22 SharedBatchDataset<UnderlyingDataset>, 23 typename UnderlyingDataset::BatchType, 24 typename UnderlyingDataset::BatchRequestType> { 25 public: 26 using BatchType = typename UnderlyingDataset::BatchType; 27 using BatchRequestType = typename UnderlyingDataset::BatchRequestType; 28 29 /// Constructs a new `SharedBatchDataset` from a `shared_ptr` to the 30 /// `UnderlyingDataset`. SharedBatchDataset(std::shared_ptr<UnderlyingDataset> shared_dataset)31 /* implicit */ SharedBatchDataset( 32 std::shared_ptr<UnderlyingDataset> shared_dataset) 33 : dataset_(std::move(shared_dataset)) {} 34 35 /// Calls `get_batch` on the underlying dataset. get_batch(BatchRequestType request)36 BatchType get_batch(BatchRequestType request) override { 37 return dataset_->get_batch(std::move(request)); 38 } 39 40 /// Returns the `size` from the underlying dataset. size()41 std::optional<size_t> size() const override { 42 return dataset_->size(); 43 } 44 45 /// Accesses the underlying dataset. 46 UnderlyingDataset& operator*() { 47 return *dataset_; 48 } 49 50 /// Accesses the underlying dataset. 51 const UnderlyingDataset& operator*() const { 52 return *dataset_; 53 } 54 55 /// Accesses the underlying dataset. 56 UnderlyingDataset* operator->() { 57 return dataset_.get(); 58 } 59 60 /// Accesses the underlying dataset. 61 const UnderlyingDataset* operator->() const { 62 return dataset_.get(); 63 } 64 65 /// Calls `reset()` on the underlying dataset. reset()66 void reset() { 67 dataset_->reset(); 68 } 69 70 private: 71 std::shared_ptr<UnderlyingDataset> dataset_; 72 }; 73 74 /// Constructs a new `SharedBatchDataset` by creating a 75 /// `shared_ptr<UnderlyingDatase>`. All arguments are forwarded to 76 /// `make_shared<UnderlyingDataset>`. 77 template <typename UnderlyingDataset, typename... Args> make_shared_dataset(Args &&...args)78SharedBatchDataset<UnderlyingDataset> make_shared_dataset(Args&&... args) { 79 return std::make_shared<UnderlyingDataset>(std::forward<Args>(args)...); 80 } 81 } // namespace datasets 82 } // namespace data 83 } // namespace torch 84