xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/ts_native_functions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/FunctionalTensorWrapper.h>
2 #include <ATen/Functions.h>
3 #include <ATen/MetaFunctions.h>
4 #include <ATen/NativeFunctions.h>
5 #include <ATen/Operators.h>
6 #include <ATen/native/BinaryOps.h>
7 #include <ATen/native/CPUFallback.h>
8 #include <torch/csrc/lazy/core/helpers.h>
9 #include <torch/csrc/lazy/core/ir_builder.h>
10 #include <torch/csrc/lazy/core/metrics.h>
11 #include <torch/csrc/lazy/core/ops/utils.h>
12 #include <torch/csrc/lazy/core/shape_inference.h>
13 #include <torch/csrc/lazy/core/tensor_impl.h>
14 #include <torch/csrc/lazy/core/tensor_util.h>
15 #include <torch/csrc/lazy/generated/LazyNativeFunctions.h>
16 #include <torch/csrc/lazy/ts_backend/config.h>
17 #include <torch/csrc/lazy/ts_backend/ops/to_copy.h>
18 #include <torch/csrc/lazy/ts_backend/tensor_aten_ops.h>
19 #include <torch/csrc/lazy/ts_backend/ts_autograd_functions.h>
20 #include <torch/csrc/lazy/ts_backend/ts_eager_fallback.h>
21 #include <torch/library.h>
22 
23 using at::Tensor;
24 
25 namespace torch {
26 namespace lazy {
27 namespace {
28 
CreateLtcTensor(const at::Tensor & tensor,const std::optional<torch::lazy::BackendDevice> & device)29 at::Tensor CreateLtcTensor(
30     const at::Tensor& tensor,
31     const std::optional<torch::lazy::BackendDevice>& device) {
32   if (tensor.defined() && device) {
33     return torch::lazy::CreateAtenFromLtcTensor(
34         torch::lazy::LazyTensor::Create(tensor, *device));
35   }
36   return tensor;
37 }
38 
GetLtcDevice(const std::optional<c10::Device> & device)39 std::optional<torch::lazy::BackendDevice> GetLtcDevice(
40     const std::optional<c10::Device>& device) {
41   if (!device) {
42     return std::nullopt;
43   }
44   if (device->type() != at::kLazy) {
45     return std::nullopt;
46   }
47   return torch::lazy::atenDeviceToBackendDevice(*device);
48 }
49 
50 } // namespace
51 
52 // clone is special in LT because we make it a no-op.
53 // This should be safe to do, because every operator in the LT is functional.
clone(const at::Tensor & self,std::optional<at::MemoryFormat> memory_format)54 at::Tensor LazyNativeFunctions::clone(
55     const at::Tensor& self,
56     std::optional<at::MemoryFormat> memory_format) {
57   auto self_lt = torch::lazy::TryGetLtcTensor(self);
58   return torch::lazy::CreateAtenFromLtcTensor(
59       self_lt->Create(self_lt->GetIrValue(), self_lt->GetDevice()));
60 }
61 
_copy_from(const at::Tensor & self,const at::Tensor & dst,bool non_blocking)62 at::Tensor LazyNativeFunctions::_copy_from(
63     const at::Tensor& self,
64     const at::Tensor& dst,
65     bool non_blocking) {
66   TORCH_LAZY_FN_COUNTER("lazy::");
67   auto dst_tensor = torch::lazy::TryGetLtcTensor(dst);
68   auto self_tensor = torch::lazy::TryGetLtcTensor(self);
69   if (!self_tensor) {
70     // providing a new 'eager' value (self) for an existing lazy tensor (dst)
71     static bool sync_update = FLAGS_torch_lazy_ts_tensor_update_sync;
72     TORCH_CHECK(dst_tensor);
73     dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update);
74   } else if (!dst_tensor) {
75     // materializing a lazy tensor (self) and copying its value into eager
76     // tensor (dst) detached=false lets us skip a copy in `ToTensor`, which
77     // should be safe because we are only going to use the tensor for
78     // dst.copy_()
79     TORCH_CHECK(self_tensor);
80     at::Tensor tensor = self_tensor->ToTensor(/*detached=*/false);
81     at::Tensor typed_tensor =
82         torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false);
83     dst.resize_as_(typed_tensor).copy_(typed_tensor);
84   } else {
85     // Copying one lazy tensor to another
86     if (!dst_tensor->CurrentIrValue()) {
87       // if dest is not backed by IR (e.g. result of some lazy operation),
88       // then it should have at::Tensor data backing it instead
89       auto dst_tensor_data = dst_tensor->CurrentTensorData();
90       TORCH_CHECK(dst_tensor_data);
91       auto src_tensor_data = self_tensor->CurrentTensorData();
92       if (src_tensor_data) {
93         // both src/dst are simply backed by at::Tensor data, no IR- do a
94         // straightforward copy
95         dst_tensor_data->copy_(*src_tensor_data);
96       } else {
97         // src needs to be materialized before its result can be used for a copy
98         // into dst since we use the src tensor only for making a copy, we don't
99         // need to detach it note: it would be even more efficient if we could
100         // cause ToTensor to materialize the value directly into dst's buffer
101         // (that would need to be detached though).
102         dst_tensor_data->copy_(self_tensor->ToTensor(/*detached=*/false));
103       }
104     } else {
105       copy_(dst_tensor, self_tensor);
106       auto* impl =
107           dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl());
108       impl->set_tensor(dst_tensor);
109     }
110   }
111   return dst;
112 }
113 
_copy_from_and_resize(const at::Tensor & self,const at::Tensor & dst)114 at::Tensor LazyNativeFunctions::_copy_from_and_resize(
115     const at::Tensor& self,
116     const at::Tensor& dst) {
117   TORCH_LAZY_FN_COUNTER("lazy::");
118   auto dst_tensor = torch::lazy::TryGetLtcTensor(dst);
119   auto self_tensor = torch::lazy::TryGetLtcTensor(self);
120   if (!self_tensor) {
121     TORCH_CHECK(dst_tensor);
122     dst_tensor->UpdateFromTensorOut(self);
123   } else if (!dst_tensor) {
124     TORCH_CHECK(self_tensor);
125     at::Tensor tensor = self_tensor->ToTensor(/*detached=*/true);
126     at::Tensor typed_tensor =
127         torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false);
128     dst.resize_as_(typed_tensor).copy_(typed_tensor);
129   } else {
130     // at this point we know dst is a lazy tensor
131     auto* dest_impl =
132         dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl());
133     TORCH_CHECK(dest_impl);
134     dest_impl->tensor()->UpdateFromTensorOut(self_tensor);
135     dest_impl->force_refresh_sizes();
136   }
137   return dst;
138 }
139 
_to_copy(const at::Tensor & self,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory,bool non_blocking,std::optional<at::MemoryFormat> memory_format)140 at::Tensor LazyNativeFunctions::_to_copy(
141     const at::Tensor& self,
142     std::optional<at::ScalarType> dtype,
143     std::optional<at::Layout> layout,
144     std::optional<at::Device> device,
145     std::optional<bool> pin_memory,
146     bool non_blocking,
147     std::optional<at::MemoryFormat> memory_format) {
148   if (force_eager_fallback(at::aten::_to_copy)) {
149     TORCH_INTERNAL_ASSERT(
150         false,
151         "Fallback is currently impossible for _to_copy since the fallback helper itself reinvokes _to_copy");
152   }
153 
154   auto options = self.options();
155   if (dtype) {
156     // I put each of these setters in a conditional instead of doing
157     // `self.options().dtype(dtype).layout(layout)... because calling
158     // .dtype(nullopt) on an options() that already has dtype appears to wipe it
159     options = options.dtype(dtype);
160   }
161   if (layout) {
162     options = options.layout(layout);
163   }
164   if (memory_format) {
165     options = options.memory_format(memory_format);
166   }
167   if (pin_memory) {
168     // TODO(whc) can we honor 'pin_memory' in some/all cases?
169     options = options.pinned_memory(pin_memory);
170     TORCH_WARN_ONCE(
171         "Pinned memory used in lazy _to_copy, check if the behavior is as intended");
172   }
173 
174   TORCH_LAZY_FN_COUNTER("lazy::");
175   auto lazy_self = torch::lazy::TryGetLtcTensor(self);
176   if (!lazy_self && device && device->type() == c10::kLazy) {
177     // Case 1: eager->lazy (we create a new lazy tensor)
178     // See Note [Lazy Tensor Functionalization]
179     // Invariant: if the functionalization key is in the exclude set, then we're
180     // expected to return an ordinary tensor, which will be "lifted" into a
181     // functional wrapper later.
182     bool functionalize_output =
183         !c10::impl::tls_local_dispatch_key_set().excluded_.has(
184             c10::DispatchKey::Functionalize);
185     return torch::lazy::to_lazy_tensor(
186         self,
187         options,
188         *device,
189         /*non_blocking=*/non_blocking,
190         /*functionalize_output=*/functionalize_output);
191   } else if (device && device->type() != c10::kLazy) {
192     // Case 2: lazy->eager (forces a graph break since we are materializing a
193     // tensor)
194 
195     TORCH_INTERNAL_ASSERT(lazy_self);
196     auto eager_tensor = lazy_self->ToTensor(/*detached=*/true);
197     options = options.device(device);
198     auto moved_eager_tensor =
199         eager_tensor.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
200     return moved_eager_tensor;
201   } else if (
202       device && device->type() == c10::kLazy && device->has_index() &&
203       device->index() != self.device().index()) {
204     // Case 3: lazy:0 -> lazy:1
205 
206     // TODO(whc) what do we actually want to do here?
207     //   option 1: materialize, move eager tensor, create new lazy tensor
208     //     - this should be our default, as it is what would happen before we
209     //     implemented _to_copy
210     //     - actually combines case 1 + case 2
211     //   option 2: support multiple devices inside one lazy/TS executor (case 4)
212     //     - but: we may have other assumptions that there is just one device
213     //     per executor? so don't take this lightly
214 
215     TORCH_INTERNAL_ASSERT(lazy_self);
216     auto eager_tensor = lazy_self->ToTensor(/*detached=*/true);
217     // we move the eager tensor to the 'eager' equivalent of our lazy device
218     // e.g. if our device is lazy:1, the backend maps that to cuda:1, which is
219     // what we use
220     auto eager_device = c10::Device(
221         torch::lazy::getBackend()->EagerFallbackDeviceType(), device->index());
222     options = options.device(eager_device);
223     auto moved_eager_tensor =
224         eager_tensor.to(options, /*non_blocking=*/false, /*copy=*/true);
225     lazy_self = torch::lazy::GetOrCreateLtcTensor(
226         moved_eager_tensor,
227         torch::lazy::atenDeviceToBackendDevice(eager_device));
228     return torch::lazy::CreateAtenFromLtcTensor(lazy_self);
229 
230   } else {
231     // Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy
232     // graph)
233 
234     // Note: captured _to_copy will be executed with real eager tensors, not
235     // lazy tensors. We DO NOT want to burn 'lazy:0' as the device into this
236     // captured IR, or we will try to convert an eager tensor back to a lazy one
237     // inside the torchscript executor lazy:0 -> lazy:1 is handled in case3, so
238     // we can safely drop the device argument
239     device = std::nullopt;
240 
241     torch::lazy::NodePtr node = torch::lazy::ReuseNode<ToCopy>(
242         lazy_self->GetIrValue(),
243         dtype,
244         layout,
245         device,
246         pin_memory,
247         non_blocking,
248         memory_format);
249     if (!node) {
250       auto shapes = torch::lazy::compute_shape__to_copy(
251           self, dtype, layout, device, pin_memory, non_blocking, memory_format);
252       TORCH_INTERNAL_ASSERT(shapes.size() == 1);
253       node = torch::lazy::MakeNode<ToCopy>(
254           lazy_self->GetIrValue(),
255           dtype,
256           layout,
257           device,
258           pin_memory,
259           non_blocking,
260           memory_format,
261           std::move(shapes));
262       CacheNode(node);
263     }
264 
265     auto result =
266         torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create(
267             std::move(node), lazy_self->GetDevice()));
268     return result;
269   }
270 };
271 
empty_symint(at::SymIntArrayRef sym_size,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory,std::optional<at::MemoryFormat> memory_format)272 at::Tensor LazyNativeFunctions::empty_symint(
273     at::SymIntArrayRef sym_size,
274     std::optional<at::ScalarType> dtype,
275     std::optional<at::Layout> layout,
276     std::optional<at::Device> device,
277     std::optional<bool> pin_memory,
278     std::optional<at::MemoryFormat> memory_format) {
279   // TODO: support this directly
280   auto size = C10_AS_INTARRAYREF_SLOW(sym_size);
281   const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType();
282   at::TensorOptions options = at::TensorOptions()
283                                   .device(c10::Device(device_type))
284                                   .layout(layout)
285                                   .pinned_memory(pin_memory)
286                                   .dtype(dtype);
287   auto x_result = at::empty(size, options, memory_format);
288   auto tensor = CreateLtcTensor(x_result, GetLtcDevice(device));
289   // See Note [Lazy Tensor Functionalization]
290   if (c10::impl::tls_local_dispatch_key_set().excluded_.has(
291           c10::DispatchKey::Functionalize)) {
292     // Invariant: if the functionalization key is in the exclude set, then we're
293     // expected to return an ordinary tensor, which will be "lifted" into a
294     // functional wrapper later.
295     return tensor;
296   } else {
297     auto wrapped = at::functionalization::impl::to_functional_tensor(tensor);
298     return wrapped;
299   }
300 }
301 
empty_strided_symint(at::SymIntArrayRef sym_size,at::SymIntArrayRef sym_stride,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory)302 at::Tensor LazyNativeFunctions::empty_strided_symint(
303     at::SymIntArrayRef sym_size,
304     at::SymIntArrayRef sym_stride,
305     std::optional<at::ScalarType> dtype,
306     std::optional<at::Layout> layout,
307     std::optional<at::Device> device,
308     std::optional<bool> pin_memory) {
309   TORCH_LAZY_FN_COUNTER("lazy::");
310   at::Tensor t =
311       empty_symint(sym_size, dtype, layout, device, pin_memory, std::nullopt);
312   auto size = C10_AS_INTARRAYREF_SLOW(sym_size);
313   auto stride = C10_AS_INTARRAYREF_SLOW(sym_stride);
314   return t.as_strided(size, stride, /*storage_offset=*/0);
315 }
316 
fill_(at::Tensor & self,const at::Scalar & value)317 at::Tensor& LazyNativeFunctions::fill_(
318     at::Tensor& self,
319     const at::Scalar& value) {
320   TORCH_LAZY_FN_COUNTER("lazy::");
321   auto self_tensor = torch::lazy::TryGetLtcTensor(self);
322   torch::lazy::fill_(self_tensor, value);
323   return self;
324 }
325 
max_pool3d(const at::Tensor & self,at::IntArrayRef kernel_size,at::IntArrayRef stride,at::IntArrayRef padding,at::IntArrayRef dilation,bool ceil_mode)326 at::Tensor LazyNativeFunctions::max_pool3d(
327     const at::Tensor& self,
328     at::IntArrayRef kernel_size,
329     at::IntArrayRef stride,
330     at::IntArrayRef padding,
331     at::IntArrayRef dilation,
332     bool ceil_mode) {
333   return torch::lazy::MaxPool3dAutogradFunctionTS::apply(
334       self, kernel_size, stride, padding, dilation, ceil_mode);
335 }
336 
337 // We need to explicitly override max pooling operators and just call the
338 // fallback for them because we've customized the autograd function for them
339 // (backward needs saved indices from forward).
max_pool3d_with_indices(const at::Tensor & self,at::IntArrayRef kernel_size,at::IntArrayRef stride,at::IntArrayRef padding,at::IntArrayRef dilation,bool ceil_mode)340 std::tuple<at::Tensor, at::Tensor> LazyNativeFunctions::max_pool3d_with_indices(
341     const at::Tensor& self,
342     at::IntArrayRef kernel_size,
343     at::IntArrayRef stride,
344     at::IntArrayRef padding,
345     at::IntArrayRef dilation,
346     bool ceil_mode) {
347   return at::native::
348       call_fallback_fn<&ltc_eager_fallback, ATEN_OP(max_pool3d_with_indices)>::
349           call(self, kernel_size, stride, padding, dilation, ceil_mode);
350 }
351 
max_pool3d_with_indices_backward(const at::Tensor & grad_output,const at::Tensor & self,at::IntArrayRef kernel_size,at::IntArrayRef stride,at::IntArrayRef padding,at::IntArrayRef dilation,bool ceil_mode,const at::Tensor & indices)352 at::Tensor LazyNativeFunctions::max_pool3d_with_indices_backward(
353     const at::Tensor& grad_output,
354     const at::Tensor& self,
355     at::IntArrayRef kernel_size,
356     at::IntArrayRef stride,
357     at::IntArrayRef padding,
358     at::IntArrayRef dilation,
359     bool ceil_mode,
360     const at::Tensor& indices) {
361   return at::native::call_fallback_fn<
362       &ltc_eager_fallback,
363       ATEN_OP(max_pool3d_with_indices_backward)>::
364       call(
365           grad_output,
366           self,
367           kernel_size,
368           stride,
369           padding,
370           dilation,
371           ceil_mode,
372           indices);
373 }
374 
_unsafe_view(const at::Tensor & self,at::IntArrayRef size)375 at::Tensor LazyNativeFunctions::_unsafe_view(
376     const at::Tensor& self,
377     at::IntArrayRef size) {
378   TORCH_LAZY_FN_COUNTER("lazy::");
379   return LazyNativeFunctions::view_copy_symint(
380       self, c10::fromIntArrayRefSlow(size));
381 }
382 
383 // This is needed by the torch.tensor constructor.
384 // LazyTensor always opts into functionalization.
385 // "lifting" a tensor for functionalization means wrapping it in a
386 // FunctionalTensorWrapper object.
lift(const at::Tensor & tensor)387 at::Tensor LazyNativeFunctions::lift(const at::Tensor& tensor) {
388   TORCH_INTERNAL_ASSERT(
389       !at::functionalization::impl::isFunctionalTensor(tensor));
390   return at::functionalization::impl::to_functional_tensor(tensor);
391 }
lift_fresh(const at::Tensor & tensor)392 at::Tensor LazyNativeFunctions::lift_fresh(const at::Tensor& tensor) {
393   TORCH_INTERNAL_ASSERT(
394       !at::functionalization::impl::isFunctionalTensor(tensor));
395   return at::functionalization::impl::to_functional_tensor(tensor);
396 }
397 
398 // All of the below ops correspond to CompositeExplicitAutograd kernels from
399 // core that call into view operators internally. These are all composite ops
400 // that LTC can technically re-use / get for free, but we need to
401 // "functionalize" them to remove the view ops before we can use them.
block_diag(at::TensorList tensors)402 at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) {
403   return at::functionalization::functionalize_aten_op<ATEN_OP(
404       block_diag)>::call(tensors);
405 }
new_empty_strided_symint(const at::Tensor & self,c10::SymIntArrayRef size,c10::SymIntArrayRef stride,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory)406 at::Tensor LazyNativeFunctions::new_empty_strided_symint(
407     const at::Tensor& self,
408     c10::SymIntArrayRef size,
409     c10::SymIntArrayRef stride,
410     std::optional<at::ScalarType> dtype,
411     std::optional<at::Layout> layout,
412     std::optional<at::Device> device,
413     std::optional<bool> pin_memory) {
414   return at::functionalization::
415       functionalize_aten_op_symint<ATEN_OP(new_empty_strided)>::call(
416           self, size, stride, dtype, layout, device, pin_memory);
417 }
418 
narrow_copy_symint(const at::Tensor & self,int64_t dim,c10::SymInt start,c10::SymInt length)419 at::Tensor LazyNativeFunctions::narrow_copy_symint(
420     const at::Tensor& self,
421     int64_t dim,
422     c10::SymInt start,
423     c10::SymInt length) {
424   return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
425       narrow_copy)>::call(self, dim, start, length);
426 }
pixel_shuffle(const at::Tensor & self,int64_t upscale_factor)427 at::Tensor LazyNativeFunctions::pixel_shuffle(
428     const at::Tensor& self,
429     int64_t upscale_factor) {
430   return at::functionalization::functionalize_aten_op<ATEN_OP(
431       pixel_shuffle)>::call(self, upscale_factor);
432 }
pixel_unshuffle(const at::Tensor & self,int64_t downscale_factor)433 at::Tensor LazyNativeFunctions::pixel_unshuffle(
434     const at::Tensor& self,
435     int64_t downscale_factor) {
436   return at::functionalization::functionalize_aten_op<ATEN_OP(
437       pixel_unshuffle)>::call(self, downscale_factor);
438 }
select_backward_symint(const at::Tensor & grad_output,c10::SymIntArrayRef input_sizes,int64_t dim,c10::SymInt index)439 at::Tensor LazyNativeFunctions::select_backward_symint(
440     const at::Tensor& grad_output,
441     c10::SymIntArrayRef input_sizes,
442     int64_t dim,
443     c10::SymInt index) {
444   return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
445       select_backward)>::call(grad_output, input_sizes, dim, index);
446 }
_trilinear(const at::Tensor & i1,const at::Tensor & i2,const at::Tensor & i3,at::IntArrayRef expand1,at::IntArrayRef expand2,at::IntArrayRef expand3,at::IntArrayRef sumdim,int64_t unroll_dim)447 at::Tensor LazyNativeFunctions::_trilinear(
448     const at::Tensor& i1,
449     const at::Tensor& i2,
450     const at::Tensor& i3,
451     at::IntArrayRef expand1,
452     at::IntArrayRef expand2,
453     at::IntArrayRef expand3,
454     at::IntArrayRef sumdim,
455     int64_t unroll_dim) {
456   return at::functionalization::functionalize_aten_op<ATEN_OP(_trilinear)>::
457       call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim);
458 }
linalg_pinv(const at::Tensor & self,const std::optional<at::Tensor> & atol,const std::optional<at::Tensor> & rtol,bool hermitian)459 at::Tensor LazyNativeFunctions::linalg_pinv(
460     const at::Tensor& self,
461     const std::optional<at::Tensor>& atol,
462     const std::optional<at::Tensor>& rtol,
463     bool hermitian) {
464   return at::functionalization::functionalize_aten_op<ATEN_OP2(
465       linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
466 }
467 
468 // functionalize_aten_op can't handle out= ops directly.
469 // Instead, we can call the composite kernel from core, and copy and mutations
470 // back to the inputs.
logsumexp_out(const at::Tensor & self,at::IntArrayRef dim,bool keepdim,at::Tensor & out)471 at::Tensor& LazyNativeFunctions::logsumexp_out(
472     const at::Tensor& self,
473     at::IntArrayRef dim,
474     bool keepdim,
475     at::Tensor& out) {
476   auto self_wrapped = at::functionalization::impl::to_functional_tensor(self);
477   auto out_wrapped = at::functionalization::impl::to_functional_tensor(out);
478   // directly call the composite kernel from core.
479   // Make sure to re-enable functionalization first.
480   auto curr_tls = c10::impl::tls_local_dispatch_key_set();
481   auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet();
482   tls_reenable_functionalize.set_included(curr_tls.included_);
483   tls_reenable_functionalize.set_excluded(
484       curr_tls.excluded_.remove(c10::DispatchKey::Functionalize));
485   c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize);
486   at::native::logsumexp_out(self_wrapped, dim, keepdim, out_wrapped);
487   auto out_unwrapped =
488       at::functionalization::impl::from_functional_tensor(out_wrapped);
489   // propagate mutations back to the inputs (including resizing)
490   out.resize_(out_unwrapped.sizes());
491   out.copy_(out_unwrapped);
492   return out;
493 }
494 
diag_embed(const at::Tensor & self,int64_t offset,int64_t dim1,int64_t dim2)495 at::Tensor LazyNativeFunctions::diag_embed(
496     const at::Tensor& self,
497     int64_t offset,
498     int64_t dim1,
499     int64_t dim2) {
500   return at::functionalization::functionalize_aten_op<ATEN_OP(
501       diag_embed)>::call(self, offset, dim1, dim2);
502 }
503 
diagonal_backward_symint(const at::Tensor & grad_output,at::SymIntArrayRef input_sizes,int64_t offset,int64_t dim1,int64_t dim2)504 at::Tensor LazyNativeFunctions::diagonal_backward_symint(
505     const at::Tensor& grad_output,
506     at::SymIntArrayRef input_sizes,
507     int64_t offset,
508     int64_t dim1,
509     int64_t dim2) {
510   return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
511       diagonal_backward)>::call(grad_output, input_sizes, offset, dim1, dim2);
512 }
513 
slice_backward_symint(const at::Tensor & grad_output,at::SymIntArrayRef input_sizes,int64_t dim,c10::SymInt start,c10::SymInt end,c10::SymInt step)514 at::Tensor LazyNativeFunctions::slice_backward_symint(
515     const at::Tensor& grad_output,
516     at::SymIntArrayRef input_sizes,
517     int64_t dim,
518     c10::SymInt start,
519     c10::SymInt end,
520     c10::SymInt step) {
521   return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
522       slice_backward)>::call(grad_output, input_sizes, dim, start, end, step);
523 }
524 
525 // re-use the composite kernel from core, that way we don't need to provide a
526 // backwards formula for native_group_norm
native_group_norm(const at::Tensor & input,const std::optional<at::Tensor> & weight,const std::optional<at::Tensor> & bias,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps)527 std::tuple<Tensor, Tensor, Tensor> LazyNativeFunctions::native_group_norm(
528     const at::Tensor& input,
529     const std::optional<at::Tensor>& weight,
530     const std::optional<at::Tensor>& bias,
531     int64_t N,
532     int64_t C,
533     int64_t HxW,
534     int64_t group,
535     double eps) {
536   return at::native::math_group_norm(
537       input, weight, bias, N, C, HxW, group, eps);
538 }
539 
540 } // namespace lazy
541 } // namespace torch
542