xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/training_loss.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport inspect
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerimport time
6*da0073e9SAndroid Build Coastguard Workerfrom datetime import timedelta
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerfrom datasets import load_dataset, load_metric
9*da0073e9SAndroid Build Coastguard Workerfrom transformers import AutoModelForSequenceClassification, AutoTokenizer
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerimport torch
12*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo
13*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.data import DataLoader
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workertorch.backends.cuda.matmul.allow_tf32 = True
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker# You will download around 84G dataset if you run this end to end training/evaluation example.
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workeros.environ["TOKENIZERS_PARALLELISM"] = "false"
21*da0073e9SAndroid Build Coastguard Workerdevice = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerdef data_processing(num_samples, batch_size):
25*da0073e9SAndroid Build Coastguard Worker    dataset = load_dataset("yelp_review_full")
26*da0073e9SAndroid Build Coastguard Worker    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    def tokenize_function(examples):
29*da0073e9SAndroid Build Coastguard Worker        return tokenizer(examples["text"], padding="max_length", truncation=True)
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    tokenized_datasets = dataset.map(tokenize_function, batched=True)
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker    tokenized_datasets = tokenized_datasets.remove_columns(["text"])
34*da0073e9SAndroid Build Coastguard Worker    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
35*da0073e9SAndroid Build Coastguard Worker    tokenized_datasets.set_format("torch")
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker    small_train_dataset = tokenized_datasets["train"].select(range(num_samples))
38*da0073e9SAndroid Build Coastguard Worker    small_eval_dataset = tokenized_datasets["test"].select(range(num_samples))
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker    train_dataloader = DataLoader(small_train_dataset, batch_size=batch_size)
41*da0073e9SAndroid Build Coastguard Worker    eval_dataloader = DataLoader(small_eval_dataset, batch_size=batch_size)
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    return train_dataloader, eval_dataloader
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerdef training_iter_fn(batch, model, optimizer):
47*da0073e9SAndroid Build Coastguard Worker    outputs = model(**batch)
48*da0073e9SAndroid Build Coastguard Worker    loss = outputs.loss
49*da0073e9SAndroid Build Coastguard Worker    loss.backward()
50*da0073e9SAndroid Build Coastguard Worker    optimizer.step()
51*da0073e9SAndroid Build Coastguard Worker    optimizer.zero_grad()
52*da0073e9SAndroid Build Coastguard Worker    return loss
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Workerdef model_training_evaluation(
56*da0073e9SAndroid Build Coastguard Worker    backend, train_dataloader, eval_dataloader, model, optimizer, num_epochs, evaluation
57*da0073e9SAndroid Build Coastguard Worker):
58*da0073e9SAndroid Build Coastguard Worker    model.to(device)
59*da0073e9SAndroid Build Coastguard Worker    model.train()
60*da0073e9SAndroid Build Coastguard Worker    loss_history = []
61*da0073e9SAndroid Build Coastguard Worker    if not backend:
62*da0073e9SAndroid Build Coastguard Worker        # Run with native Pytorch
63*da0073e9SAndroid Build Coastguard Worker        opt_training_iter_fn = training_iter_fn
64*da0073e9SAndroid Build Coastguard Worker    else:
65*da0073e9SAndroid Build Coastguard Worker        # Support backends: eager, aot_eager, aot_nvfuser and inductor
66*da0073e9SAndroid Build Coastguard Worker        opt_training_iter_fn = torch._dynamo.optimize(backend)(training_iter_fn)
67*da0073e9SAndroid Build Coastguard Worker    for epoch in range(num_epochs):
68*da0073e9SAndroid Build Coastguard Worker        running_loss = 0.0
69*da0073e9SAndroid Build Coastguard Worker        for i, batch in enumerate(train_dataloader, 0):
70*da0073e9SAndroid Build Coastguard Worker            batch = {k: v.to(device) for k, v in batch.items()}
71*da0073e9SAndroid Build Coastguard Worker            loss = opt_training_iter_fn(batch, model, optimizer)
72*da0073e9SAndroid Build Coastguard Worker            running_loss += loss.item()
73*da0073e9SAndroid Build Coastguard Worker            if i % 100 == 99:
74*da0073e9SAndroid Build Coastguard Worker                loss_history.append(running_loss / 100)
75*da0073e9SAndroid Build Coastguard Worker                running_loss = 0.0
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    if evaluation:
78*da0073e9SAndroid Build Coastguard Worker        metric = load_metric("accuracy")
79*da0073e9SAndroid Build Coastguard Worker        model.eval()
80*da0073e9SAndroid Build Coastguard Worker        if not backend:
81*da0073e9SAndroid Build Coastguard Worker            opt_model = model
82*da0073e9SAndroid Build Coastguard Worker        else:
83*da0073e9SAndroid Build Coastguard Worker            opt_model = torch._dynamo.optimize(backend)(model)
84*da0073e9SAndroid Build Coastguard Worker        for batch in eval_dataloader:
85*da0073e9SAndroid Build Coastguard Worker            batch = {k: v.to(device) for k, v in batch.items()}
86*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
87*da0073e9SAndroid Build Coastguard Worker                outputs = opt_model(**batch)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker            logits = outputs.logits
90*da0073e9SAndroid Build Coastguard Worker            predictions = torch.argmax(logits, dim=-1)
91*da0073e9SAndroid Build Coastguard Worker            metric.add_batch(predictions=predictions, references=batch["labels"])
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker        return loss_history, metric.compute()
94*da0073e9SAndroid Build Coastguard Worker    else:
95*da0073e9SAndroid Build Coastguard Worker        return loss_history, None
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Workerdef check_loss(ref_loss, res_loss):
99*da0073e9SAndroid Build Coastguard Worker    assert len(ref_loss) == len(res_loss)
100*da0073e9SAndroid Build Coastguard Worker    length = len(ref_loss)
101*da0073e9SAndroid Build Coastguard Worker    x = min(length, 10)
102*da0073e9SAndroid Build Coastguard Worker    return sum(res_loss[-x:]) / 10 <= sum(ref_loss[-x:]) / 10 + 0.1
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Workerdef parse_args():
106*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(
107*da0073e9SAndroid Build Coastguard Worker        description="TorchDynamo end to end training/evaluation benchmark"
108*da0073e9SAndroid Build Coastguard Worker    )
109*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
110*da0073e9SAndroid Build Coastguard Worker        "--epochs", type=int, default=10, help="number of epochs to train (default: 10)"
111*da0073e9SAndroid Build Coastguard Worker    )
112*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
113*da0073e9SAndroid Build Coastguard Worker        "--num-samples",
114*da0073e9SAndroid Build Coastguard Worker        type=int,
115*da0073e9SAndroid Build Coastguard Worker        default=1000,
116*da0073e9SAndroid Build Coastguard Worker        help="number of samples to train/eval (default: 1000)",
117*da0073e9SAndroid Build Coastguard Worker    )
118*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
119*da0073e9SAndroid Build Coastguard Worker        "--batch-size",
120*da0073e9SAndroid Build Coastguard Worker        type=int,
121*da0073e9SAndroid Build Coastguard Worker        default=8,
122*da0073e9SAndroid Build Coastguard Worker        help="input batch size for training (default: 8)",
123*da0073e9SAndroid Build Coastguard Worker    )
124*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
125*da0073e9SAndroid Build Coastguard Worker        "--lr", type=float, default=5e-5, help="learning rate (default: 5e-5)"
126*da0073e9SAndroid Build Coastguard Worker    )
127*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
128*da0073e9SAndroid Build Coastguard Worker        "--backend",
129*da0073e9SAndroid Build Coastguard Worker        choices=torch._dynamo.list_backends(exclude_tags=None),
130*da0073e9SAndroid Build Coastguard Worker        default="inductor",
131*da0073e9SAndroid Build Coastguard Worker        help="train/evaluate model with a given backend (default: inductor)",
132*da0073e9SAndroid Build Coastguard Worker    )
133*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
134*da0073e9SAndroid Build Coastguard Worker        "--optimizer",
135*da0073e9SAndroid Build Coastguard Worker        default="Adam",
136*da0073e9SAndroid Build Coastguard Worker        help="train model using a given optimizer (default: Adam)",
137*da0073e9SAndroid Build Coastguard Worker    )
138*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
139*da0073e9SAndroid Build Coastguard Worker        "--evaluation",
140*da0073e9SAndroid Build Coastguard Worker        action="store_true",
141*da0073e9SAndroid Build Coastguard Worker        help="running evaluation after model training",
142*da0073e9SAndroid Build Coastguard Worker    )
143*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
144*da0073e9SAndroid Build Coastguard Worker    return args
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Workerdef main():
148*da0073e9SAndroid Build Coastguard Worker    args = parse_args()
149*da0073e9SAndroid Build Coastguard Worker    train_dataloader, eval_dataloader = data_processing(
150*da0073e9SAndroid Build Coastguard Worker        args.num_samples, args.batch_size
151*da0073e9SAndroid Build Coastguard Worker    )
152*da0073e9SAndroid Build Coastguard Worker    model = AutoModelForSequenceClassification.from_pretrained(
153*da0073e9SAndroid Build Coastguard Worker        "bert-base-cased", num_labels=5
154*da0073e9SAndroid Build Coastguard Worker    )
155*da0073e9SAndroid Build Coastguard Worker    optimizer_cls = getattr(sys.modules["torch.optim"], args.optimizer)
156*da0073e9SAndroid Build Coastguard Worker    if "capturable" in inspect.signature(optimizer_cls).parameters.keys():
157*da0073e9SAndroid Build Coastguard Worker        optimizer = optimizer_cls(model.parameters(), lr=args.lr, capturable=True)
158*da0073e9SAndroid Build Coastguard Worker    else:
159*da0073e9SAndroid Build Coastguard Worker        optimizer = optimizer_cls(model.parameters(), lr=args.lr)
160*da0073e9SAndroid Build Coastguard Worker    native_start = time.time()
161*da0073e9SAndroid Build Coastguard Worker    ref_loss, accuracy = model_training_evaluation(
162*da0073e9SAndroid Build Coastguard Worker        None,
163*da0073e9SAndroid Build Coastguard Worker        train_dataloader,
164*da0073e9SAndroid Build Coastguard Worker        eval_dataloader,
165*da0073e9SAndroid Build Coastguard Worker        model,
166*da0073e9SAndroid Build Coastguard Worker        optimizer,
167*da0073e9SAndroid Build Coastguard Worker        args.epochs,
168*da0073e9SAndroid Build Coastguard Worker        args.evaluation,
169*da0073e9SAndroid Build Coastguard Worker    )
170*da0073e9SAndroid Build Coastguard Worker    native_end = time.time()
171*da0073e9SAndroid Build Coastguard Worker    res_loss, accuracy = model_training_evaluation(
172*da0073e9SAndroid Build Coastguard Worker        args.backend,
173*da0073e9SAndroid Build Coastguard Worker        train_dataloader,
174*da0073e9SAndroid Build Coastguard Worker        eval_dataloader,
175*da0073e9SAndroid Build Coastguard Worker        model,
176*da0073e9SAndroid Build Coastguard Worker        optimizer,
177*da0073e9SAndroid Build Coastguard Worker        args.epochs,
178*da0073e9SAndroid Build Coastguard Worker        args.evaluation,
179*da0073e9SAndroid Build Coastguard Worker    )
180*da0073e9SAndroid Build Coastguard Worker    dynamo_end = time.time()
181*da0073e9SAndroid Build Coastguard Worker    if check_loss(ref_loss, res_loss):
182*da0073e9SAndroid Build Coastguard Worker        print(
183*da0073e9SAndroid Build Coastguard Worker            "[PASSED] TorchDynamo end to end training loss is less than or equal to native PyTorch"
184*da0073e9SAndroid Build Coastguard Worker        )
185*da0073e9SAndroid Build Coastguard Worker    else:
186*da0073e9SAndroid Build Coastguard Worker        print(
187*da0073e9SAndroid Build Coastguard Worker            "[FAILED] TorchDynamo end to end training loss is greater than native Pytorch"
188*da0073e9SAndroid Build Coastguard Worker        )
189*da0073e9SAndroid Build Coastguard Worker    if args.evaluation:
190*da0073e9SAndroid Build Coastguard Worker        print(f"Model accuracy: {accuracy}")
191*da0073e9SAndroid Build Coastguard Worker    native_elapsed = native_end - native_start
192*da0073e9SAndroid Build Coastguard Worker    dynamo_elapsed = dynamo_end - native_end
193*da0073e9SAndroid Build Coastguard Worker    print(
194*da0073e9SAndroid Build Coastguard Worker        f"Train model on {args.epochs} epochs with backend {args.backend} and optimizer {args.optimizer}:"
195*da0073e9SAndroid Build Coastguard Worker    )
196*da0073e9SAndroid Build Coastguard Worker    print(f"PyTorch spent {timedelta(seconds=native_elapsed/args.epochs)} per epoch")
197*da0073e9SAndroid Build Coastguard Worker    print(
198*da0073e9SAndroid Build Coastguard Worker        f"TorchDynamo spent {timedelta(seconds=dynamo_elapsed/args.epochs)} per epoch"
199*da0073e9SAndroid Build Coastguard Worker    )
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
203*da0073e9SAndroid Build Coastguard Worker    main()
204