xref: /aosp_15_r20/external/executorch/extension/module/module.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #pragma once
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker #include <memory>
12*523fa7a6SAndroid Build Coastguard Worker #include <string>
13*523fa7a6SAndroid Build Coastguard Worker #include <unordered_map>
14*523fa7a6SAndroid Build Coastguard Worker #include <unordered_set>
15*523fa7a6SAndroid Build Coastguard Worker #include <vector>
16*523fa7a6SAndroid Build Coastguard Worker 
17*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/executor/program.h>
18*523fa7a6SAndroid Build Coastguard Worker 
19*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
20*523fa7a6SAndroid Build Coastguard Worker namespace extension {
21*523fa7a6SAndroid Build Coastguard Worker 
22*523fa7a6SAndroid Build Coastguard Worker /**
23*523fa7a6SAndroid Build Coastguard Worker  * A facade class for loading programs and executing methods within them.
24*523fa7a6SAndroid Build Coastguard Worker  */
25*523fa7a6SAndroid Build Coastguard Worker class Module {
26*523fa7a6SAndroid Build Coastguard Worker  public:
27*523fa7a6SAndroid Build Coastguard Worker   /**
28*523fa7a6SAndroid Build Coastguard Worker    * Enum to define loading behavior.
29*523fa7a6SAndroid Build Coastguard Worker    */
30*523fa7a6SAndroid Build Coastguard Worker   enum class LoadMode {
31*523fa7a6SAndroid Build Coastguard Worker     /// Load the whole file as a buffer.
32*523fa7a6SAndroid Build Coastguard Worker     File,
33*523fa7a6SAndroid Build Coastguard Worker     /// Use mmap to load pages into memory.
34*523fa7a6SAndroid Build Coastguard Worker     Mmap,
35*523fa7a6SAndroid Build Coastguard Worker     /// Use memory locking and handle errors.
36*523fa7a6SAndroid Build Coastguard Worker     MmapUseMlock,
37*523fa7a6SAndroid Build Coastguard Worker     /// Use memory locking and ignore errors.
38*523fa7a6SAndroid Build Coastguard Worker     MmapUseMlockIgnoreErrors,
39*523fa7a6SAndroid Build Coastguard Worker   };
40*523fa7a6SAndroid Build Coastguard Worker 
41*523fa7a6SAndroid Build Coastguard Worker   /**
42*523fa7a6SAndroid Build Coastguard Worker    * Constructs an instance by loading a program from a file with specified
43*523fa7a6SAndroid Build Coastguard Worker    * memory locking behavior.
44*523fa7a6SAndroid Build Coastguard Worker    *
45*523fa7a6SAndroid Build Coastguard Worker    * @param[in] file_path The path to the ExecuTorch program file to load.
46*523fa7a6SAndroid Build Coastguard Worker    * @param[in] load_mode The loading mode to use.
47*523fa7a6SAndroid Build Coastguard Worker    * @param[in] event_tracer A EventTracer used for tracking and logging events.
48*523fa7a6SAndroid Build Coastguard Worker    */
49*523fa7a6SAndroid Build Coastguard Worker   explicit Module(
50*523fa7a6SAndroid Build Coastguard Worker       const std::string& file_path,
51*523fa7a6SAndroid Build Coastguard Worker       const LoadMode load_mode = LoadMode::MmapUseMlock,
52*523fa7a6SAndroid Build Coastguard Worker       std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
53*523fa7a6SAndroid Build Coastguard Worker 
54*523fa7a6SAndroid Build Coastguard Worker   /**
55*523fa7a6SAndroid Build Coastguard Worker    * Constructs an instance with the provided data loader and memory allocator.
56*523fa7a6SAndroid Build Coastguard Worker    *
57*523fa7a6SAndroid Build Coastguard Worker    * @param[in] data_loader A DataLoader used for loading program data.
58*523fa7a6SAndroid Build Coastguard Worker    * @param[in] memory_allocator A MemoryAllocator used for memory management.
59*523fa7a6SAndroid Build Coastguard Worker    * @param[in] temp_allocator A MemoryAllocator to use when allocating
60*523fa7a6SAndroid Build Coastguard Worker    * temporary data during kernel or delegate execution.
61*523fa7a6SAndroid Build Coastguard Worker    * @param[in] event_tracer A EventTracer used for tracking and logging events.
62*523fa7a6SAndroid Build Coastguard Worker    */
63*523fa7a6SAndroid Build Coastguard Worker   explicit Module(
64*523fa7a6SAndroid Build Coastguard Worker       std::unique_ptr<runtime::DataLoader> data_loader,
65*523fa7a6SAndroid Build Coastguard Worker       std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
66*523fa7a6SAndroid Build Coastguard Worker       std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
67*523fa7a6SAndroid Build Coastguard Worker       std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
68*523fa7a6SAndroid Build Coastguard Worker 
69*523fa7a6SAndroid Build Coastguard Worker   /**
70*523fa7a6SAndroid Build Coastguard Worker    * Constructs an instance using an existing shared program.
71*523fa7a6SAndroid Build Coastguard Worker    *
72*523fa7a6SAndroid Build Coastguard Worker    * @param[in] program The shared program to use. It's required the data loader
73*523fa7a6SAndroid Build Coastguard Worker    * the program uses is valid for the lifetime of the program.
74*523fa7a6SAndroid Build Coastguard Worker    * @param[in] memory_allocator A MemoryAllocator used for memory management.
75*523fa7a6SAndroid Build Coastguard Worker    * @param[in] temp_allocator A MemoryAllocator to use when allocating
76*523fa7a6SAndroid Build Coastguard Worker    * temporary data.
77*523fa7a6SAndroid Build Coastguard Worker    * @param[in] event_tracer A EventTracer used for tracking and logging events.
78*523fa7a6SAndroid Build Coastguard Worker    */
79*523fa7a6SAndroid Build Coastguard Worker   explicit Module(
80*523fa7a6SAndroid Build Coastguard Worker       std::shared_ptr<runtime::Program> program,
81*523fa7a6SAndroid Build Coastguard Worker       std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
82*523fa7a6SAndroid Build Coastguard Worker       std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
83*523fa7a6SAndroid Build Coastguard Worker       std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
84*523fa7a6SAndroid Build Coastguard Worker 
85*523fa7a6SAndroid Build Coastguard Worker   Module(const Module&) = delete;
86*523fa7a6SAndroid Build Coastguard Worker   Module& operator=(const Module&) = delete;
87*523fa7a6SAndroid Build Coastguard Worker   Module(Module&&) = delete;
88*523fa7a6SAndroid Build Coastguard Worker   Module& operator=(Module&&) = delete;
89*523fa7a6SAndroid Build Coastguard Worker 
90*523fa7a6SAndroid Build Coastguard Worker   /**
91*523fa7a6SAndroid Build Coastguard Worker    * Loads the program if needed.
92*523fa7a6SAndroid Build Coastguard Worker    *
93*523fa7a6SAndroid Build Coastguard Worker    * @param[in] verification The type of verification to do before returning
94*523fa7a6SAndroid Build Coastguard Worker    * success.
95*523fa7a6SAndroid Build Coastguard Worker    *
96*523fa7a6SAndroid Build Coastguard Worker    * @returns An Error to indicate success or failure of the loading process.
97*523fa7a6SAndroid Build Coastguard Worker    */
98*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD
99*523fa7a6SAndroid Build Coastguard Worker   runtime::Error load(
100*523fa7a6SAndroid Build Coastguard Worker       const runtime::Program::Verification verification =
101*523fa7a6SAndroid Build Coastguard Worker           runtime::Program::Verification::Minimal);
102*523fa7a6SAndroid Build Coastguard Worker 
103*523fa7a6SAndroid Build Coastguard Worker   /**
104*523fa7a6SAndroid Build Coastguard Worker    * Checks if the program is loaded.
105*523fa7a6SAndroid Build Coastguard Worker    *
106*523fa7a6SAndroid Build Coastguard Worker    * @returns true if the program is loaded, false otherwise.
107*523fa7a6SAndroid Build Coastguard Worker    */
is_loaded()108*523fa7a6SAndroid Build Coastguard Worker   inline bool is_loaded() const {
109*523fa7a6SAndroid Build Coastguard Worker     return program_ != nullptr;
110*523fa7a6SAndroid Build Coastguard Worker   }
111*523fa7a6SAndroid Build Coastguard Worker 
112*523fa7a6SAndroid Build Coastguard Worker   /**
113*523fa7a6SAndroid Build Coastguard Worker    * Get the program. The data loader used by the program is guaranteed to be
114*523fa7a6SAndroid Build Coastguard Worker    * valid for the lifetime of the program.
115*523fa7a6SAndroid Build Coastguard Worker    *
116*523fa7a6SAndroid Build Coastguard Worker    * @returns Shared pointer to the program or nullptr if it's not yet loaded.
117*523fa7a6SAndroid Build Coastguard Worker    */
program()118*523fa7a6SAndroid Build Coastguard Worker   inline std::shared_ptr<runtime::Program> program() const {
119*523fa7a6SAndroid Build Coastguard Worker     return program_;
120*523fa7a6SAndroid Build Coastguard Worker   }
121*523fa7a6SAndroid Build Coastguard Worker 
122*523fa7a6SAndroid Build Coastguard Worker   /**
123*523fa7a6SAndroid Build Coastguard Worker    * Get a list of method names available in the loaded program.
124*523fa7a6SAndroid Build Coastguard Worker    * Loads the program and method if needed.
125*523fa7a6SAndroid Build Coastguard Worker    *
126*523fa7a6SAndroid Build Coastguard Worker    * @returns A set of strings containing the names of the methods, or an error
127*523fa7a6SAndroid Build Coastguard Worker    * if the program or method failed to load.
128*523fa7a6SAndroid Build Coastguard Worker    */
129*523fa7a6SAndroid Build Coastguard Worker   runtime::Result<std::unordered_set<std::string>> method_names();
130*523fa7a6SAndroid Build Coastguard Worker 
131*523fa7a6SAndroid Build Coastguard Worker   /**
132*523fa7a6SAndroid Build Coastguard Worker    * Load a specific method from the program and set up memory management if
133*523fa7a6SAndroid Build Coastguard Worker    * needed. The loaded method is cached to reuse the next time it's executed.
134*523fa7a6SAndroid Build Coastguard Worker    *
135*523fa7a6SAndroid Build Coastguard Worker    * @param[in] method_name The name of the method to load.
136*523fa7a6SAndroid Build Coastguard Worker    * @param[in] event_tracer Per-method event tracer to profile/trace methods
137*523fa7a6SAndroid Build Coastguard Worker    * individually. When not given, the event tracer passed to the Module
138*523fa7a6SAndroid Build Coastguard Worker    * constructor is used. Otherwise, this per-method event tracer takes
139*523fa7a6SAndroid Build Coastguard Worker    * precedence.
140*523fa7a6SAndroid Build Coastguard Worker    *
141*523fa7a6SAndroid Build Coastguard Worker    * @returns An Error to indicate success or failure.
142*523fa7a6SAndroid Build Coastguard Worker    */
143*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD
144*523fa7a6SAndroid Build Coastguard Worker   runtime::Error load_method(
145*523fa7a6SAndroid Build Coastguard Worker       const std::string& method_name,
146*523fa7a6SAndroid Build Coastguard Worker       torch::executor::EventTracer* event_tracer = nullptr);
147*523fa7a6SAndroid Build Coastguard Worker 
148*523fa7a6SAndroid Build Coastguard Worker   /**
149*523fa7a6SAndroid Build Coastguard Worker    * Load the 'forward' method from the program and set up memory management if
150*523fa7a6SAndroid Build Coastguard Worker    * needed. The loaded method is cached to reuse the next time it's executed.
151*523fa7a6SAndroid Build Coastguard Worker    *
152*523fa7a6SAndroid Build Coastguard Worker    * @param[in] event_tracer An event tracer used for tracking and logging
153*523fa7a6SAndroid Build Coastguard Worker    * events.
154*523fa7a6SAndroid Build Coastguard Worker    *
155*523fa7a6SAndroid Build Coastguard Worker    * @returns An Error to indicate success or failure.
156*523fa7a6SAndroid Build Coastguard Worker    */
157*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD inline runtime::Error load_forward(
158*523fa7a6SAndroid Build Coastguard Worker       torch::executor::EventTracer* event_tracer = nullptr) {
159*523fa7a6SAndroid Build Coastguard Worker     return load_method("forward", event_tracer);
160*523fa7a6SAndroid Build Coastguard Worker   }
161*523fa7a6SAndroid Build Coastguard Worker 
162*523fa7a6SAndroid Build Coastguard Worker   /**
163*523fa7a6SAndroid Build Coastguard Worker    * Checks if a specific method is loaded.
164*523fa7a6SAndroid Build Coastguard Worker    *
165*523fa7a6SAndroid Build Coastguard Worker    * @param[in] method_name The name of the method to check.
166*523fa7a6SAndroid Build Coastguard Worker    *
167*523fa7a6SAndroid Build Coastguard Worker    * @returns true if the method specified by method_name is loaded, false
168*523fa7a6SAndroid Build Coastguard Worker    * otherwise.
169*523fa7a6SAndroid Build Coastguard Worker    */
is_method_loaded(const std::string & method_name)170*523fa7a6SAndroid Build Coastguard Worker   inline bool is_method_loaded(const std::string& method_name) const {
171*523fa7a6SAndroid Build Coastguard Worker     return methods_.count(method_name);
172*523fa7a6SAndroid Build Coastguard Worker   }
173*523fa7a6SAndroid Build Coastguard Worker 
174*523fa7a6SAndroid Build Coastguard Worker   /**
175*523fa7a6SAndroid Build Coastguard Worker    * Get a method metadata struct by method name.
176*523fa7a6SAndroid Build Coastguard Worker    * Loads the program and method if needed.
177*523fa7a6SAndroid Build Coastguard Worker    *
178*523fa7a6SAndroid Build Coastguard Worker    * @param[in] method_name The name of the method to get the metadata for.
179*523fa7a6SAndroid Build Coastguard Worker    *
180*523fa7a6SAndroid Build Coastguard Worker    * @returns A method metadata, or an error if the program or method failed to
181*523fa7a6SAndroid Build Coastguard Worker    * load.
182*523fa7a6SAndroid Build Coastguard Worker    */
183*523fa7a6SAndroid Build Coastguard Worker   runtime::Result<runtime::MethodMeta> method_meta(
184*523fa7a6SAndroid Build Coastguard Worker       const std::string& method_name);
185*523fa7a6SAndroid Build Coastguard Worker 
186*523fa7a6SAndroid Build Coastguard Worker   /**
187*523fa7a6SAndroid Build Coastguard Worker    * Execute a specific method with the given input values and retrieve the
188*523fa7a6SAndroid Build Coastguard Worker    * output values. Loads the program and method before executing if needed.
189*523fa7a6SAndroid Build Coastguard Worker    *
190*523fa7a6SAndroid Build Coastguard Worker    * @param[in] method_name The name of the method to execute.
191*523fa7a6SAndroid Build Coastguard Worker    * @param[in] input_values A vector of input values to be passed to the
192*523fa7a6SAndroid Build Coastguard Worker    * method.
193*523fa7a6SAndroid Build Coastguard Worker    *
194*523fa7a6SAndroid Build Coastguard Worker    * @returns A Result object containing either a vector of output values
195*523fa7a6SAndroid Build Coastguard Worker    *          from the method or an error to indicate failure.
196*523fa7a6SAndroid Build Coastguard Worker    */
197*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD
198*523fa7a6SAndroid Build Coastguard Worker   runtime::Result<std::vector<runtime::EValue>> execute(
199*523fa7a6SAndroid Build Coastguard Worker       const std::string& method_name,
200*523fa7a6SAndroid Build Coastguard Worker       const std::vector<runtime::EValue>& input_values);
201*523fa7a6SAndroid Build Coastguard Worker 
202*523fa7a6SAndroid Build Coastguard Worker   /**
203*523fa7a6SAndroid Build Coastguard Worker    * Execute a specific method with a single input value.
204*523fa7a6SAndroid Build Coastguard Worker    * Loads the program and method before executing if needed.
205*523fa7a6SAndroid Build Coastguard Worker    *
206*523fa7a6SAndroid Build Coastguard Worker    * @param[in] method_name The name of the method to execute.
207*523fa7a6SAndroid Build Coastguard Worker    * @param[in] input_value A value to be passed to the method.
208*523fa7a6SAndroid Build Coastguard Worker    *
209*523fa7a6SAndroid Build Coastguard Worker    * @returns A Result object containing either a vector of output values
210*523fa7a6SAndroid Build Coastguard Worker    *          from the method or an error to indicate failure.
211*523fa7a6SAndroid Build Coastguard Worker    */
execute(const std::string & method_name,const runtime::EValue & input_value)212*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> execute(
213*523fa7a6SAndroid Build Coastguard Worker       const std::string& method_name,
214*523fa7a6SAndroid Build Coastguard Worker       const runtime::EValue& input_value) {
215*523fa7a6SAndroid Build Coastguard Worker     return execute(method_name, std::vector<runtime::EValue>{input_value});
216*523fa7a6SAndroid Build Coastguard Worker   }
217*523fa7a6SAndroid Build Coastguard Worker 
218*523fa7a6SAndroid Build Coastguard Worker   /**
219*523fa7a6SAndroid Build Coastguard Worker    * Execute a specific method without any input values.
220*523fa7a6SAndroid Build Coastguard Worker    * Loads the program and method before executing if needed.
221*523fa7a6SAndroid Build Coastguard Worker    *
222*523fa7a6SAndroid Build Coastguard Worker    * @param[in] method_name The name of the method to execute.
223*523fa7a6SAndroid Build Coastguard Worker    *
224*523fa7a6SAndroid Build Coastguard Worker    * @returns A Result object containing either a vector of output values
225*523fa7a6SAndroid Build Coastguard Worker    *          from the method or an error to indicate failure.
226*523fa7a6SAndroid Build Coastguard Worker    */
execute(const std::string & method_name)227*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> execute(
228*523fa7a6SAndroid Build Coastguard Worker       const std::string& method_name) {
229*523fa7a6SAndroid Build Coastguard Worker     return execute(method_name, std::vector<runtime::EValue>{});
230*523fa7a6SAndroid Build Coastguard Worker   }
231*523fa7a6SAndroid Build Coastguard Worker 
232*523fa7a6SAndroid Build Coastguard Worker   /**
233*523fa7a6SAndroid Build Coastguard Worker    * Retrieve the output value of a specific method with the given input values.
234*523fa7a6SAndroid Build Coastguard Worker    * Loads the program and method before execution if needed.
235*523fa7a6SAndroid Build Coastguard Worker    *
236*523fa7a6SAndroid Build Coastguard Worker    * @param[in] method_name The name of the method to execute.
237*523fa7a6SAndroid Build Coastguard Worker    * @param[in] input_values A vector of input values to be passed to the
238*523fa7a6SAndroid Build Coastguard Worker    * method.
239*523fa7a6SAndroid Build Coastguard Worker    *
240*523fa7a6SAndroid Build Coastguard Worker    * @returns A Result object containing either the first output value from the
241*523fa7a6SAndroid Build Coastguard Worker    * method or an error to indicate failure.
242*523fa7a6SAndroid Build Coastguard Worker    */
get(const std::string & method_name,const std::vector<runtime::EValue> & input_values)243*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD inline runtime::Result<runtime::EValue> get(
244*523fa7a6SAndroid Build Coastguard Worker       const std::string& method_name,
245*523fa7a6SAndroid Build Coastguard Worker       const std::vector<runtime::EValue>& input_values) {
246*523fa7a6SAndroid Build Coastguard Worker     auto result = ET_UNWRAP(execute(method_name, input_values));
247*523fa7a6SAndroid Build Coastguard Worker     if (result.empty()) {
248*523fa7a6SAndroid Build Coastguard Worker       return runtime::Error::InvalidArgument;
249*523fa7a6SAndroid Build Coastguard Worker     }
250*523fa7a6SAndroid Build Coastguard Worker     return result[0];
251*523fa7a6SAndroid Build Coastguard Worker   }
252*523fa7a6SAndroid Build Coastguard Worker 
253*523fa7a6SAndroid Build Coastguard Worker   /**
254*523fa7a6SAndroid Build Coastguard Worker    * Retrieve the output value of a specific method with a single input value.
255*523fa7a6SAndroid Build Coastguard Worker    * Loads the program and method before execution if needed.
256*523fa7a6SAndroid Build Coastguard Worker    *
257*523fa7a6SAndroid Build Coastguard Worker    * @param[in] method_name The name of the method to execute.
258*523fa7a6SAndroid Build Coastguard Worker    * @param[in] input_value A value to be passed to the method.
259*523fa7a6SAndroid Build Coastguard Worker    *
260*523fa7a6SAndroid Build Coastguard Worker    * @returns A Result object containing either the first output value from the
261*523fa7a6SAndroid Build Coastguard Worker    * method or an error to indicate failure.
262*523fa7a6SAndroid Build Coastguard Worker    */
get(const std::string & method_name,const runtime::EValue & input_value)263*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD inline runtime::Result<runtime::EValue> get(
264*523fa7a6SAndroid Build Coastguard Worker       const std::string& method_name,
265*523fa7a6SAndroid Build Coastguard Worker       const runtime::EValue& input_value) {
266*523fa7a6SAndroid Build Coastguard Worker     return get(method_name, std::vector<runtime::EValue>{input_value});
267*523fa7a6SAndroid Build Coastguard Worker   }
268*523fa7a6SAndroid Build Coastguard Worker 
269*523fa7a6SAndroid Build Coastguard Worker   /**
270*523fa7a6SAndroid Build Coastguard Worker    * Retrieve the output value of a specific method without any input values.
271*523fa7a6SAndroid Build Coastguard Worker    * Loads the program and method before execution if needed.
272*523fa7a6SAndroid Build Coastguard Worker    *
273*523fa7a6SAndroid Build Coastguard Worker    * @param[in] method_name The name of the method to execute.
274*523fa7a6SAndroid Build Coastguard Worker    *
275*523fa7a6SAndroid Build Coastguard Worker    * @returns A Result object containing either the first output value from the
276*523fa7a6SAndroid Build Coastguard Worker    * method or an error to indicate failure.
277*523fa7a6SAndroid Build Coastguard Worker    */
get(const std::string & method_name)278*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD inline runtime::Result<runtime::EValue> get(
279*523fa7a6SAndroid Build Coastguard Worker       const std::string& method_name) {
280*523fa7a6SAndroid Build Coastguard Worker     return get(method_name, std::vector<runtime::EValue>{});
281*523fa7a6SAndroid Build Coastguard Worker   }
282*523fa7a6SAndroid Build Coastguard Worker 
283*523fa7a6SAndroid Build Coastguard Worker   /**
284*523fa7a6SAndroid Build Coastguard Worker    * Execute the 'forward' method with the given input values and retrieve the
285*523fa7a6SAndroid Build Coastguard Worker    * output values. Loads the program and method before executing if needed.
286*523fa7a6SAndroid Build Coastguard Worker    *
287*523fa7a6SAndroid Build Coastguard Worker    * @param[in] input_values A vector of input values for the 'forward' method.
288*523fa7a6SAndroid Build Coastguard Worker    *
289*523fa7a6SAndroid Build Coastguard Worker    * @returns A Result object containing either a vector of output values
290*523fa7a6SAndroid Build Coastguard Worker    *          from the 'forward' method or an error to indicate failure.
291*523fa7a6SAndroid Build Coastguard Worker    */
forward(const std::vector<runtime::EValue> & input_values)292*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward(
293*523fa7a6SAndroid Build Coastguard Worker       const std::vector<runtime::EValue>& input_values) {
294*523fa7a6SAndroid Build Coastguard Worker     return execute("forward", input_values);
295*523fa7a6SAndroid Build Coastguard Worker   }
296*523fa7a6SAndroid Build Coastguard Worker 
297*523fa7a6SAndroid Build Coastguard Worker   /**
298*523fa7a6SAndroid Build Coastguard Worker    * Execute the 'forward' method with a single value.
299*523fa7a6SAndroid Build Coastguard Worker    * Loads the program and method before executing if needed.
300*523fa7a6SAndroid Build Coastguard Worker    *
301*523fa7a6SAndroid Build Coastguard Worker    * @param[in] input_value A value for the 'forward' method.
302*523fa7a6SAndroid Build Coastguard Worker    *
303*523fa7a6SAndroid Build Coastguard Worker    * @returns A Result object containing either a vector of output values
304*523fa7a6SAndroid Build Coastguard Worker    *          from the 'forward' method or an error to indicate failure.
305*523fa7a6SAndroid Build Coastguard Worker    */
forward(const runtime::EValue & input_value)306*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward(
307*523fa7a6SAndroid Build Coastguard Worker       const runtime::EValue& input_value) {
308*523fa7a6SAndroid Build Coastguard Worker     return forward(std::vector<runtime::EValue>{input_value});
309*523fa7a6SAndroid Build Coastguard Worker   }
310*523fa7a6SAndroid Build Coastguard Worker 
311*523fa7a6SAndroid Build Coastguard Worker   /**
312*523fa7a6SAndroid Build Coastguard Worker    * Execute the 'forward' method without any input values.
313*523fa7a6SAndroid Build Coastguard Worker    * Loads the program and method before executing if needed.
314*523fa7a6SAndroid Build Coastguard Worker    *
315*523fa7a6SAndroid Build Coastguard Worker    * @returns A Result object containing either a vector of output values
316*523fa7a6SAndroid Build Coastguard Worker    *          from the 'forward' method or an error to indicate failure.
317*523fa7a6SAndroid Build Coastguard Worker    */
forward()318*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward() {
319*523fa7a6SAndroid Build Coastguard Worker     return forward(std::vector<runtime::EValue>{});
320*523fa7a6SAndroid Build Coastguard Worker   }
321*523fa7a6SAndroid Build Coastguard Worker 
322*523fa7a6SAndroid Build Coastguard Worker   /**
323*523fa7a6SAndroid Build Coastguard Worker    * Sets a single input value for a specific method.
324*523fa7a6SAndroid Build Coastguard Worker    *
325*523fa7a6SAndroid Build Coastguard Worker    * @param[in] method_name The name of the method.
326*523fa7a6SAndroid Build Coastguard Worker    * @param[in] input_value The EValue to set as the method input.
327*523fa7a6SAndroid Build Coastguard Worker    * @param[in] input_index Zero-based index of the input to set.
328*523fa7a6SAndroid Build Coastguard Worker    *
329*523fa7a6SAndroid Build Coastguard Worker    * @returns An Error to indicate success or failure.
330*523fa7a6SAndroid Build Coastguard Worker    */
331*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD
332*523fa7a6SAndroid Build Coastguard Worker   runtime::Error set_input(
333*523fa7a6SAndroid Build Coastguard Worker       const std::string& method_name,
334*523fa7a6SAndroid Build Coastguard Worker       const runtime::EValue& input_value,
335*523fa7a6SAndroid Build Coastguard Worker       size_t input_index);
336*523fa7a6SAndroid Build Coastguard Worker 
337*523fa7a6SAndroid Build Coastguard Worker   /**
338*523fa7a6SAndroid Build Coastguard Worker    * Sets a single input value for the "forward" method.
339*523fa7a6SAndroid Build Coastguard Worker    *
340*523fa7a6SAndroid Build Coastguard Worker    * @param[in] input_value The EValue to set as the method input.
341*523fa7a6SAndroid Build Coastguard Worker    * @param[in] input_index Zero-based index of the input to set.
342*523fa7a6SAndroid Build Coastguard Worker    *
343*523fa7a6SAndroid Build Coastguard Worker    * @returns An Error to indicate success or failure.
344*523fa7a6SAndroid Build Coastguard Worker    */
345*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD
set_input(const runtime::EValue & input_value,size_t input_index)346*523fa7a6SAndroid Build Coastguard Worker   inline runtime::Error set_input(
347*523fa7a6SAndroid Build Coastguard Worker       const runtime::EValue& input_value,
348*523fa7a6SAndroid Build Coastguard Worker       size_t input_index) {
349*523fa7a6SAndroid Build Coastguard Worker     return set_input("forward", input_value, input_index);
350*523fa7a6SAndroid Build Coastguard Worker   }
351*523fa7a6SAndroid Build Coastguard Worker 
352*523fa7a6SAndroid Build Coastguard Worker   /**
353*523fa7a6SAndroid Build Coastguard Worker    * Sets all input values for a specific method.
354*523fa7a6SAndroid Build Coastguard Worker    *
355*523fa7a6SAndroid Build Coastguard Worker    * @param[in] method_name The name of the method.
356*523fa7a6SAndroid Build Coastguard Worker    * @param[in] input_values A vector of EValues to set as the method inputs.
357*523fa7a6SAndroid Build Coastguard Worker    *
358*523fa7a6SAndroid Build Coastguard Worker    * @returns An Error to indicate success or failure.
359*523fa7a6SAndroid Build Coastguard Worker    */
360*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD
361*523fa7a6SAndroid Build Coastguard Worker   runtime::Error set_inputs(
362*523fa7a6SAndroid Build Coastguard Worker       const std::string& method_name,
363*523fa7a6SAndroid Build Coastguard Worker       const std::vector<runtime::EValue>& input_values);
364*523fa7a6SAndroid Build Coastguard Worker 
365*523fa7a6SAndroid Build Coastguard Worker   /**
366*523fa7a6SAndroid Build Coastguard Worker    * Sets all input values for the "forward" method.
367*523fa7a6SAndroid Build Coastguard Worker    *
368*523fa7a6SAndroid Build Coastguard Worker    * @param[in] input_values A vector of EValues to set as the method inputs.
369*523fa7a6SAndroid Build Coastguard Worker    *
370*523fa7a6SAndroid Build Coastguard Worker    * @returns An Error to indicate success or failure.
371*523fa7a6SAndroid Build Coastguard Worker    */
372*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD
set_inputs(const std::vector<runtime::EValue> & input_values)373*523fa7a6SAndroid Build Coastguard Worker   inline runtime::Error set_inputs(
374*523fa7a6SAndroid Build Coastguard Worker       const std::vector<runtime::EValue>& input_values) {
375*523fa7a6SAndroid Build Coastguard Worker     return set_inputs("forward", input_values);
376*523fa7a6SAndroid Build Coastguard Worker   }
377*523fa7a6SAndroid Build Coastguard Worker 
378*523fa7a6SAndroid Build Coastguard Worker   /**
379*523fa7a6SAndroid Build Coastguard Worker    * Sets the output tensor for a specific method.
380*523fa7a6SAndroid Build Coastguard Worker    *
381*523fa7a6SAndroid Build Coastguard Worker    * @param[in] method_name The name of the method.
382*523fa7a6SAndroid Build Coastguard Worker    * @param[in] output_value The EValue containing the Tensor to set as the
383*523fa7a6SAndroid Build Coastguard Worker    * method output.
384*523fa7a6SAndroid Build Coastguard Worker    * @param[in] output_index Zero-based index of the output to set.
385*523fa7a6SAndroid Build Coastguard Worker    *
386*523fa7a6SAndroid Build Coastguard Worker    * @returns An Error to indicate success or failure.
387*523fa7a6SAndroid Build Coastguard Worker    *
388*523fa7a6SAndroid Build Coastguard Worker    * @note Only Tensor outputs are currently supported for setting.
389*523fa7a6SAndroid Build Coastguard Worker    */
390*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD
391*523fa7a6SAndroid Build Coastguard Worker   runtime::Error set_output(
392*523fa7a6SAndroid Build Coastguard Worker       const std::string& method_name,
393*523fa7a6SAndroid Build Coastguard Worker       runtime::EValue output_value,
394*523fa7a6SAndroid Build Coastguard Worker       size_t output_index = 0);
395*523fa7a6SAndroid Build Coastguard Worker 
396*523fa7a6SAndroid Build Coastguard Worker   /**
397*523fa7a6SAndroid Build Coastguard Worker    * Sets the output tensor for the "forward" method.
398*523fa7a6SAndroid Build Coastguard Worker    *
399*523fa7a6SAndroid Build Coastguard Worker    * @param[in] output_value The EValue containing the Tensor to set as the
400*523fa7a6SAndroid Build Coastguard Worker    * method output.
401*523fa7a6SAndroid Build Coastguard Worker    * @param[in] output_index Zero-based index of the output to set.
402*523fa7a6SAndroid Build Coastguard Worker    *
403*523fa7a6SAndroid Build Coastguard Worker    * @returns An Error to indicate success or failure.
404*523fa7a6SAndroid Build Coastguard Worker    *
405*523fa7a6SAndroid Build Coastguard Worker    * @note Only Tensor outputs are currently supported for setting.
406*523fa7a6SAndroid Build Coastguard Worker    */
407*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD
408*523fa7a6SAndroid Build Coastguard Worker   inline runtime::Error set_output(
409*523fa7a6SAndroid Build Coastguard Worker       runtime::EValue output_value,
410*523fa7a6SAndroid Build Coastguard Worker       size_t output_index = 0) {
411*523fa7a6SAndroid Build Coastguard Worker     return set_output("forward", std::move(output_value), output_index);
412*523fa7a6SAndroid Build Coastguard Worker   }
413*523fa7a6SAndroid Build Coastguard Worker 
414*523fa7a6SAndroid Build Coastguard Worker   /**
415*523fa7a6SAndroid Build Coastguard Worker    * Retrieves the EventTracer instance being used by the Module.
416*523fa7a6SAndroid Build Coastguard Worker    * EventTracer is used for tracking and logging events during the execution
417*523fa7a6SAndroid Build Coastguard Worker    * of methods.
418*523fa7a6SAndroid Build Coastguard Worker    *
419*523fa7a6SAndroid Build Coastguard Worker    * @returns A pointer to the EventTracer instance. Returns nullptr if no
420*523fa7a6SAndroid Build Coastguard Worker    * EventTracer is set.
421*523fa7a6SAndroid Build Coastguard Worker    */
event_tracer()422*523fa7a6SAndroid Build Coastguard Worker   inline runtime::EventTracer* event_tracer() const {
423*523fa7a6SAndroid Build Coastguard Worker     return event_tracer_.get();
424*523fa7a6SAndroid Build Coastguard Worker   }
425*523fa7a6SAndroid Build Coastguard Worker 
426*523fa7a6SAndroid Build Coastguard Worker  private:
427*523fa7a6SAndroid Build Coastguard Worker   struct MethodHolder {
428*523fa7a6SAndroid Build Coastguard Worker     std::vector<std::vector<uint8_t>> planned_buffers;
429*523fa7a6SAndroid Build Coastguard Worker     std::vector<runtime::Span<uint8_t>> planned_spans;
430*523fa7a6SAndroid Build Coastguard Worker     std::unique_ptr<runtime::HierarchicalAllocator> planned_memory;
431*523fa7a6SAndroid Build Coastguard Worker     std::unique_ptr<runtime::MemoryManager> memory_manager;
432*523fa7a6SAndroid Build Coastguard Worker     std::unique_ptr<runtime::Method> method;
433*523fa7a6SAndroid Build Coastguard Worker     std::vector<runtime::EValue> inputs;
434*523fa7a6SAndroid Build Coastguard Worker   };
435*523fa7a6SAndroid Build Coastguard Worker 
436*523fa7a6SAndroid Build Coastguard Worker  private:
437*523fa7a6SAndroid Build Coastguard Worker   std::string file_path_;
438*523fa7a6SAndroid Build Coastguard Worker   LoadMode load_mode_{LoadMode::MmapUseMlock};
439*523fa7a6SAndroid Build Coastguard Worker   std::shared_ptr<runtime::Program> program_;
440*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<runtime::DataLoader> data_loader_;
441*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<runtime::MemoryAllocator> memory_allocator_;
442*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<runtime::MemoryAllocator> temp_allocator_;
443*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<runtime::EventTracer> event_tracer_;
444*523fa7a6SAndroid Build Coastguard Worker 
445*523fa7a6SAndroid Build Coastguard Worker  protected:
446*523fa7a6SAndroid Build Coastguard Worker   std::unordered_map<std::string, MethodHolder> methods_;
447*523fa7a6SAndroid Build Coastguard Worker 
448*523fa7a6SAndroid Build Coastguard Worker   friend class ExecuTorchJni;
449*523fa7a6SAndroid Build Coastguard Worker };
450*523fa7a6SAndroid Build Coastguard Worker 
451*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
452*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
453*523fa7a6SAndroid Build Coastguard Worker 
454*523fa7a6SAndroid Build Coastguard Worker namespace torch {
455*523fa7a6SAndroid Build Coastguard Worker namespace executor {
456*523fa7a6SAndroid Build Coastguard Worker // TODO(T197294990): Remove these deprecated aliases once all users have moved
457*523fa7a6SAndroid Build Coastguard Worker // to the new `::executorch` namespaces.
458*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::Module;
459*523fa7a6SAndroid Build Coastguard Worker } // namespace executor
460*523fa7a6SAndroid Build Coastguard Worker } // namespace torch
461