xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/util.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstddef>
4 #include <cstdint>
5 #include <list>
6 #include <string>
7 #include <unordered_map>
8 #include <vector>
9 
10 #include <ATen/record_function.h>
11 #include <c10/macros/Macros.h>
12 #include <c10/util/hash.h>
13 #include <torch/csrc/Export.h>
14 #include <torch/csrc/jit/frontend/source_range.h>
15 #include <optional>
16 
17 // TODO: replace with pytorch/rfcs#43 when it is ready.
18 #define SOFT_ASSERT(cond, ...)                         \
19   [&]() -> bool {                                      \
20     if (C10_UNLIKELY(!(cond))) {                       \
21       torch::profiler::impl::logSoftAssert(            \
22           __func__,                                    \
23           __FILE__,                                    \
24           static_cast<uint32_t>(__LINE__),             \
25           #cond,                                       \
26           ::c10::str(__VA_ARGS__));                    \
27       if (torch::profiler::impl::softAssertRaises()) { \
28         TORCH_INTERNAL_ASSERT(cond, __VA_ARGS__);      \
29       } else {                                         \
30         TORCH_WARN_ONCE(__VA_ARGS__);                  \
31       }                                                \
32       return false;                                    \
33     }                                                  \
34     return true;                                       \
35   }()
36 
37 namespace torch::profiler::impl {
38 TORCH_API bool softAssertRaises();
39 TORCH_API void setSoftAssertRaises(std::optional<bool> value);
40 TORCH_API void logSoftAssert(
41     const char* func,
42     const char* file,
43     uint32_t line,
44     const char* cond,
45     const char* args);
logSoftAssert(const char * func,const char * file,uint32_t line,const char * cond,::c10::detail::CompileTimeEmptyString args)46 TORCH_API inline void logSoftAssert(
47     const char* func,
48     const char* file,
49     uint32_t line,
50     const char* cond,
51     ::c10::detail::CompileTimeEmptyString args) {
52   logSoftAssert(func, file, line, cond, (const char*)args);
53 }
54 TORCH_API void logSoftAssert(
55     const char* func,
56     const char* file,
57     uint32_t line,
58     const char* cond,
59     const std::string& args);
60 
61 using shape =
62     std::variant<std::vector<int64_t>, std::vector<std::vector<int64_t>>>;
63 constexpr int TENSOR_LIST_DISPLAY_LENGTH_LIMIT = 30;
64 
65 std::string getNvtxStr(
66     const char* name,
67     int64_t sequence_nr,
68     const std::vector<std::vector<int64_t>>& shapes,
69     at::RecordFunctionHandle op_id = 0,
70     const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids =
71         {});
72 
73 struct TORCH_API FileLineFunc {
74   std::string filename;
75   size_t line;
76   std::string funcname;
77 };
78 
79 TORCH_API std::vector<FileLineFunc> prepareCallstack(
80     const std::vector<jit::StackEntry>& cs);
81 TORCH_API std::vector<std::string> callstackStr(
82     const std::vector<FileLineFunc>& cs);
83 TORCH_API std::string stacksToStr(
84     const std::vector<std::string>& stacks,
85     const char* delim);
86 TORCH_API std::vector<std::vector<int64_t>> inputSizes(
87     const at::RecordFunction& fn,
88     const bool flatten_list_enabled = false);
89 TORCH_API std::string variantShapesToStr(const std::vector<shape>& shapes);
90 TORCH_API std::string shapesToStr(
91     const std::vector<std::vector<int64_t>>& shapes);
92 TORCH_API std::string strListToStr(const std::vector<std::string>& types);
93 TORCH_API std::string inputOpIdsToStr(
94     const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids);
95 TORCH_API std::string ivalueToStr(const c10::IValue& val, bool isString);
96 TORCH_API std::string ivalueListToStr(const std::vector<c10::IValue>& list);
97 TORCH_API std::vector<std::string> inputTypes(const at::RecordFunction& fn);
98 
99 std::unordered_map<std::string, c10::IValue> TORCH_API
100 saveExtraArgs(const at::RecordFunction& fn);
101 std::unordered_map<std::string, std::string> TORCH_API
102 saveNcclMeta(const at::RecordFunction& fn, bool truncate = true);
103 
104 uint64_t TORCH_API computeFlops(
105     const std::string& op_name,
106     const std::unordered_map<std::string, c10::IValue>& extra_args);
107 
108 std::string shapeToStr(const std::vector<int64_t>& shape);
109 
110 template <typename T>
111 class TORCH_API GlobalStateManager {
112  public:
singleton()113   static GlobalStateManager& singleton() {
114     static GlobalStateManager singleton_;
115     return singleton_;
116   }
117 
push(std::shared_ptr<T> && state)118   static void push(std::shared_ptr<T>&& state) {
119     if (singleton().state_) {
120       LOG(WARNING) << "GlobalStatePtr already exists!";
121     } else {
122       singleton().state_ = std::move(state);
123     }
124   }
125 
get()126   static auto* get() {
127     return singleton().state_.get();
128   }
129 
pop()130   static std::shared_ptr<T> pop() {
131     auto out = singleton().state_;
132     singleton().state_.reset();
133     return out;
134   }
135 
136  private:
137   GlobalStateManager() = default;
138 
139   std::shared_ptr<T> state_;
140 };
141 
142 struct HashCombine {
143   template <typename T0, typename T1>
operatorHashCombine144   size_t operator()(const std::pair<T0, T1>& i) {
145     return c10::get_hash((*this)(i.first), (*this)(i.second));
146   }
147 
148   template <typename... Args>
operatorHashCombine149   size_t operator()(const std::tuple<Args...>& i) {
150     return c10::get_hash(i);
151   }
152 
153   template <typename T>
operatorHashCombine154   size_t operator()(const T& i) {
155     return c10::get_hash(i);
156   }
157 };
158 
159 #ifdef USE_DISTRIBUTED
160 constexpr auto kCommsName = "Collective name";
161 constexpr auto kDtype = "dtype";
162 constexpr auto kInMsgNelems = "In msg nelems";
163 constexpr auto kOutMsgNelems = "Out msg nelems";
164 constexpr auto kInSplit = "In split size";
165 constexpr auto kOutSplit = "Out split size";
166 constexpr auto kGlobalRankStart = "Global rank start";
167 constexpr auto kGlobalRankStride = "Global rank stride";
168 constexpr auto kGroupSize = "Group size";
169 constexpr auto kProcessGroupName = "Process Group Name";
170 constexpr auto kProcessGroupDesc = "Process Group Description";
171 constexpr auto kGroupRanks = "Process Group Ranks";
172 constexpr auto kRank = "Rank";
173 constexpr auto kP2pSrc = "Src Rank";
174 constexpr auto kP2pDst = "Dst Rank";
175 #endif // USE_DISTRIBUTED
176 
177 } // namespace torch::profiler::impl
178