1 #include <torch/csrc/autograd/variable.h>
2
3 #include <torch/csrc/autograd/InferenceMode.h>
4 #include <torch/csrc/autograd/autograd.h>
5 #include <torch/csrc/autograd/edge.h>
6 #include <torch/csrc/autograd/engine.h>
7 #include <torch/csrc/autograd/function.h>
8 #include <torch/csrc/autograd/functions/accumulate_grad.h>
9 #include <torch/csrc/autograd/functions/tensor.h>
10 #include <torch/csrc/autograd/functions/utils.h>
11 #include <torch/csrc/autograd/generated/Functions.h>
12 #include <torch/csrc/autograd/generated/ViewFuncs.h>
13 #include <torch/csrc/autograd/utils/error_messages.h>
14
15 #include <ATen/ATen.h>
16 #include <ATen/FuncTorchTLS.h>
17 #include <ATen/MemoryOverlap.h>
18 #include <c10/util/Exception.h>
19
20 #include <memory>
21 #include <mutex>
22 #include <stdexcept>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 namespace torch::autograd {
28
29 // Returns a ViewFunc with a corresponding view that matches the shape,
30 // stride, and storage offset of the given tensor.
31 // NB: On mobile, the as_strided() op and thus the generated AsStridedViewFunc
32 // may not be available.
create_view_func_matching(const Variable & t)33 static std::unique_ptr<ViewFunc> create_view_func_matching(const Variable& t) {
34 #ifdef AS_STRIDED_VIEW_FUNC_AVAILABLE
35 return std::make_unique<torch::autograd::generated::AsStridedViewFunc>(
36 t.sym_sizes(), t.sym_strides(), t.sym_storage_offset());
37 #else
38 return std::make_unique<ErroringViewFunc>("as_strided() not available");
39 #endif
40 }
41
DifferentiableViewMeta(at::TensorImpl * self_impl,std::optional<ViewInfo> backward_info,std::optional<ViewInfo> forward_info,bool shared_view_info,CreationMeta creation_meta)42 DifferentiableViewMeta::DifferentiableViewMeta(
43 at::TensorImpl* self_impl,
44 std::optional<ViewInfo> backward_info,
45 std::optional<ViewInfo> forward_info,
46 bool shared_view_info,
47 CreationMeta creation_meta)
48 : AutogradMeta(self_impl),
49 backward_info_(std::move(backward_info)),
50 forward_info_(std::move(forward_info)),
51 shared_view_info_(shared_view_info),
52 creation_meta_(creation_meta) {
53 is_view_ = true;
54 if (backward_info_.has_value()) {
55 self_impl->set_version_counter(
56 impl::version_counter(backward_info_.value().base_));
57 attr_version_ = self_impl->version_counter().current_version();
58 TORCH_INTERNAL_ASSERT(
59 backward_info_.value().base_.unsafeGetTensorImpl() != self_impl);
60 }
61 if (shared_view_info_) {
62 TORCH_INTERNAL_ASSERT(
63 backward_info_.has_value(),
64 "Shared view info require a backward view info.");
65 TORCH_INTERNAL_ASSERT(
66 !forward_info_.has_value(),
67 "Shared view info require forward view info to be empty")
68 }
69 }
70
71 // Chain this view info with the new view op between base and tensor
chain(const Variable & base,const Variable & tensor,std::unique_ptr<ViewFunc> view_func,std::function<Variable (const Variable &)> rev_view_func) const72 ViewInfo ViewInfo::chain(
73 const Variable& base,
74 const Variable& tensor,
75 std::unique_ptr<ViewFunc> view_func,
76 std::function<Variable(const Variable&)> rev_view_func) const {
77 // Set `view_func` using the root base as input.
78 // `view_func` is used to recover views in backward when either as_strided is
79 // not supported or the view function changes the metadata which is not
80 // recorded by as_strided See Note [View + Inplace update on base tensor] and
81 // [View + Inplace update on view tensor] for more details how we use this
82 // function in backward.
83 if (view_func) {
84 // both current_view and it's parent have a view_func
85 if (view_fn_) {
86 view_func = std::make_unique<ChainedViewFunc>(
87 view_fn_->clone_and_set(), std::move(view_func));
88
89 // assume view_fn_ / rev_view_fn_ always exist together or neither are set
90 auto prev_rev_fn = rev_view_fn_;
91 rev_view_func = [=](const at::Tensor& root_view) {
92 auto temp = rev_view_func(root_view);
93 return prev_rev_fn(temp);
94 };
95 } else {
96 // current_view has a view_func and but it's parent doesn't have one
97 if (base.unsafeGetTensorImpl()->support_as_strided()) {
98 auto match_base_view_func = create_view_func_matching(base);
99 view_func = std::make_unique<ChainedViewFunc>(
100 std::move(match_base_view_func), std::move(view_func));
101
102 // assume view_fn_ / rev_view_fn_ always exist together or neither are
103 // set
104 const auto& root_base = base._base();
105 auto root_base_size = root_base.sym_sizes().vec();
106 auto root_base_stride = root_base.sym_strides().vec();
107 auto root_base_storage_offset = root_base.sym_storage_offset();
108 rev_view_func = [=](const at::Tensor& root_view) {
109 auto temp = rev_view_func(root_view);
110 return temp.as_strided_symint(
111 root_base_size, root_base_stride, root_base_storage_offset);
112 };
113 } else {
114 // This case should be relatively rare: parent view doesn't have a
115 // view_func() AND as_strided() isn't supported; there's no obvious way
116 // to chain the two views.
117 auto error_msg =
118 ("Attempted to chain views when the parent view has no view_func() and "
119 "does not support as_strided(). This is not supported.");
120 view_func = std::make_unique<ErroringViewFunc>(error_msg);
121 rev_view_func = [=](const at::Tensor& root_view) {
122 TORCH_CHECK(false, error_msg);
123 return root_view;
124 };
125 }
126 }
127 } else if (view_fn_) {
128 // if current_view doesn't have a view_func but it's parent has one
129 auto match_tensor_view_func = create_view_func_matching(tensor);
130 view_func = std::make_unique<ChainedViewFunc>(
131 view_fn_->clone_and_set(), std::move(match_tensor_view_func));
132
133 // assume view_fn_ / rev_view_fn_ always exist together or neither are set
134 auto prev_rev_view_fn = rev_view_fn_;
135 auto base_size = base.sym_sizes().vec();
136 auto base_stride = base.sym_strides().vec();
137 auto base_storage_offset = base.sym_storage_offset();
138 rev_view_func = [=](const at::Tensor& root_view) {
139 auto temp = root_view.as_strided_symint(
140 base_size, base_stride, base_storage_offset);
141 return prev_rev_view_fn(temp);
142 };
143 }
144
145 return ViewInfo(base_, std::move(view_func), std::move(rev_view_func));
146 }
147
148 namespace {
149
150 at::Tensor singleton_undefined_tensor;
151
152 struct ConcreteAutogradMetaFactory : public c10::impl::AutogradMetaFactory {
maketorch::autograd::__anon548c2c850511::ConcreteAutogradMetaFactory153 std::unique_ptr<c10::AutogradMetaInterface> make() const override {
154 return std::make_unique<AutogradMeta>();
155 }
undefined_tensortorch::autograd::__anon548c2c850511::ConcreteAutogradMetaFactory156 const at::Tensor& undefined_tensor() const override {
157 return singleton_undefined_tensor;
158 }
159 };
160
161 ConcreteAutogradMetaFactory meta_factory;
162
163 static c10::impl::AutogradMetaFactoryRegisterer meta_factory_registerer(
164 &meta_factory);
165
166 } // namespace
167
168 namespace impl {
169
materialize_autograd_meta(const at::TensorBase & self)170 AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) {
171 TORCH_CHECK(
172 self.defined(),
173 "cannot call materialize_autograd_meta() on undefined tensor");
174 auto p = self.unsafeGetTensorImpl();
175 if (!p->autograd_meta()) {
176 p->set_autograd_meta(std::make_unique<AutogradMeta>());
177 }
178 return get_autograd_meta(self);
179 }
180
update_tensor_hooks_on_new_gradfn(const at::TensorBase & self,const std::shared_ptr<torch::autograd::Node> & old_fn,const std::shared_ptr<torch::autograd::Node> & new_fn)181 static void update_tensor_hooks_on_new_gradfn(
182 const at::TensorBase& self,
183 const std::shared_ptr<torch::autograd::Node>& old_fn,
184 const std::shared_ptr<torch::autograd::Node>& new_fn) {
185 // This function is called whenever the grad_fn of the tensor is
186 // changed. We assume here that new_fn does not yet have hooks of
187 // its own.
188 //
189 // This function does two things:
190 // (1) reset the list when grad_fn is updated, so new hooks don't
191 // get erroneously registered to the old grad_fn.
192 // Note that the old cpp_hooks_list_ is still kept alive by the
193 // old grad_fn so hooks registered to the older version of the tensor
194 // will continue to be active.
195 // (2) If there is a retains_grad hook registered, move that from the
196 // old cpp_hooks_list_ to the new one
197 const auto& meta = impl::get_autograd_meta(self);
198 TORCH_INTERNAL_ASSERT(meta);
199 TORCH_INTERNAL_ASSERT(new_fn);
200 meta->cpp_hooks_list_ = nullptr;
201 const c10::impl::PyInterpreter* interp =
202 self.unsafeGetTensorImpl()->pyobj_slot()->pyobj_interpreter();
203 if (interp) {
204 (*interp)->reset_backward_hooks(self.unsafeGetTensorImpl());
205 }
206 if (self.retains_grad()) {
207 TORCH_INTERNAL_ASSERT(old_fn);
208 auto out = old_fn->pop_retains_grad_hook(self.output_nr());
209 TORCH_INTERNAL_ASSERT(out != nullptr);
210 new_fn->add_retains_grad_hook(std::move(out), self.output_nr());
211 }
212 }
213
rebase_history(const Variable & self,Edge gradient_edge)214 void rebase_history(const Variable& self, Edge gradient_edge) {
215 TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr);
216 const auto& meta = impl::get_autograd_meta(self);
217 auto old_fn = meta != nullptr ? meta->grad_fn_ : nullptr;
218 auto diff_view_meta = get_view_autograd_meta(self);
219 if (diff_view_meta && diff_view_meta->has_bw_view()) {
220 // See NOTE [ View + Inplace detection ]
221 auto creation_meta = diff_view_meta->get_creation_meta();
222 // Do not use handle_view_on_rebase here as check_inplace should have been
223 // called before this and either throw an error
224 TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::DEFAULT);
225 TORCH_INTERNAL_ASSERT(gradient_edge.input_nr == 0);
226 TORCH_INTERNAL_ASSERT(gradient_edge.function);
227 TORCH_CHECK(
228 gradient_edge.function->num_inputs() == 1,
229 "Functions which modify views in-place must return a single Variable");
230 const auto& view_info = diff_view_meta->get_backward_view();
231 diff_view_meta->output_nr_ = gradient_edge.input_nr;
232 auto copy_slices = std::make_shared<CopySlices>(
233 view_info.base_,
234 at::TensorGeometry(self),
235 view_info.has_view_fn() ? view_info.view_fn().clone_and_set() : nullptr,
236 std::move(gradient_edge.function));
237 if (self.requires_grad()) {
238 // If self did not previously require grad, there are no hooks to move
239 torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
240 view_info.base_, view_info.base_.grad_fn(), copy_slices);
241 }
242 set_gradient_edge(view_info.base_, {std::move(copy_slices), 0});
243 self.grad_fn(); // trigger an update to the view's grad_fn
244 return;
245 }
246
247 set_gradient_edge(self, std::move(gradient_edge));
248 // Pass both self and its grad_fn to avoid calling into grad_fn reentrantly
249 torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
250 self, old_fn, self.grad_fn());
251 }
252
create_cpp_hook(const at::TensorBase & self,bool is_retains_grad_hook)253 void create_cpp_hook(const at::TensorBase& self, bool is_retains_grad_hook) {
254 const auto& fn = self.grad_fn();
255 std::shared_ptr<hooks_list>& list =
256 materialize_autograd_meta(self)->cpp_hooks_list_;
257 list = std::make_shared<hooks_list>();
258 auto hook_ptr =
259 std::make_unique<CppFunctionTensorPreHook>(list, self.output_nr());
260 // NB: we could potentially only update hooks_ if !fn, but it shouldn't
261 // matter
262 // and this was the way before, so we keep it like this for now.
263 clear_hooks(self);
264 add_hook(self, std::make_unique<CppFunctionTensorPreHook>(list, 0));
265 if (fn) {
266 fn->add_tensor_pre_hook(std::move(hook_ptr));
267 }
268 }
269
set_grad_accumulator(const Variable & self,std::weak_ptr<Node> grad_accumulator)270 void set_grad_accumulator(
271 const Variable& self,
272 std::weak_ptr<Node> grad_accumulator) {
273 materialize_autograd_meta(self)->grad_accumulator_ =
274 std::move(grad_accumulator);
275 }
276
try_get_grad_accumulator(const Variable & self)277 std::shared_ptr<Node> try_get_grad_accumulator(const Variable& self) {
278 if (get_autograd_meta(self)) {
279 return get_autograd_meta(self)->grad_accumulator_.lock();
280 } else {
281 return nullptr;
282 }
283 }
284
grad_accumulator(const Variable & self)285 std::shared_ptr<Node> grad_accumulator(const Variable& self) {
286 auto autograd_meta = get_autograd_meta(self);
287 if (!autograd_meta) {
288 return nullptr;
289 }
290 if (autograd_meta->grad_fn_) {
291 throw std::logic_error(
292 "grad_accumulator() should be only called on leaf Variables");
293 }
294 if (!autograd_meta->requires_grad_) {
295 return nullptr;
296 }
297
298 std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
299
300 auto result = autograd_meta->grad_accumulator_.lock();
301 if (result)
302 return result;
303
304 c10::raw::intrusive_ptr::incref(self.unsafeGetTensorImpl());
305 auto intrusive_from_this =
306 c10::intrusive_ptr<at::TensorImpl>::reclaim(self.unsafeGetTensorImpl());
307 result = std::make_shared<AccumulateGrad>(
308 Variable(std::move(intrusive_from_this)));
309 autograd_meta->grad_accumulator_ = result;
310 return result;
311 }
312
gradient_edge(const Variable & self)313 Edge gradient_edge(const Variable& self) {
314 // If grad_fn is null (as is the case for a leaf node), we instead
315 // interpret the gradient function to be a gradient accumulator, which will
316 // accumulate its inputs into the grad property of the variable. These
317 // nodes get suppressed in some situations, see "suppress gradient
318 // accumulation" below. Note that only variables which have `requires_grad =
319 // True` can have gradient accumulators.
320 if (const auto& gradient = self.grad_fn()) {
321 return Edge(gradient, self.output_nr());
322 } else {
323 return Edge(grad_accumulator(self), 0);
324 }
325 }
326
set_gradient_edge(const Variable & self,Edge edge)327 void set_gradient_edge(const Variable& self, Edge edge) {
328 auto* meta = materialize_autograd_meta(self);
329 meta->grad_fn_ = std::move(edge.function);
330 meta->output_nr_ = edge.input_nr;
331 // For views, make sure this new grad_fn_ is not overwritten unless it is
332 // necessary in the VariableHooks::grad_fn below. This logic is only relevant
333 // for custom autograd Functions for which multiple operations can happen on a
334 // given Tensor before its gradient edge is set when exiting the custom
335 // Function.
336 auto diff_view_meta = get_view_autograd_meta(self);
337 if (diff_view_meta && diff_view_meta->has_bw_view()) {
338 diff_view_meta->set_attr_version(self._version());
339 }
340 }
341
grad_fn_unsafe(const Variable & self)342 Node* grad_fn_unsafe(const Variable& self) {
343 if (get_autograd_meta(self)) {
344 return get_autograd_meta(self)->grad_fn_.get();
345 } else {
346 return nullptr;
347 }
348 }
349
350 // Versions
351 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
352
set_version_counter(const Variable & self,const c10::VariableVersion & version_counter)353 void set_version_counter(
354 const Variable& self,
355 const c10::VariableVersion& version_counter) {
356 TORCH_CHECK(
357 self.defined(), "cannot call set_version_counter() on undefined tensor");
358 self.unsafeGetTensorImpl()->set_version_counter(version_counter);
359 }
360
bump_version(const Variable & self)361 void bump_version(const Variable& self) {
362 TORCH_CHECK(self.defined(), "cannot call bump_version() on undefined tensor");
363 self.unsafeGetTensorImpl()->bump_version();
364 }
365
version_counter(const Variable & self)366 const c10::VariableVersion& version_counter(const Variable& self) {
367 TORCH_CHECK(
368 self.defined(), "cannot call version_counter() on undefined tensor");
369 return self.unsafeGetTensorImpl()->version_counter();
370 }
371
372 // Hooks
373 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
374
add_hook(const at::TensorBase & self,std::unique_ptr<FunctionPreHook> hook)375 void add_hook(
376 const at::TensorBase& self,
377 std::unique_ptr<FunctionPreHook> hook) {
378 AutogradMeta* meta = materialize_autograd_meta(self);
379 TORCH_INTERNAL_ASSERT(meta->hooks_.empty());
380 meta->hooks_.push_back(std::move(hook));
381 }
382
hooks(const Variable & self)383 std::vector<std::unique_ptr<FunctionPreHook>>& hooks(const Variable& self) {
384 TORCH_INTERNAL_ASSERT(get_autograd_meta(self));
385 return get_autograd_meta(self)->hooks_;
386 }
387
clear_hooks(const at::TensorBase & self)388 void clear_hooks(const at::TensorBase& self) {
389 // This is a little goofy, but usually this should be a no oop
390 materialize_autograd_meta(self)->hooks_.clear();
391 }
392
set_post_acc_grad_hooks(const at::TensorBase & self,std::unique_ptr<PostAccumulateGradHook> dict)393 void set_post_acc_grad_hooks(
394 const at::TensorBase& self,
395 std::unique_ptr<PostAccumulateGradHook> dict) {
396 AutogradMeta* meta = materialize_autograd_meta(self);
397 meta->post_acc_grad_hooks_ = std::move(dict);
398 }
399
post_acc_grad_hooks(const Variable & self)400 std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks(
401 const Variable& self) {
402 TORCH_INTERNAL_ASSERT(get_autograd_meta(self));
403 return get_autograd_meta(self)->post_acc_grad_hooks_;
404 }
405
set_name(const Variable & self,const std::string & name)406 void set_name(const Variable& self, const std::string& name) {
407 materialize_autograd_meta(self)->name_ = name;
408 }
409
410 // Miscellaneous
411 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
412
get_autograd_meta(const at::TensorBase & self)413 AutogradMeta* get_autograd_meta(const at::TensorBase& self) {
414 // NB: could return nullptr
415 TORCH_CHECK(
416 self.defined(), "cannot call get_autograd_meta() on undefined tensor");
417 return static_cast<AutogradMeta*>(
418 self.unsafeGetTensorImpl()->autograd_meta());
419 }
420
get_view_autograd_meta(const at::TensorBase & self)421 DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase& self) {
422 // NB: return nullptr if self is not a view
423 AutogradMeta* meta = get_autograd_meta(self);
424 if (meta && meta->is_view_) {
425 return static_cast<DifferentiableViewMeta*>(meta);
426 } else {
427 return nullptr;
428 }
429 }
430
431 } // namespace impl
432
433 using at::Tensor;
434
435 VariableHooks variableHooks;
436 at::impl::VariableHooksRegisterer registerVariableHooks(&variableHooks);
437
variable_data(const at::TensorBase & self) const438 at::TensorBase VariableHooks::variable_data(const at::TensorBase& self) const {
439 TORCH_CHECK(
440 self.defined(), "cannot call variable_data() on undefined tensor");
441 auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
442 /*version_counter=*/0,
443 /*allow_tensor_metadata_change=*/false);
444 self_impl_copy->set_autograd_meta(nullptr);
445 return at::Tensor(self_impl_copy);
446 }
447
tensor_data(const at::TensorBase & self) const448 at::TensorBase VariableHooks::tensor_data(const at::TensorBase& self) const {
449 TORCH_CHECK(self.defined(), "cannot call tensor_data() on undefined tensor");
450 auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
451 /*version_counter=*/self.unsafeGetTensorImpl()->version_counter(),
452 /*allow_tensor_metadata_change=*/
453 self.unsafeGetTensorImpl()->allow_tensor_metadata_change());
454 return at::Tensor(self_impl_copy);
455 }
456
is_leaf(const at::TensorBase & self) const457 bool VariableHooks::is_leaf(const at::TensorBase& self) const {
458 if (impl::get_autograd_meta(self)) {
459 return impl::get_autograd_meta(self)->grad_fn_ == nullptr;
460 } else {
461 return true;
462 }
463 }
464
output_nr(const at::TensorBase & self) const465 int64_t VariableHooks::output_nr(const at::TensorBase& self) const {
466 if (impl::get_autograd_meta(self)) {
467 return impl::get_autograd_meta(self)->output_nr_;
468 } else {
469 return 0;
470 }
471 }
472
set_data(const at::TensorBase & self_base,const at::TensorBase & new_data_base) const473 void VariableHooks::set_data(
474 const at::TensorBase& self_base,
475 const at::TensorBase& new_data_base) const {
476 at::OptionalTensorRef self_ref(self_base);
477 const Tensor& self = *self_ref;
478 at::OptionalTensorRef new_data_ref(new_data_base);
479 const Tensor& new_data = *new_data_ref;
480
481 // `var.set_data(new_data)` shallow-copies all non-autograd TensorImpl fields
482 // from `new_data` to `var`. It requires that `new_data` and `var` have
483 // compatible tensor type.
484 TORCH_CHECK(
485 _has_compatible_shallow_copy_type(self, new_data),
486 "Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type.");
487
488 TORCH_CHECK(
489 !self.requires_grad() ||
490 isDifferentiableType(at::typeMetaToScalarType(new_data.dtype())),
491 "data set to a tensor that requires gradients must be floating point or complex dtype");
492
493 // Resets gradient accumulator if metadata is out of date
494 AutogradMeta* autograd_meta = impl::get_autograd_meta(self);
495 if (autograd_meta) {
496 std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
497 auto prior_accumulator = autograd_meta->grad_accumulator_.lock();
498 if (prior_accumulator) {
499 const auto prior_device = prior_accumulator->input_metadata(0).device();
500 const auto new_device = new_data.device();
501
502 if (!new_data.options().type_equal(self.options()) ||
503 prior_device != new_device) {
504 autograd_meta->grad_accumulator_.reset();
505 }
506 }
507 }
508
509 // Version counter is not shared when we replace a `Variable`'s tensor data
510 // by calling `set_data(...)`. The original version of the `Variable` is
511 // always preserved. See NOTE [ Version Counter Sharing ] for details.
512 //
513 // `var.set_data(new_data)` always ignores `var`'s
514 // `allow_tensor_metadata_change_`, because users need this API as an escape
515 // hatch for changing a tensor's metadata regardless of its
516 // `allow_tensor_metadata_change_` value, and the users are responsible for
517 // ensuring this is the behavior they want.
518 self.unsafeGetTensorImpl()->shallow_copy_from(new_data.getIntrusivePtr());
519 }
520
data(const at::TensorBase & self) const521 at::TensorBase VariableHooks::data(const at::TensorBase& self) const {
522 return self.variable_data();
523 }
524
_version(const at::TensorBase & self) const525 int64_t VariableHooks::_version(const at::TensorBase& self) const {
526 return self.unsafeGetTensorImpl()->version_counter().current_version();
527 }
528
retain_grad(const at::TensorBase & self) const529 void VariableHooks::retain_grad(const at::TensorBase& self) const {
530 TORCH_CHECK(
531 self.requires_grad(),
532 "can't retain_grad on Tensor that has requires_grad=False");
533
534 // temporary hack to improve functorch UX.
535 const auto& functorch_tls = at::functorch::functorchTLSAccessor();
536 if (functorch_tls) {
537 functorch_tls->checkSupportsRetainGrad();
538 }
539
540 if (self.is_leaf()) { // no-op for leaves
541 return;
542 }
543 if (impl::get_autograd_meta(self)->retains_grad_) {
544 return;
545 }
546 c10::weak_intrusive_ptr<c10::TensorImpl> weak_self(self.getIntrusivePtr());
547
548 auto retain_grad_hook = [weak_self](const at::TensorBase& grad_base) {
549 at::Tensor grad{grad_base};
550 if (!weak_self.expired() && grad.defined()) {
551 auto var = weak_self.lock();
552 if (!var->grad().defined()) {
553 if (grad.is_sparse()) {
554 var->mutable_grad() = grad.clone();
555 } else {
556 var->mutable_grad() = grad.clone(at::MemoryFormat::Contiguous);
557 }
558 } else {
559 var->mutable_grad() = var->grad() + grad;
560 }
561 }
562 return at::TensorBase{};
563 };
564
565 const auto& fn = self.grad_fn();
566 fn->add_retains_grad_hook(
567 std::make_unique<CppFunctionSingleTensorPreHook>(
568 std::move(retain_grad_hook), self.output_nr()),
569 self.output_nr());
570 impl::get_autograd_meta(self)->retains_grad_ = true;
571 }
572
retains_grad(const at::TensorBase & self) const573 bool VariableHooks::retains_grad(const at::TensorBase& self) const {
574 if (impl::get_autograd_meta(self)) {
575 return impl::get_autograd_meta(self)->retains_grad_;
576 } else {
577 return false;
578 }
579 }
580
_backward(const Tensor & self,at::TensorList inputs,const std::optional<Tensor> & gradient,std::optional<bool> keep_graph,bool create_graph) const581 void VariableHooks::_backward(
582 const Tensor& self,
583 at::TensorList inputs,
584 const std::optional<Tensor>& gradient,
585 std::optional<bool> keep_graph,
586 bool create_graph) const {
587 // TODO torch::autograd::backward should take the std::optional<Tensor>
588 // gradient directly instead of us having to unwrap it to Tensor _gradient
589 // here.
590 Tensor _gradient = gradient.has_value() ? *gradient : Tensor();
591 std::vector<torch::autograd::Variable> input_vars(
592 inputs.begin(), inputs.end());
593 torch::autograd::backward(
594 {self}, {std::move(_gradient)}, keep_graph, create_graph, input_vars);
595 }
596
requires_grad_(const at::TensorBase & self,bool _requires_grad) const597 void VariableHooks::requires_grad_(
598 const at::TensorBase& self,
599 bool _requires_grad) const {
600 if (!self.is_leaf() && !_requires_grad) {
601 throw std::runtime_error(
602 autograd::utils::requires_grad_leaf_error(_requires_grad));
603 }
604 self.set_requires_grad(_requires_grad);
605 }
606
607 // Backward View Variables
608 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
609
is_view(const at::TensorBase & self) const610 bool VariableHooks::is_view(const at::TensorBase& self) const {
611 auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
612 if (diff_view_meta) {
613 return diff_view_meta->has_bw_view();
614 } else {
615 return false;
616 }
617 }
618
base(const at::TensorBase & self) const619 const at::TensorBase& VariableHooks::base(const at::TensorBase& self) const {
620 auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
621 if (diff_view_meta) {
622 TORCH_CHECK(
623 diff_view_meta->has_bw_view(),
624 "Can't get base of non-backward view Tensor");
625 return diff_view_meta->get_backward_view().base_;
626 } else {
627 throw std::runtime_error("Can't get base of non-view Tensor");
628 }
629 }
630
631 namespace {
632 std::string singleton_string;
633 }
634
name(const at::TensorBase & self) const635 const std::string& VariableHooks::name(const at::TensorBase& self) const {
636 TORCH_CHECK(
637 self.defined(), "cannot call variable_data() on undefined tensor");
638 if (torch::autograd::impl::get_autograd_meta(self)) {
639 return torch::autograd::impl::get_autograd_meta(self)->name_;
640 } else {
641 return singleton_string;
642 }
643 }
644
645 namespace {
646 std::shared_ptr<torch::autograd::Node> singleton_shared_ptr;
647 }
648
grad_fn(const at::TensorBase & self) const649 const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
650 const at::TensorBase& self) const {
651 auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
652 if (diff_view_meta && diff_view_meta->has_bw_view()) {
653 // See NOTE [ View + Inplace detection ]
654 std::lock_guard<std::mutex> lock(diff_view_meta->mutex_);
655 auto& view_info = diff_view_meta->get_backward_view();
656 if (!diff_view_meta->grad_fn_ && !view_info.base_.requires_grad()) {
657 return diff_view_meta->grad_fn_;
658 }
659 auto current_version = self._version();
660 auto old_fn = diff_view_meta->grad_fn_;
661 if (diff_view_meta->get_attr_version() != current_version) {
662 // This is an indirect rebase_history due to another view or the base
663 // being modified inplace
664 handle_view_on_rebase(diff_view_meta, /* indirect */ true);
665 TORCH_INTERNAL_ASSERT(diff_view_meta->output_nr_ == 0);
666 // Note [View + Inplace update for view tensor]
667 // An inplace update happened on Tensor `self` (which is a view).
668 // For example:
669 // view_1 = view_op_1(diff_view_meta->base_)
670 // view_2 = view_op_2(view_1)
671 // ...
672 // self = view_op_n(view_n-1)
673 // self = inplace_op(self)
674 //
675 // For CPU/CUDA backends, we employ one AsStridedBackward0 Node to
676 // represent the chain of view backward ops for efficiency.
677 //
678 // However in XLA backend we don't have full support of
679 // AsStridedBackward0, we instead run a full forward pass with a tensor
680 // that requires gradient to get proper grad_fn setup, then save it to
681 // DifferentiableViewMeta for future use. This is fairly cheap for XLA
682 // lazy tensor approach (but would be really expensive for CPU/CUDA). XLA
683 // Tensor only run through VariableType dispatch and lower the forward
684 // pass to a XLA HLO graph, then we take grad_fn and never materialize the
685 // tensor content. So we only construct the graph but not execute it,
686 // which is a fairly cheap operation to do.
687 //
688 // See Note [View + Inplace update for base tensor] for what we do to base
689 // tensor when an in-place operation happens.
690 //
691 // TODO: Potentially the following logic can be replaced by special logic
692 // in VariableType_x.cpp
693 // that would provide a way to recreate the grad_fn chain.
694 if (view_info.has_view_fn()) {
695 auto& view_fn = view_info.view_fn();
696 Tensor diff_view;
697 {
698 // We can reach this path with grad_mode disabled, e.g. engine
699 AutoGradMode grad_mode(true);
700 diff_view = view_fn(view_info.base_);
701 }
702 diff_view_meta->grad_fn_ = diff_view.grad_fn();
703 } else {
704 auto fn =
705 std::make_shared<torch::autograd::generated::AsStridedBackward0>();
706 fn->self_geometry = at::TensorGeometry(view_info.base_);
707 fn->size = self.sym_sizes().vec();
708 fn->stride = self.sym_strides().vec();
709 fn->storage_offset = self.sym_storage_offset();
710 fn->set_next_edges(
711 torch::autograd::collect_next_edges(view_info.base_));
712 fn->add_input_metadata(
713 view_info.base_.options(),
714 self.sym_sizes(), // Note: sizes(), not base_.sizes(), is
715 // intentional
716 self.unsafeGetTensorImpl()->is_python_dispatch(),
717 self.is_nested());
718 diff_view_meta->grad_fn_ = std::move(fn);
719 }
720 diff_view_meta->set_attr_version(current_version);
721
722 torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
723 self, old_fn, diff_view_meta->grad_fn_);
724 }
725 return diff_view_meta->grad_fn_;
726 }
727
728 if (torch::autograd::impl::get_autograd_meta(self)) {
729 return torch::autograd::impl::get_autograd_meta(self)->grad_fn_;
730 } else {
731 return singleton_shared_ptr;
732 }
733 }
734
remove_hook(const at::TensorBase & self,unsigned pos) const735 void VariableHooks::remove_hook(const at::TensorBase& self, unsigned pos)
736 const {
737 auto& list =
738 torch::autograd::impl::materialize_autograd_meta(self)->cpp_hooks_list_;
739 TORCH_CHECK(
740 list && pos < list->size(), "Invalid index, no hook at position ", pos);
741 // Hook will be ignored
742 (*list)[pos] = nullptr;
743 }
744
_register_hook(const at::TensorBase & self,std::function<at::TensorBase (const at::TensorBase &)> hook) const745 unsigned VariableHooks::_register_hook(
746 const at::TensorBase& self,
747 std::function<at::TensorBase(const at::TensorBase&)> hook) const {
748 TORCH_CHECK(
749 self.requires_grad(),
750 "cannot register a hook on a variable that "
751 "doesn't require gradient");
752 // NB: materialize_autograd_meta unnecessary due to requires grad check
753 auto& list = torch::autograd::impl::get_autograd_meta(self)->cpp_hooks_list_;
754 if (!list) {
755 torch::autograd::impl::create_cpp_hook(
756 self, /*is_retains_grad_hooks=*/false);
757 }
758 unsigned idx = list->size();
759 list->push_back(hook);
760 return idx;
761 }
762
handle_view_on_rebase(DifferentiableViewMeta * diff_view_meta,bool indirect)763 void handle_view_on_rebase(
764 DifferentiableViewMeta* diff_view_meta,
765 bool indirect) {
766 /// See NOTE [ View + Inplace detection ] for justification of the logic below
767 auto creation_meta = diff_view_meta->get_creation_meta();
768 if (creation_meta != CreationMeta::DEFAULT) {
769 auto grad_fn = diff_view_meta->grad_fn_.get();
770 std::string msg;
771 std::string modified_obj;
772 // Create the header for the error message.
773 if (indirect) {
774 modified_obj = "its base or another view of its base has been";
775 } else {
776 modified_obj = "is being";
777 }
778
779 if (creation_meta == CreationMeta::INFERENCE_MODE ||
780 creation_meta == CreationMeta::NO_GRAD_MODE || !grad_fn) {
781 std::string prefix;
782 if (grad_fn) {
783 prefix = c10::str(
784 "Output ",
785 diff_view_meta->output_nr_,
786 " of ",
787 grad_fn->name(),
788 " is a view of a view which was created in");
789 } else {
790 prefix = "A view was created in";
791 }
792 if (creation_meta == CreationMeta::INFERENCE_MODE) {
793 msg = c10::str(
794 prefix,
795 " inference mode and ",
796 modified_obj,
797 " modified inplace in normal mode.");
798 } else {
799 // create_meta is not necessarily CreationMeta::NO_GRAD_MODE
800 // e.g. CreationMeta::IN_CUSTOM_FUNCTION is possible, but we know that
801 // if there is no grad_fn, that means that the view was performed in
802 // no-grad mode
803 msg = c10::str(
804 prefix,
805 " no_grad mode and ",
806 modified_obj,
807 " modified inplace with grad mode enabled.");
808 }
809 } else {
810 msg = c10::str(
811 "Output ",
812 diff_view_meta->output_nr_,
813 " of ",
814 grad_fn->name(),
815 " is a view and ",
816 modified_obj,
817 " modified inplace.");
818 }
819
820 if (creation_meta == CreationMeta::MULTI_OUTPUT_NODE) {
821 msg = c10::str(
822 msg,
823 " This view is the output of a function that returns multiple views. Such functions do not"
824 " allow the output views to be modified inplace. You should replace the inplace operation by an"
825 " out-of-place one.");
826 } else if (creation_meta == CreationMeta::NO_GRAD_MODE) {
827 msg = c10::str(
828 msg,
829 " Given that this use case is ambiguous and error-prone, it is forbidden."
830 " You can clarify your code by moving both the view and the inplace either both"
831 " inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want"
832 " the inplace to be tracked).");
833 } else if (creation_meta == CreationMeta::INFERENCE_MODE) {
834 msg = c10::str(
835 msg,
836 " Given that this use case is ambiguous and error-prone, it is forbidden."
837 " You can clarify your code by moving both the view and the inplace either both"
838 " inside the inference_mode block (if you don't want the inplace to be tracked) or both outside (if you want"
839 " the inplace to be tracked).");
840 } else if (creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) {
841 msg = c10::str(
842 msg,
843 " This view was created inside a custom Function (or because an input was returned as-is) and the"
844 " autograd logic to handle view+inplace would override the custom backward associated with the custom"
845 " Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by"
846 " cloning the output of the custom Function.");
847 } else {
848 TORCH_INTERNAL_ASSERT(false, "Invalid CreationMeta state");
849 }
850
851 TORCH_CHECK(false, msg);
852 }
853 }
854
get_symints() const855 std::vector<c10::SymInt> ChainedViewFunc::get_symints() const {
856 auto symints = first->get_symints();
857 auto second_symints = second->get_symints();
858 symints.reserve(symints.size() + second_symints.size());
859 symints.insert(
860 symints.end(),
861 std::make_move_iterator(second_symints.begin()),
862 std::make_move_iterator(second_symints.end()));
863 return symints;
864 }
865
get_tensors() const866 std::vector<at::Tensor> ChainedViewFunc::get_tensors() const {
867 auto tensors = first->get_tensors();
868 auto second_tensors = second->get_tensors();
869 tensors.reserve(tensors.size() + second_tensors.size());
870 tensors.insert(
871 tensors.end(),
872 std::make_move_iterator(second_tensors.begin()),
873 std::make_move_iterator(second_tensors.end()));
874 return tensors;
875 }
876
operator ()(const at::Tensor & input_base) const877 at::Tensor ChainedViewFunc::operator()(const at::Tensor& input_base) const {
878 return (*second)((*first)(input_base));
879 }
880
clone_and_set(std::optional<std::vector<c10::SymInt>> symints,std::optional<std::vector<at::Tensor>> tensors) const881 std::unique_ptr<ViewFunc> ChainedViewFunc::clone_and_set(
882 std::optional<std::vector<c10::SymInt>> symints,
883 std::optional<std::vector<at::Tensor>> tensors) const {
884 std::optional<std::vector<c10::SymInt>> first_symints;
885 std::optional<std::vector<c10::SymInt>> second_symints;
886 if (symints.has_value()) {
887 TORCH_INTERNAL_ASSERT(symints->size() == num_symints());
888 first_symints = std::vector<c10::SymInt>(
889 symints->begin(), symints->begin() + first->num_symints());
890 second_symints = std::vector<c10::SymInt>(
891 symints->begin() + first->num_symints(), symints->end());
892 }
893
894 std::optional<std::vector<at::Tensor>> first_tensors;
895 std::optional<std::vector<at::Tensor>> second_tensors;
896 if (tensors.has_value()) {
897 TORCH_INTERNAL_ASSERT(tensors->size() == num_tensors());
898 first_tensors = std::vector<at::Tensor>(
899 tensors->begin(), tensors->begin() + first->num_tensors());
900 second_tensors = std::vector<at::Tensor>(
901 tensors->begin() + first->num_tensors(), tensors->end());
902 }
903
904 return std::make_unique<ChainedViewFunc>(
905 first->clone_and_set(first_symints, first_tensors),
906 second->clone_and_set(second_symints, second_tensors));
907 }
908
909 } // namespace torch::autograd
910