xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/Types.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/c10d/Store.hpp>
4 
5 #include <chrono>
6 #include <cstdint>
7 
8 #include <ATen/core/Tensor.h>
9 #include <ATen/core/ivalue.h>
10 
11 #include <c10/macros/Macros.h>
12 #include <c10/util/intrusive_ptr.h>
13 
14 namespace c10d {
15 
16 // Base class for supplementary data potentially needed by ReduceOps
17 struct TORCH_API _SupplementBase : torch::CustomClassHolder {
18   ~_SupplementBase() override = default;
19 };
20 
21 // Supplementary data specific to NCCL PREMUL_SUM
22 // The point of use in ProcessGroupNCCL knows how to unpack it.
23 struct NCCLPreMulSumSupplement : _SupplementBase {
24   double double_factor{0.0};
25   at::Tensor tensor_factor;
NCCLPreMulSumSupplementc10d::NCCLPreMulSumSupplement26   NCCLPreMulSumSupplement(double f) : double_factor{f} {}
NCCLPreMulSumSupplementc10d::NCCLPreMulSumSupplement27   NCCLPreMulSumSupplement(at::Tensor t) : tensor_factor{std::move(t)} {
28     TORCH_CHECK_EQ(tensor_factor.numel(), 1);
29   }
30 };
31 
32 // Other ReduceOps that need different supplementary data can also
33 // derive from _SupplementBase.
34 struct TORCH_API ReduceOp : torch::CustomClassHolder {
35   // note(crcrpar): RedOpType could be defined outside of `ReduceOp`
36   enum RedOpType : uint8_t {
37     SUM = 0,
38     AVG = 1,
39     PRODUCT = 2,
40     MIN = 3,
41     MAX = 4,
42     BAND = 5, // Bitwise AND
43     BOR = 6, // Bitwise OR
44     BXOR = 7, // Bitwise XOR
45     PREMUL_SUM = 8, // Multiply by a user-supplied constant before summing.
46     UNUSED = 9
47   };
48 
49   ReduceOp() = default;
50 
ReduceOpc10d::ReduceOp51   ReduceOp(RedOpType op) : op_(op) {
52     TORCH_INTERNAL_ASSERT(
53         op_ != PREMUL_SUM,
54         "Use `torch.distributed._make_nccl_premul_sum` to create an instance of ReduceOp with PREMUL_SUM");
55   }
56 
ReduceOpc10d::ReduceOp57   ReduceOp(
58       RedOpType op,
59       const c10::intrusive_ptr<_SupplementBase>& optional_supplement) {
60     if (optional_supplement) {
61       op_ = op;
62     } else {
63       supplement_ = optional_supplement;
64     }
65   }
66 
67   // The heap resource supplement_, if it exists, is managed by a
68   // c10::intrusive_ptr, so constructors and operator= can be simple
69   ReduceOp(const ReduceOp& other) = default;
70   ReduceOp& operator=(const ReduceOp& other) = default;
71 
72   ReduceOp(ReduceOp&& other) = default;
73   ReduceOp& operator=(ReduceOp&& other) = default;
74 
operator RedOpTypec10d::ReduceOp75   operator RedOpType() const {
76     return op_;
77   }
78 
operator ==c10d::ReduceOp79   bool operator==(const std::uint8_t other) {
80     TORCH_INTERNAL_ASSERT(other < 9, "Invalid other op value");
81     return other == op_;
82   }
83 
operator ==c10d::ReduceOp84   bool operator==(const ReduceOp::RedOpType other) {
85     return *this == static_cast<std::uint8_t>(other);
86   }
87 
88   // todo(crcrpar): Handle `RedOpType::PREMUL_SUM` with its scaling factor.
operator ==c10d::ReduceOp89   bool operator==(const ReduceOp& other) {
90     return *this == other.op_;
91   }
92 
93   RedOpType op_ = SUM;
94   // supplement_ is "type-erased" storage for optional supplementary
95   // data the op might need.
96   // The point of use will know the derived type supplement_ really is,
97   // and downcast its pointer to extract the data as the needed type(s).
98   // Right now, only PREMUL_SUM needs supplementary data, but the same
99   // mechanism could extend to support other nontrivial reduce ops with
100   // different supplementary payloads.
101   c10::intrusive_ptr<_SupplementBase> supplement_;
102 };
103 
104 template <typename T>
makeNCCLPreMulSum(const T & factor)105 ReduceOp makeNCCLPreMulSum(const T& factor) {
106   ReduceOp rop;
107   rop.op_ = ReduceOp::PREMUL_SUM;
108   rop.supplement_ = c10::make_intrusive<NCCLPreMulSumSupplement>(factor);
109   return rop;
110 }
111 
112 constexpr auto kUnsetTimeout = std::chrono::milliseconds(-1);
113 
114 struct BroadcastOptions {
115   int64_t rootRank = 0;
116   int64_t rootTensor = 0;
117   std::chrono::milliseconds timeout = kUnsetTimeout;
118   bool asyncOp = true;
119 };
120 
121 struct AllreduceOptions {
122   ReduceOp reduceOp = ReduceOp::SUM;
123   std::chrono::milliseconds timeout = kUnsetTimeout;
124   std::optional<at::Tensor> sparseIndices = std::nullopt;
125 };
126 
127 struct AllreduceCoalescedOptions : AllreduceOptions {};
128 
129 struct ReduceOptions {
130   ReduceOp reduceOp = ReduceOp::SUM;
131   int64_t rootRank = 0;
132   int64_t rootTensor = 0;
133   std::chrono::milliseconds timeout = kUnsetTimeout;
134 };
135 
136 struct AllgatherOptions {
137   std::chrono::milliseconds timeout = kUnsetTimeout;
138   bool asyncOp = true;
139 };
140 
141 struct GatherOptions {
142   int64_t rootRank = 0;
143   std::chrono::milliseconds timeout = kUnsetTimeout;
144 };
145 
146 struct ScatterOptions {
147   int64_t rootRank = 0;
148   std::chrono::milliseconds timeout = kUnsetTimeout;
149   bool asyncOp = true;
150 };
151 
152 struct ReduceScatterOptions {
153   ReduceOp reduceOp = ReduceOp::SUM;
154   std::chrono::milliseconds timeout = kUnsetTimeout;
155   bool asyncOp = true;
156 };
157 
158 struct AllToAllOptions {
159   std::chrono::milliseconds timeout = kUnsetTimeout;
160 };
161 
162 struct BarrierOptions {
163   std::vector<int64_t> device_ids;
164   std::chrono::milliseconds timeout = kUnsetTimeout;
165   std::optional<at::Device> device;
166 };
167 
168 struct DistributedBackendOptions {
169   c10::intrusive_ptr<::c10d::Store> store;
170   int group_rank;
171   int group_size;
172   std::chrono::duration<float> timeout;
173   std::string group_id;
174   std::vector<int64_t> global_ranks_in_group;
175 };
176 
177 } // namespace c10d
178