xref: /aosp_15_r20/external/executorch/runtime/executor/test/kernel_resolution_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cctype>
10 #include <filesystem>
11 
12 #include <cstring>
13 #include <memory>
14 
15 #include <executorch/extension/data_loader/file_data_loader.h>
16 #include <executorch/runtime/core/error.h>
17 #include <executorch/runtime/core/result.h>
18 #include <executorch/runtime/executor/method.h>
19 #include <executorch/runtime/executor/program.h>
20 #include <executorch/runtime/executor/test/managed_memory_manager.h>
21 #include <executorch/runtime/kernel/operator_registry.h>
22 #include <executorch/runtime/platform/runtime.h>
23 #include <executorch/test/utils/DeathTest.h>
24 
25 #include <gtest/gtest.h>
26 
27 using namespace ::testing;
28 using exec_aten::Scalar;
29 using exec_aten::ScalarType;
30 using executorch::runtime::Error;
31 using executorch::runtime::EValue;
32 using executorch::runtime::Kernel;
33 using executorch::runtime::KernelKey;
34 using executorch::runtime::KernelRuntimeContext;
35 using executorch::runtime::Method;
36 using executorch::runtime::Program;
37 using executorch::runtime::register_kernel;
38 using executorch::runtime::Result;
39 using executorch::runtime::TensorMeta;
40 using executorch::runtime::testing::ManagedMemoryManager;
41 using torch::executor::util::FileDataLoader;
42 
43 constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U;
44 constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U;
45 
46 class KernelResolutionTest : public ::testing::Test {
47  protected:
SetUp()48   void SetUp() override {
49     // Since these tests cause ET_LOG to be called, the PAL must be initialized
50     // first.
51     executorch::runtime::runtime_init();
52 
53     // Create a loader for the serialized ModuleAdd program.
54     const char* path = std::getenv("ET_MODULE_ADD_PATH");
55     Result<FileDataLoader> loader = FileDataLoader::from(path);
56     ASSERT_EQ(loader.error(), Error::Ok);
57     loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
58 
59     // Use it to load the program.
60     Result<Program> program = Program::load(
61         loader_.get(), Program::Verification::InternalConsistency);
62     ASSERT_EQ(program.error(), Error::Ok);
63     program_ = std::make_unique<Program>(std::move(program.get()));
64   }
65 
66   std::unique_ptr<FileDataLoader> loader_;
67   std::unique_ptr<Program> program_;
68 };
69 
70 /**
71  * Test if the program can initialize properly.
72  */
TEST_F(KernelResolutionTest,InitExecutionPlanSuccess)73 TEST_F(KernelResolutionTest, InitExecutionPlanSuccess) {
74   // register kernel with fallback kernel key
75   Kernel kernel_1 = Kernel(
76       "aten::add.out", {}, [](KernelRuntimeContext& context, EValue** stack) {
77         (void)context;
78         *(stack[0]) = Scalar(100);
79       });
80   auto s1 = register_kernel(kernel_1);
81   EXPECT_EQ(s1, executorch::runtime::Error::Ok);
82 
83   ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
84   auto method = program_->load_method("forward", &mmm.get());
85   ASSERT_EQ(method.error(), Error::Ok);
86 }
87 
88 /**
89  * Test if we can resolve the kernel key correctly.
90  */
TEST_F(KernelResolutionTest,ResolveKernelKeySuccess)91 TEST_F(KernelResolutionTest, ResolveKernelKeySuccess) {
92   // getting all these TensorMeta from args to this kernel_call in the program.
93   // particularly for aten::add.out:
94   // add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) ->
95   // Tensor(a!) The arguments are: `self, other, out, out` (we repeat out
96   // argument in the program) Also since we traced using randn(2, 2), all the
97   // args are Float with dim order (0, 1)
98 
99   // Construct a kernel key with the following meta:
100   // exec_aten::DimOrderType contiguous[] = {0, 1};
101   // TensorMeta float_contiguous[] = {
102   //     TensorMeta(ScalarType::Float, contiguous),
103   //     TensorMeta(ScalarType::Float, contiguous),
104   //     TensorMeta(ScalarType::Float, contiguous),
105   //     TensorMeta(ScalarType::Float, contiguous)};
106   KernelKey key = KernelKey("v1/6;0,1|6;0,1|6;0,1|6;0,1");
107   Kernel kernel_1 = Kernel(
108       "aten::add.out", key, [](KernelRuntimeContext& context, EValue** stack) {
109         (void)context;
110         *(stack[0]) = Scalar(100);
111       });
112   auto s1 = register_kernel(kernel_1);
113   EXPECT_EQ(s1, executorch::runtime::Error::Ok);
114 
115   ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
116   auto method = program_->load_method("forward", &mmm.get());
117   ASSERT_EQ(method.error(), Error::Ok);
118 }
119