xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/graph_task.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/ThreadLocalState.h>
3 #include <ATen/core/Tensor.h>
4 #include <c10/util/ThreadLocal.h>
5 #include <torch/csrc/autograd/input_buffer.h>
6 #include <torch/csrc/autograd/utils/warnings.h>
7 #include <vector>
8 
9 namespace torch::autograd {
10 
11 using edge_list = std::vector<Edge>;
12 struct ReadyQueue;
13 
14 static constexpr int NO_DEVICE = -2;
15 static constexpr int CPU_DEVICE = -1;
16 
17 // GraphTask holds metadata needed for a single execution of backward()
18 struct GraphTask : std::enable_shared_from_this<GraphTask> {
19   std::atomic<uint64_t> outstanding_tasks_{0};
20   // Indicates if an error occurred while executing any task.  When this is
21   // true, it signals all threads to stop executing.
22   std::atomic_bool has_error_{false};
23   std::atomic_bool future_completed_{false};
24   // It is safe to read keep_graph_ without synchronization
25   bool keep_graph_;
26 
27   // To protect reads/writes to not_ready_, dependencies_, captured_vars_,
28   // has_error_, future_result_, cpu_ready_queue_, and leaf_streams.
29   std::mutex mutex_;
30   std::unordered_map<Node*, InputBuffer> not_ready_;
31   std::unordered_map<Node*, int> dependencies_;
32 
33   // Records the nodes that are in the graph
34   std::unordered_set<Node*> nodes_in_graph_;
35   c10::SmallVector<Node*, 4> graph_roots_;
36   // Note [Exec info]
37   // Exec info is created for each GraphTask, which allows filtering paths on
38   // the graph that are not needed. It has a bit complicated semantics. If it's
39   // empty, it means the task is run in a "default" mode, which means that all
40   // next_edges we encounter should get executed. If it's not empty, only
41   // functions that have an entry and this entry has needed == True should be
42   // executed. exec_info is only empty when the graph is executed via
43   // .backward() and the inputs parameter is not passed. Otherwise, when
44   // executed through .grad(), or when inputs arg is specified for .backward(),
45   // exec_info will be non-empty.
46   //
47   struct ExecInfo {
48     struct Capture {
49       Capture(const Capture&) = delete;
50       Capture(Capture&&) = default;
51 
CaptureGraphTask::ExecInfo::Capture52       Capture(int input_idx, int output_idx)
53           : input_idx_(input_idx), output_idx_(output_idx) {}
54       int input_idx_; // within Node inputs
55       int output_idx_; // within the output vector of a GraphTask
56 
57       // This hook will be executed after a grad is captured. The captured
58       // grad will be replaced by the return value of the hook.
59       struct GradCaptureHook {
60         virtual ~GradCaptureHook() = default;
61         virtual at::Tensor operator()(const at::Tensor& grad) = 0;
62       };
63       // NOTE [Deprecated capture hooks]
64       //
65       // The current status of capture hooks is that we continue to support
66       // the single usage of it by distributed in the dist_engine. If anyone
67       // else needs to use it for other purposes, they should file an issue.
68       //
69       // Capture hooks were originally created because there did not exist
70       // any way to register pre/post hooks to grad_fn in a way such that it
71       // would still be executed even if that is the grad_fn of a Tensor
72       // passed as input= of .grad. As far as I know, only dist_engine uses
73       // this hook.
74       //
75       // However, there are other alternatives today like tensor hooks that can
76       // replace the usage that originally motivated its creation. Also,
77       // Captures hooks are an outlier in terms of the types of hook that
78       // autograd offers in how it is registered and behaves, e.g. it is a hook
79       // registered not to the graph, but to a particular graph_task! This makes
80       // it a burden to maintain.
81       //
82       // It would be very nice to clean up/do a migration from pre/post
83       // hooks used in distributed to use tensor hooks, but for now we just
84       // mark this method as deprecated to prevent additional usage.
85       //
86       // If you still think you really need to capture hooks, please file an
87       // issue (and tag autograd).
88       const std::vector<std::unique_ptr<GradCaptureHook>>&
DO_NOT_USE_DEPRECATED_get_capture_hooksGraphTask::ExecInfo::Capture89       DO_NOT_USE_DEPRECATED_get_capture_hooks() const {
90         return hooks_;
91       }
92       // See NOTE [deprecated capture hooks]
DO_NOT_USE_DEPRECATED_register_capture_hookGraphTask::ExecInfo::Capture93       void DO_NOT_USE_DEPRECATED_register_capture_hook(
94           std::unique_ptr<GradCaptureHook> hook) {
95         hooks_.push_back(std::move(hook));
96       }
97 
98      private:
99       // The hooks will be called one by one in the order as they were added.
100       // The input grad of a hook will be the output of its preceding hook. The
101       // first hook will take the captured grad as the input. The output of the
102       // last hook will replace the captured grad.
103       std::vector<std::unique_ptr<GradCaptureHook>> hooks_;
104     };
105 
should_executeGraphTask::ExecInfo106     bool should_execute() const {
107       return needed_ || captures_;
108     }
109 
110     bool needed_ = false;
111     std::unique_ptr<std::vector<Capture>> captures_;
112   };
113   // exec_info_ is safe to read without synchronization
114   std::unordered_map<Node*, ExecInfo> exec_info_;
115   // Captures variables are grads captured that we return to the user. After
116   // execution of the GraphTask is completed, the captured_vars_ are moved
117   // out of the GraphTask and are no longer valid.
118   std::vector<Variable> captured_vars_;
119 
120   // Note: this field is not ready to be used until the proper
121   // `thread_locals_.set_grad_mode()` call in the constructor.
122   at::ThreadLocalState thread_locals_ = at::ThreadLocalState();
123 
124   std::unordered_set<c10::Stream> leaf_streams;
125 
126   // Per-device current streams of the execute() that called this GraphTask.
127   // These will be synced with leaf_streams in exec_post_processing.
128   std::vector<std::optional<c10::Stream>> caller_current_streams_;
129 
130   // Collects caller_current_streams_ for the accelerator device.
131   void stash_current_streams();
132 
133   void init_to_execute(
134       Node& graph_root,
135       const edge_list& outputs,
136       bool accumulate_grad,
137       uint64_t min_topo_nr);
138 
139   // The value of worker_device in the thread that created this task.
140   // See Note [Reentrant backwards]
141   // Safe to read owner_ and reentrant_depth_ without synchronization
142   int owner_;
143   // The number of parent graph tasks for this graph task
144   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
145   const int reentrant_depth_;
146 
can_checkpointGraphTask147   bool can_checkpoint() const {
148     return exec_info_.empty();
149   }
150 
151   // check if the GraphTask is completed or not
152   bool completed();
153   // mark the graph task as completed and trigger post processing
154   void mark_as_completed_and_run_post_processing();
155 
156   // Set an appropriate exception on this graph_task which was encountered while
157   // running the provided function.
158   void set_exception(std::exception_ptr eptr, const std::shared_ptr<Node>& fn);
159 
160   // Set an appropriate exception on this graph_task which was encountered while
161   // running the provided function. But doesn't signal completion on
162   // 'future_result_' right away. The user needs to explicitly mark
163   // 'future_result_' completed with an appropriate exception.
164   void set_exception_without_signal(const std::shared_ptr<Node>& fn);
165 
166   // Whether or not to stop execution for this GraphTask when an error is
167   // encountered. When set to true, this would cause Engine::execute() to throw
168   // an exception as soon as the autograd engine receives an exception.
169   bool exit_on_error_;
170 
171   // CPU threads are dedicated to processing CPU work for the backward they
172   // invoked. So any given graph task maintains its own cpu_ready_queue_ where
173   // you should send work for it to be done. We memoize the cpu_ready_queue_ per
174   // GraphTask so that we know which ready queue we should push to if we are on
175   // device thread (i.e. GPU) and but next NodeTask should be run on CPU.
176   std::shared_ptr<ReadyQueue> cpu_ready_queue_;
177 
178   // Future representing the completion of the graph task. Notified when all
179   // tasks are done.
180   c10::intrusive_ptr<at::ivalue::Future> future_result_;
181 
182   // Final callbacks installed during execution of this GraphTask
183   std::vector<std::function<void()>> final_callbacks_;
184   // To protect reads and writes to final_callbacks_. Intentionally no reusing
185   // mutex_ as the two are protecting different data structures.
186   std::mutex final_callbacks_lock_;
187 
188   utils::DelayWarningHandler warning_handler_;
189 
190   uint64_t id_;
191 
192   GraphTask(
193       bool keep_graph,
194       bool grad_mode,
195       int reentrant_depth,
196       std::shared_ptr<ReadyQueue> cpu_ready_queue,
197       c10::SmallVector<Node*, 4> graph_roots,
198       bool exit_on_error = false);
199 
200  private:
201   // run GraphTask post processing
202   void exec_post_processing();
203 };
204 
205 // The guard that sets and restores current_graph_task.
206 class GraphTaskGuard {
207  public:
208   explicit GraphTaskGuard(std::shared_ptr<GraphTask> graph_task);
209   ~GraphTaskGuard();
210 
211   void restore_current_graph_task();
212 
213  private:
214   std::shared_ptr<GraphTask> last_graph_task_;
215 };
216 
217 TORCH_API const std::unordered_map<Node*, GraphTask::ExecInfo>*
218 get_current_graph_task_exec_info();
219 TORCH_API const std::unordered_set<Node*>*
220 get_current_graph_task_nodes_in_graph();
221 TORCH_API bool get_current_graph_task_keep_graph();
222 TORCH_API std::vector<Node*> get_current_graph_task_execution_order();
223 TORCH_API int get_current_graph_task_id();
224 void add_node_to_current_graph_task_exec_info(Node* fn);
225 
226 } // namespace torch::autograd
227