1 #pragma once 2 3 #ifdef USE_C10D_GLOO 4 5 #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp> 6 #include <torch/csrc/distributed/c10d/Types.hpp> 7 #include <torch/csrc/distributed/c10d/Utils.hpp> 8 9 namespace c10d { 10 11 class TORCH_API ProcessGroupWrapper : public Backend { 12 public: 13 explicit ProcessGroupWrapper( 14 const c10::intrusive_ptr<Backend>& backend, 15 c10::intrusive_ptr<Backend> glooBackend); 16 17 const std::string getBackendName() const override; 18 19 c10::intrusive_ptr<Work> broadcast( 20 std::vector<at::Tensor>& data, 21 const BroadcastOptions& opts = BroadcastOptions()) override; 22 23 c10::intrusive_ptr<Work> allreduce( 24 std::vector<at::Tensor>& data, 25 const AllreduceOptions& opts = AllreduceOptions()) override; 26 27 c10::intrusive_ptr<Work> allreduce_coalesced( 28 std::vector<at::Tensor>& tensors, 29 const AllreduceCoalescedOptions& opts = 30 AllreduceCoalescedOptions()) override; 31 32 c10::intrusive_ptr<Work> reduce( 33 std::vector<at::Tensor>& tensors, 34 const ReduceOptions& opts = ReduceOptions()) override; 35 36 c10::intrusive_ptr<Work> allgather( 37 std::vector<std::vector<at::Tensor>>& outputTensors, 38 std::vector<at::Tensor>& inputTensors, 39 const AllgatherOptions& opts = AllgatherOptions()) override; 40 41 c10::intrusive_ptr<Work> _allgather_base( 42 at::Tensor& outputBuffer, 43 at::Tensor& inputBuffer, 44 const AllgatherOptions& opts = AllgatherOptions()) override; 45 46 // This function is deprecated and will be moved out of ProcessGroup to comms: 47 // * do not add dependencies on this function, 48 // * do not implement it in your ProcessGroup, implement _allgather_base 49 // instead. 50 c10::intrusive_ptr<Work> allgather_coalesced( 51 std::vector<std::vector<at::Tensor>>& outputTensorLists, 52 std::vector<at::Tensor>& inputTensors, 53 const AllgatherOptions& opts = AllgatherOptions()) override; 54 55 c10::intrusive_ptr<Work> gather( 56 std::vector<std::vector<at::Tensor>>& outputTensors, 57 std::vector<at::Tensor>& inputTensors, 58 const GatherOptions& opts = GatherOptions()) override; 59 60 c10::intrusive_ptr<Work> scatter( 61 std::vector<at::Tensor>& outputTensors, 62 std::vector<std::vector<at::Tensor>>& inputTensors, 63 const ScatterOptions& opts = ScatterOptions()) override; 64 65 c10::intrusive_ptr<Work> reduce_scatter( 66 std::vector<at::Tensor>& outputTensors, 67 std::vector<std::vector<at::Tensor>>& inputTensors, 68 const ReduceScatterOptions& opts = ReduceScatterOptions()) override; 69 70 c10::intrusive_ptr<Work> alltoall_base( 71 at::Tensor& outputTensor, 72 at::Tensor& inputTensor, 73 std::vector<int64_t>& outputSplitSizes, 74 std::vector<int64_t>& inputSplitSizes, 75 const AllToAllOptions& opts = AllToAllOptions()) override; 76 77 c10::intrusive_ptr<Work> alltoall( 78 std::vector<at::Tensor>& outputTensors, 79 std::vector<at::Tensor>& inputTensors, 80 const AllToAllOptions& opts = AllToAllOptions()) override; 81 82 void monitoredBarrier(const BarrierOptions& opts, bool waitAllRanks = false) 83 override; 84 85 // Agrees on an initial sequence number for the whole group by having rank 0 86 // create it and broadcast it to other ranks using the store. Only implemented 87 // for GLOO and NCCL backends currently. 88 // dont implement this 89 void setSequenceNumberForGroup() override; 90 91 // Retrieves the current sequence number for the whole group, which should be 92 // in sync. If the returned number is not consistent across the group, it 93 // may indicate that there is some sort of collective desynchronization. 94 uint64_t getSequenceNumberForGroup() override; // just call underlying 95 96 c10::intrusive_ptr<Work> send( 97 std::vector<at::Tensor>& tensors, 98 int dstRank, 99 int tag) override; 100 101 c10::intrusive_ptr<Work> recv( 102 std::vector<at::Tensor>& tensors, 103 int srcRank, 104 int tag) override; 105 106 c10::intrusive_ptr<Work> recvAnysource( 107 std::vector<at::Tensor>& tensors, 108 int tag) override; 109 110 c10::intrusive_ptr<Work> barrier( 111 const BarrierOptions& opts = BarrierOptions()) override; 112 113 c10::intrusive_ptr<Work> _reduce_scatter_base( 114 at::Tensor& outputBuffer, 115 at::Tensor& inputBuffer, 116 const ReduceScatterOptions& opts) override; 117 118 void startCoalescing() override; 119 120 c10::intrusive_ptr<Work> endCoalescing() override; 121 122 c10::intrusive_ptr<Backend> getWrappedPg() const; 123 124 private: 125 // Underlying process group that actual application collectives will be 126 // dispatched to 127 c10::intrusive_ptr<Backend> backend_; 128 // Gloo process group responsible for internal coordination such as monitored 129 // barrier, sequence number checking, collective fingerprint collecting. 130 c10::intrusive_ptr<Backend> glooBackend_; 131 // Conducts several checks to ensure that the underlying collective is well 132 // formed with the goal of notifying the user about incorrect collective use 133 // in the application. 134 void runCollectiveChecks( 135 OpType op_type, 136 const std::vector<at::Tensor>& tensors); 137 }; 138 } // namespace c10d 139 140 #endif // USE_C10D_GLOO 141