xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/types.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 #include <atomic>
5 
6 namespace torch::distributed::rpc {
7 
8 using worker_id_t = int16_t;
9 using local_id_t = int64_t;
10 
11 bool getAllowJitRRefPickle();
12 TORCH_API void enableJitRRefPickle();
13 TORCH_API void disableJitRRefPickle();
14 
15 struct TORCH_API JitRRefPickleGuard {
16   JitRRefPickleGuard();
17   ~JitRRefPickleGuard();
18 };
19 
20 struct TORCH_API GloballyUniqueId final {
21   GloballyUniqueId(worker_id_t createdOn, local_id_t localId);
22   GloballyUniqueId(const GloballyUniqueId& other) = default;
23   GloballyUniqueId& operator=(const GloballyUniqueId& other) = delete;
24 
25   bool operator==(const GloballyUniqueId& other) const;
26   bool operator!=(const GloballyUniqueId& other) const;
27 
28   at::IValue toIValue() const;
29   static GloballyUniqueId fromIValue(const at::IValue&);
30 
31   struct Hash {
operatorfinal::Hash32     size_t operator()(const GloballyUniqueId& key) const {
33       return (uint64_t(key.createdOn_) << kLocalIdBits) | key.localId_;
34     }
35   };
36 
37   static constexpr int kLocalIdBits = 48;
38 
39   const worker_id_t createdOn_;
40   const local_id_t localId_;
41 };
42 
43 TORCH_API std::ostream& operator<<(
44     std::ostream& os,
45     const GloballyUniqueId& globalId);
46 
47 using RRefId = GloballyUniqueId;
48 using ForkId = GloballyUniqueId;
49 using ProfilingId = GloballyUniqueId;
50 
51 struct TORCH_API SerializedPyObj final {
SerializedPyObjfinal52   SerializedPyObj(std::string&& payload, std::vector<at::Tensor>&& tensors)
53       : payload_(std::move(payload)), tensors_(std::move(tensors)) {}
54 
55   std::vector<at::IValue> toIValues() &&;
56   static SerializedPyObj fromIValues(std::vector<at::IValue> value);
57 
58   std::string payload_;
59   std::vector<at::Tensor> tensors_;
60 };
61 
62 } // namespace torch::distributed::rpc
63