xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/ProcessGroup.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/c10d/Backend.hpp>
4 #include <memory>
5 #include <unordered_map>
6 #include <utility>
7 #include <vector>
8 
9 #include <ATen/ATen.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11 #include <c10/macros/Macros.h>
12 
13 #include <torch/csrc/distributed/c10d/Work.hpp>
14 // *************************************************************************
15 // PROCESS GROUP collective communication API IS BEING CHANGED BETWEEN
16 // versions 1.7 and 1.8.
17 // PLEASE DO NOT ADD ANY DEPENDENCIES.
18 // SEE RFC: https://github.com/pytorch/pytorch/issues/39662
19 // *************************************************************************
20 
21 constexpr auto kProcessGroupDefaultTimeout =
22     std::chrono::milliseconds(30 * 60 * 1000);
23 
24 namespace c10d {
25 
26 // ProcessGroup is a base class that captures collective and point to
27 // point communication in a fixed set of processes.
28 //
29 // The functions specified in the class below describe the API alone;
30 // implementations are provided in subclasses.
31 //
32 // Every function that performs I/O is executed asynchronously by a
33 // thread pool owned by the ProcessGroup (by default). They return an
34 // object that can be used to wait for completion or error.
35 //
36 // The ProcessGroup can instantiate subgroups with fewer or an equal
37 // number of members. Implementations must take care that multiple
38 // process groups can be used in parallel and synchronize accordingly.
39 //
40 // The ProcessGroup assumes a fixed set of processes. If the set
41 // changes, existing instances must be destructed and instantiation
42 // and initialization must start from scratch. For members of the
43 // process group to find each other (referred to as rendezvous from
44 // hereon)
45 //
46 class TORCH_API ProcessGroup : public torch::CustomClassHolder {
47  public:
48   // ProcessGroup Options is a base struct that defines the basic options
49   // when constructing a ProcessGroup. Each ProcessGroup subclass should
50   // extend this struct and define its options if it wants to provide more
51   // config options (beyond basic ones defined here) to end user.
52   struct TORCH_API Options : torch::CustomClassHolder {
Optionsc10d::ProcessGroup::Options53     explicit Options(
54         std::string backend,
55         std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout)
56         : timeout(timeout), backend(std::move(backend)) {}
57     ~Options() override = default;
58 
59     std::chrono::milliseconds timeout;
60 
61     // backend name
62     // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
63     const std::string backend;
64   };
65 
66   enum BackendType : uint8_t {
67     UNDEFINED = 0,
68     GLOO = 1,
69     NCCL = 2,
70     UCC = 3,
71     MPI = 4,
72     CUSTOM = 5,
73   };
74 
75   // Not used, set for backwards compatibility and only used for TypeDef in
76   // Ops.cpp
77   explicit ProcessGroup(int rank, int size);
78 
79   explicit ProcessGroup(
80       const c10::intrusive_ptr<::c10d::Store>& store,
81       int rank,
82       int size,
83       c10::intrusive_ptr<Options> options);
84   ~ProcessGroup() override;
85 
getRank() const86   int getRank() const {
87     return rank_;
88   }
89 
getSize() const90   int getSize() const {
91     return size_;
92   }
93 
94   // Returns an unique opaque ID of this process group object.
getID() const95   int64_t getID() const {
96     return reinterpret_cast<std::intptr_t>(this);
97   }
98 
99   // Returns an unique opaque ID of a backend for the specific backend type
100   // that can correlate with this process group's collectives.
getBackendID(BackendType backend_type) const101   int64_t getBackendID(BackendType backend_type) const {
102     return reinterpret_cast<std::intptr_t>(getBackend(backend_type).get());
103   }
104 
getBackendName() const105   virtual const std::string getBackendName() const {
106     return options_->backend;
107   };
108 
getBackendType() const109   BackendType getBackendType() const {
110     return backendType_;
111   };
112 
startCoalescing(c10::DeviceType deviceType)113   virtual void startCoalescing(c10::DeviceType deviceType) {
114     // only nccl has implemented startCoalescing so only execute for nccl
115     // backends
116     auto backend = getBackend(deviceType);
117     backend->startCoalescing();
118   }
119 
endCoalescing(c10::DeviceType deviceType)120   virtual c10::intrusive_ptr<Work> endCoalescing(c10::DeviceType deviceType) {
121     // only nccl has implemented endCoalescing so only execute for nccl
122     // backends
123     auto backend = getBackend(deviceType);
124     auto work = backend->endCoalescing();
125     return work;
126   }
127 
broadcast(std::vector<at::Tensor> & tensors,const BroadcastOptions & opts=BroadcastOptions ())128   virtual c10::intrusive_ptr<Work> broadcast(
129       std::vector<at::Tensor>& tensors,
130       const BroadcastOptions& opts = BroadcastOptions()) {
131     static auto op =
132         c10::Dispatcher::singleton()
133             .findSchemaOrThrow("c10d::broadcast_", "")
134             .typed<
135                 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
136                     at::TensorList,
137                     const c10::intrusive_ptr<::c10d::ProcessGroup>&,
138                     int64_t,
139                     int64_t,
140                     bool,
141                     int64_t)>();
142     // It's awakward to unbox the opts here and box them again in the custom C++
143     // op. But it's also complicated to make opts as a CustomClassHolder. Leave
144     // it as it is now.
145     return std::get<1>(op.call(
146         tensors,
147         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
148         opts.rootRank,
149         opts.rootTensor,
150         opts.asyncOp,
151         opts.timeout.count()));
152   }
153 
allreduce(std::vector<at::Tensor> & tensors,const AllreduceOptions & opts=AllreduceOptions ())154   virtual c10::intrusive_ptr<Work> allreduce(
155       std::vector<at::Tensor>& tensors,
156       const AllreduceOptions& opts = AllreduceOptions()) {
157     static auto op =
158         c10::Dispatcher::singleton()
159             .findSchemaOrThrow("c10d::allreduce_", "")
160             .typed<
161                 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
162                     at::TensorList,
163                     const c10::intrusive_ptr<::c10d::ProcessGroup>&,
164                     const c10::intrusive_ptr<::c10d::ReduceOp>&,
165                     const std::optional<at::Tensor>& sparse_indices,
166                     int64_t)>();
167 
168     return std::get<1>(op.call(
169         tensors,
170         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
171         c10::make_intrusive<ReduceOp>(opts.reduceOp),
172         opts.sparseIndices,
173         opts.timeout.count()));
174   }
175 
allreduce_coalesced(std::vector<at::Tensor> & tensors,const AllreduceCoalescedOptions & opts=AllreduceCoalescedOptions ())176   virtual c10::intrusive_ptr<Work> allreduce_coalesced(
177       std::vector<at::Tensor>& tensors,
178       const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) {
179     static auto op = c10::Dispatcher::singleton()
180                          .findSchemaOrThrow("c10d::allreduce_coalesced_", "")
181                          .typed<c10::intrusive_ptr<::c10d::Work>(
182                              at::TensorList,
183                              const c10::intrusive_ptr<::c10d::ProcessGroup>&,
184                              const c10::intrusive_ptr<::c10d::ReduceOp>&,
185                              int64_t)>();
186 
187     return op.call(
188         tensors,
189         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
190         c10::make_intrusive<ReduceOp>(opts.reduceOp),
191         opts.timeout.count());
192   }
193 
reduce(std::vector<at::Tensor> & tensors,const ReduceOptions & opts=ReduceOptions ())194   virtual c10::intrusive_ptr<Work> reduce(
195       std::vector<at::Tensor>& tensors,
196       const ReduceOptions& opts = ReduceOptions()) {
197     static auto op = c10::Dispatcher::singleton()
198                          .findSchemaOrThrow("c10d::reduce_", "")
199                          .typed<c10::intrusive_ptr<::c10d::Work>(
200                              at::TensorList,
201                              const c10::intrusive_ptr<::c10d::ProcessGroup>&,
202                              const c10::intrusive_ptr<::c10d::ReduceOp>&,
203                              int64_t,
204                              int64_t,
205                              int64_t)>();
206     return op.call(
207         tensors,
208         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
209         c10::make_intrusive<ReduceOp>(opts.reduceOp),
210         opts.rootRank,
211         opts.rootTensor,
212         opts.timeout.count());
213   }
214 
allgather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts=AllgatherOptions ())215   virtual c10::intrusive_ptr<Work> allgather(
216       std::vector<std::vector<at::Tensor>>& outputTensors,
217       std::vector<at::Tensor>& inputTensors,
218       const AllgatherOptions& opts = AllgatherOptions()) {
219     static auto op = c10::Dispatcher::singleton()
220                          .findSchemaOrThrow("c10d::allgather_", "")
221                          .typed<std::tuple<
222                              std::vector<std::vector<at::Tensor>>,
223                              c10::intrusive_ptr<Work>>(
224                              const std::vector<std::vector<at::Tensor>>&,
225                              at::TensorList,
226                              const c10::intrusive_ptr<::c10d::ProcessGroup>&,
227                              int64_t)>();
228 
229     return std::get<1>(op.call(
230         outputTensors,
231         inputTensors,
232         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
233         opts.timeout.count()));
234   }
235 
236   // Gathers a single tensor inputBuffer into a single buffer outputBuffer that
237   // is interpreted as a contiguous collection of size inputBuffer * WORLD_SIZE.
238   // For implementers of ProcessGroup API and advanced users only.
239   // Note: this function will be deprecated in near future.
_allgather_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,const AllgatherOptions & opts=AllgatherOptions ())240   virtual c10::intrusive_ptr<Work> _allgather_base(
241       at::Tensor& outputBuffer,
242       at::Tensor& inputBuffer,
243       const AllgatherOptions& opts = AllgatherOptions()) {
244     static auto op =
245         c10::Dispatcher::singleton()
246             .findSchemaOrThrow("c10d::_allgather_base_", "")
247             .typed<std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(
248                 at::Tensor&,
249                 at::Tensor&,
250                 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
251                 bool,
252                 int64_t)>();
253 
254     return std::get<1>(op.call(
255         outputBuffer,
256         inputBuffer,
257         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
258         opts.asyncOp,
259         opts.timeout.count()));
260   }
261 
262   // This function is deprecated and will be moved out of ProcessGroup to comms:
263   // * do not add dependencies on this function,
264   // * do not implement it in your ProcessGroup, implement _allgather_base
265   //   instead.
allgather_coalesced(std::vector<std::vector<at::Tensor>> & outputTensorLists,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts=AllgatherOptions ())266   virtual c10::intrusive_ptr<Work> allgather_coalesced(
267       std::vector<std::vector<at::Tensor>>& outputTensorLists,
268       std::vector<at::Tensor>& inputTensors,
269       const AllgatherOptions& opts = AllgatherOptions()) {
270     static auto op =
271         c10::Dispatcher::singleton()
272             .findSchemaOrThrow("c10d::allgather_coalesced_", "")
273             .typed<c10::intrusive_ptr<Work>(
274                 const std::vector<std::vector<at::Tensor>>&,
275                 const at::TensorList&,
276                 const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
277 
278     return op.call(
279         outputTensorLists,
280         inputTensors,
281         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
282   }
283 
284   // This function is a coalesced version of `allgather_into_tensor` (currently
285   // still named as `_allgather_base`). Each tensor in the vector corresponds to
286   // an input/output of one `allgather_into_tensor` operation.
allgather_into_tensor_coalesced(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts=AllgatherOptions ())287   virtual c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
288       std::vector<at::Tensor>& outputTensors,
289       std::vector<at::Tensor>& inputTensors,
290       const AllgatherOptions& opts = AllgatherOptions()) {
291     static auto op =
292         c10::Dispatcher::singleton()
293             .findSchemaOrThrow("c10d::allgather_into_tensor_coalesced_", "")
294             .typed<c10::intrusive_ptr<Work>(
295                 const at::TensorList,
296                 const at::TensorList,
297                 const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
298 
299     return op.call(
300         outputTensors,
301         inputTensors,
302         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
303   }
304 
gather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const GatherOptions & opts=GatherOptions ())305   virtual c10::intrusive_ptr<Work> gather(
306       std::vector<std::vector<at::Tensor>>& outputTensors,
307       std::vector<at::Tensor>& inputTensors,
308       const GatherOptions& opts = GatherOptions()) {
309     static auto op = c10::Dispatcher::singleton()
310                          .findSchemaOrThrow("c10d::gather_", "")
311                          .typed<c10::intrusive_ptr<::c10d::Work>(
312                              const std::vector<std::vector<at::Tensor>>&,
313                              const at::TensorList&,
314                              const c10::intrusive_ptr<::c10d::ProcessGroup>&,
315                              int64_t,
316                              int64_t)>();
317     return op.call(
318         outputTensors,
319         inputTensors,
320         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
321         opts.rootRank,
322         opts.timeout.count());
323   }
324 
scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ScatterOptions & opts=ScatterOptions ())325   virtual c10::intrusive_ptr<Work> scatter(
326       std::vector<at::Tensor>& outputTensors,
327       std::vector<std::vector<at::Tensor>>& inputTensors,
328       const ScatterOptions& opts = ScatterOptions()) {
329     static auto op =
330         c10::Dispatcher::singleton()
331             .findSchemaOrThrow("c10d::scatter_", "")
332             .typed<
333                 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
334                     const at::TensorList&,
335                     const std::vector<std::vector<at::Tensor>>&,
336                     const c10::intrusive_ptr<::c10d::ProcessGroup>&,
337                     int64_t,
338                     bool,
339                     int64_t)>();
340     return std::get<1>(op.call(
341         outputTensors,
342         inputTensors,
343         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
344         opts.rootRank,
345         opts.asyncOp,
346         opts.timeout.count()));
347   }
348 
reduce_scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ReduceScatterOptions & opts=ReduceScatterOptions ())349   virtual c10::intrusive_ptr<Work> reduce_scatter(
350       std::vector<at::Tensor>& outputTensors,
351       std::vector<std::vector<at::Tensor>>& inputTensors,
352       const ReduceScatterOptions& opts = ReduceScatterOptions()) {
353     static auto op =
354         c10::Dispatcher::singleton()
355             .findSchemaOrThrow("c10d::reduce_scatter_", "")
356             .typed<
357                 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
358                     const at::TensorList&,
359                     const std::vector<std::vector<at::Tensor>>&,
360                     const c10::intrusive_ptr<::c10d::ProcessGroup>&,
361                     const c10::intrusive_ptr<::c10d::ReduceOp>&,
362                     int64_t)>();
363     return std::get<1>(op.call(
364         outputTensors,
365         inputTensors,
366         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
367         c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
368         opts.timeout.count()));
369   }
370 
_reduce_scatter_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,const ReduceScatterOptions & opts=ReduceScatterOptions ())371   virtual c10::intrusive_ptr<Work> _reduce_scatter_base(
372       at::Tensor& outputBuffer,
373       at::Tensor& inputBuffer,
374       const ReduceScatterOptions& opts = ReduceScatterOptions()) {
375     static auto op =
376         c10::Dispatcher::singleton()
377             .findSchemaOrThrow("c10d::_reduce_scatter_base_", "")
378             .typed<std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(
379                 at::Tensor&,
380                 at::Tensor&,
381                 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
382                 const c10::intrusive_ptr<::c10d::ReduceOp>&,
383                 bool,
384                 int64_t)>();
385     return std::get<1>(op.call(
386         outputBuffer,
387         inputBuffer,
388         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
389         c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
390         opts.asyncOp,
391         opts.timeout.count()));
392   }
393 
394   // This function is a coalesced version of `reduce_scatter_tensor` (currently
395   // still named as `_reduce_scatter_base`). Each tensor in the vector
396   // corresponds to an input/output of one `reduce_scatter_tensor` operation.
reduce_scatter_tensor_coalesced(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const ReduceScatterOptions & opts=ReduceScatterOptions ())397   virtual c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
398       std::vector<at::Tensor>& outputTensors,
399       std::vector<at::Tensor>& inputTensors,
400       const ReduceScatterOptions& opts = ReduceScatterOptions()) {
401     static auto op =
402         c10::Dispatcher::singleton()
403             .findSchemaOrThrow("c10d::reduce_scatter_tensor_coalesced_", "")
404             .typed<c10::intrusive_ptr<Work>(
405                 const at::TensorList,
406                 const at::TensorList,
407                 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
408                 const c10::intrusive_ptr<::c10d::ReduceOp>&,
409                 int64_t)>();
410 
411     return op.call(
412         outputTensors,
413         inputTensors,
414         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
415         c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
416         opts.timeout.count());
417   }
418 
alltoall_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,std::vector<int64_t> & outputSplitSizes,std::vector<int64_t> & inputSplitSizes,const AllToAllOptions & opts=AllToAllOptions ())419   virtual c10::intrusive_ptr<Work> alltoall_base(
420       at::Tensor& outputBuffer,
421       at::Tensor& inputBuffer,
422       std::vector<int64_t>& outputSplitSizes,
423       std::vector<int64_t>& inputSplitSizes,
424       const AllToAllOptions& opts = AllToAllOptions()) {
425     static auto op = c10::Dispatcher::singleton()
426                          .findSchemaOrThrow("c10d::alltoall_base_", "")
427                          .typed<c10::intrusive_ptr<::c10d::Work>(
428                              at::Tensor&,
429                              at::Tensor&,
430                              const c10::intrusive_ptr<::c10d::ProcessGroup>&,
431                              std::vector<int64_t>,
432                              std::vector<int64_t>,
433                              int64_t)>();
434     return op.call(
435         outputBuffer,
436         inputBuffer,
437         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
438         outputSplitSizes,
439         inputSplitSizes,
440         opts.timeout.count());
441   }
442 
alltoall(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllToAllOptions & opts=AllToAllOptions ())443   virtual c10::intrusive_ptr<Work> alltoall(
444       std::vector<at::Tensor>& outputTensors,
445       std::vector<at::Tensor>& inputTensors,
446       const AllToAllOptions& opts = AllToAllOptions()) {
447     static auto op =
448         c10::Dispatcher::singleton()
449             .findSchemaOrThrow("c10d::alltoall_", "")
450             .typed<
451                 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
452                     const at::TensorList&,
453                     const at::TensorList&,
454                     const c10::intrusive_ptr<::c10d::ProcessGroup>&,
455                     int64_t)>();
456     return std::get<1>(op.call(
457         outputTensors,
458         inputTensors,
459         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
460         opts.timeout.count()));
461   }
462 
monitoredBarrier(const BarrierOptions & opts,bool wait_all_ranks=false)463   virtual void monitoredBarrier(
464       const BarrierOptions& opts,
465       bool wait_all_ranks = false) {
466     static auto op = c10::Dispatcher::singleton()
467                          .findSchemaOrThrow("c10d::monitored_barrier_", "")
468                          .typed<void(
469                              at::Tensor,
470                              const c10::intrusive_ptr<::c10d::ProcessGroup>&,
471                              const std::vector<int64_t>&,
472                              int64_t,
473                              bool)>();
474     // Default to using cpu implementation, monitored barrier is only for GLOO
475     at::Tensor tensor = at::empty({0}, at::TensorOptions().device(at::kCPU));
476     op.call(
477         tensor,
478         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
479         opts.device_ids,
480         opts.timeout.count(),
481         wait_all_ranks);
482   }
483 
484   // Agrees on an initial sequence number for the whole group by having rank 0
485   // create it and broadcast it to other ranks using the store. Only implemented
486   // for GLOO and NCCL backends currently.
setSequenceNumberForGroup()487   virtual void setSequenceNumberForGroup() {
488     auto backendType = getBackendType();
489     // TODO: HACK for backend name to get sequence number for that backend.
490     if (backendType == ProcessGroup::BackendType::GLOO ||
491         backendType == ProcessGroup::BackendType::NCCL ||
492         backendType == ProcessGroup::BackendType::UCC) {
493       getDefaultBackend()->setSequenceNumberForGroup();
494     } else {
495       TORCH_CHECK(
496           false,
497           c10::str(
498               "ProcessGroup ",
499               getBackendName(),
500               " does not yet support sequence numbers."));
501     }
502   }
503 
504   // Retrieves the current sequence number for the whole group, which should be
505   // in sync. If the returned number is not consistent across the group, it
506   // may indicate that there is some sort of collective desynchronization.
getSequenceNumberForGroup()507   virtual uint64_t getSequenceNumberForGroup() {
508     auto backendType = getBackendType();
509 
510     // TODO: HACK for backend name to get sequence number for that backend.
511     if (backendType == ProcessGroup::BackendType::GLOO ||
512         backendType == ProcessGroup::BackendType::NCCL ||
513         backendType == ProcessGroup::BackendType::UCC) {
514       return getDefaultBackend()->getSequenceNumberForGroup();
515     } else {
516       TORCH_CHECK(
517           false,
518           c10::str(
519               "ProcessGroup ",
520               getBackendName(),
521               " does not yet support sequence numbers."));
522     }
523   }
524 
send(std::vector<at::Tensor> & tensors,int dstRank,int tag)525   virtual c10::intrusive_ptr<Work> send(
526       std::vector<at::Tensor>& tensors,
527       int dstRank,
528       int tag) {
529     static auto op = c10::Dispatcher::singleton()
530                          .findSchemaOrThrow("c10d::send", "")
531                          .typed<c10::intrusive_ptr<::c10d::Work>(
532                              at::TensorList,
533                              const c10::intrusive_ptr<::c10d::ProcessGroup>&,
534                              int64_t,
535                              int64_t)>();
536     return op.call(
537         tensors,
538         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
539         dstRank,
540         tag);
541   }
542 
recv(std::vector<at::Tensor> & tensors,int srcRank,int tag)543   virtual c10::intrusive_ptr<Work> recv(
544       std::vector<at::Tensor>& tensors,
545       int srcRank,
546       int tag) {
547     static auto op = c10::Dispatcher::singleton()
548                          .findSchemaOrThrow("c10d::recv_", "")
549                          .typed<c10::intrusive_ptr<::c10d::Work>(
550                              at::TensorList,
551                              const c10::intrusive_ptr<::c10d::ProcessGroup>&,
552                              int64_t,
553                              int64_t)>();
554     return op.call(
555         tensors,
556         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
557         srcRank,
558         tag);
559   }
560 
recvAnysource(std::vector<at::Tensor> & tensors,int tag)561   virtual c10::intrusive_ptr<Work> recvAnysource(
562       std::vector<at::Tensor>& tensors,
563       int tag) {
564     static auto op = c10::Dispatcher::singleton()
565                          .findSchemaOrThrow("c10d::recv_any_source_", "")
566                          .typed<c10::intrusive_ptr<::c10d::Work>(
567                              at::TensorList,
568                              const c10::intrusive_ptr<::c10d::ProcessGroup>&,
569                              int64_t)>();
570     return op.call(
571         tensors,
572         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
573         tag);
574   }
575 
barrier(const BarrierOptions & opts=BarrierOptions ())576   virtual c10::intrusive_ptr<Work> barrier(
577       const BarrierOptions& opts = BarrierOptions()) {
578     static at::Tensor tensor;
579     // TODO: if nccl was specified then use it
580     auto device = opts.device;
581     if (device.has_value()) {
582       // set device tensor from argument
583       tensor = at::empty(
584           {1}, at::TensorOptions().device(device.value()).dtype(at::kByte));
585     } else if (backendType_ == c10d::ProcessGroup::BackendType::NCCL) {
586       // set cuda tensor
587       tensor = at::empty(
588           {1},
589           at::TensorOptions().device(at::DeviceType::CUDA).dtype(at::kByte));
590     } else {
591       // Default to using cpu implementation
592       tensor = at::empty(
593           {1},
594           at::TensorOptions().device(at::DeviceType::CPU).dtype(at::kByte));
595     }
596 
597     static auto op = c10::Dispatcher::singleton()
598                          .findSchemaOrThrow("c10d::barrier", "")
599                          .typed<c10::intrusive_ptr<::c10d::Work>(
600                              at::Tensor,
601                              const c10::intrusive_ptr<::c10d::ProcessGroup>&,
602                              const std::vector<int64_t>&,
603                              int64_t)>();
604 
605     return op.call(
606         tensor,
607         c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
608         opts.device_ids,
609         opts.timeout.count());
610   }
611 
getOptions()612   c10::intrusive_ptr<Options> getOptions() {
613     return options_;
614   }
615 
hasBackends()616   bool hasBackends() {
617     return !deviceTypeToBackendType_.empty();
618   }
619 
setBackend(c10::DeviceType deviceType,BackendType backendType,const std::optional<c10::intrusive_ptr<Backend>> & backend)620   void setBackend(
621       c10::DeviceType deviceType,
622       BackendType backendType,
623       const std::optional<c10::intrusive_ptr<Backend>>& backend) {
624     // TODO: should we add these entries after the backend setting succeeds?
625     deviceTypeToBackendType_[deviceType] = backendType;
626     deviceTypes_.insert(deviceType);
627     // if the backendType is already set then reuse it for this device
628     if (backendTypeToBackend_.find(backendType) !=
629         backendTypeToBackend_.end()) {
630       auto existingBackend = backendTypeToBackend_.at(backendType);
631       deviceTypeToBackend_[deviceType] = existingBackend;
632       TORCH_CHECK(
633           existingBackend->getBoundDeviceId() ==
634           (*backend)->getBoundDeviceId());
635     } else {
636       // check if backend has value
637       if (backend.has_value()) {
638         deviceTypeToBackend_[deviceType] = backend.value();
639         backendTypeToBackend_[backendType] = backend.value();
640         (*backend)->setBoundDeviceId(bound_device_id_);
641       }
642     }
643   }
644 
getDefaultBackend() const645   c10::intrusive_ptr<Backend> getDefaultBackend() const {
646     TORCH_CHECK(
647         backendTypeToBackend_.find(backendType_) != backendTypeToBackend_.end(),
648         "Could not find the default backend type ",
649         backendType_,
650         " for Process Group with name ",
651         getBackendName(),
652         ".");
653     return backendTypeToBackend_.at(backendType_);
654   }
655 
656   c10::intrusive_ptr<Backend> getBackend(c10::DeviceType deviceType);
657 
getBackend(BackendType backendType) const658   c10::intrusive_ptr<Backend> getBackend(BackendType backendType) const {
659     TORCH_CHECK(
660         backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end(),
661         "Could not find backend type ",
662         backendType,
663         ".");
664     return backendTypeToBackend_.at(backendType);
665   }
666 
667   // Return device types supported by this ProcessGroup.
668   // Note: the return type is `Device` rather than `DeviceType` for the purpose
669   // of easy comparison at Python level. The `Device` will have default index
670   // (-1).
getDeviceTypes() const671   std::vector<c10::Device> getDeviceTypes() const {
672     std::vector<c10::Device> devices;
673     devices.reserve(deviceTypes_.size());
674     for (auto& dt : deviceTypes_) {
675       devices.emplace_back(dt);
676     }
677     return devices;
678   }
679 
registerOnCompletionHook(std::function<void (std::shared_ptr<WorkInfo>)> && hook)680   void registerOnCompletionHook(
681       std::function<void(std::shared_ptr<WorkInfo>)>&& hook) {
682     getDefaultBackend()->registerOnCompletionHook(std::move(hook));
683   }
684 
waitForPendingWorks()685   void waitForPendingWorks() {
686     getDefaultBackend()->waitForPendingWorks();
687   }
688 
hasHooks() const689   bool hasHooks() const {
690     return getDefaultBackend()->hasHooks();
691   }
692 
693   const std::string& getGroupName() const;
694   void setGroupName(const std::string& name);
695   const std::string& getGroupDesc() const;
696   void setGroupDesc(const std::string& name);
697   void enableCollectivesTiming();
698 
699   void release_resources() override;
700 
701   // ProcessGroups optionally can be "bound" to a specific device.
702   // Currently this is only for nccl and allows for some opt-in
703   // optimizations such as automatic use of ncclCommSplit.  The device
704   // is specified in `init_process_group` and eventually makes it
705   // here and then down into the actual backend instances.
getBoundDeviceId() const706   std::optional<at::Device> getBoundDeviceId() const {
707     return bound_device_id_;
708   }
709 
setBoundDeviceId(std::optional<at::Device> device)710   void setBoundDeviceId(std::optional<at::Device> device) {
711     if (device) {
712       TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index");
713     }
714     bound_device_id_ = device;
715   }
716 
717  protected:
718   // Implementations of this interface need to call this to setup
719   // appropriate logging etc.
720   void init();
721 
722   c10::intrusive_ptr<c10d::Store> store_;
723   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
724   const int rank_;
725   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
726   const int size_;
727   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
728   const c10::intrusive_ptr<Options> options_;
729   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
730   const BackendType backendType_;
731   std::string pg_desc_;
732 
733   // Debug level setting. It is parsed once when ProcessGroup is constructed and
734   // remains the same across use of this process group.
735   DebugLevel dist_debug_level_{DebugLevel::Off};
736 
737   // Backend classes for this ProcessGroup
738   std::unordered_set<c10::DeviceType> deviceTypes_;
739   std::unordered_map<c10::DeviceType, BackendType> deviceTypeToBackendType_;
740   std::unordered_map<c10::DeviceType, c10::intrusive_ptr<Backend>>
741       deviceTypeToBackend_;
742   std::unordered_map<BackendType, c10::intrusive_ptr<Backend>>
743       backendTypeToBackend_;
744 
745   std::optional<at::Device> bound_device_id_;
746 };
747 
748 } // namespace c10d
749