1 #include <torch/csrc/autograd/function.h>
2
3 #include <c10/util/ThreadLocal.h>
4 #include <torch/csrc/autograd/engine.h>
5 #include <torch/csrc/autograd/variable.h>
6
7 #include <ATen/ATen.h>
8
9 #include <memory>
10 #include <string>
11 #include <utility>
12 #include <vector>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #else
17 #include <ATen/ops/zeros.h>
18 #endif
19
20 namespace torch::autograd {
21
22 // The current evaluating node. This is useful to assign the current node as a
23 // parent of new nodes created during the evaluation of this node in anomaly
24 // mode.
25 C10_DEFINE_TLS_static(std::shared_ptr<Node>, tls_current_evaluating_node);
26 #define current_evaluating_node (tls_current_evaluating_node.get())
27
NodeGuard(std::shared_ptr<Node> node)28 NodeGuard::NodeGuard(std::shared_ptr<Node> node)
29 : last_evaluating_node_(std::move(current_evaluating_node)) {
30 current_evaluating_node = std::move(node);
31 }
~NodeGuard()32 NodeGuard::~NodeGuard() {
33 // restore the previous evaluating node
34 current_evaluating_node = std::move(last_evaluating_node_);
35 }
36
get_current_node()37 std::shared_ptr<Node> get_current_node() {
38 return current_evaluating_node;
39 }
40
assign_parent()41 void Node::assign_parent() {
42 metadata()->assign_parent(current_evaluating_node);
43 }
44
name() const45 auto Node::name() const -> std::string {
46 return c10::demangle(typeid(*this).name());
47 }
48
metadata()49 AnomalyMetadata* Node::metadata() noexcept {
50 if (!anomaly_metadata_) {
51 anomaly_metadata_ = Engine::get_default_engine().make_anomaly_metadata();
52 }
53 return anomaly_metadata_.get();
54 }
55
gatherFunctions(Node * func,std::vector<std::shared_ptr<Node>> & stack)56 static void gatherFunctions(
57 Node* func,
58 std::vector<std::shared_ptr<Node>>& stack) {
59 func->release_variables();
60
61 for (auto& edge : func->next_edges()) {
62 if (edge.function.use_count() == 1) {
63 stack.emplace_back(std::move(edge.function));
64 } else {
65 edge.function.reset();
66 }
67 }
68 }
69
70 /*
71 * Fix for #5534: prevent stack overflow on deletion of deep computation graph
72 *
73 * Sometimes one can end up with a very big computation graph of Nodes
74 * and Edges. Each std::shared_ptr<Node> contains a list of Edge, and
75 * each Edge contains a std::shared_ptr<Node>. Deleting a
76 * std::shared_ptr<Node> can trigger the recursive deletion of other
77 * std::shared_ptr<Node>'s: this can stack overflow if the graph
78 * is deep enough. Here is an example of such a graph:
79 *
80 * shared_ptr<Node> -> Edge -> shared_ptr<Node> -> Edge -> ... ->
81 * shared_ptr<Node>
82 *
83 * The solution here is to detect when we are decrementing away the last
84 * reference to a Node, and when doing so to buffer up the Node's
85 * that will be recursively decremented. We can then decrement (and free)
86 * the original Node without causing a recursive cascade, before
87 * draining the buffer applying the same behavior. This is, in effect,
88 * converting recursion to a loop, using a heap buffer in place of the
89 * recursive call stack.
90 */
deleteNode(Node * function)91 void deleteNode(Node* function) {
92 // To avoid stack overflow on large computational graphs,
93 // we need to track reference decrementing and freeing
94 // on the heap.
95 function->release_variables();
96 std::vector<std::shared_ptr<Node>> stack;
97 gatherFunctions(function, stack);
98 delete function;
99
100 while (!stack.empty()) {
101 auto func = std::move(stack.back());
102 stack.pop_back();
103 gatherFunctions(func.get(), stack);
104 // Reference count is decremented on the loop backedge.
105 }
106 }
107
zeros()108 at::Tensor TypeAndSize::zeros() {
109 return at::zeros_symint(sym_sizes, options);
110 }
111
112 } // namespace torch::autograd
113