1 #pragma once
2
3 #include <cstddef>
4 #include <cstdint>
5 #include <list>
6 #include <string>
7 #include <unordered_map>
8 #include <vector>
9
10 #include <ATen/record_function.h>
11 #include <c10/macros/Macros.h>
12 #include <c10/util/hash.h>
13 #include <torch/csrc/Export.h>
14 #include <torch/csrc/jit/frontend/source_range.h>
15 #include <optional>
16
17 // TODO: replace with pytorch/rfcs#43 when it is ready.
18 #define SOFT_ASSERT(cond, ...) \
19 [&]() -> bool { \
20 if (C10_UNLIKELY(!(cond))) { \
21 torch::profiler::impl::logSoftAssert( \
22 __func__, \
23 __FILE__, \
24 static_cast<uint32_t>(__LINE__), \
25 #cond, \
26 ::c10::str(__VA_ARGS__)); \
27 if (torch::profiler::impl::softAssertRaises()) { \
28 TORCH_INTERNAL_ASSERT(cond, __VA_ARGS__); \
29 } else { \
30 TORCH_WARN_ONCE(__VA_ARGS__); \
31 } \
32 return false; \
33 } \
34 return true; \
35 }()
36
37 namespace torch::profiler::impl {
38 TORCH_API bool softAssertRaises();
39 TORCH_API void setSoftAssertRaises(std::optional<bool> value);
40 TORCH_API void logSoftAssert(
41 const char* func,
42 const char* file,
43 uint32_t line,
44 const char* cond,
45 const char* args);
logSoftAssert(const char * func,const char * file,uint32_t line,const char * cond,::c10::detail::CompileTimeEmptyString args)46 TORCH_API inline void logSoftAssert(
47 const char* func,
48 const char* file,
49 uint32_t line,
50 const char* cond,
51 ::c10::detail::CompileTimeEmptyString args) {
52 logSoftAssert(func, file, line, cond, (const char*)args);
53 }
54 TORCH_API void logSoftAssert(
55 const char* func,
56 const char* file,
57 uint32_t line,
58 const char* cond,
59 const std::string& args);
60
61 using shape =
62 std::variant<std::vector<int64_t>, std::vector<std::vector<int64_t>>>;
63 constexpr int TENSOR_LIST_DISPLAY_LENGTH_LIMIT = 30;
64
65 std::string getNvtxStr(
66 const char* name,
67 int64_t sequence_nr,
68 const std::vector<std::vector<int64_t>>& shapes,
69 at::RecordFunctionHandle op_id = 0,
70 const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids =
71 {});
72
73 struct TORCH_API FileLineFunc {
74 std::string filename;
75 size_t line;
76 std::string funcname;
77 };
78
79 TORCH_API std::vector<FileLineFunc> prepareCallstack(
80 const std::vector<jit::StackEntry>& cs);
81 TORCH_API std::vector<std::string> callstackStr(
82 const std::vector<FileLineFunc>& cs);
83 TORCH_API std::string stacksToStr(
84 const std::vector<std::string>& stacks,
85 const char* delim);
86 TORCH_API std::vector<std::vector<int64_t>> inputSizes(
87 const at::RecordFunction& fn,
88 const bool flatten_list_enabled = false);
89 TORCH_API std::string variantShapesToStr(const std::vector<shape>& shapes);
90 TORCH_API std::string shapesToStr(
91 const std::vector<std::vector<int64_t>>& shapes);
92 TORCH_API std::string strListToStr(const std::vector<std::string>& types);
93 TORCH_API std::string inputOpIdsToStr(
94 const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids);
95 TORCH_API std::string ivalueToStr(const c10::IValue& val, bool isString);
96 TORCH_API std::string ivalueListToStr(const std::vector<c10::IValue>& list);
97 TORCH_API std::vector<std::string> inputTypes(const at::RecordFunction& fn);
98
99 std::unordered_map<std::string, c10::IValue> TORCH_API
100 saveExtraArgs(const at::RecordFunction& fn);
101 std::unordered_map<std::string, std::string> TORCH_API
102 saveNcclMeta(const at::RecordFunction& fn, bool truncate = true);
103
104 uint64_t TORCH_API computeFlops(
105 const std::string& op_name,
106 const std::unordered_map<std::string, c10::IValue>& extra_args);
107
108 std::string shapeToStr(const std::vector<int64_t>& shape);
109
110 template <typename T>
111 class TORCH_API GlobalStateManager {
112 public:
singleton()113 static GlobalStateManager& singleton() {
114 static GlobalStateManager singleton_;
115 return singleton_;
116 }
117
push(std::shared_ptr<T> && state)118 static void push(std::shared_ptr<T>&& state) {
119 if (singleton().state_) {
120 LOG(WARNING) << "GlobalStatePtr already exists!";
121 } else {
122 singleton().state_ = std::move(state);
123 }
124 }
125
get()126 static auto* get() {
127 return singleton().state_.get();
128 }
129
pop()130 static std::shared_ptr<T> pop() {
131 auto out = singleton().state_;
132 singleton().state_.reset();
133 return out;
134 }
135
136 private:
137 GlobalStateManager() = default;
138
139 std::shared_ptr<T> state_;
140 };
141
142 struct HashCombine {
143 template <typename T0, typename T1>
operatorHashCombine144 size_t operator()(const std::pair<T0, T1>& i) {
145 return c10::get_hash((*this)(i.first), (*this)(i.second));
146 }
147
148 template <typename... Args>
operatorHashCombine149 size_t operator()(const std::tuple<Args...>& i) {
150 return c10::get_hash(i);
151 }
152
153 template <typename T>
operatorHashCombine154 size_t operator()(const T& i) {
155 return c10::get_hash(i);
156 }
157 };
158
159 #ifdef USE_DISTRIBUTED
160 constexpr auto kCommsName = "Collective name";
161 constexpr auto kDtype = "dtype";
162 constexpr auto kInMsgNelems = "In msg nelems";
163 constexpr auto kOutMsgNelems = "Out msg nelems";
164 constexpr auto kInSplit = "In split size";
165 constexpr auto kOutSplit = "Out split size";
166 constexpr auto kGlobalRankStart = "Global rank start";
167 constexpr auto kGlobalRankStride = "Global rank stride";
168 constexpr auto kGroupSize = "Group size";
169 constexpr auto kProcessGroupName = "Process Group Name";
170 constexpr auto kProcessGroupDesc = "Process Group Description";
171 constexpr auto kGroupRanks = "Process Group Ranks";
172 constexpr auto kRank = "Rank";
173 constexpr auto kP2pSrc = "Src Rank";
174 constexpr auto kP2pDst = "Dst Rank";
175 #endif // USE_DISTRIBUTED
176
177 } // namespace torch::profiler::impl
178