xref: /aosp_15_r20/external/executorch/examples/llm_pte_finetuning/runner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import argparse
10
11import torch
12from executorch.examples.llm_pte_finetuning.training_lib import (
13    eval_model,
14    get_dataloader,
15    update_function,
16)
17
18from executorch.extension.pybindings.aten_lib import (  # @manual
19    _load_for_executorch_from_buffer,
20)
21from omegaconf import OmegaConf
22from torch.nn import functional as F
23from torchtune import config
24from tqdm import tqdm
25
26parser = argparse.ArgumentParser(
27    prog="Runner",
28    description="Fine tunes LoRA model using ExecuTorch.",
29    epilog="Model exported to be used for fine-tuning.",
30)
31parser.add_argument("--cfg", type=str, help="Path to the config file.")
32parser.add_argument("--model_file", type=str, help="Path to the ET model file.")
33
34
35def main() -> None:
36    args = parser.parse_args()
37    config_file = args.cfg
38    file = args.model_file
39    cfg = OmegaConf.load(config_file)
40    tokenizer = config.instantiate(
41        cfg.tokenizer,
42    )
43
44    loss_fn = config.instantiate(cfg.loss)
45
46    ds = config.instantiate(cfg.dataset, tokenizer)
47    train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2])
48    train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn)
49    val_dataloader = get_dataloader(cfg, val_set, tokenizer, loss_fn)
50
51    max_seq_len = cfg.tokenizer.max_seq_len
52    # Num of steps to run training. Assume 1 epoch
53    num_steps = 100
54    with open(file, "rb") as f:
55        model_bytes = f.read()
56        et_mod = _load_for_executorch_from_buffer(model_bytes)
57
58        # Evaluate the model before training.
59        print("Evaluating the model before training")
60        eval_loss = eval_model(
61            model=et_mod,
62            dataloader=val_dataloader,
63            loss_fn=loss_fn,
64            max_seq_len=max_seq_len,
65            num_eval_steps=10,
66        )
67        print("Eval loss: ", eval_loss)
68
69        # Based on executorch/extension/training/module/training_module.cpp
70        # grads run from [grad_start, param_start]
71        # params run from [param_start, outputs_end]
72        grad_start = et_mod.run_method("__et_training_gradients_index_forward", [])[0]
73        param_start = et_mod.run_method("__et_training_parameters_index_forward", [])[0]
74        learning_rate = 5e-3
75        f.seek(0)
76        losses = []
77        for i, batch in tqdm(enumerate(train_dataloader), total=num_steps):
78            # Run for a limited number of steps.
79            if i >= num_steps:
80                break
81            tokens, labels = batch["tokens"], batch["labels"]
82            token_size = tokens.shape[1]
83            labels_size = labels.shape[1]
84
85            # Fixed length for now. We need to resize as the input shapes
86            # should be the same passed as examples to the export function.
87            if token_size > max_seq_len:
88                tokens = tokens[:, :max_seq_len]
89            else:
90                tokens = F.pad(tokens, (0, max_seq_len - token_size), value=0)
91
92            if labels_size > max_seq_len:
93                labels = labels[:, :max_seq_len]
94            else:
95                labels = F.pad(labels, (0, max_seq_len - labels_size), value=0)
96
97            # Do not clone outputs, since we want the original weights to be returned
98            # for us to update with the gradients in-place.
99            # See https://github.com/pytorch/executorch/blob/main/extension/pybindings/pybindings.cpp#L736
100            # for more info.
101            out = et_mod.forward((tokens, labels), clone_outputs=False)
102
103            loss = out[0]
104            losses.append(loss.item())
105            with torch.no_grad():
106                for grad, param in zip(out[grad_start:param_start], out[param_start:]):
107                    update_function(param, grad, learning_rate)
108
109        print("Losses: ", losses)
110        # Evaluate the model after training.
111        eval_loss = eval_model(
112            model=et_mod,
113            dataloader=val_dataloader,
114            loss_fn=loss_fn,
115            max_seq_len=max_seq_len,
116            num_eval_steps=10,
117        )
118    print("Eval loss: ", eval_loss)
119
120
121if __name__ == "__main__":
122    main()
123