1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker# 3*da0073e9SAndroid Build Coastguard Worker# Measure distributed training iteration time. 4*da0073e9SAndroid Build Coastguard Worker# 5*da0073e9SAndroid Build Coastguard Worker# This program performs a sweep over a) a number of model architectures, and 6*da0073e9SAndroid Build Coastguard Worker# b) an increasing number of processes. This produces a 1-GPU baseline, 7*da0073e9SAndroid Build Coastguard Worker# an 8-GPU baseline (if applicable), as well as measurements for however 8*da0073e9SAndroid Build Coastguard Worker# many processes can participate in training. 9*da0073e9SAndroid Build Coastguard Worker# 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerimport argparse 12*da0073e9SAndroid Build Coastguard Workerimport itertools 13*da0073e9SAndroid Build Coastguard Workerimport json 14*da0073e9SAndroid Build Coastguard Workerimport os 15*da0073e9SAndroid Build Coastguard Workerimport shlex 16*da0073e9SAndroid Build Coastguard Workerimport subprocess 17*da0073e9SAndroid Build Coastguard Workerimport sys 18*da0073e9SAndroid Build Coastguard Workerimport time 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerimport numpy as np 21*da0073e9SAndroid Build Coastguard Workerimport torchvision 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Workerimport torch 24*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as dist 25*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 26*da0073e9SAndroid Build Coastguard Workerimport torch.optim as optim 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerdef allgather_object(obj): 30*da0073e9SAndroid Build Coastguard Worker out = [None for _ in range(dist.get_world_size())] 31*da0073e9SAndroid Build Coastguard Worker dist.all_gather_object(out, obj) 32*da0073e9SAndroid Build Coastguard Worker return out 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Workerdef allgather_run(cmd): 36*da0073e9SAndroid Build Coastguard Worker proc = subprocess.run(shlex.split(cmd), capture_output=True) 37*da0073e9SAndroid Build Coastguard Worker assert proc.returncode == 0 38*da0073e9SAndroid Build Coastguard Worker return allgather_object(proc.stdout.decode("utf-8")) 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Workerdef allequal(iterator): 42*da0073e9SAndroid Build Coastguard Worker iterator = iter(iterator) 43*da0073e9SAndroid Build Coastguard Worker try: 44*da0073e9SAndroid Build Coastguard Worker first = next(iterator) 45*da0073e9SAndroid Build Coastguard Worker except StopIteration: 46*da0073e9SAndroid Build Coastguard Worker return True 47*da0073e9SAndroid Build Coastguard Worker return all(first == rest for rest in iterator) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Workerdef benchmark_process_group(pg, benchmark, use_ddp_for_single_rank=True): 51*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(pg.rank()) 52*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(pg.rank()) 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker model = benchmark.create_model() 55*da0073e9SAndroid Build Coastguard Worker data = [(benchmark.generate_inputs(), benchmark.generate_target())] 56*da0073e9SAndroid Build Coastguard Worker criterion = nn.CrossEntropyLoss() 57*da0073e9SAndroid Build Coastguard Worker optimizer = optim.SGD(model.parameters(), 0.001, momentum=0.9, weight_decay=1e-4) 58*da0073e9SAndroid Build Coastguard Worker if use_ddp_for_single_rank or pg.size() > 1: 59*da0073e9SAndroid Build Coastguard Worker model = torch.nn.parallel.DistributedDataParallel( 60*da0073e9SAndroid Build Coastguard Worker model, 61*da0073e9SAndroid Build Coastguard Worker device_ids=[torch.cuda.current_device()], 62*da0073e9SAndroid Build Coastguard Worker broadcast_buffers=False, 63*da0073e9SAndroid Build Coastguard Worker process_group=pg, 64*da0073e9SAndroid Build Coastguard Worker bucket_cap_mb=benchmark.bucket_size, 65*da0073e9SAndroid Build Coastguard Worker ) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker measurements = [] 68*da0073e9SAndroid Build Coastguard Worker warmup_iterations = 5 69*da0073e9SAndroid Build Coastguard Worker measured_iterations = 10 70*da0073e9SAndroid Build Coastguard Worker for inputs, target in data * (warmup_iterations + measured_iterations): 71*da0073e9SAndroid Build Coastguard Worker start = time.time() 72*da0073e9SAndroid Build Coastguard Worker output = model(*inputs) 73*da0073e9SAndroid Build Coastguard Worker loss = criterion(output, target) 74*da0073e9SAndroid Build Coastguard Worker loss.backward() 75*da0073e9SAndroid Build Coastguard Worker optimizer.step() 76*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 77*da0073e9SAndroid Build Coastguard Worker measurements.append(time.time() - start) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker # Throw away measurements for warmup iterations 80*da0073e9SAndroid Build Coastguard Worker return measurements[warmup_iterations:] 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Workerdef run_benchmark(benchmark, ranks, opts): 84*da0073e9SAndroid Build Coastguard Worker group = dist.new_group(ranks=ranks, backend=benchmark.distributed_backend) 85*da0073e9SAndroid Build Coastguard Worker measurements = [] 86*da0073e9SAndroid Build Coastguard Worker if dist.get_rank() in set(ranks): 87*da0073e9SAndroid Build Coastguard Worker if not opts: 88*da0073e9SAndroid Build Coastguard Worker opts = {} 89*da0073e9SAndroid Build Coastguard Worker measurements = benchmark_process_group(group, benchmark, **opts) 90*da0073e9SAndroid Build Coastguard Worker dist.destroy_process_group(group) 91*da0073e9SAndroid Build Coastguard Worker dist.barrier() 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker # Aggregate measurements for better estimation of percentiles 94*da0073e9SAndroid Build Coastguard Worker return list(itertools.chain(*allgather_object(measurements))) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Workerdef sweep(benchmark): 98*da0073e9SAndroid Build Coastguard Worker # Synthesize the set of benchmarks to run. 99*da0073e9SAndroid Build Coastguard Worker # This list contain tuples for ("string prefix", [rank...]). 100*da0073e9SAndroid Build Coastguard Worker benchmarks = [] 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker def append_benchmark(prefix, ranks, opts=None): 103*da0073e9SAndroid Build Coastguard Worker prefix = f"{len(ranks):4} GPUs -- {prefix}" 104*da0073e9SAndroid Build Coastguard Worker benchmarks.append((prefix, ranks, opts)) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker def local_print(msg): 107*da0073e9SAndroid Build Coastguard Worker if dist.get_rank() == 0: 108*da0073e9SAndroid Build Coastguard Worker print(msg, end="", flush=True) # noqa: E999 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker def print_header(): 111*da0073e9SAndroid Build Coastguard Worker local_print("\n") 112*da0073e9SAndroid Build Coastguard Worker local_print("%22s" % "") 113*da0073e9SAndroid Build Coastguard Worker for p in [50, 75, 90, 95]: 114*da0073e9SAndroid Build Coastguard Worker local_print("%14s%10s" % ("sec/iter", "ex/sec")) 115*da0073e9SAndroid Build Coastguard Worker local_print("\n") 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker def print_measurements(prefix, nelem, measurements): 118*da0073e9SAndroid Build Coastguard Worker measurements = sorted(measurements) 119*da0073e9SAndroid Build Coastguard Worker local_print("%8s:" % prefix) 120*da0073e9SAndroid Build Coastguard Worker for p in [50, 75, 90, 95]: 121*da0073e9SAndroid Build Coastguard Worker v = np.percentile(measurements, p) 122*da0073e9SAndroid Build Coastguard Worker local_print(" p%02d: %1.3fs %6d/s" % (p, v, nelem / v)) 123*da0073e9SAndroid Build Coastguard Worker local_print("\n") 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker # Every process runs once by themselves to warm up (CUDA init, etc). 126*da0073e9SAndroid Build Coastguard Worker append_benchmark(" warmup", [dist.get_rank()], {"use_ddp_for_single_rank": False}) 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker # Single machine baselines 129*da0073e9SAndroid Build Coastguard Worker append_benchmark(" no ddp", range(1), {"use_ddp_for_single_rank": False}) 130*da0073e9SAndroid Build Coastguard Worker append_benchmark(" 1M/1G", range(1)) 131*da0073e9SAndroid Build Coastguard Worker append_benchmark(" 1M/2G", range(2)) 132*da0073e9SAndroid Build Coastguard Worker append_benchmark(" 1M/4G", range(4)) 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker # Multi-machine benchmarks 135*da0073e9SAndroid Build Coastguard Worker for i in range(1, (dist.get_world_size() // 8) + 1): 136*da0073e9SAndroid Build Coastguard Worker append_benchmark(" %dM/8G" % i, range(i * 8)) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker # Run benchmarks in order of increasing number of GPUs 139*da0073e9SAndroid Build Coastguard Worker print_header() 140*da0073e9SAndroid Build Coastguard Worker results = [] 141*da0073e9SAndroid Build Coastguard Worker for prefix, ranks, opts in sorted(benchmarks, key=lambda tup: len(tup[1])): 142*da0073e9SAndroid Build Coastguard Worker # Turn range into materialized list. 143*da0073e9SAndroid Build Coastguard Worker ranks = list(ranks) 144*da0073e9SAndroid Build Coastguard Worker measurements = run_benchmark(benchmark, ranks, opts) 145*da0073e9SAndroid Build Coastguard Worker if "warmup" not in prefix: 146*da0073e9SAndroid Build Coastguard Worker print_measurements(prefix, benchmark.batch_size, measurements) 147*da0073e9SAndroid Build Coastguard Worker results.append({"ranks": ranks, "measurements": measurements}) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker return results 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Workerclass Benchmark: 153*da0073e9SAndroid Build Coastguard Worker def __init__(self, device, distributed_backend, bucket_size): 154*da0073e9SAndroid Build Coastguard Worker self.device = device 155*da0073e9SAndroid Build Coastguard Worker self.batch_size = 32 156*da0073e9SAndroid Build Coastguard Worker self.distributed_backend = distributed_backend 157*da0073e9SAndroid Build Coastguard Worker self.bucket_size = bucket_size 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker def __str__(self): 160*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker def create_model(self): 163*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker def generate_inputs(self): 166*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker def generate_target(self): 169*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Workerclass TorchvisionBenchmark(Benchmark): 173*da0073e9SAndroid Build Coastguard Worker def __init__(self, device, distributed_backend, bucket_size, model): 174*da0073e9SAndroid Build Coastguard Worker super().__init__( 175*da0073e9SAndroid Build Coastguard Worker device, 176*da0073e9SAndroid Build Coastguard Worker distributed_backend, 177*da0073e9SAndroid Build Coastguard Worker bucket_size, 178*da0073e9SAndroid Build Coastguard Worker ) 179*da0073e9SAndroid Build Coastguard Worker self.model = model 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker def __str__(self): 182*da0073e9SAndroid Build Coastguard Worker return f"{self.model} with batch size {self.batch_size}" 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker def create_model(self): 185*da0073e9SAndroid Build Coastguard Worker return torchvision.models.__dict__[self.model]().to(self.device) 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker def generate_inputs(self): 188*da0073e9SAndroid Build Coastguard Worker return [torch.rand([self.batch_size, 3, 224, 224], device=self.device)] 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker def generate_target(self): 191*da0073e9SAndroid Build Coastguard Worker return torch.tensor([1] * self.batch_size, dtype=torch.long, device=self.device) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Workerdef main(): 195*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser(description="PyTorch distributed benchmark suite") 196*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--rank", type=int, default=os.environ["RANK"]) 197*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--world-size", type=int, required=True) 198*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--distributed-backend", type=str, default="nccl") 199*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--bucket-size", type=int, default=25) 200*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--master-addr", type=str, required=True) 201*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--master-port", type=str, required=True) 202*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--model", type=str) 203*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 204*da0073e9SAndroid Build Coastguard Worker "--json", type=str, metavar="PATH", help="Write file with benchmark results" 205*da0073e9SAndroid Build Coastguard Worker ) 206*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker num_gpus_per_node = torch.cuda.device_count() 209*da0073e9SAndroid Build Coastguard Worker assert num_gpus_per_node == 8, "Expected 8 GPUs per machine" 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker # The global process group used only for communicating benchmark 212*da0073e9SAndroid Build Coastguard Worker # metadata, like measurements. Not for benchmarking itself. 213*da0073e9SAndroid Build Coastguard Worker dist.init_process_group( 214*da0073e9SAndroid Build Coastguard Worker backend="gloo", 215*da0073e9SAndroid Build Coastguard Worker init_method=f"tcp://{args.master_addr}:{args.master_port}", 216*da0073e9SAndroid Build Coastguard Worker rank=args.rank, 217*da0073e9SAndroid Build Coastguard Worker world_size=args.world_size, 218*da0073e9SAndroid Build Coastguard Worker ) 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker output = allgather_run("nvidia-smi topo -m") 221*da0073e9SAndroid Build Coastguard Worker if not allequal(output): 222*da0073e9SAndroid Build Coastguard Worker print('Output of "nvidia-smi topo -m" differs between machines') 223*da0073e9SAndroid Build Coastguard Worker sys.exit(1) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker if args.rank == 0: 226*da0073e9SAndroid Build Coastguard Worker print("-----------------------------------") 227*da0073e9SAndroid Build Coastguard Worker print("PyTorch distributed benchmark suite") 228*da0073e9SAndroid Build Coastguard Worker print("-----------------------------------") 229*da0073e9SAndroid Build Coastguard Worker print() 230*da0073e9SAndroid Build Coastguard Worker print(f"* PyTorch version: {torch.__version__}") 231*da0073e9SAndroid Build Coastguard Worker print(f"* CUDA version: {torch.version.cuda}") 232*da0073e9SAndroid Build Coastguard Worker print(f"* Distributed backend: {args.distributed_backend}") 233*da0073e9SAndroid Build Coastguard Worker print(f"* Maximum bucket size: {args.bucket_size}MB") 234*da0073e9SAndroid Build Coastguard Worker print() 235*da0073e9SAndroid Build Coastguard Worker print("--- nvidia-smi topo -m ---") 236*da0073e9SAndroid Build Coastguard Worker print() 237*da0073e9SAndroid Build Coastguard Worker print(output[0]) 238*da0073e9SAndroid Build Coastguard Worker print("--------------------------") 239*da0073e9SAndroid Build Coastguard Worker print() 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_device(dist.get_rank() % 8) 242*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:%d" % (dist.get_rank() % 8)) 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker benchmarks = [] 245*da0073e9SAndroid Build Coastguard Worker if args.model: 246*da0073e9SAndroid Build Coastguard Worker benchmarks.append( 247*da0073e9SAndroid Build Coastguard Worker TorchvisionBenchmark( 248*da0073e9SAndroid Build Coastguard Worker device=device, 249*da0073e9SAndroid Build Coastguard Worker distributed_backend=args.distributed_backend, 250*da0073e9SAndroid Build Coastguard Worker bucket_size=args.bucket_size, 251*da0073e9SAndroid Build Coastguard Worker model=args.model, 252*da0073e9SAndroid Build Coastguard Worker ) 253*da0073e9SAndroid Build Coastguard Worker ) 254*da0073e9SAndroid Build Coastguard Worker else: 255*da0073e9SAndroid Build Coastguard Worker for model in ["resnet50", "resnet101", "resnext50_32x4d", "resnext101_32x8d"]: 256*da0073e9SAndroid Build Coastguard Worker benchmarks.append( 257*da0073e9SAndroid Build Coastguard Worker TorchvisionBenchmark( 258*da0073e9SAndroid Build Coastguard Worker device=device, 259*da0073e9SAndroid Build Coastguard Worker distributed_backend=args.distributed_backend, 260*da0073e9SAndroid Build Coastguard Worker bucket_size=args.bucket_size, 261*da0073e9SAndroid Build Coastguard Worker model=model, 262*da0073e9SAndroid Build Coastguard Worker ) 263*da0073e9SAndroid Build Coastguard Worker ) 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker benchmark_results = [] 266*da0073e9SAndroid Build Coastguard Worker for benchmark in benchmarks: 267*da0073e9SAndroid Build Coastguard Worker if args.rank == 0: 268*da0073e9SAndroid Build Coastguard Worker print(f"\nBenchmark: {str(benchmark)}") 269*da0073e9SAndroid Build Coastguard Worker result = sweep(benchmark) 270*da0073e9SAndroid Build Coastguard Worker benchmark_results.append( 271*da0073e9SAndroid Build Coastguard Worker { 272*da0073e9SAndroid Build Coastguard Worker "model": benchmark.model, 273*da0073e9SAndroid Build Coastguard Worker "batch_size": benchmark.batch_size, 274*da0073e9SAndroid Build Coastguard Worker "result": result, 275*da0073e9SAndroid Build Coastguard Worker } 276*da0073e9SAndroid Build Coastguard Worker ) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker # Write file with benchmark results if applicable 279*da0073e9SAndroid Build Coastguard Worker if args.rank == 0 and args.json: 280*da0073e9SAndroid Build Coastguard Worker report = { 281*da0073e9SAndroid Build Coastguard Worker "pytorch_version": torch.__version__, 282*da0073e9SAndroid Build Coastguard Worker "cuda_version": torch.version.cuda, 283*da0073e9SAndroid Build Coastguard Worker "distributed_backend": args.distributed_backend, 284*da0073e9SAndroid Build Coastguard Worker "bucket_size": args.bucket_size, 285*da0073e9SAndroid Build Coastguard Worker "benchmark_results": benchmark_results, 286*da0073e9SAndroid Build Coastguard Worker } 287*da0073e9SAndroid Build Coastguard Worker with open(args.json, "w") as f: 288*da0073e9SAndroid Build Coastguard Worker json.dump(report, f) 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 292*da0073e9SAndroid Build Coastguard Worker main() 293