1 #ifdef _WIN32
2 #ifndef WIN32_LEAN_AND_MEAN
3 #define WIN32_LEAN_AND_MEAN
4 #endif
5 #include <windows.h>
6
7 #include <processthreadsapi.h>
8 #else
9 #include <unistd.h>
10 #endif // _WIN32
11
12 #include <fmt/format.h>
13 #include <fmt/ranges.h>
14 #include <chrono>
15 #include <cmath>
16 #include <fstream>
17 #include <iomanip>
18 #include <map>
19 #include <mutex>
20 #include <sstream>
21 #include <stack>
22 #include <vector>
23
24 #include <ATen/core/TensorBody.h>
25 #include <ATen/core/function_schema.h>
26 #include <ATen/core/stack.h>
27 #include <ATen/record_function.h>
28 #include <c10/util/irange.h>
29 #include <torch/csrc/profiler/standalone/execution_trace_observer.h>
30 #include <torch/csrc/profiler/util.h>
31
32 #ifdef USE_DISTRIBUTED
33 #include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
34 #endif // USE_DISTRIBUTED
35
36 using namespace at;
37
38 // Collective property attributes
39 // https://github.com/pytorch/pytorch/issues/124674
40 #ifdef USE_DISTRIBUTED
41 constexpr auto kETCommsName = "collective_name";
42 constexpr auto kETInMsgNelems = "in_msg_nelems";
43 constexpr auto kETOutMsgNelems = "out_msg_nelems";
44 constexpr auto kETInSplit = "in_split_size";
45 constexpr auto kETOutSplit = "out_split_size";
46 constexpr auto kETGlobalRankStart = "global_rank_start";
47 constexpr auto kETGlobalRankStride = "global_rank_stride";
48 constexpr auto kETGroupSize = "pg_size";
49 constexpr auto kETProcessGroupName = "pg_name";
50 constexpr auto kETProcessGroupDesc = "pg_desc";
51 #endif // USE_DISTRIBUTED
52
53 namespace torch::profiler::impl {
54
55 //******************************************************************************
56 // JSON output utility functions. To be merged with PyTorch profiler.
57 //******************************************************************************
58 template <typename T>
vectorToString(const std::vector<T> & v)59 inline std::string vectorToString(const std::vector<T>& v) {
60 return fmt::format("[{}]", fmt::join(v, ","));
61 }
62
63 std::string json_str_escape(const std::string& str);
64
65 constexpr size_t kMaxNumElements = 4096;
66
getScalarValue(const c10::IValue & val)67 inline std::string getScalarValue(const c10::IValue& val) {
68 if (val.isDouble()) {
69 double d_val = val.toDouble();
70 if (std::isinf(d_val) || std::isnan(d_val)) {
71 return fmt::format("\"{}\"", std::to_string(d_val));
72 } else {
73 return std::to_string(d_val);
74 }
75 } else if (val.isInt()) {
76 return std::to_string(val.toInt());
77 } else if (val.isBool()) {
78 return val.toBool() ? "true" : "false";
79 } else if (val.isString()) {
80 const std::string& str_val = val.toStringRef();
81 return fmt::format("\"{}\"", json_str_escape(str_val));
82 } else if (val.isDevice()) {
83 return fmt::format("\"{}\"", val.toDevice().str());
84 }
85 return fmt::format("\"<{}>\"", val.tagKind());
86 }
87
processId()88 inline int32_t processId() {
89 #ifndef _WIN32
90 return static_cast<int32_t>(getpid());
91 #else
92 return static_cast<int32_t>(GetCurrentProcessId());
93 #endif
94 }
95
96 //******************************************************************************
97 // Main ExecutionTraceObserver implementation.
98 //******************************************************************************
99
100 // ExecutionTraceObserver contains all the states of the observer. Some of them
101 // are shared between the enter and exit RecordFunction call backs, some data
102 // like the `opStack` may be accessed across different threads. So we should be
103 // careful about data races. A global mutex `gMutex` is used avoid these races
104 // at the cost of performance in large number of threads situations. We may
105 // optimize this further to thread local, fine-grained locking, or use thread
106 // safe containers.
107 struct TORCH_API ExecutionTraceObserver { // NOLINT
108 using ID = size_t;
109
110 // Mapping of each thread to its own operator stack
111 std::map<size_t, std::stack<ID>> opStack{};
112 // Uses the underlying TensorImpl object pointer as the key and map to its
113 // unique id.
114 std::map<const void*, ID> objectId{};
115 // Observer run state.
116 enum class RunState { uninitialized, disabled, enabled };
117
118 // Mutex for multithreaded access to the shared containers.
119 std::recursive_mutex gMutex{};
120 // Stream to write output JSON.
121 std::ofstream out{};
122
123 // Full path to the output file.
124 std::string fileName{};
125
126 // RecordFunction callback handle for this observer.
127 CallbackHandle cbHandle{INVALID_CALLBACK_HANDLE};
128
129 // Process ID.
130 int32_t pid{-1};
131 std::string recordTime{};
132
133 ExecutionTraceObserver() = default;
134
135 // Returns a new unique ID.
getNewIDtorch::profiler::impl::ExecutionTraceObserver136 ID getNewID() {
137 return id_++;
138 }
139
getStatetorch::profiler::impl::ExecutionTraceObserver140 RunState getState() const {
141 return state_;
142 }
143
setStatetorch::profiler::impl::ExecutionTraceObserver144 void setState(RunState newState) {
145 if (state_ == RunState::uninitialized ||
146 callbackShouldBeEnabled(state_) != callbackShouldBeEnabled(newState)) {
147 if (callbackShouldBeEnabled(newState)) {
148 reenableCallback(cbHandle);
149 } else {
150 disableCallback(cbHandle);
151 }
152 }
153 state_ = newState;
154 }
155
156 private:
callbackShouldBeEnabledtorch::profiler::impl::ExecutionTraceObserver157 static bool callbackShouldBeEnabled(RunState run_state) {
158 return run_state == ExecutionTraceObserver::RunState::enabled;
159 }
160
161 // Must use accessors to change this so that we can keep the
162 // RecordFunction callback in sync with the state.
163 RunState state_{RunState::uninitialized};
164
165 // All tensors and operators have an unique id assigned. Increment id for each
166 // new tensor or operator node.
167 // 0 -> unintialized
168 // 1 -> root ID
169 // 2 ... -> regular node ID
170 std::atomic<ID> id_{2};
171 };
172
173 // Using a singleton manager here to allow init and delete the observer object.
174 using ObserverManager = GlobalStateManager<ExecutionTraceObserver>;
175
176 // Uninitialized node has id = 0
177 const ExecutionTraceObserver::ID kUninitializedId{0};
178 // Root node has id = 1
179 const ExecutionTraceObserver::ID kRootId{1};
180
181 struct FunctionCallContext : public ObserverContext { // NOLINT
182 std::string name;
183 std::string kernelBackend;
184 std::string kernelFile;
185 ExecutionTraceObserver::ID opId{kUninitializedId};
186 ExecutionTraceObserver::ID parentId{kUninitializedId};
187 ExecutionTraceObserver::ID fwParentId{kUninitializedId};
188 std::vector<std::string> inputTypes;
189 std::vector<std::string> inputShapes;
190 std::vector<std::string> inputStrides;
191 std::vector<std::string> inputValues;
192 };
193
194 // Opens the json file to write the execution trace.
openOutputFile(const std::string & name)195 static std::ofstream openOutputFile(const std::string& name) {
196 std::ofstream stream;
197 stream.open(name, std::ofstream::out | std::ofstream::trunc);
198 if (!stream) {
199 LOG(ERROR) << "Failed to open '" << name << "'";
200 } else {
201 VLOG(1) << "PyTorch Execution Trace: writing to " << name;
202 }
203 return stream;
204 }
205
206 #ifdef USE_DISTRIBUTED
getAttrJson(const std::string & name,const std::string & type,const std::string & value)207 static inline std::string getAttrJson(
208 const std::string& name,
209 const std::string& type,
210 const std::string& value) {
211 // note name and type are not quoted but value should be if it is a string.
212 return fmt::format(
213 R"JSON(
214 {{"name": "{}", "type": "{}", "value": {}}})JSON",
215 name,
216 type,
217 value);
218 }
219 #endif
220
writeJsonNode(std::ofstream & out,const std::string & name,const uint64_t id,const uint64_t rf_id,const uint64_t parent,const uint64_t fw_parent,const int64_t seq_id,const uint64_t scope,const uint64_t tid,const uint64_t fw_tid,const std::string & inputs="[]",const std::string & inputShapes="[]",const std::string & inputStrides="[]",const std::string & inputTypes="[]",const std::string & outputs="[]",const std::string & output_shapes="[]",const std::string & output_strides="[]",const std::string & output_types="[]",const std::string & operator_schema="",const std::string & kernelBackend="",const std::string & kernelFile="",const std::string & additiona_attrs="")221 static void writeJsonNode(
222 std::ofstream& out,
223 const std::string& name,
224 const uint64_t id,
225 const uint64_t rf_id,
226 const uint64_t parent,
227 const uint64_t fw_parent,
228 const int64_t seq_id,
229 const uint64_t scope,
230 const uint64_t tid,
231 const uint64_t fw_tid,
232 const std::string& inputs = "[]",
233 const std::string& inputShapes = "[]",
234 const std::string& inputStrides = "[]",
235 const std::string& inputTypes = "[]",
236 const std::string& outputs = "[]",
237 const std::string& output_shapes = "[]",
238 const std::string& output_strides = "[]",
239 const std::string& output_types = "[]",
240 const std::string& operator_schema = "",
241 const std::string& kernelBackend = "",
242 const std::string& kernelFile = "",
243 const std::string& additiona_attrs = "") {
244 out << fmt::format(
245 R"JSON(
246 {{
247 "id": {}, "name": "{}", "ctrl_deps": {},
248 "inputs": {{"values": {}, "shapes": {}, "types": {}, "strides": {}}},
249 "outputs": {{"values": {}, "shapes": {}, "types": {}, "strides": {}}},
250 "attrs": [{{"name": "rf_id", "type": "uint64", "value": {}}},{{"name": "fw_parent", "type": "uint64", "value": {}}},{{"name": "seq_id", "type": "int64", "value": {}}},{{"name": "scope", "type": "uint64", "value": {}}},{{"name": "tid", "type": "uint64", "value": {}}},{{"name": "fw_tid", "type": "uint64", "value": {}}},{{"name": "op_schema", "type": "string", "value": "{}"}},{{"name": "kernel_backend", "type": "string", "value": "{}"}},{{"name": "kernel_file", "type": "string", "value": "{}"}}{}]
251 }})JSON",
252 id,
253 name,
254 parent,
255 inputs,
256 inputShapes,
257 inputTypes,
258 inputStrides,
259 outputs,
260 output_shapes,
261 output_types,
262 output_strides,
263 rf_id,
264 fw_parent,
265 seq_id,
266 scope,
267 tid,
268 fw_tid,
269 operator_schema,
270 kernelBackend,
271 kernelFile,
272 additiona_attrs);
273 }
274
timeString(const std::time_t timepoint)275 inline std::string timeString(const std::time_t timepoint) {
276 std::ostringstream oss;
277 oss << std::put_time(std::localtime(&timepoint), "%Y-%m-%d %X"); // NOLINT
278 return oss.str();
279 }
280
initExecutionTraceStart(ExecutionTraceObserver & ob)281 static bool initExecutionTraceStart(ExecutionTraceObserver& ob) {
282 ob.out = openOutputFile(ob.fileName);
283 // If somehow the output stream failed to open, finish observer here.
284 if (!ob.out) {
285 LOG(WARNING) << "Failed to open output file: " << ob.fileName;
286 return false;
287 }
288
289 // Wall clock time for the first op collection time.
290 const auto current_time = std::chrono::system_clock::now();
291 ob.recordTime =
292 timeString(std::chrono::system_clock::to_time_t(current_time));
293 // Start timestamp using steady_clock for measurement.
294 const auto timestamp =
295 std::chrono::duration_cast<std::chrono::milliseconds>(
296 std::chrono::steady_clock::now().time_since_epoch())
297 .count();
298
299 ob.out << fmt::format(
300 R"JSON({{
301 "schema": "1.1.1-chakra.0.0.4", "pid": {}, "time": "{}", "start_ts": {},
302 "nodes": [)JSON",
303 ob.pid,
304 ob.recordTime,
305 timestamp);
306 return true;
307 }
308
309 // Write out Execution Trace to file
finalizeExecutionTraceOutput(ExecutionTraceObserver & ob)310 static void finalizeExecutionTraceOutput(ExecutionTraceObserver& ob) {
311 writeJsonNode(
312 ob.out,
313 "[pytorch|profiler|execution_trace|process]",
314 kRootId,
315 0, // rf_id
316 kRootId, // parent is self
317 0, // fw_parent
318 -1, // seq_id
319 static_cast<std::underlying_type_t<RecordScope>>(RecordScope::USER_SCOPE),
320 0, // tid
321 0); // fw_tid
322
323 // Finish timestamp using steady_clock for measurement.
324 const auto timestamp =
325 std::chrono::duration_cast<std::chrono::milliseconds>(
326 std::chrono::steady_clock::now().time_since_epoch())
327 .count();
328 ob.out << fmt::format(
329 R"JSON(
330 ],
331 "finish_ts": {}
332 }})JSON",
333 timestamp);
334
335 ob.out.close();
336 VLOG(1) << "PyTorch Execution Trace: written to file " << ob.fileName;
337 }
338
getObjectID(ExecutionTraceObserver & ob,const void * t)339 inline ExecutionTraceObserver::ID getObjectID(
340 ExecutionTraceObserver& ob,
341 const void* t) {
342 auto iter = ob.objectId.find(t);
343 if (iter == ob.objectId.end()) {
344 ExecutionTraceObserver::ID objectId = ob.getNewID();
345 ob.objectId[t] = objectId;
346 return objectId;
347 }
348
349 return iter->second;
350 }
351
352 inline std::tuple<std::string, std::string, std::string, std::string>
convertIValue(ExecutionTraceObserver & ob,const c10::IValue & val,const bool baseType=true,const size_t maxArrayLen=kMaxNumElements)353 convertIValue(
354 ExecutionTraceObserver& ob,
355 const c10::IValue& val,
356 const bool baseType = true,
357 const size_t maxArrayLen = kMaxNumElements) {
358 std::string type = val.tagKind();
359 if (val.isTensor()) {
360 std::string tensor_shape, tensor_stride, tensor_type, tensor_value;
361
362 const auto& tensor = val.toTensor();
363 const auto tensor_impl = tensor.unsafeGetTensorImpl();
364 if (tensor.defined() && !tensor_impl->has_symbolic_sizes_strides()) {
365 // tensor shape
366 tensor_shape = vectorToString(tensor.sizes().vec());
367 // tensor strides
368 tensor_stride = vectorToString(tensor.strides().vec());
369 } else {
370 tensor_shape = "[]";
371 tensor_stride = "[]";
372 }
373 // tensor dtype
374 type = type + fmt::format("({})", std::string(tensor.dtype().name()));
375 tensor_type = baseType ? fmt::format("\"{}\"", type) : type;
376
377 ExecutionTraceObserver::ID tensor_id = getObjectID(ob, tensor_impl);
378 ExecutionTraceObserver::ID storage_id = 0;
379 size_t offset = 0;
380 size_t numel = 0;
381 size_t itemsize = 0;
382 std::string device_str = "";
383 // symbolic sizes/strides implies t->storage_offset() will fail
384 if (tensor_impl->has_storage() &&
385 !tensor_impl->has_symbolic_sizes_strides()) {
386 auto& t_storage = tensor_impl->storage();
387 storage_id = getObjectID(ob, t_storage.data());
388 offset = tensor_impl->storage_offset();
389 numel = tensor_impl->numel();
390 itemsize = tensor_impl->itemsize();
391 device_str = tensor_impl->device().str();
392 }
393 tensor_value = fmt::format(
394 "[{},{},{},{},{},\"{}\"]",
395 tensor_id,
396 storage_id,
397 offset,
398 numel,
399 itemsize,
400 device_str);
401 return std::make_tuple(
402 tensor_shape, tensor_stride, tensor_type, tensor_value);
403 } else if (val.isTuple()) {
404 const auto& val_tuple = val.toTupleRef().elements();
405 size_t tuple_size = val_tuple.size();
406 std::vector<std::string> shape_array;
407 std::vector<std::string> stride_array;
408 std::vector<std::string> type_array;
409 std::vector<std::string> value_array;
410 for (const auto j : c10::irange(tuple_size)) {
411 auto tuple = convertIValue(ob, val_tuple[j], false, maxArrayLen);
412 shape_array.push_back(std::get<0>(tuple));
413 stride_array.push_back(std::get<1>(tuple));
414 type_array.push_back(std::get<2>(tuple));
415 value_array.push_back(std::get<3>(tuple));
416 }
417 type = type + vectorToString(type_array);
418 std::string tensor_type = baseType ? fmt::format("\"{}\"", type) : type;
419 return std::make_tuple(
420 vectorToString(shape_array),
421 vectorToString(stride_array),
422 tensor_type,
423 vectorToString(value_array));
424 } else if (val.isList()) {
425 const auto& val_list = val.toList();
426 size_t list_size = val_list.size();
427 std::vector<std::string> shape_array;
428 std::vector<std::string> stride_array;
429 std::vector<std::string> type_array;
430 std::vector<std::string> value_array;
431 for (const auto j : c10::irange(list_size)) {
432 auto tuple = convertIValue(ob, val_list.get(j), false, maxArrayLen);
433 shape_array.push_back(std::get<0>(tuple));
434 stride_array.push_back(std::get<1>(tuple));
435 type_array.push_back(std::get<2>(tuple));
436 value_array.push_back(std::get<3>(tuple));
437 if (j >= maxArrayLen) {
438 LOG(WARNING) << "list size=" << val_list.size()
439 << " exceeded maxArrayLen=" << maxArrayLen;
440 break;
441 }
442 }
443 type = type + vectorToString(type_array);
444 std::string tensor_type = baseType ? fmt::format("\"{}\"", type) : type;
445 return std::make_tuple(
446 vectorToString(shape_array),
447 vectorToString(stride_array),
448 tensor_type,
449 vectorToString(value_array));
450 } else {
451 std::string tensor_shape = "[]";
452 std::string tensor_stride = "[]";
453 std::string tensor_type = baseType ? fmt::format("\"{}\"", type) : type;
454 std::string tensor_value = getScalarValue(val);
455
456 return std::make_tuple(
457 tensor_shape, tensor_stride, tensor_type, tensor_value);
458 }
459 }
460
appendValueInfo(ExecutionTraceObserver & ob,const c10::IValue & val,std::vector<std::string> & shapes,std::vector<std::string> & strides,std::vector<std::string> & types,std::vector<std::string> & values)461 inline void appendValueInfo(
462 ExecutionTraceObserver& ob,
463 const c10::IValue& val,
464 std::vector<std::string>& shapes,
465 std::vector<std::string>& strides,
466 std::vector<std::string>& types,
467 std::vector<std::string>& values) {
468 auto tuple = convertIValue(ob, val, true);
469
470 shapes.push_back(std::get<0>(tuple));
471 strides.push_back(std::get<1>(tuple));
472 types.push_back(std::get<2>(tuple));
473 values.push_back(std::get<3>(tuple));
474 }
475
handleKernelBackendInfo(FunctionCallContext & fc,const RecordFunction & fn)476 inline void handleKernelBackendInfo(
477 FunctionCallContext& fc,
478 const RecordFunction& fn) {
479 // triton kernel related information are in kwinputs
480 const auto& kwinputs = fn.kwinputs();
481 if (kwinputs.find("kernel_backend") != kwinputs.end()) {
482 fc.kernelBackend = kwinputs.at("kernel_backend").toStringRef();
483 if (fc.kernelBackend == "triton") {
484 fc.kernelFile = kwinputs.at("kernel_file").toStringRef();
485 TORCH_INTERNAL_ASSERT(
486 kwinputs.find("kernel_file") != kwinputs.end(),
487 "kernel file is missing in triton kernel");
488 // Remove the path of the file name
489 if (fc.kernelFile.find_last_of('/') != std::string::npos) {
490 fc.kernelFile =
491 fc.kernelFile.substr(fc.kernelFile.find_last_of('/') + 1);
492 }
493
494 // get grid information
495 TORCH_INTERNAL_ASSERT(
496 kwinputs.find("grid") != kwinputs.end(),
497 "grid is missing in triton kernel");
498 fc.inputValues.emplace_back(
499 "\"" + kwinputs.at("grid").toStringRef() + "\"");
500 fc.inputTypes.emplace_back("\"String\"");
501 fc.inputShapes.emplace_back("[]");
502
503 // get stream information
504 TORCH_INTERNAL_ASSERT(
505 kwinputs.find("stream") != kwinputs.end(),
506 "stream is missing in triton kernel");
507 fc.inputValues.emplace_back(
508 std::to_string(kwinputs.at("stream").toInt()));
509 fc.inputTypes.emplace_back("\"Int\"");
510 fc.inputShapes.emplace_back("[]");
511 }
512 }
513 }
514
515 // Additional attributes for commounication collectives
getCommsNodeAttrs(const RecordFunction & fn)516 inline std::string getCommsNodeAttrs(const RecordFunction& fn) { // NOLINT
517 std::vector<std::string> attrs;
518
519 #ifdef USE_DISTRIBUTED
520 // We rely on paramcommsdebug object that is available in thread local info
521 auto debugInfo = dynamic_cast<ParamCommsDebugInfo*>(
522 c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO));
523 if (debugInfo == nullptr) {
524 LOG(WARNING) << "ParamCommsDebugInfo not available for function: "
525 << fn.name();
526 return ", " + getAttrJson("debug", "string", "\"missing comms info\"");
527 }
528
529 // get NcclMeta from record function, this used ParamCommsDebugInfo above
530 auto meta = saveNcclMeta(fn, false /*truncate*/);
531
532 auto addAttr =
533 [&](const char* commsMetaName, const char* etMetaName, const char* type) {
534 auto it = meta.find(commsMetaName);
535 if (it != meta.end()) {
536 attrs.push_back(getAttrJson(etMetaName, type, it->second));
537 }
538 };
539
540 addAttr(kCommsName, kETCommsName, "string");
541 addAttr(kDtype, kDtype, "string");
542
543 addAttr(kInMsgNelems, kETInMsgNelems, "uint64");
544 addAttr(kOutMsgNelems, kETOutMsgNelems, "uint64");
545
546 // following two metadata are lists.
547 addAttr(kInSplit, kETInSplit, "string");
548 addAttr(kOutSplit, kETOutSplit, "string");
549
550 addAttr(kGlobalRankStart, kETGlobalRankStart, "uint64");
551 addAttr(kGlobalRankStride, kETGlobalRankStride, "uint64");
552
553 // pg_name is a string.
554 addAttr(kProcessGroupName, kETProcessGroupName, "string");
555 addAttr(kProcessGroupDesc, kETProcessGroupDesc, "string");
556
557 addAttr(kGroupSize, kETGroupSize, "uint64");
558
559 #endif // USE_DISTRIBUTED
560
561 // XXX consider using as string stream?
562 return attrs.empty() ? "" : fmt::format(", {}", fmt::join(attrs, ", "));
563 }
564
recordOperatorStart(ExecutionTraceObserver & ob,FunctionCallContext & fc,const RecordFunction & fn)565 static void recordOperatorStart(
566 ExecutionTraceObserver& ob,
567 FunctionCallContext& fc,
568 const RecordFunction& fn) {
569 auto tid = fn.threadId();
570
571 try {
572 const std::lock_guard<std::recursive_mutex> lock(ob.gMutex);
573
574 // if current thread stack is empty, push the root node to the stack first
575 if (ob.opStack[tid].empty()) {
576 auto thread_node_id = ob.getNewID();
577 ob.opStack[tid].push(thread_node_id);
578 writeJsonNode(
579 ob.out,
580 "[pytorch|profiler|execution_trace|thread]",
581 thread_node_id,
582 0, // rf_id
583 kRootId,
584 0, // fw_parent
585 -1, // seq_id
586 static_cast<std::underlying_type_t<RecordScope>>(
587 RecordScope::USER_SCOPE),
588 tid,
589 0); // fw_tid
590 ob.out << ",";
591 }
592 fc.name = fn.name();
593 auto num_inputs = fn.num_inputs();
594 const auto inputs = fn.inputs();
595
596 VLOG(2) << "inputs: " << num_inputs << " " << inputs.size() << '\n';
597 // We have two cases: for unboxed kernel, we have num_inputs ==
598 // inputs.size() for boxed kernel using stack, there could be more elements
599 // on the stack from previous ops.
600 // TORCH_INTERNAL_ASSERT(num_inputs <= inputs.size());
601 if (num_inputs > inputs.size()) {
602 LOG(WARNING) << "RecordFunction " << fc.name
603 << " expected num_inputs=" << num_inputs
604 << " > inputs.size()=" << inputs.size();
605 return;
606 }
607 // need to account for Stack mode where the inputs are at the end.
608 size_t input_start = inputs.size() - num_inputs;
609
610 for (const auto i : c10::irange(input_start, inputs.size())) {
611 appendValueInfo(
612 ob,
613 inputs[i],
614 fc.inputShapes,
615 fc.inputStrides,
616 fc.inputTypes,
617 fc.inputValues);
618 }
619
620 handleKernelBackendInfo(fc, fn);
621
622 fc.parentId = ob.opStack[tid].top();
623 // get parent id from the forward stack, this can be different for
624 // autograd ops, which may execute on a different thread than the original
625 // thread (which should have the parent op on the stack).
626 auto fw_tid = fn.forwardThreadId();
627 if (fw_tid != 0) {
628 fc.fwParentId = ob.opStack[fw_tid].top();
629 }
630 // all input nodes should have id > opId
631 fc.opId = ob.getNewID();
632 ob.opStack[tid].push(fc.opId);
633
634 } catch (const std::exception& e) {
635 LOG(WARNING) << "Exception in execution trace observer: " << e.what();
636 }
637 }
638
onFunctionEnter(const RecordFunction & fn)639 static std::unique_ptr<ObserverContext> onFunctionEnter(
640 const RecordFunction& fn) {
641 using RunState = ExecutionTraceObserver::RunState;
642 auto ob = ObserverManager::get();
643 if (ob != nullptr && ob->getState() == RunState::enabled) {
644 // record op
645 auto fc_ptr = std::make_unique<FunctionCallContext>();
646 recordOperatorStart(*ob, *fc_ptr.get(), fn);
647 return fc_ptr;
648 }
649 return nullptr;
650 }
651
json_str_escape(const std::string & str)652 inline std::string json_str_escape(const std::string& str) {
653 std::ostringstream ostream;
654 for (char ch : str) {
655 if (ch == '"') {
656 ostream << "\\\"";
657 } else if (ch == '\\') {
658 ostream << "\\\\";
659 } else if (ch == '\b') {
660 ostream << "\\b";
661 } else if (ch == '\f') {
662 ostream << "\\f";
663 } else if (ch == '\n') {
664 ostream << "\\n";
665 } else if (ch == '\r') {
666 ostream << "\\r";
667 } else if (ch == '\t') {
668 ostream << "\\t";
669 } else if ('\x00' <= ch && ch <= '\x1f') {
670 ostream << "\\u" << std::hex << std::setw(4) << std::setfill('0')
671 << static_cast<int>(ch);
672 } else {
673 ostream << ch;
674 }
675 }
676 return ostream.str();
677 }
678
onFunctionExit(const RecordFunction & fn,ObserverContext * ctx_ptr)679 static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) {
680 using RunState = ExecutionTraceObserver::RunState;
681 auto ob = ObserverManager::get();
682 if (ob == nullptr || ctx_ptr == nullptr) {
683 return;
684 }
685 if (ob->getState() == RunState::enabled) {
686 auto fc_ptr = dynamic_cast<FunctionCallContext*>(ctx_ptr);
687 // TORCH_INTERNAL_ASSERT(fc_ptr != nullptr);
688 if (fc_ptr == nullptr) {
689 LOG(WARNING) << "FunctionCallContext is nullptr.";
690 return;
691 }
692 auto& fc = *fc_ptr;
693
694 auto outputs = fn.outputs();
695 auto num_outputs = fn.num_outputs();
696 // We have two cases: for unboxed kernel, we have num_outputs ==
697 // outputs.size() for boxed kernel using stack, there could be more elements
698 // on the stack from previous ops.
699 VLOG(2) << "outputs: " << num_outputs << " " << outputs.size() << '\n';
700 // TORCH_INTERNAL_ASSERT(num_outputs <= outputs.size());
701 if (num_outputs > outputs.size()) {
702 LOG(WARNING) << "RecordFunction " << fc.name
703 << " num_outputs=" << num_outputs
704 << " > outputs.size()=" << outputs.size();
705 return;
706 }
707 // need to account for Stack mode where the outputs are at the end.
708 size_t output_start = outputs.size() - num_outputs;
709
710 std::vector<std::string> output_types;
711 std::vector<std::string> output_strides;
712 std::vector<std::string> output_shapes;
713 std::vector<std::string> output_values;
714 try {
715 const std::lock_guard<std::recursive_mutex> lock(ob->gMutex);
716 // remove current op id from stack
717
718 ob->opStack[fn.threadId()].pop();
719 for (const auto i : c10::irange(output_start, outputs.size())) {
720 appendValueInfo(
721 *ob,
722 outputs[i],
723 output_shapes,
724 output_strides,
725 output_types,
726 output_values);
727 }
728
729 std::string op_schema_str{};
730 const auto op_schema = fn.operator_schema();
731 if (op_schema.has_value()) {
732 op_schema_str = json_str_escape(c10::toString(op_schema.value()));
733 }
734
735 const std::string additiona_attrs =
736 fn.isNcclMeta() ? getCommsNodeAttrs(fn) : "";
737
738 writeJsonNode(
739 ob->out,
740 fc.name,
741 fc.opId,
742 fn.handle(),
743 fc.parentId,
744 fc.fwParentId,
745 fn.seqNr(),
746 static_cast<std::underlying_type_t<RecordScope>>(fn.scope()),
747 fn.threadId(),
748 fn.forwardThreadId(),
749 vectorToString(fc.inputValues),
750 vectorToString(fc.inputShapes),
751 vectorToString(fc.inputStrides),
752 vectorToString(fc.inputTypes),
753 vectorToString(output_values),
754 vectorToString(output_shapes),
755 vectorToString(output_strides),
756 vectorToString(output_types),
757 op_schema_str,
758 fc.kernelBackend,
759 fc.kernelFile,
760 additiona_attrs);
761 ob->out << ",";
762 } catch (const std::exception& e) {
763 LOG(WARNING) << "Exception in execution trace observer: [" << fc.name
764 << " (" << fc.opId << ")] " << e.what();
765 }
766 }
767 }
768
769 // Add execution trace observer callback functions to the RecordFunction global
770 // observers.
addExecutionTraceObserver(const std::string & output_file_path)771 bool addExecutionTraceObserver(const std::string& output_file_path) {
772 // Check if the observer is already initialized.
773 if (ObserverManager::get() == nullptr) {
774 ObserverManager::push(std::make_shared<ExecutionTraceObserver>());
775 auto& ob = *ObserverManager::get();
776 ob.pid = processId();
777 // Set output
778 ob.fileName = output_file_path;
779 if (!initExecutionTraceStart(ob)) {
780 return false;
781 }
782
783 ob.cbHandle = addGlobalCallback(
784 RecordFunctionCallback(&onFunctionEnter, &onFunctionExit)
785 .needsInputs(true)
786 .needsOutputs(true)
787 .needsIds(true));
788 // Default to disabled.
789 ob.setState(ExecutionTraceObserver::RunState::disabled);
790
791 VLOG(1) << "PyTorch Execution Trace: added observer, output="
792 << output_file_path;
793 } else if (ObserverManager::get()->cbHandle != INVALID_CALLBACK_HANDLE) {
794 LOG(WARNING) << "Execution trace observer is already registered.";
795 }
796 return true;
797 }
798
removeExecutionTraceObserver()799 void removeExecutionTraceObserver() {
800 auto ob = ObserverManager::get();
801 if (ob != nullptr) {
802 if (ob->getState() != ExecutionTraceObserver::RunState::disabled) {
803 disableExecutionTraceObserver();
804 }
805
806 if (ob->cbHandle != INVALID_CALLBACK_HANDLE) {
807 finalizeExecutionTraceOutput(*ob);
808 removeCallback(ob->cbHandle);
809 ob->cbHandle = INVALID_CALLBACK_HANDLE;
810 // Release the current ET observer object and reset.
811 TORCH_INTERNAL_ASSERT(
812 ObserverManager::pop() != nullptr,
813 "Global state ptr cannot be null before resetting");
814 VLOG(1) << "PyTorch Execution Trace: removed observer";
815 } else {
816 LOG(WARNING) << "Execution trace observer was not registered.";
817 }
818 } else {
819 LOG(WARNING) << "Execution trace observer was not initialized.";
820 }
821 }
822
enableExecutionTraceObserver()823 void enableExecutionTraceObserver() {
824 LOG(WARNING) << "Enabling Execution Trace Observer";
825 auto& ob = *ObserverManager::get();
826 // Make sure we are not already enabled.
827 if (ob.getState() == ExecutionTraceObserver::RunState::enabled) {
828 LOG(WARNING)
829 << "Trying to enable Execution Trace Observer when it's already enabled.";
830 } else {
831 ob.setState(ExecutionTraceObserver::RunState::enabled);
832 }
833 }
834
disableExecutionTraceObserver()835 void disableExecutionTraceObserver() {
836 LOG(WARNING) << "Disabling Execution Trace Observer";
837 auto& ob = *ObserverManager::get();
838 if (ob.getState() != ExecutionTraceObserver::RunState::disabled) {
839 ob.setState(ExecutionTraceObserver::RunState::disabled);
840 } else {
841 LOG(WARNING)
842 << "Trying to disable Execution Trace Observer when it's already disabled.";
843 }
844 }
845 } // namespace torch::profiler::impl
846