1import argparse 2import inspect 3import os 4import sys 5import time 6from datetime import timedelta 7 8from datasets import load_dataset, load_metric 9from transformers import AutoModelForSequenceClassification, AutoTokenizer 10 11import torch 12import torch._dynamo 13from torch.utils.data import DataLoader 14 15 16torch.backends.cuda.matmul.allow_tf32 = True 17 18# You will download around 84G dataset if you run this end to end training/evaluation example. 19 20os.environ["TOKENIZERS_PARALLELISM"] = "false" 21device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 22 23 24def data_processing(num_samples, batch_size): 25 dataset = load_dataset("yelp_review_full") 26 tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") 27 28 def tokenize_function(examples): 29 return tokenizer(examples["text"], padding="max_length", truncation=True) 30 31 tokenized_datasets = dataset.map(tokenize_function, batched=True) 32 33 tokenized_datasets = tokenized_datasets.remove_columns(["text"]) 34 tokenized_datasets = tokenized_datasets.rename_column("label", "labels") 35 tokenized_datasets.set_format("torch") 36 37 small_train_dataset = tokenized_datasets["train"].select(range(num_samples)) 38 small_eval_dataset = tokenized_datasets["test"].select(range(num_samples)) 39 40 train_dataloader = DataLoader(small_train_dataset, batch_size=batch_size) 41 eval_dataloader = DataLoader(small_eval_dataset, batch_size=batch_size) 42 43 return train_dataloader, eval_dataloader 44 45 46def training_iter_fn(batch, model, optimizer): 47 outputs = model(**batch) 48 loss = outputs.loss 49 loss.backward() 50 optimizer.step() 51 optimizer.zero_grad() 52 return loss 53 54 55def model_training_evaluation( 56 backend, train_dataloader, eval_dataloader, model, optimizer, num_epochs, evaluation 57): 58 model.to(device) 59 model.train() 60 loss_history = [] 61 if not backend: 62 # Run with native Pytorch 63 opt_training_iter_fn = training_iter_fn 64 else: 65 # Support backends: eager, aot_eager, aot_nvfuser and inductor 66 opt_training_iter_fn = torch._dynamo.optimize(backend)(training_iter_fn) 67 for epoch in range(num_epochs): 68 running_loss = 0.0 69 for i, batch in enumerate(train_dataloader, 0): 70 batch = {k: v.to(device) for k, v in batch.items()} 71 loss = opt_training_iter_fn(batch, model, optimizer) 72 running_loss += loss.item() 73 if i % 100 == 99: 74 loss_history.append(running_loss / 100) 75 running_loss = 0.0 76 77 if evaluation: 78 metric = load_metric("accuracy") 79 model.eval() 80 if not backend: 81 opt_model = model 82 else: 83 opt_model = torch._dynamo.optimize(backend)(model) 84 for batch in eval_dataloader: 85 batch = {k: v.to(device) for k, v in batch.items()} 86 with torch.no_grad(): 87 outputs = opt_model(**batch) 88 89 logits = outputs.logits 90 predictions = torch.argmax(logits, dim=-1) 91 metric.add_batch(predictions=predictions, references=batch["labels"]) 92 93 return loss_history, metric.compute() 94 else: 95 return loss_history, None 96 97 98def check_loss(ref_loss, res_loss): 99 assert len(ref_loss) == len(res_loss) 100 length = len(ref_loss) 101 x = min(length, 10) 102 return sum(res_loss[-x:]) / 10 <= sum(ref_loss[-x:]) / 10 + 0.1 103 104 105def parse_args(): 106 parser = argparse.ArgumentParser( 107 description="TorchDynamo end to end training/evaluation benchmark" 108 ) 109 parser.add_argument( 110 "--epochs", type=int, default=10, help="number of epochs to train (default: 10)" 111 ) 112 parser.add_argument( 113 "--num-samples", 114 type=int, 115 default=1000, 116 help="number of samples to train/eval (default: 1000)", 117 ) 118 parser.add_argument( 119 "--batch-size", 120 type=int, 121 default=8, 122 help="input batch size for training (default: 8)", 123 ) 124 parser.add_argument( 125 "--lr", type=float, default=5e-5, help="learning rate (default: 5e-5)" 126 ) 127 parser.add_argument( 128 "--backend", 129 choices=torch._dynamo.list_backends(exclude_tags=None), 130 default="inductor", 131 help="train/evaluate model with a given backend (default: inductor)", 132 ) 133 parser.add_argument( 134 "--optimizer", 135 default="Adam", 136 help="train model using a given optimizer (default: Adam)", 137 ) 138 parser.add_argument( 139 "--evaluation", 140 action="store_true", 141 help="running evaluation after model training", 142 ) 143 args = parser.parse_args() 144 return args 145 146 147def main(): 148 args = parse_args() 149 train_dataloader, eval_dataloader = data_processing( 150 args.num_samples, args.batch_size 151 ) 152 model = AutoModelForSequenceClassification.from_pretrained( 153 "bert-base-cased", num_labels=5 154 ) 155 optimizer_cls = getattr(sys.modules["torch.optim"], args.optimizer) 156 if "capturable" in inspect.signature(optimizer_cls).parameters.keys(): 157 optimizer = optimizer_cls(model.parameters(), lr=args.lr, capturable=True) 158 else: 159 optimizer = optimizer_cls(model.parameters(), lr=args.lr) 160 native_start = time.time() 161 ref_loss, accuracy = model_training_evaluation( 162 None, 163 train_dataloader, 164 eval_dataloader, 165 model, 166 optimizer, 167 args.epochs, 168 args.evaluation, 169 ) 170 native_end = time.time() 171 res_loss, accuracy = model_training_evaluation( 172 args.backend, 173 train_dataloader, 174 eval_dataloader, 175 model, 176 optimizer, 177 args.epochs, 178 args.evaluation, 179 ) 180 dynamo_end = time.time() 181 if check_loss(ref_loss, res_loss): 182 print( 183 "[PASSED] TorchDynamo end to end training loss is less than or equal to native PyTorch" 184 ) 185 else: 186 print( 187 "[FAILED] TorchDynamo end to end training loss is greater than native Pytorch" 188 ) 189 if args.evaluation: 190 print(f"Model accuracy: {accuracy}") 191 native_elapsed = native_end - native_start 192 dynamo_elapsed = dynamo_end - native_end 193 print( 194 f"Train model on {args.epochs} epochs with backend {args.backend} and optimizer {args.optimizer}:" 195 ) 196 print(f"PyTorch spent {timedelta(seconds=native_elapsed/args.epochs)} per epoch") 197 print( 198 f"TorchDynamo spent {timedelta(seconds=dynamo_elapsed/args.epochs)} per epoch" 199 ) 200 201 202if __name__ == "__main__": 203 main() 204