xref: /aosp_15_r20/external/executorch/runtime/executor/tensor_parser_portable.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/executor/tensor_parser.h>
10 
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
13 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14 #include <executorch/runtime/executor/memory_manager.h>
15 #include <executorch/runtime/executor/program.h>
16 #include <executorch/runtime/platform/profiler.h>
17 #include <executorch/schema/program_generated.h>
18 
19 namespace executorch {
20 namespace runtime {
21 namespace deserialization {
22 
23 using torch::executor::ScalarType;
24 using torch::executor::Tensor;
25 using torch::executor::TensorImpl;
26 
parseTensor(const Program * program,MemoryManager * memory_manager,const executorch_flatbuffer::Tensor * s_tensor)27 Result<Tensor> parseTensor(
28     const Program* program,
29     MemoryManager* memory_manager,
30     const executorch_flatbuffer::Tensor* s_tensor) {
31   EXECUTORCH_SCOPE_PROF("TensorParser::parseTensor");
32   auto method_allocator = memory_manager->method_allocator();
33 
34   ET_CHECK_OR_RETURN_ERROR(
35       s_tensor->storage_offset() == 0,
36       NotSupported,
37       "Non-zero storage offset %" PRId32 " not supported",
38       s_tensor->storage_offset());
39 
40   ScalarType scalar_type = static_cast<ScalarType>(s_tensor->scalar_type());
41   ET_CHECK_OR_RETURN_ERROR(
42       isValid(scalar_type) &&
43           // Types that do not yet have deserialization support.
44           scalar_type != exec_aten::ScalarType::ComplexHalf &&
45           scalar_type != exec_aten::ScalarType::ComplexFloat &&
46           scalar_type != exec_aten::ScalarType::ComplexDouble,
47       InvalidProgram,
48       "Invalid or unsupported ScalarType %" PRId8,
49       static_cast<int8_t>(scalar_type));
50 
51   TensorShapeDynamism dynamism =
52       static_cast<TensorShapeDynamism>(s_tensor->shape_dynamism());
53   // TODO(T175194371): Remove this check once fully dynamic shapes are
54   // supported.
55   ET_CHECK_OR_RETURN_ERROR(
56       dynamism != TensorShapeDynamism::DYNAMIC_UNBOUND,
57       NotSupported,
58       "Fully dynamic tensor shapes not yet supported: T175194371");
59 
60   ET_CHECK_OR_RETURN_ERROR(
61       s_tensor->sizes() != nullptr, InvalidProgram, "Missing sizes field");
62   const auto serialized_sizes = s_tensor->sizes()->data();
63   const auto dim = s_tensor->sizes()->size();
64 
65   ET_CHECK_OR_RETURN_ERROR(
66       s_tensor->dim_order() != nullptr,
67       InvalidProgram,
68       "Missing dim_order field");
69   ET_CHECK_OR_RETURN_ERROR(
70       s_tensor->dim_order()->size() == dim,
71       InvalidProgram,
72       "dim_order size %" PRIu32 " != dim %" PRIu32,
73       s_tensor->dim_order()->size(),
74       dim);
75   const auto serialized_dim_order = s_tensor->dim_order()->data();
76 
77   exec_aten::SizesType* sizes = nullptr;
78   exec_aten::DimOrderType* dim_order = nullptr;
79   // For dynamic shape tensors, allocate local buffers to allow mutable sizes
80   // and strides
81   if (dynamism != TensorShapeDynamism::STATIC) {
82     // copy sizes and dim order out of flatbuffer
83     // kimishpate: I think dim order can remain immutable and point to fb
84     // memory, unless we plan to implement in-place permute
85     exec_aten::SizesType* sizes_buf = ET_ALLOCATE_LIST_OR_RETURN_ERROR(
86         method_allocator, exec_aten::SizesType, dim);
87     exec_aten::DimOrderType* dim_order_buf = ET_ALLOCATE_LIST_OR_RETURN_ERROR(
88         method_allocator, exec_aten::DimOrderType, dim);
89     std::memcpy(
90         sizes_buf, serialized_sizes, sizeof(exec_aten::SizesType) * dim);
91     std::memcpy(
92         dim_order_buf,
93         serialized_dim_order,
94         sizeof(exec_aten::DimOrderType) * dim);
95 
96     sizes = sizes_buf;
97     dim_order = dim_order_buf;
98   } else {
99     // Const cast safe here as these tensors can't be resized, so these fields
100     // will not be modified.
101     sizes = const_cast<exec_aten::SizesType*>(serialized_sizes);
102     dim_order = const_cast<exec_aten::DimOrderType*>(serialized_dim_order);
103   }
104   // We will remove strides from schema.
105   // Allocating strides buffer here and populating it.
106   // In subsequent diffs we can remove strides accessor, however this
107   // will introduce incompatible APIs between ATen Tensor and ETensor.
108   exec_aten::StridesType* strides = ET_ALLOCATE_LIST_OR_RETURN_ERROR(
109       method_allocator, exec_aten::StridesType, dim);
110   auto status = dim_order_to_stride(sizes, dim_order, dim, strides);
111   ET_CHECK_OR_RETURN_ERROR(
112       status == Error::Ok,
113       Internal,
114       "dim_order_to_stride returned invalid status");
115 
116   auto* tensor_impl =
117       ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(method_allocator, TensorImpl);
118   // Placement new on the allocated memory space. Note that we create this first
119   // with null data so we can find its expected size before getting its memory.
120   new (tensor_impl) TensorImpl(
121       scalar_type,
122       dim,
123       sizes,
124       /*data=*/nullptr,
125       dim_order,
126       strides,
127       dynamism);
128 
129   // Now that we know how big the tensor is, find and assign its memory.
130   Result<void*> data_ptr = getTensorDataPtr(
131       s_tensor,
132       program,
133       tensor_impl->nbytes(),
134       memory_manager->planned_memory());
135   if (!data_ptr.ok()) {
136     ET_LOG(
137         Error,
138         "getTensorDataPtr() failed: 0x%" PRIx32,
139         static_cast<uint32_t>(data_ptr.error()));
140     return data_ptr.error();
141   }
142   tensor_impl->set_data(data_ptr.get());
143 
144   return Tensor(tensor_impl);
145 }
146 
147 } // namespace deserialization
148 } // namespace runtime
149 } // namespace executorch
150