xref: /aosp_15_r20/external/executorch/examples/llm_pte_finetuning/training_lib.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from functools import partial
10from typing import Any
11
12import torch
13from executorch.extension.pybindings.aten_lib import ExecuTorchModule  # @manual
14
15from torch.nn import functional as F
16from torch.utils.data import DataLoader, Dataset, DistributedSampler
17from torchtune.data import AlpacaToMessages
18from torchtune.data._collate import padded_collate_sft
19from torchtune.datasets import PackedDataset, SFTDataset
20from torchtune.modules.tokenizers import ModelTokenizer
21from tqdm import tqdm
22
23
24class TrainingModule(torch.nn.Module):
25    """
26    The model being trained should return the loss from forward(). This
27    class wraps the actual model and computes the loss for an LLM
28    fine-tuning task. The loss is computed as the cross entropy between
29    the tokens and a shifted version of the labels so we learn to predict
30    the next token.
31    """
32
33    def __init__(
34        self, model: torch.nn.Module, loss: torch.nn.modules.loss._Loss
35    ) -> None:
36        super().__init__()
37        self.model = model
38        self.loss = loss
39
40    def forward(self, input: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
41        # Output is of the shape (seq_len, vocab_size).
42        logits = self.model(input)
43        logits = logits[..., :-1, :].contiguous()
44        labels = labels[..., 1:].contiguous()
45        logits = logits.transpose(1, 2)
46        return self.loss(logits, labels)
47
48
49def python_code_instructions_alpaca(tokenizer: ModelTokenizer) -> PackedDataset:
50    """
51    Python code instruction-input-output pairs from iamtarun/python_code_instructions_18k_alpaca templated with Alpaca.
52    """
53    ds = SFTDataset(
54        # pyre-ignore[6]: Incompatible parameter type
55        model_transform=tokenizer,
56        source="iamtarun/python_code_instructions_18k_alpaca",
57        message_transform=AlpacaToMessages(
58            train_on_input=False,
59        ),
60        # pyre-ignore[6]: Incompatible parameter type
61        split="train",
62    )
63    if tokenizer.max_seq_len is None:
64        raise ValueError(
65            "PackedDataset requires a max_seq_len to be set on the tokenizer."
66        )
67    return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len, split_across_pack=False)
68
69
70def update_function(
71    param: torch.Tensor,
72    grad: torch.Tensor,
73    learning_rate: float,
74    weight_decay: float = 1.0,
75) -> None:
76    """SGD update function."""
77    grad = grad + weight_decay * param
78    param.sub_(learning_rate * grad)
79
80
81def eval_model(
82    model: ExecuTorchModule,
83    dataloader: DataLoader,
84    loss_fn: torch.nn.modules.loss._Loss,
85    max_seq_len: int,
86    num_eval_steps: int,
87) -> float:
88    total_loss = 0
89    for i, batch in tqdm(enumerate(dataloader), total=num_eval_steps):
90        if i >= num_eval_steps:
91            break
92        tokens, labels = batch["tokens"], batch["labels"]
93        token_size = tokens.shape[1]
94        labels_size = labels.shape[1]
95
96        tokens, labels = batch["tokens"], batch["labels"]
97        token_size = tokens.shape[1]
98        labels_size = labels.shape[1]
99
100        # Fixed length for now. We need to resize as the input shapes
101        # should be the same passed as examples to the export function.
102        if token_size > max_seq_len:
103            tokens = tokens[:, :max_seq_len]
104        else:
105            tokens = F.pad(tokens, (0, max_seq_len - token_size), value=0)
106
107        if labels_size > max_seq_len:
108            labels = labels[:, :max_seq_len]
109        else:
110            labels = F.pad(labels, (0, max_seq_len - labels_size), value=0)
111
112        out = model.forward((tokens, labels))
113        loss = out[0]
114        total_loss += loss
115    return total_loss / num_eval_steps
116
117
118def get_dataloader(
119    cfg: Any,  # pyre-ignore[2]
120    ds: Dataset[Any],  # pyre-ignore[2]
121    tokenizer: Any,  # pyre-ignore[2]
122    loss_fn: torch.nn.modules.loss._Loss,
123) -> DataLoader:
124    """Given a dataset, tokenizer, and loss function, return a dataloader."""
125    packed = cfg.dataset.get("packed", False)
126
127    sampler = DistributedSampler(
128        ds,
129        num_replicas=1,
130        rank=0,
131        shuffle=cfg.shuffle,
132        seed=0,
133    )
134    dataloader = DataLoader(
135        dataset=ds,
136        sampler=sampler,
137        batch_size=cfg.batch_size,
138        collate_fn=(
139            partial(
140                padded_collate_sft,
141                padding_idx=tokenizer.pad_id,
142                ignore_idx=loss_fn.ignore_index,
143            )
144            if not packed
145            else None
146        ),
147    )
148    return dataloader
149