1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <gtest/gtest.h>
10 #include <vector>
11
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/result.h>
14 #include <executorch/runtime/core/span.h>
15 #include <executorch/runtime/kernel/kernel_runtime_context.h>
16 #include <executorch/runtime/kernel/operator_registry.h>
17 #include <executorch/runtime/kernel/test/test_util.h>
18 #include <executorch/runtime/platform/runtime.h>
19 #include <executorch/test/utils/DeathTest.h>
20
21 using namespace ::testing;
22 using exec_aten::Scalar;
23 using exec_aten::ScalarType;
24 using exec_aten::Tensor;
25 using executorch::runtime::Error;
26 using executorch::runtime::EValue;
27 using executorch::runtime::get_op_function_from_registry;
28 using executorch::runtime::Kernel;
29 using executorch::runtime::KernelKey;
30 using executorch::runtime::KernelRuntimeContext;
31 using executorch::runtime::OpFunction;
32 using executorch::runtime::register_kernels;
33 using executorch::runtime::registry_has_op_function;
34 using executorch::runtime::Result;
35 using executorch::runtime::Span;
36 using executorch::runtime::TensorMeta;
37 using executorch::runtime::testing::make_kernel_key;
38
39 class OperatorRegistryTest : public ::testing::Test {
40 public:
SetUp()41 void SetUp() override {
42 executorch::runtime::runtime_init();
43 }
44 };
45
TEST_F(OperatorRegistryTest,Basic)46 TEST_F(OperatorRegistryTest, Basic) {
47 Kernel kernels[] = {Kernel("foo", [](KernelRuntimeContext&, EValue**) {})};
48 Span<const Kernel> kernels_span(kernels);
49 (void)register_kernels(kernels_span);
50 EXPECT_FALSE(registry_has_op_function("fpp"));
51 EXPECT_TRUE(registry_has_op_function("foo"));
52 }
53
TEST_F(OperatorRegistryTest,RegisterOpsMoreThanOnceDie)54 TEST_F(OperatorRegistryTest, RegisterOpsMoreThanOnceDie) {
55 Kernel kernels[] = {
56 Kernel("foo", [](KernelRuntimeContext&, EValue**) {}),
57 Kernel("foo", [](KernelRuntimeContext&, EValue**) {})};
58 Span<const Kernel> kernels_span = Span<const Kernel>(kernels);
59 ET_EXPECT_DEATH({ (void)register_kernels(kernels_span); }, "");
60 }
61
62 constexpr int BUF_SIZE = KernelKey::MAX_SIZE;
63
TEST_F(OperatorRegistryTest,KernelKeyEquals)64 TEST_F(OperatorRegistryTest, KernelKeyEquals) {
65 char buf_long_contiguous[BUF_SIZE];
66 make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous);
67 KernelKey long_contiguous = KernelKey(buf_long_contiguous);
68
69 KernelKey long_key_1 = KernelKey(long_contiguous);
70
71 KernelKey long_key_2 = KernelKey(long_contiguous);
72
73 EXPECT_EQ(long_key_1, long_key_2);
74
75 char buf_float_contiguous[BUF_SIZE];
76 make_kernel_key({{ScalarType::Float, {0, 1, 2, 3}}}, buf_float_contiguous);
77 KernelKey float_key = KernelKey(buf_float_contiguous);
78
79 EXPECT_NE(long_key_1, float_key);
80
81 char buf_channel_first[BUF_SIZE];
82 make_kernel_key({{ScalarType::Long, {0, 3, 1, 2}}}, buf_channel_first);
83 KernelKey long_key_3 = KernelKey(buf_channel_first);
84
85 EXPECT_NE(long_key_1, long_key_3);
86 }
87
TEST_F(OperatorRegistryTest,RegisterKernels)88 TEST_F(OperatorRegistryTest, RegisterKernels) {
89 char buf_long_contiguous[BUF_SIZE];
90 make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous);
91 KernelKey key = KernelKey(buf_long_contiguous);
92
93 Kernel kernel_1 = Kernel(
94 "test::boo", key, [](KernelRuntimeContext& context, EValue** stack) {
95 (void)context;
96 *(stack[0]) = Scalar(100);
97 });
98 auto s1 = register_kernels({&kernel_1, 1});
99 EXPECT_EQ(s1, Error::Ok);
100
101 Tensor::DimOrderType dims[] = {0, 1, 2, 3};
102 auto dim_order_type = Span<Tensor::DimOrderType>(dims, 4);
103 TensorMeta meta[] = {TensorMeta(ScalarType::Long, dim_order_type)};
104 Span<const TensorMeta> user_kernel_key(meta);
105
106 // no fallback kernel is registered
107 EXPECT_FALSE(registry_has_op_function("test::boo", {}));
108 Result<OpFunction> fallback_func =
109 get_op_function_from_registry("test::boo", {});
110 EXPECT_NE(fallback_func.error(), Error::Ok);
111
112 EXPECT_TRUE(registry_has_op_function("test::boo", user_kernel_key));
113 Result<OpFunction> func =
114 get_op_function_from_registry("test::boo", user_kernel_key);
115 EXPECT_EQ(func.error(), Error::Ok);
116
117 EValue values[1];
118 values[0] = Scalar(0);
119 EValue* kernels[1];
120 kernels[0] = &values[0];
121 KernelRuntimeContext context{};
122 (*func)(context, kernels);
123
124 auto val = values[0].toScalar().to<int64_t>();
125 ASSERT_EQ(val, 100);
126 }
127
TEST_F(OperatorRegistryTest,RegisterTwoKernels)128 TEST_F(OperatorRegistryTest, RegisterTwoKernels) {
129 char buf_long_contiguous[BUF_SIZE];
130 make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous);
131 KernelKey key_1 = KernelKey(buf_long_contiguous);
132
133 char buf_float_contiguous[BUF_SIZE];
134 make_kernel_key({{ScalarType::Float, {0, 1, 2, 3}}}, buf_float_contiguous);
135 KernelKey key_2 = KernelKey(buf_float_contiguous);
136 Kernel kernel_1 = Kernel(
137 "test::bar", key_1, [](KernelRuntimeContext& context, EValue** stack) {
138 (void)context;
139 *(stack[0]) = Scalar(100);
140 });
141 Kernel kernel_2 = Kernel(
142 "test::bar", key_2, [](KernelRuntimeContext& context, EValue** stack) {
143 (void)context;
144 *(stack[0]) = Scalar(50);
145 });
146 Kernel kernels[] = {kernel_1, kernel_2};
147 auto s1 = register_kernels(kernels);
148 // has both kernels
149 Tensor::DimOrderType dims[] = {0, 1, 2, 3};
150 auto dim_order_type = Span<Tensor::DimOrderType>(dims, 4);
151 TensorMeta meta[] = {TensorMeta(ScalarType::Long, dim_order_type)};
152 Span<const TensorMeta> user_kernel_key_1(meta);
153
154 TensorMeta meta_2[] = {TensorMeta(ScalarType::Float, dim_order_type)};
155 Span<const TensorMeta> user_kernel_key_2(meta_2);
156
157 // no fallback kernel is registered
158 EXPECT_FALSE(registry_has_op_function("test::bar", {}));
159 Result<OpFunction> fallback_func =
160 get_op_function_from_registry("test::bar", {});
161 EXPECT_NE(fallback_func.error(), Error::Ok);
162
163 EValue values[1];
164 values[0] = Scalar(0);
165 EValue* evalues[1];
166 evalues[0] = &values[0];
167 KernelRuntimeContext context{};
168
169 // test kernel_1
170 EXPECT_TRUE(registry_has_op_function("test::bar", user_kernel_key_1));
171 Result<OpFunction> func_1 =
172 get_op_function_from_registry("test::bar", user_kernel_key_1);
173 EXPECT_EQ(func_1.error(), Error::Ok);
174 (*func_1)(context, evalues);
175
176 auto val_1 = values[0].toScalar().to<int64_t>();
177 ASSERT_EQ(val_1, 100);
178
179 // test kernel_2
180 EXPECT_TRUE(registry_has_op_function("test::bar", user_kernel_key_2));
181 Result<OpFunction> func_2 =
182 get_op_function_from_registry("test::bar", user_kernel_key_2);
183 EXPECT_EQ(func_2.error(), Error::Ok);
184 values[0] = Scalar(0);
185 (*func_2)(context, evalues);
186
187 auto val_2 = values[0].toScalar().to<int64_t>();
188 ASSERT_EQ(val_2, 50);
189 }
190
TEST_F(OperatorRegistryTest,DoubleRegisterKernelsDies)191 TEST_F(OperatorRegistryTest, DoubleRegisterKernelsDies) {
192 char buf_long_contiguous[BUF_SIZE];
193 make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous);
194 KernelKey key = KernelKey(buf_long_contiguous);
195
196 Kernel kernel_1 = Kernel(
197 "test::baz", key, [](KernelRuntimeContext& context, EValue** stack) {
198 (void)context;
199 *(stack[0]) = Scalar(100);
200 });
201 Kernel kernel_2 = Kernel(
202 "test::baz", key, [](KernelRuntimeContext& context, EValue** stack) {
203 (void)context;
204 *(stack[0]) = Scalar(50);
205 });
206 Kernel kernels[] = {kernel_1, kernel_2};
207 // clang-tidy off
208 ET_EXPECT_DEATH({ auto s1 = register_kernels(kernels); }, "");
209 // clang-tidy on
210 }
211
TEST_F(OperatorRegistryTest,ExecutorChecksKernel)212 TEST_F(OperatorRegistryTest, ExecutorChecksKernel) {
213 char buf_long_contiguous[BUF_SIZE];
214 make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous);
215 KernelKey key = KernelKey(buf_long_contiguous);
216
217 Kernel kernel_1 = Kernel(
218 "test::qux", key, [](KernelRuntimeContext& context, EValue** stack) {
219 (void)context;
220 *(stack[0]) = Scalar(100);
221 });
222 auto s1 = register_kernels({&kernel_1, 1});
223 EXPECT_EQ(s1, Error::Ok);
224
225 Tensor::DimOrderType dims[] = {0, 1, 2, 3};
226 auto dim_order_type = Span<Tensor::DimOrderType>(dims, 4);
227 TensorMeta meta[] = {TensorMeta(ScalarType::Long, dim_order_type)};
228 Span<const TensorMeta> user_kernel_key_1(meta);
229 EXPECT_TRUE(registry_has_op_function("test::qux", user_kernel_key_1));
230
231 Tensor::DimOrderType dims_channel_first[] = {0, 3, 1, 2};
232 auto dim_order_type_channel_first =
233 Span<Tensor::DimOrderType>(dims_channel_first, 4);
234 TensorMeta meta_channel_first[] = {
235 TensorMeta(ScalarType::Long, dim_order_type_channel_first)};
236 Span<const TensorMeta> user_kernel_key_2(meta_channel_first);
237 EXPECT_FALSE(registry_has_op_function("test::qux", user_kernel_key_2));
238
239 TensorMeta meta_float[] = {TensorMeta(ScalarType::Float, dim_order_type)};
240 Span<const TensorMeta> user_kernel_key_3(meta_float);
241 EXPECT_FALSE(registry_has_op_function("test::qux", user_kernel_key_3));
242 }
243
TEST_F(OperatorRegistryTest,ExecutorUsesKernel)244 TEST_F(OperatorRegistryTest, ExecutorUsesKernel) {
245 char buf_long_contiguous[BUF_SIZE];
246 make_kernel_key({{ScalarType::Long, {0, 1, 2, 3}}}, buf_long_contiguous);
247 KernelKey key = KernelKey(buf_long_contiguous);
248
249 Kernel kernel_1 = Kernel(
250 "test::quux", key, [](KernelRuntimeContext& context, EValue** stack) {
251 (void)context;
252 *(stack[0]) = Scalar(100);
253 });
254 auto s1 = register_kernels({&kernel_1, 1});
255 EXPECT_EQ(s1, Error::Ok);
256
257 Tensor::DimOrderType dims[] = {0, 1, 2, 3};
258 auto dim_order_type = Span<Tensor::DimOrderType>(dims, 4);
259 TensorMeta meta[] = {TensorMeta(ScalarType::Long, dim_order_type)};
260 Span<const TensorMeta> user_kernel_key_1(meta);
261
262 EXPECT_TRUE(registry_has_op_function("test::quux", user_kernel_key_1));
263 Result<OpFunction> func =
264 get_op_function_from_registry("test::quux", user_kernel_key_1);
265 EXPECT_EQ(func.error(), Error::Ok);
266
267 EValue values[1];
268 values[0] = Scalar(0);
269 EValue* kernels[1];
270 kernels[0] = &values[0];
271 KernelRuntimeContext context{};
272 (*func)(context, kernels);
273
274 auto val = values[0].toScalar().to<int64_t>();
275 ASSERT_EQ(val, 100);
276 }
277
TEST_F(OperatorRegistryTest,ExecutorUsesFallbackKernel)278 TEST_F(OperatorRegistryTest, ExecutorUsesFallbackKernel) {
279 Kernel kernel_1 = Kernel(
280 "test::corge",
281 KernelKey{},
282 [](KernelRuntimeContext& context, EValue** stack) {
283 (void)context;
284 *(stack[0]) = Scalar(100);
285 });
286 auto s1 = register_kernels({&kernel_1, 1});
287 EXPECT_EQ(s1, Error::Ok);
288
289 EXPECT_TRUE(registry_has_op_function("test::corge"));
290 EXPECT_TRUE(registry_has_op_function("test::corge", {}));
291
292 Result<OpFunction> func = get_op_function_from_registry("test::corge", {});
293 EXPECT_EQ(func.error(), Error::Ok);
294
295 EValue values[1];
296 values[0] = Scalar(0);
297 EValue* kernels[1];
298 kernels[0] = &values[0];
299 KernelRuntimeContext context{};
300 (*func)(context, kernels);
301
302 auto val = values[0].toScalar().to<int64_t>();
303 ASSERT_EQ(val, 100);
304 }
305