xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/UCCTracing.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_C10D_UCC
4 
5 #include <torch/csrc/distributed/c10d/UCCUtils.hpp>
6 
7 namespace c10d {
8 
9 #define RECORD_COMMS_TRACE(                                                    \
10     _comms_tracer, _work, _opType, _rank, _comm_size, _inTensors, _outTensors) \
11   do {                                                                         \
12     if (torch_ucc_config.enable_comms_logger) {                                \
13       _comms_tracer->recordComms(                                              \
14           opTypeToString(_opType),                                             \
15           (uintptr_t)_work.get(),                                              \
16           _rank,                                                               \
17           _comm_size,                                                          \
18           _inTensors,                                                          \
19           _outTensors);                                                        \
20     }                                                                          \
21   } while (0)
22 
23 // interfaces to collect communication traces
24 class TORCH_API CommTraceLogger : public torch::CustomClassHolder {
25  private:
26   std::vector<std::string> comms_trace_;
27   std::vector<std::string> curBlocks_; /* unused */
28   std::vector<int64_t> curOutSplitSizes_;
29   std::vector<int64_t> curInSplitSizes_;
30   int curRoot_ = -1;
31   unsigned long seqnum = 0;
32 
33  public:
34   void setCurBlock(const std::string& name); /* unused */
35   void popBlock(); /* unused */
36   // record root info if applicable, e.g., broadcast, gather, scatter
37   void recordOptionalInfo(int root = -1);
38   // record input/output splits of Alltoallv
39   void recordOptionalInfo(
40       const std::vector<int64_t>& outputSplitSizes = {},
41       const std::vector<int64_t>& inputSplitSizes = {});
42   // record essential comms information
43   void recordComms(
44       const std::string& collName,
45       const uintptr_t workReq = 0,
46       const int rank = -1,
47       const int world_size = -1,
48       const std::vector<at::Tensor>& inputTensors = {},
49       const std::vector<at::Tensor>& outputTensor = {});
50   // return collected comms traces
getCommsTrace()51   std::vector<std::string>& getCommsTrace() {
52     return comms_trace_;
53   }
54 };
55 
56 } // namespace c10d
57 
58 #endif // USE_C10D_UCC
59