1 #pragma once 2 3 #include <ATen/core/ivalue.h> 4 #include <ATen/record_function.h> 5 #include <c10/macros/Macros.h> 6 #include <c10/util/ThreadLocalDebugInfo.h> 7 #include <string> 8 #include <vector> 9 10 namespace torch { 11 12 class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { 13 public: 14 ParamCommsDebugInfo() = default; 15 ParamCommsDebugInfo( 16 std::tuple<std::string, std::string> pgName, 17 int rank, 18 std::string&& collName, 19 int64_t inNelems, 20 int64_t outNelems, 21 at::ScalarType dType, 22 std::vector<int64_t> inSplitSizes, 23 std::vector<int64_t> outSplitSizes, 24 int globalRankStart, 25 int globalRankStride, 26 int worldSize); 27 28 ~ParamCommsDebugInfo() override = default; 29 getProcessGroupName() const30 const std::string getProcessGroupName() const { 31 return std::get<0>(pgName_); 32 } 33 getProcessGroupDesc() const34 const std::string getProcessGroupDesc() const { 35 return std::get<1>(pgName_); 36 } 37 getRank() const38 int getRank() const { 39 return rank_; 40 } 41 getWorldSize() const42 int getWorldSize() const { 43 return worldSize_; 44 } 45 getGlobalRankStart() const46 int getGlobalRankStart() const { 47 return globalRankStart_; 48 } 49 getGlobalRankStride() const50 int getGlobalRankStride() const { 51 return globalRankStride_; 52 } 53 getCollectiveName() const54 const std::string getCollectiveName() const { 55 return collectiveName_; 56 } 57 getInMessageNelems() const58 int64_t getInMessageNelems() const { 59 return inMessageNelems_; 60 } 61 getOutMessageNelems() const62 int64_t getOutMessageNelems() const { 63 return outMessageNelems_; 64 } 65 getDType() const66 at::ScalarType getDType() const { 67 return dType_; 68 } 69 getInputSplitSizes() const70 const std::vector<int64_t>& getInputSplitSizes() const { 71 return inputSplitSizes_; 72 } 73 getOutputSplitSizes() const74 const std::vector<int64_t>& getOutputSplitSizes() const { 75 return outputSplitSizes_; 76 } 77 getGroupRanks() const78 const std::vector<int64_t>& getGroupRanks() const { 79 return groupRanks_; 80 } 81 82 private: 83 std::tuple<std::string, std::string> pgName_; // <group_name, group_desc> 84 int rank_{}; 85 int worldSize_{}; 86 std::string collectiveName_; 87 int64_t inMessageNelems_{}; 88 int64_t outMessageNelems_{}; 89 at::ScalarType dType_ = at::kByte; 90 std::vector<int64_t> inputSplitSizes_; 91 std::vector<int64_t> outputSplitSizes_; 92 int globalRankStart_{}; 93 int globalRankStride_{}; 94 std::vector<int64_t> groupRanks_{}; 95 }; 96 97 #define RECORD_PARAM_COMMS( \ 98 seq, \ 99 pgName, \ 100 rank, \ 101 collName, \ 102 inNelems, \ 103 outNelems, \ 104 dType, \ 105 inSplitSizes, \ 106 outSplitSizes, \ 107 globalRankStart, \ 108 globalRankStride, \ 109 worldSize) \ 110 auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \ 111 pgName, \ 112 rank, \ 113 collName, \ 114 inNelems, \ 115 outNelems, \ 116 dType, \ 117 inSplitSizes, \ 118 outSplitSizes, \ 119 globalRankStart, \ 120 globalRankStride, \ 121 worldSize); \ 122 c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \ 123 std::initializer_list<const c10::IValue> paramList = { \ 124 c10::IValue(seq), \ 125 pgName, \ 126 rank, \ 127 collName, \ 128 inSplitSizes, \ 129 outSplitSizes, \ 130 globalRankStart, \ 131 globalRankStride, \ 132 worldSize}; \ 133 c10::ArrayRef<const c10::IValue> paramInputs(paramList); \ 134 RECORD_FUNCTION(at::kParamCommsCallName, paramInputs); 135 136 #define RECORD_PARAM_COMMS_DATA( \ 137 seq, \ 138 pgName, \ 139 InputTensors, \ 140 OutputTensors, \ 141 rank, \ 142 collName, \ 143 inNelems, \ 144 outNelems, \ 145 dType, \ 146 inSplitSizes, \ 147 outSplitSizes, \ 148 globalRankStart, \ 149 globalRankStride, \ 150 worldSize) \ 151 auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \ 152 pgName, \ 153 rank, \ 154 collName, \ 155 inNelems, \ 156 outNelems, \ 157 dType, \ 158 inSplitSizes, \ 159 outSplitSizes, \ 160 globalRankStart, \ 161 globalRankStride, \ 162 worldSize); \ 163 c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \ 164 std::initializer_list<const c10::IValue> paramList = { \ 165 c10::IValue(InputTensors), \ 166 c10::IValue(seq), \ 167 pgName, \ 168 rank, \ 169 collName, \ 170 inSplitSizes, \ 171 outSplitSizes, \ 172 globalRankStart, \ 173 globalRankStride, \ 174 worldSize}; \ 175 c10::ArrayRef<const c10::IValue> paramInputs(paramList); \ 176 RECORD_FUNCTION_WITH_INPUTS_OUTPUTS( \ 177 at::kParamCommsCallName, \ 178 paramInputs, \ 179 std::vector<c10::IValue>(1, c10::IValue(OutputTensors))); 180 } // namespace torch 181