1 #include <c10/util/irange.h>
2 #include <torch/data/samplers/distributed.h>
3 #include <torch/serialize/archive.h>
4 #include <torch/types.h>
5
6 #include <algorithm>
7 #include <cstddef>
8 #include <random>
9 #include <vector>
10
11 namespace torch {
12 namespace data {
13 namespace samplers {
14
DistributedRandomSampler(size_t size,size_t num_replicas,size_t rank,bool allow_duplicates)15 DistributedRandomSampler::DistributedRandomSampler(
16 size_t size,
17 size_t num_replicas,
18 size_t rank,
19 bool allow_duplicates)
20 : DistributedSampler(size, num_replicas, rank, allow_duplicates),
21 begin_index_(0),
22 end_index_(0),
23 sample_index_(0) {
24 // shuffle first time.
25 reset(size_);
26 }
27
next(size_t batch_size)28 std::optional<std::vector<size_t>> DistributedRandomSampler::next(
29 size_t batch_size) {
30 if (sample_index_ == end_index_) {
31 return nullopt;
32 }
33
34 size_t end = sample_index_ + batch_size;
35 if (end > end_index_) {
36 end = end_index_;
37 }
38
39 auto iter = all_indices_.begin();
40 std::vector<size_t> res(iter + sample_index_, iter + end);
41 sample_index_ = end;
42 return res;
43 }
44
reset(std::optional<size_t> new_size)45 void DistributedRandomSampler::reset(std::optional<size_t> new_size) {
46 size_ = new_size.value_or(size_);
47 populate_indices();
48
49 std::mt19937 rand(epoch_);
50 std::shuffle(all_indices_.begin(), all_indices_.end(), rand);
51 sample_index_ = begin_index_;
52 }
53
populate_indices()54 void DistributedRandomSampler::populate_indices() {
55 size_t num_local_samples = local_sample_count();
56 size_t sample_count =
57 num_replicas_ == 1 ? size_ : num_local_samples * num_replicas_;
58 all_indices_.resize(sample_count);
59 std::iota(std::begin(all_indices_), std::end(all_indices_), 0);
60 for (const auto i : c10::irange(size_, sample_count)) {
61 // we may have added duplicate samples to make all
62 // replicas to have the same number of samples.
63 all_indices_[i] = i - size_;
64 }
65 begin_index_ = rank_ * num_local_samples;
66 end_index_ = begin_index_ + num_local_samples;
67 sample_index_ = begin_index_;
68 }
69
save(serialize::OutputArchive & archive) const70 void DistributedRandomSampler::save(serialize::OutputArchive& archive) const {
71 archive.write(
72 "sample_index_",
73 torch::tensor(static_cast<int64_t>(sample_index_)),
74 /*is_buffer=*/true);
75 archive.write(
76 "epoch_",
77 torch::tensor(static_cast<int64_t>(epoch_)),
78 /*is_buffer=*/true);
79 }
80
load(serialize::InputArchive & archive)81 void DistributedRandomSampler::load(serialize::InputArchive& archive) {
82 auto tensor = torch::empty(1, torch::kInt64);
83 archive.read("epoch_", tensor, /*is_buffer=*/true);
84 epoch_ = tensor.item<int64_t>();
85 // call reset() after loading epoch_ to populate indices.
86 reset(size_);
87
88 tensor = torch::empty(1, torch::kInt64);
89 archive.read("sample_index_", tensor, /*is_buffer=*/true);
90 sample_index_ = tensor.item<int64_t>();
91 }
92
index() const93 size_t DistributedRandomSampler::index() const noexcept {
94 return sample_index_;
95 }
96
DistributedSequentialSampler(size_t size,size_t num_replicas,size_t rank,bool allow_duplicates)97 DistributedSequentialSampler::DistributedSequentialSampler(
98 size_t size,
99 size_t num_replicas,
100 size_t rank,
101 bool allow_duplicates)
102 : DistributedSampler(size, num_replicas, rank, allow_duplicates),
103 begin_index_(0),
104 end_index_(0),
105 sample_index_(0) {
106 populate_indices();
107 }
108
next(size_t batch_size)109 std::optional<std::vector<size_t>> DistributedSequentialSampler::next(
110 size_t batch_size) {
111 if (sample_index_ == end_index_) {
112 return nullopt;
113 }
114
115 size_t end = sample_index_ + batch_size;
116 if (end > end_index_) {
117 end = end_index_;
118 }
119
120 std::vector<size_t> res(end - sample_index_);
121 std::iota(std::begin(res), std::end(res), sample_index_);
122 if (end >= size_) {
123 for (size_t& index : res) {
124 index = index % size_;
125 }
126 }
127 sample_index_ = end;
128 return res;
129 }
130
reset(std::optional<size_t> new_size)131 void DistributedSequentialSampler::reset(std::optional<size_t> new_size) {
132 size_t size = new_size.value_or(size_);
133 if (size != size_) {
134 size_ = size;
135 populate_indices();
136 } else {
137 sample_index_ = begin_index_;
138 }
139 }
140
populate_indices()141 void DistributedSequentialSampler::populate_indices() {
142 begin_index_ = rank_ * local_sample_count();
143 end_index_ = begin_index_ + local_sample_count();
144 sample_index_ = begin_index_;
145 }
146
save(serialize::OutputArchive & archive) const147 void DistributedSequentialSampler::save(
148 serialize::OutputArchive& archive) const {
149 archive.write(
150 "sample_index_",
151 torch::tensor(static_cast<int64_t>(sample_index_)),
152 /*is_buffer=*/true);
153 }
154
load(serialize::InputArchive & archive)155 void DistributedSequentialSampler::load(serialize::InputArchive& archive) {
156 auto tensor = torch::empty(1, torch::kInt64);
157 archive.read("sample_index_", tensor, /*is_buffer=*/true);
158 sample_index_ = tensor.item<int64_t>();
159 }
160
index() const161 size_t DistributedSequentialSampler::index() const noexcept {
162 return sample_index_;
163 }
164
165 } // namespace samplers
166 } // namespace data
167 } // namespace torch
168