1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
12 #include <executorch/runtime/core/portable_type/tensor.h>
13 #include <executorch/runtime/platform/runtime.h>
14 #include <gtest/gtest.h>
15 #include <torch/library.h>
16 
17 using namespace ::testing;
18 using ::executorch::extension::internal::type_convert;
19 using ::executorch::extension::internal::type_map;
20 using ::torch::executor::ScalarType;
21 using ::torch::executor::Tensor;
22 
my_op_out(const Tensor & a,Tensor & out)23 Tensor& my_op_out(const Tensor& a, Tensor& out) {
24   (void)a;
25   return out;
26 }
27 
add_1_out(const Tensor & a,Tensor & out)28 Tensor& add_1_out(const Tensor& a, Tensor& out) {
29   (void)a;
30   out.mutable_data_ptr<int32_t>()[0] += 1;
31   return out;
32 }
33 
add_optional_scalar_out(torch::executor::optional<int64_t> s1,torch::executor::optional<int64_t> s2,Tensor & out)34 Tensor& add_optional_scalar_out(
35     torch::executor::optional<int64_t> s1,
36     torch::executor::optional<int64_t> s2,
37     Tensor& out) {
38   if (s1.has_value()) {
39     out.mutable_data_ptr<int64_t>()[0] += s1.value();
40   }
41   if (s2.has_value()) {
42     out.mutable_data_ptr<int64_t>()[0] += s2.value();
43   }
44   return out;
45 }
46 
add_optional_tensor_out(torch::executor::optional<torch::executor::Tensor> s1,torch::executor::optional<torch::executor::Tensor> s2,Tensor & out)47 Tensor& add_optional_tensor_out(
48     torch::executor::optional<torch::executor::Tensor> s1,
49     torch::executor::optional<torch::executor::Tensor> s2,
50     Tensor& out) {
51   if (s1.has_value()) {
52     out.mutable_data_ptr<int64_t>()[0] +=
53         s1.value().mutable_data_ptr<int64_t>()[0];
54   }
55   if (s2.has_value()) {
56     out.mutable_data_ptr<int64_t>()[0] +=
57         s2.value().mutable_data_ptr<int64_t>()[0];
58   }
59   return out;
60 }
61 
sum_arrayref_scalar_out(torch::executor::ArrayRef<int64_t> a,Tensor & out)62 Tensor& sum_arrayref_scalar_out(
63     torch::executor::ArrayRef<int64_t> a,
64     Tensor& out) {
65   for (int i = 0; i < a.size(); i++) {
66     out.mutable_data_ptr<int64_t>()[0] += a[i];
67   }
68   return out;
69 }
70 
sum_arrayref_tensor_out(torch::executor::ArrayRef<torch::executor::Tensor> a,Tensor & out)71 Tensor& sum_arrayref_tensor_out(
72     torch::executor::ArrayRef<torch::executor::Tensor> a,
73     Tensor& out) {
74   for (int i = 0; i < a.size(); i++) {
75     out.mutable_data_ptr<int32_t>()[0] += a[i].const_data_ptr<int32_t>()[0];
76   }
77   return out;
78 }
79 
sum_arrayref_optional_tensor_out(torch::executor::ArrayRef<torch::executor::optional<torch::executor::Tensor>> a,Tensor & out)80 Tensor& sum_arrayref_optional_tensor_out(
81     torch::executor::ArrayRef<
82         torch::executor::optional<torch::executor::Tensor>> a,
83     Tensor& out) {
84   for (int i = 0; i < a.size(); i++) {
85     if (a[i].has_value()) {
86       out.mutable_data_ptr<int32_t>()[0] +=
87           a[i].value().const_data_ptr<int32_t>()[0];
88     }
89   }
90   return out;
91 }
92 
quantized_embedding_byte_out(const Tensor & weight,const Tensor & weight_scales,const Tensor & weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,Tensor & out)93 Tensor& quantized_embedding_byte_out(
94     const Tensor& weight,
95     const Tensor& weight_scales,
96     const Tensor& weight_zero_points,
97     int64_t weight_quant_min,
98     int64_t weight_quant_max,
99     const Tensor& indices,
100     Tensor& out) {
101   (void)weight;
102   (void)weight_scales;
103   (void)weight_zero_points;
104   (void)weight_quant_min;
105   (void)indices;
106   out.mutable_data_ptr<int32_t>()[0] -= static_cast<int32_t>(weight_quant_max);
107   return out;
108 }
109 
110 class MakeATenFunctorFromETFunctorTest : public ::testing::Test {
111  public:
SetUp()112   void SetUp() override {
113     torch::executor::runtime_init();
114   }
115 };
116 
TEST_F(MakeATenFunctorFromETFunctorTest,TestTypeMap_Scalar)117 TEST_F(MakeATenFunctorFromETFunctorTest, TestTypeMap_Scalar) {
118   EXPECT_TRUE((std::is_same<type_map<int64_t>::type, int64_t>::value));
119 }
120 
TEST_F(MakeATenFunctorFromETFunctorTest,TestTypeMap_Tensor)121 TEST_F(MakeATenFunctorFromETFunctorTest, TestTypeMap_Tensor) {
122   // Normal, ref, const, and const ref.
123   EXPECT_TRUE(
124       (std::is_same<type_map<torch::executor::Tensor>::type, at::Tensor>::
125            value));
126   EXPECT_TRUE(
127       (std::is_same<type_map<torch::executor::Tensor&>::type, at::Tensor&>::
128            value));
129   EXPECT_TRUE((std::is_same<
130                type_map<const torch::executor::Tensor>::type,
131                const at::Tensor>::value));
132   EXPECT_TRUE((std::is_same<
133                type_map<const torch::executor::Tensor&>::type,
134                const at::Tensor&>::value));
135 }
136 
TEST_F(MakeATenFunctorFromETFunctorTest,TestTypeMap_Optionals)137 TEST_F(MakeATenFunctorFromETFunctorTest, TestTypeMap_Optionals) {
138   // Scalar.
139   EXPECT_TRUE((std::is_same<
140                type_map<torch::executor::optional<int64_t>>::type,
141                std::optional<int64_t>>::value));
142   // Tensor.
143   EXPECT_TRUE(
144       (std::is_same<
145           type_map<torch::executor::optional<torch::executor::Tensor>>::type,
146           std::optional<at::Tensor>>::value));
147   // ArrayRef.
148   EXPECT_TRUE((std::is_same<
149                type_map<torch::executor::optional<
150                    torch::executor::ArrayRef<int64_t>>>::type,
151                std::optional<c10::ArrayRef<int64_t>>>::value));
152   EXPECT_TRUE((std::is_same<
153                type_map<torch::executor::optional<
154                    torch::executor::ArrayRef<torch::executor::Tensor>>>::type,
155                std::optional<c10::ArrayRef<at::Tensor>>>::value));
156 }
157 
TEST_F(MakeATenFunctorFromETFunctorTest,TestTypeMap_ArrayRef)158 TEST_F(MakeATenFunctorFromETFunctorTest, TestTypeMap_ArrayRef) {
159   // Scalar.
160   EXPECT_TRUE((std::is_same<
161                type_map<torch::executor::ArrayRef<int64_t>>::type,
162                c10::ArrayRef<int64_t>>::value));
163   // Tensor.
164   EXPECT_TRUE(
165       (std::is_same<
166           type_map<torch::executor::ArrayRef<torch::executor::Tensor>>::type,
167           c10::ArrayRef<at::Tensor>>::value));
168   // Optionals.
169   EXPECT_TRUE((std::is_same<
170                type_map<torch::executor::ArrayRef<
171                    torch::executor::optional<int64_t>>>::type,
172                c10::ArrayRef<std::optional<int64_t>>>::value));
173   EXPECT_TRUE((std::is_same<
174                type_map<torch::executor::ArrayRef<
175                    torch::executor::optional<torch::executor::Tensor>>>::type,
176                c10::ArrayRef<std::optional<at::Tensor>>>::value));
177 }
178 
TEST_F(MakeATenFunctorFromETFunctorTest,TestConvert_Tensor)179 TEST_F(MakeATenFunctorFromETFunctorTest, TestConvert_Tensor) {
180   // Convert at to et.
181   at::Tensor at_in = torch::tensor({1});
182   auto et = type_convert<at::Tensor, torch::executor::Tensor>(at_in).call();
183   EXPECT_TRUE((std::is_same<decltype(et), torch::executor::Tensor>::value));
184 
185   // Convert et to at.
186   torch::executor::testing::TensorFactory<ScalarType::Int> tf;
187   torch::executor::Tensor et_in = tf.ones({3});
188   auto at_out = type_convert<torch::executor::Tensor, at::Tensor>(et_in).call();
189   EXPECT_TRUE((std::is_same<decltype(at_out), at::Tensor>::value));
190 }
191 
TEST_F(MakeATenFunctorFromETFunctorTest,TestConvert_OptionalScalar)192 TEST_F(MakeATenFunctorFromETFunctorTest, TestConvert_OptionalScalar) {
193   // Convert optional at to et.
194   auto optional_at_in = std::optional<int64_t>();
195   auto optional_et =
196       type_convert<std::optional<int64_t>, torch::executor::optional<int64_t>>(
197           optional_at_in)
198           .call();
199   EXPECT_TRUE(
200       (std::is_same<decltype(optional_et), torch::executor::optional<int64_t>>::
201            value));
202 
203   // Convert optional et to at.
204   auto optional_et_in = torch::executor::optional<int64_t>();
205   auto optional_at_out =
206       type_convert<torch::executor::optional<int64_t>, std::optional<int64_t>>(
207           optional_et_in)
208           .call();
209   EXPECT_TRUE(
210       (std::is_same<decltype(optional_at_out), std::optional<int64_t>>::value));
211 }
212 
TEST_F(MakeATenFunctorFromETFunctorTest,TestConvert_OptionalTensor)213 TEST_F(MakeATenFunctorFromETFunctorTest, TestConvert_OptionalTensor) {
214   // Convert optional at to et.
215   auto optional_at_in = std::optional<at::Tensor>();
216   auto optional_et =
217       type_convert<
218           std::optional<at::Tensor>,
219           torch::executor::optional<torch::executor::Tensor>>(optional_at_in)
220           .call();
221   EXPECT_TRUE((std::is_same<
222                decltype(optional_et),
223                torch::executor::optional<torch::executor::Tensor>>::value));
224 
225   // Convert optional et to at.
226   torch::executor::testing::TensorFactory<ScalarType::Int> tf;
227   auto et_in = torch::executor::optional<torch::executor::Tensor>(tf.ones({3}));
228   auto optional_at_out = type_convert<
229                              torch::executor::optional<torch::executor::Tensor>,
230                              std::optional<at::Tensor>>(optional_et)
231                              .call();
232   EXPECT_TRUE(
233       (std::is_same<decltype(optional_at_out), std::optional<at::Tensor>>::
234            value));
235 }
236 
TEST_F(MakeATenFunctorFromETFunctorTest,TestConvert_ArrayRefScalar)237 TEST_F(MakeATenFunctorFromETFunctorTest, TestConvert_ArrayRefScalar) {
238   // Convert arrayref at to et.
239   const std::vector<int64_t> vec = {1, 2, 3};
240   c10::ArrayRef<int64_t> arrayref_at_in = c10::ArrayRef<int64_t>(vec);
241   auto arrayref_et =
242       type_convert<c10::ArrayRef<int64_t>, torch::executor::ArrayRef<int64_t>>(
243           arrayref_at_in)
244           .call();
245   EXPECT_TRUE(
246       (std::is_same<decltype(arrayref_et), torch::executor::ArrayRef<int64_t>>::
247            value));
248 
249   // Convert array ref et to at.
250   auto arrayref_et_in =
251       torch::executor::ArrayRef<int64_t>(vec.data(), vec.size());
252 
253   auto arrayref_at_out =
254       type_convert<torch::executor::ArrayRef<int64_t>, c10::ArrayRef<int64_t>>(
255           arrayref_et_in)
256           .call();
257   EXPECT_TRUE(
258       (std::is_same<decltype(arrayref_at_out), c10::ArrayRef<int64_t>>::value));
259 }
260 
TEST_F(MakeATenFunctorFromETFunctorTest,TestConvert_ArrayRefTensor)261 TEST_F(MakeATenFunctorFromETFunctorTest, TestConvert_ArrayRefTensor) {
262   // Convert arrayref at to et.
263   const std::vector<at::Tensor> vec_at{torch::tensor({1}), torch::tensor({2})};
264   c10::ArrayRef<at::Tensor> arrayref_at_in =
265       c10::ArrayRef<at::Tensor>(vec_at.data(), vec_at.size());
266 
267   auto arrayref_et =
268       type_convert<
269           c10::ArrayRef<at::Tensor>,
270           torch::executor::ArrayRef<torch::executor::Tensor>>(arrayref_at_in)
271           .call();
272   EXPECT_TRUE((std::is_same<
273                decltype(arrayref_et),
274                torch::executor::ArrayRef<torch::executor::Tensor>>::value));
275   // Convert array ref et to at.
276   torch::executor::testing::TensorFactory<ScalarType::Int> tf;
277   std::vector<torch::executor::Tensor> vec_et{tf.ones({1}), tf.ones({2})};
278   auto arrayref_et_in = torch::executor::ArrayRef<torch::executor::Tensor>(
279       vec_et.data(), vec_et.size());
280 
281   auto arrayref_at_out = type_convert<
282                              torch::executor::ArrayRef<torch::executor::Tensor>,
283                              c10::ArrayRef<at::Tensor>>(arrayref_et_in)
284                              .call();
285   EXPECT_TRUE(
286       (std::is_same<decltype(arrayref_at_out), c10::ArrayRef<at::Tensor>>::
287            value));
288 }
289 
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_Basic)290 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_Basic) {
291   auto function = WRAP_TO_ATEN(my_op_out, 1);
292   at::Tensor a = torch::tensor({1.0f});
293   at::Tensor b = torch::tensor({2.0f});
294   at::Tensor c = function(a, b);
295   EXPECT_EQ(c.const_data_ptr<float>()[0], 2.0f);
296 }
297 
298 // Register operators.
TORCH_LIBRARY(my_op,m)299 TORCH_LIBRARY(my_op, m) {
300   m.def("add_1.out", WRAP_TO_ATEN(add_1_out, 1));
301   m.def(
302       "embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
303       WRAP_TO_ATEN(quantized_embedding_byte_out, 6));
304   m.def("add_optional_scalar.out", WRAP_TO_ATEN(add_optional_scalar_out, 2));
305   m.def("add_optional_tensor.out", WRAP_TO_ATEN(add_optional_tensor_out, 2));
306   m.def("sum_arrayref_scalar.out", WRAP_TO_ATEN(sum_arrayref_scalar_out, 1));
307   m.def("sum_arrayref_tensor.out", WRAP_TO_ATEN(sum_arrayref_tensor_out, 1));
308   m.def(
309       "sum_arrayref_optional_tensor.out",
310       WRAP_TO_ATEN(sum_arrayref_optional_tensor_out, 1));
311 };
312 
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_RegisterWrappedFunction)313 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_RegisterWrappedFunction) {
314   auto op = c10::Dispatcher::singleton().findSchema({"my_op::add_1", "out"});
315   EXPECT_TRUE(op.has_value());
316   at::Tensor a =
317       torch::tensor({1}, torch::TensorOptions().dtype(torch::kInt32));
318   at::Tensor b =
319       torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32));
320   torch::jit::Stack stack = {a, b};
321   op.value().callBoxed(&stack);
322   EXPECT_EQ(stack.size(), 1);
323   EXPECT_EQ(stack[0].toTensor().const_data_ptr<int32_t>()[0], 3);
324 }
325 
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_EmbeddingByte)326 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_EmbeddingByte) {
327   auto op =
328       c10::Dispatcher::singleton().findSchema({"my_op::embedding_byte", "out"});
329   EXPECT_TRUE(op.has_value());
330   at::Tensor weight =
331       torch::tensor({1}, torch::TensorOptions().dtype(torch::kInt32));
332   at::Tensor scale =
333       torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32));
334   at::Tensor zero_point =
335       torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32));
336   at::Tensor indices =
337       torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32));
338   at::Tensor out =
339       torch::tensor({4}, torch::TensorOptions().dtype(torch::kInt32));
340   torch::jit::Stack stack = {weight, scale, zero_point, 0, 1, indices, out};
341   op.value().callBoxed(&stack);
342   EXPECT_EQ(stack.size(), 1);
343   EXPECT_EQ(stack[0].toTensor().const_data_ptr<int32_t>()[0], 3);
344 }
345 
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_OptionalScalarAdd)346 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_OptionalScalarAdd) {
347   std::optional<int64_t> a = std::optional<int64_t>(3);
348   std::optional<int64_t> b = std::optional<int64_t>();
349   at::Tensor out = torch::tensor({0});
350 
351   auto op = c10::Dispatcher::singleton().findSchema(
352       {"my_op::add_optional_scalar", "out"});
353   EXPECT_TRUE(op.has_value());
354   torch::jit::Stack stack = {a, b, out};
355   op.value().callBoxed(&stack);
356 
357   EXPECT_EQ(stack.size(), 1);
358   EXPECT_EQ(stack[0].toTensor().const_data_ptr<int64_t>()[0], 3);
359 }
360 
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_OptionalTensorAdd)361 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_OptionalTensorAdd) {
362   std::optional<at::Tensor> a = std::optional<at::Tensor>(torch::tensor({8}));
363   std::optional<at::Tensor> b = std::optional<at::Tensor>();
364   at::Tensor out = torch::tensor({0});
365 
366   auto op = c10::Dispatcher::singleton().findSchema(
367       {"my_op::add_optional_tensor", "out"});
368   EXPECT_TRUE(op.has_value());
369   torch::jit::Stack stack = {a, b, out};
370   op.value().callBoxed(&stack);
371 
372   EXPECT_EQ(stack.size(), 1);
373   EXPECT_EQ(stack[0].toTensor().const_data_ptr<int64_t>()[0], 8);
374 }
375 
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_ArrayRefScalarAdd)376 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_ArrayRefScalarAdd) {
377   std::vector<int64_t> vec{2, 3, 4};
378   at::ArrayRef<int64_t> arrayref = at::ArrayRef(vec.data(), vec.size());
379   at::Tensor out = torch::tensor({0});
380 
381   auto op = c10::Dispatcher::singleton().findSchema(
382       {"my_op::sum_arrayref_scalar", "out"});
383   EXPECT_TRUE(op.has_value());
384   torch::jit::Stack stack = {arrayref, out};
385   op.value().callBoxed(&stack);
386 
387   EXPECT_EQ(stack.size(), 1);
388   EXPECT_EQ(stack[0].toTensor().const_data_ptr<int64_t>()[0], 9);
389 }
390 
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_ArrayRefTensorAdd)391 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_ArrayRefTensorAdd) {
392   std::vector<at::Tensor> vec{
393       torch::tensor({1}), torch::tensor({2}), torch::tensor({3})};
394   at::ArrayRef arrayref = at::ArrayRef(vec.data(), vec.size());
395   at::Tensor out = torch::tensor({0});
396 
397   auto op = c10::Dispatcher::singleton().findSchema(
398       {"my_op::sum_arrayref_tensor", "out"});
399   EXPECT_TRUE(op.has_value());
400   torch::jit::Stack stack = {arrayref, out};
401   op.value().callBoxed(&stack);
402 
403   EXPECT_EQ(stack.size(), 1);
404   EXPECT_EQ(stack[0].toTensor().const_data_ptr<int64_t>()[0], 6);
405 }
406 
TEST_F(MakeATenFunctorFromETFunctorTest,TestWrap_ArrayRefOptional)407 TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_ArrayRefOptional) {
408   std::vector<std::optional<at::Tensor>> vec{
409       std::optional<at::Tensor>(torch::tensor({1})),
410       std::optional<at::Tensor>(),
411       std::optional<at::Tensor>(torch::tensor({3}))};
412   at::Tensor out = torch::tensor({0});
413 
414   at::ArrayRef arrayref = at::ArrayRef(vec.data(), vec.size());
415   auto op = c10::Dispatcher::singleton().findSchema(
416       {"my_op::sum_arrayref_optional_tensor", "out"});
417   EXPECT_TRUE(op.has_value());
418   torch::jit::Stack stack = {arrayref, out};
419   op.value().callBoxed(&stack);
420 
421   EXPECT_EQ(stack.size(), 1);
422   EXPECT_EQ(stack[0].toTensor().const_data_ptr<int64_t>()[0], 4);
423 }
424