1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from typing import Any, Dict, Tuple 10 11import torch 12from executorch.examples.llm_pte_finetuning.training_lib import TrainingModule 13from executorch.exir import to_edge 14 15from omegaconf import DictConfig 16from torch.export import export, ExportedProgram 17from torch.export.experimental import _export_forward_backward 18from torch.nn.attention import sdpa_kernel, SDPBackend 19from torchtune import config 20from torchtune.modules.peft import get_adapter_params, set_trainable_params 21from torchtune.training.precision import get_dtype, set_default_dtype 22from torchtune.utils._device import get_device 23 24 25def load_checkpoint(cfg: Any) -> Dict[str, Any]: # pyre-ignore[2] 26 """ 27 Extract the checkpoint state from file and validate. This includes the 28 base model weights. If resume_from_checkpoint is True, this also includes 29 the adapter weights and recipe state 30 """ 31 checkpointer = config.instantiate( 32 cfg.checkpointer, 33 resume_from_checkpoint=cfg.resume_from_checkpoint, 34 ) 35 checkpoint_dict = checkpointer.load_checkpoint() 36 return checkpoint_dict 37 38 39def setup_model( 40 cfg: DictConfig, 41 base_model_state_dict: Dict[str, Any], 42) -> torch.nn.Module: 43 device = get_device(device=cfg.device) 44 dtype = get_dtype(cfg.dtype, device=device) 45 with set_default_dtype(dtype), device: 46 model = config.instantiate(cfg.model) 47 48 adapter_params = get_adapter_params(model) 49 set_trainable_params(model, adapter_params) 50 model.load_state_dict(base_model_state_dict, strict=False) 51 return model 52 53 54def export_model_lora_training( 55 model: TrainingModule, 56 example_args: Tuple[Any, ...], # pyre-ignore[2] 57 output_file: str, 58) -> None: 59 """ 60 Export model with LoRA model to executorch for training, only. 61 """ 62 63 # 0. Mark the LoRA layers as trainable (requires_grad = True) in order 64 # to just export the backwards pass for these layers later in the 65 # export process. 66 set_trainable_params(model, get_adapter_params(model)) 67 68 print("Exporting model with LoRA for training") 69 # 1. torch.export: Defines the program with the ATen operator set. 70 71 with sdpa_kernel([SDPBackend.MATH]): 72 exported_graph: ExportedProgram = export(model, example_args, strict=False) 73 print("Creating a joint forward-backwards graph for training") 74 joint_graph = _export_forward_backward(exported_graph) 75 76 # 2. to_edge: Make optimizations for Edge devices. 77 print("Lowering to edge dialect") 78 edge_program = to_edge(joint_graph) 79 80 print(edge_program._edge_programs["forward"].graph_module) 81 82 # 3. to_executorch: Convert the graph to an ExecuTorch program. 83 print("Exporting to executorch") 84 executorch_program = edge_program.to_executorch() 85 print(executorch_program.exported_program().graph_signature) 86 print(f"Saving to {output_file}") 87 with open(output_file, "wb") as file: 88 file.write(executorch_program.buffer) 89