xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ResizeCommon.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/TensorFactories.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <c10/util/irange.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/empty.h>
12 #endif
13 
14 namespace at::native {
15 
16 template <typename T>
storage_size_for(ArrayRef<T> size,ArrayRef<T> stride)17 inline T storage_size_for(ArrayRef<T> size, ArrayRef<T> stride) {
18   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(size.size() == stride.size(),
19       "storage_size_for(size, stride) requires that size and stride ",
20       "have the same size as a precondition.");
21   T storage_size = 1;
22   for (const auto dim : c10::irange(size.size())) {
23     if (size[dim] == 0) {
24       storage_size = 0;
25       break;
26     }
27     storage_size += (size[dim] - 1) * stride[dim];
28   }
29   return storage_size;
30 }
31 
resize_named_tensor_(const Tensor & self,IntArrayRef size,std::optional<MemoryFormat> optional_memory_format)32 inline const Tensor& resize_named_tensor_(
33     const Tensor& self,
34     IntArrayRef size,
35     std::optional<MemoryFormat> optional_memory_format) {
36   TORCH_INTERNAL_ASSERT(self.has_names());
37   TORCH_CHECK(
38       self.sizes() == size,
39       "Cannot resize named tensor with resize_ or resize_as_ (tried to resize "
40       "Tensor",
41       self.names(),
42       " with size ",
43       self.sizes(),
44       " to ",
45       size,
46       "). This may be caused by passing a named tensor ",
47       "as an `out=` argument; please ensure that the sizes are the same. ");
48   TORCH_CHECK(
49       !optional_memory_format.has_value(),
50       "Unsupported memory format for named tensor resize ",
51       optional_memory_format.value());
52   return self;
53 }
54 
55 // For deterministic output, fill new elements that were added after a storage
56 // resize with NaN or MAX_INT. `old_storage_nbytes` is the size of the storage
57 // before the resize happened.
fill_resize_deterministic_(const Tensor & tensor,int64_t old_storage_nbytes)58 inline const Tensor& fill_resize_deterministic_(const Tensor& tensor, int64_t old_storage_nbytes) {
59   const at::Storage& storage = tensor.unsafeGetTensorImpl()->unsafe_storage();
60   int64_t new_storage_nbytes = storage.nbytes();
61   int64_t old_storage_numel = old_storage_nbytes / tensor.itemsize();
62   int64_t new_storage_numel = new_storage_nbytes / tensor.itemsize();
63   if (new_storage_numel > old_storage_numel) {
64     at::Tensor tensor_view = at::empty({}, at::TensorOptions().dtype(tensor.scalar_type()).device(tensor.device()));
65     tensor_view.set_(
66       storage,
67       /*storage_offset=*/old_storage_numel,
68       /*size=*/{new_storage_numel - old_storage_numel},
69       /*stride=*/{1});
70     at::native::fill_empty_deterministic_(tensor_view);
71   }
72   return tensor;
73 }
74 
75 } // namespace at::native
76