xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/ddp/benchmark.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker#
3*da0073e9SAndroid Build Coastguard Worker# Measure distributed training iteration time.
4*da0073e9SAndroid Build Coastguard Worker#
5*da0073e9SAndroid Build Coastguard Worker# This program performs a sweep over a) a number of model architectures, and
6*da0073e9SAndroid Build Coastguard Worker# b) an increasing number of processes. This produces a 1-GPU baseline,
7*da0073e9SAndroid Build Coastguard Worker# an 8-GPU baseline (if applicable), as well as measurements for however
8*da0073e9SAndroid Build Coastguard Worker# many processes can participate in training.
9*da0073e9SAndroid Build Coastguard Worker#
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerimport argparse
12*da0073e9SAndroid Build Coastguard Workerimport itertools
13*da0073e9SAndroid Build Coastguard Workerimport json
14*da0073e9SAndroid Build Coastguard Workerimport os
15*da0073e9SAndroid Build Coastguard Workerimport shlex
16*da0073e9SAndroid Build Coastguard Workerimport subprocess
17*da0073e9SAndroid Build Coastguard Workerimport sys
18*da0073e9SAndroid Build Coastguard Workerimport time
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerimport numpy as np
21*da0073e9SAndroid Build Coastguard Workerimport torchvision
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Workerimport torch
24*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as dist
25*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
26*da0073e9SAndroid Build Coastguard Workerimport torch.optim as optim
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerdef allgather_object(obj):
30*da0073e9SAndroid Build Coastguard Worker    out = [None for _ in range(dist.get_world_size())]
31*da0073e9SAndroid Build Coastguard Worker    dist.all_gather_object(out, obj)
32*da0073e9SAndroid Build Coastguard Worker    return out
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Workerdef allgather_run(cmd):
36*da0073e9SAndroid Build Coastguard Worker    proc = subprocess.run(shlex.split(cmd), capture_output=True)
37*da0073e9SAndroid Build Coastguard Worker    assert proc.returncode == 0
38*da0073e9SAndroid Build Coastguard Worker    return allgather_object(proc.stdout.decode("utf-8"))
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Workerdef allequal(iterator):
42*da0073e9SAndroid Build Coastguard Worker    iterator = iter(iterator)
43*da0073e9SAndroid Build Coastguard Worker    try:
44*da0073e9SAndroid Build Coastguard Worker        first = next(iterator)
45*da0073e9SAndroid Build Coastguard Worker    except StopIteration:
46*da0073e9SAndroid Build Coastguard Worker        return True
47*da0073e9SAndroid Build Coastguard Worker    return all(first == rest for rest in iterator)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Workerdef benchmark_process_group(pg, benchmark, use_ddp_for_single_rank=True):
51*da0073e9SAndroid Build Coastguard Worker    torch.manual_seed(pg.rank())
52*da0073e9SAndroid Build Coastguard Worker    torch.cuda.manual_seed(pg.rank())
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker    model = benchmark.create_model()
55*da0073e9SAndroid Build Coastguard Worker    data = [(benchmark.generate_inputs(), benchmark.generate_target())]
56*da0073e9SAndroid Build Coastguard Worker    criterion = nn.CrossEntropyLoss()
57*da0073e9SAndroid Build Coastguard Worker    optimizer = optim.SGD(model.parameters(), 0.001, momentum=0.9, weight_decay=1e-4)
58*da0073e9SAndroid Build Coastguard Worker    if use_ddp_for_single_rank or pg.size() > 1:
59*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.parallel.DistributedDataParallel(
60*da0073e9SAndroid Build Coastguard Worker            model,
61*da0073e9SAndroid Build Coastguard Worker            device_ids=[torch.cuda.current_device()],
62*da0073e9SAndroid Build Coastguard Worker            broadcast_buffers=False,
63*da0073e9SAndroid Build Coastguard Worker            process_group=pg,
64*da0073e9SAndroid Build Coastguard Worker            bucket_cap_mb=benchmark.bucket_size,
65*da0073e9SAndroid Build Coastguard Worker        )
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker    measurements = []
68*da0073e9SAndroid Build Coastguard Worker    warmup_iterations = 5
69*da0073e9SAndroid Build Coastguard Worker    measured_iterations = 10
70*da0073e9SAndroid Build Coastguard Worker    for inputs, target in data * (warmup_iterations + measured_iterations):
71*da0073e9SAndroid Build Coastguard Worker        start = time.time()
72*da0073e9SAndroid Build Coastguard Worker        output = model(*inputs)
73*da0073e9SAndroid Build Coastguard Worker        loss = criterion(output, target)
74*da0073e9SAndroid Build Coastguard Worker        loss.backward()
75*da0073e9SAndroid Build Coastguard Worker        optimizer.step()
76*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
77*da0073e9SAndroid Build Coastguard Worker        measurements.append(time.time() - start)
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker    # Throw away measurements for warmup iterations
80*da0073e9SAndroid Build Coastguard Worker    return measurements[warmup_iterations:]
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Workerdef run_benchmark(benchmark, ranks, opts):
84*da0073e9SAndroid Build Coastguard Worker    group = dist.new_group(ranks=ranks, backend=benchmark.distributed_backend)
85*da0073e9SAndroid Build Coastguard Worker    measurements = []
86*da0073e9SAndroid Build Coastguard Worker    if dist.get_rank() in set(ranks):
87*da0073e9SAndroid Build Coastguard Worker        if not opts:
88*da0073e9SAndroid Build Coastguard Worker            opts = {}
89*da0073e9SAndroid Build Coastguard Worker        measurements = benchmark_process_group(group, benchmark, **opts)
90*da0073e9SAndroid Build Coastguard Worker    dist.destroy_process_group(group)
91*da0073e9SAndroid Build Coastguard Worker    dist.barrier()
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker    # Aggregate measurements for better estimation of percentiles
94*da0073e9SAndroid Build Coastguard Worker    return list(itertools.chain(*allgather_object(measurements)))
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Workerdef sweep(benchmark):
98*da0073e9SAndroid Build Coastguard Worker    # Synthesize the set of benchmarks to run.
99*da0073e9SAndroid Build Coastguard Worker    # This list contain tuples for ("string prefix", [rank...]).
100*da0073e9SAndroid Build Coastguard Worker    benchmarks = []
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    def append_benchmark(prefix, ranks, opts=None):
103*da0073e9SAndroid Build Coastguard Worker        prefix = f"{len(ranks):4} GPUs -- {prefix}"
104*da0073e9SAndroid Build Coastguard Worker        benchmarks.append((prefix, ranks, opts))
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    def local_print(msg):
107*da0073e9SAndroid Build Coastguard Worker        if dist.get_rank() == 0:
108*da0073e9SAndroid Build Coastguard Worker            print(msg, end="", flush=True)  # noqa: E999
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker    def print_header():
111*da0073e9SAndroid Build Coastguard Worker        local_print("\n")
112*da0073e9SAndroid Build Coastguard Worker        local_print("%22s" % "")
113*da0073e9SAndroid Build Coastguard Worker        for p in [50, 75, 90, 95]:
114*da0073e9SAndroid Build Coastguard Worker            local_print("%14s%10s" % ("sec/iter", "ex/sec"))
115*da0073e9SAndroid Build Coastguard Worker        local_print("\n")
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    def print_measurements(prefix, nelem, measurements):
118*da0073e9SAndroid Build Coastguard Worker        measurements = sorted(measurements)
119*da0073e9SAndroid Build Coastguard Worker        local_print("%8s:" % prefix)
120*da0073e9SAndroid Build Coastguard Worker        for p in [50, 75, 90, 95]:
121*da0073e9SAndroid Build Coastguard Worker            v = np.percentile(measurements, p)
122*da0073e9SAndroid Build Coastguard Worker            local_print("  p%02d:  %1.3fs  %6d/s" % (p, v, nelem / v))
123*da0073e9SAndroid Build Coastguard Worker        local_print("\n")
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    # Every process runs once by themselves to warm up (CUDA init, etc).
126*da0073e9SAndroid Build Coastguard Worker    append_benchmark("  warmup", [dist.get_rank()], {"use_ddp_for_single_rank": False})
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    # Single machine baselines
129*da0073e9SAndroid Build Coastguard Worker    append_benchmark("  no ddp", range(1), {"use_ddp_for_single_rank": False})
130*da0073e9SAndroid Build Coastguard Worker    append_benchmark("   1M/1G", range(1))
131*da0073e9SAndroid Build Coastguard Worker    append_benchmark("   1M/2G", range(2))
132*da0073e9SAndroid Build Coastguard Worker    append_benchmark("   1M/4G", range(4))
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    # Multi-machine benchmarks
135*da0073e9SAndroid Build Coastguard Worker    for i in range(1, (dist.get_world_size() // 8) + 1):
136*da0073e9SAndroid Build Coastguard Worker        append_benchmark("   %dM/8G" % i, range(i * 8))
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    # Run benchmarks in order of increasing number of GPUs
139*da0073e9SAndroid Build Coastguard Worker    print_header()
140*da0073e9SAndroid Build Coastguard Worker    results = []
141*da0073e9SAndroid Build Coastguard Worker    for prefix, ranks, opts in sorted(benchmarks, key=lambda tup: len(tup[1])):
142*da0073e9SAndroid Build Coastguard Worker        # Turn range into materialized list.
143*da0073e9SAndroid Build Coastguard Worker        ranks = list(ranks)
144*da0073e9SAndroid Build Coastguard Worker        measurements = run_benchmark(benchmark, ranks, opts)
145*da0073e9SAndroid Build Coastguard Worker        if "warmup" not in prefix:
146*da0073e9SAndroid Build Coastguard Worker            print_measurements(prefix, benchmark.batch_size, measurements)
147*da0073e9SAndroid Build Coastguard Worker            results.append({"ranks": ranks, "measurements": measurements})
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker    return results
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Workerclass Benchmark:
153*da0073e9SAndroid Build Coastguard Worker    def __init__(self, device, distributed_backend, bucket_size):
154*da0073e9SAndroid Build Coastguard Worker        self.device = device
155*da0073e9SAndroid Build Coastguard Worker        self.batch_size = 32
156*da0073e9SAndroid Build Coastguard Worker        self.distributed_backend = distributed_backend
157*da0073e9SAndroid Build Coastguard Worker        self.bucket_size = bucket_size
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker    def __str__(self):
160*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker    def create_model(self):
163*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker    def generate_inputs(self):
166*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker    def generate_target(self):
169*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Workerclass TorchvisionBenchmark(Benchmark):
173*da0073e9SAndroid Build Coastguard Worker    def __init__(self, device, distributed_backend, bucket_size, model):
174*da0073e9SAndroid Build Coastguard Worker        super().__init__(
175*da0073e9SAndroid Build Coastguard Worker            device,
176*da0073e9SAndroid Build Coastguard Worker            distributed_backend,
177*da0073e9SAndroid Build Coastguard Worker            bucket_size,
178*da0073e9SAndroid Build Coastguard Worker        )
179*da0073e9SAndroid Build Coastguard Worker        self.model = model
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker    def __str__(self):
182*da0073e9SAndroid Build Coastguard Worker        return f"{self.model} with batch size {self.batch_size}"
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker    def create_model(self):
185*da0073e9SAndroid Build Coastguard Worker        return torchvision.models.__dict__[self.model]().to(self.device)
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker    def generate_inputs(self):
188*da0073e9SAndroid Build Coastguard Worker        return [torch.rand([self.batch_size, 3, 224, 224], device=self.device)]
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker    def generate_target(self):
191*da0073e9SAndroid Build Coastguard Worker        return torch.tensor([1] * self.batch_size, dtype=torch.long, device=self.device)
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Workerdef main():
195*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(description="PyTorch distributed benchmark suite")
196*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--rank", type=int, default=os.environ["RANK"])
197*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--world-size", type=int, required=True)
198*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--distributed-backend", type=str, default="nccl")
199*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--bucket-size", type=int, default=25)
200*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--master-addr", type=str, required=True)
201*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--master-port", type=str, required=True)
202*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--model", type=str)
203*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
204*da0073e9SAndroid Build Coastguard Worker        "--json", type=str, metavar="PATH", help="Write file with benchmark results"
205*da0073e9SAndroid Build Coastguard Worker    )
206*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    num_gpus_per_node = torch.cuda.device_count()
209*da0073e9SAndroid Build Coastguard Worker    assert num_gpus_per_node == 8, "Expected 8 GPUs per machine"
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    # The global process group used only for communicating benchmark
212*da0073e9SAndroid Build Coastguard Worker    # metadata, like measurements. Not for benchmarking itself.
213*da0073e9SAndroid Build Coastguard Worker    dist.init_process_group(
214*da0073e9SAndroid Build Coastguard Worker        backend="gloo",
215*da0073e9SAndroid Build Coastguard Worker        init_method=f"tcp://{args.master_addr}:{args.master_port}",
216*da0073e9SAndroid Build Coastguard Worker        rank=args.rank,
217*da0073e9SAndroid Build Coastguard Worker        world_size=args.world_size,
218*da0073e9SAndroid Build Coastguard Worker    )
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker    output = allgather_run("nvidia-smi topo -m")
221*da0073e9SAndroid Build Coastguard Worker    if not allequal(output):
222*da0073e9SAndroid Build Coastguard Worker        print('Output of "nvidia-smi topo -m" differs between machines')
223*da0073e9SAndroid Build Coastguard Worker        sys.exit(1)
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker    if args.rank == 0:
226*da0073e9SAndroid Build Coastguard Worker        print("-----------------------------------")
227*da0073e9SAndroid Build Coastguard Worker        print("PyTorch distributed benchmark suite")
228*da0073e9SAndroid Build Coastguard Worker        print("-----------------------------------")
229*da0073e9SAndroid Build Coastguard Worker        print()
230*da0073e9SAndroid Build Coastguard Worker        print(f"* PyTorch version: {torch.__version__}")
231*da0073e9SAndroid Build Coastguard Worker        print(f"* CUDA version: {torch.version.cuda}")
232*da0073e9SAndroid Build Coastguard Worker        print(f"* Distributed backend: {args.distributed_backend}")
233*da0073e9SAndroid Build Coastguard Worker        print(f"* Maximum bucket size: {args.bucket_size}MB")
234*da0073e9SAndroid Build Coastguard Worker        print()
235*da0073e9SAndroid Build Coastguard Worker        print("--- nvidia-smi topo -m ---")
236*da0073e9SAndroid Build Coastguard Worker        print()
237*da0073e9SAndroid Build Coastguard Worker        print(output[0])
238*da0073e9SAndroid Build Coastguard Worker        print("--------------------------")
239*da0073e9SAndroid Build Coastguard Worker        print()
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker    torch.cuda.set_device(dist.get_rank() % 8)
242*da0073e9SAndroid Build Coastguard Worker    device = torch.device("cuda:%d" % (dist.get_rank() % 8))
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker    benchmarks = []
245*da0073e9SAndroid Build Coastguard Worker    if args.model:
246*da0073e9SAndroid Build Coastguard Worker        benchmarks.append(
247*da0073e9SAndroid Build Coastguard Worker            TorchvisionBenchmark(
248*da0073e9SAndroid Build Coastguard Worker                device=device,
249*da0073e9SAndroid Build Coastguard Worker                distributed_backend=args.distributed_backend,
250*da0073e9SAndroid Build Coastguard Worker                bucket_size=args.bucket_size,
251*da0073e9SAndroid Build Coastguard Worker                model=args.model,
252*da0073e9SAndroid Build Coastguard Worker            )
253*da0073e9SAndroid Build Coastguard Worker        )
254*da0073e9SAndroid Build Coastguard Worker    else:
255*da0073e9SAndroid Build Coastguard Worker        for model in ["resnet50", "resnet101", "resnext50_32x4d", "resnext101_32x8d"]:
256*da0073e9SAndroid Build Coastguard Worker            benchmarks.append(
257*da0073e9SAndroid Build Coastguard Worker                TorchvisionBenchmark(
258*da0073e9SAndroid Build Coastguard Worker                    device=device,
259*da0073e9SAndroid Build Coastguard Worker                    distributed_backend=args.distributed_backend,
260*da0073e9SAndroid Build Coastguard Worker                    bucket_size=args.bucket_size,
261*da0073e9SAndroid Build Coastguard Worker                    model=model,
262*da0073e9SAndroid Build Coastguard Worker                )
263*da0073e9SAndroid Build Coastguard Worker            )
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    benchmark_results = []
266*da0073e9SAndroid Build Coastguard Worker    for benchmark in benchmarks:
267*da0073e9SAndroid Build Coastguard Worker        if args.rank == 0:
268*da0073e9SAndroid Build Coastguard Worker            print(f"\nBenchmark: {str(benchmark)}")
269*da0073e9SAndroid Build Coastguard Worker        result = sweep(benchmark)
270*da0073e9SAndroid Build Coastguard Worker        benchmark_results.append(
271*da0073e9SAndroid Build Coastguard Worker            {
272*da0073e9SAndroid Build Coastguard Worker                "model": benchmark.model,
273*da0073e9SAndroid Build Coastguard Worker                "batch_size": benchmark.batch_size,
274*da0073e9SAndroid Build Coastguard Worker                "result": result,
275*da0073e9SAndroid Build Coastguard Worker            }
276*da0073e9SAndroid Build Coastguard Worker        )
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker    # Write file with benchmark results if applicable
279*da0073e9SAndroid Build Coastguard Worker    if args.rank == 0 and args.json:
280*da0073e9SAndroid Build Coastguard Worker        report = {
281*da0073e9SAndroid Build Coastguard Worker            "pytorch_version": torch.__version__,
282*da0073e9SAndroid Build Coastguard Worker            "cuda_version": torch.version.cuda,
283*da0073e9SAndroid Build Coastguard Worker            "distributed_backend": args.distributed_backend,
284*da0073e9SAndroid Build Coastguard Worker            "bucket_size": args.bucket_size,
285*da0073e9SAndroid Build Coastguard Worker            "benchmark_results": benchmark_results,
286*da0073e9SAndroid Build Coastguard Worker        }
287*da0073e9SAndroid Build Coastguard Worker        with open(args.json, "w") as f:
288*da0073e9SAndroid Build Coastguard Worker            json.dump(report, f)
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
292*da0073e9SAndroid Build Coastguard Worker    main()
293