xref: /aosp_15_r20/external/executorch/extension/tensor/tensor_ptr.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/tensor/tensor_ptr.h>
10 
11 #include <numeric>
12 
13 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
14 
15 namespace executorch {
16 namespace extension {
17 namespace {
18 #ifndef USE_ATEN_LIB
19 /**
20  * A structure that consolidates the metadata (sizes, dim_order, strides) and
21  * the data buffer associated with a Tensor. Since Tensor does not own
22  * the memory for these metadata arrays or the data itself, this structure
23  * ensures that they are managed together and have the same lifetime as the
24  * Tensor. When the Tensor is destroyed, the Storage structure ensures
25  * proper cleanup of the associated metadata and data if needed.
26  */
27 struct Storage final {
28   exec_aten::TensorImpl tensor_impl;
29   exec_aten::Tensor tensor;
30   std::vector<exec_aten::SizesType> sizes;
31   std::vector<exec_aten::DimOrderType> dim_order;
32   std::vector<exec_aten::StridesType> strides;
33   std::function<void(void*)> deleter;
34 
Storageexecutorch::extension::__anonea639a470111::Storage35   Storage(
36       exec_aten::TensorImpl&& tensor_impl,
37       std::vector<exec_aten::SizesType>&& sizes,
38       std::vector<exec_aten::DimOrderType>&& dim_order,
39       std::vector<exec_aten::StridesType>&& strides,
40       std::function<void(void*)>&& deleter)
41       : tensor_impl(std::move(tensor_impl)),
42         tensor(&this->tensor_impl),
43         sizes(std::move(sizes)),
44         dim_order(std::move(dim_order)),
45         strides(std::move(strides)),
46         deleter(std::move(deleter)) {}
47 
~Storageexecutorch::extension::__anonea639a470111::Storage48   ~Storage() {
49     if (deleter) {
50       deleter(tensor_impl.mutable_data());
51     }
52   }
53 };
54 #endif // USE_ATEN_LIB
55 } // namespace
56 
make_tensor_ptr(std::vector<exec_aten::SizesType> sizes,void * data,std::vector<exec_aten::DimOrderType> dim_order,std::vector<exec_aten::StridesType> strides,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism,std::function<void (void *)> deleter)57 TensorPtr make_tensor_ptr(
58     std::vector<exec_aten::SizesType> sizes,
59     void* data,
60     std::vector<exec_aten::DimOrderType> dim_order,
61     std::vector<exec_aten::StridesType> strides,
62     exec_aten::ScalarType type,
63     exec_aten::TensorShapeDynamism dynamism,
64     std::function<void(void*)> deleter) {
65   const auto dim = sizes.size();
66   ET_CHECK_MSG(
67       dim_order.empty() || dim_order.size() == dim,
68       "dim_order size must match sizes or be empty.");
69   ET_CHECK_MSG(
70       strides.empty() || strides.size() == dim,
71       "strides size must match sizes or be empty.");
72 
73   if (dim_order.empty()) {
74     dim_order.resize(dim);
75     std::iota(dim_order.begin(), dim_order.end(), 0);
76     if (!strides.empty()) {
77       std::sort(dim_order.begin(), dim_order.end(), [&](size_t a, size_t b) {
78         return strides[a] > strides[b];
79       });
80     }
81   }
82   std::vector<exec_aten::StridesType> computed_strides(dim);
83   auto error = runtime::dim_order_to_stride(
84       sizes.data(), dim_order.data(), dim, computed_strides.data());
85   ET_CHECK_MSG(error == runtime::Error::Ok, "Failed to compute strides.");
86 
87   if (!strides.empty()) {
88     ET_CHECK_MSG(computed_strides == strides, "Invalid strides provided.");
89   } else {
90     strides = std::move(computed_strides);
91   }
92 #ifndef USE_ATEN_LIB
93   exec_aten::TensorImpl tensor_impl(
94       type,
95       dim,
96       sizes.data(),
97       data,
98       dim_order.data(),
99       strides.data(),
100       dim > 0 ? dynamism : exec_aten::TensorShapeDynamism::STATIC);
101   auto storage = std::make_shared<Storage>(
102       std::move(tensor_impl),
103       std::move(sizes),
104       std::move(dim_order),
105       std::move(strides),
106       std::move(deleter));
107   const auto tensor_ptr = &storage->tensor;
108   return std::shared_ptr<exec_aten::Tensor>(std::move(storage), tensor_ptr);
109 #else
110   auto options = c10::TensorOptions()
111                      .dtype(c10::scalarTypeToTypeMeta(type))
112                      .device(c10::kCPU);
113   auto storage = c10::Storage(
114       c10::Storage::use_byte_size_t(),
115       at::detail::computeStorageNbytes(
116           sizes, strides, options.dtype().itemsize()),
117       c10::InefficientStdFunctionContext::makeDataPtr(
118           data, std::move(deleter), options.device()),
119       nullptr,
120       false);
121   auto tensor_impl = c10::make_intrusive<exec_aten::TensorImpl>(
122       std::move(storage),
123       c10::DispatchKeySet(c10::DispatchKey::CPU),
124       options.dtype());
125   tensor_impl->set_sizes_and_strides(sizes, strides);
126   return std::make_shared<exec_aten::Tensor>(std::move(tensor_impl));
127 #endif // USE_ATEN_LIB
128 }
129 
make_tensor_ptr(std::vector<exec_aten::SizesType> sizes,std::vector<uint8_t> data,std::vector<exec_aten::DimOrderType> dim_order,std::vector<exec_aten::StridesType> strides,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism)130 TensorPtr make_tensor_ptr(
131     std::vector<exec_aten::SizesType> sizes,
132     std::vector<uint8_t> data,
133     std::vector<exec_aten::DimOrderType> dim_order,
134     std::vector<exec_aten::StridesType> strides,
135     exec_aten::ScalarType type,
136     exec_aten::TensorShapeDynamism dynamism) {
137   ET_CHECK_MSG(
138       data.size() >= exec_aten::compute_numel(sizes.data(), sizes.size()) *
139               exec_aten::elementSize(type),
140       "Data size is smaller than required by sizes and scalar type.");
141   auto data_ptr = data.data();
142   return make_tensor_ptr(
143       std::move(sizes),
144       data_ptr,
145       std::move(dim_order),
146       std::move(strides),
147       type,
148       dynamism,
149       // Data is moved into the deleter and is destroyed together with Storage.
150       [data = std::move(data)](void*) {});
151 }
152 
clone_tensor_ptr(const exec_aten::Tensor & tensor)153 TensorPtr clone_tensor_ptr(const exec_aten::Tensor& tensor) {
154   std::vector<exec_aten::SizesType> sizes(
155       tensor.sizes().begin(), tensor.sizes().end());
156   std::vector<exec_aten::DimOrderType> dim_order{
157 #ifndef USE_ATEN_LIB
158       tensor.dim_order().begin(), tensor.dim_order().end()
159 #endif // USE_ATEN_LIB
160   };
161   std::vector<exec_aten::StridesType> strides(
162       tensor.strides().begin(), tensor.strides().end());
163   auto dynamism = exec_aten::TensorShapeDynamism::DYNAMIC_BOUND;
164 #ifndef USE_ATEN_LIB
165   dynamism = tensor.shape_dynamism();
166 #endif // USE_ATEN_LIB
167   return tensor.const_data_ptr()
168       ? make_tensor_ptr(
169             std::move(sizes),
170             std::vector<uint8_t>(
171                 (uint8_t*)tensor.const_data_ptr(),
172                 (uint8_t*)tensor.const_data_ptr() + tensor.nbytes()),
173             std::move(dim_order),
174             std::move(strides),
175             tensor.scalar_type(),
176             dynamism)
177       : make_tensor_ptr(
178             std::move(sizes),
179             nullptr,
180             std::move(dim_order),
181             std::move(strides),
182             tensor.scalar_type(),
183             dynamism);
184 }
185 
clone_tensor_ptr(const exec_aten::Tensor & tensor)186 TensorPtr clone_tensor_ptr(const exec_aten::Tensor& tensor) {
187   std::vector<exec_aten::SizesType> sizes(
188       tensor.sizes().begin(), tensor.sizes().end());
189   std::vector<exec_aten::DimOrderType> dim_order{
190 #ifndef USE_ATEN_LIB
191       tensor.dim_order().begin(), tensor.dim_order().end()
192 #endif // USE_ATEN_LIB
193   };
194   std::vector<exec_aten::StridesType> strides(
195       tensor.strides().begin(), tensor.strides().end());
196   auto dynamism = exec_aten::TensorShapeDynamism::DYNAMIC_BOUND;
197 #ifndef USE_ATEN_LIB
198   dynamism = tensor.shape_dynamism();
199 #endif // USE_ATEN_LIB
200   return tensor.const_data_ptr()
201       ? make_tensor_ptr(
202             std::move(sizes),
203             std::vector<uint8_t>(
204                 (uint8_t*)tensor.const_data_ptr(),
205                 (uint8_t*)tensor.const_data_ptr() + tensor.nbytes()),
206             std::move(dim_order),
207             std::move(strides),
208             tensor.scalar_type(),
209             dynamism)
210       : make_tensor_ptr(
211             std::move(sizes),
212             nullptr,
213             std::move(dim_order),
214             std::move(strides),
215             tensor.scalar_type(),
216             dynamism);
217 }
218 
resize_tensor_ptr(TensorPtr & tensor,const std::vector<exec_aten::SizesType> & sizes)219 runtime::Error resize_tensor_ptr(
220     TensorPtr& tensor,
221     const std::vector<exec_aten::SizesType>& sizes) {
222   return runtime::resize_tensor(
223       *tensor,
224       exec_aten::ArrayRef<exec_aten::SizesType>(sizes.data(), sizes.size()));
225 }
226 
227 } // namespace extension
228 } // namespace executorch
229