xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/ProcessGroup.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ThreadLocalState.h>
2 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
3 
4 #include <c10/util/Logging.h>
5 #include <fmt/format.h>
6 #include <string_view>
7 
8 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
9 #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
10 #include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
11 #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
12 #include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
13 #include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp>
14 
15 namespace c10d {
16 
strToBackendType(std::string_view backend)17 static ProcessGroup::BackendType strToBackendType(std::string_view backend) {
18   if (backend == "undefined") {
19     return ProcessGroup::BackendType::UNDEFINED;
20   } else if (backend == "gloo") {
21     return ProcessGroup::BackendType::GLOO;
22   } else if (backend == "nccl") {
23     return ProcessGroup::BackendType::NCCL;
24   } else if (backend == "ucc") {
25     return ProcessGroup::BackendType::UCC;
26   } else if (backend == "mpi") {
27     return ProcessGroup::BackendType::MPI;
28   } else {
29     return ProcessGroup::BackendType::CUSTOM;
30   }
31 }
32 
opTypeToString(OpType opType)33 std::string opTypeToString(OpType opType) {
34   switch (opType) {
35     case OpType::BROADCAST:
36       return "BROADCAST";
37     case OpType::ALLREDUCE:
38       return "ALLREDUCE";
39     case OpType::ALLREDUCE_COALESCED:
40       return "ALLREDUCE_COALESCED";
41     case OpType::REDUCE:
42       return "REDUCE";
43     case OpType::ALLGATHER:
44       return "ALLGATHER";
45     case OpType::_ALLGATHER_BASE:
46       return "_ALLGATHER_BASE";
47     case OpType::ALLGATHER_COALESCED:
48       return "ALLGATHER_COALESCED";
49     case OpType::GATHER:
50       return "GATHER";
51     case OpType::SCATTER:
52       return "SCATTER";
53     case OpType::REDUCE_SCATTER:
54       return "REDUCE_SCATTER";
55     case OpType::ALLTOALL_BASE:
56       return "ALLTOALL_BASE";
57     case OpType::ALLTOALL:
58       return "ALLTOALL";
59     case OpType::SEND:
60       return "SEND";
61     case OpType::RECV:
62       return "RECV";
63     case OpType::RECVANYSOURCE:
64       return "RECVANYSOURCE";
65     case OpType::BARRIER:
66       return "BARRIER";
67     case OpType::UNKNOWN:
68       return "UNKNOWN";
69     case OpType::_REDUCE_SCATTER_BASE:
70       return "_REDUCE_SCATTER_BASE";
71     case OpType::COALESCED:
72       return "COALESCED";
73     case OpType::_ALLREDUCE_SPARSE:
74       return "_ALLREDUCE_SPARSE";
75     default:
76       TORCH_INTERNAL_ASSERT(false, "Unknown op type!");
77   }
78   return "UNKNOWN";
79 }
80 
isP2POp(OpType opType,bool batchP2P)81 bool isP2POp(OpType opType, bool batchP2P /*= false*/) {
82   if (batchP2P)
83     return false;
84   return opType == OpType::SEND || opType == OpType::RECV ||
85       opType == OpType::RECVANYSOURCE;
86 }
87 
getBackend(c10::DeviceType deviceType)88 c10::intrusive_ptr<Backend> ProcessGroup::getBackend(
89     c10::DeviceType deviceType) {
90   // If there is a backend associated with this device type then return it
91   if (deviceTypeToBackend_.find(deviceType) != deviceTypeToBackend_.end()) {
92     return deviceTypeToBackend_.at(deviceType);
93   }
94 
95   // Get the backend type associated with the device
96   ProcessGroup::BackendType backendType{ProcessGroup::BackendType::UNDEFINED};
97   try {
98     backendType = deviceTypeToBackendType_.at(deviceType);
99   } catch (const std::out_of_range& e) {
100     TORCH_CHECK(
101         false, "No backend type associated with device type ", deviceType);
102   }
103 
104   // Check if the backend has already been initialized
105   if (backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end()) {
106     auto backend = backendTypeToBackend_.at(backendType);
107     deviceTypeToBackend_[deviceType] = backend;
108     return backend;
109   }
110 
111   TORCH_CHECK(
112       false,
113       "Could not retrieve or create the backend ",
114       backendType,
115       " for device type ",
116       deviceType);
117 }
118 
ProcessGroup(const c10::intrusive_ptr<::c10d::Store> & store,int rank,int size,c10::intrusive_ptr<Options> options)119 ProcessGroup::ProcessGroup(
120     const c10::intrusive_ptr<::c10d::Store>& store,
121     int rank,
122     int size,
123     c10::intrusive_ptr<Options> options)
124     : store_(store),
125       rank_(rank),
126       size_(size),
127       options_(std::move(options)),
128       backendType_(strToBackendType(options_->backend)),
129       dist_debug_level_(debug_level()) {
130   C10_LOG_API_USAGE_ONCE("c10d.process_group");
131 }
132 
ProcessGroup(int rank,int size)133 ProcessGroup::ProcessGroup(int rank, int size)
134     : rank_(rank), size_(size), backendType_(BackendType::UNDEFINED) {}
135 
136 ProcessGroup::~ProcessGroup() = default;
137 
init()138 void ProcessGroup::init() {
139   C10_LOG_API_USAGE_ONCE(
140       fmt::format("c10d.process_group_{}", getBackendName()));
141 }
142 
getGroupName() const143 const std::string& ProcessGroup::getGroupName() const {
144   TORCH_CHECK(!deviceTypeToBackend_.empty(), "ProcessGroup name not set");
145   return deviceTypeToBackend_.begin()->second->getGroupUid();
146 }
147 
setGroupName(const std::string & name)148 void ProcessGroup::setGroupName(const std::string& name) {
149   for (auto& kv : deviceTypeToBackend_) {
150     kv.second->setGroupUid(name);
151   }
152 }
153 
getGroupDesc() const154 const std::string& ProcessGroup::getGroupDesc() const {
155   return pg_desc_;
156 }
157 
setGroupDesc(const std::string & name)158 void ProcessGroup::setGroupDesc(const std::string& name) {
159   pg_desc_ = name;
160   // Also set the group desc for all backends
161   for (auto& kv : deviceTypeToBackend_) {
162     kv.second->setGroupDesc(name);
163   }
164 }
165 
enableCollectivesTiming()166 void ProcessGroup::enableCollectivesTiming() {
167   for (auto& kv : deviceTypeToBackend_) {
168     kv.second->enableCollectivesTiming();
169   }
170 }
171 
release_resources()172 void ProcessGroup::release_resources() {
173   store_.reset();
174   deviceTypeToBackend_.clear();
175   backendTypeToBackend_.clear();
176 }
177 
178 } // namespace c10d
179