xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TracerMode.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/LocalDispatchKeySet.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Export.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker // NOTE [Tracing Mode Switches]
8*da0073e9SAndroid Build Coastguard Worker //
9*da0073e9SAndroid Build Coastguard Worker // Historically, tracing function was controlled by two switches:
10*da0073e9SAndroid Build Coastguard Worker //
11*da0073e9SAndroid Build Coastguard Worker // - `AutoDispatchBelowADInplaceOrView` guard
12*da0073e9SAndroid Build Coastguard Worker //
13*da0073e9SAndroid Build Coastguard Worker //    Tracing function used to be script-generated inside `VariableType_*.cpp`
14*da0073e9SAndroid Build Coastguard Worker //    kernels, sharing the same `Autograd` dispatch key with autograd function.
15*da0073e9SAndroid Build Coastguard Worker //    Therefore, before tracing function was moved out of VariableType,
16*da0073e9SAndroid Build Coastguard Worker //    `AutoDispatchBelowADInplaceOrView` guard can also disable tracing as a
17*da0073e9SAndroid Build Coastguard Worker //    side effect of disabling `Autograd` dispatching.
18*da0073e9SAndroid Build Coastguard Worker //
19*da0073e9SAndroid Build Coastguard Worker // - `setTracingState()` API in `torch/csrc/jit/frontend/tracer.h`
20*da0073e9SAndroid Build Coastguard Worker //
21*da0073e9SAndroid Build Coastguard Worker //    It stores tracing data in a `TracingState` object in TLS. If the
22*da0073e9SAndroid Build Coastguard Worker //    `TracingState` object in TLS is `null`, then tracing is paused.
23*da0073e9SAndroid Build Coastguard Worker //
24*da0073e9SAndroid Build Coastguard Worker //    The `TracingState` object is created in `tracer::trace()` - the main
25*da0073e9SAndroid Build Coastguard Worker //    entrance of tracing function. It's temporarily set to `null` inside
26*da0073e9SAndroid Build Coastguard Worker //    generated VariableType (now TraceType) to bypass tracing for intermediate
27*da0073e9SAndroid Build Coastguard Worker //    ops (ops being called by other ops). After the intermediate op call
28*da0073e9SAndroid Build Coastguard Worker //    finishes it's set back to the original `TracingState` object.
29*da0073e9SAndroid Build Coastguard Worker //
30*da0073e9SAndroid Build Coastguard Worker //    The `TracingState` obect in TLS can also be read/written via its Python
31*da0073e9SAndroid Build Coastguard Worker //    binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs,
32*da0073e9SAndroid Build Coastguard Worker //    which are also exposed as `TORCH_API`.
33*da0073e9SAndroid Build Coastguard Worker //
34*da0073e9SAndroid Build Coastguard Worker // Two new switches were introduced since tracing function was moved out of
35*da0073e9SAndroid Build Coastguard Worker // VariableType:
36*da0073e9SAndroid Build Coastguard Worker //
37*da0073e9SAndroid Build Coastguard Worker // - `tracer::impl::set_dispatch_enabled()` API
38*da0073e9SAndroid Build Coastguard Worker //
39*da0073e9SAndroid Build Coastguard Worker //    Unlike the special `Autograd` dispatch key which is included in dispatch
40*da0073e9SAndroid Build Coastguard Worker //    key set by default, `Tracer` dispatch key is off by default. The
41*da0073e9SAndroid Build Coastguard Worker //    dispatching switch can be toggled via this new API.
42*da0073e9SAndroid Build Coastguard Worker //
43*da0073e9SAndroid Build Coastguard Worker // - `tracer::impl::NoTracerDispatchMode` guard
44*da0073e9SAndroid Build Coastguard Worker //
45*da0073e9SAndroid Build Coastguard Worker //    It's used to cover the old semantics of `AutoDispatchBelowADInplaceOrView`
46*da0073e9SAndroid Build Coastguard Worker //    after tracing was moved out of VariableType.
47*da0073e9SAndroid Build Coastguard Worker //
48*da0073e9SAndroid Build Coastguard Worker // Before tracing function was moved out of VariableType, tracing was enabled
49*da0073e9SAndroid Build Coastguard Worker // when the following conditions are satisfied:
50*da0073e9SAndroid Build Coastguard Worker //
51*da0073e9SAndroid Build Coastguard Worker //    1) `TracingState` object in TLS != null;
52*da0073e9SAndroid Build Coastguard Worker //       - Either inside the execution scope of `tracer::trace()`, or
53*da0073e9SAndroid Build Coastguard Worker //       - Eagerly called `setTracingState()` with non-null object.
54*da0073e9SAndroid Build Coastguard Worker //    2) Not inside `AutoDispatchBelowADInplaceOrView` scope;
55*da0073e9SAndroid Build Coastguard Worker //
56*da0073e9SAndroid Build Coastguard Worker // After:
57*da0073e9SAndroid Build Coastguard Worker //
58*da0073e9SAndroid Build Coastguard Worker //    1) `TracingState` object in TLS != null;
59*da0073e9SAndroid Build Coastguard Worker //    2) Has called `tracer::impl::set_dispatch_enabled(true)`;
60*da0073e9SAndroid Build Coastguard Worker //    3) Not inside `tracer::impl::NonDispatchGuard` scope;
61*da0073e9SAndroid Build Coastguard Worker //
62*da0073e9SAndroid Build Coastguard Worker // [TODOs]
63*da0073e9SAndroid Build Coastguard Worker //
64*da0073e9SAndroid Build Coastguard Worker // - `setTracingState()` v.s. `tracer::impl::set_dispatch_enabled()`
65*da0073e9SAndroid Build Coastguard Worker //
66*da0073e9SAndroid Build Coastguard Worker //   Currently `set_dispatch_enabled()` is set/unset inside `setTracingState()`
67*da0073e9SAndroid Build Coastguard Worker //   to keep the semantics exactly the same as before - it's confusing to keep
68*da0073e9SAndroid Build Coastguard Worker //   both switches, though. We should consider simplifying/limiting the exposed
69*da0073e9SAndroid Build Coastguard Worker //   `setTracingState()` Python/C++ APIs (and other APIs calling it) so that
70*da0073e9SAndroid Build Coastguard Worker //   these two can be unified.
71*da0073e9SAndroid Build Coastguard Worker //
72*da0073e9SAndroid Build Coastguard Worker // - `AutoDispatchBelowADInplaceOrView` v.s.
73*da0073e9SAndroid Build Coastguard Worker // `tracer::impl::NoTracerDispatchMode`
74*da0073e9SAndroid Build Coastguard Worker //
75*da0073e9SAndroid Build Coastguard Worker //   We don't need to always set both guards together to keep semantics
76*da0073e9SAndroid Build Coastguard Worker //   unchanged. For the follow use cases of `AutoDispatchBelowADInplaceOrView`
77*da0073e9SAndroid Build Coastguard Worker //   we don't need set the new tracer guard:
78*da0073e9SAndroid Build Coastguard Worker //
79*da0073e9SAndroid Build Coastguard Worker //   * Script-generated VariableType kernels. The guard is not necessary as
80*da0073e9SAndroid Build Coastguard Worker //     tracing is already disabled explicitly by `setTracingState(null)` in
81*da0073e9SAndroid Build Coastguard Worker //     generated TraceType kernels - we could keep it as is or use the new guard
82*da0073e9SAndroid Build Coastguard Worker //     instead.
83*da0073e9SAndroid Build Coastguard Worker //
84*da0073e9SAndroid Build Coastguard Worker //   * Custom ops. Will be handled by fallback kernel for `Tracer`.
85*da0073e9SAndroid Build Coastguard Worker //
86*da0073e9SAndroid Build Coastguard Worker //   * Functions that are not likely to be called in tracing context (no python
87*da0073e9SAndroid Build Coastguard Worker //     binding / not an operator), e.g.: all mobile forward() wrappers, test
88*da0073e9SAndroid Build Coastguard Worker //     binaries, and etc.
89*da0073e9SAndroid Build Coastguard Worker //
90*da0073e9SAndroid Build Coastguard Worker //   * Where new threads are spawned, e.g.: ATen/native/ConvolutionMM2d.cpp.
91*da0073e9SAndroid Build Coastguard Worker //     It's not necessary as tracing is off by default.
92*da0073e9SAndroid Build Coastguard Worker //
93*da0073e9SAndroid Build Coastguard Worker //   For the rest of cases we might need have both:
94*da0073e9SAndroid Build Coastguard Worker //
95*da0073e9SAndroid Build Coastguard Worker //   * Functions that might be reachable from eager mode python (especially
96*da0073e9SAndroid Build Coastguard Worker //     factory methods), e.g.:
97*da0073e9SAndroid Build Coastguard Worker //     `internal_new_from_data()` in `torch/csrc/utils/tensor_new.cpp`.
98*da0073e9SAndroid Build Coastguard Worker //     Without the new guard it will add `aten::empty` to the traced graph.
99*da0073e9SAndroid Build Coastguard Worker //
100*da0073e9SAndroid Build Coastguard Worker //   * Some manually maintained functions, e.g.:
101*da0073e9SAndroid Build Coastguard Worker //     `torch/csrc/autograd/VariableTypeManual.cpp`.
102*da0073e9SAndroid Build Coastguard Worker //     Set the new guard if it's not obvious whether `setTracingState(null)`
103*da0073e9SAndroid Build Coastguard Worker //     has been called before it reaches the `AutoDispatchBelowADInplaceOrView`
104*da0073e9SAndroid Build Coastguard Worker //     guard.
105*da0073e9SAndroid Build Coastguard Worker //
106*da0073e9SAndroid Build Coastguard Worker //   We might need tweak the usage of the new guard to optimize/fix things.
107*da0073e9SAndroid Build Coastguard Worker //   It should only affect the correctness of tracing function, because the
108*da0073e9SAndroid Build Coastguard Worker //   guard is essentially no-op when the master `setTracingState()` switch is
109*da0073e9SAndroid Build Coastguard Worker //   off.
110*da0073e9SAndroid Build Coastguard Worker 
111*da0073e9SAndroid Build Coastguard Worker // TODO: move this from `at::` to `jit::torch::` after
112*da0073e9SAndroid Build Coastguard Worker // `aten/src/ATen/cpp_custom_type_hack.h` is removed.
113*da0073e9SAndroid Build Coastguard Worker 
114*da0073e9SAndroid Build Coastguard Worker namespace at::tracer::impl {
115*da0073e9SAndroid Build Coastguard Worker 
is_dispatch_enabled()116*da0073e9SAndroid Build Coastguard Worker inline bool is_dispatch_enabled() {
117*da0073e9SAndroid Build Coastguard Worker   return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) &&
118*da0073e9SAndroid Build Coastguard Worker       !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer);
119*da0073e9SAndroid Build Coastguard Worker }
120*da0073e9SAndroid Build Coastguard Worker 
set_dispatch_enabled(bool enabled)121*da0073e9SAndroid Build Coastguard Worker inline void set_dispatch_enabled(bool enabled) {
122*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(
123*da0073e9SAndroid Build Coastguard Worker       !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer),
124*da0073e9SAndroid Build Coastguard Worker       "Cannot enable tracing within the scope of NoTracerDispatchMode!");
125*da0073e9SAndroid Build Coastguard Worker   c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Tracer, enabled);
126*da0073e9SAndroid Build Coastguard Worker }
127*da0073e9SAndroid Build Coastguard Worker 
128*da0073e9SAndroid Build Coastguard Worker struct NoTracerDispatchMode {
129*da0073e9SAndroid Build Coastguard Worker   c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::Tracer};
130*da0073e9SAndroid Build Coastguard Worker };
131*da0073e9SAndroid Build Coastguard Worker 
132*da0073e9SAndroid Build Coastguard Worker } // namespace at::tracer::impl
133