xref: /aosp_15_r20/external/executorch/extension/training/module/training_module.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <map>
12 #include <memory>
13 #include <string>
14 #include <unordered_map>
15 #include <vector>
16 
17 #include <executorch/extension/module/module.h>
18 #include <executorch/runtime/core/evalue.h>
19 #include <executorch/runtime/executor/program.h>
20 
21 namespace executorch {
22 namespace extension {
23 namespace training {
24 
25 /**
26  * A facade class for loading programs for on-device training and executing
27  * methods within them.
28  */
29 class ET_EXPERIMENTAL TrainingModule final : executorch::extension::Module {
30  public:
31   explicit TrainingModule(
32       std::unique_ptr<runtime::DataLoader> data_loader,
33       std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
34       std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
35       std::unique_ptr<runtime::EventTracer> event_tracer = nullptr)
Module(std::move (data_loader),std::move (memory_allocator),std::move (temp_allocator),std::move (event_tracer))36       : executorch::extension::Module(
37             std::move(data_loader),
38             std::move(memory_allocator),
39             std::move(temp_allocator),
40             std::move(event_tracer)),
41         method_named_gradients_({}) {}
42 
43   explicit TrainingModule(const Module&) = delete;
44   TrainingModule& operator=(const Module&) = delete;
45   explicit TrainingModule(Module&&) = delete;
46   TrainingModule& operator=(Module&&) = delete;
47 
48   /**
49    * Execute a specific method with the given input and retrieve output. Only
50    * valid if the specified method is a joint graph. Loads the program and
51    * method before executing if needed.
52    *
53    * @param[in] method_name The name of the joint graph method to execute.
54    * @param[in] input A vector of input values to be passed to the method.
55    *
56    * @returns A Result object containing the output values from the method or an
57    * error to indicate failure.
58    */
59   ET_EXPERIMENTAL runtime::Result<std::vector<runtime::EValue>>
60   execute_forward_backward(
61       const std::string& method_name,
62       const std::vector<runtime::EValue>& input);
63 
64   /**
65    * Retrieve the trainable parameters for a joint graph method.
66    *
67    * @param[in] method_name The name of the joint graph method to get the
68    * parameters for.
69    *
70    * @returns A Result object containing a map of the fully qualified name to
71    * parameter tensor, or an error if the method is not a joint graph.
72    */
73   ET_EXPERIMENTAL
74   runtime::Result<
75       const std::map<executorch::aten::string_view, executorch::aten::Tensor>>
76   named_parameters(const std::string& method_name);
77 
78   /**
79    * Retrieve the latest gradients for a joint graph method.
80    *
81    * @param[in] method_name The name of the joint graph method to get the
82    * gradients for.
83    *
84    * @returns A Result object containing a map of the fully qualified name to
85    * gradient tensor associated with that parameter from the latest
86    * forward_backward execution, or an error if the method is not a joint graph
87    * or has not been executed yet.
88    */
89   ET_EXPERIMENTAL
90   runtime::Result<
91       const std::map<executorch::aten::string_view, executorch::aten::Tensor>>
92   named_gradients(const std::string& method_name);
93 
94  private:
95   std::unordered_map<
96       std::string,
97       std::map<executorch::aten::string_view, executorch::aten::Tensor>>
98       method_named_gradients_;
99 };
100 
101 } // namespace training
102 } // namespace extension
103 } // namespace executorch
104