xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/sequence_num.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ThreadLocalState.h>
2 #include <torch/csrc/distributed/c10d/sequence_num.hpp>
3 
4 #include <c10/util/Logging.h>
5 
6 namespace c10d {
7 SequenceNum::SequenceNum() = default;
8 
SequenceNum(const uint64_t num)9 SequenceNum::SequenceNum(const uint64_t num) : num_(num) {}
10 
SequenceNum(const SequenceNum & other)11 SequenceNum::SequenceNum(const SequenceNum& other) {
12   if (!other.isSet()) {
13     num_ = std::nullopt;
14   } else {
15     num_ = other.get();
16   }
17 }
18 
get() const19 uint64_t SequenceNum::get() const {
20   std::lock_guard<std::mutex> lock(lock_);
21   return *num_;
22 }
23 
increment()24 void SequenceNum::increment() {
25   std::lock_guard<std::mutex> lock(lock_);
26   TORCH_CHECK(num_ != std::nullopt);
27   num_ = ++(*num_);
28 }
29 
30 // Implemented without above get() and increment() so we don't repeatedly lock
31 // and unblock.
getAndIncrement()32 uint64_t SequenceNum::getAndIncrement() {
33   uint64_t curVal = 0;
34   std::lock_guard<std::mutex> lock(lock_);
35   TORCH_CHECK(num_ != std::nullopt);
36   curVal = *num_;
37   num_ = ++(*num_);
38   return curVal;
39 }
40 
set(const uint64_t num)41 void SequenceNum::set(const uint64_t num) {
42   std::lock_guard<std::mutex> lock(lock_);
43   num_ = num;
44 }
45 
isSet() const46 bool SequenceNum::isSet() const {
47   std::lock_guard<std::mutex> lock(lock_);
48   return num_ != std::nullopt;
49 }
50 
operator =(const SequenceNum & other)51 SequenceNum& SequenceNum::operator=(const SequenceNum& other) {
52   std::lock_guard<std::mutex> lock(lock_);
53   if (!other.isSet()) {
54     num_ = std::nullopt;
55   } else {
56     num_ = other.get();
57   }
58   return *this;
59 }
60 
61 } // namespace c10d
62