xref: /aosp_15_r20/external/pytorch/aten/src/ATen/FunctionalizeFallbackKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/dispatch/Dispatcher.h>
2 #include <ATen/core/LegacyTypeDispatch.h>
3 #include <ATen/EmptyTensor.h>
4 #include <ATen/FunctionalTensorWrapper.h>
5 #include <ATen/InferSize.h>
6 #include <ATen/TensorUtils.h>
7 #include <torch/library.h>
8 #include <c10/util/irange.h>
9 #include <c10/util/strides.h>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/ATen.h>
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/_to_copy.h>
17 #include <ATen/ops/to_native.h>
18 #include <ATen/ops/lift.h>
19 #include <ATen/ops/lift_fresh.h>
20 #include <ATen/ops/lift_fresh_copy.h>
21 #include <ATen/ops/resize.h>
22 #include <ATen/ops/as_strided.h>
23 #include <ATen/ops/as_strided_copy.h>
24 #include <ATen/ops/empty_strided_native.h>
25 #include <ATen/ops/_unsafe_view.h>
26 
27 #include <utility>
28 #endif
29 
30 namespace {
functionalizeFallback(const c10::OperatorHandle & op,c10::DispatchKeySet dispatchKeySet,torch::jit::Stack * stack)31   void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) {
32     const auto& schema = op.schema();
33     // NB: auto_functionalize handles the case where outputs do not have alias info.
34     // This error message therefore suggests users to modify their custom op to the
35     // point where auto_functionalize works instead of asking them to try the raw
36     // functionalization API (because that is a bit difficult to use).
37     // If you're here and want to try the raw functionalizaton kernel approach,
38     // see https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa
39     TORCH_CHECK(
40       !schema.hasAnyAliasInfo(),
41       "Found a custom (non-ATen) operator whose output has alias annotations: ",
42       op.schema(),
43       ". We only support functionalizing operators whose outputs do not have alias ",
44       "annotations (e.g. 'Tensor(a)' is a Tensor with an alias annotation whereas ",
45       "'Tensor' is a Tensor without. The '(a)' is the alias annotation). "
46       "The alias annotation specifies that the output ",
47       "Tensor shares storage with an input that has the same annotation. ",
48       "Please check if ",
49       "(1) the output needs to be an output (if not, don't return it), ",
50       "(2) if the output doesn't share storage with any inputs, then ",
51       "delete the alias annotation. ",
52       "(3) if the output indeed shares storage with an input, then add a ",
53       ".clone() before returning it to prevent storage sharing and then "
54       "delete the alias annotation. ",
55       "Otherwise, please file an issue on GitHub.");
56     const auto num_arguments = schema.arguments().size();
57     const auto arguments_begin = stack->size() - num_arguments;
58     auto arguments = torch::jit::last(stack, num_arguments);
59 
60     auto any_functional_inputs = false;
61     auto any_tensor_inputs = false;
62     for (uint64_t idx = 0; idx < num_arguments; ++idx) {
63       const auto& ivalue = arguments[idx];
64       if (ivalue.isTensor()) {
65         any_tensor_inputs = true;
66         const auto& t = ivalue.toTensor();
67         if (t.defined() && at::functionalization::impl::isFunctionalTensor(t)) {
68           any_functional_inputs = true;
69           at::functionalization::impl::sync(t);
70           auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
71           (*stack)[arguments_begin + idx] = t_new;
72         }
73       } else if (ivalue.isTensorList()) {
74         any_tensor_inputs = true;
75         auto tensors = ivalue.toTensorList();
76         if (at::functionalization::impl::isFunctionalTensor(tensors)) {
77           any_functional_inputs = true;
78           at::functionalization::impl::sync(tensors);
79           auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors));
80           (*stack)[arguments_begin + idx] = t_new;
81         }
82       } else if (ivalue.isOptionalTensorList()) {
83         any_tensor_inputs = true;
84         auto opt_tensors = ivalue.toOptionalTensorList();
85         if (at::functionalization::impl::isFunctionalTensor(opt_tensors)) {
86           any_functional_inputs = true;
87           at::functionalization::impl::sync(opt_tensors);
88           auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors));
89           (*stack)[arguments_begin + idx] = t_new;
90         }
91       }
92     }
93     // we should wrap the output if any inputs were wrapped,
94     // OR if we're hitting a factory function (with no tensor inputs)
95     auto should_wrap_outputs = !any_tensor_inputs || any_functional_inputs;
96     {
97       at::AutoDispatchSkipFunctionalize guard;
98       op.callBoxed(stack);
99     }
100     const auto num_returns = schema.returns().size();
101     const auto returns_begin = stack->size() - num_returns;
102     auto returns = torch::jit::last(stack, num_returns);
103 
104     for (const auto idx : c10::irange(num_returns)) {
105       const auto& ivalue = returns[idx];
106       if (ivalue.isTensor() && should_wrap_outputs) {
107         const auto& t = ivalue.toTensor();
108         if (!t.defined()) continue;
109         auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
110         (*stack)[returns_begin + idx] = t_new;
111       } else if (ivalue.isTensorList() && should_wrap_outputs) {
112         auto tensors = ivalue.toTensorList();
113         auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors));
114         (*stack)[returns_begin + idx] = t_new;
115       } else if (ivalue.isOptionalTensorList() && should_wrap_outputs) {
116         auto opt_tensors = ivalue.toOptionalTensorList();
117         auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors));
118         (*stack)[returns_begin + idx] = t_new;
119       }
120     }
121   }
122 }
123 
124 // resize_() is special because:
125 // - when we resize to a larger size, it acts as a mutation
126 // - when we resize to a smaller size, it acts as a view
127 // See Note [resize_ in Functionalization] for more dtails
resize__functionalization(c10::DispatchKeySet dispatchKeySet,const at::Tensor & self,at::IntArrayRef size,std::optional<at::MemoryFormat> memory_format)128 static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatchKeySet [[maybe_unused]], const at::Tensor & self, at::IntArrayRef size, std::optional<at::MemoryFormat> memory_format) {
129   // First unwrap the tensor arguments
130   at::Tensor self_;
131   if (at::functionalization::impl::isFunctionalTensor(self)) {
132     at::functionalization::impl::sync(self);
133     self_ = at::functionalization::impl::from_functional_tensor(self);
134   } else {
135     self_ = self;
136   }
137   // Case 1: arguments are not functional tensors, so we no-op and redispatch.
138   if (!at::functionalization::impl::isFunctionalTensor(self)) {
139      at::AutoDispatchSkipFunctionalize guard;
140      self_.resize_(size, memory_format);
141      return self;
142   }
143 
144   // Case 2: actually functionalize resize_()
145   at::Tensor tmp_output;
146   {
147     at::AutoDispatchSkipFunctionalize guard;
148     tmp_output = at::resize(self_, size, memory_format);
149   }
150 
151   auto itemsize = self.dtype().itemsize();
152   auto storage_offset = self.storage_offset();
153   auto new_size_bytes = at::detail::computeStorageNbytesContiguous(size, itemsize, storage_offset);
154   auto needs_resize_storage = new_size_bytes > self.storage().nbytes();
155 
156   if (needs_resize_storage) {
157     // If resize_() actually increases the size of the storage, then we need to tell FunctionalTensorWrapper about it.
158     // See Note[resize_() in functionalization pass]
159     auto func_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
160     func_impl->maybe_replace_storage(tmp_output);
161     // See the note - we're guaranteed at this point that "self" is *not* a view (and has no outstanding views)
162     // So we don't need to treat the output of resize as view tensor.
163     return self;
164   }
165 
166   // Otherwise, we know that we're resizing to a smaller size.
167   // resize_() is effectively a view operator.
168   // The output of resizing is equivalent to taking a slice of a larger tensor.
169   // We have to emulate this "slicing" with an as_strided call.
170   auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
171   at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
172     [reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
173       if (reapply_views) {
174         return base.as_strided(size, c10::contiguous_strides(size));
175       } else {
176         return at::as_strided_copy(base, size, c10::contiguous_strides(size));
177       }
178     },
179     [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
180       return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size));
181     },
182     /*has_symbolic_inputs=*/false
183   );
184   at::functionalization::impl::mutate_view_meta(self, view_meta);
185   return self;
186 }
187 
188 
lift_functionalize(const at::Tensor & self)189 static at::Tensor lift_functionalize(const at::Tensor & self) {
190   TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self));
191   at::AutoDispatchSkipFunctionalize guard;
192   auto out = at::lift(self);
193   return at::functionalization::impl::to_functional_tensor(out);
194 }
195 
lift_fresh_functionalize(const at::Tensor & self)196 static at::Tensor lift_fresh_functionalize(const at::Tensor & self) {
197   // See Note [Exporting and compiling a graph with lift_fresh_copy]
198   if (at::functionalization::impl::isFunctionalTensor(self)) {
199     return self.view_as(self);
200   }
201 
202   at::AutoDispatchSkipFunctionalize guard;
203   auto out = at::lift_fresh(self);
204   return at::functionalization::impl::to_functional_tensor(out);
205 }
206 
lift_fresh_functionalize_copy(const at::Tensor & self)207 static at::Tensor lift_fresh_functionalize_copy(const at::Tensor & self) {
208   // Note [Exporting and compiling a graph with lift_fresh_copy]
209   // If out is already a functional tensor, don't wrap it twice.
210   // In theory this could be useful if we want to nest functionalization with itself,
211   // but that isn't really a use case today.
212   // Needed for https://github.com/pytorch/pytorch/issues/105327
213   if (at::functionalization::impl::isFunctionalTensor(self)) {
214     // Note [Composite Functionalization under PreDispatch mode]
215     // When we are tracing under PreDispatch, PreDispatch key will be
216     // in the local include TLS. As a result, when we redispatch here,
217     // we will end up hitting PreDispatch stack first. So, we should
218     // directly redispatch to the functionalize key manually.
219     static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::clone", "").typed<at::Tensor(const at::Tensor &, std::optional<at::MemoryFormat>)>();
220     return op.redispatch(c10::DispatchKeySet({c10::DispatchKey::Functionalize}), self, std::nullopt);
221   }
222 
223   at::AutoDispatchSkipFunctionalize guard;
224   auto out = at::lift_fresh_copy(self);
225   return at::functionalization::impl::to_functional_tensor(out);
226 }
227 
device_opted_into_functionalization(c10::Device self_device,std::optional<c10::Device> tgt_device)228 static bool device_opted_into_functionalization(c10::Device self_device, std::optional<c10::Device> tgt_device) {
229     // If the target device is empty, then the output tensor should be on the same device as the input
230     auto real_tgt_device = tgt_device.has_value() ? tgt_device.value() : self_device;
231     return real_tgt_device.type() == c10::DeviceType::XLA || real_tgt_device.type() == c10::DeviceType::Lazy;
232 }
233 
234 // note I only need this because the to.dtype/to.dtype_layout overload calls this, so we skip the op above.
235 // We should probably get rid of this though.
_to_copy_functionalize(const at::Tensor & self,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory,bool non_blocking,std::optional<at::MemoryFormat> memory_format)236 static at::Tensor _to_copy_functionalize(
237         const at::Tensor & self,
238         std::optional<at::ScalarType> dtype,
239         std::optional<at::Layout> layout,
240         std::optional<at::Device> device,
241         std::optional<bool> pin_memory,
242         bool non_blocking,
243         std::optional<at::MemoryFormat> memory_format) {
244   at::Tensor self_;
245   if (at::functionalization::impl::isFunctionalTensor(self)) {
246     // sync any pending updates
247     at::functionalization::impl::sync(self);
248     // pass the unwrapped tensor to the backend
249     self_ = at::functionalization::impl::from_functional_tensor(self);
250   } else {
251     self_ = self;
252   }
253 
254   at::AutoDispatchSkipFunctionalize guard;
255   auto out = at::_to_copy(self_, dtype, layout, device, pin_memory, non_blocking, memory_format);
256 
257   // Special case: if the Functionalize key is not in TLS, we assume that we're running
258   // on a lazy backend (LTC).
259   // In that case, if we're copying to a non-functionalize-enabled device,
260   // then the functionalization pass should "end". We need to sync any updates on the input
261   // tensor, but we shouldn't wrap the output.
262   if (!c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {
263     if (!device_opted_into_functionalization(self.device(), device)) {
264       return out;
265     }
266   }
267   return at::functionalization::impl::to_functional_tensor(out);
268 }
269 
270 
271 // Why is _unsafe_view special-cased here?
272 // Basically just to satisfy autograd's debug asserts.
273 // The situation:
274 // - _unsafe_view's autograd kernel has debug asserts to confirm
275 //   that the input and output alias storage.
276 // - _unsafe_view's schema in native_functions.yaml
277 //   does not contain alias annotations, so it advertises as non-aliasing.
278 // - functionalization will then treat _unsafe_view like a non-aliasing op.
279 //   Specifically, autograd will redispatch to functionalization's
280 //   boxed fallback kernel, which creates a new FunctionalTensorWrapper output
281 //   that does **not** alias storage with the input, tripping the assert.
282 // The kernel written here just manually re-ifies the aliasing relationship.
283 //
284 // Another way to handle this would be to fix unsafe_view's alias annotations
285 // in native_functions.yaml, but I think this would be a pessimization.
286 // The idea with _unsafe_view is that you're guaranteed that the input
287 // is a temporary, and don't actually have to worry about propagating
288 // mutations between the input and output.
_unsafe_view_functionalize(const at::Tensor & self,at::SymIntArrayRef size)289 static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymIntArrayRef size) {
290   if (!at::functionalization::impl::isFunctionalTensor(self)) {
291     at::AutoDispatchSkipFunctionalize guard;
292     return at::_unsafe_view_symint(self, size);
293   }
294 
295   auto self_ = at::functionalization::impl::from_functional_tensor(self);
296   at::Tensor tmp_output;
297   {
298     at::AutoDispatchSkipFunctionalize guard;
299     tmp_output = at::_unsafe_view_symint(self_, size);
300   }
301 
302   bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });
303 
304   at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
305     [size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
306       return at::_unsafe_view_symint(base, size);
307     },
308     [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
309       return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
310     },
311     /*has_symbolic_inputs=*/has_symbolic_inputs
312   );
313 
314   auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta));
315   // See  Note [Propagating strides in the functionalization pass]
316   // (for _unsafe_view, I'm just manually doing the shape inference rule here instead of calling the meta function for unsafe_view)
317   auto inferred_size = at::infer_size_dv(size, self.sym_numel());
318   auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size);
319   TORCH_INTERNAL_ASSERT(stride.has_value());
320   out.unsafeGetTensorImpl()->set_sizes_and_strides(inferred_size, stride.value());
321   return out;
322 }
323 
set__functionalize(at::Tensor & self,const at::Tensor & src)324 static at::Tensor& set__functionalize(at::Tensor& self, const at::Tensor& src) {
325   // error case
326   TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(self) || !at::functionalization::impl::isFunctionalTensor(src),
327     "set__functionalize: Tried to mutate a non-functional tensor with a functional tensor, which is not allowed");
328 
329   // nop case
330   if (!at::functionalization::impl::isFunctionalTensor(self) && !at::functionalization::impl::isFunctionalTensor(src)) {
331     at::AutoDispatchSkipFunctionalize guard;
332     return self.set_(src);
333   }
334 
335   TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(src),
336     "set__functionalize: We do not currently support x.set_(y) where y is not a FunctionalTensor. Please file an issue");
337 
338   TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
339   TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(src));
340   auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
341   auto src_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(src);
342   // See Note [Ordering of resize_() and set_()]
343   TORCH_CHECK(!self_impl->was_inductor_storage_resized(),
344     "storage_resize_() followed by set_() in torch.compile is not supported today");
345   self_impl->set__impl(src_impl);
346   return self;
347 }
348 
TORCH_LIBRARY_IMPL(_,Functionalize,m)349 TORCH_LIBRARY_IMPL(_, Functionalize, m) {
350   m.fallback(torch::CppFunction::makeFromBoxedFunction<&functionalizeFallback>());
351 }
352 
TORCH_LIBRARY_IMPL(aten,Functionalize,m)353 TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
354   m.impl("resize_", TORCH_FN(resize__functionalization));
355   m.impl("lift", TORCH_FN(lift_functionalize));
356   m.impl("lift_fresh", TORCH_FN(lift_fresh_functionalize));
357   m.impl("lift_fresh_copy", TORCH_FN(lift_fresh_functionalize_copy));
358   m.impl("_to_copy", TORCH_FN(_to_copy_functionalize));
359   m.impl("_unsafe_view", TORCH_FN(_unsafe_view_functionalize));
360   // The overloads of set_() that take in a storage should never
361   // appear with torch.compile, because dynamo graph breaks
362   m.impl("set_.source_Tensor", TORCH_FN(set__functionalize));
363 }
364