xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_C10D_NCCL
4 
5 #if defined(__linux__)
6 #include <fcntl.h>
7 #include <sys/stat.h>
8 #include <sys/types.h>
9 #include <unistd.h>
10 #endif
11 
12 #include <atomic>
13 #include <chrono>
14 #include <future>
15 #include <iostream>
16 #include <list>
17 #include <mutex>
18 #include <thread>
19 #include <unordered_map>
20 
21 #include <torch/csrc/distributed/c10d/Backend.hpp>
22 #include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
23 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
24 #include <torch/csrc/distributed/c10d/Store.hpp>
25 #include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
26 
27 #include <ATen/DynamicLibrary.h>
28 #include <ATen/cuda/CUDAContext.h>
29 #include <ATen/cuda/CUDAEvent.h>
30 #include <c10/core/Stream.h>
31 #include <c10/core/StreamGuard.h>
32 #include <c10/cuda/CUDACachingAllocator.h>
33 #include <c10/cuda/CUDAGuard.h>
34 #include <c10/cuda/CUDAStream.h>
35 
36 #include <torch/custom_class.h>
37 
38 namespace c10d {
39 
40 // Control broadcasting of NCCL uniqueId
41 static std::vector<std::string> TORCH_NCCL_BCAST_UNIQUEID = {
42     "TORCH_NCCL_BCAST_UNIQUEID"};
43 
44 // Control whether to always use high priority streams
45 static std::vector<std::string> TORCH_NCCL_HIGH_PRIORITY = {
46     "TORCH_NCCL_HIGH_PRIORITY"};
47 
48 // Control whether or not wait() is blocking or non-blocking.
49 static std::vector<std::string> TORCH_NCCL_BLOCKING_WAIT = {
50     "TORCH_NCCL_BLOCKING_WAIT",
51     "NCCL_BLOCKING_WAIT"};
52 
53 // TODO: We want to eventually remove this variable and make users to use
54 // the default value (3 - SkipCleanUp).
55 // Control whether or not we perform Async Error Handling with NCCL.
56 static std::vector<std::string> TORCH_NCCL_ASYNC_ERROR_HANDLING = {
57     "TORCH_NCCL_ASYNC_ERROR_HANDLING",
58     "NCCL_ASYNC_ERROR_HANDLING"};
59 
60 // Control whether dumping debug info on watchdog
61 // timeout is enabled. This variable must be set together with
62 // TORCH_NCCL_ENABLE_MONITORING=1 and TORCH_NCCL_TRACE_BUFFER_SIZE > 0.
63 static std::vector<std::string> TORCH_NCCL_DUMP_ON_TIMEOUT = {
64     "TORCH_NCCL_DUMP_ON_TIMEOUT"};
65 
66 // Control whether Desync Debug is enabled. This variable must be set
67 // together with TORCH_NCCL_ASYNC_ERROR_HANDLING.
68 static std::vector<std::string> TORCH_NCCL_DESYNC_DEBUG = {
69     "TORCH_NCCL_DESYNC_DEBUG",
70     "NCCL_DESYNC_DEBUG"};
71 
72 // Enable recording start-events for all ProcessGroupNCCL collectives, and
73 // compute accurate collective timing per-collective. (Note: end-events are
74 // recorded by default. Turn on this flag can increase chances of a watchdog
75 // hang due to performing a CUDA event query which eventually calls
76 // cudaEventElapsedTime() API.
77 static std::vector<std::string> TORCH_NCCL_ENABLE_TIMING = {
78     "TORCH_NCCL_ENABLE_TIMING",
79     "NCCL_ENABLE_TIMING"};
80 
81 // Enable monitoring thread which aborts the process when the ProcessGroupNCCL
82 // Watchdog thread gets stuck and no heartbeat is detected after
83 // TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC. This can happen due to calling CUDA/NCCL
84 // APIs that may hang. It is Useful to prevent jobs being stuck for a prolonged
85 // time than necessary tying up cluster resources.
86 static std::vector<std::string> TORCH_NCCL_ENABLE_MONITORING = {
87     "TORCH_NCCL_ENABLE_MONITORING"};
88 
89 // Control the watchdog heartbeat timeout period after which the monitoring
90 // thread will abort the process.
91 static std::vector<std::string> TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC = {
92     "TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"};
93 
94 // Whether to rethrow CUDA Errors in the watchdog (default true)
95 static std::vector<std::string> TORCH_NCCL_RETHROW_CUDA_ERRORS = {
96     "TORCH_NCCL_RETHROW_CUDA_ERRORS"};
97 
98 // The maximum number of events we store in the flight recorder's ring buffer.
99 // (One event could be the start or end of a collective, for example).
100 static std::vector<std::string> TORCH_NCCL_TRACE_BUFFER_SIZE = {
101     "TORCH_NCCL_TRACE_BUFFER_SIZE"};
102 
103 // Control how much extra time we will wait for dumping the debugging info
104 // before we exit and throws timeout exception.
105 static std::vector<std::string> TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = {
106     "TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"};
107 
108 // Control the interval inside the monitoring thread to check the coordinated
109 // signal from other ranks, e.g. to dump the debugging information.
110 static std::vector<std::string> TORCH_NCCL_COORD_CHECK_MILSEC = {
111     "TORCH_NCCL_COORD_CHECK_MILSEC"};
112 
113 // Whether to log C++ stack traces on unclean shutdown (default true)
114 static std::vector<std::string> TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN = {
115     "TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN"};
116 
117 // Control whether to use CudaEventCache for the collective in watchdog thread.
118 // We noticed in the past when cuda global lock is held, destroying CudaEvent
119 // can cause a hang.
120 static std::vector<std::string> TORCH_NCCL_CUDA_EVENT_CACHE = {
121     "TORCH_NCCL_CUDA_EVENT_CACHE"};
122 
123 static std::vector<std::string> TORCH_NCCL_NAN_CHECK = {"TORCH_NCCL_NAN_CHECK"};
124 
125 constexpr const char* NCCL_BACKEND_NAME = "nccl";
126 
127 constexpr const char* EXCEPTION_DUMP = "exception_dump";
128 
129 constexpr const int kWorkStatusUpdatePeriodMs = 30 * 1000; // 30 seconds
130 
131 constexpr auto kProcessGroupNCCLDefaultTimeout =
132     std::chrono::milliseconds(10 * 60 * 1000);
133 
134 // NoHandling: do not handle asynchronous NCCL errors
135 // TearDown: tear down process upon error, see `WorkNCCL::handleException`
136 // CleanUpOnly: just clean up collectives and abort communicators without
137 // tearing down process SkipCleanUp: (this is a temporary option and can be
138 // removed in future) tear down process without cleaning up NCCL communicators.
139 // This should be used as a last resort in case `ncclCommAbort` itself is
140 // hanging
141 enum ErrorHandlingMode {
142   NoHandling = 0,
143   TearDown = 1,
144   CleanUpOnly = 2,
145   SkipCleanUp = 3
146 };
147 
148 #define SHOULD_CLEAN_UP(a) (a != NoHandling && a != SkipCleanUp)
149 
150 #define SHOULD_TEAR_DOWN(a) (a != NoHandling && a != CleanUpOnly)
151 
152 #define PRINT_COLLECTIVE_HASH_SIGNATURE(phase, opType, numel, hashValue)      \
153   LOG(WARNING) << logPrefix() << "Hash of " << phase << " to NCCL " << opType \
154                << " with size " << numel << " is " << hashValue;
155 
156 // If set, ProcessGroupNCCL doesn't use recordStream calls to ensure
157 // caching allocator safety for tensors used on both user-facing and
158 // internal comm streams.
159 // Instead, it stashes live references to those tensors until after
160 // user-facing streams are synced with comm streams.
161 // See stashed_for_allocator_safety_ below.
162 static std::vector<std::string> TORCH_NCCL_AVOID_RECORD_STREAMS = {
163     "TORCH_NCCL_AVOID_RECORD_STREAMS"};
164 
165 // If set, ProcessGroupNCCL registers postAlloc and preFree hooks to cuda cache
166 // allocator so that whenever a tensor is allocated or freed, ProcessGroupNCCL
167 // can register/deregister the tensor on all available NCCL communicators.
168 static std::vector<std::string> TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK =
169     {"TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK",
170      "NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"};
171 
172 #if defined(__linux__)
173 struct DumpPipe {
DumpPipec10d::DumpPipe174   DumpPipe(int rank) {
175     std::string fileStem =
176         getCvarString({"TORCH_NCCL_DEBUG_INFO_PIPE_FILE"}, "");
177     if (fileStem.empty() ||
178         getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) {
179       return;
180     }
181     TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_TEMP_FILE is empty");
182     std::string filename = c10::str(fileStem, rank, ".pipe");
183     TORCH_CHECK(
184         unlink(filename.c_str()) != -1 || errno == ENOENT,
185         "Error removing existing named pipe ",
186         filename);
187     TORCH_CHECK(
188         mkfifo(filename.c_str(), 0666) != -1,
189         "Error creating named pipe ",
190         filename);
191     fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK);
192     LOG(INFO) << "Pipe file " << filename
193               << " has been opened, write to it to trigger NCCL Debug Dump.";
194     TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename);
195   }
shouldDumpc10d::DumpPipe196   bool shouldDump() {
197     if (fd_ == -1) {
198       return false;
199     }
200     char buf[128];
201     // non-blocking from O_NONBLOCK above.
202     // Ignore EINTR because we already will poll this
203     // again later.
204     ssize_t bytesRead = read(fd_, &buf, 128);
205     return bytesRead > 0;
206   }
~DumpPipec10d::DumpPipe207   ~DumpPipe() {
208     if (fd_ != -1) {
209       close(fd_);
210     }
211   }
212 
213  private:
214   int fd_ = -1;
215 };
216 #else
217 struct DumpPipe {
DumpPipec10d::DumpPipe218   DumpPipe(int rank) {}
shouldDumpc10d::DumpPipe219   bool shouldDump() {
220     return false;
221   }
222 };
223 #endif
224 
225 // ProcessGroupNCCL implements NCCL bindings for c10d.
226 //
227 // All functions of the class are expected to be called in the same order
228 // across all processes in the process group.  This is the only way that we
229 // can guarantee to match up the same calls among all processes.
230 //
231 // All NCCL functions provided by this class are asynchronous functions. More
232 // specifically, each NCCL call is scheduled on a separate CUDA stream that is
233 // different from the current CUDA stream. This is for the purpose of
234 // achieving potentially concurrency and better performance. As a result,
235 // it is the callers' responsibility to make sure that the CUDA stream their
236 // code works on needs to wait for the NCCL operation from
237 // this class.
238 //
239 // This can be done by calling:
240 //
241 // either WorkNCCL::wait() or WorkNCCL::synchronize(), both achieves the same
242 // functionality and are synonyms.
243 //
244 // Also note that WorkNCCL::finishedGPUExecution() is a helper function only
245 // provided by ProcessGroupNCCL to check if the NCCL operation of WorkNCCL has
246 // finished execution on the GPU (not just scheduled).
247 //
248 // Example on using the NCCL process group
249 //
250 //   ProcessGroupNCCL pg(store, rank, size);
251 //   std::shared_ptr<WorkNCCL> work = pg.allreduce(tensors);
252 //
253 //   // At this point, NCCL kernel has already by queued successfully
254 //   // Now, let current stream wait for the NCCL to finish, this function is
255 //   // async operation as well
256 //
257 //   work->wait()
258 //
259 //   // Now continue on other work in the current stream.
260 class TORCH_API ProcessGroupNCCL : public Backend {
261  public:
262   class WorkNCCL : public Work, public std::enable_shared_from_this<WorkNCCL> {
263    public:
264     friend struct WorkInfo;
265 
266     // Constructor takes a list of CUDA devices
267     WorkNCCL(
268         const std::string& pgUID,
269         const std::string& pgDesc,
270         at::Device& device,
271         int rank,
272         OpType opType,
273         uint64_t seq,
274         const char* profilingTitle = nullptr,
275         const std::optional<std::vector<at::Tensor>>& inputs = std::nullopt,
276         bool desyncDebug = false,
277         bool enableTiming = false,
278         bool cudaEventCacheEnabled = false,
279         DebugLevel distDebugLevel = DebugLevel::Off);
280     // Copy constructor doing partial copy without outputs_. Cleanup thread
281     // monitors and removes finished works. However it will deadlock when
282     // destructs outputs_ tensors who are view tensors in autograd graph.
283     WorkNCCL(const WorkNCCL& w);
284 
285     ~WorkNCCL() override;
286 
287     // Checks if the NCCL kernel has started to execute.
288     bool isStarted();
289 
290     // Checks if request has completed. In this specific case of NCCL, it checks
291     // if the NCCL operation has completed on the GPU in its own NCCL stream.
292     // Non-blocking operation.
293     bool isCompleted() override;
294 
295     bool isSuccess() const override;
296 
297     // Same as calling synchronize() for NCCL work.
298     bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
299 
300     void abort() override;
301 
302     // Let current stream wait on the completing of the NCCL work
303     // Throws on exceptions. Blocking operation, which will wait for work
304     // completion.
305     void synchronize() override;
306 
307     // Synchronize streams by blocking each on the NCCL stream
308     void synchronizeStream();
309 
310     // Helper function to handle exception (throw if needed).
311     void handleException(ErrorHandlingMode asyncErrorHandling);
312 
313     // Helper function that checks if the NCCL kernels have finished
314     // execution on the GPUs
315     bool finishedGPUExecution();
316 
317     // Get a Future object that will be marked as completed internally.
318     c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
319 
320     float getDuration() const override;
321 
322     uint64_t getSequencenumber() const override;
323 
324     const std::string& logPrefix() const;
325 
326     // Helper function that sets an exception_ptr on the WorkNCCL object.
327     void setException(std::exception_ptr exception_ptr);
328 
329     // Helper function that returns True if the WorkNCCL object has timed out
330     // and False otherwise.
331     // In case of timeout, set exception on the WorkNCCL object.
332     bool checkTimeout(
333         std::optional<std::chrono::milliseconds> timeout = std::nullopt);
334 
335     std::vector<at::Tensor> result() override;
336 
337    protected:
338     // The process group unique id
339     std::string pgUID_;
340 
341     // The process group description
342     std::string pgDesc_;
343 
344     // The cached list of CUDA devices to operate on
345     at::Device device_;
346 
347     // The start CUDA event of NCCL operator tracking this work item. These
348     // start CUDA events are needed by desync debugging if enabled.
349     std::shared_ptr<at::cuda::CUDAEvent> ncclStartEvent_;
350 
351     // The end CUDA event of NCCL operator tracking this work item.
352     std::shared_ptr<at::cuda::CUDAEvent> ncclEndEvent_;
353 
354     // The NCCL communicator used for this work item.
355     std::shared_ptr<NCCLComm> ncclComm_;
356 
357     // Tensors used for barrier op
358     at::Tensor barrierTensor_;
359 
360     // Clone of blockingWait_ from ProcessGroupNCCL.
361     bool blockingWait_ = false;
362 
363     // Clone of avoidRecordStreams_ from ProcessGroupNCCL.
364     bool avoidRecordStreams_ = false;
365 
366     // Clone of opTimeout_ from ProcessGroupNCCL.
367     std::chrono::milliseconds opTimeout_;
368 
369     // Ephemeral timeouts are owned by exactly one work,
370     // and reset after that work completes.
371     // There may be more than one ephemeral timeout active at the same time,
372     // and this variable is used to track the ownership of ephemeral timeout.
373     std::chrono::milliseconds ownedEphermeralTimeout_ =
374         std::chrono::milliseconds(0);
375 
376     // Time point representing when the work started.
377     std::chrono::time_point<std::chrono::steady_clock> workStartTime_;
378 
379     // Record the collective sequential number.
380     uint64_t seq_;
381 
382     // Indicates if the nccl start event has been updated to the store trace.
383     // This will be used by desync debug.
384     bool startTraceUpdated_{false};
385 
386     // Record collective sizes for debug. We only record the size on the first
387     // device as multi-device per process is deprecated
388     size_t numelIn_ = -1;
389     size_t numelOut_ = -1;
390 
391     // Wrapper method for the static checkForNCCLErrors which can be overridden
392     // for tests.
393     virtual std::exception_ptr checkForNCCLErrors();
394 
395     friend std::ostream& operator<<(
396         std::ostream& output,
397         const WorkNCCL& workNCCL);
398 
399    private:
400     // Helper function for synchronize
401     void synchronizeInternal(std::chrono::milliseconds timeout);
402 
403     // Checks for NCCL errors and sets an appropriate exception_ptr.
404     void checkAndSetException();
405 
406     // Just checks whether GPU execution has started, without modifying
407     // exception_ptr.
408     bool startedGPUExecutionInternal() const;
409 
410     // Just checks whether GPU execution has completed, without modifying
411     // exception_ptr.
412     bool finishedGPUExecutionInternal() const;
413 
414     // Reference to the store so that we can write aborted communicators
415     // to the store.
416     c10::intrusive_ptr<Store> store_;
417 
418     // Store a reference to NCCL collective's outputs, used by result and to
419     // give a more descriptive message when representing the Work as a string.
420     std::shared_ptr<std::vector<at::Tensor>> outputs_;
421 
422     // TORCH_NCCL_AVOID_RECORD_STREAMS implementation helper.
423     // Stores references to participating non-output tensors (ie inputs,
424     // flattened intermediates).
425     // We'll clear this list in synchronizeStream, just after user-facing
426     // stream(s) are synced with the nccl work stream(s).
427     // By keeping these refs (as well as outputs_) alive until after the
428     // collective's work rejoins the user-facing streams, we achieve
429     // caching allocator safety without any recordStream calls.
430     // For in-place collectives, some refs stashed here may alias outputs_,
431     // but that doesn't do any harm.
432     std::shared_ptr<std::vector<at::Tensor>> stashed_for_allocator_safety_;
433 
434     // The future returned by getFuture.
435     c10::intrusive_ptr<at::ivalue::Future> future_;
436 
437     bool timingEnabled_;
438     // unique id used to tell the trace buffer that this
439     // work has completed
440     std::optional<uint64_t> trace_id_;
441     DebugLevel distDebugLevel_;
442     friend class ProcessGroupNCCL;
443   };
444 
445   class CUDAEventCache {
446    public:
447     CUDAEventCache();
448     std::shared_ptr<at::cuda::CUDAEvent> create(bool timing);
449     static CUDAEventCache& get();
450 
451    private:
452     std::mutex cacheMutex_;
453     // NOTE: We intentionaly store raw pointers so that
454     // we do not attempt to destroy the event objects on process exit,
455     // because cuda may be gone.
456     std::vector<at::cuda::CUDAEvent*>
457         eventsArray_[2]; // 0 for timing=false, 1 for timing=true
458   };
459 
460   struct Options : Backend::Options {
461     // NOTE: timeout in ProcessGroupNCCL::Options denote the timeout for
462     // operations. This is only used when blockingWait_ is enabled.
463     explicit Options(bool is_high_priority_stream = false);
464 
465     // return intrusive_ptr of the object
createc10d::ProcessGroupNCCL::Options466     static c10::intrusive_ptr<Options> create(
467         bool is_high_priority_stream = false) {
468       return c10::make_intrusive<Options>(is_high_priority_stream);
469     }
470 
471     // Schedule NCCL operations on high priority CUDA streams
472     bool is_high_priority_stream;
473 
474 #ifdef NCCL_HAS_COMM_NONBLOCKING
475     // Configure ranks
476     ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
477 #endif
478 
479     // Optional "parent" backend and color to create communicators from
480     // via `ncclCommSplit`
481     std::shared_ptr<ProcessGroupNCCL> split_from;
482     int64_t split_color{0};
483     std::vector<uint64_t> global_ranks_in_group;
484     std::string group_name;
485   };
486 
487   // If you wish to create multiple process groups, each with a potentially
488   // different rank and size, you can do so by passing a new store instance
489   // to each one. If you have only a single store object, you can
490   // use the `c10d::PrefixStore` to derive scoped instances.
491   // This is also what the Python API in torch.distributed does.
492   //
493   // The process group instance keeps a reference to the store because
494   // it may be used long after the constructor runs. In fact, the constructor
495   // doesn't create any NCCL communicators. A single NCCL communicator can
496   // only be used on a specific set of devices, and are therefore created
497   // on-demand when a collective runs. If another collective is executed later,
498   // against a different set of devices, the process group creates another NCCL
499   // communicator. These NCCL communicators are cached and reused if possible.
500   //
501   ProcessGroupNCCL(
502       const c10::intrusive_ptr<Store>& store,
503       int rank,
504       int size,
505       c10::intrusive_ptr<Options> options = Options::create());
506 
507   // This constructor includes the deprecated `groupName` argument.
508   // If you have existing code that uses the `groupName`, you can replace
509   // it by specifying a `c10d::PrefixStore(groupName, store)` for store.
ProcessGroupNCCL(const c10::intrusive_ptr<Store> & store,int rank,int size,const std::string & groupName,c10::intrusive_ptr<Options> options=Options::create ())510   C10_DEPRECATED ProcessGroupNCCL(
511       const c10::intrusive_ptr<Store>& store,
512       int rank,
513       int size,
514       const std::string& groupName,
515       c10::intrusive_ptr<Options> options = Options::create())
516       : ProcessGroupNCCL(store, rank, size, options) {}
517 
518   ~ProcessGroupNCCL() override;
519 
520   // This function returns a local uid for ProcessGroupNCCL.
getUid()521   uint64_t getUid() {
522     return static_cast<uint64_t>(local_id_);
523   }
524 
getOptions()525   c10::intrusive_ptr<Options> getOptions() {
526     return options_;
527   }
528 
getBackendName() const529   const std::string getBackendName() const override {
530     return std::string(NCCL_BACKEND_NAME);
531   }
532 
supportsSplitting() const533   bool supportsSplitting() const override {
534     return true;
535   }
536 
537   void startCoalescing() override;
538 
539   c10::intrusive_ptr<Work> endCoalescing() override;
540 
541   // For specifying a composite optype, such as ALLGATHER and REDUCE_SCATTER
542   c10::intrusive_ptr<Work> endCoalescing(OpType optype);
543 
544   c10::intrusive_ptr<Work> broadcast(
545       std::vector<at::Tensor>& tensors,
546       const BroadcastOptions& opts = BroadcastOptions()) override;
547 
548   c10::intrusive_ptr<Work> _broadcast_oop(
549       at::Tensor& outputTensors,
550       at::Tensor& inputTensors,
551       const BroadcastOptions& opts = BroadcastOptions());
552 
553   c10::intrusive_ptr<Work> allreduce_sparse(
554       std::vector<at::Tensor>& tensors,
555       const AllreduceOptions& opts = AllreduceOptions()) override;
556 
557   c10::intrusive_ptr<Work> allreduce(
558       std::vector<at::Tensor>& tensors,
559       const AllreduceOptions& opts = AllreduceOptions()) override;
560 
561   c10::intrusive_ptr<Work> allreduce_coalesced(
562       std::vector<at::Tensor>& tensors,
563       const AllreduceCoalescedOptions& opts =
564           AllreduceCoalescedOptions()) override;
565 
566   c10::intrusive_ptr<Work> reduce(
567       std::vector<at::Tensor>& tensors,
568       const ReduceOptions& opts = ReduceOptions()) override;
569 
570   c10::intrusive_ptr<Work> _reduce_oop(
571       at::Tensor& outputTensors,
572       at::Tensor& inputTensors,
573       const ReduceOptions& opts = ReduceOptions());
574 
575   c10::intrusive_ptr<Work> allgather(
576       std::vector<std::vector<at::Tensor>>& outputTensors,
577       std::vector<at::Tensor>& inputTensors,
578       const AllgatherOptions& opts = AllgatherOptions()) override;
579 
580   c10::intrusive_ptr<Work> _allgather_base(
581       at::Tensor& outputbuffer,
582       at::Tensor& inputbuffer,
583       const AllgatherOptions& opts = AllgatherOptions()) override;
584 
585   c10::intrusive_ptr<Work> allgather_coalesced(
586       std::vector<std::vector<at::Tensor>>& outputTensorLists,
587       std::vector<at::Tensor>& inputTensors,
588       const AllgatherOptions& opts = AllgatherOptions()) override;
589 
590   c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
591       std::vector<at::Tensor>& outputs,
592       std::vector<at::Tensor>& inputs,
593       const AllgatherOptions& opts = AllgatherOptions()) override;
594 
595   c10::intrusive_ptr<Work> reduce_scatter(
596       std::vector<at::Tensor>& outputTensors,
597       std::vector<std::vector<at::Tensor>>& inputTensors,
598       const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
599 
600   c10::intrusive_ptr<Work> _reduce_scatter_base(
601       at::Tensor& outputTensor,
602       at::Tensor& inputTensor,
603       const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
604 
605   c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
606       std::vector<at::Tensor>& outputs,
607       std::vector<at::Tensor>& inputs,
608       const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
609 
610   c10::intrusive_ptr<Work> barrier(
611       const BarrierOptions& opts = BarrierOptions()) override;
612 
613   c10::intrusive_ptr<Work> alltoall_base(
614       at::Tensor& outputTensor,
615       at::Tensor& inputTensor,
616       std::vector<int64_t>& outputSplitSizes,
617       std::vector<int64_t>& inputSplitSizes,
618       const AllToAllOptions& opts = AllToAllOptions()) override;
619 
620   c10::intrusive_ptr<Work> alltoall(
621       std::vector<at::Tensor>& outputTensors,
622       std::vector<at::Tensor>& inputTensors,
623       const AllToAllOptions& opts = AllToAllOptions()) override;
624 
625   c10::intrusive_ptr<Work> send(
626       std::vector<at::Tensor>& tensors,
627       int dstRank,
628       int tag) override;
629 
630   c10::intrusive_ptr<Work> recv(
631       std::vector<at::Tensor>& tensors,
632       int srcRank,
633       int tag) override;
634 
635   void groupStart();
636 
637   void groupEnd();
638 
639   void groupEndNonblocking(std::shared_ptr<NCCLComm> comm);
640 
641   c10::intrusive_ptr<Work> gather(
642       std::vector<std::vector<at::Tensor>>& outputTensors,
643       std::vector<at::Tensor>& inputTensors,
644       const GatherOptions& opts = GatherOptions()) override;
645 
646   c10::intrusive_ptr<Work> scatter(
647       std::vector<at::Tensor>& outputTensors,
648       std::vector<std::vector<at::Tensor>>& inputTensors,
649       const ScatterOptions& opts = ScatterOptions()) override;
650 
651   // Unsupported Ops
652   c10::intrusive_ptr<Work> recvAnysource(
653       std::vector<at::Tensor>& tensors,
654       int tag) override;
655 
656   // Agrees on an initial sequence number for the whole group by having rank 0
657   // create it and broadcast it to other ranks using the store.
658   void setSequenceNumberForGroup() override;
659 
660   // Retrieves the current sequence number for the whole group, which should be
661   // in sync. If the returned number is not consistent across the group, it
662   // may indicate that there is some sort of collective desynchronization.
663   uint64_t getSequenceNumberForGroup() override;
664 
665   // Return the total number of splits the communicators held by this process
666   // group have performed.  Counts ncclCommCreateFromRanks() for ncclx v2.21.5+
667   uint64_t getCommSplitCounter() const;
668 
669   void registerOnCompletionHook(
670       std::function<void(std::shared_ptr<WorkInfo>)>&& hook) override;
671   void waitForPendingWorks() override;
672 
673   void enableCollectivesTiming() override;
674 
675   // Helper function for iteratively aborting communicators in the provided map
676   void abortCommsFromMap(
677       std::unordered_map<std::string, std::shared_ptr<NCCLComm>>& ncclCommsMap,
678       std::optional<std::string> abortReason);
679 
680   c10::intrusive_ptr<intra_node_comm::IntraNodeComm> initIntraNodeComm();
681 
682   // Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
683   // instead of relying on ProcessGroupNCCL destructor.
684   // return true if abort is successful, otherwise false
685   bool abort(std::optional<std::string> abortReason = std::nullopt);
686 
687   void shutdown(std::optional<std::string> reason = std::nullopt);
688 
689   void eagerConnectSingleDevice(at::Device device) override;
690 
691   void performNocolorSplit(at::Device device);
692 
693   // This method adds a temporary extension for the timeout period,
694   // applying to all collectives between the calling of this API and
695   // the completion of the first collective on the GPU. While this feature
696   // provides flexibility in specific scenarios, it introduces statefulness
697   // to timeout setting. Therefore, it is advisable to use this API sparingly
698   // and consider alternative approaches, such as directly setting the timeout
699   // or utilizing a barrier collective (one can set any timeout to the barrier),
700   // whenever feasible.
701   void addEphemeralTimeout(const std::chrono::milliseconds& timeout);
702 
703   // This function is only intended for testing purposes because we don't
704   // want to expose the `WorkNCCL` via pybind. It verifies whether the
705   // `opTimeout_` of the provided WorkNCCL instance is the same as the specified
706   // timeout.
707   bool verifyWorkTimeoutForTest(
708       const c10::intrusive_ptr<Work> work,
709       const std::chrono::milliseconds& timeout);
710 
711  protected:
712   // Helper that broadcasts nccl unique ID to all ranks through the store
713   void broadcastUniqueNCCLID(
714       ncclUniqueId* ncclID,
715       bool isSingleP2POp,
716       const std::string& devicesKey,
717       int p2pRank);
718 
719   // Helper that either looks up the cached NCCL communicators or creates
720   // a new set of NCCL communicators as a cache entry
721   std::shared_ptr<NCCLComm> getNCCLComm(
722       const std::string& deviceKey,
723       at::Device& device,
724       OpType opType,
725       int p2pRank = 0,
726       bool isSendRecvSelf = false);
727 
728   // Wrapper method which can be overridden for tests.
729   virtual std::exception_ptr checkForNCCLErrors(
730       std::shared_ptr<NCCLComm>& ncclComm);
731 
732   // Ensure thaht if record is True, the work obj will be enqueued via
733   // workEnqueue
734   virtual c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
735       at::Device& device,
736       int rank,
737       OpType opType,
738       const char* profilingTitle = nullptr,
739       const std::vector<at::Tensor>& inputs = {},
740       const std::vector<at::Tensor>& outputs = {},
741       bool record = false);
742 
743   // In the timeout case and we will dump debug info such as the NCCL flight
744   // recorder to storage. Down the road, if we have more complicated or blocking
745   // operations, we might need to use a side thread to do it.
746   bool dumpDebuggingInfo();
747 
748  private:
749   int globalRankStart;
750   int globalRankStride;
751 
752   // Helper that encapsulates work shared across all collective communication
753   // primitives.  The callbacks have the following signatures:
754   //
755   //    ncclResult_t fn(at::Tensor& input, at::Tensor& output,
756   //                    ncclComm_t, at::cuda::CUDAStream&);
757   //    void {pre,post}(std::vector<at::cuda::CUDAStream&>);
758   template <typename Fn>
759   c10::intrusive_ptr<Work> collective(
760       at::Tensor& input,
761       at::Tensor& output,
762       Fn fn,
763       OpType opType,
764       const char* profilingTitle = nullptr,
765       bool avoidRecordStreams = false,
766       bool nanCheck = true);
767 
768   template <typename Fn, typename PreProcess, typename PostProcess>
769   c10::intrusive_ptr<Work> collective(
770       at::Tensor& input,
771       at::Tensor& output,
772       Fn fn,
773       PreProcess pre,
774       PostProcess post,
775       OpType opType,
776       const char* profilingTitle = nullptr,
777       bool avoidRecordStreams = false,
778       bool nanCheck = true);
779 
780   template <typename Fn, typename PreProcess, typename PostProcess>
781   c10::intrusive_ptr<Work> collective(
782       std::vector<at::Tensor>& inputs,
783       std::vector<at::Tensor>& outputs,
784       Fn fn,
785       PreProcess pre,
786       PostProcess post,
787       OpType opType,
788       const char* profilingTitle = nullptr,
789       bool avoidRecordStreams = false,
790       bool nanCheck = true);
791 
792   template <typename Fn>
793   c10::intrusive_ptr<Work> collectiveCoalesced(
794       std::vector<at::Tensor>& input,
795       std::vector<at::Tensor>& output,
796       Fn fn,
797       OpType opType,
798       const char* profilingTitle = nullptr,
799       bool avoidRecordStreams = false);
800 
801   // Helper that encapsulates work shared across point-to-point communication
802   // primitives. It is the same structure as the helper used for collective
803   // communication primitives.
804   template <typename Fn>
805   c10::intrusive_ptr<Work> pointToPoint(
806       at::Tensor& tensor,
807       Fn fn,
808       int peer,
809       OpType opType,
810       const char* profilingTitle = nullptr);
811 
812   template <typename Fn, typename PreProcess, typename PostProcess>
813   c10::intrusive_ptr<Work> pointToPoint(
814       at::Tensor& tensor,
815       Fn fn,
816       int peer,
817       OpType opType,
818       PreProcess pre,
819       PostProcess post,
820       const char* profilingTitle);
821 
822   c10::intrusive_ptr<Work> allreduce_impl(
823       at::Tensor& tensor,
824       const AllreduceOptions& opts = AllreduceOptions());
825 
826   // Checks for NCCL errors on each of the communicators and returns an
827   // appropriate exception_ptr (nullptr if no errors).
828   static std::exception_ptr checkForNCCLErrorsInternal(
829       std::shared_ptr<NCCLComm>& ncclComm);
830 
831   // Function that runs as part of a separate thread and checks for errors on
832   // NCCL communicators. We need a separate thread to check for NCCL errors
833   // since we can't rely on the user calling certain methods like wait(),
834   // isCompleted() etc. to detect and remediate errors. In addition to this, we
835   // need a mechanism to safely abort and remove NCCL communicators from our
836   // cache. This can be done cleanly by having a thread for the ProcessGroupNCCL
837   // class. Attempting to modify the communicator cache from the WorkNCCL class
838   // might run into issues with object lifetime since the ProcessGroupNCCL
839   // object might get destroyed before the WorkNCCL object.
840   void ncclCommWatchdog();
841 
842   // Return the CUDA device most likely associated with this backend.
843   // If we aren't bound to a specific device, there is no strict
844   // guarantee that this heuristic is the correct assignment of ranks
845   // to GPUs that Python layers use, but in practice it tends to be.
846   // Fortunately we don't rely on this for correctness of any tensor
847   // operations, just for ancillary uses like barriers.
848   at::Device guessDeviceForRank() const;
849 
850   // Destroys initialized NCCL communicators in devNCCLComMap_ given by input
851   // key. Throws if there are no communicators to destroy. Also removes
852   // communicators from the cache and clears used device indices.
853   void destroyNCCLComms(const std::string& devNCCLCommMapKey);
854 
855   // Watchdog's inside loop.
856   // Takes care of cleaning up completed work, and aborting upon failure or
857   // timeout.
858   void watchdogHandler();
859 
860   void runHookLoop();
861 
862   // Desync debug helper
863   void logWorkStart(WorkNCCL& work);
864 
865   // Desync debug helper
866   void logWorkEnd(WorkNCCL& work);
867 
868   // Generates a prefix that is unique to this process group and rank, for
869   // disambiguating logs
870   std::string createLogPrefix() const;
871 
872   // Returns the unique prefix created in createLogPrefix
873   const std::string& logPrefix() const;
874 
875   // Returns the global rank of the device. This function assumes that users
876   // always create a default global process group(PG) which includes all
877   // devices. It is called in the constructor of ProcessGroupNCCL, so it always
878   // return the rank_ of the the very first PG created, aka, default global PG.
879   const int& globalRank() const;
880 
881   // Returns the global ranks of a PG.
882   const std::vector<uint64_t>& groupRanks() const;
883 
884   // Util function to assign timeout to each work.
885   void assignTimeoutToWork(
886       const c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work,
887       const c10::intrusive_ptr<Options>& option);
888 
889  protected:
890   // Function that runs as part of a separate thread aside from watchdog
891   // thread because we need to check the heartbeat from watchdog thread
892   // so that when we get stuck in some NCCL/CUDA calls,
893   // we can dump the debugging information and abort the process.
894   virtual void heartbeatMonitor();
895 
896   // Function that directly trigger std::abort so that the whole process
897   // gets terminated.
898   virtual void terminateProcess(std::string errMsg);
899 
900   // A helper function to wait for a future to complete or timeout.
901   void waitForFutureOrTimeout(
902       std::future<bool>& fut,
903       const std::chrono::milliseconds& timeOutMilSec,
904       const std::string& futDescription,
905       bool throwException = false,
906       bool log = false);
907 
908   // When watchdog timeout, this function will be called and return debug info
909   // for users. For now we only get information from retrieveDesyncReport.
910   // We are working on enabling more useful debug information for watchdog
911   // timeout.
912   virtual std::string getNCCLWatchdogDebugInfo();
913 
914   std::string getNCCLWatchdogTimeoutErrorMsg(const std::string& extraMsg);
915 
916   std::string getNCCLWatchdogTimeoutExitMsg(const std::string& exitReason);
917 
918   static const int64_t kWatchdogThreadSleepMillis;
919 
920   // The store is used to broadcast the NCCL unique ID of rank 0. This store
921   // comes with prefix and it is different across ProcessGroup NCCL instances
922   // (aka, different ProcessGroups).
923   c10::intrusive_ptr<Store> store_;
924 
925   // Reference to the store without prefix so that keys are same across all
926   // ProcessGroup NCCL instances and (key, value) pairs written to the store are
927   // global.
928   c10::intrusive_ptr<Store> globalStore_;
929 
930   bool storeError_{false};
931 
932   // The lock which protects the write/read of
933   // ephemeralTimeoutActive_/ephemeralTimeoutInflight_.
934   // TODO(fduwjj): We need to have an audit on all mutexes we are adding here.
935   // And consolidate them if possible.
936   std::mutex mtxTimeoutExtension_;
937 
938   // The ephemeral timeout added on top of existing timeout for works issued
939   // before first work finishes.
940   std::chrono::milliseconds ephemeralTimeoutActive_ =
941       std::chrono::milliseconds(0);
942 
943   // The ephemeral timeout addition which has been already applied to work.
944   std::chrono::milliseconds ephemeralTimeoutInflight_ =
945       std::chrono::milliseconds(0);
946 
947   const c10::intrusive_ptr<Options> options_;
948 
949   // The number of NCCL communicators that have been created during
950   // the lifetime of this process group. This sequence number is
951   // used to scope keys used in the store.
952   uint64_t ncclCommCounter_{0};
953 
954   // The store keys to trace the last NCCL collective kernel CUDA events - start
955   // event and end event respectively. These are used to do desync root cause
956   // analysis.
957   const std::string traceKeyStart_;
958   const std::string traceKeyEnd_;
959 
960   // The NCCL communicator that the process group has cached.
961   //
962   // For collective operations:
963   // The key is a list of GPU devices that an operation is operating on
964   // The GPU devices are stored in a device sequence and the cache NCCL
965   // communicator is associated with this GPU device sequence
966   //
967   // e.g. If the process group op only uses device 0, then the value of
968   // the used device string stored (value of the hashmap) would be "0".
969   //
970   //      If the process group op uses device 0 - 7 and the each tensor of the
971   //      input tensor list is on device, 0, 1, 2, 3, 4, 5, 6, 7 separately,
972   //      then the value of the used device string (key) stored would be
973   //      "0,1,2,3,4,5,6,7"
974   //
975   //      If the process group op uses device 0 - 7 and the each tensor of the
976   //      input tensor list is on device, 0, 4, 5, 6, 7, 1, 2, 3 separately,
977   //      then the value of the used device string stored would be
978   //      "0,4,5,6,7,1,2,3"
979   //
980   //      Note that the order of the device for the tensor list matters.
981   //
982   // For point-to-point operations:
983   // The key is a string of my current rank and the peer process rank.
984   // e.g. If process 1 and process 2 are involved in a point-to-point
985   // communication, the key will be "1:2" on both processes. Note: this is for
986   // the scenario where there is only 1 GPU per process. When it comes to
987   // multiple GPUs per process, this part may need to redesigned.
988   // TODO: we probably need a separte map for P2P comms
989   std::unordered_map<std::string, std::shared_ptr<NCCLComm>> devNCCLCommMap_;
990 
991   // The NCCL communicators currently in process of being initialized.
992   std::unordered_map<std::string, std::shared_ptr<NCCLComm>>
993       inInitializationCommMap_;
994 
995   // Mutex to guard maps like devNCCLCommMap_.
996   std::mutex mutex_;
997 
998   // Heartbeat of watchdog thread.
999   std::atomic_uint64_t heartbeat_;
1000 
1001   // The time interval used for deciding whether there is no watchdog heartbeat.
1002   int heartbeatTimeoutInSec_;
1003 
1004   // timeout for the dump to finish.
1005   int waitTimeoutDumpInMilSec_;
1006 
1007   // Interval of check coordinated signals in ProcessGroupNCCL from other ranks
1008   // e.g., trigger the dump of the debugging info for timeout when notified.
1009   int coordCheckIntervalMilSec_;
1010 
1011   // Size of ring buffer where we store NCCL Traces for debugging.
1012   int ncclTraceBufferSize_;
1013 
1014   // We gate the heartbeat monitor thread so that we can roll it out gradually.
1015   std::atomic<bool> monitorThreadEnabled_;
1016 
1017   // We gate the cudaEventCache so that we can roll it out gradually.
1018   std::atomic<bool> cudaEventCacheEnabled_;
1019 
1020   // Monitor thread which checks the heartbeat of Watchdog thread.
1021   // If the monitor thread finds there is no heartbeat, it will dump debug info
1022   // and then kill the watchdog thread to avoid hang.
1023   std::thread ncclHeartbeatMonitorThread_;
1024 
1025   // Watchdog thread which looks for errors on the cached NCCL communicators.
1026   std::thread ncclCommWatchdogThread_;
1027 
1028   std::thread onCompletionHookThread_;
1029 
1030   // Whether or not we should terminate the watchdog and workCleanup threads.
1031   std::atomic<bool> terminateProcessGroup_;
1032 
1033   // Whether or not we should terminate the heartbeat monitoring threads.
1034   std::atomic<bool> terminateHeartbeatMonitorThread_;
1035 
1036   // Whether we are in the shutdown mode when we are trying to get debug info,
1037   // such as desync report.
1038   std::atomic<bool> collectiveDebugInfoMode_;
1039 
1040   // Whether there are hooks pending to be fired
1041   std::atomic<bool> hasPendingHooks_;
1042 
1043   // This is the signal from watchdog threads to indicate whether the monitor
1044   // thread should dump. Making it static so that it is accessiable from all the
1045   // PGs. With this flag, monitor thread would dump debug info under any one of
1046   // the three conditions:
1047   //
1048   // 1: watchdog thread of any PG detects a collective timeout.
1049   // 2: timeout signal is received from other ranks through tcpstore.
1050   // 3: current PG's watchdog heartbeat timeout occurs.
1051   //
1052   // Note that only the monitor thread from PG0 will dump the debug info for
1053   // case one and two so that the debug info is only dumped once.
1054   static std::atomic<bool> shouldDump_;
1055 
1056   // Mutex to Guard workMetaList_
1057   std::mutex workMetaListMutex_;
1058 
1059   // Mutex to Guard monitorWakeUpCV_
1060   std::mutex monitorMutex_;
1061 
1062   bool writeDebugInfo_ = false;
1063 
1064   // Condition Variable for watchdog thread sleep
1065   std::condition_variable workMetaListCV_;
1066 
1067   // Condition Variable for monitor thread to wake up early
1068   std::condition_variable monitorWakeUpCV_;
1069 
1070   // Vector to Store WorkNCCL pointers
1071   std::list<ProcessGroupNCCL::WorkNCCL> workMetaList_;
1072 
1073   std::chrono::time_point<std::chrono::steady_clock> lastWorkListUpdateTime_;
1074 
1075   // Mutex to Guard workMetaList_
1076   std::mutex completedWorkListMutex_;
1077 
1078   // Condition Variable for watchdog thread sleep
1079   std::condition_variable completedWorkListCV_;
1080 
1081   std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList_;
1082 
1083   // Add Work Pointer to workVector
1084   void workEnqueue(c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>);
1085 
1086   // The CUDA streams used by NCCL kernels
1087   std::unordered_map<std::string, at::cuda::CUDAStream> ncclStreams_;
1088 
1089   // The CUDA events used to sync NCCL streams
1090   std::unordered_map<std::string, at::cuda::CUDAEvent> ncclEvents_;
1091 
1092   // Device Indexes used for all collectives in this group
1093   std::set<int> usedDeviceIdxs_;
1094 
1095   // Flag to denote if a coalescing groupStart/groupEnd block is active
1096   int coalescing_state_ = 0;
1097 
1098   // Stores device indexes for all collectives run inside a coalescing block
1099   at::Device coalescedDevice_ = at::Device("cuda");
1100 
1101   // Stores communicators for all collectives run inside a coalescing block
1102   std::shared_ptr<NCCLComm> coalescedComm_ = nullptr;
1103 
1104   // map from the key: "group name + pg counter (ID)" to the
1105   // unique NCCL ID count. This needs to be group and pg specific
1106   //
1107   // For each process group, we need a uniform unique NCCL ID counter to ensure
1108   // that NCCL operation in this process group can be completed successfully.
1109   // Since each process group ID belongs to a group name, the key to this map
1110   // is a combination of group name and ProcessGroupNCCL ID.
1111   static std::unordered_map<std::string, ssize_t> pgUniqueNCCLIDCnt_;
1112 
1113   // map from group name to the pg counter (ID) within that group
1114   //
1115   // For each group with the "group name" (which is the key), we need to
1116   // keep track of a unique process group ID when creating a new
1117   // ProcessGroupNCCL for this "group name". Therefore, the value of this
1118   // map keeps the unique ProcessGroupNCCL's ID for a specific group with
1119   // the "group name". The reason we need a per-group process group ID counter
1120   // is that different group can have different ranks and we need ensure that
1121   // each group has its own uniform process group ID for all its ranks.
1122   static std::unordered_map<std::string, ssize_t> processGroupCounterMap_;
1123 
1124   // Whether or not wait() and synchronize() are blocking operations that wait
1125   // for the operation to complete.
1126   bool blockingWait_ = false;
1127 
1128   // Whether or not to hook the cache allocator to register all allocated
1129   // tensors
1130   bool useTensorRegisterAllocatorHook_ = false;
1131 
1132   // Whether or not the workCleanupThread is used to perform async error
1133   // handling.
1134   ErrorHandlingMode asyncErrorHandling_ = NoHandling;
1135 
1136   // Whether or not to enable timeout root cause analysis.
1137   bool desyncDebug_;
1138 
1139   // Whether or not to dump debug info on exception including both watchdog
1140   // timeout and nccl errors.
1141   bool dumpOnTimeoutOrEx_;
1142 
1143   // Whether or not to enable nan check for input tensors to collectives.
1144   bool enableNanCheck_;
1145 
1146   // Whether or not to print C++ stack traces to logs on unclean shutdown.
1147   bool logCppStackOnUncleanShutdown_;
1148 
1149   // Whether or not to create start CUDAEvent and enable timing for start
1150   // and end events. Note that enableTiming_ is always true if desyncDebug_
1151   // is set to true.
1152   std::atomic<bool> enableTiming_;
1153 
1154   // Flag to enable the print of hash value of input/output of collectives for
1155   // verification.
1156   std::atomic<bool> enableCollecticeHashDebug_;
1157 
1158   // Whether or not TORCH_NCCL_AVOID_RECORD_STREAMS was set
1159   bool avoidRecordStreams_ = false;
1160 
1161   // Whether the NCCL watchdog should rethrow CUDA errors.
1162   bool rethrowCUDAErrors_ = false;
1163 
1164   // Set of communicators that this process group has aborted and their
1165   // ncclUniqueId has been written to the store. We don't need a lock
1166   // for this map since only the watchdog thread accesses this set. The
1167   // set contains the string representation of ncclUniqueId.
1168   std::unordered_set<std::string> abortedComms_;
1169 
1170   // The number of active ncclGroupStart() calls. This counter will be increased
1171   // by 1 when ncclGroupStart() is called and decreased by 1 when ncclGroupEnd()
1172   // is called.
1173   static thread_local uint64_t ncclActiveGroupCounter_;
1174 
1175   // Counting for the sequential number of NCCL collective call.
1176   // (specifically, how many actual kernels we launched, which differs from
1177   // op_id_ when coalescing is enabled)
1178   uint64_t seqCollective_{0};
1179 
1180   // Counting for the sequential number of NCCL P2P calls.
1181   uint64_t seqP2P_{0};
1182 
1183   // Incrementing counter for logical operations (collective or p2p) issued on
1184   // the ProcessGroup
1185   uint64_t op_id_{0};
1186 
1187   std::exception_ptr watchDogException_ = nullptr;
1188 
1189   // The number of ProcessGroupNCCL created on the current rank.
1190   size_t local_id_;
1191 
1192   std::string logPrefix_;
1193 
1194   c10::intrusive_ptr<intra_node_comm::IntraNodeComm> intraNodeComm_;
1195 
1196   // Number of devices on this node.
1197   int localDeviceCount_{0};
1198 
1199   std::shared_ptr<ProcessGroupStatus> pgStatus_ =
1200       std::make_shared<ProcessGroupStatus>();
1201 };
1202 
1203 // Dumps the NCCL comm traces and additional information about the Process
1204 // Group.
1205 TORCH_API std::string dump_nccl_trace(
1206     bool includeCollectives,
1207     bool includeStackTraces,
1208     bool onlyActive);
1209 
1210 // Dumps the NCCL comm traces and additional information about the Process
1211 // Group in JSON formatted string.
1212 // We don't include stack traces in JSON format as it is far too much data.
1213 TORCH_API std::string dump_nccl_trace_json(
1214     bool includeCollectives,
1215     bool onlyActive);
1216 
1217 // Gets a mutable reference to a global optional function.Heartbeat Monitor
1218 // will use this function to dump traces, if available. Inside fbcode, we
1219 // store a function here that uses an internal tool for process tracing
1220 TORCH_API std::optional<
1221     std::function<void(std::function<void(const std::string&)>)>>&
1222 get_cpp_trace_dumper();
1223 
1224 // Similar to get_cpp_trace_dumper, this stores a function defined in
1225 // torch-python layer that lets us check whether the GIL can be acquired,
1226 // helpful for instrumenting in cases where a hang was observed.
1227 typedef bool (*gil_checker_t)();
1228 
1229 TORCH_API gil_checker_t& get_gil_checker();
1230 } // namespace c10d
1231 
1232 #endif // USE_C10D_NCCL
1233