xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/function.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/function.h>
2 
3 #include <c10/util/ThreadLocal.h>
4 #include <torch/csrc/autograd/engine.h>
5 #include <torch/csrc/autograd/variable.h>
6 
7 #include <ATen/ATen.h>
8 
9 #include <memory>
10 #include <string>
11 #include <utility>
12 #include <vector>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #else
17 #include <ATen/ops/zeros.h>
18 #endif
19 
20 namespace torch::autograd {
21 
22 // The current evaluating node. This is useful to assign the current node as a
23 // parent of new nodes created during the evaluation of this node in anomaly
24 // mode.
25 C10_DEFINE_TLS_static(std::shared_ptr<Node>, tls_current_evaluating_node);
26 #define current_evaluating_node (tls_current_evaluating_node.get())
27 
NodeGuard(std::shared_ptr<Node> node)28 NodeGuard::NodeGuard(std::shared_ptr<Node> node)
29     : last_evaluating_node_(std::move(current_evaluating_node)) {
30   current_evaluating_node = std::move(node);
31 }
~NodeGuard()32 NodeGuard::~NodeGuard() {
33   // restore the previous evaluating node
34   current_evaluating_node = std::move(last_evaluating_node_);
35 }
36 
get_current_node()37 std::shared_ptr<Node> get_current_node() {
38   return current_evaluating_node;
39 }
40 
assign_parent()41 void Node::assign_parent() {
42   metadata()->assign_parent(current_evaluating_node);
43 }
44 
name() const45 auto Node::name() const -> std::string {
46   return c10::demangle(typeid(*this).name());
47 }
48 
metadata()49 AnomalyMetadata* Node::metadata() noexcept {
50   if (!anomaly_metadata_) {
51     anomaly_metadata_ = Engine::get_default_engine().make_anomaly_metadata();
52   }
53   return anomaly_metadata_.get();
54 }
55 
gatherFunctions(Node * func,std::vector<std::shared_ptr<Node>> & stack)56 static void gatherFunctions(
57     Node* func,
58     std::vector<std::shared_ptr<Node>>& stack) {
59   func->release_variables();
60 
61   for (auto& edge : func->next_edges()) {
62     if (edge.function.use_count() == 1) {
63       stack.emplace_back(std::move(edge.function));
64     } else {
65       edge.function.reset();
66     }
67   }
68 }
69 
70 /*
71  * Fix for #5534: prevent stack overflow on deletion of deep computation graph
72  *
73  * Sometimes one can end up with a very big computation graph of Nodes
74  * and Edges. Each std::shared_ptr<Node> contains a list of Edge, and
75  * each Edge contains a std::shared_ptr<Node>. Deleting a
76  * std::shared_ptr<Node> can trigger the recursive deletion of other
77  * std::shared_ptr<Node>'s: this can stack overflow if the graph
78  * is deep enough. Here is an example of such a graph:
79  *
80  * shared_ptr<Node> -> Edge -> shared_ptr<Node> -> Edge -> ... ->
81  * shared_ptr<Node>
82  *
83  * The solution here is to detect when we are decrementing away the last
84  * reference to a Node, and when doing so to buffer up the Node's
85  * that will be recursively decremented.  We can then decrement (and free)
86  * the original Node without causing a recursive cascade, before
87  * draining the buffer applying the same behavior.  This is, in effect,
88  * converting recursion to a loop, using a heap buffer in place of the
89  * recursive call stack.
90  */
deleteNode(Node * function)91 void deleteNode(Node* function) {
92   // To avoid stack overflow on large computational graphs,
93   // we need to track reference decrementing and freeing
94   // on the heap.
95   function->release_variables();
96   std::vector<std::shared_ptr<Node>> stack;
97   gatherFunctions(function, stack);
98   delete function;
99 
100   while (!stack.empty()) {
101     auto func = std::move(stack.back());
102     stack.pop_back();
103     gatherFunctions(func.get(), stack);
104     // Reference count is decremented on the loop backedge.
105   }
106 }
107 
zeros()108 at::Tensor TypeAndSize::zeros() {
109   return at::zeros_symint(sym_sizes, options);
110 }
111 
112 } // namespace torch::autograd
113