xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/data_flow.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <memory>
4 
5 #include <ATen/core/TensorBody.h>
6 #include <c10/core/TensorImpl.h>
7 #include <c10/macros/Macros.h>
8 #include <c10/util/strong_type.h>
9 
10 namespace torch::profiler::impl {
11 
12 // Identity is a complex concept in PyTorch. A Tensor might not have a
13 // an associated storage, multiple Tensors might share the same underlying
14 // storage, the storage of a Tensor might change over time, etc.
15 //
16 // For the purpose of profiling we're mostly interested in data flow
17 // analysis. As a result, we can take an expansive view of identity:
18 // Tensors share an ID if they share a TensorImpl or storage data.
19 //
20 // This identity equality is transitive; If Tensors T0 and T1 share a storage
21 // S0 and T1 later points to a different storage S1 then all Tensors which
22 // point to either S0 or S1 are considered to have the same identity. (Since
23 // profiler cannot reason beyond that.)
24 //
25 // The profiler will handle lifetime analysis to ensure that identities do
26 // not run afoul of the ABA problem. This does, however, mean that identities
27 // can only be assigned when memory profiling is enabled.
28 using TensorID = strong::type<size_t, struct TensorID_, strong::regular>;
29 
30 // Uniquely identifies an allocation. (Generally a StorageImpl's data ptr.)
31 using AllocationID = strong::type<
32     size_t,
33     struct StorageID_,
34     strong::ordered,
35     strong::regular,
36     strong::hashable>;
37 
38 // We use a Tensor's TensorImpl adress and StorageImpl data start to build the
39 // data flow graph. We do not hold an owning reference so we wrap them in strong
40 // types to prevent direct access.
41 using TensorImplAddress = strong::type<
42     const c10::TensorImpl*,
43     struct TensorImplAddress_,
44     strong::regular,
45     strong::hashable,
46     strong::boolean>;
47 
48 using StorageImplData = strong::type<
49     const void*,
50     struct StorageImplData_,
51     strong::regular,
52     strong::hashable,
53     strong::boolean>;
54 
55 // ============================================================================
56 // == weak_intrusive_ptr and the ABA problem for TensorImpl* ==================
57 // ============================================================================
58 // Tracking `TensorImpl`s is an important part of identity tracking, because
59 // a Tensor might change storage; however when it does we want to retain the
60 // fact that the old and new storage belong to the same logical Tensor. We
61 // cannot take an owning reference to the Tensor because that would change
62 // program semantics by extending the lifetime of the Tensor. However if we
63 // store a raw TensorImpl* pointer the TensorImpl might be deleted and a new
64 // TensorImpl might be created that reuses the address. (ABA problem)
65 //
66 // Fortunately, there is a feature of `c10::intrusive_ptr` that we can use to
67 // prevent address reuse for the duration of profiling: the weak intrusive ptr.
68 // When a Tensor's refcount reaches zero but there are outstanding weak
69 // references (`weakcount_ > 0`) it will free the underlying managed resources
70 // by calling `target_->release_resources()`, but it will not call `delete`.
71 // (Instead, `delete` is called when the last weak reference is destroyed.)
72 // This means that we can safely use address identity to track `TensorImpls`.
73 class WeakTensor {
74  public:
WeakTensor(const at::Tensor & t)75   explicit WeakTensor(const at::Tensor& t) : weak_self_(t.getIntrusivePtr()) {}
76 
get()77   auto get() const {
78     return TensorImplAddress{weak_self_._unsafe_get_target()};
79   }
80 
81  private:
82   c10::weak_intrusive_ptr<c10::TensorImpl> weak_self_;
83 };
84 
85 struct Result;
86 
87 void calculateUniqueTensorIDs(
88     std::vector<std::shared_ptr<Result>>& sorted_results);
89 
90 } // namespace torch::profiler::impl
91