xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/TensorWrapper.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #pragma once
8 
9 #include <ATen/functorch/Macros.h>
10 #include <ATen/Tensor.h>
11 #include <ATen/functorch/Interpreter.h>
12 
13 namespace at::functorch {
14 
15 // NOTE: [functorch's TensorWrapper]
16 //
17 // Taking better suggestions for a name. TensorWrapper is the wrapper Tensor
18 // Subclass for functorch's grad-based transforms (grad, vjp, jvp). It is
19 // analogous to how vmap uses BatchedTensor as the wrapper Tensor subclass.
20 //
21 // If you're familiar with the Tensor-Variable merge, TensorWrapper is effectively
22 // another Variable.
23 //
24 // Consider grad(grad(torch.sin))(x). This wraps `x` as TensorWrapper(TensorWrapper(x)).
25 // The reason why is so that each TensorWrapper can hold its own AutogradMeta and
26 // participate in a **separate** autograd graph.
27 //
28 // There are alternative designs we could have chosen (e.g. each grad transform
29 // stores a weak map of Tensor -> AutogradMeta); the benefit of the TensorWrapper
30 // design is that we can re-use existing VariableType kernels (i.e. Autograd kernels)
31 // without much modification. Since a TensorWrapper looks like a regular Tensor,
32 // the VariableType kernel can pull out the AutogradMeta struct from where it
33 // expects and extend the autograd graph
34 
35 struct TORCH_API TensorWrapper : public c10::TensorImpl {
36   explicit TensorWrapper(
37       c10::DispatchKeySet key_set,
38       Tensor value,
39       int64_t level,
40       std::shared_ptr<bool> is_alive,
41       bool is_immutable = false,  // if true, this came from an operation that aliases an immutable tensor
42       bool use_value_sizes_strides = true);
43 
44   void refreshMetadata();
45 
valueTensorWrapper46   const Tensor& value() const {
47     return value_;
48   }
levelTensorWrapper49   std::optional<int64_t> level() const {
50     if (is_alive()) {
51       return level_;
52     }
53     return {};
54   }
is_immutableTensorWrapper55   bool is_immutable() const {
56     return is_immutable_;
57   }
58   bool is_alive() const;
59 
60   // Overrides necessary for autograd
61   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
62     const c10::VariableVersion& version_counter,
63     bool allow_tensor_metadata_change) const override;
64   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
65       c10::VariableVersion&& version_counter,
66       bool allow_tensor_metadata_change) const override;
67   void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
68 
69  private:
70   const char* tensorimpl_type_name() const override;
71   Tensor value_;
72   int64_t level_;
73   bool is_immutable_;
74 
75   // TensorWrapper receives a boolean flag on whether or not the Grad Interpreter
76   // that created it is still alive or not.
77   // If the Grad Interpreter is no longer alive then it attempts to behave like
78   // a regular Tensor.
79   //
80   // When we exit the level, this wrapper may be marked as "not alive".
81   // Wrappers that are not alive:
82   // 1) May still have autograd metadata on them
83   // 2) Forward dispatches to the underlying value()
84   std::shared_ptr<bool> is_alive_;
85 };
86 
87 // There are two variants of makeTensorWrapper: one that accepts a level
88 // and one that accepts an Interpreter.
89 //
90 // The one that accepts a level tries to automatically get the life handle from the
91 // interpreter on the DynamicLayerStack.
92 // It needs to be used with caution: if the interpreter is not on the
93 // DynamicLayerStack, then we won't be able to find the life handle.
94 //
95 // In practice this isn't a problem: when we're constructing TensorWrapper in
96 // Python, the corresponding interpreter is on the stack.
97 TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable=false);
98 TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, const Interpreter& interpreter, bool is_immutable=false);
99 TORCH_API TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor);
100 TORCH_API void dumpTensor(std::ostream & ss, const Tensor& tensor);
101 TORCH_API void dumpTensorCout(const Tensor& tensor);
102 
103 } // namespace at::functorch
104