1 #pragma once 2 #include <ATen/ThreadLocalState.h> 3 #include <ATen/core/Tensor.h> 4 #include <c10/util/ThreadLocal.h> 5 #include <torch/csrc/autograd/input_buffer.h> 6 #include <torch/csrc/autograd/utils/warnings.h> 7 #include <vector> 8 9 namespace torch::autograd { 10 11 using edge_list = std::vector<Edge>; 12 struct ReadyQueue; 13 14 static constexpr int NO_DEVICE = -2; 15 static constexpr int CPU_DEVICE = -1; 16 17 // GraphTask holds metadata needed for a single execution of backward() 18 struct GraphTask : std::enable_shared_from_this<GraphTask> { 19 std::atomic<uint64_t> outstanding_tasks_{0}; 20 // Indicates if an error occurred while executing any task. When this is 21 // true, it signals all threads to stop executing. 22 std::atomic_bool has_error_{false}; 23 std::atomic_bool future_completed_{false}; 24 // It is safe to read keep_graph_ without synchronization 25 bool keep_graph_; 26 27 // To protect reads/writes to not_ready_, dependencies_, captured_vars_, 28 // has_error_, future_result_, cpu_ready_queue_, and leaf_streams. 29 std::mutex mutex_; 30 std::unordered_map<Node*, InputBuffer> not_ready_; 31 std::unordered_map<Node*, int> dependencies_; 32 33 // Records the nodes that are in the graph 34 std::unordered_set<Node*> nodes_in_graph_; 35 c10::SmallVector<Node*, 4> graph_roots_; 36 // Note [Exec info] 37 // Exec info is created for each GraphTask, which allows filtering paths on 38 // the graph that are not needed. It has a bit complicated semantics. If it's 39 // empty, it means the task is run in a "default" mode, which means that all 40 // next_edges we encounter should get executed. If it's not empty, only 41 // functions that have an entry and this entry has needed == True should be 42 // executed. exec_info is only empty when the graph is executed via 43 // .backward() and the inputs parameter is not passed. Otherwise, when 44 // executed through .grad(), or when inputs arg is specified for .backward(), 45 // exec_info will be non-empty. 46 // 47 struct ExecInfo { 48 struct Capture { 49 Capture(const Capture&) = delete; 50 Capture(Capture&&) = default; 51 CaptureGraphTask::ExecInfo::Capture52 Capture(int input_idx, int output_idx) 53 : input_idx_(input_idx), output_idx_(output_idx) {} 54 int input_idx_; // within Node inputs 55 int output_idx_; // within the output vector of a GraphTask 56 57 // This hook will be executed after a grad is captured. The captured 58 // grad will be replaced by the return value of the hook. 59 struct GradCaptureHook { 60 virtual ~GradCaptureHook() = default; 61 virtual at::Tensor operator()(const at::Tensor& grad) = 0; 62 }; 63 // NOTE [Deprecated capture hooks] 64 // 65 // The current status of capture hooks is that we continue to support 66 // the single usage of it by distributed in the dist_engine. If anyone 67 // else needs to use it for other purposes, they should file an issue. 68 // 69 // Capture hooks were originally created because there did not exist 70 // any way to register pre/post hooks to grad_fn in a way such that it 71 // would still be executed even if that is the grad_fn of a Tensor 72 // passed as input= of .grad. As far as I know, only dist_engine uses 73 // this hook. 74 // 75 // However, there are other alternatives today like tensor hooks that can 76 // replace the usage that originally motivated its creation. Also, 77 // Captures hooks are an outlier in terms of the types of hook that 78 // autograd offers in how it is registered and behaves, e.g. it is a hook 79 // registered not to the graph, but to a particular graph_task! This makes 80 // it a burden to maintain. 81 // 82 // It would be very nice to clean up/do a migration from pre/post 83 // hooks used in distributed to use tensor hooks, but for now we just 84 // mark this method as deprecated to prevent additional usage. 85 // 86 // If you still think you really need to capture hooks, please file an 87 // issue (and tag autograd). 88 const std::vector<std::unique_ptr<GradCaptureHook>>& DO_NOT_USE_DEPRECATED_get_capture_hooksGraphTask::ExecInfo::Capture89 DO_NOT_USE_DEPRECATED_get_capture_hooks() const { 90 return hooks_; 91 } 92 // See NOTE [deprecated capture hooks] DO_NOT_USE_DEPRECATED_register_capture_hookGraphTask::ExecInfo::Capture93 void DO_NOT_USE_DEPRECATED_register_capture_hook( 94 std::unique_ptr<GradCaptureHook> hook) { 95 hooks_.push_back(std::move(hook)); 96 } 97 98 private: 99 // The hooks will be called one by one in the order as they were added. 100 // The input grad of a hook will be the output of its preceding hook. The 101 // first hook will take the captured grad as the input. The output of the 102 // last hook will replace the captured grad. 103 std::vector<std::unique_ptr<GradCaptureHook>> hooks_; 104 }; 105 should_executeGraphTask::ExecInfo106 bool should_execute() const { 107 return needed_ || captures_; 108 } 109 110 bool needed_ = false; 111 std::unique_ptr<std::vector<Capture>> captures_; 112 }; 113 // exec_info_ is safe to read without synchronization 114 std::unordered_map<Node*, ExecInfo> exec_info_; 115 // Captures variables are grads captured that we return to the user. After 116 // execution of the GraphTask is completed, the captured_vars_ are moved 117 // out of the GraphTask and are no longer valid. 118 std::vector<Variable> captured_vars_; 119 120 // Note: this field is not ready to be used until the proper 121 // `thread_locals_.set_grad_mode()` call in the constructor. 122 at::ThreadLocalState thread_locals_ = at::ThreadLocalState(); 123 124 std::unordered_set<c10::Stream> leaf_streams; 125 126 // Per-device current streams of the execute() that called this GraphTask. 127 // These will be synced with leaf_streams in exec_post_processing. 128 std::vector<std::optional<c10::Stream>> caller_current_streams_; 129 130 // Collects caller_current_streams_ for the accelerator device. 131 void stash_current_streams(); 132 133 void init_to_execute( 134 Node& graph_root, 135 const edge_list& outputs, 136 bool accumulate_grad, 137 uint64_t min_topo_nr); 138 139 // The value of worker_device in the thread that created this task. 140 // See Note [Reentrant backwards] 141 // Safe to read owner_ and reentrant_depth_ without synchronization 142 int owner_; 143 // The number of parent graph tasks for this graph task 144 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 145 const int reentrant_depth_; 146 can_checkpointGraphTask147 bool can_checkpoint() const { 148 return exec_info_.empty(); 149 } 150 151 // check if the GraphTask is completed or not 152 bool completed(); 153 // mark the graph task as completed and trigger post processing 154 void mark_as_completed_and_run_post_processing(); 155 156 // Set an appropriate exception on this graph_task which was encountered while 157 // running the provided function. 158 void set_exception(std::exception_ptr eptr, const std::shared_ptr<Node>& fn); 159 160 // Set an appropriate exception on this graph_task which was encountered while 161 // running the provided function. But doesn't signal completion on 162 // 'future_result_' right away. The user needs to explicitly mark 163 // 'future_result_' completed with an appropriate exception. 164 void set_exception_without_signal(const std::shared_ptr<Node>& fn); 165 166 // Whether or not to stop execution for this GraphTask when an error is 167 // encountered. When set to true, this would cause Engine::execute() to throw 168 // an exception as soon as the autograd engine receives an exception. 169 bool exit_on_error_; 170 171 // CPU threads are dedicated to processing CPU work for the backward they 172 // invoked. So any given graph task maintains its own cpu_ready_queue_ where 173 // you should send work for it to be done. We memoize the cpu_ready_queue_ per 174 // GraphTask so that we know which ready queue we should push to if we are on 175 // device thread (i.e. GPU) and but next NodeTask should be run on CPU. 176 std::shared_ptr<ReadyQueue> cpu_ready_queue_; 177 178 // Future representing the completion of the graph task. Notified when all 179 // tasks are done. 180 c10::intrusive_ptr<at::ivalue::Future> future_result_; 181 182 // Final callbacks installed during execution of this GraphTask 183 std::vector<std::function<void()>> final_callbacks_; 184 // To protect reads and writes to final_callbacks_. Intentionally no reusing 185 // mutex_ as the two are protecting different data structures. 186 std::mutex final_callbacks_lock_; 187 188 utils::DelayWarningHandler warning_handler_; 189 190 uint64_t id_; 191 192 GraphTask( 193 bool keep_graph, 194 bool grad_mode, 195 int reentrant_depth, 196 std::shared_ptr<ReadyQueue> cpu_ready_queue, 197 c10::SmallVector<Node*, 4> graph_roots, 198 bool exit_on_error = false); 199 200 private: 201 // run GraphTask post processing 202 void exec_post_processing(); 203 }; 204 205 // The guard that sets and restores current_graph_task. 206 class GraphTaskGuard { 207 public: 208 explicit GraphTaskGuard(std::shared_ptr<GraphTask> graph_task); 209 ~GraphTaskGuard(); 210 211 void restore_current_graph_task(); 212 213 private: 214 std::shared_ptr<GraphTask> last_graph_task_; 215 }; 216 217 TORCH_API const std::unordered_map<Node*, GraphTask::ExecInfo>* 218 get_current_graph_task_exec_info(); 219 TORCH_API const std::unordered_set<Node*>* 220 get_current_graph_task_nodes_in_graph(); 221 TORCH_API bool get_current_graph_task_keep_graph(); 222 TORCH_API std::vector<Node*> get_current_graph_task_execution_order(); 223 TORCH_API int get_current_graph_task_id(); 224 void add_node_to_current_graph_task_exec_info(Node* fn); 225 226 } // namespace torch::autograd 227