1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Workerimport inspect 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerimport time 6*da0073e9SAndroid Build Coastguard Workerfrom datetime import timedelta 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerfrom datasets import load_dataset, load_metric 9*da0073e9SAndroid Build Coastguard Workerfrom transformers import AutoModelForSequenceClassification, AutoTokenizer 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerimport torch 12*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo 13*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data import DataLoader 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workertorch.backends.cuda.matmul.allow_tf32 = True 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker# You will download around 84G dataset if you run this end to end training/evaluation example. 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workeros.environ["TOKENIZERS_PARALLELISM"] = "false" 21*da0073e9SAndroid Build Coastguard Workerdevice = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerdef data_processing(num_samples, batch_size): 25*da0073e9SAndroid Build Coastguard Worker dataset = load_dataset("yelp_review_full") 26*da0073e9SAndroid Build Coastguard Worker tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker def tokenize_function(examples): 29*da0073e9SAndroid Build Coastguard Worker return tokenizer(examples["text"], padding="max_length", truncation=True) 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker tokenized_datasets = dataset.map(tokenize_function, batched=True) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker tokenized_datasets = tokenized_datasets.remove_columns(["text"]) 34*da0073e9SAndroid Build Coastguard Worker tokenized_datasets = tokenized_datasets.rename_column("label", "labels") 35*da0073e9SAndroid Build Coastguard Worker tokenized_datasets.set_format("torch") 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker small_train_dataset = tokenized_datasets["train"].select(range(num_samples)) 38*da0073e9SAndroid Build Coastguard Worker small_eval_dataset = tokenized_datasets["test"].select(range(num_samples)) 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker train_dataloader = DataLoader(small_train_dataset, batch_size=batch_size) 41*da0073e9SAndroid Build Coastguard Worker eval_dataloader = DataLoader(small_eval_dataset, batch_size=batch_size) 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker return train_dataloader, eval_dataloader 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Workerdef training_iter_fn(batch, model, optimizer): 47*da0073e9SAndroid Build Coastguard Worker outputs = model(**batch) 48*da0073e9SAndroid Build Coastguard Worker loss = outputs.loss 49*da0073e9SAndroid Build Coastguard Worker loss.backward() 50*da0073e9SAndroid Build Coastguard Worker optimizer.step() 51*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 52*da0073e9SAndroid Build Coastguard Worker return loss 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Workerdef model_training_evaluation( 56*da0073e9SAndroid Build Coastguard Worker backend, train_dataloader, eval_dataloader, model, optimizer, num_epochs, evaluation 57*da0073e9SAndroid Build Coastguard Worker): 58*da0073e9SAndroid Build Coastguard Worker model.to(device) 59*da0073e9SAndroid Build Coastguard Worker model.train() 60*da0073e9SAndroid Build Coastguard Worker loss_history = [] 61*da0073e9SAndroid Build Coastguard Worker if not backend: 62*da0073e9SAndroid Build Coastguard Worker # Run with native Pytorch 63*da0073e9SAndroid Build Coastguard Worker opt_training_iter_fn = training_iter_fn 64*da0073e9SAndroid Build Coastguard Worker else: 65*da0073e9SAndroid Build Coastguard Worker # Support backends: eager, aot_eager, aot_nvfuser and inductor 66*da0073e9SAndroid Build Coastguard Worker opt_training_iter_fn = torch._dynamo.optimize(backend)(training_iter_fn) 67*da0073e9SAndroid Build Coastguard Worker for epoch in range(num_epochs): 68*da0073e9SAndroid Build Coastguard Worker running_loss = 0.0 69*da0073e9SAndroid Build Coastguard Worker for i, batch in enumerate(train_dataloader, 0): 70*da0073e9SAndroid Build Coastguard Worker batch = {k: v.to(device) for k, v in batch.items()} 71*da0073e9SAndroid Build Coastguard Worker loss = opt_training_iter_fn(batch, model, optimizer) 72*da0073e9SAndroid Build Coastguard Worker running_loss += loss.item() 73*da0073e9SAndroid Build Coastguard Worker if i % 100 == 99: 74*da0073e9SAndroid Build Coastguard Worker loss_history.append(running_loss / 100) 75*da0073e9SAndroid Build Coastguard Worker running_loss = 0.0 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker if evaluation: 78*da0073e9SAndroid Build Coastguard Worker metric = load_metric("accuracy") 79*da0073e9SAndroid Build Coastguard Worker model.eval() 80*da0073e9SAndroid Build Coastguard Worker if not backend: 81*da0073e9SAndroid Build Coastguard Worker opt_model = model 82*da0073e9SAndroid Build Coastguard Worker else: 83*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize(backend)(model) 84*da0073e9SAndroid Build Coastguard Worker for batch in eval_dataloader: 85*da0073e9SAndroid Build Coastguard Worker batch = {k: v.to(device) for k, v in batch.items()} 86*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 87*da0073e9SAndroid Build Coastguard Worker outputs = opt_model(**batch) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker logits = outputs.logits 90*da0073e9SAndroid Build Coastguard Worker predictions = torch.argmax(logits, dim=-1) 91*da0073e9SAndroid Build Coastguard Worker metric.add_batch(predictions=predictions, references=batch["labels"]) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker return loss_history, metric.compute() 94*da0073e9SAndroid Build Coastguard Worker else: 95*da0073e9SAndroid Build Coastguard Worker return loss_history, None 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Workerdef check_loss(ref_loss, res_loss): 99*da0073e9SAndroid Build Coastguard Worker assert len(ref_loss) == len(res_loss) 100*da0073e9SAndroid Build Coastguard Worker length = len(ref_loss) 101*da0073e9SAndroid Build Coastguard Worker x = min(length, 10) 102*da0073e9SAndroid Build Coastguard Worker return sum(res_loss[-x:]) / 10 <= sum(ref_loss[-x:]) / 10 + 0.1 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Workerdef parse_args(): 106*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser( 107*da0073e9SAndroid Build Coastguard Worker description="TorchDynamo end to end training/evaluation benchmark" 108*da0073e9SAndroid Build Coastguard Worker ) 109*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 110*da0073e9SAndroid Build Coastguard Worker "--epochs", type=int, default=10, help="number of epochs to train (default: 10)" 111*da0073e9SAndroid Build Coastguard Worker ) 112*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 113*da0073e9SAndroid Build Coastguard Worker "--num-samples", 114*da0073e9SAndroid Build Coastguard Worker type=int, 115*da0073e9SAndroid Build Coastguard Worker default=1000, 116*da0073e9SAndroid Build Coastguard Worker help="number of samples to train/eval (default: 1000)", 117*da0073e9SAndroid Build Coastguard Worker ) 118*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 119*da0073e9SAndroid Build Coastguard Worker "--batch-size", 120*da0073e9SAndroid Build Coastguard Worker type=int, 121*da0073e9SAndroid Build Coastguard Worker default=8, 122*da0073e9SAndroid Build Coastguard Worker help="input batch size for training (default: 8)", 123*da0073e9SAndroid Build Coastguard Worker ) 124*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 125*da0073e9SAndroid Build Coastguard Worker "--lr", type=float, default=5e-5, help="learning rate (default: 5e-5)" 126*da0073e9SAndroid Build Coastguard Worker ) 127*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 128*da0073e9SAndroid Build Coastguard Worker "--backend", 129*da0073e9SAndroid Build Coastguard Worker choices=torch._dynamo.list_backends(exclude_tags=None), 130*da0073e9SAndroid Build Coastguard Worker default="inductor", 131*da0073e9SAndroid Build Coastguard Worker help="train/evaluate model with a given backend (default: inductor)", 132*da0073e9SAndroid Build Coastguard Worker ) 133*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 134*da0073e9SAndroid Build Coastguard Worker "--optimizer", 135*da0073e9SAndroid Build Coastguard Worker default="Adam", 136*da0073e9SAndroid Build Coastguard Worker help="train model using a given optimizer (default: Adam)", 137*da0073e9SAndroid Build Coastguard Worker ) 138*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 139*da0073e9SAndroid Build Coastguard Worker "--evaluation", 140*da0073e9SAndroid Build Coastguard Worker action="store_true", 141*da0073e9SAndroid Build Coastguard Worker help="running evaluation after model training", 142*da0073e9SAndroid Build Coastguard Worker ) 143*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 144*da0073e9SAndroid Build Coastguard Worker return args 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Workerdef main(): 148*da0073e9SAndroid Build Coastguard Worker args = parse_args() 149*da0073e9SAndroid Build Coastguard Worker train_dataloader, eval_dataloader = data_processing( 150*da0073e9SAndroid Build Coastguard Worker args.num_samples, args.batch_size 151*da0073e9SAndroid Build Coastguard Worker ) 152*da0073e9SAndroid Build Coastguard Worker model = AutoModelForSequenceClassification.from_pretrained( 153*da0073e9SAndroid Build Coastguard Worker "bert-base-cased", num_labels=5 154*da0073e9SAndroid Build Coastguard Worker ) 155*da0073e9SAndroid Build Coastguard Worker optimizer_cls = getattr(sys.modules["torch.optim"], args.optimizer) 156*da0073e9SAndroid Build Coastguard Worker if "capturable" in inspect.signature(optimizer_cls).parameters.keys(): 157*da0073e9SAndroid Build Coastguard Worker optimizer = optimizer_cls(model.parameters(), lr=args.lr, capturable=True) 158*da0073e9SAndroid Build Coastguard Worker else: 159*da0073e9SAndroid Build Coastguard Worker optimizer = optimizer_cls(model.parameters(), lr=args.lr) 160*da0073e9SAndroid Build Coastguard Worker native_start = time.time() 161*da0073e9SAndroid Build Coastguard Worker ref_loss, accuracy = model_training_evaluation( 162*da0073e9SAndroid Build Coastguard Worker None, 163*da0073e9SAndroid Build Coastguard Worker train_dataloader, 164*da0073e9SAndroid Build Coastguard Worker eval_dataloader, 165*da0073e9SAndroid Build Coastguard Worker model, 166*da0073e9SAndroid Build Coastguard Worker optimizer, 167*da0073e9SAndroid Build Coastguard Worker args.epochs, 168*da0073e9SAndroid Build Coastguard Worker args.evaluation, 169*da0073e9SAndroid Build Coastguard Worker ) 170*da0073e9SAndroid Build Coastguard Worker native_end = time.time() 171*da0073e9SAndroid Build Coastguard Worker res_loss, accuracy = model_training_evaluation( 172*da0073e9SAndroid Build Coastguard Worker args.backend, 173*da0073e9SAndroid Build Coastguard Worker train_dataloader, 174*da0073e9SAndroid Build Coastguard Worker eval_dataloader, 175*da0073e9SAndroid Build Coastguard Worker model, 176*da0073e9SAndroid Build Coastguard Worker optimizer, 177*da0073e9SAndroid Build Coastguard Worker args.epochs, 178*da0073e9SAndroid Build Coastguard Worker args.evaluation, 179*da0073e9SAndroid Build Coastguard Worker ) 180*da0073e9SAndroid Build Coastguard Worker dynamo_end = time.time() 181*da0073e9SAndroid Build Coastguard Worker if check_loss(ref_loss, res_loss): 182*da0073e9SAndroid Build Coastguard Worker print( 183*da0073e9SAndroid Build Coastguard Worker "[PASSED] TorchDynamo end to end training loss is less than or equal to native PyTorch" 184*da0073e9SAndroid Build Coastguard Worker ) 185*da0073e9SAndroid Build Coastguard Worker else: 186*da0073e9SAndroid Build Coastguard Worker print( 187*da0073e9SAndroid Build Coastguard Worker "[FAILED] TorchDynamo end to end training loss is greater than native Pytorch" 188*da0073e9SAndroid Build Coastguard Worker ) 189*da0073e9SAndroid Build Coastguard Worker if args.evaluation: 190*da0073e9SAndroid Build Coastguard Worker print(f"Model accuracy: {accuracy}") 191*da0073e9SAndroid Build Coastguard Worker native_elapsed = native_end - native_start 192*da0073e9SAndroid Build Coastguard Worker dynamo_elapsed = dynamo_end - native_end 193*da0073e9SAndroid Build Coastguard Worker print( 194*da0073e9SAndroid Build Coastguard Worker f"Train model on {args.epochs} epochs with backend {args.backend} and optimizer {args.optimizer}:" 195*da0073e9SAndroid Build Coastguard Worker ) 196*da0073e9SAndroid Build Coastguard Worker print(f"PyTorch spent {timedelta(seconds=native_elapsed/args.epochs)} per epoch") 197*da0073e9SAndroid Build Coastguard Worker print( 198*da0073e9SAndroid Build Coastguard Worker f"TorchDynamo spent {timedelta(seconds=dynamo_elapsed/args.epochs)} per epoch" 199*da0073e9SAndroid Build Coastguard Worker ) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 203*da0073e9SAndroid Build Coastguard Worker main() 204