1 #include <ATen/ThreadLocalState.h> 2 #include <torch/csrc/distributed/c10d/sequence_num.hpp> 3 4 #include <c10/util/Logging.h> 5 6 namespace c10d { 7 SequenceNum::SequenceNum() = default; 8 SequenceNum(const uint64_t num)9SequenceNum::SequenceNum(const uint64_t num) : num_(num) {} 10 SequenceNum(const SequenceNum & other)11SequenceNum::SequenceNum(const SequenceNum& other) { 12 if (!other.isSet()) { 13 num_ = std::nullopt; 14 } else { 15 num_ = other.get(); 16 } 17 } 18 get() const19uint64_t SequenceNum::get() const { 20 std::lock_guard<std::mutex> lock(lock_); 21 return *num_; 22 } 23 increment()24void SequenceNum::increment() { 25 std::lock_guard<std::mutex> lock(lock_); 26 TORCH_CHECK(num_ != std::nullopt); 27 num_ = ++(*num_); 28 } 29 30 // Implemented without above get() and increment() so we don't repeatedly lock 31 // and unblock. getAndIncrement()32uint64_t SequenceNum::getAndIncrement() { 33 uint64_t curVal = 0; 34 std::lock_guard<std::mutex> lock(lock_); 35 TORCH_CHECK(num_ != std::nullopt); 36 curVal = *num_; 37 num_ = ++(*num_); 38 return curVal; 39 } 40 set(const uint64_t num)41void SequenceNum::set(const uint64_t num) { 42 std::lock_guard<std::mutex> lock(lock_); 43 num_ = num; 44 } 45 isSet() const46bool SequenceNum::isSet() const { 47 std::lock_guard<std::mutex> lock(lock_); 48 return num_ != std::nullopt; 49 } 50 operator =(const SequenceNum & other)51SequenceNum& SequenceNum::operator=(const SequenceNum& other) { 52 std::lock_guard<std::mutex> lock(lock_); 53 if (!other.isSet()) { 54 num_ = std::nullopt; 55 } else { 56 num_ = other.get(); 57 } 58 return *this; 59 } 60 61 } // namespace c10d 62