xref: /aosp_15_r20/external/executorch/runtime/core/exec_aten/util/tensor_util_aten.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
10 
11 #include <ATen/Tensor.h> // @manual
12 #include <executorch/runtime/platform/assert.h>
13 
14 namespace executorch {
15 namespace runtime {
16 /**
17  * Implementation for ATen tensor util, should only be included in
18  * `<target>_aten` target and only be used in ATen mode. Explicitly taking
19  * at::Tensor (instead of exec_aten::Tensor) to make sure it fails at compile
20  * time if built incorrectly.
21  */
get_dim_order(const at::Tensor & tensor,exec_aten::DimOrderType * out_dim_order,size_t out_dim_order_size)22 Error get_dim_order(
23     const at::Tensor& tensor,
24     exec_aten::DimOrderType* out_dim_order,
25     size_t out_dim_order_size) {
26   ET_CHECK_OR_RETURN_ERROR(
27       out_dim_order_size == tensor.dim(),
28       InvalidArgument,
29       "out_dim_order_size needs to be equal to the number of dimensions of the tensor. out_dim_order_size %zu, tensor.dim() %" PRId64,
30       out_dim_order_size,
31       tensor.dim());
32   return stride_to_dim_order(
33       tensor.strides().data(), tensor.dim(), out_dim_order);
34 }
35 
tensor_has_valid_dim_order(at::Tensor t)36 bool tensor_has_valid_dim_order(at::Tensor t) {
37   exec_aten::DimOrderType dim_order[kTensorDimensionLimit];
38   ET_LOG_MSG_AND_RETURN_IF_FALSE(
39       get_dim_order(t, dim_order, t.dim()) == Error::Ok,
40       "Failed to retrieve dim order from tensor!");
41 
42   if (!validate_dim_order(dim_order, t.dim())) {
43     ET_LOG(Error, "Tensor dim order is not valid:");
44     for (size_t d = 0; d < t.dim(); ++d) {
45       ET_LOG(
46           Error,
47           "    dim_order(%zu): %zu",
48           static_cast<size_t>(d),
49           static_cast<size_t>(dim_order[d]));
50     }
51     return false;
52   }
53   return true;
54 }
55 
tensor_is_default_or_channels_last_dim_order(at::Tensor t)56 inline bool tensor_is_default_or_channels_last_dim_order(at::Tensor t) {
57   exec_aten::DimOrderType dim_order[kTensorDimensionLimit];
58   ET_LOG_MSG_AND_RETURN_IF_FALSE(
59       get_dim_order(t, dim_order, t.dim()) == Error::Ok,
60       "Failed to retrieve dim order from tensor!");
61 
62   bool ret_val = is_contiguous_dim_order(dim_order, t.dim()) ||
63       is_channels_last_dim_order(dim_order, t.dim());
64 
65   if (!ret_val) {
66     ET_LOG(
67         Error,
68         "Expected tensor to have default or channels last dim order, but got");
69     for (size_t d = 0; d < t.dim(); ++d) {
70       ET_LOG(
71           Error,
72           "    dim_order(%zu): %zu",
73           static_cast<size_t>(d),
74           static_cast<size_t>(dim_order[d]));
75     }
76   }
77   return ret_val;
78 }
79 
tensors_have_same_dim_order(const exec_aten::ArrayRef<exec_aten::Tensor> tensor_list)80 bool tensors_have_same_dim_order(
81     const exec_aten::ArrayRef<exec_aten::Tensor> tensor_list) {
82   if (tensor_list.size() < 2) {
83     return true;
84   }
85 
86   exec_aten::DimOrderType first_dim_order[kTensorDimensionLimit];
87   exec_aten::DimOrderType other_dim_order[kTensorDimensionLimit];
88 
89   ET_LOG_MSG_AND_RETURN_IF_FALSE(
90       get_dim_order(tensor_list[0], first_dim_order, tensor_list[0].dim()) ==
91           Error::Ok,
92       "Failed to retrieve dim order from 1st input tensor!");
93 
94   bool all_contiguous =
95       is_contiguous_dim_order(first_dim_order, tensor_list[0].dim());
96   bool all_channels_last =
97       is_channels_last_dim_order(first_dim_order, tensor_list[0].dim());
98 
99   for (size_t i = 1; i < tensor_list.size(); ++i) {
100     ET_LOG_MSG_AND_RETURN_IF_FALSE(
101         get_dim_order(tensor_list[i], other_dim_order, tensor_list[i].dim()) ==
102             Error::Ok,
103         "Failed to retrieve dim order from %zd-th input tensor!",
104         i);
105 
106     all_contiguous = all_contiguous &&
107         is_contiguous_dim_order(other_dim_order, tensor_list[i].dim());
108     all_channels_last = all_channels_last &&
109         is_channels_last_dim_order(other_dim_order, tensor_list[i].dim());
110   }
111 
112   ET_LOG_MSG_AND_RETURN_IF_FALSE(
113       all_contiguous || all_channels_last,
114       "%zd input tensors have different dim orders",
115       tensor_list.size());
116 
117   return all_contiguous || all_channels_last;
118 }
119 
120 namespace internal {
121 
share_tensor_data(const at::Tensor & t_dst,const at::Tensor & t_src)122 Error share_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) {
123   at::StorageImpl* storage =
124       t_dst.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl();
125 
126   ET_CHECK_OR_RETURN_ERROR(
127       t_dst.nbytes() == t_src.nbytes(),
128       InvalidArgument,
129       "t_dst.nbytes() %lu != t_src.nbytes(). %lu",
130       t_dst.nbytes(),
131       t_src.nbytes());
132 
133   ET_CHECK_OR_RETURN_ERROR(
134       t_src.mutable_data_ptr() != nullptr,
135       InvalidArgument,
136       "Source tensor should have data_ptr not being nullptr.");
137   // Assign the dataptr as the input tensor dataptr
138   storage->set_data_ptr(
139       at::DataPtr(t_src.mutable_data_ptr(), at::DeviceType::CPU));
140   storage->set_nbytes(t_src.nbytes());
141 
142   return Error::Ok;
143 }
144 
copy_tensor_data(const at::Tensor & t_dst,const at::Tensor & t_src)145 Error copy_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) {
146   void* dst_data_ptr = t_dst.unsafeGetTensorImpl()
147                            ->unsafe_storage()
148                            .unsafeGetStorageImpl()
149                            ->data_ptr()
150                            .get();
151 
152   // Currently even 0 sized tensors receive a dataptr in pre_allocated
153   // memory planning so we can do this check.
154   // TODO(jakeszwe, shunting, gasoonjia): this should be clear in design if
155   // other people make their own memory plans
156   ET_CHECK_OR_RETURN_ERROR(
157       dst_data_ptr != nullptr,
158       InvalidArgument,
159       "Destination tensor data pointer must not be null.");
160 
161   // Sources with a size 0 dimension can be nullptr
162   if (t_src.const_data_ptr() != nullptr) {
163     ET_CHECK_OR_RETURN_ERROR(
164         t_dst.nbytes() == t_src.nbytes(),
165         InvalidArgument,
166         "t_dst.nbytes() %lu != t_src.nbytes(). %lu",
167         t_dst.nbytes(),
168         t_src.nbytes());
169     // Copy the source data to the preallocated memory of the destination, which
170     // must be the same size as the source.
171     std::memcpy(dst_data_ptr, t_src.const_data_ptr(), t_src.nbytes());
172   }
173 
174   return Error::Ok;
175 }
176 
177 ET_NODISCARD Error
set_tensor_data(const at::Tensor & t,void * buffer,size_t buffer_size)178 set_tensor_data(const at::Tensor& t, void* buffer, size_t buffer_size) {
179   ET_CHECK_OR_RETURN_ERROR(
180       buffer_size >= t.nbytes(),
181       InvalidArgument,
182       "buffer_size %zu is smaller than smaller than tensor nbytes %zu",
183       buffer_size,
184       t.nbytes());
185   t.unsafeGetTensorImpl()->unsafe_storage().set_data_ptr(
186       at::DataPtr(buffer, at::DeviceType::CPU));
187   return Error::Ok;
188 }
189 
reset_data_ptr(const at::Tensor & tensor)190 void reset_data_ptr(const at::Tensor& tensor) {
191   auto impl = tensor.unsafeGetTensorImpl();
192   impl->set_sizes_contiguous(0);
193   impl->unsafe_storage().unsafeGetStorageImpl()->reset();
194 }
195 
196 /// Most callers should use resize_tensor() instead.
resize_tensor_impl(c10::TensorImpl * impl,c10::ArrayRef<exec_aten::SizesType> new_sizes)197 Error resize_tensor_impl(
198     c10::TensorImpl* impl,
199     c10::ArrayRef<exec_aten::SizesType> new_sizes) {
200   // The lean-mode Tensor will perform this check, but at::Tensor won't.
201   // Although at::Tensor can be resized in this case, it's not allowed by the
202   // higher-level constraints of the runtime.
203   if (impl->dim() != new_sizes.size()) {
204     ET_LOG(
205         Error,
206         "Tensor rank is not mutable: old dim: %" PRId64 " new dim: %zu",
207         impl->dim(),
208         new_sizes.size());
209     return torch::executor::Error::NotSupported;
210   }
211   // Will panic on failure.
212   impl->set_sizes_contiguous(new_sizes);
213   return torch::executor::Error::Ok;
214 }
215 
216 } // namespace internal
217 
218 } // namespace runtime
219 } // namespace executorch
220