xref: /aosp_15_r20/external/executorch/extension/module/module.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <memory>
12 #include <string>
13 #include <unordered_map>
14 #include <unordered_set>
15 #include <vector>
16 
17 #include <executorch/runtime/executor/program.h>
18 
19 namespace executorch {
20 namespace extension {
21 
22 /**
23  * A facade class for loading programs and executing methods within them.
24  */
25 class Module {
26  public:
27   /**
28    * Enum to define loading behavior.
29    */
30   enum class LoadMode {
31     /// Load the whole file as a buffer.
32     File,
33     /// Use mmap to load pages into memory.
34     Mmap,
35     /// Use memory locking and handle errors.
36     MmapUseMlock,
37     /// Use memory locking and ignore errors.
38     MmapUseMlockIgnoreErrors,
39   };
40 
41   /**
42    * Constructs an instance by loading a program from a file with specified
43    * memory locking behavior.
44    *
45    * @param[in] file_path The path to the ExecuTorch program file to load.
46    * @param[in] load_mode The loading mode to use.
47    * @param[in] event_tracer A EventTracer used for tracking and logging events.
48    */
49   explicit Module(
50       const std::string& file_path,
51       const LoadMode load_mode = LoadMode::MmapUseMlock,
52       std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
53 
54   /**
55    * Constructs an instance with the provided data loader and memory allocator.
56    *
57    * @param[in] data_loader A DataLoader used for loading program data.
58    * @param[in] memory_allocator A MemoryAllocator used for memory management.
59    * @param[in] temp_allocator A MemoryAllocator to use when allocating
60    * temporary data during kernel or delegate execution.
61    * @param[in] event_tracer A EventTracer used for tracking and logging events.
62    */
63   explicit Module(
64       std::unique_ptr<runtime::DataLoader> data_loader,
65       std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
66       std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
67       std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
68 
69   /**
70    * Constructs an instance using an existing shared program.
71    *
72    * @param[in] program The shared program to use. It's required the data loader
73    * the program uses is valid for the lifetime of the program.
74    * @param[in] memory_allocator A MemoryAllocator used for memory management.
75    * @param[in] temp_allocator A MemoryAllocator to use when allocating
76    * temporary data.
77    * @param[in] event_tracer A EventTracer used for tracking and logging events.
78    */
79   explicit Module(
80       std::shared_ptr<runtime::Program> program,
81       std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
82       std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
83       std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
84 
85   Module(const Module&) = delete;
86   Module& operator=(const Module&) = delete;
87   Module(Module&&) = delete;
88   Module& operator=(Module&&) = delete;
89 
90   /**
91    * Loads the program if needed.
92    *
93    * @param[in] verification The type of verification to do before returning
94    * success.
95    *
96    * @returns An Error to indicate success or failure of the loading process.
97    */
98   ET_NODISCARD
99   runtime::Error load(
100       const runtime::Program::Verification verification =
101           runtime::Program::Verification::Minimal);
102 
103   /**
104    * Checks if the program is loaded.
105    *
106    * @returns true if the program is loaded, false otherwise.
107    */
is_loaded()108   inline bool is_loaded() const {
109     return program_ != nullptr;
110   }
111 
112   /**
113    * Get the program. The data loader used by the program is guaranteed to be
114    * valid for the lifetime of the program.
115    *
116    * @returns Shared pointer to the program or nullptr if it's not yet loaded.
117    */
program()118   inline std::shared_ptr<runtime::Program> program() const {
119     return program_;
120   }
121 
122   /**
123    * Get a list of method names available in the loaded program.
124    * Loads the program and method if needed.
125    *
126    * @returns A set of strings containing the names of the methods, or an error
127    * if the program or method failed to load.
128    */
129   runtime::Result<std::unordered_set<std::string>> method_names();
130 
131   /**
132    * Load a specific method from the program and set up memory management if
133    * needed. The loaded method is cached to reuse the next time it's executed.
134    *
135    * @param[in] method_name The name of the method to load.
136    * @param[in] event_tracer Per-method event tracer to profile/trace methods
137    * individually. When not given, the event tracer passed to the Module
138    * constructor is used. Otherwise, this per-method event tracer takes
139    * precedence.
140    *
141    * @returns An Error to indicate success or failure.
142    */
143   ET_NODISCARD
144   runtime::Error load_method(
145       const std::string& method_name,
146       torch::executor::EventTracer* event_tracer = nullptr);
147 
148   /**
149    * Load the 'forward' method from the program and set up memory management if
150    * needed. The loaded method is cached to reuse the next time it's executed.
151    *
152    * @param[in] event_tracer An event tracer used for tracking and logging
153    * events.
154    *
155    * @returns An Error to indicate success or failure.
156    */
157   ET_NODISCARD inline runtime::Error load_forward(
158       torch::executor::EventTracer* event_tracer = nullptr) {
159     return load_method("forward", event_tracer);
160   }
161 
162   /**
163    * Checks if a specific method is loaded.
164    *
165    * @param[in] method_name The name of the method to check.
166    *
167    * @returns true if the method specified by method_name is loaded, false
168    * otherwise.
169    */
is_method_loaded(const std::string & method_name)170   inline bool is_method_loaded(const std::string& method_name) const {
171     return methods_.count(method_name);
172   }
173 
174   /**
175    * Get a method metadata struct by method name.
176    * Loads the program and method if needed.
177    *
178    * @param[in] method_name The name of the method to get the metadata for.
179    *
180    * @returns A method metadata, or an error if the program or method failed to
181    * load.
182    */
183   runtime::Result<runtime::MethodMeta> method_meta(
184       const std::string& method_name);
185 
186   /**
187    * Execute a specific method with the given input values and retrieve the
188    * output values. Loads the program and method before executing if needed.
189    *
190    * @param[in] method_name The name of the method to execute.
191    * @param[in] input_values A vector of input values to be passed to the
192    * method.
193    *
194    * @returns A Result object containing either a vector of output values
195    *          from the method or an error to indicate failure.
196    */
197   ET_NODISCARD
198   runtime::Result<std::vector<runtime::EValue>> execute(
199       const std::string& method_name,
200       const std::vector<runtime::EValue>& input_values);
201 
202   /**
203    * Execute a specific method with a single input value.
204    * Loads the program and method before executing if needed.
205    *
206    * @param[in] method_name The name of the method to execute.
207    * @param[in] input_value A value to be passed to the method.
208    *
209    * @returns A Result object containing either a vector of output values
210    *          from the method or an error to indicate failure.
211    */
execute(const std::string & method_name,const runtime::EValue & input_value)212   ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> execute(
213       const std::string& method_name,
214       const runtime::EValue& input_value) {
215     return execute(method_name, std::vector<runtime::EValue>{input_value});
216   }
217 
218   /**
219    * Execute a specific method without any input values.
220    * Loads the program and method before executing if needed.
221    *
222    * @param[in] method_name The name of the method to execute.
223    *
224    * @returns A Result object containing either a vector of output values
225    *          from the method or an error to indicate failure.
226    */
execute(const std::string & method_name)227   ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> execute(
228       const std::string& method_name) {
229     return execute(method_name, std::vector<runtime::EValue>{});
230   }
231 
232   /**
233    * Retrieve the output value of a specific method with the given input values.
234    * Loads the program and method before execution if needed.
235    *
236    * @param[in] method_name The name of the method to execute.
237    * @param[in] input_values A vector of input values to be passed to the
238    * method.
239    *
240    * @returns A Result object containing either the first output value from the
241    * method or an error to indicate failure.
242    */
get(const std::string & method_name,const std::vector<runtime::EValue> & input_values)243   ET_NODISCARD inline runtime::Result<runtime::EValue> get(
244       const std::string& method_name,
245       const std::vector<runtime::EValue>& input_values) {
246     auto result = ET_UNWRAP(execute(method_name, input_values));
247     if (result.empty()) {
248       return runtime::Error::InvalidArgument;
249     }
250     return result[0];
251   }
252 
253   /**
254    * Retrieve the output value of a specific method with a single input value.
255    * Loads the program and method before execution if needed.
256    *
257    * @param[in] method_name The name of the method to execute.
258    * @param[in] input_value A value to be passed to the method.
259    *
260    * @returns A Result object containing either the first output value from the
261    * method or an error to indicate failure.
262    */
get(const std::string & method_name,const runtime::EValue & input_value)263   ET_NODISCARD inline runtime::Result<runtime::EValue> get(
264       const std::string& method_name,
265       const runtime::EValue& input_value) {
266     return get(method_name, std::vector<runtime::EValue>{input_value});
267   }
268 
269   /**
270    * Retrieve the output value of a specific method without any input values.
271    * Loads the program and method before execution if needed.
272    *
273    * @param[in] method_name The name of the method to execute.
274    *
275    * @returns A Result object containing either the first output value from the
276    * method or an error to indicate failure.
277    */
get(const std::string & method_name)278   ET_NODISCARD inline runtime::Result<runtime::EValue> get(
279       const std::string& method_name) {
280     return get(method_name, std::vector<runtime::EValue>{});
281   }
282 
283   /**
284    * Execute the 'forward' method with the given input values and retrieve the
285    * output values. Loads the program and method before executing if needed.
286    *
287    * @param[in] input_values A vector of input values for the 'forward' method.
288    *
289    * @returns A Result object containing either a vector of output values
290    *          from the 'forward' method or an error to indicate failure.
291    */
forward(const std::vector<runtime::EValue> & input_values)292   ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward(
293       const std::vector<runtime::EValue>& input_values) {
294     return execute("forward", input_values);
295   }
296 
297   /**
298    * Execute the 'forward' method with a single value.
299    * Loads the program and method before executing if needed.
300    *
301    * @param[in] input_value A value for the 'forward' method.
302    *
303    * @returns A Result object containing either a vector of output values
304    *          from the 'forward' method or an error to indicate failure.
305    */
forward(const runtime::EValue & input_value)306   ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward(
307       const runtime::EValue& input_value) {
308     return forward(std::vector<runtime::EValue>{input_value});
309   }
310 
311   /**
312    * Execute the 'forward' method without any input values.
313    * Loads the program and method before executing if needed.
314    *
315    * @returns A Result object containing either a vector of output values
316    *          from the 'forward' method or an error to indicate failure.
317    */
forward()318   ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward() {
319     return forward(std::vector<runtime::EValue>{});
320   }
321 
322   /**
323    * Sets a single input value for a specific method.
324    *
325    * @param[in] method_name The name of the method.
326    * @param[in] input_value The EValue to set as the method input.
327    * @param[in] input_index Zero-based index of the input to set.
328    *
329    * @returns An Error to indicate success or failure.
330    */
331   ET_NODISCARD
332   runtime::Error set_input(
333       const std::string& method_name,
334       const runtime::EValue& input_value,
335       size_t input_index);
336 
337   /**
338    * Sets a single input value for the "forward" method.
339    *
340    * @param[in] input_value The EValue to set as the method input.
341    * @param[in] input_index Zero-based index of the input to set.
342    *
343    * @returns An Error to indicate success or failure.
344    */
345   ET_NODISCARD
set_input(const runtime::EValue & input_value,size_t input_index)346   inline runtime::Error set_input(
347       const runtime::EValue& input_value,
348       size_t input_index) {
349     return set_input("forward", input_value, input_index);
350   }
351 
352   /**
353    * Sets all input values for a specific method.
354    *
355    * @param[in] method_name The name of the method.
356    * @param[in] input_values A vector of EValues to set as the method inputs.
357    *
358    * @returns An Error to indicate success or failure.
359    */
360   ET_NODISCARD
361   runtime::Error set_inputs(
362       const std::string& method_name,
363       const std::vector<runtime::EValue>& input_values);
364 
365   /**
366    * Sets all input values for the "forward" method.
367    *
368    * @param[in] input_values A vector of EValues to set as the method inputs.
369    *
370    * @returns An Error to indicate success or failure.
371    */
372   ET_NODISCARD
set_inputs(const std::vector<runtime::EValue> & input_values)373   inline runtime::Error set_inputs(
374       const std::vector<runtime::EValue>& input_values) {
375     return set_inputs("forward", input_values);
376   }
377 
378   /**
379    * Sets the output tensor for a specific method.
380    *
381    * @param[in] method_name The name of the method.
382    * @param[in] output_value The EValue containing the Tensor to set as the
383    * method output.
384    * @param[in] output_index Zero-based index of the output to set.
385    *
386    * @returns An Error to indicate success or failure.
387    *
388    * @note Only Tensor outputs are currently supported for setting.
389    */
390   ET_NODISCARD
391   runtime::Error set_output(
392       const std::string& method_name,
393       runtime::EValue output_value,
394       size_t output_index = 0);
395 
396   /**
397    * Sets the output tensor for the "forward" method.
398    *
399    * @param[in] output_value The EValue containing the Tensor to set as the
400    * method output.
401    * @param[in] output_index Zero-based index of the output to set.
402    *
403    * @returns An Error to indicate success or failure.
404    *
405    * @note Only Tensor outputs are currently supported for setting.
406    */
407   ET_NODISCARD
408   inline runtime::Error set_output(
409       runtime::EValue output_value,
410       size_t output_index = 0) {
411     return set_output("forward", std::move(output_value), output_index);
412   }
413 
414   /**
415    * Retrieves the EventTracer instance being used by the Module.
416    * EventTracer is used for tracking and logging events during the execution
417    * of methods.
418    *
419    * @returns A pointer to the EventTracer instance. Returns nullptr if no
420    * EventTracer is set.
421    */
event_tracer()422   inline runtime::EventTracer* event_tracer() const {
423     return event_tracer_.get();
424   }
425 
426  private:
427   struct MethodHolder {
428     std::vector<std::vector<uint8_t>> planned_buffers;
429     std::vector<runtime::Span<uint8_t>> planned_spans;
430     std::unique_ptr<runtime::HierarchicalAllocator> planned_memory;
431     std::unique_ptr<runtime::MemoryManager> memory_manager;
432     std::unique_ptr<runtime::Method> method;
433     std::vector<runtime::EValue> inputs;
434   };
435 
436  private:
437   std::string file_path_;
438   LoadMode load_mode_{LoadMode::MmapUseMlock};
439   std::shared_ptr<runtime::Program> program_;
440   std::unique_ptr<runtime::DataLoader> data_loader_;
441   std::unique_ptr<runtime::MemoryAllocator> memory_allocator_;
442   std::unique_ptr<runtime::MemoryAllocator> temp_allocator_;
443   std::unique_ptr<runtime::EventTracer> event_tracer_;
444 
445  protected:
446   std::unordered_map<std::string, MethodHolder> methods_;
447 
448   friend class ExecuTorchJni;
449 };
450 
451 } // namespace extension
452 } // namespace executorch
453 
454 namespace torch {
455 namespace executor {
456 // TODO(T197294990): Remove these deprecated aliases once all users have moved
457 // to the new `::executorch` namespaces.
458 using ::executorch::extension::Module;
459 } // namespace executor
460 } // namespace torch
461