1 #pragma once 2 3 #include <ATen/core/ivalue.h> 4 #include <atomic> 5 6 namespace torch::distributed::rpc { 7 8 using worker_id_t = int16_t; 9 using local_id_t = int64_t; 10 11 bool getAllowJitRRefPickle(); 12 TORCH_API void enableJitRRefPickle(); 13 TORCH_API void disableJitRRefPickle(); 14 15 struct TORCH_API JitRRefPickleGuard { 16 JitRRefPickleGuard(); 17 ~JitRRefPickleGuard(); 18 }; 19 20 struct TORCH_API GloballyUniqueId final { 21 GloballyUniqueId(worker_id_t createdOn, local_id_t localId); 22 GloballyUniqueId(const GloballyUniqueId& other) = default; 23 GloballyUniqueId& operator=(const GloballyUniqueId& other) = delete; 24 25 bool operator==(const GloballyUniqueId& other) const; 26 bool operator!=(const GloballyUniqueId& other) const; 27 28 at::IValue toIValue() const; 29 static GloballyUniqueId fromIValue(const at::IValue&); 30 31 struct Hash { operatorfinal::Hash32 size_t operator()(const GloballyUniqueId& key) const { 33 return (uint64_t(key.createdOn_) << kLocalIdBits) | key.localId_; 34 } 35 }; 36 37 static constexpr int kLocalIdBits = 48; 38 39 const worker_id_t createdOn_; 40 const local_id_t localId_; 41 }; 42 43 TORCH_API std::ostream& operator<<( 44 std::ostream& os, 45 const GloballyUniqueId& globalId); 46 47 using RRefId = GloballyUniqueId; 48 using ForkId = GloballyUniqueId; 49 using ProfilingId = GloballyUniqueId; 50 51 struct TORCH_API SerializedPyObj final { SerializedPyObjfinal52 SerializedPyObj(std::string&& payload, std::vector<at::Tensor>&& tensors) 53 : payload_(std::move(payload)), tensors_(std::move(tensors)) {} 54 55 std::vector<at::IValue> toIValues() &&; 56 static SerializedPyObj fromIValues(std::vector<at::IValue> value); 57 58 std::string payload_; 59 std::vector<at::Tensor> tensors_; 60 }; 61 62 } // namespace torch::distributed::rpc 63