xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/data/samplers/distributed.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/data/samplers/distributed.h>
3 #include <torch/serialize/archive.h>
4 #include <torch/types.h>
5 
6 #include <algorithm>
7 #include <cstddef>
8 #include <random>
9 #include <vector>
10 
11 namespace torch {
12 namespace data {
13 namespace samplers {
14 
DistributedRandomSampler(size_t size,size_t num_replicas,size_t rank,bool allow_duplicates)15 DistributedRandomSampler::DistributedRandomSampler(
16     size_t size,
17     size_t num_replicas,
18     size_t rank,
19     bool allow_duplicates)
20     : DistributedSampler(size, num_replicas, rank, allow_duplicates),
21       begin_index_(0),
22       end_index_(0),
23       sample_index_(0) {
24   // shuffle first time.
25   reset(size_);
26 }
27 
next(size_t batch_size)28 std::optional<std::vector<size_t>> DistributedRandomSampler::next(
29     size_t batch_size) {
30   if (sample_index_ == end_index_) {
31     return nullopt;
32   }
33 
34   size_t end = sample_index_ + batch_size;
35   if (end > end_index_) {
36     end = end_index_;
37   }
38 
39   auto iter = all_indices_.begin();
40   std::vector<size_t> res(iter + sample_index_, iter + end);
41   sample_index_ = end;
42   return res;
43 }
44 
reset(std::optional<size_t> new_size)45 void DistributedRandomSampler::reset(std::optional<size_t> new_size) {
46   size_ = new_size.value_or(size_);
47   populate_indices();
48 
49   std::mt19937 rand(epoch_);
50   std::shuffle(all_indices_.begin(), all_indices_.end(), rand);
51   sample_index_ = begin_index_;
52 }
53 
populate_indices()54 void DistributedRandomSampler::populate_indices() {
55   size_t num_local_samples = local_sample_count();
56   size_t sample_count =
57       num_replicas_ == 1 ? size_ : num_local_samples * num_replicas_;
58   all_indices_.resize(sample_count);
59   std::iota(std::begin(all_indices_), std::end(all_indices_), 0);
60   for (const auto i : c10::irange(size_, sample_count)) {
61     // we may have added duplicate samples to make all
62     // replicas to have the same number of samples.
63     all_indices_[i] = i - size_;
64   }
65   begin_index_ = rank_ * num_local_samples;
66   end_index_ = begin_index_ + num_local_samples;
67   sample_index_ = begin_index_;
68 }
69 
save(serialize::OutputArchive & archive) const70 void DistributedRandomSampler::save(serialize::OutputArchive& archive) const {
71   archive.write(
72       "sample_index_",
73       torch::tensor(static_cast<int64_t>(sample_index_)),
74       /*is_buffer=*/true);
75   archive.write(
76       "epoch_",
77       torch::tensor(static_cast<int64_t>(epoch_)),
78       /*is_buffer=*/true);
79 }
80 
load(serialize::InputArchive & archive)81 void DistributedRandomSampler::load(serialize::InputArchive& archive) {
82   auto tensor = torch::empty(1, torch::kInt64);
83   archive.read("epoch_", tensor, /*is_buffer=*/true);
84   epoch_ = tensor.item<int64_t>();
85   // call reset() after loading epoch_ to populate indices.
86   reset(size_);
87 
88   tensor = torch::empty(1, torch::kInt64);
89   archive.read("sample_index_", tensor, /*is_buffer=*/true);
90   sample_index_ = tensor.item<int64_t>();
91 }
92 
index() const93 size_t DistributedRandomSampler::index() const noexcept {
94   return sample_index_;
95 }
96 
DistributedSequentialSampler(size_t size,size_t num_replicas,size_t rank,bool allow_duplicates)97 DistributedSequentialSampler::DistributedSequentialSampler(
98     size_t size,
99     size_t num_replicas,
100     size_t rank,
101     bool allow_duplicates)
102     : DistributedSampler(size, num_replicas, rank, allow_duplicates),
103       begin_index_(0),
104       end_index_(0),
105       sample_index_(0) {
106   populate_indices();
107 }
108 
next(size_t batch_size)109 std::optional<std::vector<size_t>> DistributedSequentialSampler::next(
110     size_t batch_size) {
111   if (sample_index_ == end_index_) {
112     return nullopt;
113   }
114 
115   size_t end = sample_index_ + batch_size;
116   if (end > end_index_) {
117     end = end_index_;
118   }
119 
120   std::vector<size_t> res(end - sample_index_);
121   std::iota(std::begin(res), std::end(res), sample_index_);
122   if (end >= size_) {
123     for (size_t& index : res) {
124       index = index % size_;
125     }
126   }
127   sample_index_ = end;
128   return res;
129 }
130 
reset(std::optional<size_t> new_size)131 void DistributedSequentialSampler::reset(std::optional<size_t> new_size) {
132   size_t size = new_size.value_or(size_);
133   if (size != size_) {
134     size_ = size;
135     populate_indices();
136   } else {
137     sample_index_ = begin_index_;
138   }
139 }
140 
populate_indices()141 void DistributedSequentialSampler::populate_indices() {
142   begin_index_ = rank_ * local_sample_count();
143   end_index_ = begin_index_ + local_sample_count();
144   sample_index_ = begin_index_;
145 }
146 
save(serialize::OutputArchive & archive) const147 void DistributedSequentialSampler::save(
148     serialize::OutputArchive& archive) const {
149   archive.write(
150       "sample_index_",
151       torch::tensor(static_cast<int64_t>(sample_index_)),
152       /*is_buffer=*/true);
153 }
154 
load(serialize::InputArchive & archive)155 void DistributedSequentialSampler::load(serialize::InputArchive& archive) {
156   auto tensor = torch::empty(1, torch::kInt64);
157   archive.read("sample_index_", tensor, /*is_buffer=*/true);
158   sample_index_ = tensor.item<int64_t>();
159 }
160 
index() const161 size_t DistributedSequentialSampler::index() const noexcept {
162   return sample_index_;
163 }
164 
165 } // namespace samplers
166 } // namespace data
167 } // namespace torch
168