xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_C10D_GLOO
4 
5 #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
6 #include <torch/csrc/distributed/c10d/Types.hpp>
7 #include <torch/csrc/distributed/c10d/Utils.hpp>
8 
9 namespace c10d {
10 
11 class TORCH_API ProcessGroupWrapper : public Backend {
12  public:
13   explicit ProcessGroupWrapper(
14       const c10::intrusive_ptr<Backend>& backend,
15       c10::intrusive_ptr<Backend> glooBackend);
16 
17   const std::string getBackendName() const override;
18 
19   c10::intrusive_ptr<Work> broadcast(
20       std::vector<at::Tensor>& data,
21       const BroadcastOptions& opts = BroadcastOptions()) override;
22 
23   c10::intrusive_ptr<Work> allreduce(
24       std::vector<at::Tensor>& data,
25       const AllreduceOptions& opts = AllreduceOptions()) override;
26 
27   c10::intrusive_ptr<Work> allreduce_coalesced(
28       std::vector<at::Tensor>& tensors,
29       const AllreduceCoalescedOptions& opts =
30           AllreduceCoalescedOptions()) override;
31 
32   c10::intrusive_ptr<Work> reduce(
33       std::vector<at::Tensor>& tensors,
34       const ReduceOptions& opts = ReduceOptions()) override;
35 
36   c10::intrusive_ptr<Work> allgather(
37       std::vector<std::vector<at::Tensor>>& outputTensors,
38       std::vector<at::Tensor>& inputTensors,
39       const AllgatherOptions& opts = AllgatherOptions()) override;
40 
41   c10::intrusive_ptr<Work> _allgather_base(
42       at::Tensor& outputBuffer,
43       at::Tensor& inputBuffer,
44       const AllgatherOptions& opts = AllgatherOptions()) override;
45 
46   // This function is deprecated and will be moved out of ProcessGroup to comms:
47   // * do not add dependencies on this function,
48   // * do not implement it in your ProcessGroup, implement _allgather_base
49   //   instead.
50   c10::intrusive_ptr<Work> allgather_coalesced(
51       std::vector<std::vector<at::Tensor>>& outputTensorLists,
52       std::vector<at::Tensor>& inputTensors,
53       const AllgatherOptions& opts = AllgatherOptions()) override;
54 
55   c10::intrusive_ptr<Work> gather(
56       std::vector<std::vector<at::Tensor>>& outputTensors,
57       std::vector<at::Tensor>& inputTensors,
58       const GatherOptions& opts = GatherOptions()) override;
59 
60   c10::intrusive_ptr<Work> scatter(
61       std::vector<at::Tensor>& outputTensors,
62       std::vector<std::vector<at::Tensor>>& inputTensors,
63       const ScatterOptions& opts = ScatterOptions()) override;
64 
65   c10::intrusive_ptr<Work> reduce_scatter(
66       std::vector<at::Tensor>& outputTensors,
67       std::vector<std::vector<at::Tensor>>& inputTensors,
68       const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
69 
70   c10::intrusive_ptr<Work> alltoall_base(
71       at::Tensor& outputTensor,
72       at::Tensor& inputTensor,
73       std::vector<int64_t>& outputSplitSizes,
74       std::vector<int64_t>& inputSplitSizes,
75       const AllToAllOptions& opts = AllToAllOptions()) override;
76 
77   c10::intrusive_ptr<Work> alltoall(
78       std::vector<at::Tensor>& outputTensors,
79       std::vector<at::Tensor>& inputTensors,
80       const AllToAllOptions& opts = AllToAllOptions()) override;
81 
82   void monitoredBarrier(const BarrierOptions& opts, bool waitAllRanks = false)
83       override;
84 
85   // Agrees on an initial sequence number for the whole group by having rank 0
86   // create it and broadcast it to other ranks using the store. Only implemented
87   // for GLOO and NCCL backends currently.
88   // dont implement this
89   void setSequenceNumberForGroup() override;
90 
91   // Retrieves the current sequence number for the whole group, which should be
92   // in sync. If the returned number is not consistent across the group, it
93   // may indicate that there is some sort of collective desynchronization.
94   uint64_t getSequenceNumberForGroup() override; // just call underlying
95 
96   c10::intrusive_ptr<Work> send(
97       std::vector<at::Tensor>& tensors,
98       int dstRank,
99       int tag) override;
100 
101   c10::intrusive_ptr<Work> recv(
102       std::vector<at::Tensor>& tensors,
103       int srcRank,
104       int tag) override;
105 
106   c10::intrusive_ptr<Work> recvAnysource(
107       std::vector<at::Tensor>& tensors,
108       int tag) override;
109 
110   c10::intrusive_ptr<Work> barrier(
111       const BarrierOptions& opts = BarrierOptions()) override;
112 
113   c10::intrusive_ptr<Work> _reduce_scatter_base(
114       at::Tensor& outputBuffer,
115       at::Tensor& inputBuffer,
116       const ReduceScatterOptions& opts) override;
117 
118   void startCoalescing() override;
119 
120   c10::intrusive_ptr<Work> endCoalescing() override;
121 
122   c10::intrusive_ptr<Backend> getWrappedPg() const;
123 
124  private:
125   // Underlying process group that actual application collectives will be
126   // dispatched to
127   c10::intrusive_ptr<Backend> backend_;
128   // Gloo process group responsible for internal coordination such as monitored
129   // barrier, sequence number checking, collective fingerprint collecting.
130   c10::intrusive_ptr<Backend> glooBackend_;
131   // Conducts several checks to ensure that the underlying collective is well
132   // formed with the goal of notifying the user about incorrect collective use
133   // in the application.
134   void runCollectiveChecks(
135       OpType op_type,
136       const std::vector<at::Tensor>& tensors);
137 };
138 } // namespace c10d
139 
140 #endif // USE_C10D_GLOO
141