xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ExpandBase.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/TensorBase.h>
2 
3 // Broadcasting utilities for working with TensorBase
4 namespace at {
5 namespace internal {
6 TORCH_API TensorBase expand_slow_path(const TensorBase& self, IntArrayRef size);
7 } // namespace internal
8 
expand_size(const TensorBase & self,IntArrayRef size)9 inline c10::MaybeOwned<TensorBase> expand_size(
10     const TensorBase& self,
11     IntArrayRef size) {
12   if (size.equals(self.sizes())) {
13     return c10::MaybeOwned<TensorBase>::borrowed(self);
14   }
15   return c10::MaybeOwned<TensorBase>::owned(
16       at::internal::expand_slow_path(self, size));
17 }
18 c10::MaybeOwned<TensorBase> expand_size(TensorBase&& self, IntArrayRef size) =
19     delete;
20 
expand_inplace(const TensorBase & tensor,const TensorBase & to_expand)21 inline c10::MaybeOwned<TensorBase> expand_inplace(
22     const TensorBase& tensor,
23     const TensorBase& to_expand) {
24   return expand_size(to_expand, tensor.sizes());
25 }
26 c10::MaybeOwned<TensorBase> expand_inplace(
27     const TensorBase& tensor,
28     TensorBase&& to_expand) = delete;
29 
30 } // namespace at
31