1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cctype>
10 #include <filesystem>
11
12 #include <cstring>
13 #include <memory>
14
15 #include <executorch/extension/data_loader/file_data_loader.h>
16 #include <executorch/runtime/core/error.h>
17 #include <executorch/runtime/core/result.h>
18 #include <executorch/runtime/executor/method.h>
19 #include <executorch/runtime/executor/program.h>
20 #include <executorch/runtime/executor/test/managed_memory_manager.h>
21 #include <executorch/runtime/kernel/operator_registry.h>
22 #include <executorch/runtime/platform/runtime.h>
23 #include <executorch/test/utils/DeathTest.h>
24
25 #include <gtest/gtest.h>
26
27 using namespace ::testing;
28 using exec_aten::Scalar;
29 using exec_aten::ScalarType;
30 using executorch::runtime::Error;
31 using executorch::runtime::EValue;
32 using executorch::runtime::Kernel;
33 using executorch::runtime::KernelKey;
34 using executorch::runtime::KernelRuntimeContext;
35 using executorch::runtime::Method;
36 using executorch::runtime::Program;
37 using executorch::runtime::register_kernel;
38 using executorch::runtime::Result;
39 using executorch::runtime::TensorMeta;
40 using executorch::runtime::testing::ManagedMemoryManager;
41 using torch::executor::util::FileDataLoader;
42
43 constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U;
44 constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U;
45
46 class KernelResolutionTest : public ::testing::Test {
47 protected:
SetUp()48 void SetUp() override {
49 // Since these tests cause ET_LOG to be called, the PAL must be initialized
50 // first.
51 executorch::runtime::runtime_init();
52
53 // Create a loader for the serialized ModuleAdd program.
54 const char* path = std::getenv("ET_MODULE_ADD_PATH");
55 Result<FileDataLoader> loader = FileDataLoader::from(path);
56 ASSERT_EQ(loader.error(), Error::Ok);
57 loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
58
59 // Use it to load the program.
60 Result<Program> program = Program::load(
61 loader_.get(), Program::Verification::InternalConsistency);
62 ASSERT_EQ(program.error(), Error::Ok);
63 program_ = std::make_unique<Program>(std::move(program.get()));
64 }
65
66 std::unique_ptr<FileDataLoader> loader_;
67 std::unique_ptr<Program> program_;
68 };
69
70 /**
71 * Test if the program can initialize properly.
72 */
TEST_F(KernelResolutionTest,InitExecutionPlanSuccess)73 TEST_F(KernelResolutionTest, InitExecutionPlanSuccess) {
74 // register kernel with fallback kernel key
75 Kernel kernel_1 = Kernel(
76 "aten::add.out", {}, [](KernelRuntimeContext& context, EValue** stack) {
77 (void)context;
78 *(stack[0]) = Scalar(100);
79 });
80 auto s1 = register_kernel(kernel_1);
81 EXPECT_EQ(s1, executorch::runtime::Error::Ok);
82
83 ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
84 auto method = program_->load_method("forward", &mmm.get());
85 ASSERT_EQ(method.error(), Error::Ok);
86 }
87
88 /**
89 * Test if we can resolve the kernel key correctly.
90 */
TEST_F(KernelResolutionTest,ResolveKernelKeySuccess)91 TEST_F(KernelResolutionTest, ResolveKernelKeySuccess) {
92 // getting all these TensorMeta from args to this kernel_call in the program.
93 // particularly for aten::add.out:
94 // add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) ->
95 // Tensor(a!) The arguments are: `self, other, out, out` (we repeat out
96 // argument in the program) Also since we traced using randn(2, 2), all the
97 // args are Float with dim order (0, 1)
98
99 // Construct a kernel key with the following meta:
100 // exec_aten::DimOrderType contiguous[] = {0, 1};
101 // TensorMeta float_contiguous[] = {
102 // TensorMeta(ScalarType::Float, contiguous),
103 // TensorMeta(ScalarType::Float, contiguous),
104 // TensorMeta(ScalarType::Float, contiguous),
105 // TensorMeta(ScalarType::Float, contiguous)};
106 KernelKey key = KernelKey("v1/6;0,1|6;0,1|6;0,1|6;0,1");
107 Kernel kernel_1 = Kernel(
108 "aten::add.out", key, [](KernelRuntimeContext& context, EValue** stack) {
109 (void)context;
110 *(stack[0]) = Scalar(100);
111 });
112 auto s1 = register_kernel(kernel_1);
113 EXPECT_EQ(s1, executorch::runtime::Error::Ok);
114
115 ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
116 auto method = program_->load_method("forward", &mmm.get());
117 ASSERT_EQ(method.error(), Error::Ok);
118 }
119