xref: /aosp_15_r20/external/executorch/extension/tensor/tensor_ptr_maker.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/tensor/tensor_ptr_maker.h>
10 
11 #include <random>
12 
13 namespace executorch {
14 namespace extension {
15 namespace {
16 
17 template <
18     typename INT_T,
19     typename std::enable_if<
20         std::is_integral<INT_T>::value && !std::is_same<INT_T, bool>::value,
21         bool>::type = true>
extract_scalar(exec_aten::Scalar scalar,INT_T * out_val)22 bool extract_scalar(exec_aten::Scalar scalar, INT_T* out_val) {
23   if (!scalar.isIntegral(/*includeBool=*/false)) {
24     return false;
25   }
26   int64_t val = scalar.to<int64_t>();
27   if (val < std::numeric_limits<INT_T>::lowest() ||
28       val > std::numeric_limits<INT_T>::max()) {
29     return false;
30   }
31   *out_val = static_cast<INT_T>(val);
32   return true;
33 }
34 
35 template <
36     typename FLOAT_T,
37     typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
38         type = true>
extract_scalar(exec_aten::Scalar scalar,FLOAT_T * out_val)39 bool extract_scalar(exec_aten::Scalar scalar, FLOAT_T* out_val) {
40   double val;
41   if (scalar.isFloatingPoint()) {
42     val = scalar.to<double>();
43     if (std::isfinite(val) &&
44         (val < std::numeric_limits<FLOAT_T>::lowest() ||
45          val > std::numeric_limits<FLOAT_T>::max())) {
46       return false;
47     }
48   } else if (scalar.isIntegral(/*includeBool=*/false)) {
49     val = static_cast<double>(scalar.to<int64_t>());
50   } else {
51     return false;
52   }
53   *out_val = static_cast<FLOAT_T>(val);
54   return true;
55 }
56 
57 template <
58     typename BOOL_T,
59     typename std::enable_if<std::is_same<BOOL_T, bool>::value, bool>::type =
60         true>
extract_scalar(exec_aten::Scalar scalar,BOOL_T * out_val)61 bool extract_scalar(exec_aten::Scalar scalar, BOOL_T* out_val) {
62   if (scalar.isIntegral(false)) {
63     *out_val = static_cast<bool>(scalar.to<int64_t>());
64     return true;
65   }
66   if (scalar.isBoolean()) {
67     *out_val = scalar.to<bool>();
68     return true;
69   }
70   return false;
71 }
72 
73 #define ET_EXTRACT_SCALAR(scalar, out_val) \
74   ET_CHECK_MSG(                            \
75       extract_scalar(scalar, &out_val),    \
76       #scalar " could not be extracted: wrong type or out of range");
77 
78 template <typename Distribution>
random_strided(std::vector<exec_aten::SizesType> sizes,std::vector<exec_aten::StridesType> strides,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism,Distribution && distribution)79 TensorPtr random_strided(
80     std::vector<exec_aten::SizesType> sizes,
81     std::vector<exec_aten::StridesType> strides,
82     exec_aten::ScalarType type,
83     exec_aten::TensorShapeDynamism dynamism,
84     Distribution&& distribution) {
85   auto tensor =
86       empty_strided(std::move(sizes), std::move(strides), type, dynamism);
87   std::default_random_engine gen{std::random_device{}()};
88 
89   ET_SWITCH_REALB_TYPES(type, nullptr, "random_strided", CTYPE, [&] {
90     std::generate_n(tensor->mutable_data_ptr<CTYPE>(), tensor->numel(), [&]() {
91       return static_cast<CTYPE>(distribution(gen));
92     });
93   });
94   return tensor;
95 }
96 
97 } // namespace
98 
empty_strided(std::vector<exec_aten::SizesType> sizes,std::vector<exec_aten::StridesType> strides,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism)99 TensorPtr empty_strided(
100     std::vector<exec_aten::SizesType> sizes,
101     std::vector<exec_aten::StridesType> strides,
102     exec_aten::ScalarType type,
103     exec_aten::TensorShapeDynamism dynamism) {
104   std::vector<uint8_t> data(
105       exec_aten::compute_numel(sizes.data(), sizes.size()) *
106       exec_aten::elementSize(type));
107   return make_tensor_ptr(
108       std::move(sizes),
109       std::move(data),
110       {},
111       std::move(strides),
112       type,
113       dynamism);
114 }
115 
full_strided(std::vector<exec_aten::SizesType> sizes,std::vector<exec_aten::StridesType> strides,exec_aten::Scalar fill_value,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism)116 TensorPtr full_strided(
117     std::vector<exec_aten::SizesType> sizes,
118     std::vector<exec_aten::StridesType> strides,
119     exec_aten::Scalar fill_value,
120     exec_aten::ScalarType type,
121     exec_aten::TensorShapeDynamism dynamism) {
122   auto tensor =
123       empty_strided(std::move(sizes), std::move(strides), type, dynamism);
124   ET_SWITCH_REALB_TYPES(type, nullptr, "full_strided", CTYPE, [&] {
125     CTYPE value;
126     ET_EXTRACT_SCALAR(fill_value, value);
127     std::fill(
128         tensor->mutable_data_ptr<CTYPE>(),
129         tensor->mutable_data_ptr<CTYPE>() + tensor->numel(),
130         value);
131   });
132   return tensor;
133 }
134 
rand_strided(std::vector<exec_aten::SizesType> sizes,std::vector<exec_aten::StridesType> strides,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism)135 TensorPtr rand_strided(
136     std::vector<exec_aten::SizesType> sizes,
137     std::vector<exec_aten::StridesType> strides,
138     exec_aten::ScalarType type,
139     exec_aten::TensorShapeDynamism dynamism) {
140   return random_strided(
141       std::move(sizes),
142       std::move(strides),
143       type,
144       dynamism,
145       std::uniform_real_distribution<float>(0.0f, 1.0f));
146 }
147 
randn_strided(std::vector<exec_aten::SizesType> sizes,std::vector<exec_aten::StridesType> strides,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism)148 TensorPtr randn_strided(
149     std::vector<exec_aten::SizesType> sizes,
150     std::vector<exec_aten::StridesType> strides,
151     exec_aten::ScalarType type,
152     exec_aten::TensorShapeDynamism dynamism) {
153   return random_strided(
154       std::move(sizes),
155       std::move(strides),
156       type,
157       dynamism,
158       std::normal_distribution<float>(0.0f, 1.0f));
159 }
160 
randint_strided(int64_t low,int64_t high,std::vector<exec_aten::SizesType> sizes,std::vector<exec_aten::StridesType> strides,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism)161 TensorPtr randint_strided(
162     int64_t low,
163     int64_t high,
164     std::vector<exec_aten::SizesType> sizes,
165     std::vector<exec_aten::StridesType> strides,
166     exec_aten::ScalarType type,
167     exec_aten::TensorShapeDynamism dynamism) {
168   return random_strided(
169       std::move(sizes),
170       std::move(strides),
171       type,
172       dynamism,
173       std::uniform_int_distribution<int64_t>(low, high - 1));
174 }
175 
176 } // namespace extension
177 } // namespace executorch
178