1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/runtime/core/evalue.h> 12 #include <executorch/runtime/core/event_tracer.h> 13 #include <executorch/runtime/core/exec_aten/exec_aten.h> 14 #include <executorch/runtime/core/span.h> 15 #include <executorch/runtime/executor/memory_manager.h> 16 #include <executorch/runtime/executor/method_meta.h> 17 #include <executorch/runtime/platform/compiler.h> 18 19 // Forward declare flatbuffer types. This is a public header and must not 20 // include the generated flatbuffer header. 21 namespace executorch_flatbuffer { 22 struct Chain; 23 struct ExecutionPlan; 24 struct EValue; 25 } // namespace executorch_flatbuffer 26 27 namespace executorch { 28 namespace runtime { 29 30 // Forward declare Program to avoid a circular reference. 31 class Program; 32 33 // Forward declare internal types. 34 class BackendDelegate; 35 struct Chain; 36 class KernelRuntimeContext; 37 using OpFunction = void (*)(KernelRuntimeContext&, EValue**); 38 /// A list of pointers into the master values table that together compose the 39 /// argument list for a single instruction 40 using InstructionArgs = Span<EValue*>; 41 42 /** 43 * An executable method of an executorch program. Maps to a python method like 44 * `forward()` on the original nn.Module. 45 */ 46 class Method final { 47 public: 48 /** 49 * Move ctor. Takes ownership of resources previously owned by `rhs`, 50 * and leaves `rhs` in an uninitialized state. 51 */ Method(Method && rhs)52 Method(Method&& rhs) noexcept 53 : step_state_(rhs.step_state_), 54 program_(rhs.program_), 55 memory_manager_(rhs.memory_manager_), 56 temp_allocator_(rhs.temp_allocator_), 57 serialization_plan_(rhs.serialization_plan_), 58 event_tracer_(rhs.event_tracer_), 59 n_value_(rhs.n_value_), 60 values_(rhs.values_), 61 n_delegate_(rhs.n_delegate_), 62 delegates_(rhs.delegates_), 63 n_chains_(rhs.n_chains_), 64 chains_(rhs.chains_), 65 init_state_(rhs.init_state_) { 66 // Required: clear out fields that the dtor looks at, so that we don't free 67 // anything twice. 68 rhs.n_value_ = 0; 69 rhs.values_ = nullptr; 70 rhs.n_delegate_ = 0; 71 rhs.delegates_ = nullptr; 72 73 // Helpful: Try to ensure that any other interactions with the old object 74 // result in failures. 75 rhs.init_state_ = InitializationState::Uninitialized; 76 rhs.step_state_ = {}; 77 rhs.program_ = nullptr; 78 rhs.memory_manager_ = nullptr; 79 rhs.serialization_plan_ = nullptr; 80 rhs.event_tracer_ = nullptr; 81 rhs.n_chains_ = 0; 82 rhs.chains_ = nullptr; 83 } 84 85 /** 86 * Sets the internal input value to be equivalent to the to the provided 87 * value. 88 * 89 * @param[in] input_evalue The evalue to copy into the method input. If the 90 * evalue is a tensor, the data is copied in most cases, so the tensor 91 * passed in here does not always need to outlive this call. But there is 92 * a case where the Method will keep a pointer to the tensor's data. 93 * Based on the memory plan of the method, the inputs may not have 94 * buffer space pre-allocated for them. In this case the executor will 95 * alias the memory of the tensors provided as inputs here rather then 96 * deepcopy the input into the memory planned arena. 97 * 98 * @param[in] input_idx Zero-based index of the input to set. Must be less 99 * than the value returned by inputs_size(). 100 * 101 * @returns Error::Ok on success, non-Ok on failure. 102 */ 103 ET_NODISCARD Error set_input(const EValue& input_evalue, size_t input_idx); 104 105 /** 106 * Sets the values of all method inputs. 107 * 108 * See set_input() for a more detailed description of the behavior. 109 * 110 * @param[in] input_evalues The new values for all of the method inputs. The 111 * type of each element must match the type of corresponding input. If the 112 * value of an element is a tensor, attempts to allow dynamic shape, but 113 * the dtype must always agree. 114 * 115 * @returns Error::Ok on success, non-Ok on failure. 116 */ 117 ET_NODISCARD Error 118 set_inputs(const executorch::aten::ArrayRef<EValue>& input_evalues); 119 120 /** 121 * Sets the data buffer of the specified method output to the provided value. 122 * 123 * NOTE: Based on the memory plan of the method, the output tensors may not 124 * have buffer space pre-allocated for them, in this case the executor will 125 * point those tensors to the buffer provided here, so the user should take 126 * care that the life span of this memory outlasts the executor forward. 127 * 128 * @param[in] buffer The block of memory to point the specified tensor at. 129 * 130 * @param[in] size the length of buffer in bytes, must be >= the nbytes of the 131 * specified tensor. 132 * 133 * @param[in] output_idx The index of the output to set the data_ptr for. Must 134 * correspond to a tensor, and that tensor must not have had a buffer 135 * allocated by the memory plan. 136 * 137 * @returns Error::Ok on success, non-Ok on failure. 138 */ 139 ET_NODISCARD Error 140 set_output_data_ptr(void* buffer, size_t size, size_t output_idx); 141 142 /** 143 * Copies the method's outputs into the provided array. 144 * 145 * WARNING: The output contains shallow copies of internal tensor outputs. 146 * Please do not mutate returned Tensor elements. 147 * 148 * TODO(T139259264): Add checks to detect output mutation, or deep-copy 149 * outputs. 150 * 151 * @param[in] output_evalues The array to copy the outputs into. The first 152 * `outputs_size()` elements will be set to the corresponding output 153 * values. The rest of the array will be set to the EValue value None. 154 * @param[in] length The size of the `output_evalues` array in elements. Must 155 * be greater than or equal to `outputs_size()`. 156 * 157 * @returns Error::Ok on success, non-Ok on failure. 158 */ 159 ET_NODISCARD Error get_outputs(EValue* output_evalues, size_t length); 160 161 /** 162 * Copies the method's inputs into the provided array. 163 * 164 * WARNING: The input contains shallow copies of internal tensor inputs. 165 * Please do not mutate returned Tensor elements. 166 * 167 * @param[in] input_evalues The array to copy the inputs into. The first 168 * `inputs_size()` elements will be set to the corresponding input 169 * values. The rest of the array will be set to the EValue value None. 170 * @param[in] length The size of the `input_evalues` array in elements. Must 171 * be greater than or equal to `inputs_size()`. 172 * 173 * @returns Error::Ok on success, non-Ok on failure. 174 */ 175 ET_NODISCARD Error get_inputs(EValue* input_evalues, size_t length); 176 177 /** 178 * Execute the method. 179 * 180 * NOTE: Will fail if the method has been partially executed using the 181 * `step()` api. 182 * 183 * @returns Error::Ok on success, non-Ok on failure. 184 */ 185 ET_NODISCARD Error execute(); 186 187 /** 188 * EXPERIMENTAL: Advances/executes a single instruction in the method. 189 * 190 * @retval Error::Ok step succeeded 191 * @retval non-Ok step failed 192 * @retval Error::EndOfMethod method finished executing successfully 193 */ 194 ET_EXPERIMENTAL ET_NODISCARD Error step(); 195 196 /// DEPRECATED: Use `step()` instead. 197 ET_DEPRECATED ET_NODISCARD Error experimental_step(); 198 199 /** 200 * EXPERIMENTAL: Resets execution state to the start of the Method. For use 201 * with the `step()` API. 202 * 203 * @retval Error:Ok on success 204 * @retval Error::InvalidState if called before step-based execution reached 205 * the end of the Method. This means it is not possible to recover a 206 * Method that failed mid-execution. 207 */ 208 ET_EXPERIMENTAL ET_NODISCARD Error reset_execution(); 209 210 /// DEPRECATED: Use `reset_execution()` instead. 211 ET_DEPRECATED ET_NODISCARD Error experimental_reset_execution(); 212 213 /** 214 * Returns the MethodMeta that corresponds to the calling Method. 215 */ 216 MethodMeta method_meta() const; 217 218 /** 219 * Returns the number of inputs the Method expects. 220 */ 221 size_t inputs_size() const; 222 223 /** 224 * Returns the number of outputs the Method returns. 225 */ 226 size_t outputs_size() const; 227 228 /** 229 * Retrieves the output at the specified index. 230 */ 231 const EValue& get_output(size_t i) const; 232 233 EventTracer* get_event_tracer(); 234 235 /// DEPRECATED: Use MethodMeta instead to access metadata, and set_input to 236 /// update Method inputs. 237 ET_DEPRECATED const EValue& get_input(size_t i) const; 238 /// DEPRECATED: Use MethodMeta instead to access metadata, and set_input to 239 /// update Method inputs. 240 ET_DEPRECATED EValue& mutable_input(size_t i); 241 /// DEPRECATED: Use MethodMeta instead to access metadata, and get_output to 242 /// retrieve Method outputs. 243 ET_DEPRECATED EValue& mutable_output(size_t i); 244 245 ~Method(); 246 247 private: 248 // Delete other rule-of-five methods. 249 Method(const Method&) = delete; 250 Method& operator=(const Method&) noexcept = delete; 251 Method& operator=(Method&&) = delete; 252 253 // Let Program call load(). 254 friend class Program; 255 // Let Executor call the ctor and init(). 256 friend class Executor; 257 258 enum class InitializationState : uint8_t { 259 Uninitialized, 260 Initialized, 261 InitializationFailed, 262 }; 263 264 /// Tracks what step in program execution we are on 265 struct StepState { 266 size_t chain_idx; 267 size_t instr_idx; 268 }; 269 Method(const Program * program,MemoryManager * memory_manager,EventTracer * event_tracer,MemoryAllocator * temp_allocator)270 Method( 271 const Program* program, 272 MemoryManager* memory_manager, 273 EventTracer* event_tracer, 274 MemoryAllocator* temp_allocator) 275 : step_state_(), 276 program_(program), 277 memory_manager_(memory_manager), 278 temp_allocator_(temp_allocator), 279 serialization_plan_(nullptr), 280 event_tracer_(event_tracer), 281 n_value_(0), 282 values_(nullptr), 283 n_delegate_(0), 284 delegates_(nullptr), 285 n_chains_(0), 286 chains_(nullptr), 287 init_state_(InitializationState::Uninitialized) {} 288 289 /// Static factory used by Program. 290 ET_NODISCARD static Result<Method> load( 291 executorch_flatbuffer::ExecutionPlan* s_plan, 292 const Program* program, 293 MemoryManager* memory_manager, 294 EventTracer* event_tracer); 295 296 /** 297 * Initialize the method from its serialized representation. 298 * 299 * @returns Error::Ok on success, non-Ok on failure. 300 */ 301 ET_NODISCARD Error init(executorch_flatbuffer::ExecutionPlan* s_plan); 302 303 /// Returns true if the Method was successfully initialized. initialized()304 inline bool initialized() const { 305 return init_state_ == InitializationState::Initialized; 306 } 307 308 const EValue& get_value(size_t i) const; 309 EValue& mutable_value(size_t i); 310 size_t get_input_index(size_t i) const; 311 size_t get_output_index(size_t i) const; 312 313 // Executes a single instruction using the state in step_state_ 314 ET_NODISCARD Error execute_instruction(); 315 316 StepState step_state_; 317 const Program* program_; 318 MemoryManager* memory_manager_; 319 MemoryAllocator* temp_allocator_; 320 executorch_flatbuffer::ExecutionPlan* serialization_plan_; 321 EventTracer* event_tracer_; 322 323 size_t n_value_; 324 EValue* values_; 325 326 size_t n_delegate_; 327 BackendDelegate* delegates_; 328 329 size_t n_chains_; 330 Chain* chains_; 331 332 InitializationState init_state_; 333 334 /** 335 * Parses the elements of the values_ array. On error, n_value_ will be set to 336 * the number of successfully-initialized entries so that ~Method doesn't try 337 * to clean up uninitialized entries. 338 */ 339 ET_NODISCARD Error parse_values(); 340 341 ET_NODISCARD Error resolve_operator( 342 int32_t op_index, 343 OpFunction* kernels, 344 size_t kernel_index, 345 InstructionArgs args, 346 size_t n_args); 347 348 void log_outputs(); 349 }; 350 351 } // namespace runtime 352 } // namespace executorch 353 354 namespace torch { 355 namespace executor { 356 // TODO(T197294990): Remove these deprecated aliases once all users have moved 357 // to the new `::executorch` namespaces. 358 using ::executorch::runtime::Method; 359 } // namespace executor 360 } // namespace torch 361