1 #include <ATen/FunctionalTensorWrapper.h>
2 #include <ATen/Functions.h>
3 #include <ATen/MetaFunctions.h>
4 #include <ATen/NativeFunctions.h>
5 #include <ATen/Operators.h>
6 #include <ATen/native/BinaryOps.h>
7 #include <ATen/native/CPUFallback.h>
8 #include <torch/csrc/lazy/core/helpers.h>
9 #include <torch/csrc/lazy/core/ir_builder.h>
10 #include <torch/csrc/lazy/core/metrics.h>
11 #include <torch/csrc/lazy/core/ops/utils.h>
12 #include <torch/csrc/lazy/core/shape_inference.h>
13 #include <torch/csrc/lazy/core/tensor_impl.h>
14 #include <torch/csrc/lazy/core/tensor_util.h>
15 #include <torch/csrc/lazy/generated/LazyNativeFunctions.h>
16 #include <torch/csrc/lazy/ts_backend/config.h>
17 #include <torch/csrc/lazy/ts_backend/ops/to_copy.h>
18 #include <torch/csrc/lazy/ts_backend/tensor_aten_ops.h>
19 #include <torch/csrc/lazy/ts_backend/ts_autograd_functions.h>
20 #include <torch/csrc/lazy/ts_backend/ts_eager_fallback.h>
21 #include <torch/library.h>
22
23 using at::Tensor;
24
25 namespace torch {
26 namespace lazy {
27 namespace {
28
CreateLtcTensor(const at::Tensor & tensor,const std::optional<torch::lazy::BackendDevice> & device)29 at::Tensor CreateLtcTensor(
30 const at::Tensor& tensor,
31 const std::optional<torch::lazy::BackendDevice>& device) {
32 if (tensor.defined() && device) {
33 return torch::lazy::CreateAtenFromLtcTensor(
34 torch::lazy::LazyTensor::Create(tensor, *device));
35 }
36 return tensor;
37 }
38
GetLtcDevice(const std::optional<c10::Device> & device)39 std::optional<torch::lazy::BackendDevice> GetLtcDevice(
40 const std::optional<c10::Device>& device) {
41 if (!device) {
42 return std::nullopt;
43 }
44 if (device->type() != at::kLazy) {
45 return std::nullopt;
46 }
47 return torch::lazy::atenDeviceToBackendDevice(*device);
48 }
49
50 } // namespace
51
52 // clone is special in LT because we make it a no-op.
53 // This should be safe to do, because every operator in the LT is functional.
clone(const at::Tensor & self,std::optional<at::MemoryFormat> memory_format)54 at::Tensor LazyNativeFunctions::clone(
55 const at::Tensor& self,
56 std::optional<at::MemoryFormat> memory_format) {
57 auto self_lt = torch::lazy::TryGetLtcTensor(self);
58 return torch::lazy::CreateAtenFromLtcTensor(
59 self_lt->Create(self_lt->GetIrValue(), self_lt->GetDevice()));
60 }
61
_copy_from(const at::Tensor & self,const at::Tensor & dst,bool non_blocking)62 at::Tensor LazyNativeFunctions::_copy_from(
63 const at::Tensor& self,
64 const at::Tensor& dst,
65 bool non_blocking) {
66 TORCH_LAZY_FN_COUNTER("lazy::");
67 auto dst_tensor = torch::lazy::TryGetLtcTensor(dst);
68 auto self_tensor = torch::lazy::TryGetLtcTensor(self);
69 if (!self_tensor) {
70 // providing a new 'eager' value (self) for an existing lazy tensor (dst)
71 static bool sync_update = FLAGS_torch_lazy_ts_tensor_update_sync;
72 TORCH_CHECK(dst_tensor);
73 dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update);
74 } else if (!dst_tensor) {
75 // materializing a lazy tensor (self) and copying its value into eager
76 // tensor (dst) detached=false lets us skip a copy in `ToTensor`, which
77 // should be safe because we are only going to use the tensor for
78 // dst.copy_()
79 TORCH_CHECK(self_tensor);
80 at::Tensor tensor = self_tensor->ToTensor(/*detached=*/false);
81 at::Tensor typed_tensor =
82 torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false);
83 dst.resize_as_(typed_tensor).copy_(typed_tensor);
84 } else {
85 // Copying one lazy tensor to another
86 if (!dst_tensor->CurrentIrValue()) {
87 // if dest is not backed by IR (e.g. result of some lazy operation),
88 // then it should have at::Tensor data backing it instead
89 auto dst_tensor_data = dst_tensor->CurrentTensorData();
90 TORCH_CHECK(dst_tensor_data);
91 auto src_tensor_data = self_tensor->CurrentTensorData();
92 if (src_tensor_data) {
93 // both src/dst are simply backed by at::Tensor data, no IR- do a
94 // straightforward copy
95 dst_tensor_data->copy_(*src_tensor_data);
96 } else {
97 // src needs to be materialized before its result can be used for a copy
98 // into dst since we use the src tensor only for making a copy, we don't
99 // need to detach it note: it would be even more efficient if we could
100 // cause ToTensor to materialize the value directly into dst's buffer
101 // (that would need to be detached though).
102 dst_tensor_data->copy_(self_tensor->ToTensor(/*detached=*/false));
103 }
104 } else {
105 copy_(dst_tensor, self_tensor);
106 auto* impl =
107 dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl());
108 impl->set_tensor(dst_tensor);
109 }
110 }
111 return dst;
112 }
113
_copy_from_and_resize(const at::Tensor & self,const at::Tensor & dst)114 at::Tensor LazyNativeFunctions::_copy_from_and_resize(
115 const at::Tensor& self,
116 const at::Tensor& dst) {
117 TORCH_LAZY_FN_COUNTER("lazy::");
118 auto dst_tensor = torch::lazy::TryGetLtcTensor(dst);
119 auto self_tensor = torch::lazy::TryGetLtcTensor(self);
120 if (!self_tensor) {
121 TORCH_CHECK(dst_tensor);
122 dst_tensor->UpdateFromTensorOut(self);
123 } else if (!dst_tensor) {
124 TORCH_CHECK(self_tensor);
125 at::Tensor tensor = self_tensor->ToTensor(/*detached=*/true);
126 at::Tensor typed_tensor =
127 torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false);
128 dst.resize_as_(typed_tensor).copy_(typed_tensor);
129 } else {
130 // at this point we know dst is a lazy tensor
131 auto* dest_impl =
132 dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl());
133 TORCH_CHECK(dest_impl);
134 dest_impl->tensor()->UpdateFromTensorOut(self_tensor);
135 dest_impl->force_refresh_sizes();
136 }
137 return dst;
138 }
139
_to_copy(const at::Tensor & self,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory,bool non_blocking,std::optional<at::MemoryFormat> memory_format)140 at::Tensor LazyNativeFunctions::_to_copy(
141 const at::Tensor& self,
142 std::optional<at::ScalarType> dtype,
143 std::optional<at::Layout> layout,
144 std::optional<at::Device> device,
145 std::optional<bool> pin_memory,
146 bool non_blocking,
147 std::optional<at::MemoryFormat> memory_format) {
148 if (force_eager_fallback(at::aten::_to_copy)) {
149 TORCH_INTERNAL_ASSERT(
150 false,
151 "Fallback is currently impossible for _to_copy since the fallback helper itself reinvokes _to_copy");
152 }
153
154 auto options = self.options();
155 if (dtype) {
156 // I put each of these setters in a conditional instead of doing
157 // `self.options().dtype(dtype).layout(layout)... because calling
158 // .dtype(nullopt) on an options() that already has dtype appears to wipe it
159 options = options.dtype(dtype);
160 }
161 if (layout) {
162 options = options.layout(layout);
163 }
164 if (memory_format) {
165 options = options.memory_format(memory_format);
166 }
167 if (pin_memory) {
168 // TODO(whc) can we honor 'pin_memory' in some/all cases?
169 options = options.pinned_memory(pin_memory);
170 TORCH_WARN_ONCE(
171 "Pinned memory used in lazy _to_copy, check if the behavior is as intended");
172 }
173
174 TORCH_LAZY_FN_COUNTER("lazy::");
175 auto lazy_self = torch::lazy::TryGetLtcTensor(self);
176 if (!lazy_self && device && device->type() == c10::kLazy) {
177 // Case 1: eager->lazy (we create a new lazy tensor)
178 // See Note [Lazy Tensor Functionalization]
179 // Invariant: if the functionalization key is in the exclude set, then we're
180 // expected to return an ordinary tensor, which will be "lifted" into a
181 // functional wrapper later.
182 bool functionalize_output =
183 !c10::impl::tls_local_dispatch_key_set().excluded_.has(
184 c10::DispatchKey::Functionalize);
185 return torch::lazy::to_lazy_tensor(
186 self,
187 options,
188 *device,
189 /*non_blocking=*/non_blocking,
190 /*functionalize_output=*/functionalize_output);
191 } else if (device && device->type() != c10::kLazy) {
192 // Case 2: lazy->eager (forces a graph break since we are materializing a
193 // tensor)
194
195 TORCH_INTERNAL_ASSERT(lazy_self);
196 auto eager_tensor = lazy_self->ToTensor(/*detached=*/true);
197 options = options.device(device);
198 auto moved_eager_tensor =
199 eager_tensor.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
200 return moved_eager_tensor;
201 } else if (
202 device && device->type() == c10::kLazy && device->has_index() &&
203 device->index() != self.device().index()) {
204 // Case 3: lazy:0 -> lazy:1
205
206 // TODO(whc) what do we actually want to do here?
207 // option 1: materialize, move eager tensor, create new lazy tensor
208 // - this should be our default, as it is what would happen before we
209 // implemented _to_copy
210 // - actually combines case 1 + case 2
211 // option 2: support multiple devices inside one lazy/TS executor (case 4)
212 // - but: we may have other assumptions that there is just one device
213 // per executor? so don't take this lightly
214
215 TORCH_INTERNAL_ASSERT(lazy_self);
216 auto eager_tensor = lazy_self->ToTensor(/*detached=*/true);
217 // we move the eager tensor to the 'eager' equivalent of our lazy device
218 // e.g. if our device is lazy:1, the backend maps that to cuda:1, which is
219 // what we use
220 auto eager_device = c10::Device(
221 torch::lazy::getBackend()->EagerFallbackDeviceType(), device->index());
222 options = options.device(eager_device);
223 auto moved_eager_tensor =
224 eager_tensor.to(options, /*non_blocking=*/false, /*copy=*/true);
225 lazy_self = torch::lazy::GetOrCreateLtcTensor(
226 moved_eager_tensor,
227 torch::lazy::atenDeviceToBackendDevice(eager_device));
228 return torch::lazy::CreateAtenFromLtcTensor(lazy_self);
229
230 } else {
231 // Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy
232 // graph)
233
234 // Note: captured _to_copy will be executed with real eager tensors, not
235 // lazy tensors. We DO NOT want to burn 'lazy:0' as the device into this
236 // captured IR, or we will try to convert an eager tensor back to a lazy one
237 // inside the torchscript executor lazy:0 -> lazy:1 is handled in case3, so
238 // we can safely drop the device argument
239 device = std::nullopt;
240
241 torch::lazy::NodePtr node = torch::lazy::ReuseNode<ToCopy>(
242 lazy_self->GetIrValue(),
243 dtype,
244 layout,
245 device,
246 pin_memory,
247 non_blocking,
248 memory_format);
249 if (!node) {
250 auto shapes = torch::lazy::compute_shape__to_copy(
251 self, dtype, layout, device, pin_memory, non_blocking, memory_format);
252 TORCH_INTERNAL_ASSERT(shapes.size() == 1);
253 node = torch::lazy::MakeNode<ToCopy>(
254 lazy_self->GetIrValue(),
255 dtype,
256 layout,
257 device,
258 pin_memory,
259 non_blocking,
260 memory_format,
261 std::move(shapes));
262 CacheNode(node);
263 }
264
265 auto result =
266 torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create(
267 std::move(node), lazy_self->GetDevice()));
268 return result;
269 }
270 };
271
empty_symint(at::SymIntArrayRef sym_size,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory,std::optional<at::MemoryFormat> memory_format)272 at::Tensor LazyNativeFunctions::empty_symint(
273 at::SymIntArrayRef sym_size,
274 std::optional<at::ScalarType> dtype,
275 std::optional<at::Layout> layout,
276 std::optional<at::Device> device,
277 std::optional<bool> pin_memory,
278 std::optional<at::MemoryFormat> memory_format) {
279 // TODO: support this directly
280 auto size = C10_AS_INTARRAYREF_SLOW(sym_size);
281 const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType();
282 at::TensorOptions options = at::TensorOptions()
283 .device(c10::Device(device_type))
284 .layout(layout)
285 .pinned_memory(pin_memory)
286 .dtype(dtype);
287 auto x_result = at::empty(size, options, memory_format);
288 auto tensor = CreateLtcTensor(x_result, GetLtcDevice(device));
289 // See Note [Lazy Tensor Functionalization]
290 if (c10::impl::tls_local_dispatch_key_set().excluded_.has(
291 c10::DispatchKey::Functionalize)) {
292 // Invariant: if the functionalization key is in the exclude set, then we're
293 // expected to return an ordinary tensor, which will be "lifted" into a
294 // functional wrapper later.
295 return tensor;
296 } else {
297 auto wrapped = at::functionalization::impl::to_functional_tensor(tensor);
298 return wrapped;
299 }
300 }
301
empty_strided_symint(at::SymIntArrayRef sym_size,at::SymIntArrayRef sym_stride,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory)302 at::Tensor LazyNativeFunctions::empty_strided_symint(
303 at::SymIntArrayRef sym_size,
304 at::SymIntArrayRef sym_stride,
305 std::optional<at::ScalarType> dtype,
306 std::optional<at::Layout> layout,
307 std::optional<at::Device> device,
308 std::optional<bool> pin_memory) {
309 TORCH_LAZY_FN_COUNTER("lazy::");
310 at::Tensor t =
311 empty_symint(sym_size, dtype, layout, device, pin_memory, std::nullopt);
312 auto size = C10_AS_INTARRAYREF_SLOW(sym_size);
313 auto stride = C10_AS_INTARRAYREF_SLOW(sym_stride);
314 return t.as_strided(size, stride, /*storage_offset=*/0);
315 }
316
fill_(at::Tensor & self,const at::Scalar & value)317 at::Tensor& LazyNativeFunctions::fill_(
318 at::Tensor& self,
319 const at::Scalar& value) {
320 TORCH_LAZY_FN_COUNTER("lazy::");
321 auto self_tensor = torch::lazy::TryGetLtcTensor(self);
322 torch::lazy::fill_(self_tensor, value);
323 return self;
324 }
325
max_pool3d(const at::Tensor & self,at::IntArrayRef kernel_size,at::IntArrayRef stride,at::IntArrayRef padding,at::IntArrayRef dilation,bool ceil_mode)326 at::Tensor LazyNativeFunctions::max_pool3d(
327 const at::Tensor& self,
328 at::IntArrayRef kernel_size,
329 at::IntArrayRef stride,
330 at::IntArrayRef padding,
331 at::IntArrayRef dilation,
332 bool ceil_mode) {
333 return torch::lazy::MaxPool3dAutogradFunctionTS::apply(
334 self, kernel_size, stride, padding, dilation, ceil_mode);
335 }
336
337 // We need to explicitly override max pooling operators and just call the
338 // fallback for them because we've customized the autograd function for them
339 // (backward needs saved indices from forward).
max_pool3d_with_indices(const at::Tensor & self,at::IntArrayRef kernel_size,at::IntArrayRef stride,at::IntArrayRef padding,at::IntArrayRef dilation,bool ceil_mode)340 std::tuple<at::Tensor, at::Tensor> LazyNativeFunctions::max_pool3d_with_indices(
341 const at::Tensor& self,
342 at::IntArrayRef kernel_size,
343 at::IntArrayRef stride,
344 at::IntArrayRef padding,
345 at::IntArrayRef dilation,
346 bool ceil_mode) {
347 return at::native::
348 call_fallback_fn<<c_eager_fallback, ATEN_OP(max_pool3d_with_indices)>::
349 call(self, kernel_size, stride, padding, dilation, ceil_mode);
350 }
351
max_pool3d_with_indices_backward(const at::Tensor & grad_output,const at::Tensor & self,at::IntArrayRef kernel_size,at::IntArrayRef stride,at::IntArrayRef padding,at::IntArrayRef dilation,bool ceil_mode,const at::Tensor & indices)352 at::Tensor LazyNativeFunctions::max_pool3d_with_indices_backward(
353 const at::Tensor& grad_output,
354 const at::Tensor& self,
355 at::IntArrayRef kernel_size,
356 at::IntArrayRef stride,
357 at::IntArrayRef padding,
358 at::IntArrayRef dilation,
359 bool ceil_mode,
360 const at::Tensor& indices) {
361 return at::native::call_fallback_fn<
362 <c_eager_fallback,
363 ATEN_OP(max_pool3d_with_indices_backward)>::
364 call(
365 grad_output,
366 self,
367 kernel_size,
368 stride,
369 padding,
370 dilation,
371 ceil_mode,
372 indices);
373 }
374
_unsafe_view(const at::Tensor & self,at::IntArrayRef size)375 at::Tensor LazyNativeFunctions::_unsafe_view(
376 const at::Tensor& self,
377 at::IntArrayRef size) {
378 TORCH_LAZY_FN_COUNTER("lazy::");
379 return LazyNativeFunctions::view_copy_symint(
380 self, c10::fromIntArrayRefSlow(size));
381 }
382
383 // This is needed by the torch.tensor constructor.
384 // LazyTensor always opts into functionalization.
385 // "lifting" a tensor for functionalization means wrapping it in a
386 // FunctionalTensorWrapper object.
lift(const at::Tensor & tensor)387 at::Tensor LazyNativeFunctions::lift(const at::Tensor& tensor) {
388 TORCH_INTERNAL_ASSERT(
389 !at::functionalization::impl::isFunctionalTensor(tensor));
390 return at::functionalization::impl::to_functional_tensor(tensor);
391 }
lift_fresh(const at::Tensor & tensor)392 at::Tensor LazyNativeFunctions::lift_fresh(const at::Tensor& tensor) {
393 TORCH_INTERNAL_ASSERT(
394 !at::functionalization::impl::isFunctionalTensor(tensor));
395 return at::functionalization::impl::to_functional_tensor(tensor);
396 }
397
398 // All of the below ops correspond to CompositeExplicitAutograd kernels from
399 // core that call into view operators internally. These are all composite ops
400 // that LTC can technically re-use / get for free, but we need to
401 // "functionalize" them to remove the view ops before we can use them.
block_diag(at::TensorList tensors)402 at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) {
403 return at::functionalization::functionalize_aten_op<ATEN_OP(
404 block_diag)>::call(tensors);
405 }
new_empty_strided_symint(const at::Tensor & self,c10::SymIntArrayRef size,c10::SymIntArrayRef stride,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory)406 at::Tensor LazyNativeFunctions::new_empty_strided_symint(
407 const at::Tensor& self,
408 c10::SymIntArrayRef size,
409 c10::SymIntArrayRef stride,
410 std::optional<at::ScalarType> dtype,
411 std::optional<at::Layout> layout,
412 std::optional<at::Device> device,
413 std::optional<bool> pin_memory) {
414 return at::functionalization::
415 functionalize_aten_op_symint<ATEN_OP(new_empty_strided)>::call(
416 self, size, stride, dtype, layout, device, pin_memory);
417 }
418
narrow_copy_symint(const at::Tensor & self,int64_t dim,c10::SymInt start,c10::SymInt length)419 at::Tensor LazyNativeFunctions::narrow_copy_symint(
420 const at::Tensor& self,
421 int64_t dim,
422 c10::SymInt start,
423 c10::SymInt length) {
424 return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
425 narrow_copy)>::call(self, dim, start, length);
426 }
pixel_shuffle(const at::Tensor & self,int64_t upscale_factor)427 at::Tensor LazyNativeFunctions::pixel_shuffle(
428 const at::Tensor& self,
429 int64_t upscale_factor) {
430 return at::functionalization::functionalize_aten_op<ATEN_OP(
431 pixel_shuffle)>::call(self, upscale_factor);
432 }
pixel_unshuffle(const at::Tensor & self,int64_t downscale_factor)433 at::Tensor LazyNativeFunctions::pixel_unshuffle(
434 const at::Tensor& self,
435 int64_t downscale_factor) {
436 return at::functionalization::functionalize_aten_op<ATEN_OP(
437 pixel_unshuffle)>::call(self, downscale_factor);
438 }
select_backward_symint(const at::Tensor & grad_output,c10::SymIntArrayRef input_sizes,int64_t dim,c10::SymInt index)439 at::Tensor LazyNativeFunctions::select_backward_symint(
440 const at::Tensor& grad_output,
441 c10::SymIntArrayRef input_sizes,
442 int64_t dim,
443 c10::SymInt index) {
444 return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
445 select_backward)>::call(grad_output, input_sizes, dim, index);
446 }
_trilinear(const at::Tensor & i1,const at::Tensor & i2,const at::Tensor & i3,at::IntArrayRef expand1,at::IntArrayRef expand2,at::IntArrayRef expand3,at::IntArrayRef sumdim,int64_t unroll_dim)447 at::Tensor LazyNativeFunctions::_trilinear(
448 const at::Tensor& i1,
449 const at::Tensor& i2,
450 const at::Tensor& i3,
451 at::IntArrayRef expand1,
452 at::IntArrayRef expand2,
453 at::IntArrayRef expand3,
454 at::IntArrayRef sumdim,
455 int64_t unroll_dim) {
456 return at::functionalization::functionalize_aten_op<ATEN_OP(_trilinear)>::
457 call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim);
458 }
linalg_pinv(const at::Tensor & self,const std::optional<at::Tensor> & atol,const std::optional<at::Tensor> & rtol,bool hermitian)459 at::Tensor LazyNativeFunctions::linalg_pinv(
460 const at::Tensor& self,
461 const std::optional<at::Tensor>& atol,
462 const std::optional<at::Tensor>& rtol,
463 bool hermitian) {
464 return at::functionalization::functionalize_aten_op<ATEN_OP2(
465 linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
466 }
467
468 // functionalize_aten_op can't handle out= ops directly.
469 // Instead, we can call the composite kernel from core, and copy and mutations
470 // back to the inputs.
logsumexp_out(const at::Tensor & self,at::IntArrayRef dim,bool keepdim,at::Tensor & out)471 at::Tensor& LazyNativeFunctions::logsumexp_out(
472 const at::Tensor& self,
473 at::IntArrayRef dim,
474 bool keepdim,
475 at::Tensor& out) {
476 auto self_wrapped = at::functionalization::impl::to_functional_tensor(self);
477 auto out_wrapped = at::functionalization::impl::to_functional_tensor(out);
478 // directly call the composite kernel from core.
479 // Make sure to re-enable functionalization first.
480 auto curr_tls = c10::impl::tls_local_dispatch_key_set();
481 auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet();
482 tls_reenable_functionalize.set_included(curr_tls.included_);
483 tls_reenable_functionalize.set_excluded(
484 curr_tls.excluded_.remove(c10::DispatchKey::Functionalize));
485 c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize);
486 at::native::logsumexp_out(self_wrapped, dim, keepdim, out_wrapped);
487 auto out_unwrapped =
488 at::functionalization::impl::from_functional_tensor(out_wrapped);
489 // propagate mutations back to the inputs (including resizing)
490 out.resize_(out_unwrapped.sizes());
491 out.copy_(out_unwrapped);
492 return out;
493 }
494
diag_embed(const at::Tensor & self,int64_t offset,int64_t dim1,int64_t dim2)495 at::Tensor LazyNativeFunctions::diag_embed(
496 const at::Tensor& self,
497 int64_t offset,
498 int64_t dim1,
499 int64_t dim2) {
500 return at::functionalization::functionalize_aten_op<ATEN_OP(
501 diag_embed)>::call(self, offset, dim1, dim2);
502 }
503
diagonal_backward_symint(const at::Tensor & grad_output,at::SymIntArrayRef input_sizes,int64_t offset,int64_t dim1,int64_t dim2)504 at::Tensor LazyNativeFunctions::diagonal_backward_symint(
505 const at::Tensor& grad_output,
506 at::SymIntArrayRef input_sizes,
507 int64_t offset,
508 int64_t dim1,
509 int64_t dim2) {
510 return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
511 diagonal_backward)>::call(grad_output, input_sizes, offset, dim1, dim2);
512 }
513
slice_backward_symint(const at::Tensor & grad_output,at::SymIntArrayRef input_sizes,int64_t dim,c10::SymInt start,c10::SymInt end,c10::SymInt step)514 at::Tensor LazyNativeFunctions::slice_backward_symint(
515 const at::Tensor& grad_output,
516 at::SymIntArrayRef input_sizes,
517 int64_t dim,
518 c10::SymInt start,
519 c10::SymInt end,
520 c10::SymInt step) {
521 return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
522 slice_backward)>::call(grad_output, input_sizes, dim, start, end, step);
523 }
524
525 // re-use the composite kernel from core, that way we don't need to provide a
526 // backwards formula for native_group_norm
native_group_norm(const at::Tensor & input,const std::optional<at::Tensor> & weight,const std::optional<at::Tensor> & bias,int64_t N,int64_t C,int64_t HxW,int64_t group,double eps)527 std::tuple<Tensor, Tensor, Tensor> LazyNativeFunctions::native_group_norm(
528 const at::Tensor& input,
529 const std::optional<at::Tensor>& weight,
530 const std::optional<at::Tensor>& bias,
531 int64_t N,
532 int64_t C,
533 int64_t HxW,
534 int64_t group,
535 double eps) {
536 return at::native::math_group_norm(
537 input, weight, bias, N, C, HxW, group, eps);
538 }
539
540 } // namespace lazy
541 } // namespace torch
542