1 #include <ATen/core/dispatch/Dispatcher.h>
2 #include <ATen/core/LegacyTypeDispatch.h>
3 #include <ATen/EmptyTensor.h>
4 #include <ATen/FunctionalTensorWrapper.h>
5 #include <ATen/InferSize.h>
6 #include <ATen/TensorUtils.h>
7 #include <torch/library.h>
8 #include <c10/util/irange.h>
9 #include <c10/util/strides.h>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/ATen.h>
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/_to_copy.h>
17 #include <ATen/ops/to_native.h>
18 #include <ATen/ops/lift.h>
19 #include <ATen/ops/lift_fresh.h>
20 #include <ATen/ops/lift_fresh_copy.h>
21 #include <ATen/ops/resize.h>
22 #include <ATen/ops/as_strided.h>
23 #include <ATen/ops/as_strided_copy.h>
24 #include <ATen/ops/empty_strided_native.h>
25 #include <ATen/ops/_unsafe_view.h>
26
27 #include <utility>
28 #endif
29
30 namespace {
functionalizeFallback(const c10::OperatorHandle & op,c10::DispatchKeySet dispatchKeySet,torch::jit::Stack * stack)31 void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) {
32 const auto& schema = op.schema();
33 // NB: auto_functionalize handles the case where outputs do not have alias info.
34 // This error message therefore suggests users to modify their custom op to the
35 // point where auto_functionalize works instead of asking them to try the raw
36 // functionalization API (because that is a bit difficult to use).
37 // If you're here and want to try the raw functionalizaton kernel approach,
38 // see https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa
39 TORCH_CHECK(
40 !schema.hasAnyAliasInfo(),
41 "Found a custom (non-ATen) operator whose output has alias annotations: ",
42 op.schema(),
43 ". We only support functionalizing operators whose outputs do not have alias ",
44 "annotations (e.g. 'Tensor(a)' is a Tensor with an alias annotation whereas ",
45 "'Tensor' is a Tensor without. The '(a)' is the alias annotation). "
46 "The alias annotation specifies that the output ",
47 "Tensor shares storage with an input that has the same annotation. ",
48 "Please check if ",
49 "(1) the output needs to be an output (if not, don't return it), ",
50 "(2) if the output doesn't share storage with any inputs, then ",
51 "delete the alias annotation. ",
52 "(3) if the output indeed shares storage with an input, then add a ",
53 ".clone() before returning it to prevent storage sharing and then "
54 "delete the alias annotation. ",
55 "Otherwise, please file an issue on GitHub.");
56 const auto num_arguments = schema.arguments().size();
57 const auto arguments_begin = stack->size() - num_arguments;
58 auto arguments = torch::jit::last(stack, num_arguments);
59
60 auto any_functional_inputs = false;
61 auto any_tensor_inputs = false;
62 for (uint64_t idx = 0; idx < num_arguments; ++idx) {
63 const auto& ivalue = arguments[idx];
64 if (ivalue.isTensor()) {
65 any_tensor_inputs = true;
66 const auto& t = ivalue.toTensor();
67 if (t.defined() && at::functionalization::impl::isFunctionalTensor(t)) {
68 any_functional_inputs = true;
69 at::functionalization::impl::sync(t);
70 auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
71 (*stack)[arguments_begin + idx] = t_new;
72 }
73 } else if (ivalue.isTensorList()) {
74 any_tensor_inputs = true;
75 auto tensors = ivalue.toTensorList();
76 if (at::functionalization::impl::isFunctionalTensor(tensors)) {
77 any_functional_inputs = true;
78 at::functionalization::impl::sync(tensors);
79 auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors));
80 (*stack)[arguments_begin + idx] = t_new;
81 }
82 } else if (ivalue.isOptionalTensorList()) {
83 any_tensor_inputs = true;
84 auto opt_tensors = ivalue.toOptionalTensorList();
85 if (at::functionalization::impl::isFunctionalTensor(opt_tensors)) {
86 any_functional_inputs = true;
87 at::functionalization::impl::sync(opt_tensors);
88 auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors));
89 (*stack)[arguments_begin + idx] = t_new;
90 }
91 }
92 }
93 // we should wrap the output if any inputs were wrapped,
94 // OR if we're hitting a factory function (with no tensor inputs)
95 auto should_wrap_outputs = !any_tensor_inputs || any_functional_inputs;
96 {
97 at::AutoDispatchSkipFunctionalize guard;
98 op.callBoxed(stack);
99 }
100 const auto num_returns = schema.returns().size();
101 const auto returns_begin = stack->size() - num_returns;
102 auto returns = torch::jit::last(stack, num_returns);
103
104 for (const auto idx : c10::irange(num_returns)) {
105 const auto& ivalue = returns[idx];
106 if (ivalue.isTensor() && should_wrap_outputs) {
107 const auto& t = ivalue.toTensor();
108 if (!t.defined()) continue;
109 auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
110 (*stack)[returns_begin + idx] = t_new;
111 } else if (ivalue.isTensorList() && should_wrap_outputs) {
112 auto tensors = ivalue.toTensorList();
113 auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors));
114 (*stack)[returns_begin + idx] = t_new;
115 } else if (ivalue.isOptionalTensorList() && should_wrap_outputs) {
116 auto opt_tensors = ivalue.toOptionalTensorList();
117 auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors));
118 (*stack)[returns_begin + idx] = t_new;
119 }
120 }
121 }
122 }
123
124 // resize_() is special because:
125 // - when we resize to a larger size, it acts as a mutation
126 // - when we resize to a smaller size, it acts as a view
127 // See Note [resize_ in Functionalization] for more dtails
resize__functionalization(c10::DispatchKeySet dispatchKeySet,const at::Tensor & self,at::IntArrayRef size,std::optional<at::MemoryFormat> memory_format)128 static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatchKeySet [[maybe_unused]], const at::Tensor & self, at::IntArrayRef size, std::optional<at::MemoryFormat> memory_format) {
129 // First unwrap the tensor arguments
130 at::Tensor self_;
131 if (at::functionalization::impl::isFunctionalTensor(self)) {
132 at::functionalization::impl::sync(self);
133 self_ = at::functionalization::impl::from_functional_tensor(self);
134 } else {
135 self_ = self;
136 }
137 // Case 1: arguments are not functional tensors, so we no-op and redispatch.
138 if (!at::functionalization::impl::isFunctionalTensor(self)) {
139 at::AutoDispatchSkipFunctionalize guard;
140 self_.resize_(size, memory_format);
141 return self;
142 }
143
144 // Case 2: actually functionalize resize_()
145 at::Tensor tmp_output;
146 {
147 at::AutoDispatchSkipFunctionalize guard;
148 tmp_output = at::resize(self_, size, memory_format);
149 }
150
151 auto itemsize = self.dtype().itemsize();
152 auto storage_offset = self.storage_offset();
153 auto new_size_bytes = at::detail::computeStorageNbytesContiguous(size, itemsize, storage_offset);
154 auto needs_resize_storage = new_size_bytes > self.storage().nbytes();
155
156 if (needs_resize_storage) {
157 // If resize_() actually increases the size of the storage, then we need to tell FunctionalTensorWrapper about it.
158 // See Note[resize_() in functionalization pass]
159 auto func_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
160 func_impl->maybe_replace_storage(tmp_output);
161 // See the note - we're guaranteed at this point that "self" is *not* a view (and has no outstanding views)
162 // So we don't need to treat the output of resize as view tensor.
163 return self;
164 }
165
166 // Otherwise, we know that we're resizing to a smaller size.
167 // resize_() is effectively a view operator.
168 // The output of resizing is equivalent to taking a slice of a larger tensor.
169 // We have to emulate this "slicing" with an as_strided call.
170 auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
171 at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
172 [reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
173 if (reapply_views) {
174 return base.as_strided(size, c10::contiguous_strides(size));
175 } else {
176 return at::as_strided_copy(base, size, c10::contiguous_strides(size));
177 }
178 },
179 [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
180 return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size));
181 },
182 /*has_symbolic_inputs=*/false
183 );
184 at::functionalization::impl::mutate_view_meta(self, view_meta);
185 return self;
186 }
187
188
lift_functionalize(const at::Tensor & self)189 static at::Tensor lift_functionalize(const at::Tensor & self) {
190 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self));
191 at::AutoDispatchSkipFunctionalize guard;
192 auto out = at::lift(self);
193 return at::functionalization::impl::to_functional_tensor(out);
194 }
195
lift_fresh_functionalize(const at::Tensor & self)196 static at::Tensor lift_fresh_functionalize(const at::Tensor & self) {
197 // See Note [Exporting and compiling a graph with lift_fresh_copy]
198 if (at::functionalization::impl::isFunctionalTensor(self)) {
199 return self.view_as(self);
200 }
201
202 at::AutoDispatchSkipFunctionalize guard;
203 auto out = at::lift_fresh(self);
204 return at::functionalization::impl::to_functional_tensor(out);
205 }
206
lift_fresh_functionalize_copy(const at::Tensor & self)207 static at::Tensor lift_fresh_functionalize_copy(const at::Tensor & self) {
208 // Note [Exporting and compiling a graph with lift_fresh_copy]
209 // If out is already a functional tensor, don't wrap it twice.
210 // In theory this could be useful if we want to nest functionalization with itself,
211 // but that isn't really a use case today.
212 // Needed for https://github.com/pytorch/pytorch/issues/105327
213 if (at::functionalization::impl::isFunctionalTensor(self)) {
214 // Note [Composite Functionalization under PreDispatch mode]
215 // When we are tracing under PreDispatch, PreDispatch key will be
216 // in the local include TLS. As a result, when we redispatch here,
217 // we will end up hitting PreDispatch stack first. So, we should
218 // directly redispatch to the functionalize key manually.
219 static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::clone", "").typed<at::Tensor(const at::Tensor &, std::optional<at::MemoryFormat>)>();
220 return op.redispatch(c10::DispatchKeySet({c10::DispatchKey::Functionalize}), self, std::nullopt);
221 }
222
223 at::AutoDispatchSkipFunctionalize guard;
224 auto out = at::lift_fresh_copy(self);
225 return at::functionalization::impl::to_functional_tensor(out);
226 }
227
device_opted_into_functionalization(c10::Device self_device,std::optional<c10::Device> tgt_device)228 static bool device_opted_into_functionalization(c10::Device self_device, std::optional<c10::Device> tgt_device) {
229 // If the target device is empty, then the output tensor should be on the same device as the input
230 auto real_tgt_device = tgt_device.has_value() ? tgt_device.value() : self_device;
231 return real_tgt_device.type() == c10::DeviceType::XLA || real_tgt_device.type() == c10::DeviceType::Lazy;
232 }
233
234 // note I only need this because the to.dtype/to.dtype_layout overload calls this, so we skip the op above.
235 // We should probably get rid of this though.
_to_copy_functionalize(const at::Tensor & self,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory,bool non_blocking,std::optional<at::MemoryFormat> memory_format)236 static at::Tensor _to_copy_functionalize(
237 const at::Tensor & self,
238 std::optional<at::ScalarType> dtype,
239 std::optional<at::Layout> layout,
240 std::optional<at::Device> device,
241 std::optional<bool> pin_memory,
242 bool non_blocking,
243 std::optional<at::MemoryFormat> memory_format) {
244 at::Tensor self_;
245 if (at::functionalization::impl::isFunctionalTensor(self)) {
246 // sync any pending updates
247 at::functionalization::impl::sync(self);
248 // pass the unwrapped tensor to the backend
249 self_ = at::functionalization::impl::from_functional_tensor(self);
250 } else {
251 self_ = self;
252 }
253
254 at::AutoDispatchSkipFunctionalize guard;
255 auto out = at::_to_copy(self_, dtype, layout, device, pin_memory, non_blocking, memory_format);
256
257 // Special case: if the Functionalize key is not in TLS, we assume that we're running
258 // on a lazy backend (LTC).
259 // In that case, if we're copying to a non-functionalize-enabled device,
260 // then the functionalization pass should "end". We need to sync any updates on the input
261 // tensor, but we shouldn't wrap the output.
262 if (!c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {
263 if (!device_opted_into_functionalization(self.device(), device)) {
264 return out;
265 }
266 }
267 return at::functionalization::impl::to_functional_tensor(out);
268 }
269
270
271 // Why is _unsafe_view special-cased here?
272 // Basically just to satisfy autograd's debug asserts.
273 // The situation:
274 // - _unsafe_view's autograd kernel has debug asserts to confirm
275 // that the input and output alias storage.
276 // - _unsafe_view's schema in native_functions.yaml
277 // does not contain alias annotations, so it advertises as non-aliasing.
278 // - functionalization will then treat _unsafe_view like a non-aliasing op.
279 // Specifically, autograd will redispatch to functionalization's
280 // boxed fallback kernel, which creates a new FunctionalTensorWrapper output
281 // that does **not** alias storage with the input, tripping the assert.
282 // The kernel written here just manually re-ifies the aliasing relationship.
283 //
284 // Another way to handle this would be to fix unsafe_view's alias annotations
285 // in native_functions.yaml, but I think this would be a pessimization.
286 // The idea with _unsafe_view is that you're guaranteed that the input
287 // is a temporary, and don't actually have to worry about propagating
288 // mutations between the input and output.
_unsafe_view_functionalize(const at::Tensor & self,at::SymIntArrayRef size)289 static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymIntArrayRef size) {
290 if (!at::functionalization::impl::isFunctionalTensor(self)) {
291 at::AutoDispatchSkipFunctionalize guard;
292 return at::_unsafe_view_symint(self, size);
293 }
294
295 auto self_ = at::functionalization::impl::from_functional_tensor(self);
296 at::Tensor tmp_output;
297 {
298 at::AutoDispatchSkipFunctionalize guard;
299 tmp_output = at::_unsafe_view_symint(self_, size);
300 }
301
302 bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });
303
304 at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
305 [size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
306 return at::_unsafe_view_symint(base, size);
307 },
308 [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
309 return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
310 },
311 /*has_symbolic_inputs=*/has_symbolic_inputs
312 );
313
314 auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta));
315 // See Note [Propagating strides in the functionalization pass]
316 // (for _unsafe_view, I'm just manually doing the shape inference rule here instead of calling the meta function for unsafe_view)
317 auto inferred_size = at::infer_size_dv(size, self.sym_numel());
318 auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size);
319 TORCH_INTERNAL_ASSERT(stride.has_value());
320 out.unsafeGetTensorImpl()->set_sizes_and_strides(inferred_size, stride.value());
321 return out;
322 }
323
set__functionalize(at::Tensor & self,const at::Tensor & src)324 static at::Tensor& set__functionalize(at::Tensor& self, const at::Tensor& src) {
325 // error case
326 TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(self) || !at::functionalization::impl::isFunctionalTensor(src),
327 "set__functionalize: Tried to mutate a non-functional tensor with a functional tensor, which is not allowed");
328
329 // nop case
330 if (!at::functionalization::impl::isFunctionalTensor(self) && !at::functionalization::impl::isFunctionalTensor(src)) {
331 at::AutoDispatchSkipFunctionalize guard;
332 return self.set_(src);
333 }
334
335 TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(src),
336 "set__functionalize: We do not currently support x.set_(y) where y is not a FunctionalTensor. Please file an issue");
337
338 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
339 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(src));
340 auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
341 auto src_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(src);
342 // See Note [Ordering of resize_() and set_()]
343 TORCH_CHECK(!self_impl->was_inductor_storage_resized(),
344 "storage_resize_() followed by set_() in torch.compile is not supported today");
345 self_impl->set__impl(src_impl);
346 return self;
347 }
348
TORCH_LIBRARY_IMPL(_,Functionalize,m)349 TORCH_LIBRARY_IMPL(_, Functionalize, m) {
350 m.fallback(torch::CppFunction::makeFromBoxedFunction<&functionalizeFallback>());
351 }
352
TORCH_LIBRARY_IMPL(aten,Functionalize,m)353 TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
354 m.impl("resize_", TORCH_FN(resize__functionalization));
355 m.impl("lift", TORCH_FN(lift_functionalize));
356 m.impl("lift_fresh", TORCH_FN(lift_fresh_functionalize));
357 m.impl("lift_fresh_copy", TORCH_FN(lift_fresh_functionalize_copy));
358 m.impl("_to_copy", TORCH_FN(_to_copy_functionalize));
359 m.impl("_unsafe_view", TORCH_FN(_unsafe_view_functionalize));
360 // The overloads of set_() that take in a storage should never
361 // appear with torch.compile, because dynamo graph breaks
362 m.impl("set_.source_Tensor", TORCH_FN(set__functionalize));
363 }
364