xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ExpandBase.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/TensorBase.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker // Broadcasting utilities for working with TensorBase
4*da0073e9SAndroid Build Coastguard Worker namespace at {
5*da0073e9SAndroid Build Coastguard Worker namespace internal {
6*da0073e9SAndroid Build Coastguard Worker TORCH_API TensorBase expand_slow_path(const TensorBase& self, IntArrayRef size);
7*da0073e9SAndroid Build Coastguard Worker } // namespace internal
8*da0073e9SAndroid Build Coastguard Worker 
expand_size(const TensorBase & self,IntArrayRef size)9*da0073e9SAndroid Build Coastguard Worker inline c10::MaybeOwned<TensorBase> expand_size(
10*da0073e9SAndroid Build Coastguard Worker     const TensorBase& self,
11*da0073e9SAndroid Build Coastguard Worker     IntArrayRef size) {
12*da0073e9SAndroid Build Coastguard Worker   if (size.equals(self.sizes())) {
13*da0073e9SAndroid Build Coastguard Worker     return c10::MaybeOwned<TensorBase>::borrowed(self);
14*da0073e9SAndroid Build Coastguard Worker   }
15*da0073e9SAndroid Build Coastguard Worker   return c10::MaybeOwned<TensorBase>::owned(
16*da0073e9SAndroid Build Coastguard Worker       at::internal::expand_slow_path(self, size));
17*da0073e9SAndroid Build Coastguard Worker }
18*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<TensorBase> expand_size(TensorBase&& self, IntArrayRef size) =
19*da0073e9SAndroid Build Coastguard Worker     delete;
20*da0073e9SAndroid Build Coastguard Worker 
expand_inplace(const TensorBase & tensor,const TensorBase & to_expand)21*da0073e9SAndroid Build Coastguard Worker inline c10::MaybeOwned<TensorBase> expand_inplace(
22*da0073e9SAndroid Build Coastguard Worker     const TensorBase& tensor,
23*da0073e9SAndroid Build Coastguard Worker     const TensorBase& to_expand) {
24*da0073e9SAndroid Build Coastguard Worker   return expand_size(to_expand, tensor.sizes());
25*da0073e9SAndroid Build Coastguard Worker }
26*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<TensorBase> expand_inplace(
27*da0073e9SAndroid Build Coastguard Worker     const TensorBase& tensor,
28*da0073e9SAndroid Build Coastguard Worker     TensorBase&& to_expand) = delete;
29*da0073e9SAndroid Build Coastguard Worker 
30*da0073e9SAndroid Build Coastguard Worker } // namespace at
31