1 #include <ATen/ThreadLocalState.h>
2 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
3
4 #include <c10/util/Logging.h>
5 #include <fmt/format.h>
6 #include <string_view>
7
8 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
9 #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
10 #include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
11 #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
12 #include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
13 #include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp>
14
15 namespace c10d {
16
strToBackendType(std::string_view backend)17 static ProcessGroup::BackendType strToBackendType(std::string_view backend) {
18 if (backend == "undefined") {
19 return ProcessGroup::BackendType::UNDEFINED;
20 } else if (backend == "gloo") {
21 return ProcessGroup::BackendType::GLOO;
22 } else if (backend == "nccl") {
23 return ProcessGroup::BackendType::NCCL;
24 } else if (backend == "ucc") {
25 return ProcessGroup::BackendType::UCC;
26 } else if (backend == "mpi") {
27 return ProcessGroup::BackendType::MPI;
28 } else {
29 return ProcessGroup::BackendType::CUSTOM;
30 }
31 }
32
opTypeToString(OpType opType)33 std::string opTypeToString(OpType opType) {
34 switch (opType) {
35 case OpType::BROADCAST:
36 return "BROADCAST";
37 case OpType::ALLREDUCE:
38 return "ALLREDUCE";
39 case OpType::ALLREDUCE_COALESCED:
40 return "ALLREDUCE_COALESCED";
41 case OpType::REDUCE:
42 return "REDUCE";
43 case OpType::ALLGATHER:
44 return "ALLGATHER";
45 case OpType::_ALLGATHER_BASE:
46 return "_ALLGATHER_BASE";
47 case OpType::ALLGATHER_COALESCED:
48 return "ALLGATHER_COALESCED";
49 case OpType::GATHER:
50 return "GATHER";
51 case OpType::SCATTER:
52 return "SCATTER";
53 case OpType::REDUCE_SCATTER:
54 return "REDUCE_SCATTER";
55 case OpType::ALLTOALL_BASE:
56 return "ALLTOALL_BASE";
57 case OpType::ALLTOALL:
58 return "ALLTOALL";
59 case OpType::SEND:
60 return "SEND";
61 case OpType::RECV:
62 return "RECV";
63 case OpType::RECVANYSOURCE:
64 return "RECVANYSOURCE";
65 case OpType::BARRIER:
66 return "BARRIER";
67 case OpType::UNKNOWN:
68 return "UNKNOWN";
69 case OpType::_REDUCE_SCATTER_BASE:
70 return "_REDUCE_SCATTER_BASE";
71 case OpType::COALESCED:
72 return "COALESCED";
73 case OpType::_ALLREDUCE_SPARSE:
74 return "_ALLREDUCE_SPARSE";
75 default:
76 TORCH_INTERNAL_ASSERT(false, "Unknown op type!");
77 }
78 return "UNKNOWN";
79 }
80
isP2POp(OpType opType,bool batchP2P)81 bool isP2POp(OpType opType, bool batchP2P /*= false*/) {
82 if (batchP2P)
83 return false;
84 return opType == OpType::SEND || opType == OpType::RECV ||
85 opType == OpType::RECVANYSOURCE;
86 }
87
getBackend(c10::DeviceType deviceType)88 c10::intrusive_ptr<Backend> ProcessGroup::getBackend(
89 c10::DeviceType deviceType) {
90 // If there is a backend associated with this device type then return it
91 if (deviceTypeToBackend_.find(deviceType) != deviceTypeToBackend_.end()) {
92 return deviceTypeToBackend_.at(deviceType);
93 }
94
95 // Get the backend type associated with the device
96 ProcessGroup::BackendType backendType{ProcessGroup::BackendType::UNDEFINED};
97 try {
98 backendType = deviceTypeToBackendType_.at(deviceType);
99 } catch (const std::out_of_range& e) {
100 TORCH_CHECK(
101 false, "No backend type associated with device type ", deviceType);
102 }
103
104 // Check if the backend has already been initialized
105 if (backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end()) {
106 auto backend = backendTypeToBackend_.at(backendType);
107 deviceTypeToBackend_[deviceType] = backend;
108 return backend;
109 }
110
111 TORCH_CHECK(
112 false,
113 "Could not retrieve or create the backend ",
114 backendType,
115 " for device type ",
116 deviceType);
117 }
118
ProcessGroup(const c10::intrusive_ptr<::c10d::Store> & store,int rank,int size,c10::intrusive_ptr<Options> options)119 ProcessGroup::ProcessGroup(
120 const c10::intrusive_ptr<::c10d::Store>& store,
121 int rank,
122 int size,
123 c10::intrusive_ptr<Options> options)
124 : store_(store),
125 rank_(rank),
126 size_(size),
127 options_(std::move(options)),
128 backendType_(strToBackendType(options_->backend)),
129 dist_debug_level_(debug_level()) {
130 C10_LOG_API_USAGE_ONCE("c10d.process_group");
131 }
132
ProcessGroup(int rank,int size)133 ProcessGroup::ProcessGroup(int rank, int size)
134 : rank_(rank), size_(size), backendType_(BackendType::UNDEFINED) {}
135
136 ProcessGroup::~ProcessGroup() = default;
137
init()138 void ProcessGroup::init() {
139 C10_LOG_API_USAGE_ONCE(
140 fmt::format("c10d.process_group_{}", getBackendName()));
141 }
142
getGroupName() const143 const std::string& ProcessGroup::getGroupName() const {
144 TORCH_CHECK(!deviceTypeToBackend_.empty(), "ProcessGroup name not set");
145 return deviceTypeToBackend_.begin()->second->getGroupUid();
146 }
147
setGroupName(const std::string & name)148 void ProcessGroup::setGroupName(const std::string& name) {
149 for (auto& kv : deviceTypeToBackend_) {
150 kv.second->setGroupUid(name);
151 }
152 }
153
getGroupDesc() const154 const std::string& ProcessGroup::getGroupDesc() const {
155 return pg_desc_;
156 }
157
setGroupDesc(const std::string & name)158 void ProcessGroup::setGroupDesc(const std::string& name) {
159 pg_desc_ = name;
160 // Also set the group desc for all backends
161 for (auto& kv : deviceTypeToBackend_) {
162 kv.second->setGroupDesc(name);
163 }
164 }
165
enableCollectivesTiming()166 void ProcessGroup::enableCollectivesTiming() {
167 for (auto& kv : deviceTypeToBackend_) {
168 kv.second->enableCollectivesTiming();
169 }
170 }
171
release_resources()172 void ProcessGroup::release_resources() {
173 store_.reset();
174 deviceTypeToBackend_.clear();
175 backendTypeToBackend_.clear();
176 }
177
178 } // namespace c10d
179