1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #include <executorch/runtime/backend/interface.h> 10 #include <executorch/runtime/core/error.h> 11 #include <executorch/runtime/core/evalue.h> 12 #include <executorch/runtime/platform/profiler.h> 13 #include <cstdio> 14 #include <cstdlib> /* strtol */ 15 16 using executorch::runtime::ArrayRef; 17 using executorch::runtime::Backend; 18 using executorch::runtime::BackendExecutionContext; 19 using executorch::runtime::BackendInitContext; 20 using executorch::runtime::BackendInterface; 21 using executorch::runtime::CompileSpec; 22 using executorch::runtime::DelegateHandle; 23 using executorch::runtime::Error; 24 using executorch::runtime::EValue; 25 using executorch::runtime::FreeableBuffer; 26 using executorch::runtime::MemoryAllocator; 27 using executorch::runtime::Result; 28 29 struct DemoOp { 30 const char* name; 31 long int numel; 32 const char* dtype; 33 long int debug_handle; 34 }; 35 36 struct DemoOpList { 37 DemoOp* ops; 38 size_t numops; 39 }; 40 41 class BackendWithCompiler final : public BackendInterface { 42 int max_shape = 4; 43 44 public: 45 ~BackendWithCompiler() override = default; 46 is_available() const47 bool is_available() const override { 48 return true; 49 } 50 51 // The delegate blob schema will be a list of instruction: 52 // {op: {str}, numel: {long}, dtype: {type}}<debug_handle>n 53 // Instruction will be separated by #, for example: 54 // 'op:demo::mul.Tensor, numel:4, dtype:torch.float32<debug_handle>2\ 55 // #op:demo::add.Tensor, numel:4, dtype:torch.float32<debug_handle>4#' parse_delegate(const char * str,const char * sub,DemoOp * op_list) const56 void parse_delegate(const char* str, const char* sub, DemoOp* op_list) const { 57 const char* kOpLiteral = "op:"; 58 const char* kNumelLiteral = "numel:"; 59 const char* kDtypeliteral = "dtype:"; 60 const char* kDebugHandleLiteral = "<debug_handle>"; 61 62 const char* kComma = ","; 63 64 int cnt = 0; 65 const char* left = str; 66 const char* right; 67 68 // iter 0: 69 // op:demo::sin.default, numel:1, dtype:torch.float32<debug_handle>1# 70 // |<--left right-->| 71 // iter 1: 72 // op:demo::add.Tensor, numel:4, dtype:torch.float32<debug_handle>4# 73 // |<--left right-->| 74 while ((right = strstr(left, sub))) { 75 // Get operator name 76 const char* op_start = strstr(left, kOpLiteral) + strlen(kOpLiteral); 77 const char* op_end = strstr(op_start, kComma); 78 79 op_list[cnt].name = op_start; 80 81 // Get numel 82 const char* numel_start = 83 strstr(op_end, kNumelLiteral) + strlen(kNumelLiteral); 84 char* numel_end = const_cast<char*>(strstr(numel_start, kComma)); 85 op_list[cnt].numel = strtol(numel_start, &numel_end, 10); 86 87 // Get dtype 88 const char* dtype_start = 89 strstr(numel_end, kDtypeliteral) + strlen(kDtypeliteral); 90 const char* dtype_end = strstr(dtype_start, kDebugHandleLiteral); 91 op_list[cnt].dtype = dtype_start; 92 93 // Get debug handle 94 const char* debug_handle_start = 95 strstr(dtype_end, kDebugHandleLiteral) + strlen(kDebugHandleLiteral); 96 char* debug_end = const_cast<char*>(strstr(debug_handle_start, kComma)); 97 op_list[cnt].debug_handle = strtol(debug_handle_start, &debug_end, 10); 98 99 // Move left pointer to the start of next instruction 100 left = right + 1; 101 cnt++; 102 } 103 } 104 init(BackendInitContext & context,FreeableBuffer * processed,ArrayRef<CompileSpec> compile_specs) const105 Result<DelegateHandle*> init( 106 BackendInitContext& context, 107 FreeableBuffer* processed, 108 ArrayRef<CompileSpec> compile_specs) const override { 109 MemoryAllocator* runtime_allocator = context.get_runtime_allocator(); 110 int shape = *(int*)(compile_specs.at(0).value.buffer); 111 ET_CHECK_OR_RETURN_ERROR( 112 shape <= max_shape, 113 InvalidArgument, 114 "The input number is %d and it's larger than the max number %d " 115 "supported by this backend.", 116 shape, 117 max_shape); 118 119 const char* kSignLiteral = "#"; 120 // The first number is the number of total instruction 121 const char* start = static_cast<const char*>(processed->data()); 122 123 const char* kVersion = "version:"; 124 const long int kRuntimeVersion = 0; 125 char* version_start = 126 const_cast<char*>(strstr(start, kVersion)) + strlen(kVersion); 127 char* version_end; 128 char* instruction_set_start = 129 const_cast<char*>(strstr(start, kSignLiteral)); 130 131 long int version = strtol(version_start, &version_end, 10); 132 ET_CHECK_OR_RETURN_ERROR( 133 version == kRuntimeVersion, 134 DelegateInvalidCompatibility, 135 "The version of BackendWithCompiler runtime is %ld, but received an incompatible version %ld instead.", 136 kRuntimeVersion, 137 version); 138 char* instruction_number_end; 139 long int instruction_number = strtol(start, &instruction_number_end, 10); 140 141 ET_CHECK_OR_RETURN_ERROR( 142 instruction_number >= 0, 143 InvalidArgument, 144 "Instruction count must be non-negative: %ld", 145 instruction_number); 146 147 auto op_list = 148 ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(runtime_allocator, DemoOpList); 149 op_list->ops = ET_ALLOCATE_LIST_OR_RETURN_ERROR( 150 runtime_allocator, DemoOp, instruction_number); 151 op_list->numops = static_cast<size_t>(instruction_number); 152 153 parse_delegate(instruction_set_start + 1, kSignLiteral, op_list->ops); 154 155 // Can't call `processed->Free()` because op_list points into it. 156 157 return op_list; 158 } 159 160 // Function that actually executes the model in the backend. Here there is 161 // nothing to dispatch to, so the backend is implemented locally within 162 // execute and it only supports add, subtract, and constant. In a non toy 163 // backend you can imagine how this function could be used to actually 164 // dispatch the inputs to the relevant backend/device. execute(ET_UNUSED BackendExecutionContext & context,DelegateHandle * handle,EValue ** args) const165 Error execute( 166 ET_UNUSED BackendExecutionContext& context, 167 DelegateHandle* handle, 168 EValue** args) const override { 169 EXECUTORCH_SCOPE_PROF("BackendWithCompiler::execute"); 170 171 // example: [('prim::Constant#1', 14), ('aten::add', 15)] 172 auto op_list = static_cast<const DemoOpList*>(handle); 173 174 const char* kDemoAdd = "demo::aten.add.Tensor"; 175 const char* kDemoMul = "demo::aten.mm.default"; 176 const char* kDemoSin = "demo::aten.sin.default"; 177 const char* kTorchFloat32 = "torch.float32"; 178 179 for (size_t index = 0; index < op_list->numops; index++) { 180 auto instruction = op_list->ops[index]; 181 ET_CHECK_OR_RETURN_ERROR( 182 strncmp(instruction.dtype, kTorchFloat32, strlen(kTorchFloat32)) == 0, 183 NotSupported, 184 "BackendWithCompiler only support float and doesn't support %s, " 185 "debug handle is: %ld", 186 instruction.dtype, 187 instruction.debug_handle); 188 if (strncmp(instruction.name, kDemoAdd, strlen(kDemoAdd)) == 0) { 189 // z = z + b 190 const float* b_ptr = args[2]->toTensor().const_data_ptr<float>(); 191 float* z_ptr = args[3]->toTensor().mutable_data_ptr<float>(); 192 for (size_t j = 0; j < instruction.numel; j++) { 193 z_ptr[j] = b_ptr[j] + z_ptr[j]; 194 } 195 } else if (strncmp(instruction.name, kDemoMul, strlen(kDemoMul)) == 0) { 196 ET_CHECK_OR_RETURN_ERROR( 197 instruction.numel == 4, 198 NotSupported, 199 "BackendWithCompiler only support 2 x 2 matrix multiplication, " 200 "debug handle is %ld", 201 instruction.debug_handle); 202 // z = a * x 203 const float* a_ptr = args[0]->toTensor().const_data_ptr<float>(); 204 const float* x_ptr = args[1]->toTensor().const_data_ptr<float>(); 205 float* z_ptr = args[3]->toTensor().mutable_data_ptr<float>(); 206 207 z_ptr[0] = a_ptr[0] * x_ptr[0] + a_ptr[1] * x_ptr[2]; 208 z_ptr[1] = a_ptr[0] * x_ptr[1] + a_ptr[1] * x_ptr[3]; 209 z_ptr[2] = a_ptr[2] * x_ptr[0] + a_ptr[3] * x_ptr[2]; 210 z_ptr[3] = a_ptr[2] * x_ptr[1] + a_ptr[3] * x_ptr[3]; 211 } else if (strncmp(instruction.name, kDemoSin, strlen(kDemoSin)) == 0) { 212 const float* x_ptr = args[0]->toTensor().const_data_ptr<float>(); 213 float* y_ptr = args[1]->toTensor().mutable_data_ptr<float>(); 214 // Taylor series: an approximation of sin x around the point x = 0 215 // sin(x) = x - x^3 / 3! + x^5 / 5! - x^7 / 7! ... 216 // Use the first two items as proof of concept 217 for (size_t j = 0; j < instruction.numel; j++) { 218 y_ptr[j] = x_ptr[j] - x_ptr[j] * x_ptr[j] * x_ptr[j] / 6.0; 219 } 220 } 221 } 222 return Error::Ok; 223 } 224 }; 225 226 namespace { 227 auto cls = BackendWithCompiler(); 228 Backend backend{"BackendWithCompilerDemo", &cls}; 229 static auto success_with_compiler = register_backend(backend); 230 } // namespace 231