1 #pragma once 2 3 #include <ATen/Tensor.h> 4 5 namespace torch::autograd::utils { 6 7 // Helper functions to enforce the "Gradient Layout Contract" described in 8 // torch/csrc/autograd/functions/accumulate_grad.h. 9 10 // Checks if grad obeys the contract with variable. obeys_layout_contract(const at::Tensor & grad,const at::Tensor & variable)11inline bool obeys_layout_contract( 12 const at::Tensor& grad, 13 const at::Tensor& variable) { 14 TORCH_INTERNAL_ASSERT(!grad.is_sparse()); 15 TORCH_INTERNAL_ASSERT(!grad.is_sparse_csr()); 16 TORCH_INTERNAL_ASSERT(!variable.is_sparse_csr()); 17 18 // NOLINTNEXTLINE(bugprone-branch-clone) 19 if (variable.is_nested()) { 20 // TODO: Nested Tensor does not have an implementation of detach. The 21 // current implementation of nested tensor likely does obey the gradient 22 // contract and should return true, but this would likely change in the 23 // future 24 return false; 25 } else if (variable.is_sparse()) { 26 // Gradient Layout Contract is not applicable for sparse layouts 27 return false; 28 } else if (variable.is_non_overlapping_and_dense()) { 29 // Only look at stride for dimensions that are not of size 1. 30 const auto& grad_sizes = grad.sym_sizes(); 31 const auto& grad_strides = grad.sym_strides(); 32 const auto& variable_strides = variable.sym_strides(); 33 for (const auto idx : c10::irange(grad_sizes.size())) { 34 if (grad_sizes[idx] != 1) { 35 if (grad_strides[idx] != variable_strides[idx]) { 36 return false; 37 } 38 } else { 39 // This should not be needed but we don't check if a Tensor has views 40 // before stashing it. And 0-strided Tensors of size 1 are actually 41 // views for ops like cat. 42 // TODO: Actually detect views in the accumulateGrad function so that 43 // this Tensor is not considered at all. 44 if (grad_strides[idx] == 0) { 45 return false; 46 } 47 } 48 } 49 return true; 50 } else { 51 return grad.is_contiguous(at::MemoryFormat::Contiguous); 52 } 53 } 54 55 // Creates a clone of new_grad that obeys the contract with variable. 56 // The clone should attach to new_grad's history if GradMode::is_enabled(). clone_obey_contract(const at::Tensor & new_grad,const at::Tensor & variable)57inline at::Tensor clone_obey_contract( 58 const at::Tensor& new_grad, 59 const at::Tensor& variable) { 60 if (variable.is_non_overlapping_and_dense()) { 61 // (1) 62 // Does this dicey-looking sequence attach the result to new_grad's 63 // history if GradMode::is_enabled()? Yes, and @alband says it should. 64 return std::move(new_grad 65 .new_empty_strided_symint( 66 variable.sym_sizes(), 67 variable.sym_strides(), 68 variable.options().memory_format(std::nullopt)) 69 .copy_(new_grad)); 70 } else { 71 // (2) 72 return new_grad.clone(at::MemoryFormat::Contiguous); 73 } 74 } 75 76 } // namespace torch::autograd::utils 77