1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/tensor/tensor_ptr_maker.h>
10
11 #include <random>
12
13 namespace executorch {
14 namespace extension {
15 namespace {
16
17 template <
18 typename INT_T,
19 typename std::enable_if<
20 std::is_integral<INT_T>::value && !std::is_same<INT_T, bool>::value,
21 bool>::type = true>
extract_scalar(exec_aten::Scalar scalar,INT_T * out_val)22 bool extract_scalar(exec_aten::Scalar scalar, INT_T* out_val) {
23 if (!scalar.isIntegral(/*includeBool=*/false)) {
24 return false;
25 }
26 int64_t val = scalar.to<int64_t>();
27 if (val < std::numeric_limits<INT_T>::lowest() ||
28 val > std::numeric_limits<INT_T>::max()) {
29 return false;
30 }
31 *out_val = static_cast<INT_T>(val);
32 return true;
33 }
34
35 template <
36 typename FLOAT_T,
37 typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
38 type = true>
extract_scalar(exec_aten::Scalar scalar,FLOAT_T * out_val)39 bool extract_scalar(exec_aten::Scalar scalar, FLOAT_T* out_val) {
40 double val;
41 if (scalar.isFloatingPoint()) {
42 val = scalar.to<double>();
43 if (std::isfinite(val) &&
44 (val < std::numeric_limits<FLOAT_T>::lowest() ||
45 val > std::numeric_limits<FLOAT_T>::max())) {
46 return false;
47 }
48 } else if (scalar.isIntegral(/*includeBool=*/false)) {
49 val = static_cast<double>(scalar.to<int64_t>());
50 } else {
51 return false;
52 }
53 *out_val = static_cast<FLOAT_T>(val);
54 return true;
55 }
56
57 template <
58 typename BOOL_T,
59 typename std::enable_if<std::is_same<BOOL_T, bool>::value, bool>::type =
60 true>
extract_scalar(exec_aten::Scalar scalar,BOOL_T * out_val)61 bool extract_scalar(exec_aten::Scalar scalar, BOOL_T* out_val) {
62 if (scalar.isIntegral(false)) {
63 *out_val = static_cast<bool>(scalar.to<int64_t>());
64 return true;
65 }
66 if (scalar.isBoolean()) {
67 *out_val = scalar.to<bool>();
68 return true;
69 }
70 return false;
71 }
72
73 #define ET_EXTRACT_SCALAR(scalar, out_val) \
74 ET_CHECK_MSG( \
75 extract_scalar(scalar, &out_val), \
76 #scalar " could not be extracted: wrong type or out of range");
77
78 template <typename Distribution>
random_strided(std::vector<exec_aten::SizesType> sizes,std::vector<exec_aten::StridesType> strides,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism,Distribution && distribution)79 TensorPtr random_strided(
80 std::vector<exec_aten::SizesType> sizes,
81 std::vector<exec_aten::StridesType> strides,
82 exec_aten::ScalarType type,
83 exec_aten::TensorShapeDynamism dynamism,
84 Distribution&& distribution) {
85 auto tensor =
86 empty_strided(std::move(sizes), std::move(strides), type, dynamism);
87 std::default_random_engine gen{std::random_device{}()};
88
89 ET_SWITCH_REALB_TYPES(type, nullptr, "random_strided", CTYPE, [&] {
90 std::generate_n(tensor->mutable_data_ptr<CTYPE>(), tensor->numel(), [&]() {
91 return static_cast<CTYPE>(distribution(gen));
92 });
93 });
94 return tensor;
95 }
96
97 } // namespace
98
empty_strided(std::vector<exec_aten::SizesType> sizes,std::vector<exec_aten::StridesType> strides,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism)99 TensorPtr empty_strided(
100 std::vector<exec_aten::SizesType> sizes,
101 std::vector<exec_aten::StridesType> strides,
102 exec_aten::ScalarType type,
103 exec_aten::TensorShapeDynamism dynamism) {
104 std::vector<uint8_t> data(
105 exec_aten::compute_numel(sizes.data(), sizes.size()) *
106 exec_aten::elementSize(type));
107 return make_tensor_ptr(
108 std::move(sizes),
109 std::move(data),
110 {},
111 std::move(strides),
112 type,
113 dynamism);
114 }
115
full_strided(std::vector<exec_aten::SizesType> sizes,std::vector<exec_aten::StridesType> strides,exec_aten::Scalar fill_value,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism)116 TensorPtr full_strided(
117 std::vector<exec_aten::SizesType> sizes,
118 std::vector<exec_aten::StridesType> strides,
119 exec_aten::Scalar fill_value,
120 exec_aten::ScalarType type,
121 exec_aten::TensorShapeDynamism dynamism) {
122 auto tensor =
123 empty_strided(std::move(sizes), std::move(strides), type, dynamism);
124 ET_SWITCH_REALB_TYPES(type, nullptr, "full_strided", CTYPE, [&] {
125 CTYPE value;
126 ET_EXTRACT_SCALAR(fill_value, value);
127 std::fill(
128 tensor->mutable_data_ptr<CTYPE>(),
129 tensor->mutable_data_ptr<CTYPE>() + tensor->numel(),
130 value);
131 });
132 return tensor;
133 }
134
rand_strided(std::vector<exec_aten::SizesType> sizes,std::vector<exec_aten::StridesType> strides,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism)135 TensorPtr rand_strided(
136 std::vector<exec_aten::SizesType> sizes,
137 std::vector<exec_aten::StridesType> strides,
138 exec_aten::ScalarType type,
139 exec_aten::TensorShapeDynamism dynamism) {
140 return random_strided(
141 std::move(sizes),
142 std::move(strides),
143 type,
144 dynamism,
145 std::uniform_real_distribution<float>(0.0f, 1.0f));
146 }
147
randn_strided(std::vector<exec_aten::SizesType> sizes,std::vector<exec_aten::StridesType> strides,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism)148 TensorPtr randn_strided(
149 std::vector<exec_aten::SizesType> sizes,
150 std::vector<exec_aten::StridesType> strides,
151 exec_aten::ScalarType type,
152 exec_aten::TensorShapeDynamism dynamism) {
153 return random_strided(
154 std::move(sizes),
155 std::move(strides),
156 type,
157 dynamism,
158 std::normal_distribution<float>(0.0f, 1.0f));
159 }
160
randint_strided(int64_t low,int64_t high,std::vector<exec_aten::SizesType> sizes,std::vector<exec_aten::StridesType> strides,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism)161 TensorPtr randint_strided(
162 int64_t low,
163 int64_t high,
164 std::vector<exec_aten::SizesType> sizes,
165 std::vector<exec_aten::StridesType> strides,
166 exec_aten::ScalarType type,
167 exec_aten::TensorShapeDynamism dynamism) {
168 return random_strided(
169 std::move(sizes),
170 std::move(strides),
171 type,
172 dynamism,
173 std::uniform_int_distribution<int64_t>(low, high - 1));
174 }
175
176 } // namespace extension
177 } // namespace executorch
178