xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/variable.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/variable.h>
2 
3 #include <torch/csrc/autograd/InferenceMode.h>
4 #include <torch/csrc/autograd/autograd.h>
5 #include <torch/csrc/autograd/edge.h>
6 #include <torch/csrc/autograd/engine.h>
7 #include <torch/csrc/autograd/function.h>
8 #include <torch/csrc/autograd/functions/accumulate_grad.h>
9 #include <torch/csrc/autograd/functions/tensor.h>
10 #include <torch/csrc/autograd/functions/utils.h>
11 #include <torch/csrc/autograd/generated/Functions.h>
12 #include <torch/csrc/autograd/generated/ViewFuncs.h>
13 #include <torch/csrc/autograd/utils/error_messages.h>
14 
15 #include <ATen/ATen.h>
16 #include <ATen/FuncTorchTLS.h>
17 #include <ATen/MemoryOverlap.h>
18 #include <c10/util/Exception.h>
19 
20 #include <memory>
21 #include <mutex>
22 #include <stdexcept>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 namespace torch::autograd {
28 
29 // Returns a ViewFunc with a corresponding view that matches the shape,
30 // stride, and storage offset of the given tensor.
31 // NB: On mobile, the as_strided() op and thus the generated AsStridedViewFunc
32 // may not be available.
create_view_func_matching(const Variable & t)33 static std::unique_ptr<ViewFunc> create_view_func_matching(const Variable& t) {
34 #ifdef AS_STRIDED_VIEW_FUNC_AVAILABLE
35   return std::make_unique<torch::autograd::generated::AsStridedViewFunc>(
36       t.sym_sizes(), t.sym_strides(), t.sym_storage_offset());
37 #else
38   return std::make_unique<ErroringViewFunc>("as_strided() not available");
39 #endif
40 }
41 
DifferentiableViewMeta(at::TensorImpl * self_impl,std::optional<ViewInfo> backward_info,std::optional<ViewInfo> forward_info,bool shared_view_info,CreationMeta creation_meta)42 DifferentiableViewMeta::DifferentiableViewMeta(
43     at::TensorImpl* self_impl,
44     std::optional<ViewInfo> backward_info,
45     std::optional<ViewInfo> forward_info,
46     bool shared_view_info,
47     CreationMeta creation_meta)
48     : AutogradMeta(self_impl),
49       backward_info_(std::move(backward_info)),
50       forward_info_(std::move(forward_info)),
51       shared_view_info_(shared_view_info),
52       creation_meta_(creation_meta) {
53   is_view_ = true;
54   if (backward_info_.has_value()) {
55     self_impl->set_version_counter(
56         impl::version_counter(backward_info_.value().base_));
57     attr_version_ = self_impl->version_counter().current_version();
58     TORCH_INTERNAL_ASSERT(
59         backward_info_.value().base_.unsafeGetTensorImpl() != self_impl);
60   }
61   if (shared_view_info_) {
62     TORCH_INTERNAL_ASSERT(
63         backward_info_.has_value(),
64         "Shared view info require a backward view info.");
65     TORCH_INTERNAL_ASSERT(
66         !forward_info_.has_value(),
67         "Shared view info require forward view info to be empty")
68   }
69 }
70 
71 // Chain this view info with the new view op between base and tensor
chain(const Variable & base,const Variable & tensor,std::unique_ptr<ViewFunc> view_func,std::function<Variable (const Variable &)> rev_view_func) const72 ViewInfo ViewInfo::chain(
73     const Variable& base,
74     const Variable& tensor,
75     std::unique_ptr<ViewFunc> view_func,
76     std::function<Variable(const Variable&)> rev_view_func) const {
77   // Set `view_func` using the root base as input.
78   // `view_func` is used to recover views in backward when either as_strided is
79   // not supported or the view function changes the metadata which is not
80   // recorded by as_strided See Note [View + Inplace update on base tensor] and
81   // [View + Inplace update on view tensor] for more details how we use this
82   // function in backward.
83   if (view_func) {
84     // both current_view and it's parent have a view_func
85     if (view_fn_) {
86       view_func = std::make_unique<ChainedViewFunc>(
87           view_fn_->clone_and_set(), std::move(view_func));
88 
89       // assume view_fn_ / rev_view_fn_ always exist together or neither are set
90       auto prev_rev_fn = rev_view_fn_;
91       rev_view_func = [=](const at::Tensor& root_view) {
92         auto temp = rev_view_func(root_view);
93         return prev_rev_fn(temp);
94       };
95     } else {
96       // current_view has a view_func and but it's parent doesn't have one
97       if (base.unsafeGetTensorImpl()->support_as_strided()) {
98         auto match_base_view_func = create_view_func_matching(base);
99         view_func = std::make_unique<ChainedViewFunc>(
100             std::move(match_base_view_func), std::move(view_func));
101 
102         // assume view_fn_ / rev_view_fn_ always exist together or neither are
103         // set
104         const auto& root_base = base._base();
105         auto root_base_size = root_base.sym_sizes().vec();
106         auto root_base_stride = root_base.sym_strides().vec();
107         auto root_base_storage_offset = root_base.sym_storage_offset();
108         rev_view_func = [=](const at::Tensor& root_view) {
109           auto temp = rev_view_func(root_view);
110           return temp.as_strided_symint(
111               root_base_size, root_base_stride, root_base_storage_offset);
112         };
113       } else {
114         // This case should be relatively rare: parent view doesn't have a
115         // view_func() AND as_strided() isn't supported; there's no obvious way
116         // to chain the two views.
117         auto error_msg =
118             ("Attempted to chain views when the parent view has no view_func() and "
119              "does not support as_strided(). This is not supported.");
120         view_func = std::make_unique<ErroringViewFunc>(error_msg);
121         rev_view_func = [=](const at::Tensor& root_view) {
122           TORCH_CHECK(false, error_msg);
123           return root_view;
124         };
125       }
126     }
127   } else if (view_fn_) {
128     // if current_view doesn't have a view_func but it's parent has one
129     auto match_tensor_view_func = create_view_func_matching(tensor);
130     view_func = std::make_unique<ChainedViewFunc>(
131         view_fn_->clone_and_set(), std::move(match_tensor_view_func));
132 
133     // assume view_fn_ / rev_view_fn_ always exist together or neither are set
134     auto prev_rev_view_fn = rev_view_fn_;
135     auto base_size = base.sym_sizes().vec();
136     auto base_stride = base.sym_strides().vec();
137     auto base_storage_offset = base.sym_storage_offset();
138     rev_view_func = [=](const at::Tensor& root_view) {
139       auto temp = root_view.as_strided_symint(
140           base_size, base_stride, base_storage_offset);
141       return prev_rev_view_fn(temp);
142     };
143   }
144 
145   return ViewInfo(base_, std::move(view_func), std::move(rev_view_func));
146 }
147 
148 namespace {
149 
150 at::Tensor singleton_undefined_tensor;
151 
152 struct ConcreteAutogradMetaFactory : public c10::impl::AutogradMetaFactory {
maketorch::autograd::__anon548c2c850511::ConcreteAutogradMetaFactory153   std::unique_ptr<c10::AutogradMetaInterface> make() const override {
154     return std::make_unique<AutogradMeta>();
155   }
undefined_tensortorch::autograd::__anon548c2c850511::ConcreteAutogradMetaFactory156   const at::Tensor& undefined_tensor() const override {
157     return singleton_undefined_tensor;
158   }
159 };
160 
161 ConcreteAutogradMetaFactory meta_factory;
162 
163 static c10::impl::AutogradMetaFactoryRegisterer meta_factory_registerer(
164     &meta_factory);
165 
166 } // namespace
167 
168 namespace impl {
169 
materialize_autograd_meta(const at::TensorBase & self)170 AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) {
171   TORCH_CHECK(
172       self.defined(),
173       "cannot call materialize_autograd_meta() on undefined tensor");
174   auto p = self.unsafeGetTensorImpl();
175   if (!p->autograd_meta()) {
176     p->set_autograd_meta(std::make_unique<AutogradMeta>());
177   }
178   return get_autograd_meta(self);
179 }
180 
update_tensor_hooks_on_new_gradfn(const at::TensorBase & self,const std::shared_ptr<torch::autograd::Node> & old_fn,const std::shared_ptr<torch::autograd::Node> & new_fn)181 static void update_tensor_hooks_on_new_gradfn(
182     const at::TensorBase& self,
183     const std::shared_ptr<torch::autograd::Node>& old_fn,
184     const std::shared_ptr<torch::autograd::Node>& new_fn) {
185   // This function is called whenever the grad_fn of the tensor is
186   // changed. We assume here that new_fn does not yet have hooks of
187   // its own.
188   //
189   // This function does two things:
190   // (1) reset the list when grad_fn is updated, so new hooks don't
191   //     get erroneously registered to the old grad_fn.
192   //     Note that the old cpp_hooks_list_ is still kept alive by the
193   //     old grad_fn so hooks registered to the older version of the tensor
194   //     will continue to be active.
195   // (2) If there is a retains_grad hook registered, move that from the
196   //     old cpp_hooks_list_ to the new one
197   const auto& meta = impl::get_autograd_meta(self);
198   TORCH_INTERNAL_ASSERT(meta);
199   TORCH_INTERNAL_ASSERT(new_fn);
200   meta->cpp_hooks_list_ = nullptr;
201   const c10::impl::PyInterpreter* interp =
202       self.unsafeGetTensorImpl()->pyobj_slot()->pyobj_interpreter();
203   if (interp) {
204     (*interp)->reset_backward_hooks(self.unsafeGetTensorImpl());
205   }
206   if (self.retains_grad()) {
207     TORCH_INTERNAL_ASSERT(old_fn);
208     auto out = old_fn->pop_retains_grad_hook(self.output_nr());
209     TORCH_INTERNAL_ASSERT(out != nullptr);
210     new_fn->add_retains_grad_hook(std::move(out), self.output_nr());
211   }
212 }
213 
rebase_history(const Variable & self,Edge gradient_edge)214 void rebase_history(const Variable& self, Edge gradient_edge) {
215   TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr);
216   const auto& meta = impl::get_autograd_meta(self);
217   auto old_fn = meta != nullptr ? meta->grad_fn_ : nullptr;
218   auto diff_view_meta = get_view_autograd_meta(self);
219   if (diff_view_meta && diff_view_meta->has_bw_view()) {
220     // See NOTE [ View + Inplace detection ]
221     auto creation_meta = diff_view_meta->get_creation_meta();
222     // Do not use handle_view_on_rebase here as check_inplace should have been
223     // called before this and either throw an error
224     TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::DEFAULT);
225     TORCH_INTERNAL_ASSERT(gradient_edge.input_nr == 0);
226     TORCH_INTERNAL_ASSERT(gradient_edge.function);
227     TORCH_CHECK(
228         gradient_edge.function->num_inputs() == 1,
229         "Functions which modify views in-place must return a single Variable");
230     const auto& view_info = diff_view_meta->get_backward_view();
231     diff_view_meta->output_nr_ = gradient_edge.input_nr;
232     auto copy_slices = std::make_shared<CopySlices>(
233         view_info.base_,
234         at::TensorGeometry(self),
235         view_info.has_view_fn() ? view_info.view_fn().clone_and_set() : nullptr,
236         std::move(gradient_edge.function));
237     if (self.requires_grad()) {
238       // If self did not previously require grad, there are no hooks to move
239       torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
240           view_info.base_, view_info.base_.grad_fn(), copy_slices);
241     }
242     set_gradient_edge(view_info.base_, {std::move(copy_slices), 0});
243     self.grad_fn(); // trigger an update to the view's grad_fn
244     return;
245   }
246 
247   set_gradient_edge(self, std::move(gradient_edge));
248   // Pass both self and its grad_fn to avoid calling into grad_fn reentrantly
249   torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
250       self, old_fn, self.grad_fn());
251 }
252 
create_cpp_hook(const at::TensorBase & self,bool is_retains_grad_hook)253 void create_cpp_hook(const at::TensorBase& self, bool is_retains_grad_hook) {
254   const auto& fn = self.grad_fn();
255   std::shared_ptr<hooks_list>& list =
256       materialize_autograd_meta(self)->cpp_hooks_list_;
257   list = std::make_shared<hooks_list>();
258   auto hook_ptr =
259       std::make_unique<CppFunctionTensorPreHook>(list, self.output_nr());
260   // NB: we could potentially only update hooks_ if !fn, but it shouldn't
261   // matter
262   //     and this was the way before, so we keep it like this for now.
263   clear_hooks(self);
264   add_hook(self, std::make_unique<CppFunctionTensorPreHook>(list, 0));
265   if (fn) {
266     fn->add_tensor_pre_hook(std::move(hook_ptr));
267   }
268 }
269 
set_grad_accumulator(const Variable & self,std::weak_ptr<Node> grad_accumulator)270 void set_grad_accumulator(
271     const Variable& self,
272     std::weak_ptr<Node> grad_accumulator) {
273   materialize_autograd_meta(self)->grad_accumulator_ =
274       std::move(grad_accumulator);
275 }
276 
try_get_grad_accumulator(const Variable & self)277 std::shared_ptr<Node> try_get_grad_accumulator(const Variable& self) {
278   if (get_autograd_meta(self)) {
279     return get_autograd_meta(self)->grad_accumulator_.lock();
280   } else {
281     return nullptr;
282   }
283 }
284 
grad_accumulator(const Variable & self)285 std::shared_ptr<Node> grad_accumulator(const Variable& self) {
286   auto autograd_meta = get_autograd_meta(self);
287   if (!autograd_meta) {
288     return nullptr;
289   }
290   if (autograd_meta->grad_fn_) {
291     throw std::logic_error(
292         "grad_accumulator() should be only called on leaf Variables");
293   }
294   if (!autograd_meta->requires_grad_) {
295     return nullptr;
296   }
297 
298   std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
299 
300   auto result = autograd_meta->grad_accumulator_.lock();
301   if (result)
302     return result;
303 
304   c10::raw::intrusive_ptr::incref(self.unsafeGetTensorImpl());
305   auto intrusive_from_this =
306       c10::intrusive_ptr<at::TensorImpl>::reclaim(self.unsafeGetTensorImpl());
307   result = std::make_shared<AccumulateGrad>(
308       Variable(std::move(intrusive_from_this)));
309   autograd_meta->grad_accumulator_ = result;
310   return result;
311 }
312 
gradient_edge(const Variable & self)313 Edge gradient_edge(const Variable& self) {
314   // If grad_fn is null (as is the case for a leaf node), we instead
315   // interpret the gradient function to be a gradient accumulator, which will
316   // accumulate its inputs into the grad property of the variable. These
317   // nodes get suppressed in some situations, see "suppress gradient
318   // accumulation" below. Note that only variables which have `requires_grad =
319   // True` can have gradient accumulators.
320   if (const auto& gradient = self.grad_fn()) {
321     return Edge(gradient, self.output_nr());
322   } else {
323     return Edge(grad_accumulator(self), 0);
324   }
325 }
326 
set_gradient_edge(const Variable & self,Edge edge)327 void set_gradient_edge(const Variable& self, Edge edge) {
328   auto* meta = materialize_autograd_meta(self);
329   meta->grad_fn_ = std::move(edge.function);
330   meta->output_nr_ = edge.input_nr;
331   // For views, make sure this new grad_fn_ is not overwritten unless it is
332   // necessary in the VariableHooks::grad_fn below. This logic is only relevant
333   // for custom autograd Functions for which multiple operations can happen on a
334   // given Tensor before its gradient edge is set when exiting the custom
335   // Function.
336   auto diff_view_meta = get_view_autograd_meta(self);
337   if (diff_view_meta && diff_view_meta->has_bw_view()) {
338     diff_view_meta->set_attr_version(self._version());
339   }
340 }
341 
grad_fn_unsafe(const Variable & self)342 Node* grad_fn_unsafe(const Variable& self) {
343   if (get_autograd_meta(self)) {
344     return get_autograd_meta(self)->grad_fn_.get();
345   } else {
346     return nullptr;
347   }
348 }
349 
350 // Versions
351 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
352 
set_version_counter(const Variable & self,const c10::VariableVersion & version_counter)353 void set_version_counter(
354     const Variable& self,
355     const c10::VariableVersion& version_counter) {
356   TORCH_CHECK(
357       self.defined(), "cannot call set_version_counter() on undefined tensor");
358   self.unsafeGetTensorImpl()->set_version_counter(version_counter);
359 }
360 
bump_version(const Variable & self)361 void bump_version(const Variable& self) {
362   TORCH_CHECK(self.defined(), "cannot call bump_version() on undefined tensor");
363   self.unsafeGetTensorImpl()->bump_version();
364 }
365 
version_counter(const Variable & self)366 const c10::VariableVersion& version_counter(const Variable& self) {
367   TORCH_CHECK(
368       self.defined(), "cannot call version_counter() on undefined tensor");
369   return self.unsafeGetTensorImpl()->version_counter();
370 }
371 
372 // Hooks
373 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
374 
add_hook(const at::TensorBase & self,std::unique_ptr<FunctionPreHook> hook)375 void add_hook(
376     const at::TensorBase& self,
377     std::unique_ptr<FunctionPreHook> hook) {
378   AutogradMeta* meta = materialize_autograd_meta(self);
379   TORCH_INTERNAL_ASSERT(meta->hooks_.empty());
380   meta->hooks_.push_back(std::move(hook));
381 }
382 
hooks(const Variable & self)383 std::vector<std::unique_ptr<FunctionPreHook>>& hooks(const Variable& self) {
384   TORCH_INTERNAL_ASSERT(get_autograd_meta(self));
385   return get_autograd_meta(self)->hooks_;
386 }
387 
clear_hooks(const at::TensorBase & self)388 void clear_hooks(const at::TensorBase& self) {
389   // This is a little goofy, but usually this should be a no oop
390   materialize_autograd_meta(self)->hooks_.clear();
391 }
392 
set_post_acc_grad_hooks(const at::TensorBase & self,std::unique_ptr<PostAccumulateGradHook> dict)393 void set_post_acc_grad_hooks(
394     const at::TensorBase& self,
395     std::unique_ptr<PostAccumulateGradHook> dict) {
396   AutogradMeta* meta = materialize_autograd_meta(self);
397   meta->post_acc_grad_hooks_ = std::move(dict);
398 }
399 
post_acc_grad_hooks(const Variable & self)400 std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks(
401     const Variable& self) {
402   TORCH_INTERNAL_ASSERT(get_autograd_meta(self));
403   return get_autograd_meta(self)->post_acc_grad_hooks_;
404 }
405 
set_name(const Variable & self,const std::string & name)406 void set_name(const Variable& self, const std::string& name) {
407   materialize_autograd_meta(self)->name_ = name;
408 }
409 
410 // Miscellaneous
411 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
412 
get_autograd_meta(const at::TensorBase & self)413 AutogradMeta* get_autograd_meta(const at::TensorBase& self) {
414   // NB: could return nullptr
415   TORCH_CHECK(
416       self.defined(), "cannot call get_autograd_meta() on undefined tensor");
417   return static_cast<AutogradMeta*>(
418       self.unsafeGetTensorImpl()->autograd_meta());
419 }
420 
get_view_autograd_meta(const at::TensorBase & self)421 DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase& self) {
422   // NB: return nullptr if self is not a view
423   AutogradMeta* meta = get_autograd_meta(self);
424   if (meta && meta->is_view_) {
425     return static_cast<DifferentiableViewMeta*>(meta);
426   } else {
427     return nullptr;
428   }
429 }
430 
431 } // namespace impl
432 
433 using at::Tensor;
434 
435 VariableHooks variableHooks;
436 at::impl::VariableHooksRegisterer registerVariableHooks(&variableHooks);
437 
variable_data(const at::TensorBase & self) const438 at::TensorBase VariableHooks::variable_data(const at::TensorBase& self) const {
439   TORCH_CHECK(
440       self.defined(), "cannot call variable_data() on undefined tensor");
441   auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
442       /*version_counter=*/0,
443       /*allow_tensor_metadata_change=*/false);
444   self_impl_copy->set_autograd_meta(nullptr);
445   return at::Tensor(self_impl_copy);
446 }
447 
tensor_data(const at::TensorBase & self) const448 at::TensorBase VariableHooks::tensor_data(const at::TensorBase& self) const {
449   TORCH_CHECK(self.defined(), "cannot call tensor_data() on undefined tensor");
450   auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
451       /*version_counter=*/self.unsafeGetTensorImpl()->version_counter(),
452       /*allow_tensor_metadata_change=*/
453       self.unsafeGetTensorImpl()->allow_tensor_metadata_change());
454   return at::Tensor(self_impl_copy);
455 }
456 
is_leaf(const at::TensorBase & self) const457 bool VariableHooks::is_leaf(const at::TensorBase& self) const {
458   if (impl::get_autograd_meta(self)) {
459     return impl::get_autograd_meta(self)->grad_fn_ == nullptr;
460   } else {
461     return true;
462   }
463 }
464 
output_nr(const at::TensorBase & self) const465 int64_t VariableHooks::output_nr(const at::TensorBase& self) const {
466   if (impl::get_autograd_meta(self)) {
467     return impl::get_autograd_meta(self)->output_nr_;
468   } else {
469     return 0;
470   }
471 }
472 
set_data(const at::TensorBase & self_base,const at::TensorBase & new_data_base) const473 void VariableHooks::set_data(
474     const at::TensorBase& self_base,
475     const at::TensorBase& new_data_base) const {
476   at::OptionalTensorRef self_ref(self_base);
477   const Tensor& self = *self_ref;
478   at::OptionalTensorRef new_data_ref(new_data_base);
479   const Tensor& new_data = *new_data_ref;
480 
481   // `var.set_data(new_data)` shallow-copies all non-autograd TensorImpl fields
482   // from `new_data` to `var`. It requires that `new_data` and `var` have
483   // compatible tensor type.
484   TORCH_CHECK(
485       _has_compatible_shallow_copy_type(self, new_data),
486       "Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type.");
487 
488   TORCH_CHECK(
489       !self.requires_grad() ||
490           isDifferentiableType(at::typeMetaToScalarType(new_data.dtype())),
491       "data set to a tensor that requires gradients must be floating point or complex dtype");
492 
493   // Resets gradient accumulator if metadata is out of date
494   AutogradMeta* autograd_meta = impl::get_autograd_meta(self);
495   if (autograd_meta) {
496     std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
497     auto prior_accumulator = autograd_meta->grad_accumulator_.lock();
498     if (prior_accumulator) {
499       const auto prior_device = prior_accumulator->input_metadata(0).device();
500       const auto new_device = new_data.device();
501 
502       if (!new_data.options().type_equal(self.options()) ||
503           prior_device != new_device) {
504         autograd_meta->grad_accumulator_.reset();
505       }
506     }
507   }
508 
509   // Version counter is not shared when we replace a `Variable`'s tensor data
510   // by calling `set_data(...)`. The original version of the `Variable` is
511   // always preserved. See NOTE [ Version Counter Sharing ] for details.
512   //
513   // `var.set_data(new_data)` always ignores `var`'s
514   // `allow_tensor_metadata_change_`, because users need this API as an escape
515   // hatch for changing a tensor's metadata regardless of its
516   // `allow_tensor_metadata_change_` value, and the users are responsible for
517   // ensuring this is the behavior they want.
518   self.unsafeGetTensorImpl()->shallow_copy_from(new_data.getIntrusivePtr());
519 }
520 
data(const at::TensorBase & self) const521 at::TensorBase VariableHooks::data(const at::TensorBase& self) const {
522   return self.variable_data();
523 }
524 
_version(const at::TensorBase & self) const525 int64_t VariableHooks::_version(const at::TensorBase& self) const {
526   return self.unsafeGetTensorImpl()->version_counter().current_version();
527 }
528 
retain_grad(const at::TensorBase & self) const529 void VariableHooks::retain_grad(const at::TensorBase& self) const {
530   TORCH_CHECK(
531       self.requires_grad(),
532       "can't retain_grad on Tensor that has requires_grad=False");
533 
534   // temporary hack to improve functorch UX.
535   const auto& functorch_tls = at::functorch::functorchTLSAccessor();
536   if (functorch_tls) {
537     functorch_tls->checkSupportsRetainGrad();
538   }
539 
540   if (self.is_leaf()) { // no-op for leaves
541     return;
542   }
543   if (impl::get_autograd_meta(self)->retains_grad_) {
544     return;
545   }
546   c10::weak_intrusive_ptr<c10::TensorImpl> weak_self(self.getIntrusivePtr());
547 
548   auto retain_grad_hook = [weak_self](const at::TensorBase& grad_base) {
549     at::Tensor grad{grad_base};
550     if (!weak_self.expired() && grad.defined()) {
551       auto var = weak_self.lock();
552       if (!var->grad().defined()) {
553         if (grad.is_sparse()) {
554           var->mutable_grad() = grad.clone();
555         } else {
556           var->mutable_grad() = grad.clone(at::MemoryFormat::Contiguous);
557         }
558       } else {
559         var->mutable_grad() = var->grad() + grad;
560       }
561     }
562     return at::TensorBase{};
563   };
564 
565   const auto& fn = self.grad_fn();
566   fn->add_retains_grad_hook(
567       std::make_unique<CppFunctionSingleTensorPreHook>(
568           std::move(retain_grad_hook), self.output_nr()),
569       self.output_nr());
570   impl::get_autograd_meta(self)->retains_grad_ = true;
571 }
572 
retains_grad(const at::TensorBase & self) const573 bool VariableHooks::retains_grad(const at::TensorBase& self) const {
574   if (impl::get_autograd_meta(self)) {
575     return impl::get_autograd_meta(self)->retains_grad_;
576   } else {
577     return false;
578   }
579 }
580 
_backward(const Tensor & self,at::TensorList inputs,const std::optional<Tensor> & gradient,std::optional<bool> keep_graph,bool create_graph) const581 void VariableHooks::_backward(
582     const Tensor& self,
583     at::TensorList inputs,
584     const std::optional<Tensor>& gradient,
585     std::optional<bool> keep_graph,
586     bool create_graph) const {
587   // TODO torch::autograd::backward should take the std::optional<Tensor>
588   // gradient directly instead of us having to unwrap it to Tensor _gradient
589   // here.
590   Tensor _gradient = gradient.has_value() ? *gradient : Tensor();
591   std::vector<torch::autograd::Variable> input_vars(
592       inputs.begin(), inputs.end());
593   torch::autograd::backward(
594       {self}, {std::move(_gradient)}, keep_graph, create_graph, input_vars);
595 }
596 
requires_grad_(const at::TensorBase & self,bool _requires_grad) const597 void VariableHooks::requires_grad_(
598     const at::TensorBase& self,
599     bool _requires_grad) const {
600   if (!self.is_leaf() && !_requires_grad) {
601     throw std::runtime_error(
602         autograd::utils::requires_grad_leaf_error(_requires_grad));
603   }
604   self.set_requires_grad(_requires_grad);
605 }
606 
607 // Backward View Variables
608 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
609 
is_view(const at::TensorBase & self) const610 bool VariableHooks::is_view(const at::TensorBase& self) const {
611   auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
612   if (diff_view_meta) {
613     return diff_view_meta->has_bw_view();
614   } else {
615     return false;
616   }
617 }
618 
base(const at::TensorBase & self) const619 const at::TensorBase& VariableHooks::base(const at::TensorBase& self) const {
620   auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
621   if (diff_view_meta) {
622     TORCH_CHECK(
623         diff_view_meta->has_bw_view(),
624         "Can't get base of non-backward view Tensor");
625     return diff_view_meta->get_backward_view().base_;
626   } else {
627     throw std::runtime_error("Can't get base of non-view Tensor");
628   }
629 }
630 
631 namespace {
632 std::string singleton_string;
633 }
634 
name(const at::TensorBase & self) const635 const std::string& VariableHooks::name(const at::TensorBase& self) const {
636   TORCH_CHECK(
637       self.defined(), "cannot call variable_data() on undefined tensor");
638   if (torch::autograd::impl::get_autograd_meta(self)) {
639     return torch::autograd::impl::get_autograd_meta(self)->name_;
640   } else {
641     return singleton_string;
642   }
643 }
644 
645 namespace {
646 std::shared_ptr<torch::autograd::Node> singleton_shared_ptr;
647 }
648 
grad_fn(const at::TensorBase & self) const649 const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
650     const at::TensorBase& self) const {
651   auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
652   if (diff_view_meta && diff_view_meta->has_bw_view()) {
653     // See NOTE [ View + Inplace detection ]
654     std::lock_guard<std::mutex> lock(diff_view_meta->mutex_);
655     auto& view_info = diff_view_meta->get_backward_view();
656     if (!diff_view_meta->grad_fn_ && !view_info.base_.requires_grad()) {
657       return diff_view_meta->grad_fn_;
658     }
659     auto current_version = self._version();
660     auto old_fn = diff_view_meta->grad_fn_;
661     if (diff_view_meta->get_attr_version() != current_version) {
662       // This is an indirect rebase_history due to another view or the base
663       // being modified inplace
664       handle_view_on_rebase(diff_view_meta, /* indirect */ true);
665       TORCH_INTERNAL_ASSERT(diff_view_meta->output_nr_ == 0);
666       // Note [View + Inplace update for view tensor]
667       // An inplace update happened on Tensor `self` (which is a view).
668       // For example:
669       //   view_1 = view_op_1(diff_view_meta->base_)
670       //   view_2 = view_op_2(view_1)
671       //   ...
672       //   self = view_op_n(view_n-1)
673       //   self = inplace_op(self)
674       //
675       // For CPU/CUDA backends, we employ one AsStridedBackward0 Node to
676       // represent the chain of view backward ops for efficiency.
677       //
678       // However in XLA backend we don't have full support of
679       // AsStridedBackward0, we instead run a full forward pass with a tensor
680       // that requires gradient to get proper grad_fn setup, then save it to
681       // DifferentiableViewMeta for future use. This is fairly cheap for XLA
682       // lazy tensor approach (but would be really expensive for CPU/CUDA). XLA
683       // Tensor only run through VariableType dispatch and lower the forward
684       // pass to a XLA HLO graph, then we take grad_fn and never materialize the
685       // tensor content. So we only construct the graph but not execute it,
686       // which is a fairly cheap operation to do.
687       //
688       // See Note [View + Inplace update for base tensor] for what we do to base
689       // tensor when an in-place operation happens.
690       //
691       // TODO: Potentially the following logic can be replaced by special logic
692       // in VariableType_x.cpp
693       //       that would provide a way to recreate the grad_fn chain.
694       if (view_info.has_view_fn()) {
695         auto& view_fn = view_info.view_fn();
696         Tensor diff_view;
697         {
698           // We can reach this path with grad_mode disabled, e.g. engine
699           AutoGradMode grad_mode(true);
700           diff_view = view_fn(view_info.base_);
701         }
702         diff_view_meta->grad_fn_ = diff_view.grad_fn();
703       } else {
704         auto fn =
705             std::make_shared<torch::autograd::generated::AsStridedBackward0>();
706         fn->self_geometry = at::TensorGeometry(view_info.base_);
707         fn->size = self.sym_sizes().vec();
708         fn->stride = self.sym_strides().vec();
709         fn->storage_offset = self.sym_storage_offset();
710         fn->set_next_edges(
711             torch::autograd::collect_next_edges(view_info.base_));
712         fn->add_input_metadata(
713             view_info.base_.options(),
714             self.sym_sizes(), // Note: sizes(), not base_.sizes(), is
715                               // intentional
716             self.unsafeGetTensorImpl()->is_python_dispatch(),
717             self.is_nested());
718         diff_view_meta->grad_fn_ = std::move(fn);
719       }
720       diff_view_meta->set_attr_version(current_version);
721 
722       torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
723           self, old_fn, diff_view_meta->grad_fn_);
724     }
725     return diff_view_meta->grad_fn_;
726   }
727 
728   if (torch::autograd::impl::get_autograd_meta(self)) {
729     return torch::autograd::impl::get_autograd_meta(self)->grad_fn_;
730   } else {
731     return singleton_shared_ptr;
732   }
733 }
734 
remove_hook(const at::TensorBase & self,unsigned pos) const735 void VariableHooks::remove_hook(const at::TensorBase& self, unsigned pos)
736     const {
737   auto& list =
738       torch::autograd::impl::materialize_autograd_meta(self)->cpp_hooks_list_;
739   TORCH_CHECK(
740       list && pos < list->size(), "Invalid index, no hook at position ", pos);
741   // Hook will be ignored
742   (*list)[pos] = nullptr;
743 }
744 
_register_hook(const at::TensorBase & self,std::function<at::TensorBase (const at::TensorBase &)> hook) const745 unsigned VariableHooks::_register_hook(
746     const at::TensorBase& self,
747     std::function<at::TensorBase(const at::TensorBase&)> hook) const {
748   TORCH_CHECK(
749       self.requires_grad(),
750       "cannot register a hook on a variable that "
751       "doesn't require gradient");
752   // NB: materialize_autograd_meta unnecessary due to requires grad check
753   auto& list = torch::autograd::impl::get_autograd_meta(self)->cpp_hooks_list_;
754   if (!list) {
755     torch::autograd::impl::create_cpp_hook(
756         self, /*is_retains_grad_hooks=*/false);
757   }
758   unsigned idx = list->size();
759   list->push_back(hook);
760   return idx;
761 }
762 
handle_view_on_rebase(DifferentiableViewMeta * diff_view_meta,bool indirect)763 void handle_view_on_rebase(
764     DifferentiableViewMeta* diff_view_meta,
765     bool indirect) {
766   /// See NOTE [ View + Inplace detection ] for justification of the logic below
767   auto creation_meta = diff_view_meta->get_creation_meta();
768   if (creation_meta != CreationMeta::DEFAULT) {
769     auto grad_fn = diff_view_meta->grad_fn_.get();
770     std::string msg;
771     std::string modified_obj;
772     // Create the header for the error message.
773     if (indirect) {
774       modified_obj = "its base or another view of its base has been";
775     } else {
776       modified_obj = "is being";
777     }
778 
779     if (creation_meta == CreationMeta::INFERENCE_MODE ||
780         creation_meta == CreationMeta::NO_GRAD_MODE || !grad_fn) {
781       std::string prefix;
782       if (grad_fn) {
783         prefix = c10::str(
784             "Output ",
785             diff_view_meta->output_nr_,
786             " of ",
787             grad_fn->name(),
788             " is a view of a view which was created in");
789       } else {
790         prefix = "A view was created in";
791       }
792       if (creation_meta == CreationMeta::INFERENCE_MODE) {
793         msg = c10::str(
794             prefix,
795             " inference mode and ",
796             modified_obj,
797             " modified inplace in normal mode.");
798       } else {
799         // create_meta is not necessarily CreationMeta::NO_GRAD_MODE
800         // e.g. CreationMeta::IN_CUSTOM_FUNCTION is possible, but we know that
801         // if there is no grad_fn, that means that the view was performed in
802         // no-grad mode
803         msg = c10::str(
804             prefix,
805             " no_grad mode and ",
806             modified_obj,
807             " modified inplace with grad mode enabled.");
808       }
809     } else {
810       msg = c10::str(
811           "Output ",
812           diff_view_meta->output_nr_,
813           " of ",
814           grad_fn->name(),
815           " is a view and ",
816           modified_obj,
817           " modified inplace.");
818     }
819 
820     if (creation_meta == CreationMeta::MULTI_OUTPUT_NODE) {
821       msg = c10::str(
822           msg,
823           " This view is the output of a function that returns multiple views. Such functions do not"
824           " allow the output views to be modified inplace. You should replace the inplace operation by an"
825           " out-of-place one.");
826     } else if (creation_meta == CreationMeta::NO_GRAD_MODE) {
827       msg = c10::str(
828           msg,
829           " Given that this use case is ambiguous and error-prone, it is forbidden."
830           " You can clarify your code by moving both the view and the inplace either both"
831           " inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want"
832           " the inplace to be tracked).");
833     } else if (creation_meta == CreationMeta::INFERENCE_MODE) {
834       msg = c10::str(
835           msg,
836           " Given that this use case is ambiguous and error-prone, it is forbidden."
837           " You can clarify your code by moving both the view and the inplace either both"
838           " inside the inference_mode block (if you don't want the inplace to be tracked) or both outside (if you want"
839           " the inplace to be tracked).");
840     } else if (creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) {
841       msg = c10::str(
842           msg,
843           " This view was created inside a custom Function (or because an input was returned as-is) and the"
844           " autograd logic to handle view+inplace would override the custom backward associated with the custom"
845           " Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by"
846           " cloning the output of the custom Function.");
847     } else {
848       TORCH_INTERNAL_ASSERT(false, "Invalid CreationMeta state");
849     }
850 
851     TORCH_CHECK(false, msg);
852   }
853 }
854 
get_symints() const855 std::vector<c10::SymInt> ChainedViewFunc::get_symints() const {
856   auto symints = first->get_symints();
857   auto second_symints = second->get_symints();
858   symints.reserve(symints.size() + second_symints.size());
859   symints.insert(
860       symints.end(),
861       std::make_move_iterator(second_symints.begin()),
862       std::make_move_iterator(second_symints.end()));
863   return symints;
864 }
865 
get_tensors() const866 std::vector<at::Tensor> ChainedViewFunc::get_tensors() const {
867   auto tensors = first->get_tensors();
868   auto second_tensors = second->get_tensors();
869   tensors.reserve(tensors.size() + second_tensors.size());
870   tensors.insert(
871       tensors.end(),
872       std::make_move_iterator(second_tensors.begin()),
873       std::make_move_iterator(second_tensors.end()));
874   return tensors;
875 }
876 
operator ()(const at::Tensor & input_base) const877 at::Tensor ChainedViewFunc::operator()(const at::Tensor& input_base) const {
878   return (*second)((*first)(input_base));
879 }
880 
clone_and_set(std::optional<std::vector<c10::SymInt>> symints,std::optional<std::vector<at::Tensor>> tensors) const881 std::unique_ptr<ViewFunc> ChainedViewFunc::clone_and_set(
882     std::optional<std::vector<c10::SymInt>> symints,
883     std::optional<std::vector<at::Tensor>> tensors) const {
884   std::optional<std::vector<c10::SymInt>> first_symints;
885   std::optional<std::vector<c10::SymInt>> second_symints;
886   if (symints.has_value()) {
887     TORCH_INTERNAL_ASSERT(symints->size() == num_symints());
888     first_symints = std::vector<c10::SymInt>(
889         symints->begin(), symints->begin() + first->num_symints());
890     second_symints = std::vector<c10::SymInt>(
891         symints->begin() + first->num_symints(), symints->end());
892   }
893 
894   std::optional<std::vector<at::Tensor>> first_tensors;
895   std::optional<std::vector<at::Tensor>> second_tensors;
896   if (tensors.has_value()) {
897     TORCH_INTERNAL_ASSERT(tensors->size() == num_tensors());
898     first_tensors = std::vector<at::Tensor>(
899         tensors->begin(), tensors->begin() + first->num_tensors());
900     second_tensors = std::vector<at::Tensor>(
901         tensors->begin() + first->num_tensors(), tensors->end());
902   }
903 
904   return std::make_unique<ChainedViewFunc>(
905       first->clone_and_set(first_symints, first_tensors),
906       second->clone_and_set(second_symints, second_tensors));
907 }
908 
909 } // namespace torch::autograd
910