1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/TensorWrapper.h>
8 #include <ATen/functorch/DynamicLayer.h>
9 #include <ATen/functorch/BatchedTensorImpl.h>
10
11 #include <torch/library.h>
12 #include <ATen/core/dispatch/Dispatcher.h>
13
14 #include <iostream>
15
16 namespace at::functorch {
17
dumpTensor(std::ostream & ss,const Tensor & tensor)18 void dumpTensor(std::ostream& ss, const Tensor& tensor) {
19 auto* wrapped = maybeGetTensorWrapper(tensor);
20 if (!wrapped) {
21 auto* batched = maybeGetBatchedImpl(tensor);
22 if (batched) {
23 ss << "Batched[lvl=" << batched->level() << " dim=" << batched->bdim() << ", ";
24 dumpTensor(ss, batched->value());
25 ss << "]";
26 return;
27 }
28 ss << "Tensor" << tensor.sizes();
29 return;
30 }
31 ss << "Wrapper[";
32 if (wrapped->level().has_value()) {
33 ss << "lvl=" << wrapped->level().value() << ", ";
34 } else {
35 ss << "dead, ";
36 }
37 dumpTensor(ss, wrapped->value());
38 ss << "]";
39 }
40
refreshMetadata()41 void TensorWrapper::refreshMetadata() {
42 // update size, strides and storage_offset
43 set_sizes_and_strides(
44 value_.sym_sizes(), value_.sym_strides(), value_.sym_storage_offset());
45
46 refresh_numel();
47 refresh_contiguous();
48 }
49
dumpTensorCout(const Tensor & tensor)50 void dumpTensorCout(const Tensor& tensor) {
51 dumpTensor(std::cout, tensor);
52
53 std::cout << '\n';
54 }
55
makeTensorWrapperPtr(const Tensor & tensor,int64_t level,const std::shared_ptr<bool> & life_handle)56 static c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int64_t level, const std::shared_ptr<bool>& life_handle) {
57 auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({
58 DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
59 auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
60 key_set = key_set.add(DispatchKey::FuncTorchGradWrapper);
61 return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, life_handle);
62 }
63
64 // use makeTensorWrapper instead to avoid potential footguns:
65 // unsafeMakeTensorWrapper doesn't check that level and life_handle
66 // refer to the same interpreter
unsafeMakeTensorWrapper(const Tensor & tensor,int64_t level,bool is_immutable,const std::shared_ptr<bool> & life_handle)67 static Tensor unsafeMakeTensorWrapper(
68 const Tensor& tensor,
69 int64_t level,
70 bool is_immutable,
71 const std::shared_ptr<bool>& life_handle) {
72 auto wrapped = maybeGetTensorWrapper(tensor);
73 if (wrapped) {
74 TORCH_INTERNAL_ASSERT(wrapped->level() < level);
75 }
76
77 auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({
78 DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
79 auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
80 key_set = key_set.add(DispatchKey::FuncTorchGradWrapper);
81 auto result = at::detail::make_tensor<TensorWrapper>(
82 key_set, tensor, level, life_handle, is_immutable);
83 TORCH_INTERNAL_ASSERT(result.key_set().has(DispatchKey::FuncTorchGradWrapper));
84
85 if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
86 result.unsafeGetTensorImpl()->set_wrapped_number(true);
87 }
88
89 return result;
90 }
91
makeTensorWrapper(const Tensor & tensor,int64_t level,bool is_immutable)92 Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable) {
93 auto life_handle = getLifeHandleForLevel(level);
94 return unsafeMakeTensorWrapper(
95 tensor,
96 level,
97 is_immutable,
98 getLifeHandleForLevel(level));
99 }
100
makeTensorWrapper(const Tensor & tensor,const Interpreter & interpreter,bool is_immutable)101 Tensor makeTensorWrapper(const Tensor& tensor, const Interpreter& interpreter, bool is_immutable) {
102 return unsafeMakeTensorWrapper(
103 tensor,
104 interpreter.level(),
105 is_immutable,
106 interpreter.is_alive_ptr());
107 }
108
109
is_alive() const110 bool TensorWrapper::is_alive() const {
111 return *is_alive_;
112 }
113
shallow_copy_and_detach(const c10::VariableVersion & version_counter,bool allow_tensor_metadata_change) const114 c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
115 const c10::VariableVersion& version_counter,
116 bool allow_tensor_metadata_change) const {
117 auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive_);
118 dest_impl->set_version_counter(version_counter);
119
120 // TODO: is this even right?
121 dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
122 return dest_impl;
123 }
124
shallow_copy_and_detach(c10::VariableVersion && version_counter,bool allow_tensor_metadata_change) const125 c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
126 c10::VariableVersion&& version_counter,
127 bool allow_tensor_metadata_change) const {
128 auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive_);
129 dest_impl->set_version_counter(version_counter);
130
131 // TODO: is this even right?
132 dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
133 return dest_impl;
134 }
135
shallow_copy_from(const c10::intrusive_ptr<TensorImpl> & impl)136 void TensorWrapper::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) {
137 TORCH_CHECK(false, "mutating directly with `.data` inside functorch transform is not allowed.");
138 }
139
TensorWrapper(c10::DispatchKeySet key_set,Tensor value,int64_t level,std::shared_ptr<bool> is_alive,bool is_immutable,bool use_value_sizes_strides)140 TensorWrapper::TensorWrapper(
141 c10::DispatchKeySet key_set,
142 Tensor value,
143 int64_t level,
144 std::shared_ptr<bool> is_alive,
145 bool is_immutable,
146 bool use_value_sizes_strides)
147 : TensorImpl(key_set, value.dtype(), value.device())
148 , value_(std::move(value))
149 , level_(level)
150 , is_immutable_(is_immutable)
151 , is_alive_(std::move(is_alive))
152 {
153 TORCH_INTERNAL_ASSERT(value_.defined());
154
155 // TODO: need to reset sizes/strides on mutation
156 TORCH_INTERNAL_ASSERT(use_value_sizes_strides);
157 refreshMetadata();
158
159 set_storage_access_should_throw();
160 }
161
tensorimpl_type_name() const162 const char* TensorWrapper::tensorimpl_type_name() const {
163 return "TensorWrapper";
164 }
165
166
maybeGetTensorWrapper(const Tensor & tensor)167 TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor) {
168 if (!tensor.key_set().has(DispatchKey::FuncTorchGradWrapper)) {
169 return nullptr;
170 }
171 return (TensorWrapper*)(tensor.unsafeGetTensorImpl());
172 }
173
dead_tensor_wrapper_fallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)174 static void dead_tensor_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
175 auto args_size = op.schema().arguments().size();
176 int64_t unwrapped_count = 0;
177 auto unwrapIfDeadAndIncrement = [&](const Tensor& tensor) {
178 auto* wrapped = maybeGetTensorWrapper(tensor);
179 if (!wrapped) {
180 return tensor;
181 }
182
183 // NOTE: We need to test for both is_alive and functorch mode dispatch keys
184 // being active because certain ops may disable the keys but not set
185 // the relevant tensor's state to dead.
186 // Example: torch.tensor([x, y, z]) - variant which accepts list of scalars
187 // leads to the above case.
188 constexpr auto functorch_mode_ks = DispatchKeySet(
189 {DispatchKey::FuncTorchDynamicLayerFrontMode,
190 DispatchKey::FuncTorchDynamicLayerBackMode});
191 if (wrapped->is_alive() && wrapped->key_set().has_any(functorch_mode_ks)) {
192 return tensor;
193 }
194 unwrapped_count++;
195 return wrapped->value();
196 };
197
198 foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrapIfDeadAndIncrement);
199 TORCH_INTERNAL_ASSERT(unwrapped_count > 0, "Should have at least one dead wrapper");
200
201 // re-dispatch
202 op.callBoxed(stack);
203 }
204
205 // TensorWrapper backend fallback: Unwrap and fallthrough.
206
TORCH_LIBRARY_IMPL(_,FuncTorchGradWrapper,m)207 TORCH_LIBRARY_IMPL(_, FuncTorchGradWrapper, m) {
208 m.fallback(torch::CppFunction::makeFromBoxedFunction<&dead_tensor_wrapper_fallback>());
209 }
210
211 } // namespace at::functorch
212