xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/utils/grad_layout_contract.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 
5 namespace torch::autograd::utils {
6 
7 // Helper functions to enforce the "Gradient Layout Contract" described in
8 // torch/csrc/autograd/functions/accumulate_grad.h.
9 
10 // Checks if grad obeys the contract with variable.
obeys_layout_contract(const at::Tensor & grad,const at::Tensor & variable)11 inline bool obeys_layout_contract(
12     const at::Tensor& grad,
13     const at::Tensor& variable) {
14   TORCH_INTERNAL_ASSERT(!grad.is_sparse());
15   TORCH_INTERNAL_ASSERT(!grad.is_sparse_csr());
16   TORCH_INTERNAL_ASSERT(!variable.is_sparse_csr());
17 
18   // NOLINTNEXTLINE(bugprone-branch-clone)
19   if (variable.is_nested()) {
20     // TODO: Nested Tensor does not have an implementation of detach. The
21     // current implementation of nested tensor likely does obey the gradient
22     // contract and should return true, but this would likely change in the
23     // future
24     return false;
25   } else if (variable.is_sparse()) {
26     // Gradient Layout Contract is not applicable for sparse layouts
27     return false;
28   } else if (variable.is_non_overlapping_and_dense()) {
29     // Only look at stride for dimensions that are not of size 1.
30     const auto& grad_sizes = grad.sym_sizes();
31     const auto& grad_strides = grad.sym_strides();
32     const auto& variable_strides = variable.sym_strides();
33     for (const auto idx : c10::irange(grad_sizes.size())) {
34       if (grad_sizes[idx] != 1) {
35         if (grad_strides[idx] != variable_strides[idx]) {
36           return false;
37         }
38       } else {
39         // This should not be needed but we don't check if a Tensor has views
40         // before stashing it. And 0-strided Tensors of size 1 are actually
41         // views for ops like cat.
42         // TODO: Actually detect views in the accumulateGrad function so that
43         // this Tensor is not considered at all.
44         if (grad_strides[idx] == 0) {
45           return false;
46         }
47       }
48     }
49     return true;
50   } else {
51     return grad.is_contiguous(at::MemoryFormat::Contiguous);
52   }
53 }
54 
55 // Creates a clone of new_grad that obeys the contract with variable.
56 // The clone should attach to new_grad's history if GradMode::is_enabled().
clone_obey_contract(const at::Tensor & new_grad,const at::Tensor & variable)57 inline at::Tensor clone_obey_contract(
58     const at::Tensor& new_grad,
59     const at::Tensor& variable) {
60   if (variable.is_non_overlapping_and_dense()) {
61     // (1)
62     // Does this dicey-looking sequence attach the result to new_grad's
63     // history if GradMode::is_enabled()?  Yes, and @alband says it should.
64     return std::move(new_grad
65                          .new_empty_strided_symint(
66                              variable.sym_sizes(),
67                              variable.sym_strides(),
68                              variable.options().memory_format(std::nullopt))
69                          .copy_(new_grad));
70   } else {
71     // (2)
72     return new_grad.clone(at::MemoryFormat::Contiguous);
73   }
74 }
75 
76 } // namespace torch::autograd::utils
77