xref: /aosp_15_r20/external/executorch/runtime/executor/method.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/runtime/core/evalue.h>
12 #include <executorch/runtime/core/event_tracer.h>
13 #include <executorch/runtime/core/exec_aten/exec_aten.h>
14 #include <executorch/runtime/core/span.h>
15 #include <executorch/runtime/executor/memory_manager.h>
16 #include <executorch/runtime/executor/method_meta.h>
17 #include <executorch/runtime/platform/compiler.h>
18 
19 // Forward declare flatbuffer types. This is a public header and must not
20 // include the generated flatbuffer header.
21 namespace executorch_flatbuffer {
22 struct Chain;
23 struct ExecutionPlan;
24 struct EValue;
25 } // namespace executorch_flatbuffer
26 
27 namespace executorch {
28 namespace runtime {
29 
30 // Forward declare Program to avoid a circular reference.
31 class Program;
32 
33 // Forward declare internal types.
34 class BackendDelegate;
35 struct Chain;
36 class KernelRuntimeContext;
37 using OpFunction = void (*)(KernelRuntimeContext&, EValue**);
38 /// A list of pointers into the master values table that together compose the
39 /// argument list for a single instruction
40 using InstructionArgs = Span<EValue*>;
41 
42 /**
43  * An executable method of an executorch program. Maps to a python method like
44  * `forward()` on the original nn.Module.
45  */
46 class Method final {
47  public:
48   /**
49    * Move ctor. Takes ownership of resources previously owned by `rhs`,
50    * and leaves `rhs` in an uninitialized state.
51    */
Method(Method && rhs)52   Method(Method&& rhs) noexcept
53       : step_state_(rhs.step_state_),
54         program_(rhs.program_),
55         memory_manager_(rhs.memory_manager_),
56         temp_allocator_(rhs.temp_allocator_),
57         serialization_plan_(rhs.serialization_plan_),
58         event_tracer_(rhs.event_tracer_),
59         n_value_(rhs.n_value_),
60         values_(rhs.values_),
61         n_delegate_(rhs.n_delegate_),
62         delegates_(rhs.delegates_),
63         n_chains_(rhs.n_chains_),
64         chains_(rhs.chains_),
65         init_state_(rhs.init_state_) {
66     // Required: clear out fields that the dtor looks at, so that we don't free
67     // anything twice.
68     rhs.n_value_ = 0;
69     rhs.values_ = nullptr;
70     rhs.n_delegate_ = 0;
71     rhs.delegates_ = nullptr;
72 
73     // Helpful: Try to ensure that any other interactions with the old object
74     // result in failures.
75     rhs.init_state_ = InitializationState::Uninitialized;
76     rhs.step_state_ = {};
77     rhs.program_ = nullptr;
78     rhs.memory_manager_ = nullptr;
79     rhs.serialization_plan_ = nullptr;
80     rhs.event_tracer_ = nullptr;
81     rhs.n_chains_ = 0;
82     rhs.chains_ = nullptr;
83   }
84 
85   /**
86    * Sets the internal input value to be equivalent to the to the provided
87    * value.
88    *
89    * @param[in] input_evalue The evalue to copy into the method input. If the
90    *     evalue is a tensor, the data is copied in most cases, so the tensor
91    *     passed in here does not always need to outlive this call. But there is
92    *     a case where the Method will keep a pointer to the tensor's data.
93    *     Based on the memory plan of the method, the inputs may not have
94    *     buffer space pre-allocated for them. In this case the executor will
95    *     alias the memory of the tensors provided as inputs here rather then
96    *     deepcopy the input into the memory planned arena.
97    *
98    * @param[in] input_idx Zero-based index of the input to set. Must be less
99    *     than the value returned by inputs_size().
100    *
101    * @returns Error::Ok on success, non-Ok on failure.
102    */
103   ET_NODISCARD Error set_input(const EValue& input_evalue, size_t input_idx);
104 
105   /**
106    * Sets the values of all method inputs.
107    *
108    * See set_input() for a more detailed description of the behavior.
109    *
110    * @param[in] input_evalues The new values for all of the method inputs. The
111    *     type of each element must match the type of corresponding input. If the
112    *     value of an element is a tensor, attempts to allow dynamic shape, but
113    *     the dtype must always agree.
114    *
115    * @returns Error::Ok on success, non-Ok on failure.
116    */
117   ET_NODISCARD Error
118   set_inputs(const executorch::aten::ArrayRef<EValue>& input_evalues);
119 
120   /**
121    * Sets the data buffer of the specified method output to the provided value.
122    *
123    * NOTE: Based on the memory plan of the method, the output tensors may not
124    * have buffer space pre-allocated for them, in this case the executor will
125    * point those tensors to the buffer provided here, so the user should take
126    * care that the life span of this memory outlasts the executor forward.
127    *
128    * @param[in] buffer The block of memory to point the specified tensor at.
129    *
130    * @param[in] size the length of buffer in bytes, must be >= the nbytes of the
131    * specified tensor.
132    *
133    * @param[in] output_idx The index of the output to set the data_ptr for. Must
134    *     correspond to a tensor, and that tensor must not have had a buffer
135    *     allocated by the memory plan.
136    *
137    * @returns Error::Ok on success, non-Ok on failure.
138    */
139   ET_NODISCARD Error
140   set_output_data_ptr(void* buffer, size_t size, size_t output_idx);
141 
142   /**
143    * Copies the method's outputs into the provided array.
144    *
145    * WARNING: The output contains shallow copies of internal tensor outputs.
146    * Please do not mutate returned Tensor elements.
147    *
148    * TODO(T139259264): Add checks to detect output mutation, or deep-copy
149    * outputs.
150    *
151    * @param[in] output_evalues The array to copy the outputs into. The first
152    *     `outputs_size()` elements will be set to the corresponding output
153    *     values. The rest of the array will be set to the EValue value None.
154    * @param[in] length The size of the `output_evalues` array in elements. Must
155    *     be greater than or equal to `outputs_size()`.
156    *
157    * @returns Error::Ok on success, non-Ok on failure.
158    */
159   ET_NODISCARD Error get_outputs(EValue* output_evalues, size_t length);
160 
161   /**
162    * Copies the method's inputs into the provided array.
163    *
164    * WARNING: The input contains shallow copies of internal tensor inputs.
165    * Please do not mutate returned Tensor elements.
166    *
167    * @param[in] input_evalues The array to copy the inputs into. The first
168    *     `inputs_size()` elements will be set to the corresponding input
169    *     values. The rest of the array will be set to the EValue value None.
170    * @param[in] length The size of the `input_evalues` array in elements. Must
171    *     be greater than or equal to `inputs_size()`.
172    *
173    * @returns Error::Ok on success, non-Ok on failure.
174    */
175   ET_NODISCARD Error get_inputs(EValue* input_evalues, size_t length);
176 
177   /**
178    * Execute the method.
179    *
180    * NOTE: Will fail if the method has been partially executed using the
181    * `step()` api.
182    *
183    * @returns Error::Ok on success, non-Ok on failure.
184    */
185   ET_NODISCARD Error execute();
186 
187   /**
188    * EXPERIMENTAL: Advances/executes a single instruction in the method.
189    *
190    * @retval Error::Ok step succeeded
191    * @retval non-Ok step failed
192    * @retval Error::EndOfMethod method finished executing successfully
193    */
194   ET_EXPERIMENTAL ET_NODISCARD Error step();
195 
196   /// DEPRECATED: Use `step()` instead.
197   ET_DEPRECATED ET_NODISCARD Error experimental_step();
198 
199   /**
200    * EXPERIMENTAL: Resets execution state to the start of the Method. For use
201    * with the `step()` API.
202    *
203    * @retval Error:Ok on success
204    * @retval Error::InvalidState if called before step-based execution reached
205    *     the end of the Method. This means it is not possible to recover a
206    *     Method that failed mid-execution.
207    */
208   ET_EXPERIMENTAL ET_NODISCARD Error reset_execution();
209 
210   /// DEPRECATED: Use `reset_execution()` instead.
211   ET_DEPRECATED ET_NODISCARD Error experimental_reset_execution();
212 
213   /**
214    * Returns the MethodMeta that corresponds to the calling Method.
215    */
216   MethodMeta method_meta() const;
217 
218   /**
219    * Returns the number of inputs the Method expects.
220    */
221   size_t inputs_size() const;
222 
223   /**
224    * Returns the number of outputs the Method returns.
225    */
226   size_t outputs_size() const;
227 
228   /**
229    * Retrieves the output at the specified index.
230    */
231   const EValue& get_output(size_t i) const;
232 
233   EventTracer* get_event_tracer();
234 
235   /// DEPRECATED: Use MethodMeta instead to access metadata, and set_input to
236   /// update Method inputs.
237   ET_DEPRECATED const EValue& get_input(size_t i) const;
238   /// DEPRECATED: Use MethodMeta instead to access metadata, and set_input to
239   /// update Method inputs.
240   ET_DEPRECATED EValue& mutable_input(size_t i);
241   /// DEPRECATED: Use MethodMeta instead to access metadata, and get_output to
242   /// retrieve Method outputs.
243   ET_DEPRECATED EValue& mutable_output(size_t i);
244 
245   ~Method();
246 
247  private:
248   // Delete other rule-of-five methods.
249   Method(const Method&) = delete;
250   Method& operator=(const Method&) noexcept = delete;
251   Method& operator=(Method&&) = delete;
252 
253   // Let Program call load().
254   friend class Program;
255   // Let Executor call the ctor and init().
256   friend class Executor;
257 
258   enum class InitializationState : uint8_t {
259     Uninitialized,
260     Initialized,
261     InitializationFailed,
262   };
263 
264   /// Tracks what step in program execution we are on
265   struct StepState {
266     size_t chain_idx;
267     size_t instr_idx;
268   };
269 
Method(const Program * program,MemoryManager * memory_manager,EventTracer * event_tracer,MemoryAllocator * temp_allocator)270   Method(
271       const Program* program,
272       MemoryManager* memory_manager,
273       EventTracer* event_tracer,
274       MemoryAllocator* temp_allocator)
275       : step_state_(),
276         program_(program),
277         memory_manager_(memory_manager),
278         temp_allocator_(temp_allocator),
279         serialization_plan_(nullptr),
280         event_tracer_(event_tracer),
281         n_value_(0),
282         values_(nullptr),
283         n_delegate_(0),
284         delegates_(nullptr),
285         n_chains_(0),
286         chains_(nullptr),
287         init_state_(InitializationState::Uninitialized) {}
288 
289   /// Static factory used by Program.
290   ET_NODISCARD static Result<Method> load(
291       executorch_flatbuffer::ExecutionPlan* s_plan,
292       const Program* program,
293       MemoryManager* memory_manager,
294       EventTracer* event_tracer);
295 
296   /**
297    * Initialize the method from its serialized representation.
298    *
299    * @returns Error::Ok on success, non-Ok on failure.
300    */
301   ET_NODISCARD Error init(executorch_flatbuffer::ExecutionPlan* s_plan);
302 
303   /// Returns true if the Method was successfully initialized.
initialized()304   inline bool initialized() const {
305     return init_state_ == InitializationState::Initialized;
306   }
307 
308   const EValue& get_value(size_t i) const;
309   EValue& mutable_value(size_t i);
310   size_t get_input_index(size_t i) const;
311   size_t get_output_index(size_t i) const;
312 
313   // Executes a single instruction using the state in step_state_
314   ET_NODISCARD Error execute_instruction();
315 
316   StepState step_state_;
317   const Program* program_;
318   MemoryManager* memory_manager_;
319   MemoryAllocator* temp_allocator_;
320   executorch_flatbuffer::ExecutionPlan* serialization_plan_;
321   EventTracer* event_tracer_;
322 
323   size_t n_value_;
324   EValue* values_;
325 
326   size_t n_delegate_;
327   BackendDelegate* delegates_;
328 
329   size_t n_chains_;
330   Chain* chains_;
331 
332   InitializationState init_state_;
333 
334   /**
335    * Parses the elements of the values_ array. On error, n_value_ will be set to
336    * the number of successfully-initialized entries so that ~Method doesn't try
337    * to clean up uninitialized entries.
338    */
339   ET_NODISCARD Error parse_values();
340 
341   ET_NODISCARD Error resolve_operator(
342       int32_t op_index,
343       OpFunction* kernels,
344       size_t kernel_index,
345       InstructionArgs args,
346       size_t n_args);
347 
348   void log_outputs();
349 };
350 
351 } // namespace runtime
352 } // namespace executorch
353 
354 namespace torch {
355 namespace executor {
356 // TODO(T197294990): Remove these deprecated aliases once all users have moved
357 // to the new `::executorch` namespaces.
358 using ::executorch::runtime::Method;
359 } // namespace executor
360 } // namespace torch
361