1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <map> 12 #include <memory> 13 #include <string> 14 #include <unordered_map> 15 #include <vector> 16 17 #include <executorch/extension/module/module.h> 18 #include <executorch/runtime/core/evalue.h> 19 #include <executorch/runtime/executor/program.h> 20 21 namespace executorch { 22 namespace extension { 23 namespace training { 24 25 /** 26 * A facade class for loading programs for on-device training and executing 27 * methods within them. 28 */ 29 class ET_EXPERIMENTAL TrainingModule final : executorch::extension::Module { 30 public: 31 explicit TrainingModule( 32 std::unique_ptr<runtime::DataLoader> data_loader, 33 std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr, 34 std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr, 35 std::unique_ptr<runtime::EventTracer> event_tracer = nullptr) Module(std::move (data_loader),std::move (memory_allocator),std::move (temp_allocator),std::move (event_tracer))36 : executorch::extension::Module( 37 std::move(data_loader), 38 std::move(memory_allocator), 39 std::move(temp_allocator), 40 std::move(event_tracer)), 41 method_named_gradients_({}) {} 42 43 explicit TrainingModule(const Module&) = delete; 44 TrainingModule& operator=(const Module&) = delete; 45 explicit TrainingModule(Module&&) = delete; 46 TrainingModule& operator=(Module&&) = delete; 47 48 /** 49 * Execute a specific method with the given input and retrieve output. Only 50 * valid if the specified method is a joint graph. Loads the program and 51 * method before executing if needed. 52 * 53 * @param[in] method_name The name of the joint graph method to execute. 54 * @param[in] input A vector of input values to be passed to the method. 55 * 56 * @returns A Result object containing the output values from the method or an 57 * error to indicate failure. 58 */ 59 ET_EXPERIMENTAL runtime::Result<std::vector<runtime::EValue>> 60 execute_forward_backward( 61 const std::string& method_name, 62 const std::vector<runtime::EValue>& input); 63 64 /** 65 * Retrieve the trainable parameters for a joint graph method. 66 * 67 * @param[in] method_name The name of the joint graph method to get the 68 * parameters for. 69 * 70 * @returns A Result object containing a map of the fully qualified name to 71 * parameter tensor, or an error if the method is not a joint graph. 72 */ 73 ET_EXPERIMENTAL 74 runtime::Result< 75 const std::map<executorch::aten::string_view, executorch::aten::Tensor>> 76 named_parameters(const std::string& method_name); 77 78 /** 79 * Retrieve the latest gradients for a joint graph method. 80 * 81 * @param[in] method_name The name of the joint graph method to get the 82 * gradients for. 83 * 84 * @returns A Result object containing a map of the fully qualified name to 85 * gradient tensor associated with that parameter from the latest 86 * forward_backward execution, or an error if the method is not a joint graph 87 * or has not been executed yet. 88 */ 89 ET_EXPERIMENTAL 90 runtime::Result< 91 const std::map<executorch::aten::string_view, executorch::aten::Tensor>> 92 named_gradients(const std::string& method_name); 93 94 private: 95 std::unordered_map< 96 std::string, 97 std::map<executorch::aten::string_view, executorch::aten::Tensor>> 98 method_named_gradients_; 99 }; 100 101 } // namespace training 102 } // namespace extension 103 } // namespace executorch 104