xref: /aosp_15_r20/external/executorch/examples/llm_pte_finetuning/model_loading_lib.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from typing import Any, Dict, Tuple
10
11import torch
12from executorch.examples.llm_pte_finetuning.training_lib import TrainingModule
13from executorch.exir import to_edge
14
15from omegaconf import DictConfig
16from torch.export import export, ExportedProgram
17from torch.export.experimental import _export_forward_backward
18from torch.nn.attention import sdpa_kernel, SDPBackend
19from torchtune import config
20from torchtune.modules.peft import get_adapter_params, set_trainable_params
21from torchtune.training.precision import get_dtype, set_default_dtype
22from torchtune.utils._device import get_device
23
24
25def load_checkpoint(cfg: Any) -> Dict[str, Any]:  # pyre-ignore[2]
26    """
27    Extract the checkpoint state from file and validate. This includes the
28    base model weights. If resume_from_checkpoint is True, this also includes
29    the adapter weights and recipe state
30    """
31    checkpointer = config.instantiate(
32        cfg.checkpointer,
33        resume_from_checkpoint=cfg.resume_from_checkpoint,
34    )
35    checkpoint_dict = checkpointer.load_checkpoint()
36    return checkpoint_dict
37
38
39def setup_model(
40    cfg: DictConfig,
41    base_model_state_dict: Dict[str, Any],
42) -> torch.nn.Module:
43    device = get_device(device=cfg.device)
44    dtype = get_dtype(cfg.dtype, device=device)
45    with set_default_dtype(dtype), device:
46        model = config.instantiate(cfg.model)
47
48    adapter_params = get_adapter_params(model)
49    set_trainable_params(model, adapter_params)
50    model.load_state_dict(base_model_state_dict, strict=False)
51    return model
52
53
54def export_model_lora_training(
55    model: TrainingModule,
56    example_args: Tuple[Any, ...],  # pyre-ignore[2]
57    output_file: str,
58) -> None:
59    """
60    Export model with LoRA model to executorch for training, only.
61    """
62
63    # 0. Mark the LoRA layers as trainable (requires_grad = True) in order
64    # to just export the backwards pass for these layers later in the
65    # export process.
66    set_trainable_params(model, get_adapter_params(model))
67
68    print("Exporting model with LoRA for training")
69    # 1. torch.export: Defines the program with the ATen operator set.
70
71    with sdpa_kernel([SDPBackend.MATH]):
72        exported_graph: ExportedProgram = export(model, example_args, strict=False)
73        print("Creating a joint forward-backwards graph for training")
74        joint_graph = _export_forward_backward(exported_graph)
75
76        # 2. to_edge: Make optimizations for Edge devices.
77        print("Lowering to edge dialect")
78        edge_program = to_edge(joint_graph)
79
80        print(edge_program._edge_programs["forward"].graph_module)
81
82    # 3. to_executorch: Convert the graph to an ExecuTorch program.
83    print("Exporting to executorch")
84    executorch_program = edge_program.to_executorch()
85    print(executorch_program.exported_program().graph_signature)
86    print(f"Saving to {output_file}")
87    with open(output_file, "wb") as file:
88        file.write(executorch_program.buffer)
89