1 #pragma once
2
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/TensorFactories.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <c10/util/irange.h>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/empty.h>
12 #endif
13
14 namespace at::native {
15
16 template <typename T>
storage_size_for(ArrayRef<T> size,ArrayRef<T> stride)17 inline T storage_size_for(ArrayRef<T> size, ArrayRef<T> stride) {
18 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(size.size() == stride.size(),
19 "storage_size_for(size, stride) requires that size and stride ",
20 "have the same size as a precondition.");
21 T storage_size = 1;
22 for (const auto dim : c10::irange(size.size())) {
23 if (size[dim] == 0) {
24 storage_size = 0;
25 break;
26 }
27 storage_size += (size[dim] - 1) * stride[dim];
28 }
29 return storage_size;
30 }
31
resize_named_tensor_(const Tensor & self,IntArrayRef size,std::optional<MemoryFormat> optional_memory_format)32 inline const Tensor& resize_named_tensor_(
33 const Tensor& self,
34 IntArrayRef size,
35 std::optional<MemoryFormat> optional_memory_format) {
36 TORCH_INTERNAL_ASSERT(self.has_names());
37 TORCH_CHECK(
38 self.sizes() == size,
39 "Cannot resize named tensor with resize_ or resize_as_ (tried to resize "
40 "Tensor",
41 self.names(),
42 " with size ",
43 self.sizes(),
44 " to ",
45 size,
46 "). This may be caused by passing a named tensor ",
47 "as an `out=` argument; please ensure that the sizes are the same. ");
48 TORCH_CHECK(
49 !optional_memory_format.has_value(),
50 "Unsupported memory format for named tensor resize ",
51 optional_memory_format.value());
52 return self;
53 }
54
55 // For deterministic output, fill new elements that were added after a storage
56 // resize with NaN or MAX_INT. `old_storage_nbytes` is the size of the storage
57 // before the resize happened.
fill_resize_deterministic_(const Tensor & tensor,int64_t old_storage_nbytes)58 inline const Tensor& fill_resize_deterministic_(const Tensor& tensor, int64_t old_storage_nbytes) {
59 const at::Storage& storage = tensor.unsafeGetTensorImpl()->unsafe_storage();
60 int64_t new_storage_nbytes = storage.nbytes();
61 int64_t old_storage_numel = old_storage_nbytes / tensor.itemsize();
62 int64_t new_storage_numel = new_storage_nbytes / tensor.itemsize();
63 if (new_storage_numel > old_storage_numel) {
64 at::Tensor tensor_view = at::empty({}, at::TensorOptions().dtype(tensor.scalar_type()).device(tensor.device()));
65 tensor_view.set_(
66 storage,
67 /*storage_offset=*/old_storage_numel,
68 /*size=*/{new_storage_numel - old_storage_numel},
69 /*stride=*/{1});
70 at::native::fill_empty_deterministic_(tensor_view);
71 }
72 return tensor;
73 }
74
75 } // namespace at::native
76