xref: /aosp_15_r20/external/executorch/extension/training/pybindings/_training_module.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9from typing import Any, Dict, List, Sequence
10
11from executorch.exir._warnings import experimental
12
13from executorch.extension.pybindings.portable_lib import (
14    _load_for_executorch,
15    _load_for_executorch_from_buffer,
16    ExecuTorchModule,
17)
18from torch import Tensor
19
20
21@experimental("This API is experimental and subject to change without notice.")
22class TrainingModule:
23    def __init__(self, module: ExecuTorchModule):
24        self.model = module
25
26        self.gradients_method_prefix = "__et_training_gradients_index_"
27        self.parameters_method_prefix = "__et_training_parameters_index_"
28        self.fqn_method_prefix = "__et_training_fqn_"
29
30        self.named_grads = None
31        self.named_params = None
32
33    def forward_backward(self, method_name: str, inputs: Sequence[Any]) -> List[Any]:
34        # The default ET model returns a large list of outputs that can logically be
35        # separated into [user outputs, gradients, parameters]. Can use these metadata
36        # methods to slice the list into the correct parts.
37        grad_start_idx = self.model.run_method(
38            self.gradients_method_prefix + method_name, ()
39        )[0]
40        params_start_idx = self.model.run_method(
41            self.parameters_method_prefix + method_name, ()
42        )[0]
43
44        full_outputs = self.model.run_method(method_name, inputs)
45
46        user_outs = full_outputs[:grad_start_idx]
47        grads = full_outputs[grad_start_idx:params_start_idx]
48        params = full_outputs[params_start_idx:]
49
50        # Important that the outputs are not cloned because we need the optimizer to
51        # be able to mutate the actual weights and not clones of them.
52        fqn = self.model.run_method(
53            self.fqn_method_prefix + method_name, (), clone_outputs=False
54        )
55
56        self.named_grads = dict(zip(fqn, grads))
57        if self.named_params is None:
58            self.named_params = dict(zip(fqn, params))
59
60        return user_outs
61
62    def named_gradients(self) -> Dict[str, Tensor]:
63        if self.named_grads is None:
64            raise RuntimeError("Must call forward_backward before named_grads")
65        return self.named_grads
66
67    def named_parameters(self) -> Dict[str, Tensor]:
68        if self.named_grads is None:
69            raise RuntimeError(
70                "Must call forward_backward before named_params. This will be fixed in a later version"
71            )
72        return self.named_params
73
74
75@experimental("This API is experimental and subject to change without notice.")
76def _load_for_executorch_for_training(path: str) -> TrainingModule:
77    et_module = _load_for_executorch(path)
78    return TrainingModule(et_module)
79
80
81@experimental("This API is experimental and subject to change without notice.")
82def _load_for_executorch_for_training_from_buffer(
83    buffer: bytes,
84) -> TrainingModule:
85    et_module = _load_for_executorch_from_buffer(buffer)
86    return TrainingModule(et_module)
87