xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/ddp/benchmark.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2#
3# Measure distributed training iteration time.
4#
5# This program performs a sweep over a) a number of model architectures, and
6# b) an increasing number of processes. This produces a 1-GPU baseline,
7# an 8-GPU baseline (if applicable), as well as measurements for however
8# many processes can participate in training.
9#
10
11import argparse
12import itertools
13import json
14import os
15import shlex
16import subprocess
17import sys
18import time
19
20import numpy as np
21import torchvision
22
23import torch
24import torch.distributed as dist
25import torch.nn as nn
26import torch.optim as optim
27
28
29def allgather_object(obj):
30    out = [None for _ in range(dist.get_world_size())]
31    dist.all_gather_object(out, obj)
32    return out
33
34
35def allgather_run(cmd):
36    proc = subprocess.run(shlex.split(cmd), capture_output=True)
37    assert proc.returncode == 0
38    return allgather_object(proc.stdout.decode("utf-8"))
39
40
41def allequal(iterator):
42    iterator = iter(iterator)
43    try:
44        first = next(iterator)
45    except StopIteration:
46        return True
47    return all(first == rest for rest in iterator)
48
49
50def benchmark_process_group(pg, benchmark, use_ddp_for_single_rank=True):
51    torch.manual_seed(pg.rank())
52    torch.cuda.manual_seed(pg.rank())
53
54    model = benchmark.create_model()
55    data = [(benchmark.generate_inputs(), benchmark.generate_target())]
56    criterion = nn.CrossEntropyLoss()
57    optimizer = optim.SGD(model.parameters(), 0.001, momentum=0.9, weight_decay=1e-4)
58    if use_ddp_for_single_rank or pg.size() > 1:
59        model = torch.nn.parallel.DistributedDataParallel(
60            model,
61            device_ids=[torch.cuda.current_device()],
62            broadcast_buffers=False,
63            process_group=pg,
64            bucket_cap_mb=benchmark.bucket_size,
65        )
66
67    measurements = []
68    warmup_iterations = 5
69    measured_iterations = 10
70    for inputs, target in data * (warmup_iterations + measured_iterations):
71        start = time.time()
72        output = model(*inputs)
73        loss = criterion(output, target)
74        loss.backward()
75        optimizer.step()
76        torch.cuda.synchronize()
77        measurements.append(time.time() - start)
78
79    # Throw away measurements for warmup iterations
80    return measurements[warmup_iterations:]
81
82
83def run_benchmark(benchmark, ranks, opts):
84    group = dist.new_group(ranks=ranks, backend=benchmark.distributed_backend)
85    measurements = []
86    if dist.get_rank() in set(ranks):
87        if not opts:
88            opts = {}
89        measurements = benchmark_process_group(group, benchmark, **opts)
90    dist.destroy_process_group(group)
91    dist.barrier()
92
93    # Aggregate measurements for better estimation of percentiles
94    return list(itertools.chain(*allgather_object(measurements)))
95
96
97def sweep(benchmark):
98    # Synthesize the set of benchmarks to run.
99    # This list contain tuples for ("string prefix", [rank...]).
100    benchmarks = []
101
102    def append_benchmark(prefix, ranks, opts=None):
103        prefix = f"{len(ranks):4} GPUs -- {prefix}"
104        benchmarks.append((prefix, ranks, opts))
105
106    def local_print(msg):
107        if dist.get_rank() == 0:
108            print(msg, end="", flush=True)  # noqa: E999
109
110    def print_header():
111        local_print("\n")
112        local_print("%22s" % "")
113        for p in [50, 75, 90, 95]:
114            local_print("%14s%10s" % ("sec/iter", "ex/sec"))
115        local_print("\n")
116
117    def print_measurements(prefix, nelem, measurements):
118        measurements = sorted(measurements)
119        local_print("%8s:" % prefix)
120        for p in [50, 75, 90, 95]:
121            v = np.percentile(measurements, p)
122            local_print("  p%02d:  %1.3fs  %6d/s" % (p, v, nelem / v))
123        local_print("\n")
124
125    # Every process runs once by themselves to warm up (CUDA init, etc).
126    append_benchmark("  warmup", [dist.get_rank()], {"use_ddp_for_single_rank": False})
127
128    # Single machine baselines
129    append_benchmark("  no ddp", range(1), {"use_ddp_for_single_rank": False})
130    append_benchmark("   1M/1G", range(1))
131    append_benchmark("   1M/2G", range(2))
132    append_benchmark("   1M/4G", range(4))
133
134    # Multi-machine benchmarks
135    for i in range(1, (dist.get_world_size() // 8) + 1):
136        append_benchmark("   %dM/8G" % i, range(i * 8))
137
138    # Run benchmarks in order of increasing number of GPUs
139    print_header()
140    results = []
141    for prefix, ranks, opts in sorted(benchmarks, key=lambda tup: len(tup[1])):
142        # Turn range into materialized list.
143        ranks = list(ranks)
144        measurements = run_benchmark(benchmark, ranks, opts)
145        if "warmup" not in prefix:
146            print_measurements(prefix, benchmark.batch_size, measurements)
147            results.append({"ranks": ranks, "measurements": measurements})
148
149    return results
150
151
152class Benchmark:
153    def __init__(self, device, distributed_backend, bucket_size):
154        self.device = device
155        self.batch_size = 32
156        self.distributed_backend = distributed_backend
157        self.bucket_size = bucket_size
158
159    def __str__(self):
160        raise NotImplementedError
161
162    def create_model(self):
163        raise NotImplementedError
164
165    def generate_inputs(self):
166        raise NotImplementedError
167
168    def generate_target(self):
169        raise NotImplementedError
170
171
172class TorchvisionBenchmark(Benchmark):
173    def __init__(self, device, distributed_backend, bucket_size, model):
174        super().__init__(
175            device,
176            distributed_backend,
177            bucket_size,
178        )
179        self.model = model
180
181    def __str__(self):
182        return f"{self.model} with batch size {self.batch_size}"
183
184    def create_model(self):
185        return torchvision.models.__dict__[self.model]().to(self.device)
186
187    def generate_inputs(self):
188        return [torch.rand([self.batch_size, 3, 224, 224], device=self.device)]
189
190    def generate_target(self):
191        return torch.tensor([1] * self.batch_size, dtype=torch.long, device=self.device)
192
193
194def main():
195    parser = argparse.ArgumentParser(description="PyTorch distributed benchmark suite")
196    parser.add_argument("--rank", type=int, default=os.environ["RANK"])
197    parser.add_argument("--world-size", type=int, required=True)
198    parser.add_argument("--distributed-backend", type=str, default="nccl")
199    parser.add_argument("--bucket-size", type=int, default=25)
200    parser.add_argument("--master-addr", type=str, required=True)
201    parser.add_argument("--master-port", type=str, required=True)
202    parser.add_argument("--model", type=str)
203    parser.add_argument(
204        "--json", type=str, metavar="PATH", help="Write file with benchmark results"
205    )
206    args = parser.parse_args()
207
208    num_gpus_per_node = torch.cuda.device_count()
209    assert num_gpus_per_node == 8, "Expected 8 GPUs per machine"
210
211    # The global process group used only for communicating benchmark
212    # metadata, like measurements. Not for benchmarking itself.
213    dist.init_process_group(
214        backend="gloo",
215        init_method=f"tcp://{args.master_addr}:{args.master_port}",
216        rank=args.rank,
217        world_size=args.world_size,
218    )
219
220    output = allgather_run("nvidia-smi topo -m")
221    if not allequal(output):
222        print('Output of "nvidia-smi topo -m" differs between machines')
223        sys.exit(1)
224
225    if args.rank == 0:
226        print("-----------------------------------")
227        print("PyTorch distributed benchmark suite")
228        print("-----------------------------------")
229        print()
230        print(f"* PyTorch version: {torch.__version__}")
231        print(f"* CUDA version: {torch.version.cuda}")
232        print(f"* Distributed backend: {args.distributed_backend}")
233        print(f"* Maximum bucket size: {args.bucket_size}MB")
234        print()
235        print("--- nvidia-smi topo -m ---")
236        print()
237        print(output[0])
238        print("--------------------------")
239        print()
240
241    torch.cuda.set_device(dist.get_rank() % 8)
242    device = torch.device("cuda:%d" % (dist.get_rank() % 8))
243
244    benchmarks = []
245    if args.model:
246        benchmarks.append(
247            TorchvisionBenchmark(
248                device=device,
249                distributed_backend=args.distributed_backend,
250                bucket_size=args.bucket_size,
251                model=args.model,
252            )
253        )
254    else:
255        for model in ["resnet50", "resnet101", "resnext50_32x4d", "resnext101_32x8d"]:
256            benchmarks.append(
257                TorchvisionBenchmark(
258                    device=device,
259                    distributed_backend=args.distributed_backend,
260                    bucket_size=args.bucket_size,
261                    model=model,
262                )
263            )
264
265    benchmark_results = []
266    for benchmark in benchmarks:
267        if args.rank == 0:
268            print(f"\nBenchmark: {str(benchmark)}")
269        result = sweep(benchmark)
270        benchmark_results.append(
271            {
272                "model": benchmark.model,
273                "batch_size": benchmark.batch_size,
274                "result": result,
275            }
276        )
277
278    # Write file with benchmark results if applicable
279    if args.rank == 0 and args.json:
280        report = {
281            "pytorch_version": torch.__version__,
282            "cuda_version": torch.version.cuda,
283            "distributed_backend": args.distributed_backend,
284            "bucket_size": args.bucket_size,
285            "benchmark_results": benchmark_results,
286        }
287        with open(args.json, "w") as f:
288            json.dump(report, f)
289
290
291if __name__ == "__main__":
292    main()
293