xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/functions/tensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/functions/tensor.h>
2 
3 #include <torch/csrc/autograd/function.h>
4 #include <torch/csrc/autograd/functions/basic_ops.h>
5 #include <torch/csrc/autograd/functions/utils.h>
6 #include <torch/csrc/autograd/graph_task.h>
7 #include <torch/csrc/autograd/variable.h>
8 #include <torch/csrc/dynamo/compiled_autograd.h>
9 
10 #include <ATen/ATen.h>
11 #include <c10/util/irange.h>
12 
13 #include <memory>
14 #include <stdexcept>
15 #include <utility>
16 
17 namespace torch::autograd {
18 
apply(variable_list && grads)19 auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
20   check_input_variables("CopyBackwards", grads, 1, -1, true);
21   auto grad = c10::MaybeOwned<at::Tensor>::borrowed(grads[0]);
22   variable_list grad_inputs(2);
23   if (grad->defined()) {
24     if (task_should_compute_output(0)) {
25       grad_inputs[0] = at::zeros_like(*grad, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
26     }
27     if (task_should_compute_output(1)) {
28       // Handle R->C copies without raising a warning
29       const auto src_type = src_options.dtype().toScalarType();
30       if (!c10::isComplexType(src_type) && grad->is_complex()) {
31         grad = c10::MaybeOwned<at::Tensor>::owned(at::real(grads[0]));
32       }
33 
34       at::DeviceGuard device_guard(src_options.device());
35       grad_inputs[1] = grad->to(src_options);
36     }
37   }
38   return grad_inputs;
39 }
40 
compiled_args(CompiledNodeArgs & args)41 void CopyBackwards::compiled_args(CompiledNodeArgs& args) {
42   args.collect(src_options);
43 }
apply_with_saved(const variable_list & inputs,SwapSavedVariables & saved)44 variable_list CopyBackwards::apply_with_saved(
45     const variable_list& inputs,
46     SwapSavedVariables& saved) {
47   saved.before(src_options);
48   auto result = apply(variable_list(inputs));
49   saved.after(src_options);
50   return result;
51 }
52 
CopySlices(const Variable & base_var,at::TensorGeometry view_,std::unique_ptr<ViewFunc> view_fn_,std::shared_ptr<Node> fn_)53 CopySlices::CopySlices(
54     const Variable& base_var,
55     at::TensorGeometry view_,
56     std::unique_ptr<ViewFunc> view_fn_,
57     std::shared_ptr<Node> fn_)
58     : Node(),
59       base(base_var),
60       view(std::move(view_)),
61       view_fn(std::move(view_fn_)),
62       fn(std::move(fn_)) {
63   // Take the next_edges of fn as our own, except for index 0 which goes
64   // to base instead of the view.
65   add_input_metadata(base_var);
66   const auto num_outputs = fn->num_outputs();
67   next_edges_.reserve(num_outputs);
68   add_next_edge(impl::gradient_edge(base_var));
69   for (const auto i : c10::irange(1, num_outputs)) {
70     add_next_edge(fn->next_edge(i));
71   }
72 }
73 
74 // common code between apply/apply_with_saved
75 template <typename T>
apply_impl(variable_list && inputs,const T & call_fn)76 inline variable_list CopySlices::apply_impl(
77     variable_list&& inputs,
78     const T& call_fn) {
79   check_input_variables("CopySlices", inputs, 1, -1, true);
80   auto& grad = inputs[0];
81   if (!grad.defined()) {
82     return variable_list(num_outputs());
83   }
84 
85   // Acquire lock to here protect thread safety on fn
86   // see Note [Thread Safety on Autograd Node]
87   std::lock_guard<std::mutex> lock(mutex_);
88 
89   if (!fn) {
90     throw std::runtime_error(ERR_BACKWARD_TWICE);
91   }
92 
93   auto result =
94       grad.new_empty_strided_symint(base.sym_sizes(), base.sym_strides());
95   result.copy_(grad);
96 
97   at::Tensor grad_slice;
98   if (view_fn) {
99     grad_slice = (*view_fn)(result);
100   } else {
101     auto offset = view.sym_storage_offset() - base.sym_storage_offset();
102     grad_slice =
103         result.as_strided_symint(view.sym_sizes(), view.sym_strides(), offset);
104   }
105 
106   // See Note [View + Inplace update for view tensor] For more details on this
107   // block Since the gradient edge for the 0th input is different between `this`
108   // and `fn`, make sure that the one from `fn` has the same metadata in the
109   // current GraphTask's exec_info as the one on `this`.
110   const auto exec_info = get_current_graph_task_exec_info();
111   if (exec_info && !exec_info->empty()) {
112     const auto& fn_edge = fn->next_edge(0);
113     const auto& this_edge = this->next_edge(0);
114     TORCH_INTERNAL_ASSERT(fn_edge.is_valid() == this_edge.is_valid());
115     if (fn_edge.is_valid()) {
116       const auto fn_next_node = fn_edge.function.get();
117       auto it = exec_info->find(fn_next_node);
118       if (it == exec_info->end()) {
119         // Node is not in the exec_info already
120         if (task_should_compute_output(0)) {
121           // And we need gradient for the corresponding output
122           add_node_to_current_graph_task_exec_info(fn_next_node);
123           // There is no need to remove this after execution because we are
124           // guaranteed that this->next_edge(0) must be in the history of
125           // fn->next_edge(0) (we cannot easily assert this as it might be far
126           // away if there were many chained views). This means that, since
127           // fn->next_edge(0) was not needed (no exec_info entry for it), we
128           // know that nothing downstream of fn->next_edge(0) is needed either
129           // (otherwise the whole path from that Node to this->next_edge(0)
130           // would be needed as well). This means that no other Node will ever
131           // look at fn->next_edge(0) metadata and thus there is no need to
132           // clean them up.
133         }
134       } else {
135         TORCH_INTERNAL_ASSERT(
136             it->second.should_execute() == task_should_compute_output(0));
137       }
138     }
139   }
140 
141   // Sanity check that the graph was never modified after the fact (it is
142   // read-only!)
143   TORCH_INTERNAL_ASSERT(num_outputs() == fn->num_outputs());
144   for (const auto i : c10::irange(1, this->num_outputs())) {
145     TORCH_INTERNAL_ASSERT(
146         fn->next_edge(i).function.get() == this->next_edge(i).function.get());
147   }
148 
149   // TODO: We clone grad_slice because we modify it below and "fn" might save
150   // it for the backward of res. We might be able to avoid the clone() if
151   // double-backprop is disabled.
152   auto res = call_fn({grad_slice.clone(at::MemoryFormat::Contiguous)});
153 
154   variable_list grad_inputs(num_outputs());
155   for (const auto i : c10::irange(res.size())) {
156     if (task_should_compute_output(i)) {
157       if (!res[i].defined()) {
158         // If the output is not defined, treat it as if it was a zero tensor.
159         // This can happen if users define a custom Function.
160         continue;
161       }
162       if (i == 0) {
163         grad_slice.copy_(res[i]);
164         // NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
165         grad_inputs[i] = std::move(result); // NOLINT(bugprone-use-after-move)
166       } else {
167         grad_inputs[i] = std::move(res[i]);
168       }
169     }
170   }
171 
172   return grad_inputs;
173 }
174 
release_variables()175 void CopySlices::release_variables() {
176   // Acquire lock to here protect thread safety on fn
177   std::lock_guard<std::mutex> lock(mutex_);
178   fn = nullptr;
179 }
180 
compiled_args(CompiledNodeArgs & args)181 void CopySlices::compiled_args(CompiledNodeArgs& args) {
182   TORCH_CHECK(!view_fn, "view_fn not supported by compiled autograd")
183   TORCH_INTERNAL_ASSERT((bool)fn);
184   args.collect(base);
185   args.collect(view);
186   args.collect(fn);
187   fn->compiled_args(args);
188 }
189 
apply_with_saved(const variable_list & grads,SwapSavedVariables & saved)190 variable_list CopySlices::apply_with_saved(
191     const variable_list& grads,
192     SwapSavedVariables& saved) {
193   saved.before(base);
194   saved.before(view);
195   int call_count = 0;
196   variable_list result = apply_impl(
197       variable_list(grads),
198       [this, &saved, &call_count](const variable_list& inputs2) {
199         call_count++;
200         return fn->apply_with_saved(inputs2, saved);
201       });
202   TORCH_INTERNAL_ASSERT(call_count == 1);
203   saved.after(base);
204   saved.after(view);
205   return result;
206 }
207 
apply(variable_list && inputs1)208 auto CopySlices::apply(variable_list&& inputs1) -> variable_list {
209   return apply_impl(std::move(inputs1), [this](variable_list&& inputs2) {
210     return (*fn)(std::move(inputs2));
211   });
212 }
213 
214 } // namespace torch::autograd
215