1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from functools import partial 10from typing import Any 11 12import torch 13from executorch.extension.pybindings.aten_lib import ExecuTorchModule # @manual 14 15from torch.nn import functional as F 16from torch.utils.data import DataLoader, Dataset, DistributedSampler 17from torchtune.data import AlpacaToMessages 18from torchtune.data._collate import padded_collate_sft 19from torchtune.datasets import PackedDataset, SFTDataset 20from torchtune.modules.tokenizers import ModelTokenizer 21from tqdm import tqdm 22 23 24class TrainingModule(torch.nn.Module): 25 """ 26 The model being trained should return the loss from forward(). This 27 class wraps the actual model and computes the loss for an LLM 28 fine-tuning task. The loss is computed as the cross entropy between 29 the tokens and a shifted version of the labels so we learn to predict 30 the next token. 31 """ 32 33 def __init__( 34 self, model: torch.nn.Module, loss: torch.nn.modules.loss._Loss 35 ) -> None: 36 super().__init__() 37 self.model = model 38 self.loss = loss 39 40 def forward(self, input: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 41 # Output is of the shape (seq_len, vocab_size). 42 logits = self.model(input) 43 logits = logits[..., :-1, :].contiguous() 44 labels = labels[..., 1:].contiguous() 45 logits = logits.transpose(1, 2) 46 return self.loss(logits, labels) 47 48 49def python_code_instructions_alpaca(tokenizer: ModelTokenizer) -> PackedDataset: 50 """ 51 Python code instruction-input-output pairs from iamtarun/python_code_instructions_18k_alpaca templated with Alpaca. 52 """ 53 ds = SFTDataset( 54 # pyre-ignore[6]: Incompatible parameter type 55 model_transform=tokenizer, 56 source="iamtarun/python_code_instructions_18k_alpaca", 57 message_transform=AlpacaToMessages( 58 train_on_input=False, 59 ), 60 # pyre-ignore[6]: Incompatible parameter type 61 split="train", 62 ) 63 if tokenizer.max_seq_len is None: 64 raise ValueError( 65 "PackedDataset requires a max_seq_len to be set on the tokenizer." 66 ) 67 return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len, split_across_pack=False) 68 69 70def update_function( 71 param: torch.Tensor, 72 grad: torch.Tensor, 73 learning_rate: float, 74 weight_decay: float = 1.0, 75) -> None: 76 """SGD update function.""" 77 grad = grad + weight_decay * param 78 param.sub_(learning_rate * grad) 79 80 81def eval_model( 82 model: ExecuTorchModule, 83 dataloader: DataLoader, 84 loss_fn: torch.nn.modules.loss._Loss, 85 max_seq_len: int, 86 num_eval_steps: int, 87) -> float: 88 total_loss = 0 89 for i, batch in tqdm(enumerate(dataloader), total=num_eval_steps): 90 if i >= num_eval_steps: 91 break 92 tokens, labels = batch["tokens"], batch["labels"] 93 token_size = tokens.shape[1] 94 labels_size = labels.shape[1] 95 96 tokens, labels = batch["tokens"], batch["labels"] 97 token_size = tokens.shape[1] 98 labels_size = labels.shape[1] 99 100 # Fixed length for now. We need to resize as the input shapes 101 # should be the same passed as examples to the export function. 102 if token_size > max_seq_len: 103 tokens = tokens[:, :max_seq_len] 104 else: 105 tokens = F.pad(tokens, (0, max_seq_len - token_size), value=0) 106 107 if labels_size > max_seq_len: 108 labels = labels[:, :max_seq_len] 109 else: 110 labels = F.pad(labels, (0, max_seq_len - labels_size), value=0) 111 112 out = model.forward((tokens, labels)) 113 loss = out[0] 114 total_loss += loss 115 return total_loss / num_eval_steps 116 117 118def get_dataloader( 119 cfg: Any, # pyre-ignore[2] 120 ds: Dataset[Any], # pyre-ignore[2] 121 tokenizer: Any, # pyre-ignore[2] 122 loss_fn: torch.nn.modules.loss._Loss, 123) -> DataLoader: 124 """Given a dataset, tokenizer, and loss function, return a dataloader.""" 125 packed = cfg.dataset.get("packed", False) 126 127 sampler = DistributedSampler( 128 ds, 129 num_replicas=1, 130 rank=0, 131 shuffle=cfg.shuffle, 132 seed=0, 133 ) 134 dataloader = DataLoader( 135 dataset=ds, 136 sampler=sampler, 137 batch_size=cfg.batch_size, 138 collate_fn=( 139 partial( 140 padded_collate_sft, 141 padding_idx=tokenizer.pad_id, 142 ignore_idx=loss_fn.ignore_index, 143 ) 144 if not packed 145 else None 146 ), 147 ) 148 return dataloader 149