1# mypy: allow-untyped-defs 2r""" 3Module ``torch.distributed.launch``. 4 5``torch.distributed.launch`` is a module that spawns up multiple distributed 6training processes on each of the training nodes. 7 8.. warning:: 9 10 This module is going to be deprecated in favor of :ref:`torchrun <launcher-api>`. 11 12The utility can be used for single-node distributed training, in which one or 13more processes per node will be spawned. The utility can be used for either 14CPU training or GPU training. If the utility is used for GPU training, 15each distributed process will be operating on a single GPU. This can achieve 16well-improved single-node training performance. It can also be used in 17multi-node distributed training, by spawning up multiple processes on each node 18for well-improved multi-node distributed training performance as well. 19This will especially be beneficial for systems with multiple Infiniband 20interfaces that have direct-GPU support, since all of them can be utilized for 21aggregated communication bandwidth. 22 23In both cases of single-node distributed training or multi-node distributed 24training, this utility will launch the given number of processes per node 25(``--nproc-per-node``). If used for GPU training, this number needs to be less 26or equal to the number of GPUs on the current system (``nproc_per_node``), 27and each process will be operating on a single GPU from *GPU 0 to 28GPU (nproc_per_node - 1)*. 29 30**How to use this module:** 31 321. Single-Node multi-process distributed training 33 34:: 35 36 python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE 37 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other 38 arguments of your training script) 39 402. Multi-Node multi-process distributed training: (e.g. two nodes) 41 42 43Node 1: *(IP: 192.168.1.1, and has a free port: 1234)* 44 45:: 46 47 python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE 48 --nnodes=2 --node-rank=0 --master-addr="192.168.1.1" 49 --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 50 and all other arguments of your training script) 51 52Node 2: 53 54:: 55 56 python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE 57 --nnodes=2 --node-rank=1 --master-addr="192.168.1.1" 58 --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 59 and all other arguments of your training script) 60 613. To look up what optional arguments this module offers: 62 63:: 64 65 python -m torch.distributed.launch --help 66 67 68**Important Notices:** 69 701. This utility and multi-process distributed (single-node or 71multi-node) GPU training currently only achieves the best performance using 72the NCCL distributed backend. Thus NCCL backend is the recommended backend to 73use for GPU training. 74 752. In your training program, you must parse the command-line argument: 76``--local-rank=LOCAL_PROCESS_RANK``, which will be provided by this module. 77If your training program uses GPUs, you should ensure that your code only 78runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by: 79 80Parsing the local_rank argument 81 82:: 83 84 >>> # xdoctest: +SKIP 85 >>> import argparse 86 >>> parser = argparse.ArgumentParser() 87 >>> parser.add_argument("--local-rank", "--local_rank", type=int) 88 >>> args = parser.parse_args() 89 90Set your device to local rank using either 91 92:: 93 94 >>> torch.cuda.set_device(args.local_rank) # before your code runs 95 96or 97 98:: 99 100 >>> with torch.cuda.device(args.local_rank): 101 >>> # your code to run 102 >>> ... 103 104.. versionchanged:: 2.0.0 105 106 The launcher will passes the ``--local-rank=<rank>`` argument to your script. 107 From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the 108 previously used underscored ``--local_rank``. 109 110 For backward compatibility, it may be necessary for users to handle both 111 cases in their argument parsing code. This means including both ``"--local-rank"`` 112 and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is 113 provided, the launcher will trigger an error: "error: unrecognized arguments: 114 --local-rank=<rank>". For training code that only supports PyTorch 2.0.0+, 115 including ``"--local-rank"`` should be sufficient. 116 1173. In your training program, you are supposed to call the following function 118at the beginning to start the distributed backend. It is strongly recommended 119that ``init_method=env://``. Other init methods (e.g. ``tcp://``) may work, 120but ``env://`` is the one that is officially supported by this module. 121 122:: 123 124 >>> torch.distributed.init_process_group(backend='YOUR BACKEND', 125 >>> init_method='env://') 126 1274. In your training program, you can either use regular distributed functions 128or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your 129training program uses GPUs for training and you would like to use 130:func:`torch.nn.parallel.DistributedDataParallel` module, 131here is how to configure it. 132 133:: 134 135 >>> model = torch.nn.parallel.DistributedDataParallel(model, 136 >>> device_ids=[args.local_rank], 137 >>> output_device=args.local_rank) 138 139Please ensure that ``device_ids`` argument is set to be the only GPU device id 140that your code will be operating on. This is generally the local rank of the 141process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``, 142and ``output_device`` needs to be ``args.local_rank`` in order to use this 143utility 144 1455. Another way to pass ``local_rank`` to the subprocesses via environment variable 146``LOCAL_RANK``. This behavior is enabled when you launch the script with 147``--use-env=True``. You must adjust the subprocess example above to replace 148``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher 149will not pass ``--local-rank`` when you specify this flag. 150 151.. warning:: 152 153 ``local_rank`` is NOT globally unique: it is only unique per process 154 on a machine. Thus, don't use it to decide if you should, e.g., 155 write to a networked filesystem. See 156 https://github.com/pytorch/pytorch/issues/12042 for an example of 157 how things can go wrong if you don't do this correctly. 158 159 160 161""" 162 163from typing_extensions import deprecated as _deprecated 164 165from torch.distributed.run import get_args_parser, run 166 167 168def parse_args(args): 169 parser = get_args_parser() 170 parser.add_argument( 171 "--use-env", 172 "--use_env", 173 default=False, 174 action="store_true", 175 help="Use environment variable to pass " 176 "'local rank'. For legacy reasons, the default value is False. " 177 "If set to True, the script will not pass " 178 "--local-rank as argument, and will instead set LOCAL_RANK.", 179 ) 180 return parser.parse_args(args) 181 182 183def launch(args): 184 if args.no_python and not args.use_env: 185 raise ValueError( 186 "When using the '--no-python' flag," 187 " you must also set the '--use-env' flag." 188 ) 189 run(args) 190 191 192@_deprecated( 193 "The module torch.distributed.launch is deprecated\n" 194 "and will be removed in future. Use torchrun.\n" 195 "Note that --use-env is set by default in torchrun.\n" 196 "If your script expects `--local-rank` argument to be set, please\n" 197 "change it to read from `os.environ['LOCAL_RANK']` instead. See \n" 198 "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n" 199 "further instructions\n", 200 category=FutureWarning, 201) 202def main(args=None): 203 args = parse_args(args) 204 launch(args) 205 206 207if __name__ == "__main__": 208 main() 209