xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/NCCLUtils.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_C10D_NCCL
4 
5 #include <stdio.h>
6 #include <stdlib.h>
7 
8 #include <memory>
9 #include <mutex>
10 #include <thread>
11 
12 #include <ATen/ATen.h>
13 #include <ATen/cuda/CUDAEvent.h>
14 #include <c10/util/Exception.h>
15 #include <nccl.h>
16 #include <torch/csrc/distributed/c10d/TraceUtils.h>
17 #include <optional>
18 
19 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
20     (NCCL_MINOR >= 14)
21 #define NCCL_HAS_COMM_NONBLOCKING
22 #endif
23 
24 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
25     (NCCL_MINOR >= 18)
26 #define NCCL_HAS_COMM_SPLIT
27 #endif
28 
29 // ncclGetLastError() is enabled only for NCCL versions 2.13+
30 // ncclRemoteError only exists in NCCL versions 2.13+
31 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
32     (NCCL_MINOR >= 13)
33 #define ENABLE_NCCL_GET_LAST_ERROR
34 #define NCCL_REMOTE_ERROR
35 #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
36 #define ENABLE_NCCL_GET_LAST_ERROR
37 #define NCCL_REMOTE_ERROR
38 #endif
39 
40 // Error checking is enabled only for NCCL versions 2.4+ since ncclCommAbort()
41 // and ncclCommGetAsyncError() are not supported in earlier versions.
42 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
43     (NCCL_MINOR >= 4)
44 #define ENABLE_NCCL_ERROR_CHECKING
45 #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
46 #define ENABLE_NCCL_ERROR_CHECKING
47 #endif
48 
49 // P2P is enabled only for NCCL versions 2.7+ since ncclSend()
50 // and ncclRecv() are not supported in earlier versions.
51 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
52     (NCCL_MINOR >= 7)
53 #define ENABLE_NCCL_P2P_SUPPORT
54 #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
55 #define ENABLE_NCCL_P2P_SUPPORT
56 #endif
57 
58 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
59     (NCCL_MINOR >= 11)
60 #define ENABLE_NCCL_PREMUL_SUM_SUPPORT
61 #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
62 #define ENABLE_NCCL_PREMUL_SUM_SUPPORT
63 #endif
64 
65 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
66     (NCCL_MINOR >= 17)
67 #define NCCL_HAS_COMM_CTA_CGA
68 #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
69 #define NCCL_HAS_COMM_CTA_CGA
70 #endif
71 
72 #if defined(NCCL_REGISTRATION_SUPPORTED) ||                              \
73     ((defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
74       (NCCL_MINOR >= 19)))
75 #define NCCL_HAS_COMM_REGISTER
76 #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
77 #define NCCL_HAS_COMM_REGISTER
78 #endif
79 
80 // Macro to throw on a non-successful NCCL return value.
81 #define C10D_NCCL_CHECK(cmd, failureReason)                                   \
82   do {                                                                        \
83     ncclResult_t result = cmd;                                                \
84     if (result != ncclSuccess) {                                              \
85       std::string err = "NCCL error in: " + std::string(__FILE__) + ":" +     \
86           std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
87           "\n" + getNcclErrorDetailStr(result, failureReason);                \
88       TORCH_CHECK_WITH(DistBackendError, false, err);                         \
89     }                                                                         \
90   } while (0)
91 
92 // Macro to throw on a non-successful NCCL return value for NONBLOCKING calls.
93 #define C10D_NCCL_CHECK_NONBLOCKING(cmd, failureReason)                       \
94   do {                                                                        \
95     ncclResult_t result = cmd;                                                \
96     if (result != ncclSuccess && result != ncclInProgress) {                  \
97       std::string err = "NCCL error in: " + std::string(__FILE__) + ":" +     \
98           std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
99           "\n" + getNcclErrorDetailStr(result, failureReason);                \
100       TORCH_CHECK_WITH(DistBackendError, false, err);                         \
101     }                                                                         \
102   } while (0)
103 
104 // Macro to throw on a non-successful NCCL return value, non-blocking.
105 #define C10D_NCCL_CHECK_TIMEOUT(cmd, comm, failureReason)                     \
106   ncclResult_t result = cmd;                                                  \
107   auto startTimepoint = std::chrono::steady_clock::now();                     \
108   while (result == ncclInProgress) {                                          \
109     if (nccl_nonblocking_timeout() > 0) {                                     \
110       auto currentTimepoint = std::chrono::steady_clock::now();               \
111       auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(    \
112                              currentTimepoint - startTimepoint)               \
113                              .count();                                        \
114       if (timeElapsed > nccl_nonblocking_timeout()) {                         \
115         std::string err = "NCCL timeout in: " + std::string(__FILE__) + ":" + \
116             std::to_string(__LINE__) + ", " +                                 \
117             ncclGetErrorWithVersion(result) + "\n" +                          \
118             getNcclErrorDetailStr(result, failureReason);                     \
119         TORCH_CHECK_WITH(DistBackendError, false, err);                       \
120       }                                                                       \
121     }                                                                         \
122     ncclCommGetAsyncError(comm, &result);                                     \
123   }                                                                           \
124   if (result != ncclSuccess) {                                                \
125     std::string err = "NCCL error in: " + std::string(__FILE__) + ":" +       \
126         std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) +   \
127         "\n" + getNcclErrorDetailStr(result, failureReason);                  \
128     TORCH_CHECK_WITH(DistBackendError, false, err);                           \
129   }
130 
131 #define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comm, failureReason)           \
132   ncclResult_t state = cmd;                                                  \
133   auto startTimepoint = std::chrono::steady_clock::now();                    \
134   if (state == ncclInProgress) {                                             \
135     do {                                                                     \
136       if (nccl_nonblocking_timeout() > 0) {                                  \
137         auto currentTimepoint = std::chrono::steady_clock::now();            \
138         auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>( \
139                                currentTimepoint - startTimepoint)            \
140                                .count();                                     \
141         if (timeElapsed > nccl_nonblocking_timeout()) {                      \
142           std::string err = "NCCL timeout in: " + std::string(__FILE__) +    \
143               ":" + std::to_string(__LINE__) + ", " +                        \
144               ncclGetErrorWithVersion(state) + "\n" +                        \
145               getNcclErrorDetailStr(state, failureReason);                   \
146           TORCH_CHECK_WITH(DistBackendError, false, err);                    \
147         }                                                                    \
148       }                                                                      \
149       ncclCommGetAsyncError(comm->getNcclComm(), &state);                    \
150     } while (state == ncclInProgress);                                       \
151   }                                                                          \
152   if (state != ncclSuccess) {                                                \
153     std::string err = "NCCL error in: " + std::string(__FILE__) + ":" +      \
154         std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) +   \
155         "\n" + getNcclErrorDetailStr(state, failureReason);                  \
156     TORCH_CHECK_WITH(DistBackendError, false, err);                          \
157   }
158 
159 // Macro to print and abort on a non-successful NCCL return value.
160 #define C10D_NCCL_ASSERT(cmd)                            \
161   do {                                                   \
162     ncclResult_t result = cmd;                           \
163     if (result != ncclSuccess) {                         \
164       std::string err = ncclGetErrorWithVersion(result); \
165       fprintf(                                           \
166           stderr,                                        \
167           "NCCL error in: %s:%d, %s\n",                  \
168           __FILE__,                                      \
169           __LINE__,                                      \
170           err.c_str());                                  \
171       abort();                                           \
172     }                                                    \
173   } while (0)
174 
175 namespace c10d {
176 #define DEFINE_CONSTANT(name, value) \
177   static c10::IValue name = value;   \
178   static std::string name##_str = value;
179 // Update whenever changing contents or formatting of the dump
180 // (minor when adding fields, major when changing existing fields)
181 // Also update both JSON and Pickle dumps to make use of the newly defined
182 // field(s).
183 DEFINE_CONSTANT(version_val, "2.4");
184 DEFINE_CONSTANT(entries_key, "entries");
185 DEFINE_CONSTANT(nccl_comm_key, "nccl_comm_state");
186 DEFINE_CONSTANT(version_key, "version");
187 DEFINE_CONSTANT(pg_config_key, "pg_config");
188 DEFINE_CONSTANT(pg_status_key, "pg_status");
189 DEFINE_CONSTANT(record_id_key, "record_id");
190 DEFINE_CONSTANT(pg_id_key, "pg_id");
191 DEFINE_CONSTANT(pg_name_key, "process_group");
192 DEFINE_CONSTANT(collective_seq_id_key, "collective_seq_id");
193 DEFINE_CONSTANT(p2p_seq_id_key, "p2p_seq_id");
194 DEFINE_CONSTANT(is_p2p_key, "is_p2p");
195 DEFINE_CONSTANT(op_id_key, "op_id");
196 DEFINE_CONSTANT(profiling_name_key, "profiling_name");
197 DEFINE_CONSTANT(input_sizes_key, "input_sizes");
198 DEFINE_CONSTANT(input_dtypes_key, "input_dtypes");
199 DEFINE_CONSTANT(output_sizes_key, "output_sizes");
200 DEFINE_CONSTANT(output_dtypes_key, "output_dtypes");
201 DEFINE_CONSTANT(time_created_key, "time_created_ns");
202 DEFINE_CONSTANT(duration_key, "duration_ms");
203 DEFINE_CONSTANT(timeout_key, "timeout_ms");
204 DEFINE_CONSTANT(frames_key, "frames");
205 DEFINE_CONSTANT(state_key, "state");
206 DEFINE_CONSTANT(line_key, "line");
207 DEFINE_CONSTANT(name_key, "name");
208 DEFINE_CONSTANT(filename_key, "filename");
209 DEFINE_CONSTANT(retired_key, "retired");
210 DEFINE_CONSTANT(time_discovered_started_key, "time_discovered_started_ns");
211 DEFINE_CONSTANT(time_discovered_completed_key, "time_discovered_completed_ns");
212 DEFINE_CONSTANT(completed_state, "completed");
213 DEFINE_CONSTANT(scheduled_state, "scheduled");
214 DEFINE_CONSTANT(started_state, "started");
215 #undef DEFINE_CONSTANT
216 
217 TORCH_API size_t hashTensors(const std::vector<at::Tensor>& tensors);
218 TORCH_API std::string getNcclVersion();
219 TORCH_API std::string ncclGetErrorWithVersion(ncclResult_t error);
220 bool nccl_use_nonblocking();
221 int nccl_nonblocking_timeout();
222 
223 // Provides additional detail into NCCL error codes based on when these are
224 // thrown in the NCCL codebase.
225 TORCH_API std::string getNcclErrorDetailStr(
226     ncclResult_t error,
227     std::optional<std::string> processGroupFailureReason = std::nullopt);
228 
229 // Write NCCL debug info to local disk or any storage users define.
230 // There are some constrains we set for the debug info writer:
231 // 1. The writer should only be registered once.
232 // 2. Once registered, users cannot change it including un-register.
233 // 3. It is recommended to register the customized writer in the trainer setup,
234 //    If users don't register before calling launchAsyncDebugDump, then users
235 //    lose the chance to register (and the default writer will be
236 //    auto-registered).
237 class TORCH_API DebugInfoWriter {
238  public:
239   virtual ~DebugInfoWriter() = default;
240   virtual void write(const std::string& ncclTrace);
241   static DebugInfoWriter& getWriter(int rank);
242   static void registerWriter(std::unique_ptr<DebugInfoWriter> writer);
getWriterTarget()243   virtual std::string getWriterTarget() {
244     return filename_;
245   }
246 
247  protected:
DebugInfoWriter(std::string namePrefix,int rank)248   DebugInfoWriter(std::string namePrefix, int rank) {
249     filename_ = c10::str(namePrefix, rank);
250   }
251   std::string filename_;
252 
253  private:
254   static std::unique_ptr<DebugInfoWriter> writer_;
255   static std::atomic<bool> hasWriterRegistered_;
256 };
257 
258 // RAII wrapper for NCCL communicator
259 class NCCLComm {
260  public:
NCCLComm(ncclComm_t ncclComm)261   explicit NCCLComm(ncclComm_t ncclComm)
262       : ncclComm_(ncclComm),
263         aborted_(false),
264         ncclAsyncErr_(ncclSuccess),
265         commFailureReason_(std::nullopt),
266         initialized_(false) {}
267 
NCCLComm()268   NCCLComm() : NCCLComm(nullptr) {}
269 
~NCCLComm()270   ~NCCLComm() noexcept {
271     // Add lock in this destructor, as aborted_ needs to be read after memory
272     // barrier here.
273     std::unique_lock<std::mutex> lock(mutex_);
274     if (ncclComm_ && initialized_ && !aborted_) {
275 #ifdef ENABLE_NCCL_ERROR_CHECKING
276       // Use ncclCommAbort instead of ncclCommDestroy here since
277       // ncclCommDestroy could block forever waiting for work to complete on
278       // the communicator.
279       C10D_NCCL_ASSERT(::ncclCommAbort(ncclComm_));
280 #else
281       C10D_NCCL_ASSERT(::ncclCommDestroy(ncclComm_));
282 #endif
283     }
284   }
285 
create(int numRanks,int rank,ncclUniqueId commId)286   static std::shared_ptr<NCCLComm> create(
287       int numRanks,
288       int rank,
289       ncclUniqueId commId) {
290     auto comm = std::make_shared<NCCLComm>();
291     C10D_NCCL_CHECK(
292         ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank),
293         std::nullopt);
294     comm->ncclId_ = commId;
295     comm->rank_ = rank;
296     comm->initialized_ = true;
297     return comm;
298   }
299 
300 #ifdef NCCL_HAS_COMM_NONBLOCKING
create(int numRanks,int rank,ncclUniqueId commId,ncclConfig_t & config)301   static std::shared_ptr<NCCLComm> create(
302       int numRanks,
303       int rank,
304       ncclUniqueId commId,
305       ncclConfig_t& config) {
306     auto comm = std::make_shared<NCCLComm>();
307     bool isInitialized = false;
308     if (nccl_use_nonblocking()) {
309       config.blocking = 0;
310       LOG(INFO) << "Rank " << rank
311                 << ": creating NCCL communicator in nonblocking mode";
312       C10D_NCCL_CHECK_NONBLOCKING(
313           ncclCommInitRankConfig(
314               &(comm->ncclComm_), numRanks, commId, rank, &config),
315           std::nullopt);
316     } else {
317       C10D_NCCL_CHECK(
318           ncclCommInitRankConfig(
319               &(comm->ncclComm_), numRanks, commId, rank, &config),
320           std::nullopt);
321       // under blocking mode, comm is initialized after NCCL CHECK
322       isInitialized = true;
323     }
324     comm->ncclId_ = commId;
325     comm->rank_ = rank;
326     comm->initialized_ = isInitialized;
327     return comm;
328   }
329 
330   static std::shared_ptr<NCCLComm> split(
331       NCCLComm* source,
332       int color_id,
333       int rank,
334       ncclConfig_t& config,
335       std::vector<uint64_t>& ranks_ull);
336 #endif
337 
338 #if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP)
ncclCommDump()339   std::unordered_map<std::string, std::string> ncclCommDump() {
340     std::unordered_map<std::string, std::string> dump;
341     if (isAborted()) {
342       LOG(INFO) << "Communicator was aborted before trying to dump its state.";
343       return dump;
344     }
345     C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), std::nullopt);
346     return dump;
347   }
348 #endif
349 
getNcclId()350   ncclUniqueId getNcclId() {
351     return ncclId_;
352   }
353 
354   // Must not be copyable
355   NCCLComm(const NCCLComm&) = delete;
356   NCCLComm& operator=(const NCCLComm&) = delete;
357 
358   // Do not support move assignment as there is no valid use case
359   NCCLComm& operator=(NCCLComm&& other) = delete;
360 
361   // Move constructable
NCCLComm(NCCLComm && other)362   NCCLComm(NCCLComm&& other) {
363     // Using other's lock, as it reads other's states
364     // Can not use this.mutex_, as this object is being constructed.
365     std::unique_lock<std::mutex> lock(other.mutex_);
366     std::swap(ncclComm_, other.ncclComm_);
367     std::swap(aborted_, other.aborted_);
368     std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
369     std::swap(initialized_, other.initialized_);
370   }
371 
372   ncclComm_t getNcclComm();
373 
getNcclCommFailureReason() const374   std::optional<std::string> getNcclCommFailureReason() const {
375     std::unique_lock<std::mutex> lock(mutex_);
376     return commFailureReason_;
377   }
378 
ncclCommAbort(std::optional<std::string> commFailureReason=std::nullopt)379   void ncclCommAbort(
380       std::optional<std::string> commFailureReason = std::nullopt) {
381     std::unique_lock<std::mutex> lock(mutex_);
382 #ifdef ENABLE_NCCL_ERROR_CHECKING
383     if (aborted_ && !initialized_) {
384       // Should not abort twice.
385       return;
386     }
387 
388 #ifdef NCCL_HAS_COMM_REGISTER
389     // Deregister all registered segments before aborting.
390     for (auto& it : registeredSegmentHandles_) {
391       void* handle = it.second;
392       C10D_NCCL_CHECK(
393           ::ncclCommDeregister(ncclComm_, handle),
394           c10::str(
395               "Failed to deregister segment handle ",
396               handle,
397               " on ncclComm_ ",
398               ncclComm_));
399     }
400     registeredSegmentHandles_.clear();
401 #endif
402 
403     // Set true failure reason if provided by ProcessGroupNCCL (e.g. work
404     // timeout)
405     commFailureReason_ = commFailureReason;
406     LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: "
407               << (commFailureReason ? *commFailureReason
408                                     : "No abort reason provided.");
409 #ifndef NCCL_HAS_COMM_NONBLOCKING
410     C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
411 #else
412     C10D_NCCL_CHECK_TIMEOUT(
413         ::ncclCommAbort(ncclComm_), ncclComm_, commFailureReason_);
414 #endif
415     aborted_ = true;
416     ncclComm_ = nullptr;
417 
418     // Set an appropriate error so that we avoid using the communicator.
419     if (ncclAsyncErr_ == ncclSuccess) {
420       ncclAsyncErr_ = ncclSystemError;
421     }
422 #else
423     // This is a NOOP, if error checks are disabled.
424     return;
425 #endif
426   }
427 
isAborted() const428   bool isAborted() const {
429     std::unique_lock<std::mutex> lock(mutex_);
430     return aborted_;
431   }
432 
getCommSplitCounter() const433   uint64_t getCommSplitCounter() const {
434     return ncclCommSplitCounter_;
435   }
436 
checkForNcclError()437   ncclResult_t checkForNcclError() {
438     std::unique_lock<std::mutex> lock(mutex_);
439 #ifdef ENABLE_NCCL_ERROR_CHECKING
440     if (ncclAsyncErr_ != ncclSuccess) {
441       return ncclAsyncErr_;
442     }
443     C10D_NCCL_CHECK(
444         ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_);
445     return ncclAsyncErr_;
446 #else
447     // Always return success, if error checks are disabled.
448     return ncclSuccess;
449 #endif
450   }
451 
registerSegment(void * ptr,size_t size)452   ncclResult_t registerSegment(void* ptr, size_t size) {
453     std::unique_lock<std::mutex> lock(mutex_);
454 #ifdef NCCL_HAS_COMM_REGISTER
455     // We register only segments from cache allocator
456     // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always
457     // maps to a unique handle and should not be registered before the current
458     // ptr is deregistered and freed.
459     TORCH_CHECK(
460         registeredSegmentHandles_.count(ptr) == 0,
461         "Segment with ptr ",
462         ptr,
463         " has already been registered on ncclComm_ ",
464         ncclComm_);
465 
466     void* handle;
467     C10D_NCCL_CHECK(
468         ncclCommRegister(ncclComm_, ptr, size, &handle),
469         c10::str(
470             "Failed to register segment with ptr ",
471             ptr,
472             ", size ",
473             size,
474             " on ncclComm_ ",
475             ncclComm_));
476     registeredSegmentHandles_[ptr] = handle;
477     return ncclSuccess;
478 #else
479     return ncclInvalidUsage;
480 #endif
481   }
482 
deregisterSegment(void * ptr)483   ncclResult_t deregisterSegment(void* ptr) {
484     std::unique_lock<std::mutex> lock(mutex_);
485 #ifdef NCCL_HAS_COMM_REGISTER
486     TORCH_CHECK(
487         registeredSegmentHandles_.count(ptr) == 1,
488         "Segment with ptr ",
489         ptr,
490         " is not registered on ncclComm_ ",
491         ncclComm_);
492 
493     void* handle = registeredSegmentHandles_[ptr];
494     C10D_NCCL_CHECK(
495         ncclCommDeregister(ncclComm_, handle),
496         c10::str(
497             "Failed to deregister segment handle ",
498             handle,
499             ", with ptr ",
500             ptr,
501             " on ncclComm_ ",
502             ncclComm_));
503     registeredSegmentHandles_.erase(ptr);
504     return ncclSuccess;
505 #else
506     return ncclInvalidUsage;
507 #endif
508   }
509 
510   friend class ProcessGroupNCCL;
511 
512  protected:
513   // a helper function to wait until the communicator is initialized;
514   void waitUntilInitialized(int timeoutSecs);
515   ncclComm_t ncclComm_;
516   // Unique nccl_id for this communicator.
517   ncclUniqueId ncclId_;
518   bool aborted_;
519   uint64_t ncclCommSplitCounter_{0};
520   ncclResult_t ncclAsyncErr_;
521   mutable std::mutex mutex_;
522   // Rank that this communicator corresponds to.
523   int rank_;
524   // Optional reason for communicator failure, provided by ProcessGroupNCCL for
525   // better error messaging.
526   std::optional<std::string> commFailureReason_;
527   bool initialized_{false};
528 #ifdef NCCL_HAS_COMM_REGISTER
529   // Stores handlers for tensors registered by NCCL
530   std::unordered_map<void*, void*> registeredSegmentHandles_;
531 #endif
532 };
533 
534 // Helper that automatically cleans up premul sums.
535 struct ncclRedOpRAII {
536   ncclRedOpRAII() = default;
ncclRedOpRAIIc10d::ncclRedOpRAII537   ncclRedOpRAII(ncclRedOp_t op) : op_(op) {}
ncclRedOpRAIIc10d::ncclRedOpRAII538   ncclRedOpRAII(ncclRedOp_t op, ncclComm_t comm)
539       : op_(op), comm_(comm), premul_sum_(true) {}
540   ncclRedOpRAII(const ncclRedOpRAII&) = delete;
541   ncclRedOpRAII& operator=(const ncclRedOpRAII&) = delete;
ncclRedOpRAIIc10d::ncclRedOpRAII542   ncclRedOpRAII(ncclRedOpRAII&& tmp) : ncclRedOpRAII() {
543     std::swap(tmp.op_, this->op_);
544     std::swap(tmp.comm_, this->comm_);
545     std::swap(tmp.premul_sum_, this->premul_sum_);
546   }
547 #if defined(ENABLE_NCCL_PREMUL_SUM_SUPPORT)
~ncclRedOpRAIIc10d::ncclRedOpRAII548   ~ncclRedOpRAII() {
549     if (premul_sum_) {
550       ncclRedOpDestroy(op_, comm_);
551     }
552   }
553 #endif
operator ncclRedOp_tc10d::ncclRedOpRAII554   operator ncclRedOp_t() const {
555     return op_;
556   }
557   ncclRedOp_t op_;
558   ncclComm_t comm_;
559   bool premul_sum_ = false;
560 };
561 
562 /* Helper used by work::getDuration() and nccl flight recorder */
563 float getDurationFromEvent(
564     at::cuda::CUDAEvent& ncclStartEvent,
565     at::cuda::CUDAEvent& ncclEndEvent);
566 
567 struct NCCLTraceBuffer {
getc10d::NCCLTraceBuffer568   static NCCLTraceBuffer* get() {
569     // intentionally leak on exit
570     // because this will hold python state that may get destructed
571     static NCCLTraceBuffer* instance = new NCCLTraceBuffer();
572     return instance;
573   }
NCCLTraceBufferc10d::NCCLTraceBuffer574   NCCLTraceBuffer() {
575     max_entries_ = getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0);
576     capture_cpp_stack_ = getCvarBool({"TORCH_NCCL_TRACE_CPP_STACK"}, false);
577     enabled_ = max_entries_ > 0;
578   }
579   using Event = at::cuda::CUDAEvent;
580   struct Entry {
581     size_t id_; // incremented id in the trace buffer
582                 // used to figure out where in the circular entries
583                 // buffer this entry will be located to
584                 // update state information
585     size_t pg_id_;
586     std::tuple<std::string, std::string> pg_name_; // <group_name, group_desc>
587 
588     // collective_seq_id and p2p_seq_id refer to actual kernel launches (e.g. 1
589     // per coalesced group).
590     // collective_seq_id only increments for true collective operations (over
591     // all ranks in the group). p2p_seq_id only increments over non-collective
592     // operations in the group. op_id refers to logical operations (e.g. one per
593     // op inside coalesced group)
594     size_t collective_seq_id_;
595     size_t p2p_seq_id_;
596     size_t op_id_;
597     std::string profiling_name_;
598 
599     std::shared_ptr<torch::CapturedTraceback> traceback_;
600     // we borrow pointers to start_ and end_ so we can query the state
601     // on reporting. However, once the event is completed, the call
602     // to `complete` will clear these.
603     Event *start_, *end_;
604 
605     // timestamp when the entry was created, likely close to the time the work
606     // was 'enqueued'- not necessarily started
607     c10::time_t time_created_;
608 
609     // configured timeout for this entry
610     c10::time_t timeout_ms_;
611 
612     // Is this a P2P event?
613     bool isP2P_;
614 
615     std::optional<float> duration_;
616 
617     // timestamp when our CPU threads discovered that the kernel started.
618     // will always be _after_ it actually started, and can be very late
619     // if the watchdog thread got stuck on CUDA APIs.
620     std::optional<c10::time_t> time_discovered_started_;
621 
622     // timestamp when our CPU threads discovered that the kernel completed.
623     // will always be _after_ it actually complated, and can be the same time
624     // as the discovery of the start if the watchdog thread is stuck on CUDA
625     // APIs
626     std::optional<c10::time_t> time_discovered_completed_;
627 
628     // size information for input/output tensors
629     c10::SmallVector<int, 4> input_dims_;
630     std::vector<c10::ScalarType> input_dtypes_;
631     c10::SmallVector<int, 4> output_dims_;
632     std::vector<c10::ScalarType> output_dtypes_;
633     c10::SmallVector<int64_t, 8> sizes_; // flattened from inputs, outputs
634     bool retired_ = false; // is this work entry no longer in the workMetaList_?
635                            // a retired but not completed event has timed out
636   };
637 
638   bool enabled_ = false;
639   bool capture_cpp_stack_ = false;
640   std::mutex mutex_;
641   std::vector<Entry> entries_;
642   size_t max_entries_ = 0;
643   size_t next_ = 0;
644   size_t id_ = 0;
645   std::map<size_t, std::shared_ptr<ProcessGroupStatus>> all_pg_status_ = {};
646   std::map<std::tuple<std::string, std::string>, std::vector<uint64_t>>
647       pg_name_to_ranks_ = {};
648 
649   std::optional<size_t> record(
650       size_t pg_id,
651       const std::tuple<std::string, std::string>& pg_name,
652       size_t collective_seq_id,
653       size_t p2p_seq_id,
654       size_t op_id,
655       std::string profiling_name,
656       const std::vector<at::Tensor>& inputs,
657       const std::vector<at::Tensor>& outputs,
658       Event* start,
659       Event* end,
660       std::chrono::milliseconds timeout_ms,
661       std::shared_ptr<ProcessGroupStatus> pg_status,
662       bool isP2P);
663 
664   void record_pg_ranks(
665       const std::tuple<std::string, std::string>& pg_name,
666       std::vector<uint64_t> ranks);
667 
668   void update_state(Entry& r);
669 
670   std::vector<Entry> dump_entries();
671 
672   /*
673   Mark an Event as completed and free its events.
674   This is called by the watchdog thread, and is asynchronous from the
675   perspective of the main thread.
676   compute_duration defaults to true since retire_id is only called in the
677   watchdog thread, which is currently a place we call cuda APIs which may hang,
678   but care should be taken to avoid computing duration in any function that must
679   never hang. (timing must also be enabled for compute_duration - see
680   TORCH_NCCL_ENABLE_TIMING).
681   */
682   void retire_id(std::optional<size_t> id, bool compute_duration = true);
683 
684   const c10::List<c10::IValue> getCollectiveTrace(
685       bool includeStacktraces,
686       bool onlyActive);
687 
688   // dump pg_entries
689   const c10::Dict<c10::IValue, c10::IValue> getPgConfig();
690 
691   const std::map<std::string, std::map<std::string, std::string>>
692   getPgConfigJson();
693 
694   // dump pg_status
695   const c10::Dict<c10::IValue, c10::IValue> getPgStatus();
696 
697   const std::map<std::string, std::map<std::string, std::string>>
698   getPgStatusJson();
699 
700   std::string dump_json(
701       const std::optional<std::unordered_map<
702           std::string,
703           std::unordered_map<std::string, std::string>>>& ncclDumpMap,
704       bool includeCollectives,
705       bool onlyActive);
706 
707   // dump all collectives + ncclDumpMap
708   std::string dump(
709       const std::optional<std::unordered_map<
710           std::string,
711           std::unordered_map<std::string, std::string>>>& ncclDumpMap,
712       bool includeCollectives,
713       bool includeStackTraces,
714       bool onlyActive);
715 };
716 } // namespace c10d
717 
718 #endif // USE_C10D_NCCL
719