1*523fa7a6SAndroid Build Coastguard Worker /* 2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates. 3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved. 4*523fa7a6SAndroid Build Coastguard Worker * 5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the 6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree. 7*523fa7a6SAndroid Build Coastguard Worker */ 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker #pragma once 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Worker #include <memory> 12*523fa7a6SAndroid Build Coastguard Worker #include <string> 13*523fa7a6SAndroid Build Coastguard Worker #include <unordered_map> 14*523fa7a6SAndroid Build Coastguard Worker #include <unordered_set> 15*523fa7a6SAndroid Build Coastguard Worker #include <vector> 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/executor/program.h> 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Worker namespace executorch { 20*523fa7a6SAndroid Build Coastguard Worker namespace extension { 21*523fa7a6SAndroid Build Coastguard Worker 22*523fa7a6SAndroid Build Coastguard Worker /** 23*523fa7a6SAndroid Build Coastguard Worker * A facade class for loading programs and executing methods within them. 24*523fa7a6SAndroid Build Coastguard Worker */ 25*523fa7a6SAndroid Build Coastguard Worker class Module { 26*523fa7a6SAndroid Build Coastguard Worker public: 27*523fa7a6SAndroid Build Coastguard Worker /** 28*523fa7a6SAndroid Build Coastguard Worker * Enum to define loading behavior. 29*523fa7a6SAndroid Build Coastguard Worker */ 30*523fa7a6SAndroid Build Coastguard Worker enum class LoadMode { 31*523fa7a6SAndroid Build Coastguard Worker /// Load the whole file as a buffer. 32*523fa7a6SAndroid Build Coastguard Worker File, 33*523fa7a6SAndroid Build Coastguard Worker /// Use mmap to load pages into memory. 34*523fa7a6SAndroid Build Coastguard Worker Mmap, 35*523fa7a6SAndroid Build Coastguard Worker /// Use memory locking and handle errors. 36*523fa7a6SAndroid Build Coastguard Worker MmapUseMlock, 37*523fa7a6SAndroid Build Coastguard Worker /// Use memory locking and ignore errors. 38*523fa7a6SAndroid Build Coastguard Worker MmapUseMlockIgnoreErrors, 39*523fa7a6SAndroid Build Coastguard Worker }; 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard Worker /** 42*523fa7a6SAndroid Build Coastguard Worker * Constructs an instance by loading a program from a file with specified 43*523fa7a6SAndroid Build Coastguard Worker * memory locking behavior. 44*523fa7a6SAndroid Build Coastguard Worker * 45*523fa7a6SAndroid Build Coastguard Worker * @param[in] file_path The path to the ExecuTorch program file to load. 46*523fa7a6SAndroid Build Coastguard Worker * @param[in] load_mode The loading mode to use. 47*523fa7a6SAndroid Build Coastguard Worker * @param[in] event_tracer A EventTracer used for tracking and logging events. 48*523fa7a6SAndroid Build Coastguard Worker */ 49*523fa7a6SAndroid Build Coastguard Worker explicit Module( 50*523fa7a6SAndroid Build Coastguard Worker const std::string& file_path, 51*523fa7a6SAndroid Build Coastguard Worker const LoadMode load_mode = LoadMode::MmapUseMlock, 52*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::EventTracer> event_tracer = nullptr); 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Worker /** 55*523fa7a6SAndroid Build Coastguard Worker * Constructs an instance with the provided data loader and memory allocator. 56*523fa7a6SAndroid Build Coastguard Worker * 57*523fa7a6SAndroid Build Coastguard Worker * @param[in] data_loader A DataLoader used for loading program data. 58*523fa7a6SAndroid Build Coastguard Worker * @param[in] memory_allocator A MemoryAllocator used for memory management. 59*523fa7a6SAndroid Build Coastguard Worker * @param[in] temp_allocator A MemoryAllocator to use when allocating 60*523fa7a6SAndroid Build Coastguard Worker * temporary data during kernel or delegate execution. 61*523fa7a6SAndroid Build Coastguard Worker * @param[in] event_tracer A EventTracer used for tracking and logging events. 62*523fa7a6SAndroid Build Coastguard Worker */ 63*523fa7a6SAndroid Build Coastguard Worker explicit Module( 64*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::DataLoader> data_loader, 65*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr, 66*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr, 67*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::EventTracer> event_tracer = nullptr); 68*523fa7a6SAndroid Build Coastguard Worker 69*523fa7a6SAndroid Build Coastguard Worker /** 70*523fa7a6SAndroid Build Coastguard Worker * Constructs an instance using an existing shared program. 71*523fa7a6SAndroid Build Coastguard Worker * 72*523fa7a6SAndroid Build Coastguard Worker * @param[in] program The shared program to use. It's required the data loader 73*523fa7a6SAndroid Build Coastguard Worker * the program uses is valid for the lifetime of the program. 74*523fa7a6SAndroid Build Coastguard Worker * @param[in] memory_allocator A MemoryAllocator used for memory management. 75*523fa7a6SAndroid Build Coastguard Worker * @param[in] temp_allocator A MemoryAllocator to use when allocating 76*523fa7a6SAndroid Build Coastguard Worker * temporary data. 77*523fa7a6SAndroid Build Coastguard Worker * @param[in] event_tracer A EventTracer used for tracking and logging events. 78*523fa7a6SAndroid Build Coastguard Worker */ 79*523fa7a6SAndroid Build Coastguard Worker explicit Module( 80*523fa7a6SAndroid Build Coastguard Worker std::shared_ptr<runtime::Program> program, 81*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr, 82*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr, 83*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::EventTracer> event_tracer = nullptr); 84*523fa7a6SAndroid Build Coastguard Worker 85*523fa7a6SAndroid Build Coastguard Worker Module(const Module&) = delete; 86*523fa7a6SAndroid Build Coastguard Worker Module& operator=(const Module&) = delete; 87*523fa7a6SAndroid Build Coastguard Worker Module(Module&&) = delete; 88*523fa7a6SAndroid Build Coastguard Worker Module& operator=(Module&&) = delete; 89*523fa7a6SAndroid Build Coastguard Worker 90*523fa7a6SAndroid Build Coastguard Worker /** 91*523fa7a6SAndroid Build Coastguard Worker * Loads the program if needed. 92*523fa7a6SAndroid Build Coastguard Worker * 93*523fa7a6SAndroid Build Coastguard Worker * @param[in] verification The type of verification to do before returning 94*523fa7a6SAndroid Build Coastguard Worker * success. 95*523fa7a6SAndroid Build Coastguard Worker * 96*523fa7a6SAndroid Build Coastguard Worker * @returns An Error to indicate success or failure of the loading process. 97*523fa7a6SAndroid Build Coastguard Worker */ 98*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD 99*523fa7a6SAndroid Build Coastguard Worker runtime::Error load( 100*523fa7a6SAndroid Build Coastguard Worker const runtime::Program::Verification verification = 101*523fa7a6SAndroid Build Coastguard Worker runtime::Program::Verification::Minimal); 102*523fa7a6SAndroid Build Coastguard Worker 103*523fa7a6SAndroid Build Coastguard Worker /** 104*523fa7a6SAndroid Build Coastguard Worker * Checks if the program is loaded. 105*523fa7a6SAndroid Build Coastguard Worker * 106*523fa7a6SAndroid Build Coastguard Worker * @returns true if the program is loaded, false otherwise. 107*523fa7a6SAndroid Build Coastguard Worker */ is_loaded()108*523fa7a6SAndroid Build Coastguard Worker inline bool is_loaded() const { 109*523fa7a6SAndroid Build Coastguard Worker return program_ != nullptr; 110*523fa7a6SAndroid Build Coastguard Worker } 111*523fa7a6SAndroid Build Coastguard Worker 112*523fa7a6SAndroid Build Coastguard Worker /** 113*523fa7a6SAndroid Build Coastguard Worker * Get the program. The data loader used by the program is guaranteed to be 114*523fa7a6SAndroid Build Coastguard Worker * valid for the lifetime of the program. 115*523fa7a6SAndroid Build Coastguard Worker * 116*523fa7a6SAndroid Build Coastguard Worker * @returns Shared pointer to the program or nullptr if it's not yet loaded. 117*523fa7a6SAndroid Build Coastguard Worker */ program()118*523fa7a6SAndroid Build Coastguard Worker inline std::shared_ptr<runtime::Program> program() const { 119*523fa7a6SAndroid Build Coastguard Worker return program_; 120*523fa7a6SAndroid Build Coastguard Worker } 121*523fa7a6SAndroid Build Coastguard Worker 122*523fa7a6SAndroid Build Coastguard Worker /** 123*523fa7a6SAndroid Build Coastguard Worker * Get a list of method names available in the loaded program. 124*523fa7a6SAndroid Build Coastguard Worker * Loads the program and method if needed. 125*523fa7a6SAndroid Build Coastguard Worker * 126*523fa7a6SAndroid Build Coastguard Worker * @returns A set of strings containing the names of the methods, or an error 127*523fa7a6SAndroid Build Coastguard Worker * if the program or method failed to load. 128*523fa7a6SAndroid Build Coastguard Worker */ 129*523fa7a6SAndroid Build Coastguard Worker runtime::Result<std::unordered_set<std::string>> method_names(); 130*523fa7a6SAndroid Build Coastguard Worker 131*523fa7a6SAndroid Build Coastguard Worker /** 132*523fa7a6SAndroid Build Coastguard Worker * Load a specific method from the program and set up memory management if 133*523fa7a6SAndroid Build Coastguard Worker * needed. The loaded method is cached to reuse the next time it's executed. 134*523fa7a6SAndroid Build Coastguard Worker * 135*523fa7a6SAndroid Build Coastguard Worker * @param[in] method_name The name of the method to load. 136*523fa7a6SAndroid Build Coastguard Worker * @param[in] event_tracer Per-method event tracer to profile/trace methods 137*523fa7a6SAndroid Build Coastguard Worker * individually. When not given, the event tracer passed to the Module 138*523fa7a6SAndroid Build Coastguard Worker * constructor is used. Otherwise, this per-method event tracer takes 139*523fa7a6SAndroid Build Coastguard Worker * precedence. 140*523fa7a6SAndroid Build Coastguard Worker * 141*523fa7a6SAndroid Build Coastguard Worker * @returns An Error to indicate success or failure. 142*523fa7a6SAndroid Build Coastguard Worker */ 143*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD 144*523fa7a6SAndroid Build Coastguard Worker runtime::Error load_method( 145*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name, 146*523fa7a6SAndroid Build Coastguard Worker torch::executor::EventTracer* event_tracer = nullptr); 147*523fa7a6SAndroid Build Coastguard Worker 148*523fa7a6SAndroid Build Coastguard Worker /** 149*523fa7a6SAndroid Build Coastguard Worker * Load the 'forward' method from the program and set up memory management if 150*523fa7a6SAndroid Build Coastguard Worker * needed. The loaded method is cached to reuse the next time it's executed. 151*523fa7a6SAndroid Build Coastguard Worker * 152*523fa7a6SAndroid Build Coastguard Worker * @param[in] event_tracer An event tracer used for tracking and logging 153*523fa7a6SAndroid Build Coastguard Worker * events. 154*523fa7a6SAndroid Build Coastguard Worker * 155*523fa7a6SAndroid Build Coastguard Worker * @returns An Error to indicate success or failure. 156*523fa7a6SAndroid Build Coastguard Worker */ 157*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD inline runtime::Error load_forward( 158*523fa7a6SAndroid Build Coastguard Worker torch::executor::EventTracer* event_tracer = nullptr) { 159*523fa7a6SAndroid Build Coastguard Worker return load_method("forward", event_tracer); 160*523fa7a6SAndroid Build Coastguard Worker } 161*523fa7a6SAndroid Build Coastguard Worker 162*523fa7a6SAndroid Build Coastguard Worker /** 163*523fa7a6SAndroid Build Coastguard Worker * Checks if a specific method is loaded. 164*523fa7a6SAndroid Build Coastguard Worker * 165*523fa7a6SAndroid Build Coastguard Worker * @param[in] method_name The name of the method to check. 166*523fa7a6SAndroid Build Coastguard Worker * 167*523fa7a6SAndroid Build Coastguard Worker * @returns true if the method specified by method_name is loaded, false 168*523fa7a6SAndroid Build Coastguard Worker * otherwise. 169*523fa7a6SAndroid Build Coastguard Worker */ is_method_loaded(const std::string & method_name)170*523fa7a6SAndroid Build Coastguard Worker inline bool is_method_loaded(const std::string& method_name) const { 171*523fa7a6SAndroid Build Coastguard Worker return methods_.count(method_name); 172*523fa7a6SAndroid Build Coastguard Worker } 173*523fa7a6SAndroid Build Coastguard Worker 174*523fa7a6SAndroid Build Coastguard Worker /** 175*523fa7a6SAndroid Build Coastguard Worker * Get a method metadata struct by method name. 176*523fa7a6SAndroid Build Coastguard Worker * Loads the program and method if needed. 177*523fa7a6SAndroid Build Coastguard Worker * 178*523fa7a6SAndroid Build Coastguard Worker * @param[in] method_name The name of the method to get the metadata for. 179*523fa7a6SAndroid Build Coastguard Worker * 180*523fa7a6SAndroid Build Coastguard Worker * @returns A method metadata, or an error if the program or method failed to 181*523fa7a6SAndroid Build Coastguard Worker * load. 182*523fa7a6SAndroid Build Coastguard Worker */ 183*523fa7a6SAndroid Build Coastguard Worker runtime::Result<runtime::MethodMeta> method_meta( 184*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name); 185*523fa7a6SAndroid Build Coastguard Worker 186*523fa7a6SAndroid Build Coastguard Worker /** 187*523fa7a6SAndroid Build Coastguard Worker * Execute a specific method with the given input values and retrieve the 188*523fa7a6SAndroid Build Coastguard Worker * output values. Loads the program and method before executing if needed. 189*523fa7a6SAndroid Build Coastguard Worker * 190*523fa7a6SAndroid Build Coastguard Worker * @param[in] method_name The name of the method to execute. 191*523fa7a6SAndroid Build Coastguard Worker * @param[in] input_values A vector of input values to be passed to the 192*523fa7a6SAndroid Build Coastguard Worker * method. 193*523fa7a6SAndroid Build Coastguard Worker * 194*523fa7a6SAndroid Build Coastguard Worker * @returns A Result object containing either a vector of output values 195*523fa7a6SAndroid Build Coastguard Worker * from the method or an error to indicate failure. 196*523fa7a6SAndroid Build Coastguard Worker */ 197*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD 198*523fa7a6SAndroid Build Coastguard Worker runtime::Result<std::vector<runtime::EValue>> execute( 199*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name, 200*523fa7a6SAndroid Build Coastguard Worker const std::vector<runtime::EValue>& input_values); 201*523fa7a6SAndroid Build Coastguard Worker 202*523fa7a6SAndroid Build Coastguard Worker /** 203*523fa7a6SAndroid Build Coastguard Worker * Execute a specific method with a single input value. 204*523fa7a6SAndroid Build Coastguard Worker * Loads the program and method before executing if needed. 205*523fa7a6SAndroid Build Coastguard Worker * 206*523fa7a6SAndroid Build Coastguard Worker * @param[in] method_name The name of the method to execute. 207*523fa7a6SAndroid Build Coastguard Worker * @param[in] input_value A value to be passed to the method. 208*523fa7a6SAndroid Build Coastguard Worker * 209*523fa7a6SAndroid Build Coastguard Worker * @returns A Result object containing either a vector of output values 210*523fa7a6SAndroid Build Coastguard Worker * from the method or an error to indicate failure. 211*523fa7a6SAndroid Build Coastguard Worker */ execute(const std::string & method_name,const runtime::EValue & input_value)212*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> execute( 213*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name, 214*523fa7a6SAndroid Build Coastguard Worker const runtime::EValue& input_value) { 215*523fa7a6SAndroid Build Coastguard Worker return execute(method_name, std::vector<runtime::EValue>{input_value}); 216*523fa7a6SAndroid Build Coastguard Worker } 217*523fa7a6SAndroid Build Coastguard Worker 218*523fa7a6SAndroid Build Coastguard Worker /** 219*523fa7a6SAndroid Build Coastguard Worker * Execute a specific method without any input values. 220*523fa7a6SAndroid Build Coastguard Worker * Loads the program and method before executing if needed. 221*523fa7a6SAndroid Build Coastguard Worker * 222*523fa7a6SAndroid Build Coastguard Worker * @param[in] method_name The name of the method to execute. 223*523fa7a6SAndroid Build Coastguard Worker * 224*523fa7a6SAndroid Build Coastguard Worker * @returns A Result object containing either a vector of output values 225*523fa7a6SAndroid Build Coastguard Worker * from the method or an error to indicate failure. 226*523fa7a6SAndroid Build Coastguard Worker */ execute(const std::string & method_name)227*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> execute( 228*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name) { 229*523fa7a6SAndroid Build Coastguard Worker return execute(method_name, std::vector<runtime::EValue>{}); 230*523fa7a6SAndroid Build Coastguard Worker } 231*523fa7a6SAndroid Build Coastguard Worker 232*523fa7a6SAndroid Build Coastguard Worker /** 233*523fa7a6SAndroid Build Coastguard Worker * Retrieve the output value of a specific method with the given input values. 234*523fa7a6SAndroid Build Coastguard Worker * Loads the program and method before execution if needed. 235*523fa7a6SAndroid Build Coastguard Worker * 236*523fa7a6SAndroid Build Coastguard Worker * @param[in] method_name The name of the method to execute. 237*523fa7a6SAndroid Build Coastguard Worker * @param[in] input_values A vector of input values to be passed to the 238*523fa7a6SAndroid Build Coastguard Worker * method. 239*523fa7a6SAndroid Build Coastguard Worker * 240*523fa7a6SAndroid Build Coastguard Worker * @returns A Result object containing either the first output value from the 241*523fa7a6SAndroid Build Coastguard Worker * method or an error to indicate failure. 242*523fa7a6SAndroid Build Coastguard Worker */ get(const std::string & method_name,const std::vector<runtime::EValue> & input_values)243*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD inline runtime::Result<runtime::EValue> get( 244*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name, 245*523fa7a6SAndroid Build Coastguard Worker const std::vector<runtime::EValue>& input_values) { 246*523fa7a6SAndroid Build Coastguard Worker auto result = ET_UNWRAP(execute(method_name, input_values)); 247*523fa7a6SAndroid Build Coastguard Worker if (result.empty()) { 248*523fa7a6SAndroid Build Coastguard Worker return runtime::Error::InvalidArgument; 249*523fa7a6SAndroid Build Coastguard Worker } 250*523fa7a6SAndroid Build Coastguard Worker return result[0]; 251*523fa7a6SAndroid Build Coastguard Worker } 252*523fa7a6SAndroid Build Coastguard Worker 253*523fa7a6SAndroid Build Coastguard Worker /** 254*523fa7a6SAndroid Build Coastguard Worker * Retrieve the output value of a specific method with a single input value. 255*523fa7a6SAndroid Build Coastguard Worker * Loads the program and method before execution if needed. 256*523fa7a6SAndroid Build Coastguard Worker * 257*523fa7a6SAndroid Build Coastguard Worker * @param[in] method_name The name of the method to execute. 258*523fa7a6SAndroid Build Coastguard Worker * @param[in] input_value A value to be passed to the method. 259*523fa7a6SAndroid Build Coastguard Worker * 260*523fa7a6SAndroid Build Coastguard Worker * @returns A Result object containing either the first output value from the 261*523fa7a6SAndroid Build Coastguard Worker * method or an error to indicate failure. 262*523fa7a6SAndroid Build Coastguard Worker */ get(const std::string & method_name,const runtime::EValue & input_value)263*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD inline runtime::Result<runtime::EValue> get( 264*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name, 265*523fa7a6SAndroid Build Coastguard Worker const runtime::EValue& input_value) { 266*523fa7a6SAndroid Build Coastguard Worker return get(method_name, std::vector<runtime::EValue>{input_value}); 267*523fa7a6SAndroid Build Coastguard Worker } 268*523fa7a6SAndroid Build Coastguard Worker 269*523fa7a6SAndroid Build Coastguard Worker /** 270*523fa7a6SAndroid Build Coastguard Worker * Retrieve the output value of a specific method without any input values. 271*523fa7a6SAndroid Build Coastguard Worker * Loads the program and method before execution if needed. 272*523fa7a6SAndroid Build Coastguard Worker * 273*523fa7a6SAndroid Build Coastguard Worker * @param[in] method_name The name of the method to execute. 274*523fa7a6SAndroid Build Coastguard Worker * 275*523fa7a6SAndroid Build Coastguard Worker * @returns A Result object containing either the first output value from the 276*523fa7a6SAndroid Build Coastguard Worker * method or an error to indicate failure. 277*523fa7a6SAndroid Build Coastguard Worker */ get(const std::string & method_name)278*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD inline runtime::Result<runtime::EValue> get( 279*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name) { 280*523fa7a6SAndroid Build Coastguard Worker return get(method_name, std::vector<runtime::EValue>{}); 281*523fa7a6SAndroid Build Coastguard Worker } 282*523fa7a6SAndroid Build Coastguard Worker 283*523fa7a6SAndroid Build Coastguard Worker /** 284*523fa7a6SAndroid Build Coastguard Worker * Execute the 'forward' method with the given input values and retrieve the 285*523fa7a6SAndroid Build Coastguard Worker * output values. Loads the program and method before executing if needed. 286*523fa7a6SAndroid Build Coastguard Worker * 287*523fa7a6SAndroid Build Coastguard Worker * @param[in] input_values A vector of input values for the 'forward' method. 288*523fa7a6SAndroid Build Coastguard Worker * 289*523fa7a6SAndroid Build Coastguard Worker * @returns A Result object containing either a vector of output values 290*523fa7a6SAndroid Build Coastguard Worker * from the 'forward' method or an error to indicate failure. 291*523fa7a6SAndroid Build Coastguard Worker */ forward(const std::vector<runtime::EValue> & input_values)292*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward( 293*523fa7a6SAndroid Build Coastguard Worker const std::vector<runtime::EValue>& input_values) { 294*523fa7a6SAndroid Build Coastguard Worker return execute("forward", input_values); 295*523fa7a6SAndroid Build Coastguard Worker } 296*523fa7a6SAndroid Build Coastguard Worker 297*523fa7a6SAndroid Build Coastguard Worker /** 298*523fa7a6SAndroid Build Coastguard Worker * Execute the 'forward' method with a single value. 299*523fa7a6SAndroid Build Coastguard Worker * Loads the program and method before executing if needed. 300*523fa7a6SAndroid Build Coastguard Worker * 301*523fa7a6SAndroid Build Coastguard Worker * @param[in] input_value A value for the 'forward' method. 302*523fa7a6SAndroid Build Coastguard Worker * 303*523fa7a6SAndroid Build Coastguard Worker * @returns A Result object containing either a vector of output values 304*523fa7a6SAndroid Build Coastguard Worker * from the 'forward' method or an error to indicate failure. 305*523fa7a6SAndroid Build Coastguard Worker */ forward(const runtime::EValue & input_value)306*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward( 307*523fa7a6SAndroid Build Coastguard Worker const runtime::EValue& input_value) { 308*523fa7a6SAndroid Build Coastguard Worker return forward(std::vector<runtime::EValue>{input_value}); 309*523fa7a6SAndroid Build Coastguard Worker } 310*523fa7a6SAndroid Build Coastguard Worker 311*523fa7a6SAndroid Build Coastguard Worker /** 312*523fa7a6SAndroid Build Coastguard Worker * Execute the 'forward' method without any input values. 313*523fa7a6SAndroid Build Coastguard Worker * Loads the program and method before executing if needed. 314*523fa7a6SAndroid Build Coastguard Worker * 315*523fa7a6SAndroid Build Coastguard Worker * @returns A Result object containing either a vector of output values 316*523fa7a6SAndroid Build Coastguard Worker * from the 'forward' method or an error to indicate failure. 317*523fa7a6SAndroid Build Coastguard Worker */ forward()318*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward() { 319*523fa7a6SAndroid Build Coastguard Worker return forward(std::vector<runtime::EValue>{}); 320*523fa7a6SAndroid Build Coastguard Worker } 321*523fa7a6SAndroid Build Coastguard Worker 322*523fa7a6SAndroid Build Coastguard Worker /** 323*523fa7a6SAndroid Build Coastguard Worker * Sets a single input value for a specific method. 324*523fa7a6SAndroid Build Coastguard Worker * 325*523fa7a6SAndroid Build Coastguard Worker * @param[in] method_name The name of the method. 326*523fa7a6SAndroid Build Coastguard Worker * @param[in] input_value The EValue to set as the method input. 327*523fa7a6SAndroid Build Coastguard Worker * @param[in] input_index Zero-based index of the input to set. 328*523fa7a6SAndroid Build Coastguard Worker * 329*523fa7a6SAndroid Build Coastguard Worker * @returns An Error to indicate success or failure. 330*523fa7a6SAndroid Build Coastguard Worker */ 331*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD 332*523fa7a6SAndroid Build Coastguard Worker runtime::Error set_input( 333*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name, 334*523fa7a6SAndroid Build Coastguard Worker const runtime::EValue& input_value, 335*523fa7a6SAndroid Build Coastguard Worker size_t input_index); 336*523fa7a6SAndroid Build Coastguard Worker 337*523fa7a6SAndroid Build Coastguard Worker /** 338*523fa7a6SAndroid Build Coastguard Worker * Sets a single input value for the "forward" method. 339*523fa7a6SAndroid Build Coastguard Worker * 340*523fa7a6SAndroid Build Coastguard Worker * @param[in] input_value The EValue to set as the method input. 341*523fa7a6SAndroid Build Coastguard Worker * @param[in] input_index Zero-based index of the input to set. 342*523fa7a6SAndroid Build Coastguard Worker * 343*523fa7a6SAndroid Build Coastguard Worker * @returns An Error to indicate success or failure. 344*523fa7a6SAndroid Build Coastguard Worker */ 345*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD set_input(const runtime::EValue & input_value,size_t input_index)346*523fa7a6SAndroid Build Coastguard Worker inline runtime::Error set_input( 347*523fa7a6SAndroid Build Coastguard Worker const runtime::EValue& input_value, 348*523fa7a6SAndroid Build Coastguard Worker size_t input_index) { 349*523fa7a6SAndroid Build Coastguard Worker return set_input("forward", input_value, input_index); 350*523fa7a6SAndroid Build Coastguard Worker } 351*523fa7a6SAndroid Build Coastguard Worker 352*523fa7a6SAndroid Build Coastguard Worker /** 353*523fa7a6SAndroid Build Coastguard Worker * Sets all input values for a specific method. 354*523fa7a6SAndroid Build Coastguard Worker * 355*523fa7a6SAndroid Build Coastguard Worker * @param[in] method_name The name of the method. 356*523fa7a6SAndroid Build Coastguard Worker * @param[in] input_values A vector of EValues to set as the method inputs. 357*523fa7a6SAndroid Build Coastguard Worker * 358*523fa7a6SAndroid Build Coastguard Worker * @returns An Error to indicate success or failure. 359*523fa7a6SAndroid Build Coastguard Worker */ 360*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD 361*523fa7a6SAndroid Build Coastguard Worker runtime::Error set_inputs( 362*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name, 363*523fa7a6SAndroid Build Coastguard Worker const std::vector<runtime::EValue>& input_values); 364*523fa7a6SAndroid Build Coastguard Worker 365*523fa7a6SAndroid Build Coastguard Worker /** 366*523fa7a6SAndroid Build Coastguard Worker * Sets all input values for the "forward" method. 367*523fa7a6SAndroid Build Coastguard Worker * 368*523fa7a6SAndroid Build Coastguard Worker * @param[in] input_values A vector of EValues to set as the method inputs. 369*523fa7a6SAndroid Build Coastguard Worker * 370*523fa7a6SAndroid Build Coastguard Worker * @returns An Error to indicate success or failure. 371*523fa7a6SAndroid Build Coastguard Worker */ 372*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD set_inputs(const std::vector<runtime::EValue> & input_values)373*523fa7a6SAndroid Build Coastguard Worker inline runtime::Error set_inputs( 374*523fa7a6SAndroid Build Coastguard Worker const std::vector<runtime::EValue>& input_values) { 375*523fa7a6SAndroid Build Coastguard Worker return set_inputs("forward", input_values); 376*523fa7a6SAndroid Build Coastguard Worker } 377*523fa7a6SAndroid Build Coastguard Worker 378*523fa7a6SAndroid Build Coastguard Worker /** 379*523fa7a6SAndroid Build Coastguard Worker * Sets the output tensor for a specific method. 380*523fa7a6SAndroid Build Coastguard Worker * 381*523fa7a6SAndroid Build Coastguard Worker * @param[in] method_name The name of the method. 382*523fa7a6SAndroid Build Coastguard Worker * @param[in] output_value The EValue containing the Tensor to set as the 383*523fa7a6SAndroid Build Coastguard Worker * method output. 384*523fa7a6SAndroid Build Coastguard Worker * @param[in] output_index Zero-based index of the output to set. 385*523fa7a6SAndroid Build Coastguard Worker * 386*523fa7a6SAndroid Build Coastguard Worker * @returns An Error to indicate success or failure. 387*523fa7a6SAndroid Build Coastguard Worker * 388*523fa7a6SAndroid Build Coastguard Worker * @note Only Tensor outputs are currently supported for setting. 389*523fa7a6SAndroid Build Coastguard Worker */ 390*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD 391*523fa7a6SAndroid Build Coastguard Worker runtime::Error set_output( 392*523fa7a6SAndroid Build Coastguard Worker const std::string& method_name, 393*523fa7a6SAndroid Build Coastguard Worker runtime::EValue output_value, 394*523fa7a6SAndroid Build Coastguard Worker size_t output_index = 0); 395*523fa7a6SAndroid Build Coastguard Worker 396*523fa7a6SAndroid Build Coastguard Worker /** 397*523fa7a6SAndroid Build Coastguard Worker * Sets the output tensor for the "forward" method. 398*523fa7a6SAndroid Build Coastguard Worker * 399*523fa7a6SAndroid Build Coastguard Worker * @param[in] output_value The EValue containing the Tensor to set as the 400*523fa7a6SAndroid Build Coastguard Worker * method output. 401*523fa7a6SAndroid Build Coastguard Worker * @param[in] output_index Zero-based index of the output to set. 402*523fa7a6SAndroid Build Coastguard Worker * 403*523fa7a6SAndroid Build Coastguard Worker * @returns An Error to indicate success or failure. 404*523fa7a6SAndroid Build Coastguard Worker * 405*523fa7a6SAndroid Build Coastguard Worker * @note Only Tensor outputs are currently supported for setting. 406*523fa7a6SAndroid Build Coastguard Worker */ 407*523fa7a6SAndroid Build Coastguard Worker ET_NODISCARD 408*523fa7a6SAndroid Build Coastguard Worker inline runtime::Error set_output( 409*523fa7a6SAndroid Build Coastguard Worker runtime::EValue output_value, 410*523fa7a6SAndroid Build Coastguard Worker size_t output_index = 0) { 411*523fa7a6SAndroid Build Coastguard Worker return set_output("forward", std::move(output_value), output_index); 412*523fa7a6SAndroid Build Coastguard Worker } 413*523fa7a6SAndroid Build Coastguard Worker 414*523fa7a6SAndroid Build Coastguard Worker /** 415*523fa7a6SAndroid Build Coastguard Worker * Retrieves the EventTracer instance being used by the Module. 416*523fa7a6SAndroid Build Coastguard Worker * EventTracer is used for tracking and logging events during the execution 417*523fa7a6SAndroid Build Coastguard Worker * of methods. 418*523fa7a6SAndroid Build Coastguard Worker * 419*523fa7a6SAndroid Build Coastguard Worker * @returns A pointer to the EventTracer instance. Returns nullptr if no 420*523fa7a6SAndroid Build Coastguard Worker * EventTracer is set. 421*523fa7a6SAndroid Build Coastguard Worker */ event_tracer()422*523fa7a6SAndroid Build Coastguard Worker inline runtime::EventTracer* event_tracer() const { 423*523fa7a6SAndroid Build Coastguard Worker return event_tracer_.get(); 424*523fa7a6SAndroid Build Coastguard Worker } 425*523fa7a6SAndroid Build Coastguard Worker 426*523fa7a6SAndroid Build Coastguard Worker private: 427*523fa7a6SAndroid Build Coastguard Worker struct MethodHolder { 428*523fa7a6SAndroid Build Coastguard Worker std::vector<std::vector<uint8_t>> planned_buffers; 429*523fa7a6SAndroid Build Coastguard Worker std::vector<runtime::Span<uint8_t>> planned_spans; 430*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::HierarchicalAllocator> planned_memory; 431*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::MemoryManager> memory_manager; 432*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::Method> method; 433*523fa7a6SAndroid Build Coastguard Worker std::vector<runtime::EValue> inputs; 434*523fa7a6SAndroid Build Coastguard Worker }; 435*523fa7a6SAndroid Build Coastguard Worker 436*523fa7a6SAndroid Build Coastguard Worker private: 437*523fa7a6SAndroid Build Coastguard Worker std::string file_path_; 438*523fa7a6SAndroid Build Coastguard Worker LoadMode load_mode_{LoadMode::MmapUseMlock}; 439*523fa7a6SAndroid Build Coastguard Worker std::shared_ptr<runtime::Program> program_; 440*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::DataLoader> data_loader_; 441*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::MemoryAllocator> memory_allocator_; 442*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::MemoryAllocator> temp_allocator_; 443*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<runtime::EventTracer> event_tracer_; 444*523fa7a6SAndroid Build Coastguard Worker 445*523fa7a6SAndroid Build Coastguard Worker protected: 446*523fa7a6SAndroid Build Coastguard Worker std::unordered_map<std::string, MethodHolder> methods_; 447*523fa7a6SAndroid Build Coastguard Worker 448*523fa7a6SAndroid Build Coastguard Worker friend class ExecuTorchJni; 449*523fa7a6SAndroid Build Coastguard Worker }; 450*523fa7a6SAndroid Build Coastguard Worker 451*523fa7a6SAndroid Build Coastguard Worker } // namespace extension 452*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch 453*523fa7a6SAndroid Build Coastguard Worker 454*523fa7a6SAndroid Build Coastguard Worker namespace torch { 455*523fa7a6SAndroid Build Coastguard Worker namespace executor { 456*523fa7a6SAndroid Build Coastguard Worker // TODO(T197294990): Remove these deprecated aliases once all users have moved 457*523fa7a6SAndroid Build Coastguard Worker // to the new `::executorch` namespaces. 458*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::Module; 459*523fa7a6SAndroid Build Coastguard Worker } // namespace executor 460*523fa7a6SAndroid Build Coastguard Worker } // namespace torch 461