xref: /aosp_15_r20/external/executorch/runtime/kernel/test/operator_registry_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <gtest/gtest.h>
10 #include <vector>
11 
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/result.h>
14 #include <executorch/runtime/core/span.h>
15 #include <executorch/runtime/kernel/kernel_runtime_context.h>
16 #include <executorch/runtime/kernel/operator_registry.h>
17 #include <executorch/runtime/kernel/test/test_util.h>
18 #include <executorch/runtime/platform/runtime.h>
19 #include <executorch/test/utils/DeathTest.h>
20 
21 using namespace ::testing;
22 using exec_aten::Scalar;
23 using exec_aten::ScalarType;
24 using exec_aten::Tensor;
25 using executorch::runtime::Error;
26 using executorch::runtime::EValue;
27 using executorch::runtime::get_op_function_from_registry;
28 using executorch::runtime::Kernel;
29 using executorch::runtime::KernelKey;
30 using executorch::runtime::KernelRuntimeContext;
31 using executorch::runtime::OpFunction;
32 using executorch::runtime::register_kernels;
33 using executorch::runtime::registry_has_op_function;
34 using executorch::runtime::Result;
35 using executorch::runtime::Span;
36 using executorch::runtime::TensorMeta;
37 using executorch::runtime::testing::make_kernel_key;
38 
39 class OperatorRegistryTest : public ::testing::Test {
40  public:
SetUp()41   void SetUp() override {
42     executorch::runtime::runtime_init();
43   }
44 };
45 
TEST_F(OperatorRegistryTest,Basic)46 TEST_F(OperatorRegistryTest, Basic) {
47   Kernel kernels[] = {Kernel("foo", [](KernelRuntimeContext&, EValue**) {})};
48   Span<const Kernel> kernels_span(kernels);
49   (void)register_kernels(kernels_span);
50   EXPECT_FALSE(registry_has_op_function("fpp"));
51   EXPECT_TRUE(registry_has_op_function("foo"));
52 }
53 
TEST_F(OperatorRegistryTest,RegisterOpsMoreThanOnceDie)54 TEST_F(OperatorRegistryTest, RegisterOpsMoreThanOnceDie) {
55   Kernel kernels[] = {
56       Kernel("foo", [](KernelRuntimeContext&, EValue**) {}),
57       Kernel("foo", [](KernelRuntimeContext&, EValue**) {})};
58   Span<const Kernel> kernels_span = Span<const Kernel>(kernels);
59   ET_EXPECT_DEATH({ (void)register_kernels(kernels_span); }, "");
60 }
61 
62 constexpr int BUF_SIZE = KernelKey::MAX_SIZE;
63 
TEST_F(OperatorRegistryTest,KernelKeyEquals)64 TEST_F(OperatorRegistryTest, KernelKeyEquals) {
65   char buf_long_contiguous[BUF_SIZE];
66   make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous);
67   KernelKey long_contiguous = KernelKey(buf_long_contiguous);
68 
69   KernelKey long_key_1 = KernelKey(long_contiguous);
70 
71   KernelKey long_key_2 = KernelKey(long_contiguous);
72 
73   EXPECT_EQ(long_key_1, long_key_2);
74 
75   char buf_float_contiguous[BUF_SIZE];
76   make_kernel_key({{ScalarType::Float, {0, 1, 2, 3}}}, buf_float_contiguous);
77   KernelKey float_key = KernelKey(buf_float_contiguous);
78 
79   EXPECT_NE(long_key_1, float_key);
80 
81   char buf_channel_first[BUF_SIZE];
82   make_kernel_key({{ScalarType::Long, {0, 3, 1, 2}}}, buf_channel_first);
83   KernelKey long_key_3 = KernelKey(buf_channel_first);
84 
85   EXPECT_NE(long_key_1, long_key_3);
86 }
87 
TEST_F(OperatorRegistryTest,RegisterKernels)88 TEST_F(OperatorRegistryTest, RegisterKernels) {
89   char buf_long_contiguous[BUF_SIZE];
90   make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous);
91   KernelKey key = KernelKey(buf_long_contiguous);
92 
93   Kernel kernel_1 = Kernel(
94       "test::boo", key, [](KernelRuntimeContext& context, EValue** stack) {
95         (void)context;
96         *(stack[0]) = Scalar(100);
97       });
98   auto s1 = register_kernels({&kernel_1, 1});
99   EXPECT_EQ(s1, Error::Ok);
100 
101   Tensor::DimOrderType dims[] = {0, 1, 2, 3};
102   auto dim_order_type = Span<Tensor::DimOrderType>(dims, 4);
103   TensorMeta meta[] = {TensorMeta(ScalarType::Long, dim_order_type)};
104   Span<const TensorMeta> user_kernel_key(meta);
105 
106   // no fallback kernel is registered
107   EXPECT_FALSE(registry_has_op_function("test::boo", {}));
108   Result<OpFunction> fallback_func =
109       get_op_function_from_registry("test::boo", {});
110   EXPECT_NE(fallback_func.error(), Error::Ok);
111 
112   EXPECT_TRUE(registry_has_op_function("test::boo", user_kernel_key));
113   Result<OpFunction> func =
114       get_op_function_from_registry("test::boo", user_kernel_key);
115   EXPECT_EQ(func.error(), Error::Ok);
116 
117   EValue values[1];
118   values[0] = Scalar(0);
119   EValue* kernels[1];
120   kernels[0] = &values[0];
121   KernelRuntimeContext context{};
122   (*func)(context, kernels);
123 
124   auto val = values[0].toScalar().to<int64_t>();
125   ASSERT_EQ(val, 100);
126 }
127 
TEST_F(OperatorRegistryTest,RegisterTwoKernels)128 TEST_F(OperatorRegistryTest, RegisterTwoKernels) {
129   char buf_long_contiguous[BUF_SIZE];
130   make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous);
131   KernelKey key_1 = KernelKey(buf_long_contiguous);
132 
133   char buf_float_contiguous[BUF_SIZE];
134   make_kernel_key({{ScalarType::Float, {0, 1, 2, 3}}}, buf_float_contiguous);
135   KernelKey key_2 = KernelKey(buf_float_contiguous);
136   Kernel kernel_1 = Kernel(
137       "test::bar", key_1, [](KernelRuntimeContext& context, EValue** stack) {
138         (void)context;
139         *(stack[0]) = Scalar(100);
140       });
141   Kernel kernel_2 = Kernel(
142       "test::bar", key_2, [](KernelRuntimeContext& context, EValue** stack) {
143         (void)context;
144         *(stack[0]) = Scalar(50);
145       });
146   Kernel kernels[] = {kernel_1, kernel_2};
147   auto s1 = register_kernels(kernels);
148   // has both kernels
149   Tensor::DimOrderType dims[] = {0, 1, 2, 3};
150   auto dim_order_type = Span<Tensor::DimOrderType>(dims, 4);
151   TensorMeta meta[] = {TensorMeta(ScalarType::Long, dim_order_type)};
152   Span<const TensorMeta> user_kernel_key_1(meta);
153 
154   TensorMeta meta_2[] = {TensorMeta(ScalarType::Float, dim_order_type)};
155   Span<const TensorMeta> user_kernel_key_2(meta_2);
156 
157   // no fallback kernel is registered
158   EXPECT_FALSE(registry_has_op_function("test::bar", {}));
159   Result<OpFunction> fallback_func =
160       get_op_function_from_registry("test::bar", {});
161   EXPECT_NE(fallback_func.error(), Error::Ok);
162 
163   EValue values[1];
164   values[0] = Scalar(0);
165   EValue* evalues[1];
166   evalues[0] = &values[0];
167   KernelRuntimeContext context{};
168 
169   // test kernel_1
170   EXPECT_TRUE(registry_has_op_function("test::bar", user_kernel_key_1));
171   Result<OpFunction> func_1 =
172       get_op_function_from_registry("test::bar", user_kernel_key_1);
173   EXPECT_EQ(func_1.error(), Error::Ok);
174   (*func_1)(context, evalues);
175 
176   auto val_1 = values[0].toScalar().to<int64_t>();
177   ASSERT_EQ(val_1, 100);
178 
179   // test kernel_2
180   EXPECT_TRUE(registry_has_op_function("test::bar", user_kernel_key_2));
181   Result<OpFunction> func_2 =
182       get_op_function_from_registry("test::bar", user_kernel_key_2);
183   EXPECT_EQ(func_2.error(), Error::Ok);
184   values[0] = Scalar(0);
185   (*func_2)(context, evalues);
186 
187   auto val_2 = values[0].toScalar().to<int64_t>();
188   ASSERT_EQ(val_2, 50);
189 }
190 
TEST_F(OperatorRegistryTest,DoubleRegisterKernelsDies)191 TEST_F(OperatorRegistryTest, DoubleRegisterKernelsDies) {
192   char buf_long_contiguous[BUF_SIZE];
193   make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous);
194   KernelKey key = KernelKey(buf_long_contiguous);
195 
196   Kernel kernel_1 = Kernel(
197       "test::baz", key, [](KernelRuntimeContext& context, EValue** stack) {
198         (void)context;
199         *(stack[0]) = Scalar(100);
200       });
201   Kernel kernel_2 = Kernel(
202       "test::baz", key, [](KernelRuntimeContext& context, EValue** stack) {
203         (void)context;
204         *(stack[0]) = Scalar(50);
205       });
206   Kernel kernels[] = {kernel_1, kernel_2};
207   // clang-tidy off
208   ET_EXPECT_DEATH({ auto s1 = register_kernels(kernels); }, "");
209   // clang-tidy on
210 }
211 
TEST_F(OperatorRegistryTest,ExecutorChecksKernel)212 TEST_F(OperatorRegistryTest, ExecutorChecksKernel) {
213   char buf_long_contiguous[BUF_SIZE];
214   make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous);
215   KernelKey key = KernelKey(buf_long_contiguous);
216 
217   Kernel kernel_1 = Kernel(
218       "test::qux", key, [](KernelRuntimeContext& context, EValue** stack) {
219         (void)context;
220         *(stack[0]) = Scalar(100);
221       });
222   auto s1 = register_kernels({&kernel_1, 1});
223   EXPECT_EQ(s1, Error::Ok);
224 
225   Tensor::DimOrderType dims[] = {0, 1, 2, 3};
226   auto dim_order_type = Span<Tensor::DimOrderType>(dims, 4);
227   TensorMeta meta[] = {TensorMeta(ScalarType::Long, dim_order_type)};
228   Span<const TensorMeta> user_kernel_key_1(meta);
229   EXPECT_TRUE(registry_has_op_function("test::qux", user_kernel_key_1));
230 
231   Tensor::DimOrderType dims_channel_first[] = {0, 3, 1, 2};
232   auto dim_order_type_channel_first =
233       Span<Tensor::DimOrderType>(dims_channel_first, 4);
234   TensorMeta meta_channel_first[] = {
235       TensorMeta(ScalarType::Long, dim_order_type_channel_first)};
236   Span<const TensorMeta> user_kernel_key_2(meta_channel_first);
237   EXPECT_FALSE(registry_has_op_function("test::qux", user_kernel_key_2));
238 
239   TensorMeta meta_float[] = {TensorMeta(ScalarType::Float, dim_order_type)};
240   Span<const TensorMeta> user_kernel_key_3(meta_float);
241   EXPECT_FALSE(registry_has_op_function("test::qux", user_kernel_key_3));
242 }
243 
TEST_F(OperatorRegistryTest,ExecutorUsesKernel)244 TEST_F(OperatorRegistryTest, ExecutorUsesKernel) {
245   char buf_long_contiguous[BUF_SIZE];
246   make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous);
247   KernelKey key = KernelKey(buf_long_contiguous);
248 
249   Kernel kernel_1 = Kernel(
250       "test::quux", key, [](KernelRuntimeContext& context, EValue** stack) {
251         (void)context;
252         *(stack[0]) = Scalar(100);
253       });
254   auto s1 = register_kernels({&kernel_1, 1});
255   EXPECT_EQ(s1, Error::Ok);
256 
257   Tensor::DimOrderType dims[] = {0, 1, 2, 3};
258   auto dim_order_type = Span<Tensor::DimOrderType>(dims, 4);
259   TensorMeta meta[] = {TensorMeta(ScalarType::Long, dim_order_type)};
260   Span<const TensorMeta> user_kernel_key_1(meta);
261 
262   EXPECT_TRUE(registry_has_op_function("test::quux", user_kernel_key_1));
263   Result<OpFunction> func =
264       get_op_function_from_registry("test::quux", user_kernel_key_1);
265   EXPECT_EQ(func.error(), Error::Ok);
266 
267   EValue values[1];
268   values[0] = Scalar(0);
269   EValue* kernels[1];
270   kernels[0] = &values[0];
271   KernelRuntimeContext context{};
272   (*func)(context, kernels);
273 
274   auto val = values[0].toScalar().to<int64_t>();
275   ASSERT_EQ(val, 100);
276 }
277 
TEST_F(OperatorRegistryTest,ExecutorUsesFallbackKernel)278 TEST_F(OperatorRegistryTest, ExecutorUsesFallbackKernel) {
279   Kernel kernel_1 = Kernel(
280       "test::corge",
281       KernelKey{},
282       [](KernelRuntimeContext& context, EValue** stack) {
283         (void)context;
284         *(stack[0]) = Scalar(100);
285       });
286   auto s1 = register_kernels({&kernel_1, 1});
287   EXPECT_EQ(s1, Error::Ok);
288 
289   EXPECT_TRUE(registry_has_op_function("test::corge"));
290   EXPECT_TRUE(registry_has_op_function("test::corge", {}));
291 
292   Result<OpFunction> func = get_op_function_from_registry("test::corge", {});
293   EXPECT_EQ(func.error(), Error::Ok);
294 
295   EValue values[1];
296   values[0] = Scalar(0);
297   EValue* kernels[1];
298   kernels[0] = &values[0];
299   KernelRuntimeContext context{};
300   (*func)(context, kernels);
301 
302   auto val = values[0].toScalar().to<int64_t>();
303   ASSERT_EQ(val, 100);
304 }
305