1 #pragma once 2 3 #ifdef USE_C10D_UCC 4 5 #include <torch/csrc/distributed/c10d/UCCUtils.hpp> 6 7 namespace c10d { 8 9 #define RECORD_COMMS_TRACE( \ 10 _comms_tracer, _work, _opType, _rank, _comm_size, _inTensors, _outTensors) \ 11 do { \ 12 if (torch_ucc_config.enable_comms_logger) { \ 13 _comms_tracer->recordComms( \ 14 opTypeToString(_opType), \ 15 (uintptr_t)_work.get(), \ 16 _rank, \ 17 _comm_size, \ 18 _inTensors, \ 19 _outTensors); \ 20 } \ 21 } while (0) 22 23 // interfaces to collect communication traces 24 class TORCH_API CommTraceLogger : public torch::CustomClassHolder { 25 private: 26 std::vector<std::string> comms_trace_; 27 std::vector<std::string> curBlocks_; /* unused */ 28 std::vector<int64_t> curOutSplitSizes_; 29 std::vector<int64_t> curInSplitSizes_; 30 int curRoot_ = -1; 31 unsigned long seqnum = 0; 32 33 public: 34 void setCurBlock(const std::string& name); /* unused */ 35 void popBlock(); /* unused */ 36 // record root info if applicable, e.g., broadcast, gather, scatter 37 void recordOptionalInfo(int root = -1); 38 // record input/output splits of Alltoallv 39 void recordOptionalInfo( 40 const std::vector<int64_t>& outputSplitSizes = {}, 41 const std::vector<int64_t>& inputSplitSizes = {}); 42 // record essential comms information 43 void recordComms( 44 const std::string& collName, 45 const uintptr_t workReq = 0, 46 const int rank = -1, 47 const int world_size = -1, 48 const std::vector<at::Tensor>& inputTensors = {}, 49 const std::vector<at::Tensor>& outputTensor = {}); 50 // return collected comms traces getCommsTrace()51 std::vector<std::string>& getCommsTrace() { 52 return comms_trace_; 53 } 54 }; 55 56 } // namespace c10d 57 58 #endif // USE_C10D_UCC 59