xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/standalone/execution_trace_observer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef _WIN32
2 #ifndef WIN32_LEAN_AND_MEAN
3 #define WIN32_LEAN_AND_MEAN
4 #endif
5 #include <windows.h>
6 
7 #include <processthreadsapi.h>
8 #else
9 #include <unistd.h>
10 #endif // _WIN32
11 
12 #include <fmt/format.h>
13 #include <fmt/ranges.h>
14 #include <chrono>
15 #include <cmath>
16 #include <fstream>
17 #include <iomanip>
18 #include <map>
19 #include <mutex>
20 #include <sstream>
21 #include <stack>
22 #include <vector>
23 
24 #include <ATen/core/TensorBody.h>
25 #include <ATen/core/function_schema.h>
26 #include <ATen/core/stack.h>
27 #include <ATen/record_function.h>
28 #include <c10/util/irange.h>
29 #include <torch/csrc/profiler/standalone/execution_trace_observer.h>
30 #include <torch/csrc/profiler/util.h>
31 
32 #ifdef USE_DISTRIBUTED
33 #include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
34 #endif // USE_DISTRIBUTED
35 
36 using namespace at;
37 
38 // Collective property attributes
39 // https://github.com/pytorch/pytorch/issues/124674
40 #ifdef USE_DISTRIBUTED
41 constexpr auto kETCommsName = "collective_name";
42 constexpr auto kETInMsgNelems = "in_msg_nelems";
43 constexpr auto kETOutMsgNelems = "out_msg_nelems";
44 constexpr auto kETInSplit = "in_split_size";
45 constexpr auto kETOutSplit = "out_split_size";
46 constexpr auto kETGlobalRankStart = "global_rank_start";
47 constexpr auto kETGlobalRankStride = "global_rank_stride";
48 constexpr auto kETGroupSize = "pg_size";
49 constexpr auto kETProcessGroupName = "pg_name";
50 constexpr auto kETProcessGroupDesc = "pg_desc";
51 #endif // USE_DISTRIBUTED
52 
53 namespace torch::profiler::impl {
54 
55 //******************************************************************************
56 // JSON output utility functions. To be merged with PyTorch profiler.
57 //******************************************************************************
58 template <typename T>
vectorToString(const std::vector<T> & v)59 inline std::string vectorToString(const std::vector<T>& v) {
60   return fmt::format("[{}]", fmt::join(v, ","));
61 }
62 
63 std::string json_str_escape(const std::string& str);
64 
65 constexpr size_t kMaxNumElements = 4096;
66 
getScalarValue(const c10::IValue & val)67 inline std::string getScalarValue(const c10::IValue& val) {
68   if (val.isDouble()) {
69     double d_val = val.toDouble();
70     if (std::isinf(d_val) || std::isnan(d_val)) {
71       return fmt::format("\"{}\"", std::to_string(d_val));
72     } else {
73       return std::to_string(d_val);
74     }
75   } else if (val.isInt()) {
76     return std::to_string(val.toInt());
77   } else if (val.isBool()) {
78     return val.toBool() ? "true" : "false";
79   } else if (val.isString()) {
80     const std::string& str_val = val.toStringRef();
81     return fmt::format("\"{}\"", json_str_escape(str_val));
82   } else if (val.isDevice()) {
83     return fmt::format("\"{}\"", val.toDevice().str());
84   }
85   return fmt::format("\"<{}>\"", val.tagKind());
86 }
87 
processId()88 inline int32_t processId() {
89 #ifndef _WIN32
90   return static_cast<int32_t>(getpid());
91 #else
92   return static_cast<int32_t>(GetCurrentProcessId());
93 #endif
94 }
95 
96 //******************************************************************************
97 // Main ExecutionTraceObserver implementation.
98 //******************************************************************************
99 
100 // ExecutionTraceObserver contains all the states of the observer. Some of them
101 // are shared between the enter and exit RecordFunction call backs, some data
102 // like the `opStack` may be accessed across different threads. So we should be
103 // careful about data races. A global mutex `gMutex` is used avoid these races
104 // at the cost of performance in large number of threads situations. We may
105 // optimize this further to thread local, fine-grained locking, or use thread
106 // safe containers.
107 struct TORCH_API ExecutionTraceObserver { // NOLINT
108   using ID = size_t;
109 
110   // Mapping of each thread to its own operator stack
111   std::map<size_t, std::stack<ID>> opStack{};
112   // Uses the underlying TensorImpl object pointer as the key and map to its
113   // unique id.
114   std::map<const void*, ID> objectId{};
115   // Observer run state.
116   enum class RunState { uninitialized, disabled, enabled };
117 
118   // Mutex for multithreaded access to the shared containers.
119   std::recursive_mutex gMutex{};
120   // Stream to write output JSON.
121   std::ofstream out{};
122 
123   // Full path to the output file.
124   std::string fileName{};
125 
126   // RecordFunction callback handle for this observer.
127   CallbackHandle cbHandle{INVALID_CALLBACK_HANDLE};
128 
129   // Process ID.
130   int32_t pid{-1};
131   std::string recordTime{};
132 
133   ExecutionTraceObserver() = default;
134 
135   // Returns a new unique ID.
getNewIDtorch::profiler::impl::ExecutionTraceObserver136   ID getNewID() {
137     return id_++;
138   }
139 
getStatetorch::profiler::impl::ExecutionTraceObserver140   RunState getState() const {
141     return state_;
142   }
143 
setStatetorch::profiler::impl::ExecutionTraceObserver144   void setState(RunState newState) {
145     if (state_ == RunState::uninitialized ||
146         callbackShouldBeEnabled(state_) != callbackShouldBeEnabled(newState)) {
147       if (callbackShouldBeEnabled(newState)) {
148         reenableCallback(cbHandle);
149       } else {
150         disableCallback(cbHandle);
151       }
152     }
153     state_ = newState;
154   }
155 
156  private:
callbackShouldBeEnabledtorch::profiler::impl::ExecutionTraceObserver157   static bool callbackShouldBeEnabled(RunState run_state) {
158     return run_state == ExecutionTraceObserver::RunState::enabled;
159   }
160 
161   // Must use accessors to change this so that we can keep the
162   // RecordFunction callback in sync with the state.
163   RunState state_{RunState::uninitialized};
164 
165   // All tensors and operators have an unique id assigned. Increment id for each
166   // new tensor or operator node.
167   // 0 -> unintialized
168   // 1 -> root ID
169   // 2 ... -> regular node ID
170   std::atomic<ID> id_{2};
171 };
172 
173 // Using a singleton manager here to allow init and delete the observer object.
174 using ObserverManager = GlobalStateManager<ExecutionTraceObserver>;
175 
176 // Uninitialized node has id = 0
177 const ExecutionTraceObserver::ID kUninitializedId{0};
178 // Root node has id = 1
179 const ExecutionTraceObserver::ID kRootId{1};
180 
181 struct FunctionCallContext : public ObserverContext { // NOLINT
182   std::string name;
183   std::string kernelBackend;
184   std::string kernelFile;
185   ExecutionTraceObserver::ID opId{kUninitializedId};
186   ExecutionTraceObserver::ID parentId{kUninitializedId};
187   ExecutionTraceObserver::ID fwParentId{kUninitializedId};
188   std::vector<std::string> inputTypes;
189   std::vector<std::string> inputShapes;
190   std::vector<std::string> inputStrides;
191   std::vector<std::string> inputValues;
192 };
193 
194 // Opens the json file to write the execution trace.
openOutputFile(const std::string & name)195 static std::ofstream openOutputFile(const std::string& name) {
196   std::ofstream stream;
197   stream.open(name, std::ofstream::out | std::ofstream::trunc);
198   if (!stream) {
199     LOG(ERROR) << "Failed to open '" << name << "'";
200   } else {
201     VLOG(1) << "PyTorch Execution Trace: writing to " << name;
202   }
203   return stream;
204 }
205 
206 #ifdef USE_DISTRIBUTED
getAttrJson(const std::string & name,const std::string & type,const std::string & value)207 static inline std::string getAttrJson(
208     const std::string& name,
209     const std::string& type,
210     const std::string& value) {
211   // note name and type are not quoted but value should be if it is a string.
212   return fmt::format(
213       R"JSON(
214   {{"name": "{}", "type": "{}", "value": {}}})JSON",
215       name,
216       type,
217       value);
218 }
219 #endif
220 
writeJsonNode(std::ofstream & out,const std::string & name,const uint64_t id,const uint64_t rf_id,const uint64_t parent,const uint64_t fw_parent,const int64_t seq_id,const uint64_t scope,const uint64_t tid,const uint64_t fw_tid,const std::string & inputs="[]",const std::string & inputShapes="[]",const std::string & inputStrides="[]",const std::string & inputTypes="[]",const std::string & outputs="[]",const std::string & output_shapes="[]",const std::string & output_strides="[]",const std::string & output_types="[]",const std::string & operator_schema="",const std::string & kernelBackend="",const std::string & kernelFile="",const std::string & additiona_attrs="")221 static void writeJsonNode(
222     std::ofstream& out,
223     const std::string& name,
224     const uint64_t id,
225     const uint64_t rf_id,
226     const uint64_t parent,
227     const uint64_t fw_parent,
228     const int64_t seq_id,
229     const uint64_t scope,
230     const uint64_t tid,
231     const uint64_t fw_tid,
232     const std::string& inputs = "[]",
233     const std::string& inputShapes = "[]",
234     const std::string& inputStrides = "[]",
235     const std::string& inputTypes = "[]",
236     const std::string& outputs = "[]",
237     const std::string& output_shapes = "[]",
238     const std::string& output_strides = "[]",
239     const std::string& output_types = "[]",
240     const std::string& operator_schema = "",
241     const std::string& kernelBackend = "",
242     const std::string& kernelFile = "",
243     const std::string& additiona_attrs = "") {
244   out << fmt::format(
245       R"JSON(
246     {{
247       "id": {}, "name": "{}", "ctrl_deps": {},
248       "inputs": {{"values": {}, "shapes": {}, "types": {}, "strides": {}}},
249       "outputs": {{"values": {}, "shapes": {}, "types": {}, "strides": {}}},
250       "attrs": [{{"name": "rf_id", "type": "uint64", "value": {}}},{{"name": "fw_parent", "type": "uint64", "value": {}}},{{"name": "seq_id", "type": "int64", "value": {}}},{{"name": "scope", "type": "uint64", "value": {}}},{{"name": "tid", "type": "uint64", "value": {}}},{{"name": "fw_tid", "type": "uint64", "value": {}}},{{"name": "op_schema", "type": "string", "value": "{}"}},{{"name": "kernel_backend", "type": "string", "value": "{}"}},{{"name": "kernel_file", "type": "string", "value": "{}"}}{}]
251     }})JSON",
252       id,
253       name,
254       parent,
255       inputs,
256       inputShapes,
257       inputTypes,
258       inputStrides,
259       outputs,
260       output_shapes,
261       output_types,
262       output_strides,
263       rf_id,
264       fw_parent,
265       seq_id,
266       scope,
267       tid,
268       fw_tid,
269       operator_schema,
270       kernelBackend,
271       kernelFile,
272       additiona_attrs);
273 }
274 
timeString(const std::time_t timepoint)275 inline std::string timeString(const std::time_t timepoint) {
276   std::ostringstream oss;
277   oss << std::put_time(std::localtime(&timepoint), "%Y-%m-%d %X"); // NOLINT
278   return oss.str();
279 }
280 
initExecutionTraceStart(ExecutionTraceObserver & ob)281 static bool initExecutionTraceStart(ExecutionTraceObserver& ob) {
282   ob.out = openOutputFile(ob.fileName);
283   // If somehow the output stream failed to open, finish observer here.
284   if (!ob.out) {
285     LOG(WARNING) << "Failed to open output file: " << ob.fileName;
286     return false;
287   }
288 
289   // Wall clock time for the first op collection time.
290   const auto current_time = std::chrono::system_clock::now();
291   ob.recordTime =
292       timeString(std::chrono::system_clock::to_time_t(current_time));
293   // Start timestamp using steady_clock for measurement.
294   const auto timestamp =
295       std::chrono::duration_cast<std::chrono::milliseconds>(
296           std::chrono::steady_clock::now().time_since_epoch())
297           .count();
298 
299   ob.out << fmt::format(
300       R"JSON({{
301   "schema": "1.1.1-chakra.0.0.4", "pid": {}, "time": "{}", "start_ts": {},
302   "nodes": [)JSON",
303       ob.pid,
304       ob.recordTime,
305       timestamp);
306   return true;
307 }
308 
309 // Write out Execution Trace to file
finalizeExecutionTraceOutput(ExecutionTraceObserver & ob)310 static void finalizeExecutionTraceOutput(ExecutionTraceObserver& ob) {
311   writeJsonNode(
312       ob.out,
313       "[pytorch|profiler|execution_trace|process]",
314       kRootId,
315       0, // rf_id
316       kRootId, // parent is self
317       0, // fw_parent
318       -1, // seq_id
319       static_cast<std::underlying_type_t<RecordScope>>(RecordScope::USER_SCOPE),
320       0, // tid
321       0); // fw_tid
322 
323   // Finish timestamp using steady_clock for measurement.
324   const auto timestamp =
325       std::chrono::duration_cast<std::chrono::milliseconds>(
326           std::chrono::steady_clock::now().time_since_epoch())
327           .count();
328   ob.out << fmt::format(
329       R"JSON(
330   ],
331   "finish_ts": {}
332 }})JSON",
333       timestamp);
334 
335   ob.out.close();
336   VLOG(1) << "PyTorch Execution Trace: written to file " << ob.fileName;
337 }
338 
getObjectID(ExecutionTraceObserver & ob,const void * t)339 inline ExecutionTraceObserver::ID getObjectID(
340     ExecutionTraceObserver& ob,
341     const void* t) {
342   auto iter = ob.objectId.find(t);
343   if (iter == ob.objectId.end()) {
344     ExecutionTraceObserver::ID objectId = ob.getNewID();
345     ob.objectId[t] = objectId;
346     return objectId;
347   }
348 
349   return iter->second;
350 }
351 
352 inline std::tuple<std::string, std::string, std::string, std::string>
convertIValue(ExecutionTraceObserver & ob,const c10::IValue & val,const bool baseType=true,const size_t maxArrayLen=kMaxNumElements)353 convertIValue(
354     ExecutionTraceObserver& ob,
355     const c10::IValue& val,
356     const bool baseType = true,
357     const size_t maxArrayLen = kMaxNumElements) {
358   std::string type = val.tagKind();
359   if (val.isTensor()) {
360     std::string tensor_shape, tensor_stride, tensor_type, tensor_value;
361 
362     const auto& tensor = val.toTensor();
363     const auto tensor_impl = tensor.unsafeGetTensorImpl();
364     if (tensor.defined() && !tensor_impl->has_symbolic_sizes_strides()) {
365       // tensor shape
366       tensor_shape = vectorToString(tensor.sizes().vec());
367       // tensor strides
368       tensor_stride = vectorToString(tensor.strides().vec());
369     } else {
370       tensor_shape = "[]";
371       tensor_stride = "[]";
372     }
373     // tensor dtype
374     type = type + fmt::format("({})", std::string(tensor.dtype().name()));
375     tensor_type = baseType ? fmt::format("\"{}\"", type) : type;
376 
377     ExecutionTraceObserver::ID tensor_id = getObjectID(ob, tensor_impl);
378     ExecutionTraceObserver::ID storage_id = 0;
379     size_t offset = 0;
380     size_t numel = 0;
381     size_t itemsize = 0;
382     std::string device_str = "";
383     // symbolic sizes/strides implies t->storage_offset() will fail
384     if (tensor_impl->has_storage() &&
385         !tensor_impl->has_symbolic_sizes_strides()) {
386       auto& t_storage = tensor_impl->storage();
387       storage_id = getObjectID(ob, t_storage.data());
388       offset = tensor_impl->storage_offset();
389       numel = tensor_impl->numel();
390       itemsize = tensor_impl->itemsize();
391       device_str = tensor_impl->device().str();
392     }
393     tensor_value = fmt::format(
394         "[{},{},{},{},{},\"{}\"]",
395         tensor_id,
396         storage_id,
397         offset,
398         numel,
399         itemsize,
400         device_str);
401     return std::make_tuple(
402         tensor_shape, tensor_stride, tensor_type, tensor_value);
403   } else if (val.isTuple()) {
404     const auto& val_tuple = val.toTupleRef().elements();
405     size_t tuple_size = val_tuple.size();
406     std::vector<std::string> shape_array;
407     std::vector<std::string> stride_array;
408     std::vector<std::string> type_array;
409     std::vector<std::string> value_array;
410     for (const auto j : c10::irange(tuple_size)) {
411       auto tuple = convertIValue(ob, val_tuple[j], false, maxArrayLen);
412       shape_array.push_back(std::get<0>(tuple));
413       stride_array.push_back(std::get<1>(tuple));
414       type_array.push_back(std::get<2>(tuple));
415       value_array.push_back(std::get<3>(tuple));
416     }
417     type = type + vectorToString(type_array);
418     std::string tensor_type = baseType ? fmt::format("\"{}\"", type) : type;
419     return std::make_tuple(
420         vectorToString(shape_array),
421         vectorToString(stride_array),
422         tensor_type,
423         vectorToString(value_array));
424   } else if (val.isList()) {
425     const auto& val_list = val.toList();
426     size_t list_size = val_list.size();
427     std::vector<std::string> shape_array;
428     std::vector<std::string> stride_array;
429     std::vector<std::string> type_array;
430     std::vector<std::string> value_array;
431     for (const auto j : c10::irange(list_size)) {
432       auto tuple = convertIValue(ob, val_list.get(j), false, maxArrayLen);
433       shape_array.push_back(std::get<0>(tuple));
434       stride_array.push_back(std::get<1>(tuple));
435       type_array.push_back(std::get<2>(tuple));
436       value_array.push_back(std::get<3>(tuple));
437       if (j >= maxArrayLen) {
438         LOG(WARNING) << "list size=" << val_list.size()
439                      << " exceeded maxArrayLen=" << maxArrayLen;
440         break;
441       }
442     }
443     type = type + vectorToString(type_array);
444     std::string tensor_type = baseType ? fmt::format("\"{}\"", type) : type;
445     return std::make_tuple(
446         vectorToString(shape_array),
447         vectorToString(stride_array),
448         tensor_type,
449         vectorToString(value_array));
450   } else {
451     std::string tensor_shape = "[]";
452     std::string tensor_stride = "[]";
453     std::string tensor_type = baseType ? fmt::format("\"{}\"", type) : type;
454     std::string tensor_value = getScalarValue(val);
455 
456     return std::make_tuple(
457         tensor_shape, tensor_stride, tensor_type, tensor_value);
458   }
459 }
460 
appendValueInfo(ExecutionTraceObserver & ob,const c10::IValue & val,std::vector<std::string> & shapes,std::vector<std::string> & strides,std::vector<std::string> & types,std::vector<std::string> & values)461 inline void appendValueInfo(
462     ExecutionTraceObserver& ob,
463     const c10::IValue& val,
464     std::vector<std::string>& shapes,
465     std::vector<std::string>& strides,
466     std::vector<std::string>& types,
467     std::vector<std::string>& values) {
468   auto tuple = convertIValue(ob, val, true);
469 
470   shapes.push_back(std::get<0>(tuple));
471   strides.push_back(std::get<1>(tuple));
472   types.push_back(std::get<2>(tuple));
473   values.push_back(std::get<3>(tuple));
474 }
475 
handleKernelBackendInfo(FunctionCallContext & fc,const RecordFunction & fn)476 inline void handleKernelBackendInfo(
477     FunctionCallContext& fc,
478     const RecordFunction& fn) {
479   // triton kernel related information are in kwinputs
480   const auto& kwinputs = fn.kwinputs();
481   if (kwinputs.find("kernel_backend") != kwinputs.end()) {
482     fc.kernelBackend = kwinputs.at("kernel_backend").toStringRef();
483     if (fc.kernelBackend == "triton") {
484       fc.kernelFile = kwinputs.at("kernel_file").toStringRef();
485       TORCH_INTERNAL_ASSERT(
486           kwinputs.find("kernel_file") != kwinputs.end(),
487           "kernel file is missing in triton kernel");
488       // Remove the path of the file name
489       if (fc.kernelFile.find_last_of('/') != std::string::npos) {
490         fc.kernelFile =
491             fc.kernelFile.substr(fc.kernelFile.find_last_of('/') + 1);
492       }
493 
494       // get grid information
495       TORCH_INTERNAL_ASSERT(
496           kwinputs.find("grid") != kwinputs.end(),
497           "grid is missing in triton kernel");
498       fc.inputValues.emplace_back(
499           "\"" + kwinputs.at("grid").toStringRef() + "\"");
500       fc.inputTypes.emplace_back("\"String\"");
501       fc.inputShapes.emplace_back("[]");
502 
503       // get stream information
504       TORCH_INTERNAL_ASSERT(
505           kwinputs.find("stream") != kwinputs.end(),
506           "stream is missing in triton kernel");
507       fc.inputValues.emplace_back(
508           std::to_string(kwinputs.at("stream").toInt()));
509       fc.inputTypes.emplace_back("\"Int\"");
510       fc.inputShapes.emplace_back("[]");
511     }
512   }
513 }
514 
515 // Additional attributes for commounication collectives
getCommsNodeAttrs(const RecordFunction & fn)516 inline std::string getCommsNodeAttrs(const RecordFunction& fn) { // NOLINT
517   std::vector<std::string> attrs;
518 
519 #ifdef USE_DISTRIBUTED
520   // We rely on paramcommsdebug object that is available in thread local info
521   auto debugInfo = dynamic_cast<ParamCommsDebugInfo*>(
522       c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO));
523   if (debugInfo == nullptr) {
524     LOG(WARNING) << "ParamCommsDebugInfo not available for function: "
525                  << fn.name();
526     return ", " + getAttrJson("debug", "string", "\"missing comms info\"");
527   }
528 
529   // get NcclMeta from record function, this used ParamCommsDebugInfo above
530   auto meta = saveNcclMeta(fn, false /*truncate*/);
531 
532   auto addAttr =
533       [&](const char* commsMetaName, const char* etMetaName, const char* type) {
534         auto it = meta.find(commsMetaName);
535         if (it != meta.end()) {
536           attrs.push_back(getAttrJson(etMetaName, type, it->second));
537         }
538       };
539 
540   addAttr(kCommsName, kETCommsName, "string");
541   addAttr(kDtype, kDtype, "string");
542 
543   addAttr(kInMsgNelems, kETInMsgNelems, "uint64");
544   addAttr(kOutMsgNelems, kETOutMsgNelems, "uint64");
545 
546   // following two metadata are lists.
547   addAttr(kInSplit, kETInSplit, "string");
548   addAttr(kOutSplit, kETOutSplit, "string");
549 
550   addAttr(kGlobalRankStart, kETGlobalRankStart, "uint64");
551   addAttr(kGlobalRankStride, kETGlobalRankStride, "uint64");
552 
553   // pg_name is a string.
554   addAttr(kProcessGroupName, kETProcessGroupName, "string");
555   addAttr(kProcessGroupDesc, kETProcessGroupDesc, "string");
556 
557   addAttr(kGroupSize, kETGroupSize, "uint64");
558 
559 #endif // USE_DISTRIBUTED
560 
561   // XXX consider using as string stream?
562   return attrs.empty() ? "" : fmt::format(", {}", fmt::join(attrs, ", "));
563 }
564 
recordOperatorStart(ExecutionTraceObserver & ob,FunctionCallContext & fc,const RecordFunction & fn)565 static void recordOperatorStart(
566     ExecutionTraceObserver& ob,
567     FunctionCallContext& fc,
568     const RecordFunction& fn) {
569   auto tid = fn.threadId();
570 
571   try {
572     const std::lock_guard<std::recursive_mutex> lock(ob.gMutex);
573 
574     // if current thread stack is empty, push the root node to the stack first
575     if (ob.opStack[tid].empty()) {
576       auto thread_node_id = ob.getNewID();
577       ob.opStack[tid].push(thread_node_id);
578       writeJsonNode(
579           ob.out,
580           "[pytorch|profiler|execution_trace|thread]",
581           thread_node_id,
582           0, // rf_id
583           kRootId,
584           0, // fw_parent
585           -1, // seq_id
586           static_cast<std::underlying_type_t<RecordScope>>(
587               RecordScope::USER_SCOPE),
588           tid,
589           0); // fw_tid
590       ob.out << ",";
591     }
592     fc.name = fn.name();
593     auto num_inputs = fn.num_inputs();
594     const auto inputs = fn.inputs();
595 
596     VLOG(2) << "inputs: " << num_inputs << " " << inputs.size() << '\n';
597     // We have two cases: for unboxed kernel, we have num_inputs ==
598     // inputs.size() for boxed kernel using stack, there could be more elements
599     // on the stack from previous ops.
600     // TORCH_INTERNAL_ASSERT(num_inputs <= inputs.size());
601     if (num_inputs > inputs.size()) {
602       LOG(WARNING) << "RecordFunction " << fc.name
603                    << " expected num_inputs=" << num_inputs
604                    << " > inputs.size()=" << inputs.size();
605       return;
606     }
607     // need to account for Stack mode where the inputs are at the end.
608     size_t input_start = inputs.size() - num_inputs;
609 
610     for (const auto i : c10::irange(input_start, inputs.size())) {
611       appendValueInfo(
612           ob,
613           inputs[i],
614           fc.inputShapes,
615           fc.inputStrides,
616           fc.inputTypes,
617           fc.inputValues);
618     }
619 
620     handleKernelBackendInfo(fc, fn);
621 
622     fc.parentId = ob.opStack[tid].top();
623     // get parent id from the forward stack, this can be different for
624     // autograd ops, which may execute on a different thread than the original
625     // thread (which should have the parent op on the stack).
626     auto fw_tid = fn.forwardThreadId();
627     if (fw_tid != 0) {
628       fc.fwParentId = ob.opStack[fw_tid].top();
629     }
630     // all input nodes should have id > opId
631     fc.opId = ob.getNewID();
632     ob.opStack[tid].push(fc.opId);
633 
634   } catch (const std::exception& e) {
635     LOG(WARNING) << "Exception in execution trace observer: " << e.what();
636   }
637 }
638 
onFunctionEnter(const RecordFunction & fn)639 static std::unique_ptr<ObserverContext> onFunctionEnter(
640     const RecordFunction& fn) {
641   using RunState = ExecutionTraceObserver::RunState;
642   auto ob = ObserverManager::get();
643   if (ob != nullptr && ob->getState() == RunState::enabled) {
644     // record op
645     auto fc_ptr = std::make_unique<FunctionCallContext>();
646     recordOperatorStart(*ob, *fc_ptr.get(), fn);
647     return fc_ptr;
648   }
649   return nullptr;
650 }
651 
json_str_escape(const std::string & str)652 inline std::string json_str_escape(const std::string& str) {
653   std::ostringstream ostream;
654   for (char ch : str) {
655     if (ch == '"') {
656       ostream << "\\\"";
657     } else if (ch == '\\') {
658       ostream << "\\\\";
659     } else if (ch == '\b') {
660       ostream << "\\b";
661     } else if (ch == '\f') {
662       ostream << "\\f";
663     } else if (ch == '\n') {
664       ostream << "\\n";
665     } else if (ch == '\r') {
666       ostream << "\\r";
667     } else if (ch == '\t') {
668       ostream << "\\t";
669     } else if ('\x00' <= ch && ch <= '\x1f') {
670       ostream << "\\u" << std::hex << std::setw(4) << std::setfill('0')
671               << static_cast<int>(ch);
672     } else {
673       ostream << ch;
674     }
675   }
676   return ostream.str();
677 }
678 
onFunctionExit(const RecordFunction & fn,ObserverContext * ctx_ptr)679 static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) {
680   using RunState = ExecutionTraceObserver::RunState;
681   auto ob = ObserverManager::get();
682   if (ob == nullptr || ctx_ptr == nullptr) {
683     return;
684   }
685   if (ob->getState() == RunState::enabled) {
686     auto fc_ptr = dynamic_cast<FunctionCallContext*>(ctx_ptr);
687     // TORCH_INTERNAL_ASSERT(fc_ptr != nullptr);
688     if (fc_ptr == nullptr) {
689       LOG(WARNING) << "FunctionCallContext is nullptr.";
690       return;
691     }
692     auto& fc = *fc_ptr;
693 
694     auto outputs = fn.outputs();
695     auto num_outputs = fn.num_outputs();
696     // We have two cases: for unboxed kernel, we have num_outputs ==
697     // outputs.size() for boxed kernel using stack, there could be more elements
698     // on the stack from previous ops.
699     VLOG(2) << "outputs: " << num_outputs << " " << outputs.size() << '\n';
700     // TORCH_INTERNAL_ASSERT(num_outputs <= outputs.size());
701     if (num_outputs > outputs.size()) {
702       LOG(WARNING) << "RecordFunction " << fc.name
703                    << " num_outputs=" << num_outputs
704                    << " > outputs.size()=" << outputs.size();
705       return;
706     }
707     // need to account for Stack mode where the outputs are at the end.
708     size_t output_start = outputs.size() - num_outputs;
709 
710     std::vector<std::string> output_types;
711     std::vector<std::string> output_strides;
712     std::vector<std::string> output_shapes;
713     std::vector<std::string> output_values;
714     try {
715       const std::lock_guard<std::recursive_mutex> lock(ob->gMutex);
716       // remove current op id from stack
717 
718       ob->opStack[fn.threadId()].pop();
719       for (const auto i : c10::irange(output_start, outputs.size())) {
720         appendValueInfo(
721             *ob,
722             outputs[i],
723             output_shapes,
724             output_strides,
725             output_types,
726             output_values);
727       }
728 
729       std::string op_schema_str{};
730       const auto op_schema = fn.operator_schema();
731       if (op_schema.has_value()) {
732         op_schema_str = json_str_escape(c10::toString(op_schema.value()));
733       }
734 
735       const std::string additiona_attrs =
736           fn.isNcclMeta() ? getCommsNodeAttrs(fn) : "";
737 
738       writeJsonNode(
739           ob->out,
740           fc.name,
741           fc.opId,
742           fn.handle(),
743           fc.parentId,
744           fc.fwParentId,
745           fn.seqNr(),
746           static_cast<std::underlying_type_t<RecordScope>>(fn.scope()),
747           fn.threadId(),
748           fn.forwardThreadId(),
749           vectorToString(fc.inputValues),
750           vectorToString(fc.inputShapes),
751           vectorToString(fc.inputStrides),
752           vectorToString(fc.inputTypes),
753           vectorToString(output_values),
754           vectorToString(output_shapes),
755           vectorToString(output_strides),
756           vectorToString(output_types),
757           op_schema_str,
758           fc.kernelBackend,
759           fc.kernelFile,
760           additiona_attrs);
761       ob->out << ",";
762     } catch (const std::exception& e) {
763       LOG(WARNING) << "Exception in execution trace observer: [" << fc.name
764                    << " (" << fc.opId << ")] " << e.what();
765     }
766   }
767 }
768 
769 // Add execution trace observer callback functions to the RecordFunction global
770 // observers.
addExecutionTraceObserver(const std::string & output_file_path)771 bool addExecutionTraceObserver(const std::string& output_file_path) {
772   // Check if the observer is already initialized.
773   if (ObserverManager::get() == nullptr) {
774     ObserverManager::push(std::make_shared<ExecutionTraceObserver>());
775     auto& ob = *ObserverManager::get();
776     ob.pid = processId();
777     // Set output
778     ob.fileName = output_file_path;
779     if (!initExecutionTraceStart(ob)) {
780       return false;
781     }
782 
783     ob.cbHandle = addGlobalCallback(
784         RecordFunctionCallback(&onFunctionEnter, &onFunctionExit)
785             .needsInputs(true)
786             .needsOutputs(true)
787             .needsIds(true));
788     // Default to disabled.
789     ob.setState(ExecutionTraceObserver::RunState::disabled);
790 
791     VLOG(1) << "PyTorch Execution Trace: added observer, output="
792             << output_file_path;
793   } else if (ObserverManager::get()->cbHandle != INVALID_CALLBACK_HANDLE) {
794     LOG(WARNING) << "Execution trace observer is already registered.";
795   }
796   return true;
797 }
798 
removeExecutionTraceObserver()799 void removeExecutionTraceObserver() {
800   auto ob = ObserverManager::get();
801   if (ob != nullptr) {
802     if (ob->getState() != ExecutionTraceObserver::RunState::disabled) {
803       disableExecutionTraceObserver();
804     }
805 
806     if (ob->cbHandle != INVALID_CALLBACK_HANDLE) {
807       finalizeExecutionTraceOutput(*ob);
808       removeCallback(ob->cbHandle);
809       ob->cbHandle = INVALID_CALLBACK_HANDLE;
810       // Release the current ET observer object and reset.
811       TORCH_INTERNAL_ASSERT(
812           ObserverManager::pop() != nullptr,
813           "Global state ptr cannot be null before resetting");
814       VLOG(1) << "PyTorch Execution Trace: removed observer";
815     } else {
816       LOG(WARNING) << "Execution trace observer was not registered.";
817     }
818   } else {
819     LOG(WARNING) << "Execution trace observer was not initialized.";
820   }
821 }
822 
enableExecutionTraceObserver()823 void enableExecutionTraceObserver() {
824   LOG(WARNING) << "Enabling Execution Trace Observer";
825   auto& ob = *ObserverManager::get();
826   // Make sure we are not already enabled.
827   if (ob.getState() == ExecutionTraceObserver::RunState::enabled) {
828     LOG(WARNING)
829         << "Trying to enable Execution Trace Observer when it's already enabled.";
830   } else {
831     ob.setState(ExecutionTraceObserver::RunState::enabled);
832   }
833 }
834 
disableExecutionTraceObserver()835 void disableExecutionTraceObserver() {
836   LOG(WARNING) << "Disabling Execution Trace Observer";
837   auto& ob = *ObserverManager::get();
838   if (ob.getState() != ExecutionTraceObserver::RunState::disabled) {
839     ob.setState(ExecutionTraceObserver::RunState::disabled);
840   } else {
841     LOG(WARNING)
842         << "Trying to disable Execution Trace Observer when it's already disabled.";
843   }
844 }
845 } // namespace torch::profiler::impl
846