1 #pragma once
2
3 #include <torch/csrc/autograd/anomaly_mode.h>
4 #include <torch/csrc/autograd/edge.h>
5 #include <torch/csrc/autograd/grad_mode.h>
6 #include <torch/csrc/autograd/graph_task.h>
7 #include <torch/csrc/autograd/input_metadata.h>
8 #include <torch/csrc/autograd/saved_variable.h>
9 #include <torch/csrc/autograd/variable.h>
10 #include <torch/csrc/utils/python_stub.h>
11 #include <torch/csrc/utils/variadic.h>
12
13 #include <ATen/SequenceNumber.h>
14 #include <ATen/core/Tensor.h>
15 #include <ATen/record_function.h>
16 #include <c10/util/Exception.h>
17 #include <c10/util/irange.h>
18
19 #include <algorithm>
20 #include <cstdint>
21 #include <initializer_list>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 namespace torch::autograd {
28
29 struct Edge;
30 struct FunctionPostHook;
31 struct FunctionPreHook;
32
33 using tensor_list = std::vector<at::Tensor>;
34 using variable_list = std::vector<Variable>;
35 using edge_list = std::vector<Edge>;
36 using saved_variable_list = std::vector<SavedVariable>;
37 using IndexRange = std::pair<size_t, size_t>;
38 using torch::dynamo::autograd::CompiledNodeArgs;
39 using torch::dynamo::autograd::SwapSavedVariables;
40
41 // Custom deleter to prevent stack overflows.
42 TORCH_API void deleteNode(Node* function);
43
44 // Guard that sets and restores the evaluating node
45 class NodeGuard {
46 public:
47 explicit NodeGuard(std::shared_ptr<Node> node);
48 ~NodeGuard();
49
50 private:
51 std::shared_ptr<Node> last_evaluating_node_;
52 };
53
54 // Return the Node currently being evaluated (if any)
55 // This is only set during the backward pass while a Node is being
56 // executed.
57 TORCH_API std::shared_ptr<Node> get_current_node();
58
59 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
60 // Node
61 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
62 // A `Node` is an abstract class that represents an operation taking zero
63 // or more input `Variable`s and producing zero or more output `Variable`s. All
64 // functions in PyTorch's autograd machinery derive from this class and
65 // override its `apply` method. Instances of such subclasses will then be
66 // invokable via the call operator.
67 //
68 // Nodes in the Autograd Graph
69 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
70 // When viewing the autograd system as a graph, `Node`s are the vertices or
71 // nodes, connected to each other via (directed) `Edge`s, which themselves are
72 // represented via (`Node`, input_nr) pairs. `Variable`s are the outputs to
73 // and inputs of `Node`s, and travel between these edges during execution
74 // of the graph. When two or more `Edge`s (from different sources) point at the
75 // same input to a `Node`, the values produced along all of these edges are
76 // implicitly summed prior to being forwarded to the target `Node`.
77 //
78 // Hierarchy
79 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
80 // Subclasses usually represent differentiable functions as well as their
81 // gradient operators. Note, however, that due to the very general definition
82 // of a `Node` taking *zero* or more inputs and producing *zero* or more
83 // outputs, uses of `Node`s are flexible and extend beyond purely
84 // mathematical operations. For example, the `AccumulateGrad` function is a
85 // *sink*: it takes one input, but produces no outputs, instead accumulating
86 // the input as a side effect. At the other extreme, the `GraphRoot` function
87 // receives no inputs from other functions, but produces multiple outputs.
88 //
89 // Interface
90 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
91 // The most important method on `Node` is the call operator, which takes in
92 // a list of variables and produces a list of variables. The precise size of
93 // these lists can be determined with `num_inputs()` and `num_outputs()`.
94 // `Node`s are stitched together via their `next_edge` interface, which let
95 // you manipulate the set of outgoing edges of a `Node`. You can add an
96 // edge with `add_next_edge()`, retrieve an edge with `next_edge(index)` and
97 // iterate over them via the `next_edges()` method. Other methods exist for
98 // integration with the JIT and other parts of PyTorch. Every `Node` has a
99 // *sequence number* that increases monotonically in the order of `Node`
100 // construction. It can be retrieved via the `sequence_nr()` method. Note that
101 // this sequence number is *thread local*. This means that when `Node`s
102 // `A`, `B` and `C` are created consecutively in the same thread, their
103 // sequence numbers will be ordered `A` < `B` < `C`. If, however, `A` and `B`
104 // are created in one thread and `C` is created in a new thread, there are *no
105 // guarantees* w.r.t. the ordering of `C` relative to `A` or `B`.
106 // See NOTE [ Sequence Number] for more details on the usages of sequence
107 // number.
108 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
109 struct TORCH_API Node : std::enable_shared_from_this<Node> {
110 public:
111 /// Construct a new `Node` with the given `next_edges`
112 explicit Node(uint64_t sequence_nr, edge_list&& next_edges = edge_list())
sequence_nr_Node113 : sequence_nr_(sequence_nr), next_edges_(std::move(next_edges)) {
114 for (const Edge& edge : next_edges_) {
115 update_topological_nr(edge);
116 }
117
118 if (AnomalyMode::is_enabled()) {
119 metadata()->store_stack();
120
121 // If anomaly mode is enabled and graph is constructed, then assign the
122 // currently evaluating node as the parent of this node.
123 // A parent is a Node where this Node is created.
124 // We are tracking the parents to track multiple backward operations.
125 assign_parent();
126 }
127
128 // Store the thread_id of the forward operator.
129 // See NOTE [ Sequence Numbers ]
130 thread_id_ = at::RecordFunction::currentThreadId();
131 }
132
133 explicit Node(edge_list&& next_edges = edge_list())
NodeNode134 : Node(
135 /*sequence_nr=*/at::sequence_number::get_and_increment(),
136 std::move(next_edges)) {}
137
138 /// Nodes are neither copyable nor moveable.
139 Node(const Node& other) = delete;
140 Node(Node&& other) = delete;
141 Node& operator=(const Node& other) = delete;
142 Node& operator=(Node&& other) = delete;
143 virtual ~Node() = default;
144
getptrNode145 std::shared_ptr<Node> getptr() {
146 return shared_from_this();
147 }
148 /// Evaluates the function on the given inputs and returns the result of the
149 /// function call.
operatorNode150 variable_list operator()(variable_list&& inputs) {
151 // In the first iteration of named tensors, autograd ignores names and
152 // operates on unnamed tensors. In the long term, autograd should
153 // probably operate with names.
154 at::NoNamesGuard no_names_guard;
155
156 #ifdef USE_ROCM
157 // Keep track of backward pass for rocblas.
158 at::ROCmBackwardPassGuard in_backward;
159 #endif
160
161 auto step_callbacks =
162 at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION);
163 if (C10_UNLIKELY(step_callbacks.has_value())) {
164 at::RecordFunction guard(std::move(*step_callbacks));
165 // Using sequence number and thread id to correlate with
166 // the forward pass function
167 guard.setForwardThreadId(thread_id_);
168 if (guard.needsInputs()) {
169 std::vector<c10::IValue> inputs_vec(inputs.begin(), inputs.end());
170 guard.before(
171 name(),
172 c10::ArrayRef<const c10::IValue>(
173 inputs_vec.data(), inputs_vec.size()),
174 static_cast<int64_t>(sequence_nr()));
175 } else {
176 guard.before(name(), static_cast<int64_t>(sequence_nr()));
177 }
178 return apply(std::move(inputs));
179 } else {
180 return apply(std::move(inputs));
181 }
182 }
183
184 // Graph Connectivity API
185 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
186
187 // Inputs. NOTE: inputs of the grad_fn correspond to Tensor outputs of the
188 // forward function.
189
190 // Marker for expected undefined input
191 struct undefined_input {};
192
193 /// Adds the type and shape metadata for a new input. Returns the index of
194 /// of the new input.
add_input_metadataNode195 uint32_t add_input_metadata(
196 const at::TensorOptions& options,
197 c10::SymIntArrayRef shape,
198 bool is_tensor_subclass,
199 bool is_nested) noexcept {
200 uint32_t input_nr = input_metadata_.size();
201 auto meta_shape = MetadataShape{std::in_place_type<SymIntSmallVec>, shape};
202 input_metadata_.emplace_back(
203 options, meta_shape, is_tensor_subclass, is_nested);
204 return input_nr;
205 }
206
add_input_metadataNode207 uint32_t add_input_metadata(const at::Tensor& t) noexcept {
208 uint32_t input_nr = input_metadata_.size();
209 input_metadata_.emplace_back(t);
210 return input_nr;
211 }
212
213 /// Adds a placeholder for an input that will not be used.
add_input_metadataNode214 uint32_t add_input_metadata(undefined_input u) noexcept {
215 uint32_t input_nr = input_metadata_.size();
216 input_metadata_.emplace_back();
217 return input_nr;
218 }
219
num_inputsNode220 uint32_t num_inputs() const noexcept {
221 return input_metadata_.size();
222 }
223
input_metadataNode224 const InputMetadata& input_metadata(size_t index) const {
225 return input_metadata_[index];
226 }
227
228 // Danger: not thread safe, caller must protect with lock
mutable_input_metadataNode229 InputMetadata& mutable_input_metadata(size_t index) {
230 return input_metadata_[index];
231 }
232
233 /**
234 * Note: Function Streams
235 * A function's stream (for a given device type) is the stream of the first
236 * element of its input buffer on a device of that type.
237 *
238 * If all elements are on the same device they MUST share a stream. If
239 * elements are on different devices (across multiple GPUs, for example)
240 * they may have different streams.
241 */
streamNode242 std::optional<c10::Stream> stream() {
243 auto opt_device_type = at::getAccelerator();
244 if (!opt_device_type.has_value()) {
245 return std::nullopt;
246 }
247 for (const auto& metadata : input_metadata_) {
248 if (metadata.device().type() == opt_device_type.value())
249 return metadata.stream();
250 }
251
252 return std::nullopt;
253 }
254
clear_input_metadataNode255 void clear_input_metadata() {
256 input_metadata_.clear();
257 }
258
259 // Outputs ("Next Edges")
260
update_topological_nrNode261 void update_topological_nr(const Edge& edge) {
262 TORCH_INTERNAL_ASSERT(
263 !has_parent_,
264 "Cannot update a node's topological_nr after it already has a parent."
265 " If we allow this, we can no longer guarantee that a parent's"
266 " topo_nr is always greater than those of all its children")
267 Node* node = edge.function.get();
268 if (node) {
269 auto topo_nr = node->topological_nr();
270 if (topological_nr_ <= topo_nr) {
271 topological_nr_ = topo_nr + 1;
272 }
273 }
274 }
275
set_next_edgeNode276 void set_next_edge(size_t index, Edge edge) {
277 update_topological_nr(edge);
278 next_edges_[index] = std::move(edge);
279 }
280
add_next_edgeNode281 void add_next_edge(Edge edge) {
282 update_topological_nr(edge);
283 next_edges_.emplace_back(std::move(edge));
284 }
285
set_next_edgesNode286 void set_next_edges(edge_list&& next_edges) {
287 next_edges_ = std::move(next_edges);
288 for (const auto& next_edge : next_edges_) {
289 update_topological_nr(next_edge);
290 }
291 }
292
next_edgeNode293 const Edge& next_edge(size_t index) const noexcept {
294 return next_edges_[index];
295 }
296
next_edgesNode297 const edge_list& next_edges() const noexcept {
298 return next_edges_;
299 }
300
next_edgesNode301 edge_list& next_edges() noexcept {
302 return next_edges_;
303 }
304
num_outputsNode305 uint32_t num_outputs() const noexcept {
306 return next_edges_.size();
307 }
308
309 // Miscellaneous Methods
310 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
311
312 /// NOTE [ Sequence Number]
313 ///
314 /// The sequence_nr has two main usages in autograd:
315 ///
316 /// 1) Helps determine the node's execution priority in the engine.
317 /// All else being equal, nodes with higher priority numbers are executed
318 /// first. Thus, nodes corresponding to ops executed later are the first to
319 /// be executed in the backward pass. One caveat is that we prioritize
320 /// AccumulateGrad nodes by explicitly setting its sequence_nr to be
321 /// UINT64_MAX.
322 /// 2) The sequence number of this `Node` is paired with with thread_id it was
323 /// created in
324 /// as a unique identifier by the profiler to annotate recorded events.
325 /// The purpose of this is to help users (and possibly programs)
326 /// interpreting the profiler's output to correlate backward nodes with its
327 /// forward ops. We need both sequence_nr and thread_id to identify a node
328 /// because sequence_nr is thread_local, i.e., starts counting up from zero
329 /// in a new thread
sequence_nrNode330 uint64_t sequence_nr() const noexcept {
331 return sequence_nr_;
332 }
333
set_sequence_nrNode334 void set_sequence_nr(uint64_t sequence_nr) {
335 sequence_nr_ = sequence_nr;
336 }
337
338 // NOTE [ Topological Number ]
339 //
340 // topological_nr is used to prune branches in the DAG during autograd
341 // discovery as maintaining topological_nr helps us check in O(1) if there
342 // does NOT exist a directed path between two nodes.
343 //
344 // The topological order number of this `Node` representing the length of the
345 // longest possible path from this Node to any leaf node. If you are leaf
346 // node, aka AccumulateGrad, this will be zero. This value has the property
347 // that For every pair of nodes X, Y in G, existence of a directed path from X
348 // to Y implies topo_nr(X) > topo_nr(Y). The converse is not true, however, so
349 // we cannot prove existence of a path from X to Y, only non-existence.
350 //
351 // One assumption we make when using topo_nr is that once a node
352 // has been used, i.e., has a parent node, its own topo_nr does not change
353 // we have added some checks with the `has_parent_` field to enforce this.
354 //
355 // What NOT to do:
356 //
357 // 1) 2 -> 1 -> 0 In this diagram we label nodes with their
358 // topo_nr.
359 // 2 -> 1 -> 0 We have two simple graphs that can each
360 // arise from
361 // `t.exp().exp()`, for example.
362 // 2) 2 -> 1 -> 0
363 // /
364 // 2 -> 1 -> 0 We add 2 as a next edge to 1 even though 1
365 // already
366 // has a parent.
367 // 3) 2 -> 1 -> 0
368 // /
369 // 2 -> 3 -> 0 2 < 3, yet there exists a path from 2 to 3!
370 //
topological_nrNode371 uint64_t topological_nr() const noexcept {
372 has_parent_ = true;
373 return topological_nr_;
374 }
375
376 // assigning a node as a parent to this node
377 void assign_parent();
378
379 /// Id of the thread that created Node
thread_idNode380 uint64_t thread_id() const noexcept {
381 return thread_id_;
382 }
383
384 /// Returns the name of the dynamic type of the function, for debugging.
385 virtual std::string name() const;
386
387 /// The difference between functions `should_compute_output` and
388 /// `task_should_compute_output`:
389 /// - `should_compute_output` should only be used during graph construction
390 /// and takes into account only requires_grad information
391 /// - `task_should_compute_output` should only be called during the backward
392 /// pass (unless called directly through grad_fn) and takes into account the
393 /// current graph task. Specifically, the autograd engine trims unnecessary
394 /// edges when `inputs` are specified, and during backward untrimmed nodes
395 /// left on the graph can/should check `task_should_compute_output` to see if
396 /// any outgoing edges have been trimmed by the engine. If that is the case,
397 /// gradient computation wrt those edges can be omitted.
398 ///
399 /// Returns true if the particular output edge is active, and that particular
400 /// output of this function should be computed.
should_compute_outputNode401 bool should_compute_output(size_t output_edge_index) const {
402 TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range");
403 return next_edges_[output_edge_index].is_valid();
404 }
405
406 /// Returns true if any of the output edges in any of the ranges are active.
should_compute_outputNode407 bool should_compute_output(std::initializer_list<IndexRange> idxs) const {
408 return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) {
409 for (const auto i : c10::irange(range.first, range.second)) {
410 if (should_compute_output(i))
411 return true;
412 }
413 return false;
414 });
415 }
416
417 /// Same as the above `should_compute_output` function but will also
418 /// check whether this edge is needed within the current graph task.
task_should_compute_outputNode419 bool task_should_compute_output(size_t output_edge_index) const {
420 TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range");
421 const auto& next = next_edges_[output_edge_index];
422 if (next.is_valid()) {
423 const auto exec_info = get_current_graph_task_exec_info();
424 if (exec_info && !exec_info->empty()) {
425 auto it = exec_info->find(next.function.get());
426 if (it == exec_info->end() || !it->second.should_execute()) {
427 return false; // this edge is not needed for the current graph_task
428 }
429 }
430 return true;
431 }
432 return false;
433 }
434
435 /// Returns true if any of the output edges in any of the ranges are active
436 /// and should be computed in the current graph task.
task_should_compute_outputNode437 bool task_should_compute_output(
438 std::initializer_list<IndexRange> idxs) const {
439 return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) {
440 for (const auto i : c10::irange(range.first, range.second)) {
441 if (task_should_compute_output(i))
442 return true;
443 }
444 return false;
445 });
446 }
447
448 /// Returns the `PyObject` stored for this `Node` (for Python
449 /// interaction).
pyobjNode450 PyObject* pyobj() const noexcept {
451 return pyobj_;
452 }
453
454 /// Sets the `PyObject` stored for this `Node` (for Python interaction).
set_pyobjNode455 void set_pyobj(PyObject* pyobj) noexcept {
456 pyobj_ = pyobj;
457 }
458
459 /// Returns the anomaly metadata stored for this `Node`.
460 /// If none exist, creates a new empty one.
461 AnomalyMetadata* metadata() noexcept;
462
463 // Hook API
464 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
465
add_post_hookNode466 uintptr_t add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) {
467 post_hooks_.emplace_back(std::move(post_hook));
468 // Use the raw pointer as the unique key to identify this hook. This key
469 // can then be used in del_post_hook(key) to remove this hook.
470 return reinterpret_cast<std::uintptr_t>(post_hooks_.back().get());
471 }
472
post_hooksNode473 const std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks()
474 const noexcept {
475 return post_hooks_;
476 }
477
478 // delete a post hook matching the key
del_post_hookNode479 bool del_post_hook(const uintptr_t& key) {
480 for (auto it = post_hooks_.begin(); it != post_hooks_.end(); ++it) {
481 if (key == reinterpret_cast<std::uintptr_t>(it->get())) {
482 post_hooks_.erase(it);
483 return true;
484 }
485 }
486 return false;
487 }
488
post_hooksNode489 std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() noexcept {
490 return post_hooks_;
491 }
492
add_pre_hookNode493 void add_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
494 pre_hooks_.emplace_back(std::move(pre_hook));
495 }
496
add_tensor_pre_hookNode497 void add_tensor_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
498 tensor_pre_hooks_.emplace_back(std::move(pre_hook));
499 }
500
add_retains_grad_hookNode501 void add_retains_grad_hook(
502 std::unique_ptr<FunctionPreHook>&& pre_hook,
503 size_t output_idx) {
504 retains_grad_hooks_[output_idx] = std::move(pre_hook);
505 }
506
pop_retains_grad_hookNode507 std::unique_ptr<FunctionPreHook> pop_retains_grad_hook(size_t output_idx) {
508 auto ret = std::move(retains_grad_hooks_[output_idx]);
509 retains_grad_hooks_.erase(output_idx);
510 return ret;
511 }
512
pre_hooksNode513 const std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks()
514 const noexcept {
515 return pre_hooks_;
516 }
517
pre_hooksNode518 std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() noexcept {
519 return pre_hooks_;
520 }
521
522 virtual std::vector<std::unique_ptr<FunctionPreHook>>&
tensor_pre_hooksNode523 tensor_pre_hooks() noexcept {
524 return tensor_pre_hooks_;
525 }
526
527 virtual std::unique_ptr<PostAccumulateGradHook>&
tensor_post_acc_grad_hooksNode528 tensor_post_acc_grad_hooks() noexcept {
529 static std::unique_ptr<PostAccumulateGradHook> empty = nullptr;
530 return empty;
531 }
532
533 std::unordered_map<size_t, std::unique_ptr<FunctionPreHook>>&
retains_grad_hooksNode534 retains_grad_hooks() noexcept {
535 return retains_grad_hooks_;
536 }
537
538 // Customization Points for Subclasses
539 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
540
541 /// Releases saved variables if the operation won't be reused.
release_variablesNode542 virtual void release_variables() {}
543
544 /// Called before an apply if `release_variables()` is going to be called.
545 /// Allows larger ops like `InterpreterAutogradFunction` to incrementally
546 /// release variables as they run.
will_release_variablesNode547 virtual void will_release_variables() {}
548
549 /// Returns true if this function is traceable. An op is traceable if all
550 /// operations happening within `apply()` are performed on autograd
551 /// `Variables` (i.e. apply mostly instantiates and applies other functions).
is_traceableNode552 virtual bool is_traceable() {
553 return false;
554 }
555
556 /// A `Node` is said to pass state transparently to backward, if the
557 /// state consists only of (Saved)Variables and only non-variable objects
558 /// that parameterize the operation in some way that defines the graph
559 /// structure AND the backward function is traceable. In particular,
560 /// parametrization MUST NOT depend on the data of any `Variable`.
561 /// TODO: it might be possible to handle cases where backward is
562 /// non-traceable but state passing could be considered transparent. This
563 /// will probably depend on saved_variable_list being mutable.
564 /// NOTE: this value matters only if is_traceable() returns false.
passes_state_transparentlyNode565 virtual bool passes_state_transparently() {
566 return false;
567 }
568
569 // see [Note: Compiled Autograd]
570 // Used by compiled autograd to
571 // 1) Extract tensors/symint args
572 // 2) Collect node information for specialization and caching
573 // Implementations in subclasses should call args.collect() with all node
574 // attrs. These functions are only called durring backward.
compiled_argsNode575 virtual void compiled_args(CompiledNodeArgs& args) {
576 throw std::runtime_error(
577 std::string("compiled_args not implemented: ") + name());
578 }
579
580 // Used by compiled autograd to call apply() with different saved tensors
581 // Implementations should call saved.before() on all attrs, then apply(), then
582 // saved.after() on all attrs in the same order.
apply_with_savedNode583 virtual variable_list apply_with_saved(
584 const variable_list& inputs,
585 SwapSavedVariables& saved) {
586 throw std::runtime_error(
587 std::string("apply_with_saved not implemented: ") + name());
588 }
589
590 protected:
591 /// Performs the `Node`'s actual operation.
592 virtual variable_list apply(variable_list&& inputs) = 0;
593
594 /// Calls `apply()`, but instruments it with tracing machinery.
595 variable_list traced_apply(variable_list inputs);
596
597 // Sequence number used to correlate backward nodes with forward ops in the
598 // profiler and provide determinism in the engine.
599 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
600 uint64_t sequence_nr_;
601
602 // See NOTE [ Topological Number ]
603 uint64_t topological_nr_ = 0;
604
605 // Tracks whether this node has been added as the next_edge of another node
606 // via set_next_edge(s), which always calls topological_nr() of all its
607 // children See NOTE [ Topological Number ] for why we need this.
608 mutable bool has_parent_ = false;
609
610 // Id of the thread that created the instance
611 uint64_t thread_id_ = 0;
612
613 // Note [Thread Safety on Autograd Node]
614 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
615 // Autograd Engine let the owning thread which calls Engine::execute to drive
616 // the GraphTask execution, there might be cases that part of the GraphTask is
617 // shared across different `backward()` or `grad()` calls, i.e. fork new
618 // threads in the middle of the forward and call `backward()` separately from
619 // different threads. We need to protect the thread safety on NodeTask to
620 // prevent data racing on shared variables read/write.
621 //
622 // NB: This is only needed for Autograd Nodes that runs on CPU, technically
623 // "CUDA", "XLA" nodes don't need locking because device threads are always
624 // single threaded.
625 //
626 // Here we add a thread mutex to help protect the Node's thread safety, so
627 // that different threads cannot race the shared data when executing the same
628 // NodeTask from multiple CPU threads. It IS the user/developer responsibility
629 // to take advantage of this mutex to protect the thread safety of their
630 // autograd Node. The general strategy of thread safety on autograd Node:
631 //
632 // 1. User should lock the mutex during Node::release_variables() if the Node
633 // needs
634 // to release the variables on the fly, this serve the purpose that when we
635 // release saved_variables from one thread, no other threads can release
636 // the saved variables concurrently. call the Node::apply(),
637 // 2. User should lock the mutex during Node::apply(), this is to ensure Node
638 // that
639 // writing to the shared variable are not racing across threads (i.e.
640 // AccumulateGrad and custom C++ Autograd Node if writing to shared
641 // variables )
642 // 3. item 2 and item 3 should work together so that when we release saved
643 // variables
644 // from one thread, no other threads can call Node::apply(), this ensures
645 // the variable references from other threads aren't dangling.
646 // 4. if the Node don't release any variables and no shared data read/write in
647 // the Node
648 // i.e. purely functional, user don't need to lock the mutex
649 //
650 // This way we could protect the thread safety on Autograd Node, but we could
651 // still not protect the thread safety on Node pre/post C++ hooks (python
652 // hooks are automatically thread safe), we rely on the user to write thread
653 // safe C++ hooks if they want the hook to be correctly applied in
654 // multithreading environment.
655 std::mutex mutex_;
656
657 edge_list next_edges_;
658 PyObject* pyobj_ = nullptr; // weak reference
659 std::unique_ptr<AnomalyMetadata> anomaly_metadata_ = nullptr;
660
661 // NOTE [Hooks ordering]
662 // We have 3 separate fields for pre hooks registered to the autograd nodes
663 // because the conditions under which they execute are different, and we
664 // want more fine-grained control over the order in which different types
665 // of hooks are executed.
666 // - pre_hooks are only executed when the node itself is executed
667 // - tensor_pre_hook is executed as long as the engine traverses over it
668 // even if that node won't be executed.
669 // - retains_grad_hook are like tensor_pre_hooks except they are always
670 // ordered after all other tensor pre hooks
671 std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
672 std::vector<std::unique_ptr<FunctionPreHook>> tensor_pre_hooks_;
673 std::unordered_map<size_t, std::unique_ptr<FunctionPreHook>>
674 retains_grad_hooks_;
675 std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
676 at::SmallVector<InputMetadata, 2> input_metadata_;
677 };
678
679 /// See Node::is_traceable() for definition.
680 struct TraceableFunction : public Node {
681 using Node::Node;
is_traceableTraceableFunction682 bool is_traceable() final {
683 return true;
684 }
685 };
686
687 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
688 // Associated Free Nodes
689 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
690
691 namespace detail {
692 // Implementation of `collect_next_edges` (see below).
693 struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
694 edge_list next_edges;
695 using IterArgs<MakeNextFunctionList>::operator();
operatorMakeNextFunctionList696 void operator()(const Variable& variable) {
697 if (variable.defined()) {
698 next_edges.emplace_back(impl::gradient_edge(variable));
699 } else {
700 next_edges.emplace_back();
701 }
702 }
operatorMakeNextFunctionList703 void operator()(const Variable* variable) {
704 operator()(*variable);
705 }
operatorMakeNextFunctionList706 void operator()(const std::optional<Variable>& variable) {
707 if (variable.has_value()) {
708 operator()(*variable);
709 } else {
710 next_edges.emplace_back();
711 }
712 }
713 };
714 } // namespace detail
715
716 /// Create an `Edge` between the given `variable` and the `function`, which is
717 /// assumed to be the gradient function of this variable (i.e. the function
718 /// through which this variable is backpropagated during the backward pass).
719 /// This sets the `grad_fn` property of the `variable`. This function assumes
720 /// that the `Variable` is a new input to the gradient function and its
721 /// `input_nr` thus equal to `function->num_inputs()`. Additionally, it
722 /// increments the `Node`'s number of inputs by one. Approximately
723 /// equivalent to `variable.set_gradient_edge(function,
724 /// function->add_input_metadata(variable.dispatch_type(), variable.sizes()))`.
725 /// If you don't want the `Node`'s `num_inputs` to be incremented, use
726 /// `set_gradient_edge` directly.
create_gradient_edge(Variable & variable,std::shared_ptr<Node> function)727 inline void create_gradient_edge(
728 Variable& variable,
729 std::shared_ptr<Node> function) {
730 // Copy before move.
731 const auto input_nr = function->add_input_metadata(variable);
732 impl::set_gradient_edge(variable, {std::move(function), input_nr});
733 }
734
735 /// Return true if any of the variables in the list require a gradient.
any_variable_requires_grad(const variable_list & variables)736 inline bool any_variable_requires_grad(const variable_list& variables) {
737 return std::any_of(
738 variables.begin(), variables.end(), [](const Variable& variable) {
739 return variable.defined() && variable.requires_grad();
740 });
741 }
742
743 /// Return the next edges of all the given variables, or tuples of variables.
744 template <typename... Variables>
collect_next_edges(Variables &&...variables)745 edge_list collect_next_edges(Variables&&... variables) {
746 detail::MakeNextFunctionList make;
747 make.apply(std::forward<Variables>(variables)...);
748 return std::move(make.next_edges);
749 }
750
751 struct TypeAndSize {
TypeAndSizeTypeAndSize752 TypeAndSize() : options(at::TensorOptions()) {}
753 /* implicit */
TypeAndSizeTypeAndSize754 TypeAndSize(const at::Tensor& t)
755 : sym_sizes(t.sym_sizes().vec()), options(t.options()) {}
756
757 at::Tensor zeros();
758
759 std::vector<c10::SymInt> sym_sizes;
760 at::TensorOptions options;
761 };
762
763 } // namespace torch::autograd
764