1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <memory> 12 #include <string> 13 #include <unordered_map> 14 #include <unordered_set> 15 #include <vector> 16 17 #include <executorch/runtime/executor/program.h> 18 19 namespace executorch { 20 namespace extension { 21 22 /** 23 * A facade class for loading programs and executing methods within them. 24 */ 25 class Module { 26 public: 27 /** 28 * Enum to define loading behavior. 29 */ 30 enum class LoadMode { 31 /// Load the whole file as a buffer. 32 File, 33 /// Use mmap to load pages into memory. 34 Mmap, 35 /// Use memory locking and handle errors. 36 MmapUseMlock, 37 /// Use memory locking and ignore errors. 38 MmapUseMlockIgnoreErrors, 39 }; 40 41 /** 42 * Constructs an instance by loading a program from a file with specified 43 * memory locking behavior. 44 * 45 * @param[in] file_path The path to the ExecuTorch program file to load. 46 * @param[in] load_mode The loading mode to use. 47 * @param[in] event_tracer A EventTracer used for tracking and logging events. 48 */ 49 explicit Module( 50 const std::string& file_path, 51 const LoadMode load_mode = LoadMode::MmapUseMlock, 52 std::unique_ptr<runtime::EventTracer> event_tracer = nullptr); 53 54 /** 55 * Constructs an instance with the provided data loader and memory allocator. 56 * 57 * @param[in] data_loader A DataLoader used for loading program data. 58 * @param[in] memory_allocator A MemoryAllocator used for memory management. 59 * @param[in] temp_allocator A MemoryAllocator to use when allocating 60 * temporary data during kernel or delegate execution. 61 * @param[in] event_tracer A EventTracer used for tracking and logging events. 62 */ 63 explicit Module( 64 std::unique_ptr<runtime::DataLoader> data_loader, 65 std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr, 66 std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr, 67 std::unique_ptr<runtime::EventTracer> event_tracer = nullptr); 68 69 /** 70 * Constructs an instance using an existing shared program. 71 * 72 * @param[in] program The shared program to use. It's required the data loader 73 * the program uses is valid for the lifetime of the program. 74 * @param[in] memory_allocator A MemoryAllocator used for memory management. 75 * @param[in] temp_allocator A MemoryAllocator to use when allocating 76 * temporary data. 77 * @param[in] event_tracer A EventTracer used for tracking and logging events. 78 */ 79 explicit Module( 80 std::shared_ptr<runtime::Program> program, 81 std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr, 82 std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr, 83 std::unique_ptr<runtime::EventTracer> event_tracer = nullptr); 84 85 Module(const Module&) = delete; 86 Module& operator=(const Module&) = delete; 87 Module(Module&&) = delete; 88 Module& operator=(Module&&) = delete; 89 90 /** 91 * Loads the program if needed. 92 * 93 * @param[in] verification The type of verification to do before returning 94 * success. 95 * 96 * @returns An Error to indicate success or failure of the loading process. 97 */ 98 ET_NODISCARD 99 runtime::Error load( 100 const runtime::Program::Verification verification = 101 runtime::Program::Verification::Minimal); 102 103 /** 104 * Checks if the program is loaded. 105 * 106 * @returns true if the program is loaded, false otherwise. 107 */ is_loaded()108 inline bool is_loaded() const { 109 return program_ != nullptr; 110 } 111 112 /** 113 * Get the program. The data loader used by the program is guaranteed to be 114 * valid for the lifetime of the program. 115 * 116 * @returns Shared pointer to the program or nullptr if it's not yet loaded. 117 */ program()118 inline std::shared_ptr<runtime::Program> program() const { 119 return program_; 120 } 121 122 /** 123 * Get a list of method names available in the loaded program. 124 * Loads the program and method if needed. 125 * 126 * @returns A set of strings containing the names of the methods, or an error 127 * if the program or method failed to load. 128 */ 129 runtime::Result<std::unordered_set<std::string>> method_names(); 130 131 /** 132 * Load a specific method from the program and set up memory management if 133 * needed. The loaded method is cached to reuse the next time it's executed. 134 * 135 * @param[in] method_name The name of the method to load. 136 * @param[in] event_tracer Per-method event tracer to profile/trace methods 137 * individually. When not given, the event tracer passed to the Module 138 * constructor is used. Otherwise, this per-method event tracer takes 139 * precedence. 140 * 141 * @returns An Error to indicate success or failure. 142 */ 143 ET_NODISCARD 144 runtime::Error load_method( 145 const std::string& method_name, 146 torch::executor::EventTracer* event_tracer = nullptr); 147 148 /** 149 * Load the 'forward' method from the program and set up memory management if 150 * needed. The loaded method is cached to reuse the next time it's executed. 151 * 152 * @param[in] event_tracer An event tracer used for tracking and logging 153 * events. 154 * 155 * @returns An Error to indicate success or failure. 156 */ 157 ET_NODISCARD inline runtime::Error load_forward( 158 torch::executor::EventTracer* event_tracer = nullptr) { 159 return load_method("forward", event_tracer); 160 } 161 162 /** 163 * Checks if a specific method is loaded. 164 * 165 * @param[in] method_name The name of the method to check. 166 * 167 * @returns true if the method specified by method_name is loaded, false 168 * otherwise. 169 */ is_method_loaded(const std::string & method_name)170 inline bool is_method_loaded(const std::string& method_name) const { 171 return methods_.count(method_name); 172 } 173 174 /** 175 * Get a method metadata struct by method name. 176 * Loads the program and method if needed. 177 * 178 * @param[in] method_name The name of the method to get the metadata for. 179 * 180 * @returns A method metadata, or an error if the program or method failed to 181 * load. 182 */ 183 runtime::Result<runtime::MethodMeta> method_meta( 184 const std::string& method_name); 185 186 /** 187 * Execute a specific method with the given input values and retrieve the 188 * output values. Loads the program and method before executing if needed. 189 * 190 * @param[in] method_name The name of the method to execute. 191 * @param[in] input_values A vector of input values to be passed to the 192 * method. 193 * 194 * @returns A Result object containing either a vector of output values 195 * from the method or an error to indicate failure. 196 */ 197 ET_NODISCARD 198 runtime::Result<std::vector<runtime::EValue>> execute( 199 const std::string& method_name, 200 const std::vector<runtime::EValue>& input_values); 201 202 /** 203 * Execute a specific method with a single input value. 204 * Loads the program and method before executing if needed. 205 * 206 * @param[in] method_name The name of the method to execute. 207 * @param[in] input_value A value to be passed to the method. 208 * 209 * @returns A Result object containing either a vector of output values 210 * from the method or an error to indicate failure. 211 */ execute(const std::string & method_name,const runtime::EValue & input_value)212 ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> execute( 213 const std::string& method_name, 214 const runtime::EValue& input_value) { 215 return execute(method_name, std::vector<runtime::EValue>{input_value}); 216 } 217 218 /** 219 * Execute a specific method without any input values. 220 * Loads the program and method before executing if needed. 221 * 222 * @param[in] method_name The name of the method to execute. 223 * 224 * @returns A Result object containing either a vector of output values 225 * from the method or an error to indicate failure. 226 */ execute(const std::string & method_name)227 ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> execute( 228 const std::string& method_name) { 229 return execute(method_name, std::vector<runtime::EValue>{}); 230 } 231 232 /** 233 * Retrieve the output value of a specific method with the given input values. 234 * Loads the program and method before execution if needed. 235 * 236 * @param[in] method_name The name of the method to execute. 237 * @param[in] input_values A vector of input values to be passed to the 238 * method. 239 * 240 * @returns A Result object containing either the first output value from the 241 * method or an error to indicate failure. 242 */ get(const std::string & method_name,const std::vector<runtime::EValue> & input_values)243 ET_NODISCARD inline runtime::Result<runtime::EValue> get( 244 const std::string& method_name, 245 const std::vector<runtime::EValue>& input_values) { 246 auto result = ET_UNWRAP(execute(method_name, input_values)); 247 if (result.empty()) { 248 return runtime::Error::InvalidArgument; 249 } 250 return result[0]; 251 } 252 253 /** 254 * Retrieve the output value of a specific method with a single input value. 255 * Loads the program and method before execution if needed. 256 * 257 * @param[in] method_name The name of the method to execute. 258 * @param[in] input_value A value to be passed to the method. 259 * 260 * @returns A Result object containing either the first output value from the 261 * method or an error to indicate failure. 262 */ get(const std::string & method_name,const runtime::EValue & input_value)263 ET_NODISCARD inline runtime::Result<runtime::EValue> get( 264 const std::string& method_name, 265 const runtime::EValue& input_value) { 266 return get(method_name, std::vector<runtime::EValue>{input_value}); 267 } 268 269 /** 270 * Retrieve the output value of a specific method without any input values. 271 * Loads the program and method before execution if needed. 272 * 273 * @param[in] method_name The name of the method to execute. 274 * 275 * @returns A Result object containing either the first output value from the 276 * method or an error to indicate failure. 277 */ get(const std::string & method_name)278 ET_NODISCARD inline runtime::Result<runtime::EValue> get( 279 const std::string& method_name) { 280 return get(method_name, std::vector<runtime::EValue>{}); 281 } 282 283 /** 284 * Execute the 'forward' method with the given input values and retrieve the 285 * output values. Loads the program and method before executing if needed. 286 * 287 * @param[in] input_values A vector of input values for the 'forward' method. 288 * 289 * @returns A Result object containing either a vector of output values 290 * from the 'forward' method or an error to indicate failure. 291 */ forward(const std::vector<runtime::EValue> & input_values)292 ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward( 293 const std::vector<runtime::EValue>& input_values) { 294 return execute("forward", input_values); 295 } 296 297 /** 298 * Execute the 'forward' method with a single value. 299 * Loads the program and method before executing if needed. 300 * 301 * @param[in] input_value A value for the 'forward' method. 302 * 303 * @returns A Result object containing either a vector of output values 304 * from the 'forward' method or an error to indicate failure. 305 */ forward(const runtime::EValue & input_value)306 ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward( 307 const runtime::EValue& input_value) { 308 return forward(std::vector<runtime::EValue>{input_value}); 309 } 310 311 /** 312 * Execute the 'forward' method without any input values. 313 * Loads the program and method before executing if needed. 314 * 315 * @returns A Result object containing either a vector of output values 316 * from the 'forward' method or an error to indicate failure. 317 */ forward()318 ET_NODISCARD inline runtime::Result<std::vector<runtime::EValue>> forward() { 319 return forward(std::vector<runtime::EValue>{}); 320 } 321 322 /** 323 * Sets a single input value for a specific method. 324 * 325 * @param[in] method_name The name of the method. 326 * @param[in] input_value The EValue to set as the method input. 327 * @param[in] input_index Zero-based index of the input to set. 328 * 329 * @returns An Error to indicate success or failure. 330 */ 331 ET_NODISCARD 332 runtime::Error set_input( 333 const std::string& method_name, 334 const runtime::EValue& input_value, 335 size_t input_index); 336 337 /** 338 * Sets a single input value for the "forward" method. 339 * 340 * @param[in] input_value The EValue to set as the method input. 341 * @param[in] input_index Zero-based index of the input to set. 342 * 343 * @returns An Error to indicate success or failure. 344 */ 345 ET_NODISCARD set_input(const runtime::EValue & input_value,size_t input_index)346 inline runtime::Error set_input( 347 const runtime::EValue& input_value, 348 size_t input_index) { 349 return set_input("forward", input_value, input_index); 350 } 351 352 /** 353 * Sets all input values for a specific method. 354 * 355 * @param[in] method_name The name of the method. 356 * @param[in] input_values A vector of EValues to set as the method inputs. 357 * 358 * @returns An Error to indicate success or failure. 359 */ 360 ET_NODISCARD 361 runtime::Error set_inputs( 362 const std::string& method_name, 363 const std::vector<runtime::EValue>& input_values); 364 365 /** 366 * Sets all input values for the "forward" method. 367 * 368 * @param[in] input_values A vector of EValues to set as the method inputs. 369 * 370 * @returns An Error to indicate success or failure. 371 */ 372 ET_NODISCARD set_inputs(const std::vector<runtime::EValue> & input_values)373 inline runtime::Error set_inputs( 374 const std::vector<runtime::EValue>& input_values) { 375 return set_inputs("forward", input_values); 376 } 377 378 /** 379 * Sets the output tensor for a specific method. 380 * 381 * @param[in] method_name The name of the method. 382 * @param[in] output_value The EValue containing the Tensor to set as the 383 * method output. 384 * @param[in] output_index Zero-based index of the output to set. 385 * 386 * @returns An Error to indicate success or failure. 387 * 388 * @note Only Tensor outputs are currently supported for setting. 389 */ 390 ET_NODISCARD 391 runtime::Error set_output( 392 const std::string& method_name, 393 runtime::EValue output_value, 394 size_t output_index = 0); 395 396 /** 397 * Sets the output tensor for the "forward" method. 398 * 399 * @param[in] output_value The EValue containing the Tensor to set as the 400 * method output. 401 * @param[in] output_index Zero-based index of the output to set. 402 * 403 * @returns An Error to indicate success or failure. 404 * 405 * @note Only Tensor outputs are currently supported for setting. 406 */ 407 ET_NODISCARD 408 inline runtime::Error set_output( 409 runtime::EValue output_value, 410 size_t output_index = 0) { 411 return set_output("forward", std::move(output_value), output_index); 412 } 413 414 /** 415 * Retrieves the EventTracer instance being used by the Module. 416 * EventTracer is used for tracking and logging events during the execution 417 * of methods. 418 * 419 * @returns A pointer to the EventTracer instance. Returns nullptr if no 420 * EventTracer is set. 421 */ event_tracer()422 inline runtime::EventTracer* event_tracer() const { 423 return event_tracer_.get(); 424 } 425 426 private: 427 struct MethodHolder { 428 std::vector<std::vector<uint8_t>> planned_buffers; 429 std::vector<runtime::Span<uint8_t>> planned_spans; 430 std::unique_ptr<runtime::HierarchicalAllocator> planned_memory; 431 std::unique_ptr<runtime::MemoryManager> memory_manager; 432 std::unique_ptr<runtime::Method> method; 433 std::vector<runtime::EValue> inputs; 434 }; 435 436 private: 437 std::string file_path_; 438 LoadMode load_mode_{LoadMode::MmapUseMlock}; 439 std::shared_ptr<runtime::Program> program_; 440 std::unique_ptr<runtime::DataLoader> data_loader_; 441 std::unique_ptr<runtime::MemoryAllocator> memory_allocator_; 442 std::unique_ptr<runtime::MemoryAllocator> temp_allocator_; 443 std::unique_ptr<runtime::EventTracer> event_tracer_; 444 445 protected: 446 std::unordered_map<std::string, MethodHolder> methods_; 447 448 friend class ExecuTorchJni; 449 }; 450 451 } // namespace extension 452 } // namespace executorch 453 454 namespace torch { 455 namespace executor { 456 // TODO(T197294990): Remove these deprecated aliases once all users have moved 457 // to the new `::executorch` namespaces. 458 using ::executorch::extension::Module; 459 } // namespace executor 460 } // namespace torch 461