1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
12 #include <executorch/runtime/core/portable_type/tensor.h>
13 #include <executorch/runtime/platform/runtime.h>
14 #include <gtest/gtest.h>
15 #include <torch/library.h>
16
17 using namespace ::testing;
18 using ::executorch::extension::internal::type_convert;
19 using ::executorch::extension::internal::type_map;
20 using ::torch::executor::ScalarType;
21 using ::torch::executor::Tensor;
22
my_op_out(const Tensor & a,Tensor & out)23 Tensor& my_op_out(const Tensor& a, Tensor& out) {
24 (void)a;
25 return out;
26 }
27
add_1_out(const Tensor & a,Tensor & out)28 Tensor& add_1_out(const Tensor& a, Tensor& out) {
29 (void)a;
30 out.mutable_data_ptr<int32_t>()[0] += 1;
31 return out;
32 }
33
add_optional_scalar_out(torch::executor::optional<int64_t> s1,torch::executor::optional<int64_t> s2,Tensor & out)34 Tensor& add_optional_scalar_out(
35 torch::executor::optional<int64_t> s1,
36 torch::executor::optional<int64_t> s2,
37 Tensor& out) {
38 if (s1.has_value()) {
39 out.mutable_data_ptr<int64_t>()[0] += s1.value();
40 }
41 if (s2.has_value()) {
42 out.mutable_data_ptr<int64_t>()[0] += s2.value();
43 }
44 return out;
45 }
46
add_optional_tensor_out(torch::executor::optional<torch::executor::Tensor> s1,torch::executor::optional<torch::executor::Tensor> s2,Tensor & out)47 Tensor& add_optional_tensor_out(
48 torch::executor::optional<torch::executor::Tensor> s1,
49 torch::executor::optional<torch::executor::Tensor> s2,
50 Tensor& out) {
51 if (s1.has_value()) {
52 out.mutable_data_ptr<int64_t>()[0] +=
53 s1.value().mutable_data_ptr<int64_t>()[0];
54 }
55 if (s2.has_value()) {
56 out.mutable_data_ptr<int64_t>()[0] +=
57 s2.value().mutable_data_ptr<int64_t>()[0];
58 }
59 return out;
60 }
61
sum_arrayref_scalar_out(torch::executor::ArrayRef<int64_t> a,Tensor & out)62 Tensor& sum_arrayref_scalar_out(
63 torch::executor::ArrayRef<int64_t> a,
64 Tensor& out) {
65 for (int i = 0; i < a.size(); i++) {
66 out.mutable_data_ptr<int64_t>()[0] += a[i];
67 }
68 return out;
69 }
70
sum_arrayref_tensor_out(torch::executor::ArrayRef<torch::executor::Tensor> a,Tensor & out)71 Tensor& sum_arrayref_tensor_out(
72 torch::executor::ArrayRef<torch::executor::Tensor> a,
73 Tensor& out) {
74 for (int i = 0; i < a.size(); i++) {
75 out.mutable_data_ptr<int32_t>()[0] += a[i].const_data_ptr<int32_t>()[0];
76 }
77 return out;
78 }
79
sum_arrayref_optional_tensor_out(torch::executor::ArrayRef<torch::executor::optional<torch::executor::Tensor>> a,Tensor & out)80 Tensor& sum_arrayref_optional_tensor_out(
81 torch::executor::ArrayRef<
82 torch::executor::optional<torch::executor::Tensor>> a,
83 Tensor& out) {
84 for (int i = 0; i < a.size(); i++) {
85 if (a[i].has_value()) {
86 out.mutable_data_ptr<int32_t>()[0] +=
87 a[i].value().const_data_ptr<int32_t>()[0];
88 }
89 }
90 return out;
91 }
92
quantized_embedding_byte_out(const Tensor & weight,const Tensor & weight_scales,const Tensor & weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,Tensor & out)93 Tensor& quantized_embedding_byte_out(
94 const Tensor& weight,
95 const Tensor& weight_scales,
96 const Tensor& weight_zero_points,
97 int64_t weight_quant_min,
98 int64_t weight_quant_max,
99 const Tensor& indices,
100 Tensor& out) {
101 (void)weight;
102 (void)weight_scales;
103 (void)weight_zero_points;
104 (void)weight_quant_min;
105 (void)indices;
106 out.mutable_data_ptr<int32_t>()[0] -= static_cast<int32_t>(weight_quant_max);
107 return out;
108 }
109
110 class MakeATenFunctorFromETFunctorTest : public ::testing::Test {
111 public:
SetUp()112 void SetUp() override {
113 torch::executor::runtime_init();
114 }
115 };
116
TEST_F(MakeATenFunctorFromETFunctorTest,TestTypeMap_Scalar)117 TEST_F(MakeATenFunctorFromETFunctorTest, TestTypeMap_Scalar) {
118 EXPECT_TRUE((std::is_same<type_map<int64_t>::type, int64_t>::value));
119 }
120
TEST_F(MakeATenFunctorFromETFunctorTest,TestTypeMap_Tensor)121 TEST_F(MakeATenFunctorFromETFunctorTest, TestTypeMap_Tensor) {
122 // Normal, ref, const, and const ref.
123 EXPECT_TRUE(
124 (std::is_same<type_map<torch::executor::Tensor>::type, at::Tensor>::
125 value));
126 EXPECT_TRUE(
127 (std::is_same<type_map<torch::executor::Tensor&>::type, at::Tensor&>::
128 value));
129 EXPECT_TRUE((std::is_same<
130 type_map<const torch::executor::Tensor>::type,
131 const at::Tensor>::value));
132 EXPECT_TRUE((std::is_same<
133 type_map<const torch::executor::Tensor&>::type,
134 const at::Tensor&>::value));
135 }
136
TEST_F(MakeATenFunctorFromETFunctorTest,TestTypeMap_Optionals)137 TEST_F(MakeATenFunctorFromETFunctorTest, TestTypeMap_Optionals) {
138 // Scalar.
139 EXPECT_TRUE((std::is_same<
140 type_map<torch::executor::optional<int64_t>>::type,
141 std::optional<int64_t>>::value));
142 // Tensor.
143 EXPECT_TRUE(
144 (std::is_same<
145 type_map<torch::executor::optional<torch::executor::Tensor>>::type,
146 std::optional<at::Tensor>>::value));
147 // ArrayRef.
148 EXPECT_TRUE((std::is_same<
149 type_map<torch::executor::optional<
150 torch::executor::ArrayRef<int64_t>>>::type,
151 std::optional<c10::ArrayRef<int64_t>>>::value));
152 EXPECT_TRUE((std::is_same<
153 type_map<torch::executor::optional<
154 torch::executor::ArrayRef<torch::executor::Tensor>>>::type,
155 std::optional<c10::ArrayRef<at::Tensor>>>::value));
156 }
157
TEST_F(MakeATenFunctorFromETFunctorTest,TestTypeMap_ArrayRef)158 TEST_F(MakeATenFunctorFromETFunctorTest, TestTypeMap_ArrayRef) {
159 // Scalar.
160 EXPECT_TRUE((std::is_same<
161 type_map<torch::executor::ArrayRef<int64_t>>::type,
162 c10::ArrayRef<int64_t>>::value));
163 // Tensor.
164 EXPECT_TRUE(
165 (std::is_same<
166 type_map<torch::executor::ArrayRef<torch::executor::Tensor>>::type,
167 c10::ArrayRef<at::Tensor>>::value));
168 // Optionals.
169 EXPECT_TRUE((std::is_same<
170 type_map<torch::executor::ArrayRef<
171 torch::executor::optional<int64_t>>>::type,
172 c10::ArrayRef<std::optional<int64_t>>>::value));
173 EXPECT_TRUE((std::is_same<
174 type_map<torch::executor::ArrayRef<
175 torch::executor::optional<torch::executor::Tensor>>>::type,
176 c10::ArrayRef<std::optional<at::Tensor>>>::value));
177 }
178
TEST_F(MakeATenFunctorFromETFunctorTest,TestConvert_Tensor)179 TEST_F(MakeATenFunctorFromETFunctorTest, TestConvert_Tensor) {
180 // Convert at to et.
181 at::Tensor at_in = torch::tensor({1});
182 auto et = type_convert<at::Tensor, torch::executor::Tensor>(at_in).call();
183 EXPECT_TRUE((std::is_same<decltype(et), torch::executor::Tensor>::value));
184
185 // Convert et to at.
186 torch::executor::testing::TensorFactory<ScalarType::Int> tf;
187 torch::executor::Tensor et_in = tf.ones({3});
188 auto at_out = type_convert<torch::executor::Tensor, at::Tensor>(et_in).call();
189 EXPECT_TRUE((std::is_same<decltype(at_out), at::Tensor>::value));
190 }
191
TEST_F(MakeATenFunctorFromETFunctorTest,TestConvert_OptionalScalar)192 TEST_F(MakeATenFunctorFromETFunctorTest, TestConvert_OptionalScalar) {
193 // Convert optional at to et.
194 auto optional_at_in = std::optional<int64_t>();
195 auto optional_et =
196 type_convert<std::optional<int64_t>, torch::executor::optional<int64_t>>(
197 optional_at_in)
198 .call();
199 EXPECT_TRUE(
200 (std::is_same<decltype(optional_et), torch::executor::optional<int64_t>>::
201 value));
202
203 // Convert optional et to at.
204 auto optional_et_in = torch::executor::optional<int64_t>();
205 auto optional_at_out =
206 type_convert<torch::executor::optional<int64_t>, std::optional<int64_t>>(
207 optional_et_in)
208 .call();
209 EXPECT_TRUE(
210 (std::is_same<decltype(optional_at_out), std::optional<int64_t>>::value));
211 }
212
TEST_F(MakeATenFunctorFromETFunctorTest,TestConvert_OptionalTensor)213 TEST_F(MakeATenFunctorFromETFunctorTest, TestConvert_OptionalTensor) {
214 // Convert optional at to et.
215 auto optional_at_in = std::optional<at::Tensor>();
216 auto optional_et =
217 type_convert<
218 std::optional<at::Tensor>,
219 torch::executor::optional<torch::executor::Tensor>>(optional_at_in)
220 .call();
221 EXPECT_TRUE((std::is_same<
222 decltype(optional_et),
223 torch::executor::optional<torch::executor::Tensor>>::value));
224
225 // Convert optional et to at.
226 torch::executor::testing::TensorFactory<ScalarType::Int> tf;
227 auto et_in = torch::executor::optional<torch::executor::Tensor>(tf.ones({3}));
228 auto optional_at_out = type_convert<
229 torch::executor::optional<torch::executor::Tensor>,
230 std::optional<at::Tensor>>(optional_et)
231 .call();
232 EXPECT_TRUE(
233 (std::is_same<decltype(optional_at_out), std::optional<at::Tensor>>::
234 value));
235 }
236
TEST_F(MakeATenFunctorFromETFunctorTest,TestConvert_ArrayRefScalar)237 TEST_F(MakeATenFunctorFromETFunctorTest, TestConvert_ArrayRefScalar) {
238 // Convert arrayref at to et.
239 const std::vector<int64_t> vec = {1, 2, 3};
240 c10::ArrayRef<int64_t> arrayref_at_in = c10::ArrayRef<int64_t>(vec);
241 auto arrayref_et =
242 type_convert<c10::ArrayRef<int64_t>, torch::executor::ArrayRef<int64_t>>(
243 arrayref_at_in)
244 .call();
245 EXPECT_TRUE(
246 (std::is_same<decltype(arrayref_et), torch::executor::ArrayRef<int64_t>>::
247 value));
248
249 // Convert array ref et to at.
250 auto arrayref_et_in =
251 torch::executor::ArrayRef<int64_t>(vec.data(), vec.size());
252
253 auto arrayref_at_out =
254 type_convert<torch::executor::ArrayRef<int64_t>, c10::ArrayRef<int64_t>>(
255 arrayref_et_in)
256 .call();
257 EXPECT_TRUE(
258 (std::is_same<decltype(arrayref_at_out), c10::ArrayRef<int64_t>>::value));
259 }
260
TEST_F(MakeATenFunctorFromETFunctorTest,TestConvert_ArrayRefTensor)261 TEST_F(MakeATenFunctorFromETFunctorTest, TestConvert_ArrayRefTensor) {
262 // Convert arrayref at to et.
263 const std::vector<at::Tensor> vec_at{torch::tensor({1}), torch::tensor({2})};
264 c10::ArrayRef<at::Tensor> arrayref_at_in =
265 c10::ArrayRef<at::Tensor>(vec_at.data(), vec_at.size());
266
267 auto arrayref_et =
268 type_convert<
269 c10::ArrayRef<at::Tensor>,
270 torch::executor::ArrayRef<torch::executor::Tensor>>(arrayref_at_in)
271 .call();
272 EXPECT_TRUE((std::is_same<
273 decltype(arrayref_et),
274 torch::executor::ArrayRef<torch::executor::Tensor>>::value));
275 // Convert array ref et to at.
276 torch::executor::testing::TensorFactory<ScalarType::Int> tf;
277 std::vector<torch::executor::Tensor> vec_et{tf.ones({1}), tf.ones({2})};
278 auto arrayref_et_in = torch::executor::ArrayRef<torch::executor::Tensor>(
279 vec_et.data(), vec_et.size());
280
281 auto arrayref_at_out = type_convert<
282 torch::executor::ArrayRef<torch::executor::Tensor>,
283 c10::ArrayRef<at::Tensor>>(arrayref_et_in)
284 .call();
285 EXPECT_TRUE(
286 (std::is_same<decltype(arrayref_at_out), c10::ArrayRef<at::Tensor>>::
287 value));
288 }
289
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_Basic)290 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_Basic) {
291 auto function = WRAP_TO_ATEN(my_op_out, 1);
292 at::Tensor a = torch::tensor({1.0f});
293 at::Tensor b = torch::tensor({2.0f});
294 at::Tensor c = function(a, b);
295 EXPECT_EQ(c.const_data_ptr<float>()[0], 2.0f);
296 }
297
298 // Register operators.
TORCH_LIBRARY(my_op,m)299 TORCH_LIBRARY(my_op, m) {
300 m.def("add_1.out", WRAP_TO_ATEN(add_1_out, 1));
301 m.def(
302 "embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
303 WRAP_TO_ATEN(quantized_embedding_byte_out, 6));
304 m.def("add_optional_scalar.out", WRAP_TO_ATEN(add_optional_scalar_out, 2));
305 m.def("add_optional_tensor.out", WRAP_TO_ATEN(add_optional_tensor_out, 2));
306 m.def("sum_arrayref_scalar.out", WRAP_TO_ATEN(sum_arrayref_scalar_out, 1));
307 m.def("sum_arrayref_tensor.out", WRAP_TO_ATEN(sum_arrayref_tensor_out, 1));
308 m.def(
309 "sum_arrayref_optional_tensor.out",
310 WRAP_TO_ATEN(sum_arrayref_optional_tensor_out, 1));
311 };
312
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_RegisterWrappedFunction)313 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_RegisterWrappedFunction) {
314 auto op = c10::Dispatcher::singleton().findSchema({"my_op::add_1", "out"});
315 EXPECT_TRUE(op.has_value());
316 at::Tensor a =
317 torch::tensor({1}, torch::TensorOptions().dtype(torch::kInt32));
318 at::Tensor b =
319 torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32));
320 torch::jit::Stack stack = {a, b};
321 op.value().callBoxed(&stack);
322 EXPECT_EQ(stack.size(), 1);
323 EXPECT_EQ(stack[0].toTensor().const_data_ptr<int32_t>()[0], 3);
324 }
325
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_EmbeddingByte)326 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_EmbeddingByte) {
327 auto op =
328 c10::Dispatcher::singleton().findSchema({"my_op::embedding_byte", "out"});
329 EXPECT_TRUE(op.has_value());
330 at::Tensor weight =
331 torch::tensor({1}, torch::TensorOptions().dtype(torch::kInt32));
332 at::Tensor scale =
333 torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32));
334 at::Tensor zero_point =
335 torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32));
336 at::Tensor indices =
337 torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32));
338 at::Tensor out =
339 torch::tensor({4}, torch::TensorOptions().dtype(torch::kInt32));
340 torch::jit::Stack stack = {weight, scale, zero_point, 0, 1, indices, out};
341 op.value().callBoxed(&stack);
342 EXPECT_EQ(stack.size(), 1);
343 EXPECT_EQ(stack[0].toTensor().const_data_ptr<int32_t>()[0], 3);
344 }
345
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_OptionalScalarAdd)346 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_OptionalScalarAdd) {
347 std::optional<int64_t> a = std::optional<int64_t>(3);
348 std::optional<int64_t> b = std::optional<int64_t>();
349 at::Tensor out = torch::tensor({0});
350
351 auto op = c10::Dispatcher::singleton().findSchema(
352 {"my_op::add_optional_scalar", "out"});
353 EXPECT_TRUE(op.has_value());
354 torch::jit::Stack stack = {a, b, out};
355 op.value().callBoxed(&stack);
356
357 EXPECT_EQ(stack.size(), 1);
358 EXPECT_EQ(stack[0].toTensor().const_data_ptr<int64_t>()[0], 3);
359 }
360
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_OptionalTensorAdd)361 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_OptionalTensorAdd) {
362 std::optional<at::Tensor> a = std::optional<at::Tensor>(torch::tensor({8}));
363 std::optional<at::Tensor> b = std::optional<at::Tensor>();
364 at::Tensor out = torch::tensor({0});
365
366 auto op = c10::Dispatcher::singleton().findSchema(
367 {"my_op::add_optional_tensor", "out"});
368 EXPECT_TRUE(op.has_value());
369 torch::jit::Stack stack = {a, b, out};
370 op.value().callBoxed(&stack);
371
372 EXPECT_EQ(stack.size(), 1);
373 EXPECT_EQ(stack[0].toTensor().const_data_ptr<int64_t>()[0], 8);
374 }
375
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_ArrayRefScalarAdd)376 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_ArrayRefScalarAdd) {
377 std::vector<int64_t> vec{2, 3, 4};
378 at::ArrayRef<int64_t> arrayref = at::ArrayRef(vec.data(), vec.size());
379 at::Tensor out = torch::tensor({0});
380
381 auto op = c10::Dispatcher::singleton().findSchema(
382 {"my_op::sum_arrayref_scalar", "out"});
383 EXPECT_TRUE(op.has_value());
384 torch::jit::Stack stack = {arrayref, out};
385 op.value().callBoxed(&stack);
386
387 EXPECT_EQ(stack.size(), 1);
388 EXPECT_EQ(stack[0].toTensor().const_data_ptr<int64_t>()[0], 9);
389 }
390
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_ArrayRefTensorAdd)391 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_ArrayRefTensorAdd) {
392 std::vector<at::Tensor> vec{
393 torch::tensor({1}), torch::tensor({2}), torch::tensor({3})};
394 at::ArrayRef arrayref = at::ArrayRef(vec.data(), vec.size());
395 at::Tensor out = torch::tensor({0});
396
397 auto op = c10::Dispatcher::singleton().findSchema(
398 {"my_op::sum_arrayref_tensor", "out"});
399 EXPECT_TRUE(op.has_value());
400 torch::jit::Stack stack = {arrayref, out};
401 op.value().callBoxed(&stack);
402
403 EXPECT_EQ(stack.size(), 1);
404 EXPECT_EQ(stack[0].toTensor().const_data_ptr<int64_t>()[0], 6);
405 }
406
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_ArrayRefOptional)407 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_ArrayRefOptional) {
408 std::vector<std::optional<at::Tensor>> vec{
409 std::optional<at::Tensor>(torch::tensor({1})),
410 std::optional<at::Tensor>(),
411 std::optional<at::Tensor>(torch::tensor({3}))};
412 at::Tensor out = torch::tensor({0});
413
414 at::ArrayRef arrayref = at::ArrayRef(vec.data(), vec.size());
415 auto op = c10::Dispatcher::singleton().findSchema(
416 {"my_op::sum_arrayref_optional_tensor", "out"});
417 EXPECT_TRUE(op.has_value());
418 torch::jit::Stack stack = {arrayref, out};
419 op.value().callBoxed(&stack);
420
421 EXPECT_EQ(stack.size(), 1);
422 EXPECT_EQ(stack[0].toTensor().const_data_ptr<int64_t>()[0], 4);
423 }
424