1 #pragma once
2
3 #include <torch/csrc/distributed/c10d/Store.hpp>
4
5 #include <chrono>
6 #include <cstdint>
7
8 #include <ATen/core/Tensor.h>
9 #include <ATen/core/ivalue.h>
10
11 #include <c10/macros/Macros.h>
12 #include <c10/util/intrusive_ptr.h>
13
14 namespace c10d {
15
16 // Base class for supplementary data potentially needed by ReduceOps
17 struct TORCH_API _SupplementBase : torch::CustomClassHolder {
18 ~_SupplementBase() override = default;
19 };
20
21 // Supplementary data specific to NCCL PREMUL_SUM
22 // The point of use in ProcessGroupNCCL knows how to unpack it.
23 struct NCCLPreMulSumSupplement : _SupplementBase {
24 double double_factor{0.0};
25 at::Tensor tensor_factor;
NCCLPreMulSumSupplementc10d::NCCLPreMulSumSupplement26 NCCLPreMulSumSupplement(double f) : double_factor{f} {}
NCCLPreMulSumSupplementc10d::NCCLPreMulSumSupplement27 NCCLPreMulSumSupplement(at::Tensor t) : tensor_factor{std::move(t)} {
28 TORCH_CHECK_EQ(tensor_factor.numel(), 1);
29 }
30 };
31
32 // Other ReduceOps that need different supplementary data can also
33 // derive from _SupplementBase.
34 struct TORCH_API ReduceOp : torch::CustomClassHolder {
35 // note(crcrpar): RedOpType could be defined outside of `ReduceOp`
36 enum RedOpType : uint8_t {
37 SUM = 0,
38 AVG = 1,
39 PRODUCT = 2,
40 MIN = 3,
41 MAX = 4,
42 BAND = 5, // Bitwise AND
43 BOR = 6, // Bitwise OR
44 BXOR = 7, // Bitwise XOR
45 PREMUL_SUM = 8, // Multiply by a user-supplied constant before summing.
46 UNUSED = 9
47 };
48
49 ReduceOp() = default;
50
ReduceOpc10d::ReduceOp51 ReduceOp(RedOpType op) : op_(op) {
52 TORCH_INTERNAL_ASSERT(
53 op_ != PREMUL_SUM,
54 "Use `torch.distributed._make_nccl_premul_sum` to create an instance of ReduceOp with PREMUL_SUM");
55 }
56
ReduceOpc10d::ReduceOp57 ReduceOp(
58 RedOpType op,
59 const c10::intrusive_ptr<_SupplementBase>& optional_supplement) {
60 if (optional_supplement) {
61 op_ = op;
62 } else {
63 supplement_ = optional_supplement;
64 }
65 }
66
67 // The heap resource supplement_, if it exists, is managed by a
68 // c10::intrusive_ptr, so constructors and operator= can be simple
69 ReduceOp(const ReduceOp& other) = default;
70 ReduceOp& operator=(const ReduceOp& other) = default;
71
72 ReduceOp(ReduceOp&& other) = default;
73 ReduceOp& operator=(ReduceOp&& other) = default;
74
operator RedOpTypec10d::ReduceOp75 operator RedOpType() const {
76 return op_;
77 }
78
operator ==c10d::ReduceOp79 bool operator==(const std::uint8_t other) {
80 TORCH_INTERNAL_ASSERT(other < 9, "Invalid other op value");
81 return other == op_;
82 }
83
operator ==c10d::ReduceOp84 bool operator==(const ReduceOp::RedOpType other) {
85 return *this == static_cast<std::uint8_t>(other);
86 }
87
88 // todo(crcrpar): Handle `RedOpType::PREMUL_SUM` with its scaling factor.
operator ==c10d::ReduceOp89 bool operator==(const ReduceOp& other) {
90 return *this == other.op_;
91 }
92
93 RedOpType op_ = SUM;
94 // supplement_ is "type-erased" storage for optional supplementary
95 // data the op might need.
96 // The point of use will know the derived type supplement_ really is,
97 // and downcast its pointer to extract the data as the needed type(s).
98 // Right now, only PREMUL_SUM needs supplementary data, but the same
99 // mechanism could extend to support other nontrivial reduce ops with
100 // different supplementary payloads.
101 c10::intrusive_ptr<_SupplementBase> supplement_;
102 };
103
104 template <typename T>
makeNCCLPreMulSum(const T & factor)105 ReduceOp makeNCCLPreMulSum(const T& factor) {
106 ReduceOp rop;
107 rop.op_ = ReduceOp::PREMUL_SUM;
108 rop.supplement_ = c10::make_intrusive<NCCLPreMulSumSupplement>(factor);
109 return rop;
110 }
111
112 constexpr auto kUnsetTimeout = std::chrono::milliseconds(-1);
113
114 struct BroadcastOptions {
115 int64_t rootRank = 0;
116 int64_t rootTensor = 0;
117 std::chrono::milliseconds timeout = kUnsetTimeout;
118 bool asyncOp = true;
119 };
120
121 struct AllreduceOptions {
122 ReduceOp reduceOp = ReduceOp::SUM;
123 std::chrono::milliseconds timeout = kUnsetTimeout;
124 std::optional<at::Tensor> sparseIndices = std::nullopt;
125 };
126
127 struct AllreduceCoalescedOptions : AllreduceOptions {};
128
129 struct ReduceOptions {
130 ReduceOp reduceOp = ReduceOp::SUM;
131 int64_t rootRank = 0;
132 int64_t rootTensor = 0;
133 std::chrono::milliseconds timeout = kUnsetTimeout;
134 };
135
136 struct AllgatherOptions {
137 std::chrono::milliseconds timeout = kUnsetTimeout;
138 bool asyncOp = true;
139 };
140
141 struct GatherOptions {
142 int64_t rootRank = 0;
143 std::chrono::milliseconds timeout = kUnsetTimeout;
144 };
145
146 struct ScatterOptions {
147 int64_t rootRank = 0;
148 std::chrono::milliseconds timeout = kUnsetTimeout;
149 bool asyncOp = true;
150 };
151
152 struct ReduceScatterOptions {
153 ReduceOp reduceOp = ReduceOp::SUM;
154 std::chrono::milliseconds timeout = kUnsetTimeout;
155 bool asyncOp = true;
156 };
157
158 struct AllToAllOptions {
159 std::chrono::milliseconds timeout = kUnsetTimeout;
160 };
161
162 struct BarrierOptions {
163 std::vector<int64_t> device_ids;
164 std::chrono::milliseconds timeout = kUnsetTimeout;
165 std::optional<at::Device> device;
166 };
167
168 struct DistributedBackendOptions {
169 c10::intrusive_ptr<::c10d::Store> store;
170 int group_rank;
171 int group_size;
172 std::chrono::duration<float> timeout;
173 std::string group_id;
174 std::vector<int64_t> global_ranks_in_group;
175 };
176
177 } // namespace c10d
178