1 #include <torch/csrc/profiler/collection.h>
2 #include <torch/csrc/profiler/orchestration/vulkan.h>
3
4 #include <algorithm>
5 #include <functional>
6 #include <limits>
7 #include <memory>
8 #include <queue>
9 #include <type_traits>
10 #include <utility>
11
12 #include <fmt/format.h>
13
14 #ifdef USE_KINETO
15 #include <libkineto.h>
16 #endif
17
18 #include <ATen/Context.h>
19 #include <ATen/record_function.h>
20 #include <c10/util/Exception.h>
21 #include <c10/util/flat_hash_map.h>
22 #include <c10/util/overloaded.h>
23 #include <torch/csrc/jit/runtime/interpreter.h>
24 #include <torch/csrc/profiler/data_flow.h>
25 #include <torch/csrc/profiler/kineto_shim.h>
26
27 namespace torch::profiler::impl {
28 using result_ptr_t = std::shared_ptr<Result>;
29 using trace_ptr_t =
30 std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>;
31
RawTensorMetadataBase(const at::Tensor & t)32 RawTensorMetadataBase::RawTensorMetadataBase(const at::Tensor& t)
33 : data_{t.has_storage() ? t.storage().data() : nullptr},
34 dtype_{t.scalar_type()},
35 layout_{t.layout()},
36 size_dim_{static_cast<uint32_t>(t.sizes().size())} {
37 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
38 t.sizes().size() <= std::numeric_limits<uint32_t>::max(),
39 "Cannot profile Tensors of size > uint32 max. Got dim: ",
40 t.sizes().size());
41 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
42 t.sizes().size() == t.strides().size(),
43 "Tensor has mismatching sizes and strides. Sizes: ",
44 t.sizes().size(),
45 " Strides: ",
46 t.strides().size());
47 }
48
RawTensorMetadata(const at::Tensor & t)49 RawTensorMetadata::RawTensorMetadata(const at::Tensor& t)
50 : RawTensorMetadataBase(t),
51 weak_self_{WeakTensor(t)},
52 device_type_{t.device().type()},
53 device_index_{t.device().index()} {}
54
TensorMetadata(const RawTensorMetadata & r,std::vector<int64_t> sizes,std::vector<int64_t> strides)55 TensorMetadata::TensorMetadata(
56 const RawTensorMetadata& r,
57 std::vector<int64_t> sizes,
58 std::vector<int64_t> strides)
59 // NOLINTNEXTLINE(cppcoreguidelines-slicing)
60 : RawTensorMetadataBase(r),
61 weak_self_{r.weak_self_.value_or(WeakTensor(at::Tensor()))},
62 device_{r.device_type_, r.device_index_},
63 sizes_{std::move(sizes)},
64 strides_{std::move(strides)} {
65 SOFT_ASSERT(r.weak_self_.has_value());
66 }
67
68 // ============================================================================
69 // == PyTorch Ops =============================================================
70 // ============================================================================
71
72 namespace {
73 struct TagToIOType {
74 InputOutputEncoder::Tag tag;
75 InputOutputEncoder::IOType io_type;
76 };
77
78 constexpr int tagCount = ((int)InputOutputEncoder::Tag::TERMINATOR) + 1;
79 constexpr std::array<TagToIOType, tagCount> tag_map = {{
80 {InputOutputEncoder::Tag::Tensor, InputOutputEncoder::IOType::Shapes},
81 {InputOutputEncoder::Tag::UndefinedTensor,
82 InputOutputEncoder::IOType::Shapes},
83 {InputOutputEncoder::Tag::TensorListBegin,
84 InputOutputEncoder::IOType::Shapes},
85 {InputOutputEncoder::Tag::ScalarList,
86 InputOutputEncoder::IOType::ConcreteInputs},
87 {InputOutputEncoder::Tag::Scalar, InputOutputEncoder::IOType::Shapes},
88 {InputOutputEncoder::Tag::Other, InputOutputEncoder::IOType::Shapes},
89 {InputOutputEncoder::Tag::TERMINATOR, InputOutputEncoder::IOType::None},
90 }};
91
allTagsMapped(int idx=0)92 constexpr bool allTagsMapped(int idx = 0) {
93 return tag_map[idx].tag == InputOutputEncoder::Tag::TERMINATOR ||
94 ((idx == (int)tag_map[idx].tag) && allTagsMapped(idx + 1));
95 }
96 static_assert(allTagsMapped(), "tag_map is out of order");
97
tagToIOType(InputOutputEncoder::Tag tag)98 constexpr InputOutputEncoder::IOType tagToIOType(InputOutputEncoder::Tag tag) {
99 return tag_map[(int)tag].io_type;
100 }
101 } // namespace
102
103 // ----------------------------
104 // | Input / Output encoder |
105 // ----------------------------
push(c10::ArrayRef<const c10::IValue> values)106 void InputOutputEncoder::push(c10::ArrayRef<const c10::IValue> values) {
107 for (const auto& value : values) {
108 if (value.isTensor()) {
109 push(value.toTensor());
110 } else if (value.isScalar()) {
111 tags_.emplace_back(Tag::Scalar);
112 // Scalars are small enough that they are stored in ivalues without an
113 // extra memory alloc
114 // TODO: further optimize this by maybe giving Profiler access to the
115 // guts of IValue.
116 ivalues_.emplace_back(value);
117 } else if (value.isTensorList()) {
118 tags_.emplace_back(Tag::TensorListBegin);
119 for (const auto& t : value.toTensorList()) {
120 push(t);
121 }
122 tags_.emplace_back(Tag::TERMINATOR);
123 } else if (isSupportedScalarList(value)) {
124 tags_.emplace_back(Tag::ScalarList);
125 ivalues_.emplace_back(value);
126 } else {
127 tags_.emplace_back(Tag::Other);
128 }
129 }
130 tags_.emplace_back(Tag::TERMINATOR);
131 }
132
push(const at::Tensor & t)133 void InputOutputEncoder::push(const at::Tensor& t) {
134 // TODO fix nested and symbolic sizes
135 if (t.defined() && !t.is_nested() &&
136 !t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
137 tags_.emplace_back(Tag::Tensor);
138 tensor_metadata_.emplace_back(t);
139 tensor_sizes_strides_.copy(t.sizes());
140 if (t.layout() == at::kStrided) {
141 // Only Strided layout tensors have strides
142 tensor_sizes_strides_.copy(t.strides());
143 }
144 } else {
145 tags_.emplace_back(Tag::UndefinedTensor);
146 }
147 }
148
isSupportedScalarList(const c10::IValue & list_candidate)149 bool InputOutputEncoder::isSupportedScalarList(
150 const c10::IValue& list_candidate) {
151 // Scalar list can be very long. If a list is too long, we shouldn't
152 // collect it. This function checks whether the list is a scalar list
153 // and whether its length is sufficiently short.
154
155 if (!get_record_concrete_inputs_enabled()) {
156 return false;
157 }
158
159 if (!list_candidate.isList()) {
160 return false;
161 }
162 auto list_ref = list_candidate.toListRef();
163 if (C10_UNLIKELY(list_ref.empty())) {
164 return true;
165 }
166 if (C10_UNLIKELY(!list_ref[0].isScalar())) {
167 return false;
168 }
169 if (C10_UNLIKELY(list_ref.size() > SCALAR_LIST_LENGTH_LIMIT)) {
170 return false;
171 }
172 return true;
173 }
174
175 // This function returns a lambda which is is a custom-iterator-like getter.
176 // Each invocation of the lambda returns input values for one op.
177 //
178 // io_type is used to filter the ivalues between 'Shapes' and 'Concrete Args'.
179 // Shapes are used to represent the shapes of tensors. We save only the shapes
180 // of the tensors because tensors can be large.
181 // Concrete args are separated to clarify that they are the actual values.
getIValueGenerator(const IOType & io_type)182 auto InputOutputEncoder::getIValueGenerator(const IOType& io_type) {
183 return [this,
184 tag_it = tags_.begin(),
185 tensor_metadata_it = tensor_metadata_.begin(),
186 tensor_size_strides_it = tensor_sizes_strides_.begin(),
187 ivals_it = ivalues_.begin(),
188 io_type]() mutable {
189 auto decode_tensor = [&]() -> TensorMetadata {
190 std::vector<int64_t> sizes;
191 std::vector<int64_t> strides;
192 if (tensor_metadata_it.exhausted()) {
193 LOG(WARNING)
194 << "Tensor metadata exhausted prematurely. Reported shapes may be inaccurate!";
195 return {RawTensorMetadata(), sizes, strides};
196 }
197 const auto& raw_metadata = *tensor_metadata_it++;
198 for (C10_UNUSED const auto _ : c10::irange(raw_metadata.size_dim_)) {
199 if (tensor_size_strides_it.exhausted()) {
200 LOG(WARNING)
201 << "Expected Tensor Size mismatch with raw Tensor metadata. Reported shapes may be inaccurate!";
202 return {raw_metadata, sizes, strides};
203 }
204 sizes.push_back(*tensor_size_strides_it++);
205 }
206 if (raw_metadata.layout_ == at::kStrided) {
207 for (C10_UNUSED const auto _ : c10::irange(raw_metadata.size_dim_)) {
208 if (tensor_size_strides_it.exhausted()) {
209 LOG(WARNING)
210 << "Expected Tensor Strides mismatch with raw Tensor metadata. Reported shapes may be inaccurate!";
211 return {raw_metadata, sizes, strides};
212 }
213 strides.push_back(*tensor_size_strides_it++);
214 }
215 }
216 return {raw_metadata, sizes, strides};
217 };
218
219 std::vector<op_input_t> out;
220 auto push_value = [&out, io_type](const Tag& tag, op_input_t input) {
221 if (io_type == tagToIOType(tag)) {
222 out.emplace_back(std::move(input));
223 } else {
224 out.emplace_back(std::nullopt);
225 }
226 };
227
228 bool terminate = false;
229 while (!terminate && tag_it != tags_.end()) {
230 switch (*tag_it) {
231 case Tag::Tensor:
232 push_value(*tag_it, decode_tensor());
233 break;
234
235 case Tag::TensorListBegin: {
236 std::vector<TensorMetadata> arg;
237 bool found_undefined = false;
238 while (*(++tag_it) != Tag::TERMINATOR) {
239 if (*tag_it == Tag::UndefinedTensor) {
240 found_undefined = true;
241 continue;
242 }
243 TORCH_INTERNAL_ASSERT(*tag_it == Tag::Tensor, (int)(*tag_it));
244 arg.emplace_back(decode_tensor());
245 }
246 if (found_undefined) {
247 push_value(*tag_it, std::nullopt);
248 } else {
249 push_value(Tag::TensorListBegin, std::move(arg));
250 }
251 } break;
252
253 case Tag::ScalarList:
254 case Tag::Scalar:
255 push_value(*tag_it, *ivals_it++);
256 break;
257
258 case Tag::UndefinedTensor:
259 case Tag::Other:
260 push_value(*tag_it, std::nullopt);
261 break;
262
263 case Tag::TERMINATOR:
264 // This marks the end of this op.
265 terminate = true;
266 break;
267
268 default:
269 break;
270 }
271 ++tag_it;
272 }
273 return out;
274 };
275 }
276
getInputShapeGenerator()277 auto InputOutputEncoder::getInputShapeGenerator() {
278 return getIValueGenerator(IOType::Shapes);
279 }
280
getConcreteInputGenerator()281 auto InputOutputEncoder::getConcreteInputGenerator() {
282 return getIValueGenerator(IOType::ConcreteInputs);
283 }
284
clear()285 void InputOutputEncoder::clear() {
286 tags_.clear();
287 tensor_metadata_.clear();
288 tensor_sizes_strides_.clear();
289 ivalues_.clear();
290 }
291
292 // ---------------------------------------------------
293 // | Correlation ID tracking (OpList & EventBlock) |
294 // ---------------------------------------------------
295 template <typename T, size_t ChunkSize>
EventBlock()296 ThreadLocalSubqueue::TorchOpStorage::EventBlock<T, ChunkSize>::EventBlock() {
297 static std::atomic<uint64_t> counter_{0};
298 id_start_ = 1 + ChunkSize * counter_++;
299 }
300
301 template <class... Args>
302 std::pair<KinetoObserverContext::Event*, uint64_t> ThreadLocalSubqueue::
emplace_back(Args &&...args)303 TorchOpStorage::OpList::emplace_back(Args&&... args) {
304 auto event_ptr = AppendOnlyList::emplace_back(std::forward<Args>(args)...);
305 auto corr_id = buffer_last_->correlation_id(event_ptr);
306 return {event_ptr, corr_id};
307 }
308
correlationID(const OpList::Iterator & e)309 uint64_t ThreadLocalSubqueue::TorchOpStorage::OpList::correlationID(
310 const OpList::Iterator& e) {
311 return e.address().first->correlation_id(&*e);
312 }
313
314 template <typename T, size_t ChunkSize>
315 uint64_t ThreadLocalSubqueue::TorchOpStorage::EventBlock<T, ChunkSize>::
correlation_id(const T * ptr) const316 correlation_id(const T* ptr) const {
317 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
318 ptr >= this->data() && ptr < this->data() + ChunkSize);
319 return id_start_ + (ptr - this->data());
320 }
321
322 // ---------------------------------
323 // | Collection (Observer logic) |
324 // ---------------------------------
begin_op(const at::RecordFunction & fn)325 std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op(
326 const at::RecordFunction& fn) {
327 auto [event, corr_id] = torch_ops_.op_events_.emplace_back(
328 torch::profiler::impl::TorchOpBasicFields{
329 fn.seqNr(),
330 fn.forwardThreadId(),
331 fn.scope(),
332 fn.isAsync(),
333 fn.handle(),
334 fn.debugHandle(),
335 fn.name()});
336 if (config_.report_input_shapes) {
337 torch_ops_.inputs_outputs_.push(fn.inputs());
338 torch_ops_.kwinputs_.emplace_back(fn.kwinputs());
339 }
340 if (fn.scope() == at::RecordScope::USER_SCOPE) {
341 torch::profiler::impl::kineto::pushUserCorrelationId(corr_id);
342 } else {
343 torch::profiler::impl::kineto::pushCorrelationId(corr_id);
344 }
345
346 #if !defined BUILD_LITE_INTERPRETER && !defined C10_MOBILE
347 // backward nodes source range corresponds to the forward node
348 // TODO: consider using C++ stack trace
349 if (config_.with_stack && fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
350 auto cs = torch::profiler::impl::prepareCallstack(jit::currentCallstack());
351 torch_ops_.jit_stack_.emplace_back(callstackStr(cs));
352 }
353 if (config_.with_modules &&
354 fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
355 torch_ops_.jit_modules_.emplace_back(jit::currentModuleHierarchy());
356 }
357 #endif
358 if (config_.with_flops) {
359 torch_ops_.extra_args_.emplace_back(
360 torch::profiler::impl::saveExtraArgs(fn));
361 }
362
363 // Record NCCL metadata for specific CPU ops
364 fn.isNcclMeta() ? torch_ops_.extra_meta_.emplace_back(
365 torch::profiler::impl::saveNcclMeta(fn))
366 : torch_ops_.extra_meta_.emplace_back();
367
368 auto out = std::make_unique<KinetoObserverContext>(event);
369
370 if (config_.state == ProfilerState::KINETO_GPU_FALLBACK) {
371 try {
372 out->fallback_ = torch_ops_.device_fallback_.emplace_back();
373 torch::profiler::impl::cudaStubs()->record(
374 nullptr, &out->fallback_->device_event_start_, nullptr);
375 } catch (const std::exception& e) {
376 LOG(WARNING) << "Failed to record CUDA event. " << e.what();
377 }
378 } else if (config_.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
379 out->fallback_ = torch_ops_.device_fallback_.emplace_back();
380 torch::profiler::impl::privateuse1Stubs()->record(
381 nullptr, &out->fallback_->device_event_start_, nullptr);
382 }
383
384 event->start_time_ = c10::getApproximateTime();
385 event->allow_tf32_cublas_ = at::globalContext().allowTF32CuBLAS();
386 if (!config_.experimental_config.performance_events.empty()) {
387 const size_t n = config_.experimental_config.performance_events.size();
388 event->counters_ = std::make_unique<perf_counters_t>(n, 0);
389 perf_profiler_->Enable();
390 }
391 return out;
392 }
393
394 // ---------------
395 // | Collation |
396 // ---------------
397 namespace {
398 template <typename T>
399 struct StealOrDefault {
StealOrDefaulttorch::profiler::impl::__anon1af765770511::StealOrDefault400 StealOrDefault(T& container)
401 : container_{container}, it_{container.begin()} {}
402
~StealOrDefaulttorch::profiler::impl::__anon1af765770511::StealOrDefault403 ~StealOrDefault() {
404 container_.get().clear();
405 }
406
operator ()torch::profiler::impl::__anon1af765770511::StealOrDefault407 typename T::Iterator::value_type operator()() {
408 if (it_.exhausted()) {
409 return typename T::Iterator::value_type();
410 } else {
411 auto result = std::move(*it_);
412 ++it_;
413 return result;
414 }
415 }
416
417 std::reference_wrapper<T> container_;
418 typename T::Iterator it_;
419 };
420 } // namespace
421
422 std::string profilerStepString = "ProfilerStep#";
423
materialize(std::vector<std::shared_ptr<Result>> & out,std::vector<ProfilerStepInfo> & step_info,const std::function<c10::time_t (c10::approx_time_t)> & time_converter,const uint64_t tid,const kineto::DeviceAndResource & kineto_info)424 void ThreadLocalSubqueue::TorchOpStorage::materialize(
425 std::vector<std::shared_ptr<Result>>& out,
426 std::vector<ProfilerStepInfo>& step_info,
427 const std::function<c10::time_t(c10::approx_time_t)>& time_converter,
428 const uint64_t tid,
429 const kineto::DeviceAndResource& kineto_info) {
430 // Plumb Autograd info to the top level annotation.
431 auto it = op_events_.begin();
432 for (C10_UNUSED const auto _ :
433 c10::irange(static_cast<int64_t>(op_events_.size()) - 1)) {
434 auto& first = it->basic_fields_;
435 auto& second = (++it)->basic_fields_;
436 if (first.scope_ == at::RecordScope::FUNCTION &&
437 second.scope_ == at::RecordScope::BACKWARD_FUNCTION &&
438 first.name_.rfind("autograd::engine::evaluate_function: ", 0) == 0) {
439 first.sequence_number_ = second.sequence_number_;
440 first.forward_tid_ = second.forward_tid_;
441 }
442 }
443
444 // `AccumulateGrad` is an important marker for profile analysis; however the
445 // annotation relies on `c10::demangle` which is platform dependent. In
446 // particular, Windows will add a "struct " prefix.
447 const std::string accumulate_grad = "torch::autograd::AccumulateGrad";
448 const std::string windows_pattern = std::string("struct ") + accumulate_grad;
449 for (auto& event : op_events_) {
450 auto& name = event.basic_fields_.name_;
451 auto position = name.find(windows_pattern);
452 if (position != std::string::npos) {
453 name.replace(position, windows_pattern.size(), accumulate_grad);
454 }
455 }
456
457 auto input_shape_getter = inputs_outputs_.getInputShapeGenerator();
458 auto concrete_input_getter = inputs_outputs_.getConcreteInputGenerator();
459
460 // TODO: CTAD will take care of template args when we move to C++17
461 auto jit_stack = StealOrDefault<decltype(jit_stack_)>(jit_stack_);
462 auto jit_module = StealOrDefault<decltype(jit_modules_)>(jit_modules_);
463 auto extra_args = StealOrDefault<decltype(extra_args_)>(extra_args_);
464 auto extra_meta = StealOrDefault<decltype(extra_meta_)>(extra_meta_);
465 auto kwinputs = StealOrDefault<decltype(kwinputs_)>(kwinputs_);
466 auto gpu_fallback =
467 StealOrDefault<decltype(device_fallback_)>(device_fallback_);
468
469 for (auto event = op_events_.begin(); event != op_events_.end(); ++event) {
470 ExtraFields<EventType::TorchOp> e{
471 std::move(event->basic_fields_),
472 ThreadLocalSubqueue::TorchOpStorage::OpList::correlationID(event),
473 time_converter(event->end_time_),
474 input_shape_getter(),
475 concrete_input_getter(),
476 jit_stack(),
477 jit_module(),
478 extra_args(),
479 extra_meta(),
480 kwinputs(),
481 gpu_fallback(),
482 event->allow_tf32_cublas_,
483 std::move(event->counters_)};
484
485 if (e.name_.find(profilerStepString) != std::string::npos) {
486 step_info.emplace_back(
487 time_converter(event->start_time_),
488 time_converter(event->end_time_),
489 out.size());
490 }
491 out.emplace_back(Result::create(
492 time_converter(event->start_time_), tid, kineto_info, std::move(e)));
493 }
494
495 op_events_.clear();
496 inputs_outputs_.clear();
497 }
498
499 template <size_t BlockSize>
materialize_vulkan(std::vector<std::shared_ptr<Result>> & out,AppendOnlyList<ExtraFields<EventType::Vulkan>::raw_event_t,BlockSize> & raw_events,const std::function<c10::time_t (c10::approx_time_t)> & time_converter,const uint64_t tid,const kineto::DeviceAndResource & kineto_info)500 void materialize_vulkan(
501 std::vector<std::shared_ptr<Result>>& out,
502 AppendOnlyList<ExtraFields<EventType::Vulkan>::raw_event_t, BlockSize>&
503 raw_events,
504 const std::function<c10::time_t(c10::approx_time_t)>& time_converter,
505 const uint64_t tid,
506 const kineto::DeviceAndResource& kineto_info) {
507 for (const auto& i : raw_events) {
508 const auto name_and_duration_ns =
509 torch::profiler::impl::vulkan::getShaderNameAndDurationNs(i.second);
510
511 out.emplace_back(Result::create(
512 /*start_time_ns_=*/time_converter(i.first),
513 /*start_tid_=*/tid,
514 /*kineto_info_=*/kineto_info,
515 /*extra_fields_=*/
516 ExtraFields<EventType::Vulkan>{
517 /*name_=*/std::get<0>(name_and_duration_ns),
518 /*duration_ns_=*/
519 static_cast<int64_t>(std::get<1>(name_and_duration_ns)),
520 /*in_tree_building_=*/false}));
521 }
522 }
523
524 namespace {
525 // See `RecordQueue::getSubqueue()` for an overview of this cache.
526 struct SubQueueThreadCache {
527 uint32_t key_;
528 ThreadLocalSubqueue* ref_;
529 };
530
531 // The astute observer will note that this leaves a dangling reference; nothing
532 // in the teardown of `RecordQueue` or `ThreadLocalSubqueue` clears this value.
533 // (And the raw pointer in `SubQueueThreadCache` will not extend the lifetime
534 // of `*ref_`.) This is safe, however, because `getSubqueue` will check
535 // `sub_queue_cache_.key_` before attempting to access `ref_`, and if `key_`
536 // does not match the RecordQueue's *unique* `id_` it will evict
537 // `sub_queue_cache_` and fall back to a different mechanism.
538 std::atomic<uint32_t> queue_id_{0};
539 thread_local SubQueueThreadCache sub_queue_cache_{0, nullptr};
540
toString(const ExtraFields<EventType::PyCall> & e)541 std::string toString(const ExtraFields<EventType::PyCall>& e) {
542 if (e.module_.has_value()) {
543 return fmt::format(
544 "nn.Module: {}_{}", e.module_->cls_name_.str(), e.module_->id_);
545 }
546 return fmt::format(
547 "{}({}): {}",
548 e.callsite_.filename_.str(),
549 e.callsite_.line_no_,
550 e.callsite_.funcname_.str());
551 }
552
scopeToType(at::RecordScope scope)553 auto scopeToType(at::RecordScope scope) {
554 return scope == at::RecordScope::USER_SCOPE
555 ? libkineto::ActivityType::USER_ANNOTATION
556 : libkineto::ActivityType::CPU_OP;
557 }
558
torchOpEndNS(const ExtraFields<EventType::TorchOp> & e,const bool finished,const std::weak_ptr<Result> & parent)559 int64_t torchOpEndNS(
560 const ExtraFields<EventType::TorchOp>& e,
561 const bool finished,
562 const std::weak_ptr<Result>& parent) {
563 if (finished && e.end_time_ns_ == std::numeric_limits<c10::time_t>::min()) {
564 auto p = parent.lock();
565 if (p) {
566 return p->endTimeNS();
567 }
568 }
569 return e.end_time_ns_;
570 }
571
kinetoEventCorrelationID(const ExtraFields<EventType::Kineto> & e,const std::weak_ptr<Result> & parent)572 auto kinetoEventCorrelationID(
573 const ExtraFields<EventType::Kineto>& e,
574 const std::weak_ptr<Result>& parent) {
575 if (e.correlation_id_) {
576 return e.correlation_id_;
577 }
578 auto p = parent.lock();
579 return p ? p->correlationID() : 0;
580 }
581 } // namespace
582
583 #define ATTRIBUTE(event_type, expr) \
584 [&](const ExtraFields<EventType::event_type>& e) { \
585 (void)e; \
586 return expr; \
587 }
588
name() const589 std::string Result::name() const {
590 return visit(c10::overloaded(
591 ATTRIBUTE(Vulkan, std::string(e.name_)),
592 ATTRIBUTE(Allocation, std::string("[memory]")),
593 ATTRIBUTE(OutOfMemory, std::string("[OutOfMemory]")),
594 ATTRIBUTE(PyCall, toString(e)),
595 ATTRIBUTE(PyCCall, std::string(e.function_name_.str())),
596 [](const auto& e) -> std::string { return e.name_; }));
597 }
598
kinetoType() const599 libkineto::ActivityType Result::kinetoType() const {
600 return visit(c10::overloaded(
601 ATTRIBUTE(TorchOp, scopeToType(e.scope_)),
602 ATTRIBUTE(Backend, scopeToType(e.scope_)),
603 ATTRIBUTE(Vulkan, libkineto::ActivityType::CPU_OP),
604 ATTRIBUTE(Allocation, libkineto::ActivityType::CPU_INSTANT_EVENT),
605 ATTRIBUTE(OutOfMemory, libkineto::ActivityType::CPU_INSTANT_EVENT),
606 ATTRIBUTE(PyCall, libkineto::ActivityType::PYTHON_FUNCTION),
607 ATTRIBUTE(PyCCall, libkineto::ActivityType::PYTHON_FUNCTION),
608 ATTRIBUTE(Kineto, e.activity_type_)));
609 }
610
correlationID() const611 uint64_t Result::correlationID() const {
612 return visit(c10::overloaded(
613 ATTRIBUTE(TorchOp, e.correlation_id_),
614 ATTRIBUTE(Kineto, kinetoEventCorrelationID(e, parent_)),
615 [&](const auto&) -> uint64_t { return 0; }));
616 }
617
endTimeNS() const618 int64_t Result::endTimeNS() const {
619 auto end_time_ns = visit(c10::overloaded(
620 ATTRIBUTE(TorchOp, torchOpEndNS(e, finished_, parent_)),
621 ATTRIBUTE(Backend, e.end_time_us_ * 1000),
622 ATTRIBUTE(
623 Vulkan, start_time_ns_ + (e.in_tree_building_ ? 0 : e.duration_ns_)),
624 ATTRIBUTE(Allocation, start_time_ns_),
625 ATTRIBUTE(OutOfMemory, start_time_ns_),
626 ATTRIBUTE(Kineto, start_time_ns_ + e.duration_ns_),
627 [&](const auto& e) -> int64_t { return e.end_time_ns_; }));
628
629 // In rare cases we're willing to tolerate ops which are missing an end time
630 // so long as they can borrow their parent's end time. A consequence of this,
631 // however, is that `endTimeNS` may not make sense until tree construction is
632 // complete.
633 auto end_time_is_valid =
634 !finished_ || SOFT_ASSERT(end_time_ns >= start_time_ns_, name());
635 return end_time_is_valid ? end_time_ns : start_time_ns_;
636 }
637
endTID() const638 uint64_t Result::endTID() const {
639 return visit(c10::overloaded(
640 ATTRIBUTE(TorchOp, e.end_tid_),
641 [&](const auto&) -> uint64_t { return start_tid_; }));
642 }
643
deviceType() const644 c10::DeviceType Result::deviceType() const {
645 using torch::autograd::profiler::deviceTypeFromActivity;
646 return visit(c10::overloaded(
647 ATTRIBUTE(Vulkan, c10::DeviceType::Vulkan),
648 ATTRIBUTE(Allocation, e.device_type_),
649 ATTRIBUTE(OutOfMemory, e.device_type_),
650 ATTRIBUTE(Kineto, deviceTypeFromActivity(e.activity_type_)),
651 [&](const auto&) { return c10::DeviceType::CPU; }));
652 }
653 #undef ATTRIBUTE
654
ThreadLocalSubqueue(const uint64_t tid,ProfilerConfig config)655 ThreadLocalSubqueue::ThreadLocalSubqueue(
656 const uint64_t tid,
657 ProfilerConfig config)
658 : tid_{tid},
659 config_{std::move(config)},
660 kineto_info_{kineto::kineto_ids()} {
661 torch::profiler::impl::kineto::recordThreadInfo();
662 if (!config_.experimental_config.performance_events.empty()) {
663 perf_profiler_ =
664 std::make_unique<torch::profiler::impl::linux_perf::PerfProfiler>();
665 perf_profiler_->Configure(config_.experimental_config.performance_events);
666 }
667 }
668
RecordQueue(ProfilerConfig config,std::set<ActivityType> activities)669 RecordQueue::RecordQueue(
670 ProfilerConfig config,
671 std::set<ActivityType> activities)
672 : id_(++queue_id_),
673 config_{std::move(config)},
674 activities_{std::move(activities)} {
675 if (tracePython()) {
676 python_tracer_ = python_tracer::PythonTracerBase::make(this);
677 }
678 }
679
tracePython() const680 bool RecordQueue::tracePython() const {
681 return config_.with_stack && activities_.count(ActivityType::CPU);
682 }
683
getSubqueue()684 ThreadLocalSubqueue* RecordQueue::getSubqueue() {
685 // In the most common case, a thread will want to write to the same sub-queue
686 // that it wrote to last call. The only time that isn't true is if:
687 // A) The profiler context has ended and we are in a new one.
688 // B) Two profilers are active in different TLS contexts, and this thread
689 // is a worker helping with intra-op parallelism.
690 // Since we expect this to be the OVERWHELMINGLY common case (>99%), we add a
691 // special thread_local cache so that we can skip the overall `flat_hash_map`
692 // (and corresponding lock).
693 if (id_ == sub_queue_cache_.key_) {
694 return sub_queue_cache_.ref_;
695 }
696
697 const auto tid = at::RecordFunction::currentThreadId();
698 std::lock_guard<std::mutex> guard(sub_queue_mutex_);
699 auto it = sub_queues_.find(tid);
700 if (it == sub_queues_.end()) {
701 it = sub_queues_
702 .emplace(tid, std::make_unique<ThreadLocalSubqueue>(tid, config_))
703 .first;
704 }
705
706 sub_queue_cache_ = SubQueueThreadCache{id_, it->second.get()};
707 return it->second.get();
708 }
709
stop()710 void RecordQueue::stop() {
711 if (python_tracer_) {
712 python_tracer_->stop();
713 }
714 }
715
restart()716 void RecordQueue::restart() {
717 if (python_tracer_) {
718 python_tracer_->restart();
719 }
720 }
721
722 namespace {
mark_finished(std::shared_ptr<Result> & r)723 void mark_finished(std::shared_ptr<Result>& r) {
724 TORCH_INTERNAL_ASSERT(!r->finished_, r->name());
725 r->finished_ = true;
726 TORCH_INTERNAL_ASSERT(r->endTimeNS() >= r->start_time_ns_, r->name());
727 }
728
729 #ifdef USE_KINETO
730 // Assumption: Total threads number will not exceed 2^16-1, and total ops will
731 // not exceed 2^48 -1.
getForwardThreadKey(uint64_t tid,uint64_t seqNr)732 static inline uint64_t getForwardThreadKey(uint64_t tid, uint64_t seqNr) {
733 return (((tid) << 48) | ((seqNr) & (((uint64_t)1 << 48) - 1)));
734 }
735
generateForwardBackwardLink(const Result & profiler_result,uint64_t & fwd_bwd_link_id,libkineto::GenericTraceActivity & activity,std::unordered_map<uint64_t,libkineto::GenericTraceActivity * > & tidSeq2activity)736 void generateForwardBackwardLink(
737 const Result& profiler_result,
738 uint64_t& fwd_bwd_link_id,
739 libkineto::GenericTraceActivity& activity,
740 std::unordered_map<uint64_t, libkineto::GenericTraceActivity*>&
741 tidSeq2activity) {
742 const ExtraFields<EventType::TorchOp>& extra_fields =
743 std::get<ExtraFields<EventType::TorchOp>>(profiler_result.extra_fields_);
744 if (extra_fields.forward_tid_ > 0) {
745 // act is backward op.
746 uint64_t key = getForwardThreadKey(
747 extra_fields.forward_tid_, extra_fields.sequence_number_);
748 auto iter = tidSeq2activity.find(key);
749 if (iter != tidSeq2activity.end()) {
750 libkineto::GenericTraceActivity* fwd = iter->second;
751 fwd->flow.start = true;
752 activity.flow.id = fwd->flow.id = fwd_bwd_link_id;
753 activity.flow.type = fwd->flow.type = libkineto::kLinkFwdBwd;
754 ++fwd_bwd_link_id;
755
756 // If there are multiple events that match this sequence/tid pair, we
757 // should delete this entry in the map to avoid inserting multiple "end"
758 // flow events.
759 tidSeq2activity.erase(iter);
760 }
761 } else if (profiler_result.start_tid_ != 0) {
762 // act is forward op.
763 uint64_t key = getForwardThreadKey(
764 profiler_result.start_tid_, extra_fields.sequence_number_);
765 // Assumption: Among all ops with same sequence number,
766 // the one with biggest start time is most likely launching backward op.
767 auto iter = tidSeq2activity.find(key);
768 if (iter == tidSeq2activity.end()) {
769 tidSeq2activity[key] = &activity;
770 } else {
771 // Now the sequence number is only incremented on creating a "Node"
772 // object for backward pass, by calling
773 // "at::sequence_number::get_and_increment()". Among all ops with same
774 // sequence number, the one with biggest startTime is the one launching
775 // backward op.
776 if (activity.startTime >= iter->second->startTime) {
777 tidSeq2activity[key] = &activity;
778 }
779 }
780 }
781 }
782 #endif // USE_KINETO
783
generateForwardBackwardLinks(std::unique_ptr<torch::profiler::impl::kineto::trace_t> & cpu_trace,const std::vector<std::shared_ptr<Result>> & results)784 void generateForwardBackwardLinks(
785 std::unique_ptr<torch::profiler::impl::kineto::trace_t>& cpu_trace,
786 const std::vector<std::shared_ptr<Result>>& results){
787 #ifndef USE_KINETO
788 }
789 #else // USE_KINETO
790 TORCH_INTERNAL_ASSERT(cpu_trace->activities.size() == results.size());
791
792 // startThreadId_seqNum to pointer of activity.
793 // Low-16bits of startThreadId and low-48bits seqNum are concatenated into
794 // one uint64_t variable as key.
795
796 std::unordered_map<uint64_t, libkineto::GenericTraceActivity*> tidSeq2activity;
797 uint64_t fwd_bwd_link_id = 1;
798
799 using result_activity_t = std::pair<Result*, libkineto::GenericTraceActivity*>;
800 std::vector<result_activity_t> torch_events;
801
802 for (const auto idx : c10::irange(cpu_trace->activities.size())) {
803 auto& profiler_result = results[idx];
804 auto& activity = cpu_trace->activities[idx];
805
806 // add information about an associated forward op, if a sequence number
807 // is available (e.g. during training)
808
809 profiler_result->visit_if_base<ExtraFields<EventType::TorchOp>>(
810 [&](const auto& e) {
811 if (e.sequence_number_ >= 0) {
812 torch_events.emplace_back(profiler_result.get(), activity.get());
813 }
814 });
815 }
816
817 // We need to visit the events in chronological order.
818 // So we sort them by end_time_ns_ before processing.
819 std::sort(
820 torch_events.begin(),
821 torch_events.end(),
822 [](const result_activity_t& left, const result_activity_t& right) {
823 auto left_end_time =
824 std::get<ExtraFields<EventType::TorchOp>>(left.first->extra_fields_)
825 .end_time_ns_;
826 auto right_end_time =
827 std::get<ExtraFields<EventType::TorchOp>>(right.first->extra_fields_)
828 .end_time_ns_;
829 return left_end_time < right_end_time;
830 });
831
832 for (auto& [profiler_result, activity] : torch_events) {
833 generateForwardBackwardLink(
834 *profiler_result, fwd_bwd_link_id, *activity, tidSeq2activity);
835 }
836 }
837 #endif // USE_KINETO
838
839 static constexpr const char* indexKey = "Ev Idx";
840
passEventsToKineto(const std::vector<std::shared_ptr<Result>> & results,uint64_t start_time_ns,uint64_t end_time_ns,const ProfilerConfig & config)841 void passEventsToKineto(
842 const std::vector<std::shared_ptr<Result>>& results,
843 uint64_t start_time_ns,
844 uint64_t end_time_ns,
845 const ProfilerConfig& config) {
846 using namespace torch::profiler::impl::kineto;
847 TraceWrapper cpu_trace(
848 static_cast<int64_t>(start_time_ns), "PyTorch Profiler");
849
850 // Generate Kineto events for each event recorded by the PyTorch profiler.
851 for (const auto i : c10::irange(results.size())) {
852 const auto& e = results[i];
853 // (TODO): This is a temporary fix for async traces to make sure that we do
854 // not use int64 MIN as end time in Kineto. If we use that value, the
855 // duration will overflow and become a very large positive number. For a
856 // long term solution, add guards in kineto for each activity type
857 int64_t act_end_time = std::max(e->endTimeNS(), e->start_time_ns_);
858 auto* activity = cpu_trace.addCPUActivity(
859 e->name(),
860 e->kinetoType(),
861 e->kineto_info_,
862 e->correlationID(),
863 e->start_time_ns_,
864 act_end_time);
865
866 TORCH_INTERNAL_ASSERT(activity || !kKinetoAvailable);
867 if (activity) {
868 addMetadata(activity, indexKey, std::to_string(i));
869
870 // There is a longstanding regression for initializing
871 // on-demand Kineto activity handling. Enabling this path
872 // for Profiler API could cause side effects as much has changed since.
873 // Make a surgical fix here until we holistically assess the on-demand
874 // vs API path framentation, which has been snowballing in complexity
875 // and thus flakiness.
876 if (config.global()) {
877 e->kineto_activity_ = activity;
878 }
879 }
880 }
881
882 if (get_fwd_bwd_enabled()) {
883 generateForwardBackwardLinks(cpu_trace.get(), results);
884 }
885
886 // Kineto adds the events that it collected.
887 cpu_trace.transferCpuTrace(static_cast<int64_t>(end_time_ns));
888 }
889
890 #ifdef USE_KINETO
891 // There are two mechanisms that we use to connect Profiler and Kineto events.
892 // The first is the correlation ID. The profiler pushes a unique integer at the
893 // start of an op and pops it at the end. Kineto then associates the events
894 // that it collects with that correlation ID and sets the linked activity of
895 // the events that it collected to point to the profiler op.
896 //
897 // However, this is not a sufficient description because it does not retain
898 // dependency information between kineto ops. Consider a call to `torch.add`.
899 // Three events will be collected:
900 // `aten::add` (TorchOp, collected by profiler)
901 // `cudaLaunchKernel` (CUDA runtime event, collected by Kineto)
902 // `at::vectorized_...` (GPU kernel, collected by Kineto)
903 // If we only relied on correlation IDs we would set both Kineto events as
904 // children of the `at::add`, rather than the correct
905 // `at::add -> cudaLaunchKernel -> at::vectorized_...`
906 //
907 // Kineto surfaces this information through a second concept called a "flow".
908 // In this example, the `cudaLaunchKernel` event is the start of a flow and the
909 // GPU kernel has the same flow id but is not a start event. Thus, when merging
910 // the Kineto events into the call tree we first add all events which are flow
911 // start nodes. We then merge the rest, trying to pair them with flow starts
912 // and falling back to correlation ID if necessary. For any nodes without
913 // linked events the caller is determined using the normal tree construction
914 // algorithm.
915 class TransferEvents {
916 using itrace_t = libkineto::ITraceActivity;
917 using activity_t = torch::profiler::impl::kineto::activity_t;
918
919 public:
TransferEvents(std::vector<std::shared_ptr<Result>> & results,trace_ptr_t & trace)920 TransferEvents(
921 std::vector<std::shared_ptr<Result>>& results,
922 trace_ptr_t& trace)
923 : results_{results} {
924 auto* trace_activities_ptr = trace->get()->activities();
925 TORCH_INTERNAL_ASSERT(trace_activities_ptr != nullptr);
926 trace_activities_ = *trace_activities_ptr;
927 reassociate();
928 extractEventsFromTrace();
929 setParents();
930 }
931
932 private:
extractIndex(const std::string & metadata_json)933 static long long extractIndex(const std::string& metadata_json) {
934 static const auto prefix = fmt::format("\"{}\": ", indexKey);
935 auto pos = metadata_json.find(prefix);
936 return (pos == std::string::npos) ? unmatchedIndex : [&]() {
937 auto end = metadata_json.find(',', pos);
938 end = (end == std::string::npos) ? metadata_json.size() : end;
939 return std::stoll(metadata_json.substr(pos + prefix.size(), end));
940 }();
941 }
942
lookup(const itrace_t * key)943 std::shared_ptr<Result> lookup(const itrace_t* key) {
944 if (key == nullptr) {
945 return nullptr;
946 }
947
948 // First check the map.
949 auto it = kineto_events_.find(key);
950 if (it != kineto_events_.end()) {
951 return it->second;
952 }
953
954 // Then fallback to the encoded metadata.
955 const auto index = extractIndex(key ? key->metadataJson() : "");
956 if (index != unmatchedIndex) {
957 auto out = results_.get().at(index);
958 kineto_events_[key] = out;
959 return out;
960 }
961
962 // And finally give up.
963 return nullptr;
964 }
965
reassociate()966 void reassociate() {
967 // Match profiler events with the corresponding kineto events. Kineto may
968 // have moved or copied the activities, so we have to recover the
969 // relationship between `libkineto::ITraceActivity` and `Result`.
970 for (const auto* activity : trace_activities_) {
971 TORCH_INTERNAL_ASSERT(activity != nullptr);
972 auto e = lookup(activity);
973 if (e != nullptr) {
974 TORCH_INTERNAL_ASSERT(e->kineto_activity_ == nullptr);
975 e->kineto_activity_ = static_cast<const activity_t*>(activity);
976 }
977 }
978 if (results_.get().size() != kineto_events_.size()) {
979 TORCH_WARN(fmt::format(
980 "Failed to recover relationship between all profiler and kineto events: "
981 "{} vs. {} reassociated.",
982 results_.get().size(),
983 kineto_events_.size()));
984 }
985 }
986
resultFromActivity(const itrace_t * activity)987 std::shared_ptr<Result> resultFromActivity(const itrace_t* activity) {
988 TORCH_INTERNAL_ASSERT(activity != nullptr);
989
990 // Kineto is inconsistent with types, so we have to cast to int32.
991 torch::profiler::impl::kineto::DeviceAndResource device_and_resource{
992 static_cast<int32_t>(activity->deviceId()),
993 static_cast<int32_t>(activity->resourceId())};
994
995 auto event = Result::create(
996 activity->timestamp(),
997 noTID, // Placeholder
998 device_and_resource,
999 ExtraFields<EventType::Kineto>{
1000 activity->name(),
1001 activity->duration(),
1002 static_cast<uint64_t>(activity->correlationId()),
1003 activity->type(),
1004 {/*id=*/static_cast<uint32_t>(activity->flowId()),
1005 /*type=*/static_cast<uint32_t>(activity->flowType()),
1006 /*start=*/activity->flowStart()}});
1007
1008 // NB: It's tempting to set `event->kineto_activity_`; however we can only
1009 // guarantee that the events we passed to Kineto are of type
1010 // `GenericTraceActivity`. Others may derive from ITraceActivity and thus
1011 // are not safe to cast.
1012 return event;
1013 }
1014
toResult(const itrace_t * activity)1015 std::shared_ptr<Result> toResult(const itrace_t* activity) {
1016 auto e = lookup(activity);
1017
1018 // Until we are very sure that we can reassociate kineto and profiler
1019 // events we need to be very defensive.
1020 const auto type = activity->type();
1021 if (e == nullptr &&
1022 (type == libkineto::ActivityType::CPU_OP ||
1023 type == libkineto::ActivityType::CPU_INSTANT_EVENT ||
1024 type == libkineto::ActivityType::USER_ANNOTATION ||
1025 type == libkineto::ActivityType::PYTHON_FUNCTION)) {
1026 TORCH_WARN_ONCE(
1027 "Detected an event which was likely passed to kineto by the PyTorch "
1028 "profiler, but is not present in the set of known events: ",
1029 activity->name(),
1030 " This most likely means that Kineto has not "
1031 "maintained address stability for this event. Please report this to "
1032 "the PyTorch team.");
1033 return nullptr;
1034 }
1035
1036 if (e == nullptr) {
1037 e = resultFromActivity(activity);
1038 results_.get().push_back(e);
1039 kineto_events_[activity] = e;
1040 }
1041 return e;
1042 }
1043
extractEventsFromTrace()1044 void extractEventsFromTrace() {
1045 for (const auto* activity : trace_activities_) {
1046 auto e = toResult(activity);
1047 const auto* linked_activity = activity->linkedActivity();
1048 if (e && linked_activity) {
1049 e->visit(c10::overloaded(
1050 [&](ExtraFields<EventType::Kineto>& i) {
1051 i.linked_activity_ = toResult(linked_activity);
1052 },
1053 [](auto&) { TORCH_INTERNAL_ASSERT(false); }));
1054 }
1055 }
1056 }
1057
setKinetoTID(std::shared_ptr<Result> & r,std::shared_ptr<Result> parent)1058 void setKinetoTID(
1059 std::shared_ptr<Result>& r,
1060 std::shared_ptr<Result> parent) {
1061 r->visit(c10::overloaded(
1062 [&](ExtraFields<EventType::Kineto>& i) {
1063 TORCH_INTERNAL_ASSERT(r->start_tid_ == noTID);
1064 r->start_tid_ = parent ? parent->start_tid_
1065 : at::RecordFunction::currentThreadId();
1066 },
1067 [](auto&) {}));
1068
1069 for (auto& child : r->children_) {
1070 setKinetoTID(child, r);
1071 }
1072 }
1073
setParents()1074 void setParents() {
1075 // First pass: Collect start events and set parent to linked event.
1076 ska::flat_hash_map<uint32_t, std::shared_ptr<Result>> flow_map;
1077 for (auto& e : results_.get()) {
1078 TORCH_INTERNAL_ASSERT(e != nullptr);
1079 e->visit(c10::overloaded(
1080 [&](const ExtraFields<EventType::Kineto>& i) {
1081 if (i.flow.type == libkineto::kLinkAsyncCpuGpu && i.flow.start) {
1082 auto inserted = flow_map.insert({i.flow.id, e});
1083 #ifdef USE_ROCM
1084 if (inserted.second) {
1085 TORCH_WARN_ONCE(
1086 "ROCTracer produced duplicate flow start: ", i.flow.id);
1087 }
1088 #else // USE_ROCM
1089 TORCH_INTERNAL_ASSERT(inserted.second);
1090 #endif // USE_ROCM
1091 }
1092 TORCH_INTERNAL_ASSERT(e->parent_.expired());
1093 e->parent_ = i.linked_activity_;
1094 },
1095 [](const auto&) {}));
1096 }
1097
1098 // Second pass
1099 for (auto& e : results_.get()) {
1100 e->visit(c10::overloaded(
1101 [&](const ExtraFields<EventType::Kineto>& i) {
1102 // Flow takes priority over linked event.
1103 const auto it = flow_map.find(i.flow.id);
1104 if (it != flow_map.end() &&
1105 i.flow.type == libkineto::kLinkAsyncCpuGpu && !i.flow.start) {
1106 e->parent_ = it->second;
1107 }
1108
1109 // If a parent was set we have to do some bookkeeping.
1110 auto parent = e->parent_.lock();
1111 if (parent) {
1112 parent->children_.push_back(e);
1113 mark_finished(e);
1114 }
1115 },
1116 [](const auto&) {}));
1117 }
1118
1119 // Set TIDs now that we have established lineage.
1120 for (auto& e : results_.get()) {
1121 if (e->parent_.expired()) {
1122 setKinetoTID(e, nullptr);
1123 }
1124 }
1125 }
1126
1127 static constexpr long long unmatchedIndex = -1;
1128 static constexpr auto noTID = std::numeric_limits<uint64_t>::max();
1129 std::reference_wrapper<std::vector<std::shared_ptr<Result>>> results_;
1130 std::vector<const itrace_t*> trace_activities_;
1131 ska::flat_hash_map<const itrace_t*, std::shared_ptr<Result>> kineto_events_;
1132 };
1133 #else
1134 class TransferEvents {
1135 public:
1136 template <class... Args>
TransferEvents(Args &&...)1137 TransferEvents(Args&&...) {}
1138 };
1139 #endif
1140
addKinetoEvents(std::vector<std::shared_ptr<Result>> & results,uint64_t start_time_ns,uint64_t end_time_ns,const ProfilerConfig & config)1141 trace_ptr_t addKinetoEvents(
1142 std::vector<std::shared_ptr<Result>>& results,
1143 uint64_t start_time_ns,
1144 uint64_t end_time_ns,
1145 const ProfilerConfig& config) {
1146 using namespace torch::profiler::impl::kineto;
1147 passEventsToKineto(results, start_time_ns, end_time_ns, config);
1148
1149 // In on demand mode kineto is directly controlled by other machinery.
1150 if (config.global()) {
1151 return nullptr;
1152 }
1153
1154 auto trace = std::make_unique<ActivityTraceWrapper>(stopTrace());
1155 TORCH_INTERNAL_ASSERT(trace || !kKinetoAvailable);
1156 TransferEvents transfer{results, trace};
1157 return trace;
1158 }
1159
1160 struct ResultGreater {
operator ()torch::profiler::impl::__anon1af765770c11::ResultGreater1161 bool operator()(const result_ptr_t& a, const result_ptr_t& b) const {
1162 return a->endTimeNS() > b->endTimeNS();
1163 }
1164 };
1165
set_in_tree_building(std::vector<result_ptr_t> & results,const bool value)1166 void set_in_tree_building(
1167 std::vector<result_ptr_t>& results,
1168 const bool value) {
1169 for (result_ptr_t& r : results) {
1170 r->visit(c10::overloaded(
1171 [value](ExtraFields<EventType::Vulkan>& i) {
1172 i.in_tree_building_ = value;
1173 },
1174 [&](auto&) {
1175 // pass
1176 }));
1177 }
1178 }
1179
build_tree(std::vector<std::shared_ptr<Result>> & sorted_events)1180 void build_tree(std::vector<std::shared_ptr<Result>>& sorted_events) {
1181 set_in_tree_building(sorted_events, true);
1182
1183 using op_fields = ExtraFields<EventType::TorchOp>;
1184 ska::flat_hash_map<uint64_t, std::shared_ptr<Result>> stacks;
1185 std::priority_queue<result_ptr_t, std::vector<result_ptr_t>, ResultGreater>
1186 end_events_;
1187
1188 auto push_event = [&stacks, &end_events_](std::shared_ptr<Result>& event) {
1189 // Kineto builds subtrees using correlation ids and flows, so some Kineto
1190 // events are already marked finished before the main tree building
1191 // algorithm. It's fine to ignore them; the root event of these subtrees
1192 // not a Kineto op and will be handled normally.
1193 if (std::holds_alternative<ExtraFields<EventType::Kineto>>(
1194 event->extra_fields_) &&
1195 event->finished_) {
1196 return;
1197 }
1198
1199 TORCH_INTERNAL_ASSERT(event->parent_.expired());
1200 for (const auto& child : event->children_) {
1201 TORCH_INTERNAL_ASSERT(child->finished_);
1202 }
1203 TORCH_INTERNAL_ASSERT(!event->finished_);
1204
1205 auto parent_it = stacks.find(event->start_tid_);
1206 if (parent_it == stacks.end()) {
1207 auto fwd_tid = event->visit(c10::overloaded(
1208 [](const op_fields& i) { return i.forward_tid_; },
1209 [](const auto&) -> uint64_t { return 0; }));
1210 if (fwd_tid) {
1211 parent_it = stacks.find(fwd_tid);
1212 }
1213 }
1214
1215 if (parent_it != stacks.end()) {
1216 event->parent_ = parent_it->second;
1217 parent_it->second->children_.push_back(event);
1218 }
1219
1220 if (event->endTimeNS() > event->start_time_ns_) {
1221 stacks[event->start_tid_] = event;
1222 end_events_.push(event);
1223 } else if (event->endTimeNS() == std::numeric_limits<c10::time_t>::min()) {
1224 // We use min time to indicate the lack of a termination event, so if we
1225 // encounter such a case we don't push to `end_events_`.
1226 stacks[event->start_tid_] = event;
1227 } else {
1228 mark_finished(event);
1229 }
1230 };
1231
1232 auto pop_event = [&stacks](std::shared_ptr<Result> event) {
1233 if (event->finished_) {
1234 // This event was marked finished by a previous `pop_event` call.
1235 return;
1236 }
1237
1238 auto start_tid = event->start_tid_;
1239 auto frame = stacks.at(start_tid);
1240
1241 while (frame.get() != event.get()) {
1242 TORCH_INTERNAL_ASSERT(frame != nullptr);
1243 mark_finished(frame);
1244 TORCH_INTERNAL_ASSERT(!frame->parent_.expired());
1245 frame = frame->parent_.lock();
1246 }
1247
1248 mark_finished(event);
1249 stacks.erase(start_tid);
1250 auto new_frame = event->parent_.lock();
1251 if (new_frame != nullptr) {
1252 stacks[start_tid] = new_frame;
1253 }
1254 };
1255
1256 // Stack replay loop.
1257 for (auto& event : sorted_events) {
1258 while (!end_events_.empty() &&
1259 end_events_.top()->endTimeNS() < event->start_time_ns_) {
1260 pop_event(end_events_.top());
1261 end_events_.pop();
1262 }
1263 push_event(event);
1264 }
1265
1266 // Cleanup remaining exit events.
1267 while (!end_events_.empty()) {
1268 pop_event(end_events_.top());
1269 end_events_.pop();
1270 }
1271
1272 set_in_tree_building(sorted_events, false);
1273 }
1274
1275 /**
1276 * Adjust r's duration to be the max of its current duration and the sum of all
1277 * of its children's adjusted durations (keeping its start time the same)
1278 * (adjust all child durations recursively)
1279 */
adjust_durations_dfs(std::shared_ptr<Result> & r)1280 int64_t adjust_durations_dfs(std::shared_ptr<Result>& r) {
1281 if (SOFT_ASSERT(r != nullptr)) {
1282 int64_t original_duration = r->endTimeNS() - r->start_time_ns_;
1283 int64_t children_total_duration = std::accumulate(
1284 r->children_.begin(),
1285 r->children_.end(),
1286 0,
1287 [](int64_t acc, std::shared_ptr<Result>& child) {
1288 return acc + adjust_durations_dfs(child);
1289 });
1290
1291 if (children_total_duration > original_duration) {
1292 r->visit(c10::overloaded(
1293 [&r, &children_total_duration](ExtraFields<EventType::TorchOp>& i) {
1294 i.end_time_ns_ = r->start_time_ns_ + children_total_duration;
1295 },
1296 [&children_total_duration](ExtraFields<EventType::Vulkan>& i) {
1297 i.duration_ns_ = children_total_duration;
1298 },
1299 [](ExtraFields<EventType::Allocation>& _) {
1300 // Pass- Allocation events can't have children
1301 },
1302 [&](auto&) {
1303 SOFT_ASSERT(
1304 false,
1305 "unexpected event type in mobile profiler adjust_durations_dfs: ",
1306 r->name());
1307 }));
1308 return children_total_duration;
1309 } else {
1310 return original_duration;
1311 }
1312 } else {
1313 return 0;
1314 }
1315 }
1316
1317 /**
1318 * 1) Adjust r's start time to be [new_start_time] (also adjusting end time and
1319 keeping duration the same)
1320 * 2) Recursively adjust r's children's start times, making them line up such
1321 that the last one ends at the same time as r
1322 * 3) Return r's final end time
1323 */
adjust_timestamps_dfs(std::shared_ptr<Result> & r,int64_t new_start_time)1324 int64_t adjust_timestamps_dfs(
1325 std::shared_ptr<Result>& r,
1326 int64_t new_start_time) {
1327 if (SOFT_ASSERT(r != nullptr)) {
1328 if (r->start_time_ns_ != new_start_time) {
1329 // Adjust start time (keeping duration constant)
1330 r->visit(c10::overloaded(
1331 [&r, &new_start_time](ExtraFields<EventType::TorchOp>& i) {
1332 i.end_time_ns_ =
1333 new_start_time + (i.end_time_ns_ - r->start_time_ns_);
1334 },
1335 [](ExtraFields<EventType::Vulkan>& i) {
1336 // Pass- We don't need to manually adjust end time for Vulkan events
1337 },
1338 [](ExtraFields<EventType::Allocation>& _) {
1339 // Pass- No duration or end time to adjust
1340 },
1341 [&](auto&) {
1342 SOFT_ASSERT(
1343 false,
1344 "unexpected event type in mobile profiler adjust_timestamps_dfs: ",
1345 r->name());
1346 }));
1347 r->start_time_ns_ = new_start_time;
1348 }
1349 int64_t children_total_duration = std::accumulate(
1350 r->children_.begin(),
1351 r->children_.end(),
1352 0,
1353 [](int64_t acc, std::shared_ptr<Result>& child) {
1354 return acc + (child->endTimeNS() - child->start_time_ns_);
1355 });
1356
1357 int64_t child_start_time = r->endTimeNS() - children_total_duration;
1358 for (std::shared_ptr<Result>& child : r->children_) {
1359 child_start_time = adjust_timestamps_dfs(child, child_start_time);
1360 }
1361 }
1362 return r->endTimeNS();
1363 }
1364
1365 /**
1366 * Adjust timestamps and durations of nodes in [out] such that
1367 * - Vulkan event timelines are synchronized with CPU event times
1368 * - Parent event timelines fully contain their child timelines
1369 * - No overlaps in timelines for nodes at the same depth
1370 */
adjust_timestamps(std::vector<std::shared_ptr<Result>> & out)1371 void adjust_timestamps(std::vector<std::shared_ptr<Result>>& out) {
1372 if (out.empty()) {
1373 return;
1374 }
1375
1376 int64_t min_start_time = out[0]->start_time_ns_;
1377 for (std::shared_ptr<Result>& r : out) {
1378 // Only begin traversal for root nodes.
1379 if (r->parent_.expired()) {
1380 adjust_durations_dfs(r);
1381 min_start_time = adjust_timestamps_dfs(
1382 r,
1383 std::max(
1384 r->tag() != EventType::Vulkan
1385 ? r->start_time_ns_
1386 : std::numeric_limits<int64_t>::min(),
1387 min_start_time));
1388 }
1389 }
1390 }
1391 } // namespace
1392
1393 std::pair<
1394 std::vector<std::shared_ptr<Result>>,
1395 std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>>
getRecords(std::function<c10::time_t (c10::approx_time_t)> time_converter,uint64_t start_time_ns,uint64_t end_time_ns)1396 RecordQueue::getRecords(
1397 std::function<c10::time_t(c10::approx_time_t)> time_converter,
1398 uint64_t start_time_ns,
1399 uint64_t end_time_ns) {
1400 auto converter = [&](c10::approx_time_t t) {
1401 return t == std::numeric_limits<c10::approx_time_t>::min()
1402 ? std::numeric_limits<c10::time_t>::min()
1403 : time_converter(t);
1404 };
1405
1406 // Lambda that checks that only the right side of the base intersects with
1407 // ev_start and ev_end
1408 auto right_intersection_only =
1409 [&](ProfilerStepInfo base, int64_t ev_start, int64_t ev_end) {
1410 return (base.start_time_ns < ev_start) &&
1411 (base.end_time_ns <= ev_end && base.end_time_ns > ev_start);
1412 };
1413 std::vector<std::shared_ptr<Result>> out;
1414 std::vector<python_tracer::CompressedEvent> python_enters;
1415 std::vector<ProfilerStepInfo> step_info;
1416 long unsigned int step_idx = 0;
1417 for (auto& subqueue_it : sub_queues_) {
1418 auto& queue = *subqueue_it.second;
1419 auto materialize = [&](auto& events) {
1420 for (auto& i : events) {
1421 c10::time_t start_time_ns = 0;
1422 if constexpr (std::is_same_v<
1423 std::remove_reference_t<decltype(i)>,
1424 ExtraFields<EventType::Backend>>) {
1425 start_time_ns = i.start_time_us_ * 1000;
1426 } else {
1427 start_time_ns = converter(i.start_time_);
1428 }
1429 out.emplace_back(Result::create(
1430 /*start_time_ns_=*/start_time_ns,
1431 /*start_tid_=*/queue.tid(),
1432 /*kineto_info_=*/queue.kineto_info(),
1433 /*extra_fields_=*/std::move(i)));
1434 }
1435 events.clear();
1436 };
1437
1438 queue.torch_ops_.materialize(
1439 out, step_info, converter, queue.tid(), queue.kineto_info());
1440 materialize(queue.backend_events_);
1441 materialize_vulkan(
1442 out, queue.vulkan_events_, converter, queue.tid(), queue.kineto_info());
1443 for (auto& i : queue.allocations_) {
1444 out.emplace_back(Result::create(
1445 /*start_time_ns_=*/converter(i.start_time_),
1446 /*start_tid_=*/queue.tid(),
1447 /*kineto_info_=*/queue.kineto_info(),
1448 /*extra_fields_=*/ExtraFields<EventType::Allocation>(i)));
1449 }
1450 materialize(queue.ooms_);
1451
1452 for (auto& i : queue.py_calls_) {
1453 python_enters.push_back(
1454 {i.first, queue.tid(), queue.kineto_info(), converter(i.second)});
1455 }
1456 }
1457
1458 if (python_tracer_) {
1459 std::vector<std::shared_ptr<torch::profiler::impl::Result>> ev;
1460 try {
1461 ev = python_tracer_->getEvents(
1462 converter, python_enters, static_cast<c10::time_t>(end_time_ns));
1463 } catch (std::exception& e) {
1464 // Normally addKinetoEvents() below will stop the trace - but if an
1465 // exception happens here then the events will never be stopped and future
1466 // runs will be broken - so make sure to stopTrace() if we see an
1467 // exception.
1468 torch::profiler::impl::kineto::stopTrace();
1469 throw;
1470 }
1471 // Placeholder for if we run out of ProfilerStep annotations
1472 ProfilerStepInfo defaultStep = {LLONG_MAX, LLONG_MAX, 0};
1473 ProfilerStepInfo step =
1474 step_idx < step_info.size() ? step_info[step_idx] : defaultStep;
1475 for (const auto& i : ev) {
1476 // If event has start time after step end time we can continue to the next
1477 // step
1478 while (i->start_time_ns_ > step.end_time_ns) {
1479 step_idx++;
1480 step = step_idx < step_info.size() ? step_info[step_idx] : defaultStep;
1481 }
1482 // If Step annotation starts before event and ends before event ends with
1483 // intersection then we move the lefthand side of the step annotation to
1484 // the event start time
1485 if (right_intersection_only(step, i->start_time_ns_, i->endTimeNS())) {
1486 auto currStepRes = out[step.out_idx];
1487 currStepRes->start_time_ns_ = i->start_time_ns_ + 1;
1488 step_idx++;
1489 step = step_idx < step_info.size() ? step_info[step_idx] : defaultStep;
1490 }
1491 out.push_back(i);
1492 }
1493 python_tracer_.reset();
1494 }
1495
1496 if (config_.experimental_config.adjust_timestamps) {
1497 std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) {
1498 return a->start_time_ns_ < b->start_time_ns_;
1499 });
1500 build_tree(out);
1501 adjust_timestamps(out);
1502 for (auto& r : out) {
1503 r->parent_.reset();
1504 // Reset these so that second build_tree can happen
1505 r->finished_ = false;
1506 r->children_.clear();
1507 }
1508 }
1509
1510 auto trace = addKinetoEvents(out, start_time_ns, end_time_ns, config_);
1511
1512 std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) {
1513 return a->start_time_ns_ < b->start_time_ns_;
1514 });
1515
1516 if (config_.report_input_shapes && config_.profile_memory) {
1517 calculateUniqueTensorIDs(out);
1518 }
1519
1520 build_tree(out);
1521 return {out, std::move(trace)};
1522 }
1523
1524 namespace {
record_concrete_inputs_enabled_fn()1525 std::function<bool()>& record_concrete_inputs_enabled_fn() {
1526 static std::function<bool()> fn = []() { return true; };
1527 return fn;
1528 }
1529 } // namespace
1530
get_record_concrete_inputs_enabled()1531 bool get_record_concrete_inputs_enabled() {
1532 return record_concrete_inputs_enabled_fn()();
1533 }
1534
set_record_concrete_inputs_enabled_fn(std::function<bool ()> fn)1535 void set_record_concrete_inputs_enabled_fn(std::function<bool()> fn) {
1536 record_concrete_inputs_enabled_fn() = std::move(fn);
1537 }
1538
set_record_concrete_inputs_enabled_val(bool val)1539 void set_record_concrete_inputs_enabled_val(bool val) {
1540 record_concrete_inputs_enabled_fn() = [val]() { return val; };
1541 }
1542
1543 namespace {
fwd_bwd_enabled_fn()1544 std::function<bool()>& fwd_bwd_enabled_fn() {
1545 static std::function<bool()> fn = []() { return true; };
1546 return fn;
1547 }
1548 } // namespace
1549
get_fwd_bwd_enabled()1550 bool get_fwd_bwd_enabled() {
1551 return fwd_bwd_enabled_fn()();
1552 }
1553
set_fwd_bwd_enabled_fn(std::function<bool ()> fn)1554 void set_fwd_bwd_enabled_fn(std::function<bool()> fn) {
1555 fwd_bwd_enabled_fn() = std::move(fn);
1556 }
1557
set_fwd_bwd_enabled_val(bool val)1558 void set_fwd_bwd_enabled_val(bool val) {
1559 fwd_bwd_enabled_fn() = [val]() { return val; };
1560 }
1561
1562 namespace {
cuda_sync_enabled_fn()1563 std::function<bool()>& cuda_sync_enabled_fn() {
1564 static std::function<bool()> fn = []() { return false; };
1565 return fn;
1566 }
1567 } // namespace
1568
get_cuda_sync_enabled()1569 bool get_cuda_sync_enabled() {
1570 return cuda_sync_enabled_fn()();
1571 }
1572
set_cuda_sync_enabled_fn(std::function<bool ()> fn)1573 void set_cuda_sync_enabled_fn(std::function<bool()> fn) {
1574 cuda_sync_enabled_fn() = std::move(fn);
1575 }
1576
set_cuda_sync_enabled_val(bool val)1577 void set_cuda_sync_enabled_val(bool val) {
1578 cuda_sync_enabled_fn() = [val]() { return val; };
1579 }
1580
1581 } // namespace torch::profiler::impl
1582