1 #include <ATen/core/dynamic_type.h>
2 #include <ATen/core/type_factory.h>
3 #include <torch/csrc/jit/mobile/register_ops_common_utils.h>
4
5 namespace torch::jit {
6
normalizeIndex(int64_t idx,int64_t list_size)7 int64_t normalizeIndex(int64_t idx, int64_t list_size) {
8 if (idx < 0) {
9 // Handle negative indexing
10 idx = list_size + idx;
11 }
12 return idx;
13 }
14
tensorToListRecursive(char * data,int64_t cur_dim,int64_t num_tensor_dims,at::TypePtr ty,at::ScalarType scalar_ty,at::IntArrayRef sizes,at::IntArrayRef strides,size_t element_size)15 IValue tensorToListRecursive(
16 char* data,
17 int64_t cur_dim,
18 int64_t num_tensor_dims,
19 at::TypePtr ty,
20 at::ScalarType scalar_ty,
21 at::IntArrayRef sizes,
22 at::IntArrayRef strides,
23 size_t element_size) {
24 // If ty is a ListType, get the element type.
25 if (auto list_type = ty->cast<at::ListType>()) {
26 ty = list_type->getElementType();
27 } else {
28 // If the output type is a scalar, read and push one scalar of
29 // the right type onto the stack.
30 if (ty == at::IntType::get()) {
31 int64_t scalar = *(int64_t*)data;
32 return IValue(scalar);
33 } else if (ty == at::FloatType::get()) {
34 TORCH_INTERNAL_ASSERT(
35 scalar_ty == at::ScalarType::Float ||
36 scalar_ty == at::ScalarType::Double,
37 "Unexpected scalar type for Tensor");
38 double scalar =
39 scalar_ty == at::ScalarType::Float ? *(float*)data : *(double*)data;
40 return IValue(scalar);
41 } else if (ty == at::ComplexType::get()) {
42 TORCH_INTERNAL_ASSERT(
43 scalar_ty == at::ScalarType::ComplexFloat ||
44 scalar_ty == at::ScalarType::ComplexDouble,
45 "Unexpected scalar type for Tensor");
46 c10::complex<double> scalar = scalar_ty == at::ScalarType::ComplexFloat
47 ? *(c10::complex<float>*)data
48 : *(c10::complex<double>*)data;
49 return IValue(scalar);
50 } else if (ty == at::BoolType::get()) {
51 bool scalar = *(bool*)data;
52 return IValue(scalar);
53 } else {
54 TORCH_CHECK(
55 false,
56 ty->repr_str(),
57 " is not one of the supported types for tolist: int, float, bool");
58 }
59 }
60
61 // Make the result list consisting of elements of type ty. Since this
62 // invocation is processing dimension cur_dim, there will be sizes[cur_dim]
63 // output elements.
64 auto result = c10::impl::GenericList(ty);
65 result.reserve(sizes[cur_dim]);
66
67 // Since ty was a list type, tensorToListRecursive needs to be called
68 // recursively on each slice of the tensor in the current dimension.
69 for (int64_t i = 0, e = sizes[cur_dim]; i < e; ++i) {
70 auto inner_result = tensorToListRecursive(
71 data,
72 cur_dim + 1,
73 num_tensor_dims,
74 ty,
75 scalar_ty,
76 sizes,
77 strides,
78 element_size);
79
80 if (inner_result.isList()) {
81 result.emplace_back(inner_result.toList());
82 } else if (inner_result.isComplexDouble()) {
83 result.emplace_back(inner_result.toComplexDouble());
84 } else if (inner_result.isDouble()) {
85 result.emplace_back(inner_result.toDouble());
86 } else if (inner_result.isInt()) {
87 result.emplace_back(inner_result.toInt());
88 } else if (inner_result.isBool()) {
89 result.emplace_back(inner_result.toBool());
90 } else {
91 TORCH_INTERNAL_ASSERT(
92 false && "Unknown return type for tensorToListRecursive");
93 }
94
95 data += strides[cur_dim] * element_size;
96 }
97
98 return result;
99 }
100
101 } // namespace torch::jit
102