xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/training_loss.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import inspect
3import os
4import sys
5import time
6from datetime import timedelta
7
8from datasets import load_dataset, load_metric
9from transformers import AutoModelForSequenceClassification, AutoTokenizer
10
11import torch
12import torch._dynamo
13from torch.utils.data import DataLoader
14
15
16torch.backends.cuda.matmul.allow_tf32 = True
17
18# You will download around 84G dataset if you run this end to end training/evaluation example.
19
20os.environ["TOKENIZERS_PARALLELISM"] = "false"
21device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
22
23
24def data_processing(num_samples, batch_size):
25    dataset = load_dataset("yelp_review_full")
26    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
27
28    def tokenize_function(examples):
29        return tokenizer(examples["text"], padding="max_length", truncation=True)
30
31    tokenized_datasets = dataset.map(tokenize_function, batched=True)
32
33    tokenized_datasets = tokenized_datasets.remove_columns(["text"])
34    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
35    tokenized_datasets.set_format("torch")
36
37    small_train_dataset = tokenized_datasets["train"].select(range(num_samples))
38    small_eval_dataset = tokenized_datasets["test"].select(range(num_samples))
39
40    train_dataloader = DataLoader(small_train_dataset, batch_size=batch_size)
41    eval_dataloader = DataLoader(small_eval_dataset, batch_size=batch_size)
42
43    return train_dataloader, eval_dataloader
44
45
46def training_iter_fn(batch, model, optimizer):
47    outputs = model(**batch)
48    loss = outputs.loss
49    loss.backward()
50    optimizer.step()
51    optimizer.zero_grad()
52    return loss
53
54
55def model_training_evaluation(
56    backend, train_dataloader, eval_dataloader, model, optimizer, num_epochs, evaluation
57):
58    model.to(device)
59    model.train()
60    loss_history = []
61    if not backend:
62        # Run with native Pytorch
63        opt_training_iter_fn = training_iter_fn
64    else:
65        # Support backends: eager, aot_eager, aot_nvfuser and inductor
66        opt_training_iter_fn = torch._dynamo.optimize(backend)(training_iter_fn)
67    for epoch in range(num_epochs):
68        running_loss = 0.0
69        for i, batch in enumerate(train_dataloader, 0):
70            batch = {k: v.to(device) for k, v in batch.items()}
71            loss = opt_training_iter_fn(batch, model, optimizer)
72            running_loss += loss.item()
73            if i % 100 == 99:
74                loss_history.append(running_loss / 100)
75                running_loss = 0.0
76
77    if evaluation:
78        metric = load_metric("accuracy")
79        model.eval()
80        if not backend:
81            opt_model = model
82        else:
83            opt_model = torch._dynamo.optimize(backend)(model)
84        for batch in eval_dataloader:
85            batch = {k: v.to(device) for k, v in batch.items()}
86            with torch.no_grad():
87                outputs = opt_model(**batch)
88
89            logits = outputs.logits
90            predictions = torch.argmax(logits, dim=-1)
91            metric.add_batch(predictions=predictions, references=batch["labels"])
92
93        return loss_history, metric.compute()
94    else:
95        return loss_history, None
96
97
98def check_loss(ref_loss, res_loss):
99    assert len(ref_loss) == len(res_loss)
100    length = len(ref_loss)
101    x = min(length, 10)
102    return sum(res_loss[-x:]) / 10 <= sum(ref_loss[-x:]) / 10 + 0.1
103
104
105def parse_args():
106    parser = argparse.ArgumentParser(
107        description="TorchDynamo end to end training/evaluation benchmark"
108    )
109    parser.add_argument(
110        "--epochs", type=int, default=10, help="number of epochs to train (default: 10)"
111    )
112    parser.add_argument(
113        "--num-samples",
114        type=int,
115        default=1000,
116        help="number of samples to train/eval (default: 1000)",
117    )
118    parser.add_argument(
119        "--batch-size",
120        type=int,
121        default=8,
122        help="input batch size for training (default: 8)",
123    )
124    parser.add_argument(
125        "--lr", type=float, default=5e-5, help="learning rate (default: 5e-5)"
126    )
127    parser.add_argument(
128        "--backend",
129        choices=torch._dynamo.list_backends(exclude_tags=None),
130        default="inductor",
131        help="train/evaluate model with a given backend (default: inductor)",
132    )
133    parser.add_argument(
134        "--optimizer",
135        default="Adam",
136        help="train model using a given optimizer (default: Adam)",
137    )
138    parser.add_argument(
139        "--evaluation",
140        action="store_true",
141        help="running evaluation after model training",
142    )
143    args = parser.parse_args()
144    return args
145
146
147def main():
148    args = parse_args()
149    train_dataloader, eval_dataloader = data_processing(
150        args.num_samples, args.batch_size
151    )
152    model = AutoModelForSequenceClassification.from_pretrained(
153        "bert-base-cased", num_labels=5
154    )
155    optimizer_cls = getattr(sys.modules["torch.optim"], args.optimizer)
156    if "capturable" in inspect.signature(optimizer_cls).parameters.keys():
157        optimizer = optimizer_cls(model.parameters(), lr=args.lr, capturable=True)
158    else:
159        optimizer = optimizer_cls(model.parameters(), lr=args.lr)
160    native_start = time.time()
161    ref_loss, accuracy = model_training_evaluation(
162        None,
163        train_dataloader,
164        eval_dataloader,
165        model,
166        optimizer,
167        args.epochs,
168        args.evaluation,
169    )
170    native_end = time.time()
171    res_loss, accuracy = model_training_evaluation(
172        args.backend,
173        train_dataloader,
174        eval_dataloader,
175        model,
176        optimizer,
177        args.epochs,
178        args.evaluation,
179    )
180    dynamo_end = time.time()
181    if check_loss(ref_loss, res_loss):
182        print(
183            "[PASSED] TorchDynamo end to end training loss is less than or equal to native PyTorch"
184        )
185    else:
186        print(
187            "[FAILED] TorchDynamo end to end training loss is greater than native Pytorch"
188        )
189    if args.evaluation:
190        print(f"Model accuracy: {accuracy}")
191    native_elapsed = native_end - native_start
192    dynamo_elapsed = dynamo_end - native_end
193    print(
194        f"Train model on {args.epochs} epochs with backend {args.backend} and optimizer {args.optimizer}:"
195    )
196    print(f"PyTorch spent {timedelta(seconds=native_elapsed/args.epochs)} per epoch")
197    print(
198        f"TorchDynamo spent {timedelta(seconds=dynamo_elapsed/args.epochs)} per epoch"
199    )
200
201
202if __name__ == "__main__":
203    main()
204