1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // This source code is licensed under the BSD-style license found in the 5 // LICENSE file in the root directory of this source tree. 6 7 #pragma once 8 9 #include <ATen/functorch/Macros.h> 10 #include <ATen/Tensor.h> 11 #include <ATen/functorch/Interpreter.h> 12 13 namespace at::functorch { 14 15 // NOTE: [functorch's TensorWrapper] 16 // 17 // Taking better suggestions for a name. TensorWrapper is the wrapper Tensor 18 // Subclass for functorch's grad-based transforms (grad, vjp, jvp). It is 19 // analogous to how vmap uses BatchedTensor as the wrapper Tensor subclass. 20 // 21 // If you're familiar with the Tensor-Variable merge, TensorWrapper is effectively 22 // another Variable. 23 // 24 // Consider grad(grad(torch.sin))(x). This wraps `x` as TensorWrapper(TensorWrapper(x)). 25 // The reason why is so that each TensorWrapper can hold its own AutogradMeta and 26 // participate in a **separate** autograd graph. 27 // 28 // There are alternative designs we could have chosen (e.g. each grad transform 29 // stores a weak map of Tensor -> AutogradMeta); the benefit of the TensorWrapper 30 // design is that we can re-use existing VariableType kernels (i.e. Autograd kernels) 31 // without much modification. Since a TensorWrapper looks like a regular Tensor, 32 // the VariableType kernel can pull out the AutogradMeta struct from where it 33 // expects and extend the autograd graph 34 35 struct TORCH_API TensorWrapper : public c10::TensorImpl { 36 explicit TensorWrapper( 37 c10::DispatchKeySet key_set, 38 Tensor value, 39 int64_t level, 40 std::shared_ptr<bool> is_alive, 41 bool is_immutable = false, // if true, this came from an operation that aliases an immutable tensor 42 bool use_value_sizes_strides = true); 43 44 void refreshMetadata(); 45 valueTensorWrapper46 const Tensor& value() const { 47 return value_; 48 } levelTensorWrapper49 std::optional<int64_t> level() const { 50 if (is_alive()) { 51 return level_; 52 } 53 return {}; 54 } is_immutableTensorWrapper55 bool is_immutable() const { 56 return is_immutable_; 57 } 58 bool is_alive() const; 59 60 // Overrides necessary for autograd 61 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( 62 const c10::VariableVersion& version_counter, 63 bool allow_tensor_metadata_change) const override; 64 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( 65 c10::VariableVersion&& version_counter, 66 bool allow_tensor_metadata_change) const override; 67 void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override; 68 69 private: 70 const char* tensorimpl_type_name() const override; 71 Tensor value_; 72 int64_t level_; 73 bool is_immutable_; 74 75 // TensorWrapper receives a boolean flag on whether or not the Grad Interpreter 76 // that created it is still alive or not. 77 // If the Grad Interpreter is no longer alive then it attempts to behave like 78 // a regular Tensor. 79 // 80 // When we exit the level, this wrapper may be marked as "not alive". 81 // Wrappers that are not alive: 82 // 1) May still have autograd metadata on them 83 // 2) Forward dispatches to the underlying value() 84 std::shared_ptr<bool> is_alive_; 85 }; 86 87 // There are two variants of makeTensorWrapper: one that accepts a level 88 // and one that accepts an Interpreter. 89 // 90 // The one that accepts a level tries to automatically get the life handle from the 91 // interpreter on the DynamicLayerStack. 92 // It needs to be used with caution: if the interpreter is not on the 93 // DynamicLayerStack, then we won't be able to find the life handle. 94 // 95 // In practice this isn't a problem: when we're constructing TensorWrapper in 96 // Python, the corresponding interpreter is on the stack. 97 TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable=false); 98 TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, const Interpreter& interpreter, bool is_immutable=false); 99 TORCH_API TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor); 100 TORCH_API void dumpTensor(std::ostream & ss, const Tensor& tensor); 101 TORCH_API void dumpTensorCout(const Tensor& tensor); 102 103 } // namespace at::functorch 104