1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9from typing import Any, Dict, List, Sequence 10 11from executorch.exir._warnings import experimental 12 13from executorch.extension.pybindings.portable_lib import ( 14 _load_for_executorch, 15 _load_for_executorch_from_buffer, 16 ExecuTorchModule, 17) 18from torch import Tensor 19 20 21@experimental("This API is experimental and subject to change without notice.") 22class TrainingModule: 23 def __init__(self, module: ExecuTorchModule): 24 self.model = module 25 26 self.gradients_method_prefix = "__et_training_gradients_index_" 27 self.parameters_method_prefix = "__et_training_parameters_index_" 28 self.fqn_method_prefix = "__et_training_fqn_" 29 30 self.named_grads = None 31 self.named_params = None 32 33 def forward_backward(self, method_name: str, inputs: Sequence[Any]) -> List[Any]: 34 # The default ET model returns a large list of outputs that can logically be 35 # separated into [user outputs, gradients, parameters]. Can use these metadata 36 # methods to slice the list into the correct parts. 37 grad_start_idx = self.model.run_method( 38 self.gradients_method_prefix + method_name, () 39 )[0] 40 params_start_idx = self.model.run_method( 41 self.parameters_method_prefix + method_name, () 42 )[0] 43 44 full_outputs = self.model.run_method(method_name, inputs) 45 46 user_outs = full_outputs[:grad_start_idx] 47 grads = full_outputs[grad_start_idx:params_start_idx] 48 params = full_outputs[params_start_idx:] 49 50 # Important that the outputs are not cloned because we need the optimizer to 51 # be able to mutate the actual weights and not clones of them. 52 fqn = self.model.run_method( 53 self.fqn_method_prefix + method_name, (), clone_outputs=False 54 ) 55 56 self.named_grads = dict(zip(fqn, grads)) 57 if self.named_params is None: 58 self.named_params = dict(zip(fqn, params)) 59 60 return user_outs 61 62 def named_gradients(self) -> Dict[str, Tensor]: 63 if self.named_grads is None: 64 raise RuntimeError("Must call forward_backward before named_grads") 65 return self.named_grads 66 67 def named_parameters(self) -> Dict[str, Tensor]: 68 if self.named_grads is None: 69 raise RuntimeError( 70 "Must call forward_backward before named_params. This will be fixed in a later version" 71 ) 72 return self.named_params 73 74 75@experimental("This API is experimental and subject to change without notice.") 76def _load_for_executorch_for_training(path: str) -> TrainingModule: 77 et_module = _load_for_executorch(path) 78 return TrainingModule(et_module) 79 80 81@experimental("This API is experimental and subject to change without notice.") 82def _load_for_executorch_for_training_from_buffer( 83 buffer: bytes, 84) -> TrainingModule: 85 et_module = _load_for_executorch_from_buffer(buffer) 86 return TrainingModule(et_module) 87