xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/register_ops_common_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/dynamic_type.h>
2 #include <ATen/core/type_factory.h>
3 #include <torch/csrc/jit/mobile/register_ops_common_utils.h>
4 
5 namespace torch::jit {
6 
normalizeIndex(int64_t idx,int64_t list_size)7 int64_t normalizeIndex(int64_t idx, int64_t list_size) {
8   if (idx < 0) {
9     // Handle negative indexing
10     idx = list_size + idx;
11   }
12   return idx;
13 }
14 
tensorToListRecursive(char * data,int64_t cur_dim,int64_t num_tensor_dims,at::TypePtr ty,at::ScalarType scalar_ty,at::IntArrayRef sizes,at::IntArrayRef strides,size_t element_size)15 IValue tensorToListRecursive(
16     char* data,
17     int64_t cur_dim,
18     int64_t num_tensor_dims,
19     at::TypePtr ty,
20     at::ScalarType scalar_ty,
21     at::IntArrayRef sizes,
22     at::IntArrayRef strides,
23     size_t element_size) {
24   // If ty is a ListType, get the element type.
25   if (auto list_type = ty->cast<at::ListType>()) {
26     ty = list_type->getElementType();
27   } else {
28     // If the output type is a scalar, read and push one scalar of
29     // the right type onto the stack.
30     if (ty == at::IntType::get()) {
31       int64_t scalar = *(int64_t*)data;
32       return IValue(scalar);
33     } else if (ty == at::FloatType::get()) {
34       TORCH_INTERNAL_ASSERT(
35           scalar_ty == at::ScalarType::Float ||
36               scalar_ty == at::ScalarType::Double,
37           "Unexpected scalar type for Tensor");
38       double scalar =
39           scalar_ty == at::ScalarType::Float ? *(float*)data : *(double*)data;
40       return IValue(scalar);
41     } else if (ty == at::ComplexType::get()) {
42       TORCH_INTERNAL_ASSERT(
43           scalar_ty == at::ScalarType::ComplexFloat ||
44               scalar_ty == at::ScalarType::ComplexDouble,
45           "Unexpected scalar type for Tensor");
46       c10::complex<double> scalar = scalar_ty == at::ScalarType::ComplexFloat
47           ? *(c10::complex<float>*)data
48           : *(c10::complex<double>*)data;
49       return IValue(scalar);
50     } else if (ty == at::BoolType::get()) {
51       bool scalar = *(bool*)data;
52       return IValue(scalar);
53     } else {
54       TORCH_CHECK(
55           false,
56           ty->repr_str(),
57           " is not one of the supported types for tolist: int, float, bool");
58     }
59   }
60 
61   // Make the result list consisting of elements of type ty. Since this
62   // invocation is processing dimension cur_dim, there will be sizes[cur_dim]
63   // output elements.
64   auto result = c10::impl::GenericList(ty);
65   result.reserve(sizes[cur_dim]);
66 
67   // Since ty was a list type, tensorToListRecursive needs to be called
68   // recursively on each slice of the tensor in the current dimension.
69   for (int64_t i = 0, e = sizes[cur_dim]; i < e; ++i) {
70     auto inner_result = tensorToListRecursive(
71         data,
72         cur_dim + 1,
73         num_tensor_dims,
74         ty,
75         scalar_ty,
76         sizes,
77         strides,
78         element_size);
79 
80     if (inner_result.isList()) {
81       result.emplace_back(inner_result.toList());
82     } else if (inner_result.isComplexDouble()) {
83       result.emplace_back(inner_result.toComplexDouble());
84     } else if (inner_result.isDouble()) {
85       result.emplace_back(inner_result.toDouble());
86     } else if (inner_result.isInt()) {
87       result.emplace_back(inner_result.toInt());
88     } else if (inner_result.isBool()) {
89       result.emplace_back(inner_result.toBool());
90     } else {
91       TORCH_INTERNAL_ASSERT(
92           false && "Unknown return type for tensorToListRecursive");
93     }
94 
95     data += strides[cur_dim] * element_size;
96   }
97 
98   return result;
99 }
100 
101 } // namespace torch::jit
102