xref: /aosp_15_r20/external/executorch/runtime/executor/test/test_backend_compiler_lib.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/backend/interface.h>
10 #include <executorch/runtime/core/error.h>
11 #include <executorch/runtime/core/evalue.h>
12 #include <executorch/runtime/platform/profiler.h>
13 #include <cstdio>
14 #include <cstdlib> /* strtol */
15 
16 using executorch::runtime::ArrayRef;
17 using executorch::runtime::Backend;
18 using executorch::runtime::BackendExecutionContext;
19 using executorch::runtime::BackendInitContext;
20 using executorch::runtime::BackendInterface;
21 using executorch::runtime::CompileSpec;
22 using executorch::runtime::DelegateHandle;
23 using executorch::runtime::Error;
24 using executorch::runtime::EValue;
25 using executorch::runtime::FreeableBuffer;
26 using executorch::runtime::MemoryAllocator;
27 using executorch::runtime::Result;
28 
29 struct DemoOp {
30   const char* name;
31   long int numel;
32   const char* dtype;
33   long int debug_handle;
34 };
35 
36 struct DemoOpList {
37   DemoOp* ops;
38   size_t numops;
39 };
40 
41 class BackendWithCompiler final : public BackendInterface {
42   int max_shape = 4;
43 
44  public:
45   ~BackendWithCompiler() override = default;
46 
is_available() const47   bool is_available() const override {
48     return true;
49   }
50 
51   // The delegate blob schema will be a list of instruction:
52   // {op: {str}, numel: {long}, dtype: {type}}<debug_handle>n
53   // Instruction will be separated by #, for example:
54   // 'op:demo::mul.Tensor, numel:4, dtype:torch.float32<debug_handle>2\
55   // #op:demo::add.Tensor, numel:4, dtype:torch.float32<debug_handle>4#'
parse_delegate(const char * str,const char * sub,DemoOp * op_list) const56   void parse_delegate(const char* str, const char* sub, DemoOp* op_list) const {
57     const char* kOpLiteral = "op:";
58     const char* kNumelLiteral = "numel:";
59     const char* kDtypeliteral = "dtype:";
60     const char* kDebugHandleLiteral = "<debug_handle>";
61 
62     const char* kComma = ",";
63 
64     int cnt = 0;
65     const char* left = str;
66     const char* right;
67 
68     // iter 0:
69     // op:demo::sin.default, numel:1, dtype:torch.float32<debug_handle>1#
70     // |<--left                                                 right-->|
71     // iter 1:
72     // op:demo::add.Tensor, numel:4, dtype:torch.float32<debug_handle>4#
73     // |<--left                                                right-->|
74     while ((right = strstr(left, sub))) {
75       // Get operator name
76       const char* op_start = strstr(left, kOpLiteral) + strlen(kOpLiteral);
77       const char* op_end = strstr(op_start, kComma);
78 
79       op_list[cnt].name = op_start;
80 
81       // Get numel
82       const char* numel_start =
83           strstr(op_end, kNumelLiteral) + strlen(kNumelLiteral);
84       char* numel_end = const_cast<char*>(strstr(numel_start, kComma));
85       op_list[cnt].numel = strtol(numel_start, &numel_end, 10);
86 
87       // Get dtype
88       const char* dtype_start =
89           strstr(numel_end, kDtypeliteral) + strlen(kDtypeliteral);
90       const char* dtype_end = strstr(dtype_start, kDebugHandleLiteral);
91       op_list[cnt].dtype = dtype_start;
92 
93       // Get debug handle
94       const char* debug_handle_start =
95           strstr(dtype_end, kDebugHandleLiteral) + strlen(kDebugHandleLiteral);
96       char* debug_end = const_cast<char*>(strstr(debug_handle_start, kComma));
97       op_list[cnt].debug_handle = strtol(debug_handle_start, &debug_end, 10);
98 
99       // Move left pointer to the start of next instruction
100       left = right + 1;
101       cnt++;
102     }
103   }
104 
init(BackendInitContext & context,FreeableBuffer * processed,ArrayRef<CompileSpec> compile_specs) const105   Result<DelegateHandle*> init(
106       BackendInitContext& context,
107       FreeableBuffer* processed,
108       ArrayRef<CompileSpec> compile_specs) const override {
109     MemoryAllocator* runtime_allocator = context.get_runtime_allocator();
110     int shape = *(int*)(compile_specs.at(0).value.buffer);
111     ET_CHECK_OR_RETURN_ERROR(
112         shape <= max_shape,
113         InvalidArgument,
114         "The input number is %d and it's larger than the max number %d "
115         "supported by this backend.",
116         shape,
117         max_shape);
118 
119     const char* kSignLiteral = "#";
120     // The first number is the number of total instruction
121     const char* start = static_cast<const char*>(processed->data());
122 
123     const char* kVersion = "version:";
124     const long int kRuntimeVersion = 0;
125     char* version_start =
126         const_cast<char*>(strstr(start, kVersion)) + strlen(kVersion);
127     char* version_end;
128     char* instruction_set_start =
129         const_cast<char*>(strstr(start, kSignLiteral));
130 
131     long int version = strtol(version_start, &version_end, 10);
132     ET_CHECK_OR_RETURN_ERROR(
133         version == kRuntimeVersion,
134         DelegateInvalidCompatibility,
135         "The version of BackendWithCompiler runtime is %ld, but received an incompatible version %ld instead.",
136         kRuntimeVersion,
137         version);
138     char* instruction_number_end;
139     long int instruction_number = strtol(start, &instruction_number_end, 10);
140 
141     ET_CHECK_OR_RETURN_ERROR(
142         instruction_number >= 0,
143         InvalidArgument,
144         "Instruction count must be non-negative: %ld",
145         instruction_number);
146 
147     auto op_list =
148         ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(runtime_allocator, DemoOpList);
149     op_list->ops = ET_ALLOCATE_LIST_OR_RETURN_ERROR(
150         runtime_allocator, DemoOp, instruction_number);
151     op_list->numops = static_cast<size_t>(instruction_number);
152 
153     parse_delegate(instruction_set_start + 1, kSignLiteral, op_list->ops);
154 
155     // Can't call `processed->Free()` because op_list points into it.
156 
157     return op_list;
158   }
159 
160   // Function that actually executes the model in the backend. Here there is
161   // nothing to dispatch to, so the backend is implemented locally within
162   // execute and it only supports add, subtract, and constant. In a non toy
163   // backend you can imagine how this function could be used to actually
164   // dispatch the inputs to the relevant backend/device.
execute(ET_UNUSED BackendExecutionContext & context,DelegateHandle * handle,EValue ** args) const165   Error execute(
166       ET_UNUSED BackendExecutionContext& context,
167       DelegateHandle* handle,
168       EValue** args) const override {
169     EXECUTORCH_SCOPE_PROF("BackendWithCompiler::execute");
170 
171     // example: [('prim::Constant#1', 14), ('aten::add', 15)]
172     auto op_list = static_cast<const DemoOpList*>(handle);
173 
174     const char* kDemoAdd = "demo::aten.add.Tensor";
175     const char* kDemoMul = "demo::aten.mm.default";
176     const char* kDemoSin = "demo::aten.sin.default";
177     const char* kTorchFloat32 = "torch.float32";
178 
179     for (size_t index = 0; index < op_list->numops; index++) {
180       auto instruction = op_list->ops[index];
181       ET_CHECK_OR_RETURN_ERROR(
182           strncmp(instruction.dtype, kTorchFloat32, strlen(kTorchFloat32)) == 0,
183           NotSupported,
184           "BackendWithCompiler only support float and doesn't support %s, "
185           "debug handle is: %ld",
186           instruction.dtype,
187           instruction.debug_handle);
188       if (strncmp(instruction.name, kDemoAdd, strlen(kDemoAdd)) == 0) {
189         // z = z + b
190         const float* b_ptr = args[2]->toTensor().const_data_ptr<float>();
191         float* z_ptr = args[3]->toTensor().mutable_data_ptr<float>();
192         for (size_t j = 0; j < instruction.numel; j++) {
193           z_ptr[j] = b_ptr[j] + z_ptr[j];
194         }
195       } else if (strncmp(instruction.name, kDemoMul, strlen(kDemoMul)) == 0) {
196         ET_CHECK_OR_RETURN_ERROR(
197             instruction.numel == 4,
198             NotSupported,
199             "BackendWithCompiler only support 2 x 2 matrix multiplication, "
200             "debug handle is %ld",
201             instruction.debug_handle);
202         // z = a * x
203         const float* a_ptr = args[0]->toTensor().const_data_ptr<float>();
204         const float* x_ptr = args[1]->toTensor().const_data_ptr<float>();
205         float* z_ptr = args[3]->toTensor().mutable_data_ptr<float>();
206 
207         z_ptr[0] = a_ptr[0] * x_ptr[0] + a_ptr[1] * x_ptr[2];
208         z_ptr[1] = a_ptr[0] * x_ptr[1] + a_ptr[1] * x_ptr[3];
209         z_ptr[2] = a_ptr[2] * x_ptr[0] + a_ptr[3] * x_ptr[2];
210         z_ptr[3] = a_ptr[2] * x_ptr[1] + a_ptr[3] * x_ptr[3];
211       } else if (strncmp(instruction.name, kDemoSin, strlen(kDemoSin)) == 0) {
212         const float* x_ptr = args[0]->toTensor().const_data_ptr<float>();
213         float* y_ptr = args[1]->toTensor().mutable_data_ptr<float>();
214         // Taylor series: an approximation of sin x around the point x = 0
215         // sin(x) = x - x^3 / 3! + x^5 / 5! - x^7 / 7! ...
216         // Use the first two items as proof of concept
217         for (size_t j = 0; j < instruction.numel; j++) {
218           y_ptr[j] = x_ptr[j] - x_ptr[j] * x_ptr[j] * x_ptr[j] / 6.0;
219         }
220       }
221     }
222     return Error::Ok;
223   }
224 };
225 
226 namespace {
227 auto cls = BackendWithCompiler();
228 Backend backend{"BackendWithCompilerDemo", &cls};
229 static auto success_with_compiler = register_backend(backend);
230 } // namespace
231