xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/ParamCommsUtils.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 #include <ATen/record_function.h>
5 #include <c10/macros/Macros.h>
6 #include <c10/util/ThreadLocalDebugInfo.h>
7 #include <string>
8 #include <vector>
9 
10 namespace torch {
11 
12 class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
13  public:
14   ParamCommsDebugInfo() = default;
15   ParamCommsDebugInfo(
16       std::tuple<std::string, std::string> pgName,
17       int rank,
18       std::string&& collName,
19       int64_t inNelems,
20       int64_t outNelems,
21       at::ScalarType dType,
22       std::vector<int64_t> inSplitSizes,
23       std::vector<int64_t> outSplitSizes,
24       int globalRankStart,
25       int globalRankStride,
26       int worldSize);
27 
28   ~ParamCommsDebugInfo() override = default;
29 
getProcessGroupName() const30   const std::string getProcessGroupName() const {
31     return std::get<0>(pgName_);
32   }
33 
getProcessGroupDesc() const34   const std::string getProcessGroupDesc() const {
35     return std::get<1>(pgName_);
36   }
37 
getRank() const38   int getRank() const {
39     return rank_;
40   }
41 
getWorldSize() const42   int getWorldSize() const {
43     return worldSize_;
44   }
45 
getGlobalRankStart() const46   int getGlobalRankStart() const {
47     return globalRankStart_;
48   }
49 
getGlobalRankStride() const50   int getGlobalRankStride() const {
51     return globalRankStride_;
52   }
53 
getCollectiveName() const54   const std::string getCollectiveName() const {
55     return collectiveName_;
56   }
57 
getInMessageNelems() const58   int64_t getInMessageNelems() const {
59     return inMessageNelems_;
60   }
61 
getOutMessageNelems() const62   int64_t getOutMessageNelems() const {
63     return outMessageNelems_;
64   }
65 
getDType() const66   at::ScalarType getDType() const {
67     return dType_;
68   }
69 
getInputSplitSizes() const70   const std::vector<int64_t>& getInputSplitSizes() const {
71     return inputSplitSizes_;
72   }
73 
getOutputSplitSizes() const74   const std::vector<int64_t>& getOutputSplitSizes() const {
75     return outputSplitSizes_;
76   }
77 
getGroupRanks() const78   const std::vector<int64_t>& getGroupRanks() const {
79     return groupRanks_;
80   }
81 
82  private:
83   std::tuple<std::string, std::string> pgName_; // <group_name, group_desc>
84   int rank_{};
85   int worldSize_{};
86   std::string collectiveName_;
87   int64_t inMessageNelems_{};
88   int64_t outMessageNelems_{};
89   at::ScalarType dType_ = at::kByte;
90   std::vector<int64_t> inputSplitSizes_;
91   std::vector<int64_t> outputSplitSizes_;
92   int globalRankStart_{};
93   int globalRankStride_{};
94   std::vector<int64_t> groupRanks_{};
95 };
96 
97 #define RECORD_PARAM_COMMS(                                                    \
98     seq,                                                                       \
99     pgName,                                                                    \
100     rank,                                                                      \
101     collName,                                                                  \
102     inNelems,                                                                  \
103     outNelems,                                                                 \
104     dType,                                                                     \
105     inSplitSizes,                                                              \
106     outSplitSizes,                                                             \
107     globalRankStart,                                                           \
108     globalRankStride,                                                          \
109     worldSize)                                                                 \
110   auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>(          \
111       pgName,                                                                  \
112       rank,                                                                    \
113       collName,                                                                \
114       inNelems,                                                                \
115       outNelems,                                                               \
116       dType,                                                                   \
117       inSplitSizes,                                                            \
118       outSplitSizes,                                                           \
119       globalRankStart,                                                         \
120       globalRankStride,                                                        \
121       worldSize);                                                              \
122   c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
123   std::initializer_list<const c10::IValue> paramList = {                       \
124       c10::IValue(seq),                                                        \
125       pgName,                                                                  \
126       rank,                                                                    \
127       collName,                                                                \
128       inSplitSizes,                                                            \
129       outSplitSizes,                                                           \
130       globalRankStart,                                                         \
131       globalRankStride,                                                        \
132       worldSize};                                                              \
133   c10::ArrayRef<const c10::IValue> paramInputs(paramList);                     \
134   RECORD_FUNCTION(at::kParamCommsCallName, paramInputs);
135 
136 #define RECORD_PARAM_COMMS_DATA(                                               \
137     seq,                                                                       \
138     pgName,                                                                    \
139     InputTensors,                                                              \
140     OutputTensors,                                                             \
141     rank,                                                                      \
142     collName,                                                                  \
143     inNelems,                                                                  \
144     outNelems,                                                                 \
145     dType,                                                                     \
146     inSplitSizes,                                                              \
147     outSplitSizes,                                                             \
148     globalRankStart,                                                           \
149     globalRankStride,                                                          \
150     worldSize)                                                                 \
151   auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>(          \
152       pgName,                                                                  \
153       rank,                                                                    \
154       collName,                                                                \
155       inNelems,                                                                \
156       outNelems,                                                               \
157       dType,                                                                   \
158       inSplitSizes,                                                            \
159       outSplitSizes,                                                           \
160       globalRankStart,                                                         \
161       globalRankStride,                                                        \
162       worldSize);                                                              \
163   c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
164   std::initializer_list<const c10::IValue> paramList = {                       \
165       c10::IValue(InputTensors),                                               \
166       c10::IValue(seq),                                                        \
167       pgName,                                                                  \
168       rank,                                                                    \
169       collName,                                                                \
170       inSplitSizes,                                                            \
171       outSplitSizes,                                                           \
172       globalRankStart,                                                         \
173       globalRankStride,                                                        \
174       worldSize};                                                              \
175   c10::ArrayRef<const c10::IValue> paramInputs(paramList);                     \
176   RECORD_FUNCTION_WITH_INPUTS_OUTPUTS(                                         \
177       at::kParamCommsCallName,                                                 \
178       paramInputs,                                                             \
179       std::vector<c10::IValue>(1, c10::IValue(OutputTensors)));
180 } // namespace torch
181