xref: /aosp_15_r20/external/pytorch/test/distributed/test_c10d_nccl.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import copy
4import json
5import os
6import pickle
7import random
8import re
9import signal
10import sys
11import tempfile
12import threading
13import time
14import warnings
15from contextlib import contextmanager
16from datetime import datetime, timedelta
17from enum import auto, Enum
18from itertools import chain, product
19from unittest import mock, SkipTest
20
21import torch
22import torch.distributed as c10d
23
24
25if not c10d.is_available() or not c10d.is_nccl_available():
26    print("c10d NCCL not available, skipping tests", file=sys.stderr)
27    sys.exit(0)
28
29from typing import Dict, List
30
31import test_c10d_common
32from test_c10d_common import ConvNet, DoubleGpuNet, gpus_for_rank, ModuleForDdpCommHook
33
34import torch.distributed as dist
35import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default
36import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
37import torch.nn.functional as F
38import torch.testing._internal.common_utils as common
39from torch import nn
40from torch._C._distributed_c10d import OpType
41from torch.nn.parallel import DistributedDataParallel
42from torch.testing._internal.common_cuda import TEST_MULTIGPU
43from torch.testing._internal.common_distributed import (
44    get_timeout,
45    init_multigpu_helper,
46    MultiProcessTestCase,
47    requires_gloo,
48    requires_nccl,
49    requires_nccl_version,
50    skip_if_lt_x_gpu,
51    skip_if_rocm_multiprocess,
52    TEST_SKIPS,
53    with_dist_debug_levels,
54    with_nccl_blocking_wait,
55)
56from torch.testing._internal.common_utils import (
57    instantiate_parametrized_tests,
58    parametrize,
59    retry_on_connect_failures,
60    run_tests,
61    skip_but_pass_in_sandcastle,
62    skip_but_pass_in_sandcastle_if,
63    TEST_CUDA,
64    TEST_WITH_DEV_DBG_ASAN,
65    TEST_WITH_ROCM,
66    TestCase,
67)
68
69
70if TEST_WITH_DEV_DBG_ASAN:
71    print(
72        "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr
73    )
74    sys.exit(0)
75
76# bfloat16 is only supported by CUDA 11+
77BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
78    (torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 11)
79    or torch.version.hip is not None
80)
81
82
83class RendezvousEnvTest(TestCase):
84    @retry_on_connect_failures
85    @requires_nccl()
86    @skip_but_pass_in_sandcastle_if(not TEST_CUDA, "No GPUs available, skipping test")
87    def test_common_errors(self):
88        vars = {
89            "WORLD_SIZE": "1",
90            "RANK": "0",
91            "MASTER_ADDR": "127.0.0.1",
92            "MASTER_PORT": str(common.find_free_port()),
93        }
94
95        class Env:
96            def __init__(self, vars):
97                self.env_patcher = mock.patch.dict(os.environ, vars, clear=True)
98
99            def __enter__(self):
100                self.env_patcher.start()
101
102            def __exit__(self, type, value, traceback):
103                self.env_patcher.stop()
104
105        def without(d, key):
106            d = d.copy()
107            d.pop(key)
108            return d
109
110        def withouts(d, keys):
111            d = d.copy()
112            for key in keys:
113                d.pop(key)
114            return d
115
116        with Env(without(vars, "WORLD_SIZE")):
117            self.assertEqual(None, os.environ.get("WORLD_SIZE"))
118            with self.assertRaisesRegex(ValueError, "WORLD_SIZE expected"):
119                gen = c10d.rendezvous("env://")
120                next(gen)
121            c10d.init_process_group(backend="nccl", world_size=1)
122            self.assertEqual(c10d.get_rank(), 0)
123            self.assertEqual(c10d.get_world_size(), 1)
124            c10d.destroy_process_group()
125
126        with Env(without(vars, "RANK")):
127            self.assertEqual(None, os.environ.get("RANK"))
128            with self.assertRaisesRegex(ValueError, "RANK expected"):
129                gen = c10d.rendezvous("env://")
130                next(gen)
131            c10d.init_process_group(backend="nccl", rank=0)
132            self.assertEqual(c10d.get_rank(), 0)
133            self.assertEqual(c10d.get_world_size(), 1)
134            c10d.destroy_process_group()
135
136        with Env(withouts(vars, ["RANK", "WORLD_SIZE"])):
137            self.assertEqual(None, os.environ.get("RANK"))
138            self.assertEqual(None, os.environ.get("WORLD_SIZE"))
139            c10d.init_process_group(backend="nccl", rank=0, world_size=1)
140            self.assertEqual(c10d.get_rank(), 0)
141            self.assertEqual(c10d.get_world_size(), 1)
142            c10d.destroy_process_group()
143
144        with Env(vars):
145            c10d.init_process_group(backend="nccl")
146            self.assertEqual(c10d.get_rank(), 0)
147            self.assertEqual(c10d.get_world_size(), 1)
148            c10d.destroy_process_group()
149
150        with Env(without(vars, "MASTER_ADDR")):
151            self.assertEqual(None, os.environ.get("MASTER_ADDR"))
152            with self.assertRaisesRegex(ValueError, "MASTER_ADDR expected"):
153                gen = c10d.rendezvous("env://")
154                next(gen)
155
156        with Env(without(vars, "MASTER_PORT")):
157            self.assertEqual(None, os.environ.get("MASTER_PORT"))
158            with self.assertRaisesRegex(ValueError, "MASTER_PORT expected"):
159                gen = c10d.rendezvous("env://")
160                next(gen)
161
162        with Env(without(vars, "WORLD_SIZE")):
163            self.assertEqual(None, os.environ.get("WORLD_SIZE"))
164            gen = c10d.rendezvous(f"env://?world_size={1}")
165            _, _, size = next(gen)
166            self.assertEqual(size, 1)
167
168        with Env(without(vars, "RANK")):
169            self.assertEqual(None, os.environ.get("RANK"))
170            gen = c10d.rendezvous(f"env://?rank={0}")
171            _, rank, _ = next(gen)
172            self.assertEqual(rank, 0)
173
174        with Env(withouts(vars, ["RANK", "WORLD_SIZE"])):
175            self.assertEqual(None, os.environ.get("RANK"))
176            self.assertEqual(None, os.environ.get("WORLD_SIZE"))
177            gen = c10d.rendezvous(f"env://?rank={0}&world_size={1}")
178            _, rank, size = next(gen)
179            self.assertEqual(rank, 0)
180            self.assertEqual(size, 1)
181
182
183class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase):
184    @requires_nccl()
185    @retry_on_connect_failures
186    @skip_but_pass_in_sandcastle_if(not TEST_CUDA, "No GPUs available, skipping test")
187    def test_default_store_timeout_nccl(self):
188        self._test_default_store_timeout("nccl")
189
190
191class ProcessGroupNCCLNoGPUTest(TestCase):
192    MAIN_PROCESS_RANK = 0
193
194    def setUp(self):
195        self.rank = self.MAIN_PROCESS_RANK
196        self.world_size = 1
197        self.file = tempfile.NamedTemporaryFile(delete=False)
198
199    def tearDown(self):
200        pass
201
202    @requires_nccl()
203    @skip_but_pass_in_sandcastle_if(TEST_CUDA, "GPUs are available, skipping test")
204    def test_init_no_gpus(self):
205        store = c10d.FileStore(self.file.name, self.world_size)
206        with self.assertRaisesRegex(
207            ValueError, "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"
208        ):
209            c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
210
211
212class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
213    def _create_process_group_nccl(self, store, opts, device_id=None):
214        # create nccl processgroup with opts
215        c10d.init_process_group(
216            "nccl",
217            world_size=self.world_size,
218            rank=self.rank,
219            store=store,
220            pg_options=opts,
221            device_id=device_id,
222        )
223        pg = c10d.distributed_c10d._get_default_group()
224        return pg
225
226    def opts(self, high_priority_stream=False):
227        opts = c10d.ProcessGroupNCCL.Options()
228        opts.is_high_priority_stream = high_priority_stream
229        return opts
230
231    def setUp(self):
232        super().setUp()
233        # Need to skip return code checking for these tests since the child
234        # processes don't exit cleanly in some cuda versions
235        self.skip_return_code_checks = [
236            self.test_nan_assert_float16.__wrapped__,
237            self.test_nan_assert_float32.__wrapped__,
238            self.test_nan_assert_float64.__wrapped__,
239            self.test_nan_assert_bfloat16.__wrapped__,
240        ]
241
242        # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
243        # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
244        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
245        # self.num_gpus = torch.cuda.device_count()
246        self._spawn_processes()
247
248    def tearDown(self):
249        super().tearDown()
250        try:
251            os.remove(self.file_name)
252        except OSError:
253            pass
254
255    @property
256    def world_size(self):
257        return 2
258
259    @property
260    def rank_to_GPU(self):
261        # return rank to GPU map
262        return init_multigpu_helper(self.world_size, "nccl")
263
264    @requires_nccl()
265    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 1 GPU")
266    @skip_if_lt_x_gpu(1)
267    def test_nccl_dist_backend_error(self):
268        store = c10d.FileStore(self.file_name, self.world_size)
269        self._create_process_group_nccl(store, self.opts())
270
271        # Both rank 0 and 1 will use the same CUDA device resulting in ncclInvalidUsage
272        with self.assertRaises(dist.DistBackendError) as cm:
273            dist.broadcast(torch.tensor([1, 2, 3]).cuda(), 0)
274        self.assertTrue(isinstance(cm.exception, dist.DistError))
275
276        self.assertIsInstance(cm.exception, RuntimeError)
277
278    @requires_nccl()
279    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
280    def test_abort_pg(self):
281        # Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically
282        # abort the process group.
283        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
284
285        store = c10d.FileStore(self.file_name, self.world_size)
286        self._create_process_group_nccl(store, self.opts())
287        device = self.rank_to_GPU[self.rank][0]
288
289        t = torch.rand(10, 10, device=device)
290        # First allreduce to initialize state.
291        dist.all_reduce(t)
292
293        def abortpg():
294            c10d.distributed_c10d._get_default_group()._get_backend(
295                torch.device(device)
296            )._shutdown()
297
298        # Initialize DDP to ensure "destroy_process_group" will not call
299        # ProcessGroupNCCL destructor since DDP holds a reference to process group.
300        # Run a single iteration of DDP to initialize state.
301        model = DistributedDataParallel(
302            torch.nn.Linear(10, 10).to(device), device_ids=[device]
303        )
304        model(t).sum().backward()
305
306        # Now simulate collective getting stuck and abort gets us unstuck
307        if self.rank == 0:
308            dist.all_reduce(t)
309
310            # Schedule thread before we get stuck to abort pg.
311            thread = threading.Thread(target=abortpg)
312            thread.start()
313
314            # We would get stuck here due to d2h if we didn't abort.
315            t_cpu = t.cpu()
316
317            thread.join()
318
319    @requires_nccl()
320    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
321    def test_close_pg(self):
322        # Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically
323        # abort the process group.
324        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
325
326        store = c10d.FileStore(self.file_name, self.world_size)
327        pg = self._create_process_group_nccl(store, self.opts())
328        device = self.rank_to_GPU[self.rank][0]
329
330        t = torch.rand(10, 10, device=device)
331        # First allreduce to initialize state.
332        pg.allreduce(t)
333
334        # Destroy pg and validate pg is no longer valid
335        dist.destroy_process_group()
336        with self.assertRaises(dist.DistBackendError):
337            pg.allreduce([t])
338
339        del pg
340
341    CUDA_12_AND_ABOVE = torch.cuda.is_available() and (
342        torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 12
343    )
344
345    @requires_nccl()
346    @skip_but_pass_in_sandcastle_if(
347        not (TEST_MULTIGPU and CUDA_12_AND_ABOVE),
348        "NCCL test requires 2+ GPUs and Device side assert could cause unexpected errors in lower versions of CUDA",
349    )
350    @parametrize(
351        "type",
352        [
353            torch.float16,
354            torch.float32,
355            torch.float64,
356            torch.bfloat16,
357            torch.float8_e4m3fn,
358            torch.float8_e5m2,
359        ],
360    )
361    @skip_if_rocm_multiprocess
362    def test_nan_assert(self, type):
363        # Expecting a device-side error when NaN is detected
364        os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
365        store = c10d.FileStore(self.file_name, self.world_size)
366        pg = self._create_process_group_nccl(store, self.opts())
367        device = self.rank_to_GPU[self.rank][0]
368        size = (10, 10)
369        nan_tensor = torch.full(size, self.rank, dtype=type, device=device)
370        # randomly pick an nan element
371        i = random.randint(0, nan_tensor.size(0) - 1)
372        j = random.randint(0, nan_tensor.size(1) - 1)
373        nan_tensor[i, j] = float("nan")
374        with self.assertRaises(RuntimeError):
375            pg.allreduce(nan_tensor)
376        dist.destroy_process_group()
377        # reset env
378        os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
379
380    @requires_nccl()
381    @skip_if_lt_x_gpu(2)
382    def test_nan_rank_filter(self):
383        # Putting NaN at recv buffer, program should not fail as NaN checker
384        # should not check on receive buffer
385        os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
386        store = c10d.FileStore(self.file_name, self.world_size)
387        device = torch.device("cuda:%d" % self.rank)
388        c10d.init_process_group(
389            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
390        )
391        t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
392        if self.rank != 0:
393            # Putting NaN at recv buffer
394            t[1, 1] = float("nan")
395        # Against broadcast
396        c10d.broadcast(t, 0)
397        # Against P2P
398        if self.rank == 0:
399            c10d.send(t, 1)
400        elif self.rank == 1:
401            c10d.recv(t, 0)
402        c10d.destroy_process_group()
403        # reset env
404        os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
405
406    @requires_nccl()
407    @skip_if_lt_x_gpu(2)
408    def test_nan_check(self):
409        # Not expecting an error, NaN check should not make legit code fail
410        os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
411        store = c10d.FileStore(self.file_name, self.world_size)
412        device = torch.device("cuda:%d" % self.rank)
413        c10d.init_process_group(
414            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
415        )
416        x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank
417        t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
418        c10d.broadcast(x, src=0)
419        c10d.all_reduce(t)
420        c10d.barrier()
421        c10d.destroy_process_group()
422        # reset env
423        os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
424
425    @requires_nccl()
426    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
427    def test_destruct_before_terminate_pg(self):
428        # Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically
429        # abort the process group.
430        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
431        store = c10d.FileStore(self.file_name, self.world_size)
432        pg = self._create_process_group_nccl(store, self.opts())
433        device = self.rank_to_GPU[self.rank][0]
434
435        t = torch.rand(10, 10, device=device)
436        # First allreduce to initialize state.
437        pg.allreduce(t)
438        # force destruction before terminating comms, destructor would terminate comms
439        del pg
440
441    @requires_nccl()
442    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
443    def test_abort_in_destroy_pg(self):
444        # Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically
445        # abort the process group.
446        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
447
448        store = c10d.FileStore(self.file_name, self.world_size)
449        pg = self._create_process_group_nccl(store, self.opts())
450        device = self.rank_to_GPU[self.rank][0]
451
452        t = torch.rand(10, 10, device=device)
453        # First allreduce to initialize state.
454        pg.allreduce(t)
455
456        # Destroy pg and validate pg is NOT in working condition since
457        # we have shutdown comms
458        dist.destroy_process_group()
459        with self.assertRaises(dist.DistBackendError):
460            pg.allreduce([t])
461
462    @requires_nccl()
463    @skip_but_pass_in_sandcastle_if(
464        torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs"
465    )
466    def test_close_multi_pg_unordered(self):
467        store = c10d.FileStore(self.file_name, self.world_size)
468        pg = self._create_process_group_nccl(store, self.opts())
469        device = self.rank_to_GPU[self.rank][0]
470        t = torch.rand(10, 10, device=device)
471        # First allreduce to initialize default PG's communicator.
472        pg.allreduce(t).wait()
473        new_pg1 = c10d.new_group([0, 1])
474        new_pg2 = c10d.new_group([0, 1])
475        if self.rank == 0 or self.rank == 1:
476            t1 = torch.rand(10, 10, device=device)
477            t2 = torch.rand(10, 10, device=device)
478            new_pg1.allreduce(t1).wait()
479            new_pg2.allreduce(t2).wait()
480        if self.rank == 0:
481            dist.destroy_process_group(new_pg2)
482            # force destruction of pg2 first
483            del new_pg2
484            dist.destroy_process_group(new_pg1)
485            del new_pg1
486        if self.rank == 1:
487            c10d.destroy_process_group(new_pg1)
488            # force destruction of pg1 first
489            del new_pg1
490            dist.destroy_process_group(new_pg2)
491            del new_pg2
492        dist.destroy_process_group()
493
494    @requires_nccl()
495    @skip_but_pass_in_sandcastle_if(
496        torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs"
497    )
498    def test_abort_in_destroy_multi_pgs(self):
499        store = c10d.FileStore(self.file_name, self.world_size)
500        pg = self._create_process_group_nccl(store, self.opts())
501        device = self.rank_to_GPU[self.rank][0]
502        t = torch.rand(10, 10, device=device)
503        # First allreduce to initialize default PG's communicator.
504        pg.allreduce(t).wait()
505        new_pg1 = c10d.new_group([0, 1])
506        new_pg2 = c10d.new_group([0, 1])
507        t1 = torch.rand(10, 10, device=device)
508        t2 = torch.rand(10, 10, device=device)
509        new_pg1.allreduce(t1).wait()
510        new_pg2.allreduce(t2).wait()
511        backend = pg._get_backend(torch.device(device))
512        # default PG's backend should have a split count of 2
513        self.assertEqual(backend.comm_split_count(), 2)
514        # shutdown all NCCL PGs in one shot
515        dist.destroy_process_group()
516
517    @requires_nccl()
518    @skip_but_pass_in_sandcastle_if(
519        torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs"
520    )
521    def test_abort_in_destroy_mixed_empty_pgs(self):
522        store = c10d.FileStore(self.file_name, self.world_size)
523        pg = self._create_process_group_nccl(store, self.opts())
524        device = self.rank_to_GPU[self.rank][0]
525        t = torch.rand(10, 10, device=device)
526        # First allreduce to initialize default PG's communicator.
527        pg.allreduce(t).wait()
528        # PG1 is an PG without comms initialized, since we don't call collective on it
529        new_pg1 = c10d.new_group([0, 1])
530        new_pg2 = c10d.new_group([0, 1])
531        t2 = torch.rand(10, 10, device=device)
532
533        new_pg2.allreduce(t2).wait()
534        backend = pg._get_backend(torch.device(device))
535        # default PG's backend should have a split count of 1
536        self.assertEqual(backend.comm_split_count(), 1)
537        # shutdown all NCCL PGs in one shot
538        dist.destroy_process_group()
539
540    @requires_nccl()
541    @skip_but_pass_in_sandcastle_if(
542        torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs"
543    )
544    def test_file_store_check(self):
545        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
546        os.environ["TORCH_NCCL_ENABLE_MONITORING"] = "0"
547        # FileStore check() would be executed
548        os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
549        os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "0"
550
551        # self.file_name is created using "delete=False"
552        # e.g., self.file_name = tempfile.NamedTemporaryFile(delete=False).name
553        store = dist.FileStore(self.file_name, self.world_size)
554        dist.init_process_group(
555            backend="nccl", rank=self.rank, world_size=self.world_size, store=store
556        )
557        pg = dist.distributed_c10d._get_default_group()
558        self.assertEqual(pg.rank(), self.rank)
559        self.assertEqual(pg.size(), self.world_size)
560        # give enough time for check() to be executed multiple times
561        time.sleep(2)
562        dist.destroy_process_group()
563
564    def _check_nccl_timeout(self, expected_timeout):
565        pg = dist.distributed_c10d._get_default_group()
566        options = pg._get_backend(torch.device(f"cuda:{self.rank}")).options
567        self.assertEqual(options._timeout, expected_timeout)
568
569    @requires_nccl()
570    @skip_but_pass_in_sandcastle_if(not TEST_CUDA, "No GPUs available, skipping test")
571    def test_init_process_group_nccl_timeout(self):
572        # nccl is handled 'specially' inside init_process_group and its options class is different from the options
573        # used by the other PG's.  There are specific edge cases for nccl that need to be tested.
574
575        store = c10d.FileStore(self.file_name, self.world_size)
576        base_opts = dict(
577            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
578        )
579
580        # test the default value coming from the `init_process_group` kwarg default
581        dist.init_process_group(**base_opts)
582        self._check_nccl_timeout(torch.distributed.constants.default_pg_nccl_timeout)
583        dist.destroy_process_group()
584
585        # test that `kwarg` timeout takes effect
586        new_timeout = timedelta(seconds=123)
587        dist.init_process_group(**base_opts, timeout=new_timeout)
588        self._check_nccl_timeout(new_timeout)
589        dist.destroy_process_group()
590
591        # test that timeout value provided via `pg_options` kwarg is ignored and issues warning,
592        # 'timeout' kwarg (or its kwdefault) taking precedence
593        opts = dist.ProcessGroupNCCL.Options()
594        opts._timeout = timedelta(seconds=123)
595        with warnings.catch_warnings(record=True) as w:
596            dist.init_process_group(**base_opts, pg_options=opts)
597            # TODO(whc) i verified that we are indeed emitting this warning, and i can't figure out why i can't catch it.
598            # self.assertEqual(len(w), 1)
599            # self.assertTrue("pg_options._timeout was specified" in str(w[-1].message))
600        self._check_nccl_timeout(torch.distributed.constants.default_pg_nccl_timeout)
601        dist.destroy_process_group()
602
603        # test that timeout value provided via `pg_options` kwarg is ignored and issues warning,
604        # 'timeout' kwarg taking precedence
605        opts = dist.ProcessGroupNCCL.Options()
606        opts._timeout = timedelta(seconds=123)
607        dist.init_process_group(
608            **base_opts, pg_options=opts, timeout=timedelta(seconds=1240)
609        )
610        self._check_nccl_timeout(timedelta(seconds=1240))
611        dist.destroy_process_group()
612
613    @requires_nccl()
614    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
615    @parametrize("backend", [None, "nccl"])
616    def test_set_nccl_pg_timeout(self, backend):
617        store = c10d.FileStore(self.file_name, self.world_size)
618        opts = dict(
619            backend=backend,
620            store=store,
621            rank=self.rank,
622            world_size=self.world_size,
623            timeout=timedelta(seconds=123),
624        )
625        dist.init_process_group(**opts)
626        pg = dist.distributed_c10d._get_default_group()
627        pg.allreduce(torch.rand(10).cuda(self.rank))
628        self._check_nccl_timeout(timedelta(seconds=123))
629        pg._get_backend(torch.device(f"cuda:{self.rank}"))._set_default_timeout(
630            timedelta(seconds=23)
631        )
632        self._check_nccl_timeout(timedelta(seconds=23))
633        pg.allreduce(torch.rand(10).cuda(self.rank))
634        c10d.distributed_c10d._set_pg_timeout(timedelta(seconds=252), pg)
635        self._check_nccl_timeout(timedelta(seconds=252))
636
637    @requires_nccl()
638    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
639    @parametrize("backend", [None, "nccl"])
640    def test_extend_nccl_pg_timeout(self, backend):
641        torch.cuda.set_device(self.rank)
642        store = c10d.FileStore(self.file_name, self.world_size)
643        opts = dict(
644            backend=backend,
645            store=store,
646            rank=self.rank,
647            world_size=self.world_size,
648            timeout=timedelta(seconds=123),
649        )
650        dist.init_process_group(**opts)
651        pg = dist.distributed_c10d._get_default_group()
652        bankend = pg._get_backend(torch.device(f"cuda:{self.rank}"))
653        w = pg.allreduce(torch.rand(10).cuda(self.rank))
654        self.assertTrue(bankend._verify_work_timeout(w, timedelta(seconds=123)))
655        w.wait()
656        bankend._set_default_timeout(timedelta(seconds=3))
657        if self.rank == 0:
658            # Ideally we want to sleep for a very long time, but this is not
659            # feasible in unit test. So this is only a very tiny case.
660            time.sleep(5)
661            pg.allreduce(torch.rand(10).cuda(self.rank))
662            time.sleep(5)
663            pg.allreduce(torch.rand(5).cuda(self.rank))
664            w = pg.allreduce(torch.rand(10).cuda(self.rank))
665            self.assertTrue(bankend._verify_work_timeout(w, timedelta(seconds=3)))
666            w.wait()
667        else:
668            dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs(
669                timedelta(seconds=10)
670            )
671            w1 = pg.allreduce(torch.rand(10).cuda(self.rank))
672            w2 = pg.allreduce(torch.rand(5).cuda(self.rank))
673            self.assertTrue(bankend._verify_work_timeout(w1, timedelta(seconds=13)))
674            self.assertTrue(bankend._verify_work_timeout(w2, timedelta(seconds=13)))
675            w1.wait()
676            dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs(
677                timedelta(seconds=5)
678            )
679            # Since we are not block wait so use a sync here to leave enough time
680            # for watchdog to reset first timeout extension.
681            torch.cuda.synchronize(torch.device(f"cuda:{self.rank}"))
682            w = pg.allreduce(torch.rand(10).cuda(self.rank))
683            self.assertTrue(bankend._verify_work_timeout(w, timedelta(seconds=8)))
684            w.wait()
685
686    @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
687    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
688    def test_comm_split_optimization(self):
689        # Test the optimization of new groups that contain all world
690        # ranks use the "transparent" `ncclCommSplit` optimization.
691        store = c10d.FileStore(self.file_name, self.world_size)
692        pg = self._create_process_group_nccl(store, self.opts())
693
694        # Test lazy splitting behavior across each per-device backend.
695        for device in self.rank_to_GPU[self.rank]:
696            backend = pg._get_backend(torch.device(device))
697
698            # split doesn't happen unless the original process group has lazily
699            # created communicators, so first verify we haven't split even when
700            # making the new group and running an operation on the original pg.
701            ng = c10d.new_group()
702            tensor = torch.tensor([self.rank]).cuda(device)
703            pg.broadcast(tensor, 0)
704            self.assertEqual(backend.comm_split_count(), 0)
705
706            # The new group will force a split of the original on first use.
707            ng.broadcast(tensor, 0)
708            self.assertEqual(backend.comm_split_count(), 1)
709
710    @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
711    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
712    @skip_but_pass_in_sandcastle_if(
713        torch.cuda.nccl.version()[-1] == "x", "NCCL test not for NCCLX"
714    )
715    def test_comm_split_subgroup(self):
716        # Test `ncclCommSplit` for smaller subgroups of the world when
717        # we've passed a specific device_id to init_process_group.
718        store = c10d.FileStore(self.file_name, self.world_size)
719        device = torch.device(f"cuda:{self.rank}")
720        pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
721        backend = pg._get_backend(torch.device(device))
722
723        tensor = torch.full((1,), self.rank).cuda(device)
724        original_tensor = tensor.clone()
725        ng = c10d.new_group([0])
726
727        # comm split happens eagerly since device_id is passed to init_process_group.
728        self.assertEqual(backend.comm_split_count(), 1)
729        if self.rank == 0:
730            dist.broadcast(tensor, 0, group=ng)
731
732        # no additional comm split happens after a collective.
733        self.assertEqual(backend.comm_split_count(), 1)
734        self.assertEqual(tensor, original_tensor)
735        dist.destroy_process_group()
736
737    @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
738    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
739    def test_comm_split_group(self):
740        # Test `ncclCommSplit` for smaller subgroups of the world when
741        # we've passed a specific device_id to init_process_group.
742        store = c10d.FileStore(self.file_name, self.world_size)
743        device = torch.device(f"cuda:{self.rank}")
744        pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
745        backend = pg._get_backend(torch.device(device))
746
747        tensor = torch.full((1,), self.rank).cuda(device)
748        ng1 = c10d.split_group(pg, [[0, 1]])
749        backend1 = pg._get_backend(torch.device(device))
750
751        # check basic options are the same between parent and child
752        self.assertEqual(backend.options._timeout, backend1.options._timeout)
753        self.assertEqual(
754            backend.options.is_high_priority_stream,
755            backend1.options.is_high_priority_stream,
756        )
757        self.assertEqual(ng1.group_desc, "default_pg:split:0")
758
759        # comm split happens eagerly since device_id is passed to init_process_group.
760        self.assertEqual(backend.comm_split_count(), 1)
761        dist.broadcast(tensor, 0, group=ng1)
762        self.assertEqual(tensor, torch.full((1,), 0))
763
764        ng2 = c10d.split_group(pg, [[0, 1]])
765        self.assertEqual(ng2.group_desc, "default_pg:split:1")
766        self.assertEqual(backend.comm_split_count(), 2)
767
768        dist.destroy_process_group()
769
770    @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
771    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
772    def test_non_blocking_init(self):
773        # Test creating a pg using nonblocking mode but not eagerly
774        os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
775        os.environ["TORCH_NCCL_NONBLOCKING_TIMEOUT"] = "100"
776        store = c10d.FileStore(self.file_name, self.world_size)
777        device = self.rank_to_GPU[self.rank][0]
778        pg = self._create_process_group_nccl(store, self.opts())
779        backend = pg._get_backend(torch.device(device))
780        self.assertEqual(backend.comm_split_count(), 0)
781        reduce_tensor = torch.rand(10, 10, device=device)
782        # Run an allreduce, which should trigger a comm init for pg
783        pg.allreduce(reduce_tensor).wait()
784        new_pg = c10d.new_group()
785        # even after pg's collective call, new pg's comm is not initialized until its own collectcive calls
786        self.assertEqual(backend.comm_split_count(), 0)
787        broadcast_tensor = torch.tensor([self.rank]).cuda(device)
788        new_pg.broadcast(broadcast_tensor, 0).wait()
789        self.assertEqual(backend.comm_split_count(), 1)
790        dist.destroy_process_group()
791
792    @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
793    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
794    def test_non_blocking_with_eager_init(self):
795        # Test creating a pg eagerly with nonblocking mode when
796        # we've passed a specific device_id to init_process_group.
797        os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
798        os.environ["TORCH_NCCL_NONBLOCKING_TIMEOUT"] = "100"
799        store = c10d.FileStore(self.file_name, self.world_size)
800        device = torch.device(f"cuda:{self.rank}")
801        # bound device to triger eager init mode
802        pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
803        backend = pg._get_backend(torch.device(device))
804        self.assertEqual(backend.comm_split_count(), 0)
805        reduce_tensor = torch.rand(10, 10, device=device)
806        # Run an allreduce, comm should have already started initilizaing,
807        # but allreduce is issued to CUDA STREAM only after the initialization is a success
808        pg.allreduce(reduce_tensor).wait()
809        new_pg = c10d.new_group()
810        # new pg's comm is initialized eagerly
811        self.assertEqual(backend.comm_split_count(), 1)
812        broadcast_tensor = torch.tensor([self.rank]).cuda(device)
813        new_pg.broadcast(broadcast_tensor, 0).wait()
814        self.assertEqual(backend.comm_split_count(), 1)
815        dist.destroy_process_group()
816
817    @requires_nccl()
818    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
819    def test_get_uid(self):
820        store = c10d.FileStore(self.file_name, self.world_size)
821        device = torch.device(f"cuda:{self.rank}")
822        pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
823        from torch.distributed.distributed_c10d import _get_process_group_uid
824
825        self.assertEqual(_get_process_group_uid(pg), 0)
826        pg_2 = c10d.new_group([0, 1])
827        self.assertEqual(_get_process_group_uid(pg_2), 1)
828
829    @requires_nccl()
830    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
831    def test_set_process_group_desc(self):
832        store = c10d.FileStore(self.file_name, self.world_size)
833        device = torch.device(f"cuda:{self.rank}")
834        pg_default = self._create_process_group_nccl(
835            store, self.opts(), device_id=device
836        )
837        self.assertEqual(pg_default.group_desc, "default_pg")
838        pg_1 = c10d.new_group([0, 1], group_desc="test_purpose")
839        self.assertEqual(pg_1.group_desc, "test_purpose")
840        pg_2 = c10d.new_group([0, 1])
841        self.assertEqual(pg_2.group_desc, "undefined")
842
843
844class DistributedDataParallelTest(
845    test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
846):
847    def setUp(self):
848        super().setUp()
849        # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
850        # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
851        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
852        self._spawn_processes()
853
854    def _get_process_group(self):
855        store = self._get_store()
856        c10d.init_process_group(
857            "nccl", store=store, rank=self.rank, world_size=self.world_size
858        )
859        return c10d.distributed_c10d._get_default_group()
860
861    def _test_nccl_backend(
862        self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
863    ):
864        process_group = self._get_process_group()
865        self._test_ddp_with_process_group(
866            process_group, devices, device_ids, multi_device, gradient_as_bucket_view
867        )
868
869    @requires_nccl()
870    @skip_if_lt_x_gpu(2)
871    def test_nccl_propagate_error_reason(self):
872        # Need to use TORCH_NCCL_BLOCKING_WAIT and not ASYNC_ERROR_HANDLING,
873        # otherwise process will be taken down and we can't check for errors.
874        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
875        os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
876        # Need to disable TORCH_NCCL_DUMP_ON_TIMEOUT otherwise this test times out
877        os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "0"
878        store = c10d.FileStore(self.file_name, self.world_size)
879        # provide sufficient timeout to initialize NCCL comm.
880        pg = c10d.ProcessGroupNCCL(
881            store, self.rank, self.world_size, timeout=timedelta(seconds=15)
882        )
883        pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
884        pg.barrier().wait(timedelta(seconds=5))
885        # Simulate stuckness in rank 0.
886        if self.rank == 0:
887            pg_gloo.barrier().wait()
888        inp = torch.ones(1).cuda(self.rank)
889
890        if self.rank != 0:
891            # Time out due to rank 0 not calling into allreduce.
892            with self.assertRaises(dist.DistBackendError):
893                pg.allreduce([inp]).wait(timedelta(seconds=5))
894
895            # Now when nonzero rank attempts to use communicator, original failure reason should be logged.
896            try:
897                pg.allreduce([torch.ones(2).cuda(self.rank)]).wait()
898            except dist.DistBackendError as e:
899                self.assertTrue("aborted" in str(e))
900            else:
901                self.fail("Expected error to be raised!")
902
903            # Unblock rank 0
904            pg_gloo.barrier().wait()
905
906        # TODO: We can also test that if rank 0 attempts to use the communicator,
907        # then we should error out with the info that it was aborted due to
908        # timeout on another rank. Although this would only be the case after
909        # the watchdog has run on the rank, and there is no reliable way
910        # to confirm it has run.
911
912    @requires_nccl()
913    @skip_if_lt_x_gpu(2)
914    def test_nccl_backend_multi_device_ids_not_allowed(self):
915        int_devices = list(range(torch.cuda.device_count()))
916        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
917        with self.assertRaisesRegex(
918            ValueError, "device_ids can only be None or contain a single element."
919        ):
920            self._test_nccl_backend(devices, int_devices)
921
922    @requires_nccl()
923    @skip_if_lt_x_gpu(2)
924    def test_nccl_backend_single_device_module_device_ids_None(self):
925        self._test_nccl_backend(None, None)
926
927    @requires_nccl()
928    @skip_if_lt_x_gpu(2)
929    def test_nccl_backend_single_device_module_empty_device_ids(self):
930        # This tests the backward compatibility of accepting an empty list as `device_ids`,
931        # although we no longer document this in favor of the default value of `None`,
932        # which is consistent with multi-device modules and CPU modules.
933        self._test_nccl_backend(None, [])
934
935    @requires_nccl()
936    @skip_if_lt_x_gpu(4)
937    def test_nccl_backend_multi_device_module_device_ids_None(self):
938        int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
939        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
940        self._test_nccl_backend(devices, None, multi_device=True)
941
942    @requires_nccl()
943    @skip_if_lt_x_gpu(2)
944    def test_nccl_backend_1gpu_module_device_ids_integer_list(self):
945        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
946        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
947        self._test_nccl_backend(devices, int_devices)
948
949    @requires_nccl()
950    @skip_if_lt_x_gpu(2)
951    def test_nccl_backend_1gpu_module_device_ids_torch_device_list(self):
952        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
953        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
954        self._test_nccl_backend(devices, devices)
955
956    @requires_nccl()
957    @skip_if_lt_x_gpu(4)
958    def test_nccl_backend_2gpu_module(self):
959        int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
960        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
961        self._test_nccl_backend(devices, None, multi_device=True)
962
963    @requires_nccl()
964    @skip_if_lt_x_gpu(8)
965    def test_nccl_backend_4gpu_module(self):
966        int_devices = gpus_for_rank(self.world_size)[self.rank][:4]
967        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
968        self._test_nccl_backend(devices, None, multi_device=True)
969
970    @requires_nccl()
971    @skip_if_lt_x_gpu(4)
972    def test_ddp_multi_device_module_config(self):
973        gpus = gpus_for_rank(self.world_size)[self.rank]
974
975        self.assertTrue(len(gpus) >= 2, "expecting at least 2 gpus per process")
976
977        process_group = self._get_process_group()
978
979        gpus = gpus[:2]
980        model = DoubleGpuNet(gpus)
981
982        with self.assertRaisesRegex(
983            ValueError,
984            "DistributedDataParallel device_ids and output_device arguments only work with "
985            "single-device/multiple-device GPU modules or CPU modules",
986        ):
987            ddp_model = DistributedDataParallel(
988                model, output_device=gpus[1], process_group=process_group
989            )
990
991        with self.assertRaisesRegex(
992            ValueError, "device_ids can only be None or contain a single element."
993        ):
994            ddp_model = DistributedDataParallel(
995                model, device_ids=gpus, process_group=process_group
996            )
997
998        with self.assertRaisesRegex(
999            ValueError, "input module must be on the same type of devices"
1000        ):
1001            model.fc1 = model.fc1.cpu()
1002            ddp_model = DistributedDataParallel(model, process_group=process_group)
1003
1004        model = model.cpu()
1005        with self.assertRaisesRegex(
1006            ValueError, "device_ids can only be None or contain a single element."
1007        ):
1008            ddp_model = DistributedDataParallel(
1009                model, device_ids=gpus, process_group=process_group
1010            )
1011
1012    def _test_fp16(self, gradient_as_bucket_view=False):
1013        process_group = self._get_process_group()
1014
1015        gpus = gpus_for_rank(self.world_size)[self.rank]
1016        model = nn.Linear(1, 1, bias=False).cuda(gpus[0]).half()
1017        nn.init.constant_(model.weight, 1)
1018        ddp_model = DistributedDataParallel(
1019            model,
1020            device_ids=[gpus[0]],
1021            process_group=process_group,
1022            bucket_cap_mb=0.001,
1023            gradient_as_bucket_view=gradient_as_bucket_view,
1024        )
1025
1026        # Input 2**15, so that the gradients will overflow with a
1027        # world_size of 2, unless we normalize the gradient by the
1028        # world_size before the reduction
1029        input = torch.tensor([[2**15]]).cuda(gpus[0]).half()
1030
1031        # Step model
1032        ddp_model.train()
1033        output = ddp_model(input)
1034        loss = output.sum()
1035        loss.backward()
1036
1037        self.assertFalse(any(torch.isinf(p.grad).any() for p in ddp_model.parameters()))
1038
1039    @requires_nccl()
1040    @skip_if_lt_x_gpu(2)
1041    def test_fp16(self):
1042        self._test_fp16()
1043
1044    @requires_nccl()
1045    @skip_if_lt_x_gpu(2)
1046    def test_fp16_grad_is_view(self):
1047        self._test_fp16(gradient_as_bucket_view=True)
1048
1049    def _test_arbitrary_forward_return_value(self, gradient_as_bucket_view=False):
1050        """
1051        Note: this test can be sped up by only running it on a CPU module
1052        once DistributedDataParallel supports them.
1053        """
1054        process_group = self._get_process_group()
1055
1056        class ForwardReturnValueModule(nn.Module):
1057            def __init__(self) -> None:
1058                super().__init__()
1059                self.fc1 = nn.Linear(2, 10, bias=False)
1060                self.fc2 = nn.Linear(10, 4, bias=False)
1061                self.fc3 = nn.Linear(4, 4, bias=False)
1062                self.relu = nn.ReLU()
1063
1064            def forward(self, x, fn):
1065                x = self.relu(self.fc1(x))
1066                x = self.relu(self.fc2(x))
1067                # The first softmax does NOT include fc3 in its autograd graph
1068                # whereas the second softmax DOES. If we pass only the first
1069                # tensor we see in the output to the reducer, it marks the
1070                # gradient for fc3 as ready (because it doesn't show up). If
1071                # downstream uses of this return value choose to differentiate
1072                # against the second output tensor, it would still receive a
1073                # gradient and a callback for this tensor, resulting in a crash.
1074                return fn(
1075                    F.softmax(x, dim=1),
1076                    F.softmax(self.fc3(x), dim=1),
1077                )
1078
1079        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1080        model = DistributedDataParallel(
1081            ForwardReturnValueModule().float().to(device_id),
1082            device_ids=[device_id],
1083            process_group=process_group,
1084            gradient_as_bucket_view=gradient_as_bucket_view,
1085        )
1086
1087        batch_size = 4
1088        criterion = nn.CrossEntropyLoss()
1089        input = torch.rand([batch_size, 2], dtype=torch.float)
1090        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
1091            device_id
1092        )
1093
1094        # Always run "backward" to ensure the reducer is called by autograd.
1095        # If we don't correctly capture the output tensors from the return value,
1096        # the reducer won't see a hook for the unused parameter, and throw an error.
1097        # The correct capture is what we're testing in this function.
1098        def test(box, unbox):
1099            output = model(input, fn=box)
1100            loss = criterion(unbox(output), target)
1101            loss.backward()
1102
1103        # Test with identity return value
1104        test(
1105            box=lambda x, y: (x, y),
1106            unbox=lambda obj: obj[1],
1107        )
1108
1109        # Test with list return value
1110        test(
1111            box=lambda x, y: ["foo", x, "bar", y],
1112            unbox=lambda obj: obj[3],
1113        )
1114
1115        # Test with tuple return value
1116        test(
1117            box=lambda x, y: ("foo", x, "bar", y),
1118            unbox=lambda obj: obj[3],
1119        )
1120
1121        # Test with dict return value
1122        test(
1123            box=lambda x, y: {"foo": "bar", "a": x, "b": y},
1124            unbox=lambda obj: obj["b"],
1125        )
1126
1127        # Test with list with dict return value
1128        test(
1129            box=lambda x, y: ["foo", "bar", {"a": x, "b": y}],
1130            unbox=lambda obj: obj[2]["b"],
1131        )
1132
1133        # Test with dict with list return value
1134        test(
1135            box=lambda x, y: {"foo": "bar", "list": [0, x, 1, y]},
1136            unbox=lambda obj: obj["list"][3],
1137        )
1138
1139    @requires_nccl()
1140    @skip_if_lt_x_gpu(2)
1141    def test_arbitrary_forward_return_value(self):
1142        self._test_arbitrary_forward_return_value()
1143
1144    @requires_nccl()
1145    @skip_if_lt_x_gpu(2)
1146    def test_arbitrary_forward_return_value_grad_is_view(self):
1147        self._test_arbitrary_forward_return_value(gradient_as_bucket_view=True)
1148
1149    @requires_nccl()
1150    @skip_if_lt_x_gpu(2)
1151    def test_ddp_with_lazy_parameters(self):
1152        process_group = self._get_process_group()
1153        with self.assertRaisesRegex(
1154            RuntimeError, "Modules with uninitialized parameters"
1155        ):
1156            DistributedDataParallel(
1157                torch.nn.LazyLinear(10), process_group=process_group
1158            )
1159
1160    def _test_find_unused_parameters_kwarg(self, gradient_as_bucket_view=False):
1161        """
1162        Note: this test can be sped up by only running it on a CPU module
1163        once DistributedDataParallel supports them.
1164        """
1165        torch.cuda.set_device(self.rank)
1166        dist.init_process_group(
1167            backend="nccl",
1168            world_size=self.world_size,
1169            rank=self.rank,
1170            init_method=f"file://{self.file_name}",
1171        )
1172        process_group = c10d.distributed_c10d._get_default_group()
1173
1174        class FindUnusedParametersModule(nn.Module):
1175            def __init__(self) -> None:
1176                super().__init__()
1177                self.fc1 = nn.Linear(2, 10, bias=False)
1178                self.fc2 = nn.Linear(10, 4, bias=False)
1179                self.fc3 = nn.Linear(4, 4, bias=False)
1180                self.relu = nn.ReLU()
1181
1182            def forward(self, x):
1183                x = self.relu(self.fc1(x))
1184                x = self.relu(self.fc2(x))
1185                # Return the fc3 module so that the caller can invoke it
1186                # outside of the forward function. While this is bad practice,
1187                # we can use it to trigger a reducer error.
1188                return (F.softmax(x, dim=1), self.fc3)
1189
1190        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1191        batch_size = 4
1192        criterion = nn.CrossEntropyLoss()
1193        input = torch.rand([batch_size, 2], dtype=torch.float)
1194        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
1195            device_id
1196        )
1197
1198        ddp_model = None
1199
1200        def test_find_unused_parameters(
1201            find_unused_parameters, test_default=False, gradient_as_bucket_view=False
1202        ):
1203            if test_default:
1204                model = DistributedDataParallel(
1205                    FindUnusedParametersModule().float().to(device_id),
1206                    device_ids=[device_id],
1207                    process_group=process_group,
1208                    gradient_as_bucket_view=gradient_as_bucket_view,
1209                )
1210            else:
1211                model = DistributedDataParallel(
1212                    FindUnusedParametersModule().float().to(device_id),
1213                    device_ids=[device_id],
1214                    process_group=process_group,
1215                    find_unused_parameters=find_unused_parameters,
1216                    gradient_as_bucket_view=gradient_as_bucket_view,
1217                )
1218            nonlocal ddp_model
1219            ddp_model = model
1220
1221            output, fc3 = model(input)
1222            output = fc3(output)
1223            loss = criterion(output, target)
1224            loss.backward()
1225
1226        # First test that finding unused params under these conditions is to
1227        # trigger an error when `backward` is called (because fc3 is an unused
1228        # parameter and will therefore be marked ready twice).
1229        try:
1230            test_find_unused_parameters(
1231                True, gradient_as_bucket_view=gradient_as_bucket_view
1232            )
1233        except Exception as ex:
1234            self.assertTrue(
1235                str(ex).startswith(
1236                    "Expected to mark a variable ready only once.",
1237                )
1238            )
1239            unused_index = 2
1240            unused_index_str = f"Parameter at index {unused_index}"
1241            model = ddp_model.module
1242            for module_name, module in model.named_modules():
1243                if module == model.fc3:
1244                    for parameter_name, _ in module.named_parameters(recurse=False):
1245                        unused_fqn = f"{module_name}.{parameter_name}"
1246                        # Only one such parameter in model.fc3, since bias=False
1247                        break
1248
1249            if dist.get_debug_level() != dist.DebugLevel.OFF:
1250                unused_index_str += f" with name {unused_fqn}"
1251
1252            self.assertTrue(unused_index_str in str(ex))
1253        else:
1254            self.fail("Expected exception")
1255
1256        dist.barrier(process_group)
1257
1258        # Then test that the default behavior can be overridden by setting
1259        # `find_unused_parameters=False`.
1260        try:
1261            test_find_unused_parameters(
1262                False, gradient_as_bucket_view=gradient_as_bucket_view
1263            )
1264        except Exception as ex:
1265            self.fail(f"Unexpected exception: {ex}")
1266
1267        # Test find_unused_parameters defaults to False
1268        try:
1269            test_find_unused_parameters(
1270                True, test_default=True, gradient_as_bucket_view=gradient_as_bucket_view
1271            )
1272        except Exception as ex:
1273            self.fail(f"Unexpected exception: {ex}")
1274
1275    # TODO: Combine the following tests once https://github.com/pytorch/pytorch/issues/55967
1276    # is resolved.
1277    @requires_nccl()
1278    @skip_if_lt_x_gpu(2)
1279    @with_dist_debug_levels(levels=["DETAIL"])
1280    def test_find_unused_parameters_kwarg_debug_detail(self):
1281        self._test_find_unused_parameters_kwarg()
1282
1283    @requires_nccl()
1284    @skip_if_lt_x_gpu(2)
1285    @with_dist_debug_levels(levels=["INFO"])
1286    def test_find_unused_parameters_kwarg_debug_info(self):
1287        self._test_find_unused_parameters_kwarg()
1288
1289    @requires_nccl()
1290    @skip_if_lt_x_gpu(2)
1291    @with_dist_debug_levels(levels=["OFF"])
1292    def test_find_unused_parameters_kwarg_debug_off(self):
1293        self._test_find_unused_parameters_kwarg()
1294
1295    @requires_nccl()
1296    @skip_if_lt_x_gpu(2)
1297    @with_dist_debug_levels(levels=["DETAIL"])
1298    def test_find_unused_parameters_kwarg_grad_is_view_debug_detail(self):
1299        self._test_find_unused_parameters_kwarg(gradient_as_bucket_view=True)
1300
1301    @requires_nccl()
1302    @skip_if_lt_x_gpu(2)
1303    @with_dist_debug_levels(levels=["INFO"])
1304    def test_find_unused_parameters_kwarg_grad_is_view_debug_info(self):
1305        self._test_find_unused_parameters_kwarg(gradient_as_bucket_view=True)
1306
1307    @requires_nccl()
1308    @skip_if_lt_x_gpu(2)
1309    @with_dist_debug_levels(levels=["OFF"])
1310    def test_find_unused_parameters_kwarg_grad_is_view_debug_off(self):
1311        self._test_find_unused_parameters_kwarg(gradient_as_bucket_view=True)
1312
1313    def _test_multiple_outputs_multiple_backward(self, gradient_as_bucket_view=False):
1314        """
1315        Note: this test can be sped up by only running it on a CPU module
1316        once DistributedDataParallel supports them.
1317        """
1318        process_group = self._get_process_group()
1319
1320        class MultipleOutputModule(nn.Module):
1321            def __init__(self) -> None:
1322                super().__init__()
1323
1324                def define_module():
1325                    return nn.Sequential(
1326                        nn.Linear(2, 10, bias=False),
1327                        nn.ReLU(),
1328                        nn.Linear(10, 4, bias=False),
1329                        nn.ReLU(),
1330                    )
1331
1332                self.module0 = define_module()
1333                self.module1 = define_module()
1334
1335            def forward(self, x):
1336                return (
1337                    F.softmax(self.module0(x), dim=1),
1338                    F.softmax(self.module1(x), dim=1),
1339                )
1340
1341        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1342        model = DistributedDataParallel(
1343            MultipleOutputModule().float().to(device_id),
1344            device_ids=[device_id],
1345            process_group=process_group,
1346            gradient_as_bucket_view=gradient_as_bucket_view,
1347        )
1348
1349        batch_size = 4
1350        criterion = nn.CrossEntropyLoss()
1351        input = torch.rand([batch_size, 2], dtype=torch.float)
1352        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
1353            device_id
1354        )
1355
1356        # Compute loss and gradients for both outputs
1357        output1, output2 = model(input)
1358        loss1 = criterion(output1, target)
1359        loss1.backward()
1360        loss2 = criterion(output2, target)
1361        loss2.backward()
1362
1363    @requires_nccl()
1364    @skip_if_lt_x_gpu(2)
1365    def test_multiple_outputs_multiple_backward(self):
1366        self._test_multiple_outputs_multiple_backward()
1367
1368    @requires_nccl()
1369    @skip_if_lt_x_gpu(2)
1370    def test_multiple_outputs_multiple_backward_grad_is_view(self):
1371        self._test_multiple_outputs_multiple_backward(gradient_as_bucket_view=True)
1372
1373    @requires_nccl()
1374    @skip_if_lt_x_gpu(2)
1375    def test_no_grad(self):
1376        """
1377        Note: this test can be sped up by only running it on a CPU module
1378        once DistributedDataParallel supports them.
1379        """
1380        process_group = self._get_process_group()
1381
1382        class NoGradModule(nn.Module):
1383            def __init__(self) -> None:
1384                super().__init__()
1385                self.fc1 = nn.Linear(2, 10, bias=False)
1386                self.fc2 = nn.Linear(10, 4, bias=False)
1387                self.relu = nn.ReLU()
1388
1389            def forward(self, x):
1390                x = self.relu(self.fc1(x))
1391                x = self.relu(self.fc2(x))
1392                return F.softmax(x, dim=1)
1393
1394        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1395        model = DistributedDataParallel(
1396            NoGradModule().float().to(device_id),
1397            device_ids=[device_id],
1398            process_group=process_group,
1399        )
1400
1401        batch_size = 4
1402        input = torch.rand([batch_size, 2], dtype=torch.float)
1403
1404        def check_no_grads():
1405            for p in model.parameters():
1406                self.assertTrue(p.requires_grad)
1407                self.assertIsNone(p.grad)
1408
1409        # After initialization, no parameter has their gradient set.
1410        check_no_grads()
1411
1412        # Run `forward` function with torch.no_grad()
1413        with torch.no_grad():
1414            output = model(input)
1415            self.assertTrue(isinstance(output, torch.Tensor))
1416
1417        # No parameter should have their gradient set.
1418        check_no_grads()
1419
1420    def _test_accumulate_gradients_module(self, gradient_as_bucket_view=False):
1421        # This is NOT the recommended way to implement accumulating grads, but
1422        # we would like to make sure DDP does not mess up with the underlying
1423        # module.
1424        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
1425        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
1426        process_group = self._get_process_group()
1427        global_batch_size = self.world_size
1428
1429        model, ddp_model, input, target = self._prepare_single_device_module(
1430            process_group, devices, devices, global_batch_size, gradient_as_bucket_view
1431        )
1432
1433        def step_model(model, input, target):
1434            model.train()
1435            output = model(input)
1436            loss = F.mse_loss(output, target.to(output.device))
1437            loss.backward()
1438
1439        # ensure accumulate grads works with no_grad
1440        with torch.no_grad():
1441            ddp_model.train()
1442            ddp_model.module(input)
1443
1444        # Check two model parameters over 4 iterations.
1445        # Use 4 iterations because we alternate between reducing and
1446        # not reducing and want to make sure we switch both ways.
1447        for iteration in range(4):
1448            step_model(model, input, target)
1449
1450            if iteration % 2 == 0:
1451                # Skip gradients sync without calling prepare_for_backward
1452                step_model(
1453                    ddp_model.module,
1454                    input[self.rank : (self.rank + 1)],
1455                    target[self.rank : (self.rank + 1)],
1456                )
1457                for i, j in zip(model.parameters(), ddp_model.parameters()):
1458                    self.assertNotEqual(i.grad, j.grad)
1459            else:
1460                step_model(
1461                    ddp_model,
1462                    input[self.rank : (self.rank + 1)],
1463                    target[self.rank : (self.rank + 1)],
1464                )
1465                for i, j in zip(model.parameters(), ddp_model.parameters()):
1466                    self.assertEqual(i.grad, j.grad, rtol=1.3e-06, atol=5e-5)
1467
1468            # Shuffle the input so that DDP input is different
1469            torch.manual_seed(1337 + iteration)
1470            input = input[torch.randperm(global_batch_size)]
1471
1472    @requires_nccl()
1473    @skip_if_lt_x_gpu(2)
1474    def test_accumulate_gradients_module(self):
1475        self._test_accumulate_gradients_module()
1476
1477    @requires_nccl()
1478    @skip_if_lt_x_gpu(2)
1479    def test_accumulate_gradients_module_with_grad_is_view(self):
1480        self._test_accumulate_gradients_module(gradient_as_bucket_view=True)
1481
1482    @requires_nccl()
1483    @skip_if_lt_x_gpu(2)
1484    def test_failure_recovery(self):
1485        process_group = self._get_process_group()
1486
1487        # need to create a separate file for the recovered FileStore, because
1488        # the original one will be deleted when destructing the first FileStore.
1489        recovery_filename = self.file_name + "_recovery"
1490
1491        if self.rank == 0:
1492            # the file will be deleted by the recovered FileStore
1493            open(recovery_filename, "w").close()
1494
1495        # not necessary to run barrier here, as DDP will synchronize
1496
1497        class TestModel(nn.Module):
1498            def __init__(self) -> None:
1499                super().__init__()
1500                self.fc1 = nn.Linear(2, 10, bias=False)
1501                self.fc2 = nn.Linear(10, 4, bias=False)
1502                self.relu = nn.ReLU()
1503
1504            def forward(self, x):
1505                x = self.relu(self.fc1(x))
1506                x = self.relu(self.fc2(x))
1507                return F.softmax(x, dim=1)
1508
1509        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1510        model = TestModel().float().to(device_id)
1511        ddp = DistributedDataParallel(
1512            model,
1513            device_ids=[device_id],
1514            process_group=process_group,
1515        )
1516
1517        batch_size = 4
1518        criterion = nn.CrossEntropyLoss()
1519        input = torch.rand([batch_size, 2], dtype=torch.float)
1520        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
1521            device_id
1522        )
1523
1524        for _ in range(6):
1525            output = ddp(input)
1526            loss = criterion(output, target)
1527            loss.backward()
1528
1529        del ddp
1530        c10d.destroy_process_group(process_group)
1531
1532        store = c10d.FileStore(recovery_filename, self.world_size)
1533        c10d.init_process_group(
1534            "nccl", store=store, rank=self.rank, world_size=self.world_size
1535        )
1536        process_group = c10d.distributed_c10d._get_default_group()
1537        ddp = DistributedDataParallel(
1538            model,
1539            device_ids=[device_id],
1540            process_group=process_group,
1541        )
1542
1543        input = torch.rand([batch_size, 2], dtype=torch.float)
1544        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
1545            device_id
1546        )
1547        for _ in range(6):
1548            output = ddp(input)
1549            loss = criterion(output, target)
1550            loss.backward()
1551
1552    @requires_nccl()
1553    @skip_if_lt_x_gpu(2)
1554    def test_pass_default_pg(self):
1555        dist.init_process_group(
1556            "nccl",
1557            init_method=f"file://{self.file_name}",
1558            world_size=self.world_size,
1559            rank=self.rank,
1560        )
1561
1562        default_pg = c10d.distributed_c10d._get_default_group()
1563        dist.destroy_process_group(default_pg)
1564        self.assertFalse(dist.is_initialized())
1565
1566    def _test_grad_layout(self, replica_devices, layer_devs, local_batch_size):
1567        process_group = self._get_process_group()
1568
1569        global_batch_size = local_batch_size * self.world_size
1570
1571        # Carry out some trials with small buckets and some with big buckets.
1572        bucketsizes = (0.000001, 25)
1573        # Tuples of lists.  Each list describes per-layer characteristics for one trial.
1574        layer_formats = (
1575            [torch.contiguous_format] * 4,
1576            [torch.channels_last] * 2 + [torch.contiguous_format] * 2,
1577            [torch.channels_last] * 4,
1578        )
1579        layer_dtypes = (
1580            [torch.float] * 4,
1581            [torch.float] * 2 + [torch.half] * 2,
1582            [torch.half] * 4,
1583        )
1584
1585        input_dev = layer_devs[0] if isinstance(layer_devs, list) else layer_devs
1586        target_dev = layer_devs[-1] if isinstance(layer_devs, list) else layer_devs
1587        input = torch.randn(
1588            (global_batch_size, 8, 8, 8), device=input_dev, dtype=torch.float
1589        )
1590        target = torch.randn(
1591            (global_batch_size, 8, 4, 4), device=target_dev, dtype=torch.float
1592        )
1593        local_batch_start = self.rank * local_batch_size
1594        local_batch_end = (self.rank + 1) * local_batch_size
1595
1596        # Reducer.cpp sneakily creates one "initial bucket" that ignores the "bucket_cap_mb"
1597        # argument.  The following makes sure the initial bucket also complies.
1598        @contextmanager
1599        def first_bucket_size(ddp_bucket_mb):
1600            old_DEFAULT_FIRST_BUCKET_BYTES = dist._DEFAULT_FIRST_BUCKET_BYTES
1601            dist._DEFAULT_FIRST_BUCKET_BYTES = int(ddp_bucket_mb * 1.0e6)
1602            try:
1603                yield
1604            finally:
1605                dist._DEFAULT_FIRST_BUCKET_BYTES = old_DEFAULT_FIRST_BUCKET_BYTES
1606
1607        with torch.backends.cudnn.flags(
1608            enabled=True, deterministic=True, benchmark=False
1609        ):
1610            for formats, dtypes, bucketsize in product(
1611                layer_formats, layer_dtypes, bucketsizes
1612            ):
1613                with first_bucket_size(bucketsize):
1614                    model_msg = f"rank = {self.rank} formats = {formats} dtypes = {dtypes} bucketsize = {bucketsize} "
1615                    try:
1616                        m = ConvNet(layer_devs, formats, dtypes)
1617                        m_ddp = DistributedDataParallel(
1618                            copy.deepcopy(m),
1619                            device_ids=replica_devices,
1620                            process_group=process_group,
1621                            bucket_cap_mb=bucketsize,
1622                        )
1623                        opt = torch.optim.SGD(m.parameters(), lr=0.1)
1624                        opt_ddp = torch.optim.SGD(m_ddp.parameters(), lr=0.1)
1625                        has_half = any(p.dtype is torch.half for p in m.parameters())
1626                        tol = 1.0e-3 if has_half else 1.0e-5
1627                    except BaseException:
1628                        # Prints case-specific debugging info to narrow down failing case.
1629                        print(
1630                            "Caught exception during model creation for " + model_msg,
1631                            flush=True,
1632                        )
1633                        raise
1634                    # 3 iters:  First iter creates grads, second iter retests after rebucketing,
1635                    # third iter tries zeroed grads.
1636                    for it in range(3):
1637                        iter_msg = f"iter = {it} " + model_msg
1638                        named_msg = iter_msg
1639                        try:
1640                            F.mse_loss(m(input).float(), target).backward()
1641                            F.mse_loss(
1642                                m_ddp(input[local_batch_start:local_batch_end]).float(),
1643                                target[local_batch_start:local_batch_end],
1644                            ).backward()
1645                            for i, ((layer_name, m_child), m_ddp_child) in enumerate(
1646                                zip(m.named_children(), m_ddp.module.children())
1647                            ):
1648                                named_msg = layer_name + ".weight" + " " + iter_msg
1649                                self.assertTrue(
1650                                    m_child.weight.grad.is_contiguous(
1651                                        memory_format=formats[i]
1652                                    ),
1653                                    named_msg,
1654                                )
1655                                self.assertTrue(
1656                                    m_ddp_child.weight.grad.is_contiguous(
1657                                        memory_format=formats[i]
1658                                    ),
1659                                    named_msg,
1660                                )
1661                                for j, ((param_name, p), p_ddp) in enumerate(
1662                                    zip(
1663                                        m_child.named_parameters(),
1664                                        m_ddp_child.parameters(),
1665                                    )
1666                                ):
1667                                    named_msg = (
1668                                        layer_name + "." + param_name + " " + iter_msg
1669                                    )
1670                                    self.assertEqual(
1671                                        p.grad, p_ddp.grad, rtol=tol, atol=tol
1672                                    )
1673                            opt.step()
1674                            opt_ddp.step()
1675                            if it == 0:
1676                                for p, p_ddp in zip(m.parameters(), m_ddp.parameters()):
1677                                    p.grad = None
1678                                    p_ddp.grad = None
1679                            else:
1680                                m.zero_grad()
1681                                m_ddp.zero_grad()
1682                        except BaseException:
1683                            # Makes sure we still get info if an error occurred somewhere other than the asserts.
1684                            print(
1685                                "Caught exception during iterations at " + named_msg,
1686                                flush=True,
1687                            )
1688                            raise
1689
1690    @requires_nccl()
1691    @skip_if_lt_x_gpu(2)
1692    def test_grad_layout_1devicemodule_1replicaperprocess(self):
1693        dev0 = torch.device("cuda:" + str(gpus_for_rank(self.world_size)[self.rank][0]))
1694        # Tells DDP to use just one device.
1695        replica_devices = [dev0]
1696        # Tells _test_grad_layout to construct ConvNet with all layers on this process's first assigned device.
1697        layer_devs = dev0
1698        local_batch_size = 8
1699        self._test_grad_layout(replica_devices, layer_devs, local_batch_size)
1700
1701    @requires_nccl()
1702    @skip_if_lt_x_gpu(4)
1703    @skip_if_rocm_multiprocess
1704    def test_grad_layout_2devicemodule(self):
1705        int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
1706        dev0 = torch.device("cuda:" + str(int_devices[0]))
1707        dev1 = torch.device("cuda:" + str(int_devices[1]))
1708        # DDP's default behavior for a multi-device module is "don't replicate."
1709        replica_devices = None
1710        # Tells _test_grad_layout to constructs this process's ConvNet on 2 devices, with 2 layers on each device.
1711        layer_devs = [dev0] * 2 + [dev1] * 2
1712        local_batch_size = 8
1713        self._test_grad_layout(replica_devices, layer_devs, local_batch_size)
1714
1715    @requires_nccl()
1716    @skip_if_lt_x_gpu(2)
1717    def test_param_layout_mismatch_error(self):
1718        process_group = self._get_process_group()
1719
1720        dev0 = torch.device("cuda:" + str(gpus_for_rank(self.world_size)[self.rank][0]))
1721        layer_devs = dev0
1722        layer_formats = (
1723            [torch.contiguous_format] * 4
1724            if self.rank == 0
1725            else [torch.channels_last] * 4
1726        )
1727        layer_dtypes = [torch.float] * 4
1728
1729        m = ConvNet(layer_devs, layer_formats, layer_dtypes)
1730        if self.rank == 0:
1731            m_ddp = DistributedDataParallel(
1732                m, device_ids=[dev0], process_group=process_group
1733            )
1734        else:
1735            with self.assertRaisesRegex(
1736                RuntimeError,
1737                ".* appears not to match strides of the same param in process 0",
1738            ):
1739                m_ddp = DistributedDataParallel(
1740                    m, device_ids=[dev0], process_group=process_group
1741                )
1742
1743    def _gpu_model_with_ddp_comm_hook(
1744        self,
1745        process_group,
1746        hook=None,
1747        gradient_as_bucket_view=False,
1748        state=None,
1749        static_graph=False,
1750    ):
1751        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1752        gpu_model = DistributedDataParallel(
1753            ModuleForDdpCommHook().to(device_id),
1754            device_ids=[device_id],
1755            process_group=process_group,
1756            gradient_as_bucket_view=gradient_as_bucket_view,
1757            static_graph=static_graph,
1758        )
1759
1760        # Register a DDP communication hook if any.
1761        if hook is not None:
1762            gpu_model.register_comm_hook(state, hook)
1763
1764        return gpu_model
1765
1766    @requires_nccl()
1767    @skip_if_lt_x_gpu(2)
1768    def test_ddp_comm_hook_future_passing_gpu_nccl(self):
1769        """
1770        This unit test verifies whether the Future object is passed properly using nccl backend.
1771        The hook callback function creates a Future object and sets a value to it.
1772        """
1773        process_group = self._get_process_group()
1774
1775        # Get GPU model with simple_hook registered.
1776        gpu_model = self._gpu_model_with_ddp_comm_hook(process_group, self._simple_hook)
1777
1778        # check whether the grads are equal to what simple_hook's then callback returns.
1779        # without the comm_hook, result would be 0.25 * torch.ones(2, 2).
1780        self._run_and_verify_hook(gpu_model, 8, 2 * torch.ones(2, 2))
1781
1782    def _test_ddp_comm_hook_allreduce_hook_nccl(
1783        self, gradient_as_bucket_view=False, static_graph=False
1784    ):
1785        """
1786        This unit test verifies whether a DDP communication hook that just calls
1787        allreduce gives the same result with the case of no hook registered.
1788        Without the then callback, the future_value in reducer is no longer
1789        a PyObject, and this unit test verifies future_value is properly checked.
1790        """
1791        process_group = self._get_process_group()
1792
1793        def allreduce_hook(
1794            state: object, bucket: dist.GradBucket
1795        ) -> torch.futures.Future[torch.Tensor]:
1796            tensors = [bucket.buffer() / self.world_size]
1797            return (
1798                process_group.allreduce(tensors)
1799                .get_future()
1800                .then(lambda fut: fut.value()[0])
1801            )
1802
1803        # Get GPU model with allreduce_hook registered.
1804        gpu_model = self._gpu_model_with_ddp_comm_hook(
1805            process_group, allreduce_hook, gradient_as_bucket_view, static_graph
1806        )
1807
1808        # check whether the grads are equal to what DDP without hook would return.
1809        self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))
1810
1811    def _test_default_ddp_comm_hooks_nccl(self, gradient_as_bucket_view=False):
1812        """
1813        This unit test verifies whether default Python DDP communication hooks ALLREDUCE, FP16_COMPRESS
1814        and BF16_COMPRESS, can give the same result with the case of no hook registered.
1815        """
1816        process_group = self._get_process_group()
1817
1818        # For these default DDP comm hooks, the only state is process group.
1819        state = process_group
1820        hook_options = [default.allreduce_hook, default.fp16_compress_hook]
1821        if (
1822            not TEST_WITH_ROCM
1823            and BFLOAT16_AVAILABLE
1824            and c10d.is_nccl_available()
1825            and torch.cuda.nccl.version() >= (2, 10)
1826        ):
1827            hook_options.append(default.bf16_compress_hook)
1828        for hook in hook_options:
1829            # Get GPU model with the hook registered.
1830            # The first arg 'process_group' is used for initializing the test environment,
1831            # so it cannot be replaced by 'state', although they have the same value.
1832            gpu_model = self._gpu_model_with_ddp_comm_hook(
1833                process_group, hook, gradient_as_bucket_view, state
1834            )
1835
1836            # check whether the grads are equal to what DDP without hook would return.
1837            self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))
1838
1839    def _test_fp16_compress_wrapper(self, gradient_as_bucket_view=False):
1840        """
1841        This unit test verifies whether wrapping the ALLREDUCE and POWER_SGD hooks with
1842        the FP16_WRAPPER can give the same result as when there is no hook registered.
1843        """
1844        process_group = self._get_process_group()
1845        powerSGD_state = powerSGD.PowerSGDState(process_group=process_group)
1846
1847        hook_args = [
1848            (powerSGD.powerSGD_hook, powerSGD_state),
1849            (default.allreduce_hook, process_group),
1850        ]
1851
1852        for hook, state in hook_args:
1853            gpu_model = self._gpu_model_with_ddp_comm_hook(
1854                process_group,
1855                default.fp16_compress_wrapper(hook),
1856                gradient_as_bucket_view,
1857                state,
1858            )
1859
1860            # check whether the grads are equal to what DDP without hook would return.
1861            self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))
1862
1863    def _test_bf16_compress_wrapper(self, gradient_as_bucket_view=False):
1864        """
1865        This unit test verifies whether wrapping the ALLREDUCE and POWER_SGD hooks with
1866        the BF16_WRAPPER can give the same result as when there is no hook registered.
1867        """
1868        process_group = self._get_process_group()
1869        powerSGD_state = powerSGD.PowerSGDState(process_group=process_group)
1870
1871        hook_args = [
1872            (powerSGD.powerSGD_hook, powerSGD_state),
1873            (default.allreduce_hook, process_group),
1874        ]
1875
1876        for hook, state in hook_args:
1877            gpu_model = self._gpu_model_with_ddp_comm_hook(
1878                process_group,
1879                default.bf16_compress_wrapper(hook),
1880                gradient_as_bucket_view,
1881                state,
1882            )
1883
1884            # check whether the grads are equal to what DDP without hook would return.
1885            self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))
1886
1887    def _test_powerSGD_ddp_comm_hook_nccl(self, gradient_as_bucket_view=False):
1888        """
1889        This unit test verifies whether Python DDP communication hook POWER_SGD
1890        can give the same result with the case of no hook registered.
1891        """
1892        process_group = self._get_process_group()
1893
1894        # Get GPU model with the hook registered.
1895        # Test the hook with different algorithmic configs.
1896        for use_error_feedback, warm_start, batch_tensors_with_same_shape in product(
1897            [True, False],
1898            [True, False],
1899            [True, False],
1900        ):
1901            state = powerSGD.PowerSGDState(
1902                process_group=process_group,
1903                matrix_approximation_rank=1,
1904                use_error_feedback=use_error_feedback,
1905                warm_start=warm_start,
1906                batch_tensors_with_same_shape=batch_tensors_with_same_shape,
1907            )
1908            for hook in [powerSGD.powerSGD_hook, powerSGD.batched_powerSGD_hook]:
1909                gpu_model = self._gpu_model_with_ddp_comm_hook(
1910                    process_group, hook, gradient_as_bucket_view, state
1911                )
1912
1913                # check whether the grads are equal to what DDP without hook would return.
1914                self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))
1915
1916    def _test_builtin_ddp_comm_hooks_nccl(self, gradient_as_bucket_view=False):
1917        """
1918        This unit test verifies whether built-in C++ DDP communication hooks ALLREDUCE and FP16_COMPRESS
1919        can give the same result with the case of no hook registered.
1920        """
1921        process_group = self._get_process_group()
1922
1923        for comm_hook_type in [
1924            dist.BuiltinCommHookType.ALLREDUCE,
1925            dist.BuiltinCommHookType.FP16_COMPRESS,
1926        ]:
1927            # Get GPU model with the built-in communication hook.
1928            gpu_model = self._gpu_model_with_builtin_ddp_comm_hook(
1929                process_group, comm_hook_type, gradient_as_bucket_view
1930            )
1931
1932            # check whether the grads are equal to what DDP without hook would return.
1933            self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))
1934
1935    @requires_nccl()
1936    @skip_if_lt_x_gpu(2)
1937    def test_ddp_comm_hook_allreduce_hook_nccl(self):
1938        self._test_ddp_comm_hook_allreduce_hook_nccl()
1939
1940    @requires_nccl()
1941    @skip_if_lt_x_gpu(2)
1942    def test_default_ddp_comm_hooks_nccl(self):
1943        self._test_default_ddp_comm_hooks_nccl()
1944
1945    @requires_nccl()
1946    @skip_if_lt_x_gpu(2)
1947    def test_fp16_compress_wrapper_nccl(self):
1948        self._test_fp16_compress_wrapper()
1949
1950    @requires_nccl()
1951    @requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
1952    @skip_but_pass_in_sandcastle_if(
1953        not BFLOAT16_AVAILABLE,
1954        "BFloat16 is only supported by CUDA 11+",
1955    )
1956    @skip_if_lt_x_gpu(2)
1957    def test_bf16_compress_wrapper_nccl(self):
1958        self._test_bf16_compress_wrapper()
1959
1960    @requires_nccl()
1961    @skip_if_lt_x_gpu(2)
1962    def test_builtin_ddp_comm_hooks_nccl(self):
1963        self._test_builtin_ddp_comm_hooks_nccl()
1964
1965    @requires_nccl()
1966    @skip_if_lt_x_gpu(2)
1967    def test_powerSGD_ddp_comm_hook_nccl(self):
1968        self._test_powerSGD_ddp_comm_hook_nccl()
1969
1970    @requires_nccl()
1971    @skip_if_lt_x_gpu(2)
1972    def test_ddp_comm_hook_allreduce_hook_nccl_grad_is_view(self):
1973        self._test_ddp_comm_hook_allreduce_hook_nccl(gradient_as_bucket_view=True)
1974
1975    @requires_nccl()
1976    @skip_if_lt_x_gpu(2)
1977    def test_ddp_comm_hook_allreduce_hook_nccl_static_graph(self):
1978        self._test_ddp_comm_hook_allreduce_hook_nccl(static_graph=True)
1979
1980    @requires_nccl()
1981    @skip_if_lt_x_gpu(2)
1982    def test_default_ddp_comm_hooks_nccl_is_view(self):
1983        self._test_default_ddp_comm_hooks_nccl(gradient_as_bucket_view=True)
1984
1985    @requires_nccl()
1986    @skip_if_lt_x_gpu(2)
1987    def test_fp16_compress_wrapper_is_view(self):
1988        self._test_fp16_compress_wrapper(gradient_as_bucket_view=True)
1989
1990    @requires_nccl()
1991    @requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
1992    @skip_but_pass_in_sandcastle_if(
1993        not BFLOAT16_AVAILABLE,
1994        "BFloat16 is only supported by CUDA 11+",
1995    )
1996    @skip_if_lt_x_gpu(2)
1997    def test_bf16_compress_wrapper_is_view(self):
1998        self._test_bf16_compress_wrapper(gradient_as_bucket_view=True)
1999
2000    @requires_nccl()
2001    @skip_if_lt_x_gpu(2)
2002    def test_builtin_ddp_comm_hooks_nccl_grad_is_view(self):
2003        self._test_builtin_ddp_comm_hooks_nccl(gradient_as_bucket_view=True)
2004
2005    @requires_nccl()
2006    @skip_if_lt_x_gpu(2)
2007    def test_powerSGD_ddp_comm_hook_nccl_grad_is_view(self):
2008        self._test_powerSGD_ddp_comm_hook_nccl(gradient_as_bucket_view=True)
2009
2010    @requires_nccl()
2011    @skip_if_lt_x_gpu(2)
2012    def test_ddp_comm_hook_allreduce_with_then_hook_nccl(self):
2013        """
2014        This unit test verifies whether a DDP communication hook that calls allreduce and then
2015        multiplies the result by ten and divides by two gives the expected result.
2016        """
2017        process_group = self._get_process_group()
2018
2019        def allreduce_with_then_hook(
2020            state: object, bucket: dist.GradBucket
2021        ) -> torch.futures.Future[torch.Tensor]:
2022            tensors = [bucket.buffer() / self.world_size]
2023            fut = process_group.allreduce(tensors).get_future()
2024
2025            def mult(fut):
2026                # Multiply the result by 10.
2027                return 10 * fut.value()[0]
2028
2029            def div(fut):
2030                # Divide the result by 2.
2031                return 0.5 * fut.value()
2032
2033            return fut.then(mult).then(div)
2034
2035        # Get GPU model with allreduce_with_then_hook registered.
2036        gpu_model = self._gpu_model_with_ddp_comm_hook(
2037            process_group, allreduce_with_then_hook
2038        )
2039
2040        # check whether the grads are equal to what allreduce returns multiplied by 5.
2041        # without the comm_hook, result would be still 0.25 * torch.ones(2, 2).
2042        self._run_and_verify_hook(gpu_model, 8, 1.25 * torch.ones(2, 2))
2043
2044    class AcceptsParam(torch.nn.Module):
2045        def __init__(self, p, factor):
2046            super().__init__()
2047            self.a = p
2048            self.f = factor
2049
2050        def forward(self, input):
2051            return input + self.a * self.f
2052
2053    @requires_nccl()
2054    @skip_if_lt_x_gpu(2)
2055    def test_ddp_weight_sharing(self):
2056        process_group = self._get_process_group()
2057
2058        size = 2048 * 2048
2059        dev = self.rank
2060        world = self.world_size
2061
2062        p = torch.nn.Parameter(torch.randn(size, requires_grad=True))
2063
2064        for try_set_to_none, use_bucket_view in product((False, True), (False, True)):
2065            m = torch.nn.Sequential(
2066                self.AcceptsParam(p, dev + 1), self.AcceptsParam(p, dev + 1)
2067            ).cuda(dev)
2068
2069            m = torch.nn.parallel.DistributedDataParallel(
2070                m,
2071                bucket_cap_mb=1,
2072                gradient_as_bucket_view=use_bucket_view,
2073                device_ids=[dev],
2074                process_group=process_group,
2075            )
2076
2077            for i in range(3):
2078                m.zero_grad(set_to_none=try_set_to_none)
2079                m(1).sum().backward()
2080
2081                # Each param value is multiplied by "rank + 1" twice in forward, so the grad
2082                # values produced by a particular rank should be 2. * (rank + 1).
2083                # Summing these over ranks and dividing by world size gives the expected result:
2084                analytic = torch.full_like(
2085                    p, 2.0 * (world * (world + 1.0) / 2.0) / world, device=dev
2086                )
2087                for name, p in m.named_parameters():
2088                    self.assertEqual(
2089                        p.grad,
2090                        analytic,
2091                        "mismatch at "
2092                        + name
2093                        + ".grad for "
2094                        + f"set_to_none = {try_set_to_none}, use_bucket_view = {use_bucket_view}",
2095                    )
2096
2097    @requires_nccl()
2098    @skip_if_lt_x_gpu(2)
2099    def test_ddp_packed_sequence(self):
2100        """
2101        Tests that DDP with ``device_ids`` specified can run a forward and
2102        backward pass with ``PackedSequence`` s with parity compared to a local
2103        version of the model.
2104        """
2105        store = c10d.FileStore(self.file_name, self.world_size)
2106        process_group = dist.init_process_group(
2107            "nccl",
2108            world_size=self.world_size,
2109            rank=self.rank,
2110            store=store,
2111        )
2112        seqs = ["sequence_sequence", "seq", "sequence"]
2113        vocab = ["<pad>"] + sorted({ch for seq in seqs for ch in seq})
2114        vectorized_seqs = [[vocab.index(tok) for tok in seq] for seq in seqs]
2115        # Set the seed to make the embedding and LSTM deterministic (even
2116        # across ranks since DDP broadcasts parameters from rank 0)
2117        torch.manual_seed(0)
2118        embed = nn.Embedding(len(vocab), 4)  # keep on CPU
2119        lstm = nn.LSTM(input_size=4, hidden_size=2, batch_first=True).to(self.rank)
2120        lstm_ddp = DistributedDataParallel(
2121            copy.deepcopy(lstm),
2122            device_ids=[self.rank],
2123            process_group=process_group,
2124        )
2125        for p1, p2 in zip(lstm.parameters(), lstm_ddp.module.parameters()):
2126            self.assertEqual(p1, p2)
2127        seq_lengths = torch.LongTensor(list(map(len, vectorized_seqs)))
2128        seq_tensor = torch.Tensor(
2129            torch.zeros((len(vectorized_seqs), seq_lengths.max()))
2130        ).long()
2131        for i, (seq, seq_len) in enumerate(zip(vectorized_seqs, seq_lengths)):
2132            seq_tensor[i, :seq_len] = torch.LongTensor(seq)
2133        seq_lengths, permutation_idx = seq_lengths.sort(0, descending=True)
2134        seq_tensor = seq_tensor[permutation_idx]
2135        embedded_seq_tensor = embed(seq_tensor)
2136        packed_input = torch.nn.utils.rnn.pack_padded_sequence(
2137            embedded_seq_tensor,
2138            seq_lengths,
2139            batch_first=True,
2140        )
2141        packed_input_ddp = torch.nn.utils.rnn.pack_padded_sequence(
2142            embedded_seq_tensor.detach().clone(),
2143            seq_lengths,
2144            batch_first=True,
2145        )
2146        # Move the input to GPU explicitly for the local model
2147        packed_output, (ht, ct) = lstm(packed_input.to(self.rank))
2148        # Let DDP move the input to GPU internally
2149        packed_output_ddp, (ht_ddp, ct_ddp) = lstm_ddp(packed_input_ddp)
2150        self.assertEqual(packed_output.data, packed_output_ddp.data)
2151        self.assertEqual(ht, ht_ddp)
2152        self.assertEqual(ct, ct_ddp)
2153        packed_output.data.sum().backward()
2154        packed_output_ddp.data.sum().backward()
2155        for p1, p2 in zip(lstm.parameters(), lstm_ddp.parameters()):
2156            self.assertEqual(p1.grad, p2.grad)
2157
2158    @requires_nccl()
2159    @skip_if_lt_x_gpu(2)
2160    def test_channels_last_contig(self):
2161        process_group = self._get_process_group()
2162        device = torch.device(f"cuda:{self.rank}")
2163        tensor = torch.ones((2, 16, 768, 1152), dtype=torch.float32, device=device).to(
2164            memory_format=torch.channels_last
2165        )
2166        process_group.broadcast([tensor]).wait()
2167
2168    @requires_nccl()
2169    @skip_if_lt_x_gpu(2)
2170    def test_ddp_complex_params(self):
2171        class FFTModel(nn.Module):
2172            def __init__(self, hin, win, n_features):
2173                super().__init__()
2174                self.hin = hin
2175                self.win = win
2176                self.weight = nn.Parameter(
2177                    torch.ones(
2178                        (n_features, n_features, hin, win // 2 + 1), dtype=torch.cfloat
2179                    )
2180                )
2181
2182            def forward(self, x):
2183                xc = torch.fft.rfft2(
2184                    x, s=(self.hin, self.win), dim=(-2, -1), norm="ortho"
2185                )
2186                xcw = torch.einsum("nchw,cohw->nohw", xc, self.weight)
2187                x = torch.fft.irfft2(xcw, dim=(-2, -1), norm="ortho")
2188                return x
2189
2190        process_group = self._get_process_group()
2191        device_id = gpus_for_rank(self.world_size)[self.rank][0]
2192        N, C, H, W = 1, 16, 64, 64
2193        ddp_model = DistributedDataParallel(
2194            FFTModel(hin=H, win=W, n_features=C).to(device_id),
2195            device_ids=[device_id],
2196            process_group=process_group,
2197        )
2198        optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.001)
2199
2200        inp = torch.ones((N, C, H, W), dtype=torch.float32)
2201
2202        # train step
2203        out = ddp_model(inp)
2204        loss = torch.sum(out)
2205        loss.backward()
2206        optimizer.step()
2207
2208        torch.cuda.synchronize(device=device_id)
2209
2210
2211class WorkHookTest(MultiProcessTestCase):
2212    @property
2213    def world_size(self):
2214        return 2
2215
2216    def setUp(self):
2217        super().setUp()
2218        # set TORCH_NCCL_ENABLE_TIMING to enable timing for CUDAEvents
2219        # in ProcessGroup Work
2220        os.environ["TORCH_NCCL_ENABLE_TIMING"] = "1"
2221        self._spawn_processes()
2222
2223    def tearDown(self):
2224        super().tearDown()
2225        del os.environ["TORCH_NCCL_ENABLE_TIMING"]
2226        try:
2227            os.remove(self.file_name)
2228        except OSError:
2229            pass
2230
2231    def _get_store(self):
2232        return dist.FileStore(self.file_name, self.world_size)
2233
2234    def _get_process_group(self):
2235        store = self._get_store()
2236        c10d.init_process_group(
2237            "nccl", store=store, rank=self.rank, world_size=self.world_size
2238        )
2239        return c10d.distributed_c10d._get_default_group()
2240
2241    @requires_nccl()
2242    @skip_if_lt_x_gpu(2)
2243    def test_on_completion_hook_broadcast(self):
2244        pg = self._get_process_group()
2245        num_hook_fired = 0
2246        durations: List[float] = []
2247
2248        def hook(work_info: torch._C._distributed_c10d.WorkInfo):
2249            nonlocal num_hook_fired, durations
2250            num_hook_fired += 1
2251            durations.append(work_info.active_duration.total_seconds())
2252
2253        pg._register_on_completion_hook(hook)
2254        tensor = torch.ones([2, 3]).cuda(self.rank) * self.rank
2255        pg.broadcast([tensor]).wait()
2256        pg.broadcast([tensor]).wait()
2257
2258        # N.B.: destroy_process_group is necessary to wait for
2259        # all pending works to finish.
2260        c10d.destroy_process_group(pg)
2261
2262        self.assertEqual(num_hook_fired, 2)
2263        self.assertEqual(len(durations), 2)
2264        for duration in durations:
2265            self.assertTrue(duration > 0)
2266
2267        self.assertEqual(tensor, torch.zeros([2, 3]).cuda(self.rank))
2268
2269    @requires_nccl()
2270    @skip_if_lt_x_gpu(2)
2271    def test_on_completion_hook_mixed_ops(self):
2272        pg = self._get_process_group()
2273        num_hook_fired = 0
2274        durations: List[float] = []
2275
2276        def hook(work_info: torch._C._distributed_c10d.WorkInfo):
2277            nonlocal num_hook_fired, durations
2278            num_hook_fired += 1
2279            durations.append(work_info.active_duration.total_seconds())
2280
2281        pg._register_on_completion_hook(hook)
2282        tensor = torch.ones([2, 3]).cuda(self.rank)
2283        tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
2284        # intentionally using async ops.
2285        pg.allreduce(tensor)
2286        pg.allgather(tensor_list, tensor)
2287        pg.allreduce(tensor)
2288
2289        # N.B.: destroy_process_group is necessary to wait for
2290        # all pending works to finish.
2291        c10d.destroy_process_group(pg)
2292
2293        self.assertEqual(num_hook_fired, 3)
2294        self.assertEqual(len(durations), 3)
2295        for duration in durations:
2296            self.assertTrue(duration > 0)
2297
2298        self.assertEqual(
2299            tensor,
2300            torch.ones([2, 3]).cuda(self.rank) * self.world_size * self.world_size,
2301        )
2302
2303        self.assertEqual(
2304            tensor_list,
2305            [
2306                torch.ones([2, 3]).cuda(self.rank) * self.world_size
2307                for _ in range(self.world_size)
2308            ],
2309        )
2310
2311    @requires_nccl()
2312    @skip_if_lt_x_gpu(2)
2313    def test_on_completion_hook_with_ddp(self):
2314        pg = self._get_process_group()
2315        num_hook_fired: Dict[int, int] = {}
2316        durations: Dict[OpType, List[float]] = {}
2317
2318        def hook(work_info: torch._C._distributed_c10d.WorkInfo):
2319            nonlocal num_hook_fired, durations
2320            op_type = work_info.op_type
2321            if op_type not in num_hook_fired:
2322                num_hook_fired[op_type] = 0
2323                durations[op_type] = []
2324            num_hook_fired[op_type] += 1
2325            durations[op_type].append(work_info.active_duration.total_seconds())
2326
2327        pg._register_on_completion_hook(hook)
2328
2329        nlayers = 10
2330        net = nn.Sequential(
2331            *[nn.Linear(1000, 1000, bias=False) for _ in range(nlayers)]
2332        ).to(self.rank)
2333
2334        ddp = DistributedDataParallel(
2335            net,
2336            device_ids=[self.rank],
2337            process_group=pg,
2338            bucket_cap_mb=1,
2339        )
2340
2341        pg._wait_for_pending_works()
2342
2343        # DDP is expected to synchronize model parameter by broadcasting
2344        # from rank0 to other ranks. However, this is DDP's internal implementation,
2345        # which is subject to change in future versions.
2346        self.assertTrue(num_hook_fired[OpType.BROADCAST] > 0)
2347        ctor_allreduce = (
2348            num_hook_fired[OpType.ALLREDUCE]
2349            if OpType.ALLREDUCE in num_hook_fired
2350            else 0
2351        )
2352
2353        x = torch.zeros(2, 1000).cuda(self.rank)
2354        ddp(x).sum().backward()
2355
2356        c10d.destroy_process_group(pg)
2357
2358        self.assertTrue(OpType.ALLREDUCE in num_hook_fired)
2359        # The number of allreduce ops depend on DDP internal implementation, but
2360        # there should be at least one allreduce.
2361        self.assertTrue(num_hook_fired[OpType.ALLREDUCE] - ctor_allreduce > 0)
2362        self.assertTrue(all(duration > 0 for duration in chain(*(durations.values()))))
2363
2364    # Not testing FSDP due to https://github.com/pytorch/pytorch/issues/90848.
2365    # We cannot disable workCleanupLoop() as hooks are fired in that thread.
2366
2367    @requires_nccl()
2368    @skip_if_lt_x_gpu(2)
2369    def test_on_completion_hook_all_gather_object(self):
2370        torch.cuda.set_device(self.rank)
2371
2372        pg = self._get_process_group()
2373        num_hook_fired: Dict[int, int] = {}
2374        durations: Dict[OpType, List[float]] = {}
2375
2376        def hook(work_info: torch._C._distributed_c10d.WorkInfo):
2377            nonlocal num_hook_fired, durations
2378            op_type = work_info.op_type
2379            if op_type not in num_hook_fired:
2380                num_hook_fired[op_type] = 0
2381                durations[op_type] = []
2382            num_hook_fired[op_type] += 1
2383            durations[op_type].append(work_info.active_duration.total_seconds())
2384
2385        pg._register_on_completion_hook(hook)
2386
2387        obj = {"rank": self.rank, "world_size": self.world_size}
2388        obj_list = [None for _ in range(self.world_size)]
2389
2390        c10d.all_gather_object(obj_list, obj, group=pg)
2391
2392        for r, o in enumerate(obj_list):
2393            self.assertTrue(isinstance(o, dict))
2394            self.assertTrue(set(o.keys()), {"rank", "world_size"})
2395            self.assertEqual(o["rank"], r)
2396            self.assertEqual(o["world_size"], self.world_size)
2397
2398        c10d.destroy_process_group(pg)
2399
2400        self.assertTrue(OpType.ALLGATHER in num_hook_fired)
2401        self.assertEqual(len(num_hook_fired), 1)
2402        # two allgathers, one for size and another for values
2403        self.assertEqual(num_hook_fired[OpType.ALLGATHER], 2)
2404        self.assertTrue(all(duration > 0 for duration in durations[OpType.ALLGATHER]))
2405
2406    @requires_nccl()
2407    @skip_if_lt_x_gpu(2)
2408    def test_on_completion_hook_seq(self):
2409        pg = self._get_process_group()
2410        num_hook_fired = 0
2411        seq: int = -1
2412        work: int = 0
2413
2414        def hook(work_info: torch._C._distributed_c10d.WorkInfo):
2415            nonlocal num_hook_fired, seq
2416            num_hook_fired += 1
2417            seq = work_info.seq
2418
2419        pg._register_on_completion_hook(hook)
2420        tensor = torch.ones([2, 3]).cuda(self.rank) * self.rank
2421        work_count = 3
2422        for i in range(work_count):
2423            work += 1
2424            pg.broadcast([tensor]).wait()
2425
2426        # N.B.: destroy_process_group is necessary to wait for
2427        # all pending works to finish.
2428        c10d.destroy_process_group(pg)
2429
2430        self.assertEqual(num_hook_fired, work_count)
2431        self.assertEqual(work, seq)
2432
2433
2434class NcclErrorHandlingTest(MultiProcessTestCase):
2435    def setUp(self):
2436        super().setUp()
2437        # Need to skip return code checking for these tests since the child
2438        # processes don't exit cleanly.
2439        self.skip_return_code_checks = [
2440            self.test_nccl_errors_blocking_abort.__wrapped__,
2441            self.test_nccl_errors_blocking_sigkill.__wrapped__,
2442            self.test_nccl_errors_blocking_sigterm.__wrapped__,
2443            self.test_nccl_errors_blocking_nonzero_exit.__wrapped__,
2444        ]
2445        # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
2446        # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
2447        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
2448        self._spawn_processes()
2449
2450    def tearDown(self):
2451        super().tearDown()
2452        try:
2453            os.remove(self.file_name)
2454        except OSError:
2455            pass
2456
2457    @property
2458    def op_timeout_sec(self):
2459        return 3
2460
2461    @property
2462    def world_size(self):
2463        return 3
2464
2465    @property
2466    def blocking_wait_error_msg(self):
2467        return "timeout"
2468
2469    def _run_all_reduce(self, pg):
2470        pg.allreduce(torch.rand(10).cuda(self.rank))
2471
2472    @requires_nccl()
2473    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2474    @skip_if_lt_x_gpu(3)
2475    @skip_if_rocm_multiprocess
2476    @skip_but_pass_in_sandcastle("Test does not pass when run locally")
2477    def test_nccl_errors_nonblocking(self):
2478        # Note: we unset and restore TORCH_NCCL_ASYNC_ERROR_HANDLING for this test
2479        # since test_c10d_common runs with async error handling by default, but this
2480        # tests behavior when it is not enabled.
2481        prev_nccl_async_error_handling = os.environ.get(
2482            "TORCH_NCCL_ASYNC_ERROR_HANDLING", None
2483        )
2484        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
2485        store = c10d.FileStore(self.file_name, self.world_size)
2486        process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
2487        process_group.allreduce(torch.rand(10).cuda(self.rank))
2488        if self.rank == 0:
2489            # This allreduce does not block Python thread as allreduce enqueues
2490            # the cuda operation, and then wait only blocks the current cuda
2491            # stream.
2492            work = process_group.allreduce(torch.rand(10).cuda(self.rank))
2493            work.wait()
2494
2495            # Now the work scheduled next should hang forever since the previous
2496            # allreduce will never complete.
2497            t = threading.Thread(target=self._run_all_reduce, args=(process_group,))
2498            t.daemon = True
2499            t.start()
2500            t.join(int(get_timeout(self.id()) / 5))
2501            self.assertTrue(t.is_alive())
2502
2503        if prev_nccl_async_error_handling is not None:
2504            os.environ[
2505                "TORCH_NCCL_ASYNC_ERROR_HANDLING"
2506            ] = prev_nccl_async_error_handling
2507
2508    def _test_nccl_errors_blocking(self, func):
2509        store = c10d.FileStore(self.file_name, self.world_size)
2510        process_group = c10d.ProcessGroupNCCL(
2511            store,
2512            self.rank,
2513            self.world_size,
2514            timeout=timedelta(seconds=10),
2515        )
2516        process_group.allreduce(torch.rand(10).cuda(self.rank))
2517        if self.rank == 0:
2518            work = process_group.allreduce(torch.rand(10).cuda(self.rank))
2519            with self.assertRaisesRegex(dist.DistBackendError, ""):
2520                # It seems the error message would be different depending on
2521                # whether the test is run on CI machine and devGPU.  Skipping
2522                # the error message check to make both sides happy.
2523                work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
2524            # Run some GPU operations to make sure cuda has not gotten stuck.
2525            # It was observed cuda could get stuck if NCCL communicators were
2526            # not properly aborted before throwing RuntimeError.
2527            a = torch.rand(10).cuda(self.rank)
2528        elif self.rank == 1:
2529            # Clean up structures (ex: files for FileStore before going down)
2530            del process_group
2531            func()
2532
2533    @with_nccl_blocking_wait
2534    @requires_nccl()
2535    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2536    @skip_if_lt_x_gpu(3)
2537    @skip_if_rocm_multiprocess
2538    def test_nccl_errors_blocking_clean_exit(self):
2539        self._test_nccl_errors_blocking(lambda: sys.exit(0))
2540
2541    @with_nccl_blocking_wait
2542    @requires_nccl()
2543    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2544    @skip_if_lt_x_gpu(3)
2545    @skip_if_rocm_multiprocess
2546    def test_nccl_errors_blocking_nonzero_exit(self):
2547        self._test_nccl_errors_blocking(lambda: sys.exit(1))
2548
2549    @with_nccl_blocking_wait
2550    @requires_nccl()
2551    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2552    @skip_if_lt_x_gpu(3)
2553    @skip_if_rocm_multiprocess
2554    @skip_but_pass_in_sandcastle(
2555        "Frequently times out see https://github.com/pytorch/pytorch/issues/58920"
2556    )
2557    def test_nccl_errors_blocking_abort(self):
2558        self._test_nccl_errors_blocking(lambda: os.abort())
2559
2560    @with_nccl_blocking_wait
2561    @requires_nccl()
2562    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2563    @skip_if_lt_x_gpu(3)
2564    @skip_if_rocm_multiprocess
2565    def test_nccl_errors_blocking_sigkill(self):
2566        self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGKILL))
2567
2568    @with_nccl_blocking_wait
2569    @requires_nccl()
2570    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2571    @skip_if_lt_x_gpu(3)
2572    @skip_if_rocm_multiprocess
2573    def test_nccl_errors_blocking_sigterm(self):
2574        self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGTERM))
2575
2576    @with_nccl_blocking_wait
2577    @requires_nccl()
2578    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2579    @skip_if_lt_x_gpu(3)
2580    def test_nccl_blocking_wait_with_barrier(self):
2581        store = c10d.FileStore(self.file_name, self.world_size)
2582        process_group = c10d.ProcessGroupNCCL(
2583            store,
2584            self.rank,
2585            self.world_size,
2586            timeout=timedelta(seconds=10),
2587        )
2588        process_group.barrier().wait()
2589        if self.rank == 0:
2590            with self.assertRaisesRegex(dist.DistBackendError, ""):
2591                # It seems the error message would be different depending on
2592                # whether the test is run on CI machine and devGPU.  Skipping
2593                # the error message check to make both sides happy.
2594                process_group.barrier().wait(
2595                    timeout=timedelta(seconds=self.op_timeout_sec)
2596                )
2597
2598    def _run_invalid_nccl_blocking_wait_env(self, val):
2599        os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val
2600        store = c10d.FileStore(self.file_name, self.world_size)
2601        with self.assertRaises(RuntimeError):
2602            process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
2603
2604    @requires_nccl()
2605    @skip_if_lt_x_gpu(3)
2606    def test_invalid_nccl_blocking_wait_env(self):
2607        self._run_invalid_nccl_blocking_wait_env("abc")
2608        self._run_invalid_nccl_blocking_wait_env("-1")
2609        self._run_invalid_nccl_blocking_wait_env("2147483647")
2610        self._run_invalid_nccl_blocking_wait_env("4294967295")
2611
2612    @with_nccl_blocking_wait
2613    @requires_nccl()
2614    @requires_gloo()
2615    @skip_if_lt_x_gpu(3)
2616    def test_nccl_timeout(self):
2617        store = c10d.FileStore(self.file_name, self.world_size)
2618
2619        # Initialize process_group.
2620        process_group = c10d.ProcessGroupNCCL(
2621            store, self.rank, self.world_size, timeout=timedelta(seconds=10)
2622        )
2623        # Control gloo pg used as go-ahead signal/barrier
2624        # to coordinate btwn ranks.
2625        pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
2626        failed_collective_timeout = timedelta(milliseconds=100)
2627        process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
2628            timeout=timedelta(seconds=5)
2629        )
2630
2631        if self.rank == 0:
2632            # This should timeout in about 1 second.
2633            # Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
2634            with self.assertRaisesRegex(
2635                dist.DistBackendError, self.blocking_wait_error_msg
2636            ):
2637                process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
2638                    timeout=failed_collective_timeout
2639                )
2640            # Now do a barrier to tell other rank to go ahead.
2641            pg_gloo.barrier().wait()
2642        else:
2643            # Wait on rank 0 to fail.
2644            try:
2645                pg_gloo.barrier().wait()
2646            except Exception as e:
2647                raise ValueError(
2648                    f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}"
2649                ) from e
2650
2651
2652class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
2653    @property
2654    def device(self):
2655        return f"cuda:{self.rank}"
2656
2657    def setUp(self):
2658        super().setUp()
2659        # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
2660        # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
2661        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
2662        self._spawn_processes()
2663
2664    def tearDown(self):
2665        super().tearDown()
2666        try:
2667            os.remove(self.file_name)
2668        except OSError:
2669            pass
2670
2671    def _test_broadcast_coalesced(self, process_group, device, root_rank):
2672        half = torch.float16
2673
2674        # No support for float16 for CPU tensors
2675        if device == torch.device("cpu"):
2676            half = torch.float32
2677
2678        target = torch.arange(60, dtype=half, device=device).chunk(5)
2679        target += torch.arange(60, dtype=torch.float32, device=device).chunk(5)
2680        target += torch.arange(60, dtype=half, device=device).chunk(5)
2681        target += torch.arange(60, dtype=torch.float64, device=device).chunk(5)
2682        target += torch.arange(60, dtype=half, device=device).chunk(5)
2683        target += torch.arange(60, dtype=torch.float32, device=device).chunk(5)
2684
2685        # The tensors to pass to broadcast are identical to the target
2686        # only on the process that is the root of the broadcast.
2687        if self.rank == root_rank:
2688            tensors = [tensor.clone() for tensor in target]
2689        else:
2690            tensors = [torch.zeros_like(tensor) for tensor in target]
2691
2692        if self.rank != root_rank:
2693            self.assertNotEqual(tensors, target)
2694
2695        c10d._broadcast_coalesced(
2696            process_group, tensors, buffer_size=256, src=root_rank
2697        )
2698
2699        if self.rank != root_rank:
2700            self.assertEqual(tensors, target)
2701
2702    @requires_nccl()
2703    @skip_if_lt_x_gpu(2)
2704    def test_broadcast_coalesced_nccl(self):
2705        store = c10d.FileStore(self.file_name, self.world_size)
2706        c10d.init_process_group(
2707            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
2708        )
2709        process_group = c10d.distributed_c10d._get_default_group()
2710        device = torch.device("cuda:%d" % self.rank)
2711        ranks = [0, 1]
2712        for root_rank in ranks:
2713            self._test_broadcast_coalesced(process_group, device, root_rank)
2714
2715    @requires_nccl()
2716    @skip_if_lt_x_gpu(2)
2717    def test_all_reduce_coalesced_nccl(self):
2718        store = c10d.FileStore(self.file_name, self.world_size)
2719        c10d.init_process_group(
2720            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
2721        )
2722        process_group = c10d.distributed_c10d._get_default_group()
2723        device = torch.device("cuda:%d" % self.rank)
2724        tensors = [
2725            torch.full((60 + i,), self.rank + 1 + i, device=device, dtype=torch.float)
2726            for i in range(5)
2727        ]
2728        torch.distributed.all_reduce_coalesced(tensors, group=process_group)
2729        for i, t in enumerate(tensors):
2730            self.assertEqual(
2731                t,
2732                torch.full_like(
2733                    t, self.world_size * (i + (self.world_size + 1.0) / 2.0)
2734                ),
2735            )
2736
2737    @requires_nccl()
2738    @skip_if_lt_x_gpu(2)
2739    def test_all_reduce_coalesced_nccl_float8_errors(self):
2740        store = c10d.FileStore(self.file_name, self.world_size)
2741        c10d.init_process_group(
2742            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
2743        )
2744        process_group = c10d.distributed_c10d._get_default_group()
2745        device = torch.device("cuda:%d" % self.rank)
2746        tensors = [
2747            torch.full(
2748                (60 + i,), self.rank + 1 + i, device=device, dtype=torch.float
2749            ).to(torch.float8_e4m3fn)
2750            for i in range(5)
2751        ]
2752        with self.assertRaisesRegex(
2753            RuntimeError,
2754            "Float8 dtypes are not currenlty supported for NCCL reductions",
2755        ):
2756            torch.distributed.all_reduce_coalesced(tensors, group=process_group)
2757
2758    @requires_nccl()
2759    @skip_if_lt_x_gpu(2)
2760    def test_all_reduce_coalesced_manager_nccl(self):
2761        store = c10d.FileStore(self.file_name, self.world_size)
2762        c10d.init_process_group(
2763            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
2764        )
2765        process_group = c10d.distributed_c10d._get_default_group()
2766        device = torch.device("cuda:%d" % self.rank)
2767        tensors = [
2768            torch.full((60 + i,), self.rank + 1 + i, device=device, dtype=torch.float)
2769            for i in range(5)
2770        ]
2771        with torch.distributed._coalescing_manager(
2772            group=process_group, device=device, async_ops=True
2773        ) as cm:
2774            for tensor in tensors:
2775                torch.distributed.all_reduce(tensor)
2776        self.assertEqual(len(cm.works), 1)
2777        cm.wait()
2778        for i, t in enumerate(tensors):
2779            self.assertEqual(
2780                t,
2781                torch.full_like(
2782                    t, self.world_size * (i + (self.world_size + 1.0) / 2.0)
2783                ),
2784            )
2785
2786    @requires_nccl()
2787    @skip_if_lt_x_gpu(2)
2788    @skip_if_rocm_multiprocess
2789    def test_intra_node_comm_all_reduce(self):
2790        from torch._C._distributed_c10d import _get_intra_node_comm_usage_counter
2791        from torch.testing._internal.common_cuda import SM80OrLater
2792
2793        for peer in range(self.world_size):
2794            if peer == self.rank:
2795                continue
2796            if not torch._C._cuda_canDeviceAccessPeer(self.rank, peer):
2797                raise SkipTest("Test requires p2p access")
2798
2799        if not SM80OrLater:
2800            raise SkipTest("Test requires sm>=80")
2801
2802        store = c10d.FileStore(self.file_name, self.world_size)
2803        os.environ["ENABLE_INTRA_NODE_COMM"] = "1"
2804        os.environ["TEST_INTRA_NODE_COMM"] = "1"
2805        torch.cuda.set_device(self.rank)
2806        c10d.init_process_group(
2807            backend="nccl", rank=self.rank, world_size=self.world_size, store=store
2808        )
2809        expect = self.world_size * (self.world_size - 1) // 2
2810
2811        # IntraNodeComm currently only supports sum and bf16.
2812        # Verify that it is not used in the next two configurations.
2813        t = torch.full((4 * 1024 // 2,), self.rank).cuda()
2814        c10d.all_reduce(t, c10d.ReduceOp.SUM)
2815        self.assertTrue(t.eq(expect).all())
2816        self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
2817
2818        t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
2819        c10d.all_reduce(t, c10d.ReduceOp.AVG)
2820        self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
2821
2822        # Verify that IntraNodeComm is used up to 10MB
2823        t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
2824        c10d.all_reduce(t, c10d.ReduceOp.SUM)
2825        self.assertTrue(t.eq(expect).all())
2826        self.assertEqual(_get_intra_node_comm_usage_counter(), 1)
2827
2828        t = torch.full((512 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
2829        c10d.all_reduce(t, c10d.ReduceOp.SUM)
2830        self.assertTrue(t.eq(expect).all())
2831        self.assertEqual(_get_intra_node_comm_usage_counter(), 2)
2832
2833        t = torch.full((10 * 1024**2 // 2,), self.rank, dtype=torch.bfloat16).cuda()
2834        c10d.all_reduce(t, c10d.ReduceOp.SUM)
2835        self.assertTrue(t.eq(expect).all())
2836        self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
2837
2838        # Verify that IntraNodeComm is not used beyond 10MB
2839        t = torch.full(
2840            (10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16
2841        ).cuda()
2842        c10d.all_reduce(t, c10d.ReduceOp.SUM)
2843        self.assertTrue(t.eq(expect).all())
2844        self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
2845
2846        c10d.destroy_process_group()
2847
2848    @requires_nccl()
2849    @skip_if_lt_x_gpu(2)
2850    def test_sequence_num_set_default_pg_nccl(self):
2851        torch.cuda.set_device(self.rank)
2852        self._test_sequence_num_set_default_pg(backend="nccl")
2853
2854    @skip_if_lt_x_gpu(2)
2855    @requires_nccl()
2856    def test_sequence_num_incremented_nccl_default(self):
2857        self._test_sequence_num_incremented_default_group("nccl")
2858
2859    @skip_if_lt_x_gpu(4)
2860    @requires_nccl()
2861    def test_sequence_num_incremented_nccl_subgroup(self):
2862        if self.world_size < 4:
2863            return skip_but_pass_in_sandcastle("Test requires world_size of at least 4")
2864        self._test_sequence_num_incremented_subgroup("nccl")
2865
2866    @requires_nccl()
2867    @skip_if_lt_x_gpu(2)
2868    def test_sequence_num_set_nccl_new_group(self):
2869        torch.cuda.set_device(self.rank)
2870        self._test_sequence_num_set_new_group(backend="nccl")
2871
2872    def _test_pass_nccl_options(self, pg_opts):
2873        store = c10d.FileStore(self.file_name, self.world_size)
2874        # Test init_process_group accepts options
2875        dist.init_process_group(
2876            "nccl",
2877            world_size=self.world_size,
2878            rank=self.rank,
2879            store=store,
2880            pg_options=pg_opts,
2881        )
2882
2883        # Test with new_group
2884        pg = c10d.new_group([0, 1], pg_options=pg_opts)
2885        # test the process group works as expected
2886        t = torch.tensor([self.rank + 1] * 10).cuda(self.rank)
2887        pg.allreduce(t).wait()
2888        expected_tensor = torch.tensor([3] * 10).cuda(self.rank)
2889        self.assertEqual(expected_tensor, t)
2890
2891    @requires_nccl()
2892    @skip_if_lt_x_gpu(2)
2893    def test_pass_nccl_options_high_priority_stream(self):
2894        pg_opts = c10d.ProcessGroupNCCL.Options()
2895        pg_opts.is_high_priority_stream = True
2896        self._test_pass_nccl_options(pg_opts)
2897
2898    @requires_nccl()
2899    @requires_nccl_version(
2900        (2, 18), "Need NCCL 2.17+ for configuring NCCL communicators"
2901    )
2902    @skip_if_lt_x_gpu(2)
2903    def test_pass_nccl_options_config(self):
2904        pg_opts = c10d.ProcessGroupNCCL.Options()
2905        pg_opts.config.max_ctas = 4
2906        pg_opts.config.min_ctas = 2
2907        pg_opts.config.cga_cluster_size = 2
2908        pg_opts.config.net_name = "Socket"
2909        pg_opts.config.split_share = 1
2910        nccl_debug_file = tempfile.NamedTemporaryFile()
2911        os.environ["NCCL_DEBUG"] = "INFO"
2912        os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name
2913
2914        # Tests functionality when passing nccl config
2915        self._test_pass_nccl_options(pg_opts)
2916
2917        # Tests if comms were configured
2918        nccl_debug_file_content = nccl_debug_file.read()
2919        max_ctas = re.search(rb"Max CTAs.*(\d+)|$", nccl_debug_file_content).group(1)
2920        min_ctas = re.search(rb"Min CTAs.*(\d+)|$", nccl_debug_file_content).group(1)
2921        split_share = re.search(
2922            rb"Split share.*(\d+)|$", nccl_debug_file_content
2923        ).group(1)
2924        cga_cluster_size = re.search(
2925            rb"CGA cluster.*(\d+)|$", nccl_debug_file_content
2926        ).group(1)
2927        net_name = re.search(
2928            rb"Using network.([a-zA-z]+)|$", nccl_debug_file_content
2929        ).group(1)
2930        self.assertEqual(pg_opts.config.max_ctas, int(max_ctas))
2931        self.assertEqual(pg_opts.config.min_ctas, int(min_ctas))
2932        self.assertEqual(pg_opts.config.cga_cluster_size, int(cga_cluster_size))
2933        self.assertEqual(pg_opts.config.net_name, net_name.decode())
2934        self.assertEqual(pg_opts.config.split_share, int(split_share))
2935
2936    @requires_nccl()
2937    @skip_if_lt_x_gpu(4)
2938    def test_nccl_barrier(self):
2939        store = c10d.FileStore(self.file_name, self.world_size)
2940        c10d.init_process_group(
2941            backend="nccl", rank=self.rank, world_size=self.world_size, store=store
2942        )
2943
2944        t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
2945        c10d.all_reduce(t)
2946        expected_tensor = torch.tensor([3] * 10).cuda(2 * self.rank)
2947        self.assertEqual(expected_tensor, t)
2948
2949        # Test with new_group
2950        pg = c10d.new_group([0, 1])
2951        t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
2952        pg.allreduce(t).wait()
2953        self.assertEqual(expected_tensor, t)
2954
2955        pg = c10d.new_group([0])
2956        if self.rank == 0:
2957            t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
2958            expected_tensor = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
2959            pg.allreduce(t).wait()
2960            self.assertEqual(expected_tensor, t)
2961
2962        pg = c10d.new_group([1])
2963        if self.rank == 1:
2964            t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
2965            expected_tensor = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
2966            pg.allreduce(t).wait()
2967            self.assertEqual(expected_tensor, t)
2968
2969    @requires_nccl()
2970    @skip_if_lt_x_gpu(2)
2971    def test_nccl_barrier_device_ids(self):
2972        store = c10d.FileStore(self.file_name, self.world_size)
2973        c10d.init_process_group(
2974            backend="nccl", rank=self.rank, world_size=self.world_size, store=store
2975        )
2976
2977        c10d.barrier(device_ids=[self.rank])
2978
2979    @requires_nccl()
2980    @skip_if_lt_x_gpu(2)
2981    def test_nccl_barrier_device_ids_function_argument(self):
2982        store = c10d.FileStore(self.file_name, self.world_size)
2983        c10d.init_process_group(
2984            backend="nccl", rank=self.rank, world_size=self.world_size, store=store
2985        )
2986
2987        with self.assertRaisesRegex(TypeError, "Invalid function argument"):
2988            c10d.barrier(device_ids=self.rank)
2989
2990    @requires_nccl()
2991    @skip_if_lt_x_gpu(2)
2992    @with_dist_debug_levels(levels=["DETAIL"])
2993    def test_nccl_warn_not_in_group_debug_detail(self):
2994        self._test_warn_not_in_group(backend="nccl")
2995
2996    @requires_nccl()
2997    @skip_if_lt_x_gpu(2)
2998    @with_dist_debug_levels(levels=["INFO"])
2999    def test_nccl_warn_not_in_group_debug_info(self):
3000        self._test_warn_not_in_group(backend="nccl")
3001
3002    @requires_nccl()
3003    @skip_if_lt_x_gpu(2)
3004    @with_dist_debug_levels(levels=["OFF"])
3005    def test_nccl_warn_not_in_group_debug_off(self):
3006        self._test_warn_not_in_group(backend="nccl")
3007
3008    @requires_nccl()
3009    @skip_if_lt_x_gpu(2)
3010    def test_nncl_rank_membership(self):
3011        self._test_rank_membership(backend="nccl")
3012
3013    @requires_nccl()
3014    @skip_if_lt_x_gpu(2)
3015    def test_tensor_dtype_mismatch(self):
3016        self._test_tensor_dtype_mismatch(backend="nccl")
3017
3018    @requires_nccl()
3019    @skip_if_lt_x_gpu(2)
3020    def test_tensor_dtype_complex(self):
3021        self._test_tensor_dtype_complex(backend="nccl")
3022
3023    @requires_nccl()
3024    @skip_if_lt_x_gpu(2)
3025    def test_reduce_scatter_base_k(self):
3026        store = dist.FileStore(self.file_name, self.world_size)
3027        dist.init_process_group(
3028            "nccl",
3029            world_size=self.world_size,
3030            rank=self.rank,
3031            store=store,
3032        )
3033        output_tensor = torch.zeros(2, dtype=torch.int64).to(self.rank)
3034        input_tensors = torch.arange(self.world_size * 2, dtype=torch.int64).to(
3035            self.rank
3036        )
3037        input_tensors = torch.reshape(input_tensors, (self.world_size, 2))
3038        dist.reduce_scatter_tensor(output_tensor, input_tensors)
3039        self.assertEqual(output_tensor, input_tensors[self.rank] * self.world_size)
3040
3041    @requires_nccl()
3042    @skip_if_lt_x_gpu(2)
3043    def test_reduce_scatter_tensor_coalesced(self):
3044        store = dist.FileStore(self.file_name, self.world_size)
3045        dist.init_process_group(
3046            "nccl",
3047            world_size=self.world_size,
3048            rank=self.rank,
3049            store=store,
3050        )
3051        output_tensors = torch.zeros(2, 2).to(self.rank)
3052        input_tensors = [torch.ones(2, 2).to(self.rank) for _ in range(self.world_size)]
3053        with dist._coalescing_manager():
3054            for i in range(self.world_size):
3055                dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
3056        self.assertEqual(output_tensors, input_tensors[self.rank] * self.world_size)
3057
3058    @requires_nccl()
3059    @skip_if_lt_x_gpu(2)
3060    def test_reduce_scatter_base_k_float8_errors(self):
3061        store = dist.FileStore(self.file_name, self.world_size)
3062        dist.init_process_group(
3063            "nccl",
3064            world_size=self.world_size,
3065            rank=self.rank,
3066            store=store,
3067        )
3068        output_tensor = (
3069            torch.zeros(2, dtype=torch.float32).to(torch.float8_e4m3fn).to(self.rank)
3070        )
3071        input_tensors = (
3072            torch.arange(self.world_size * 2, dtype=torch.float32)
3073            .to(torch.float8_e4m3fn)
3074            .to(self.rank)
3075        )
3076        input_tensors = torch.reshape(input_tensors, (self.world_size, 2))
3077        with self.assertRaisesRegex(
3078            RuntimeError,
3079            "Float8 dtypes are not currenlty supported for NCCL reductions",
3080        ):
3081            dist.reduce_scatter_tensor(output_tensor, input_tensors)
3082
3083    @requires_nccl()
3084    @skip_if_lt_x_gpu(2)
3085    def test_reduce_scatter_tensor_coalesced_float8_errors(self):
3086        store = dist.FileStore(self.file_name, self.world_size)
3087        dist.init_process_group(
3088            "nccl",
3089            world_size=self.world_size,
3090            rank=self.rank,
3091            store=store,
3092        )
3093        output_tensors = torch.zeros(2, 2).to(torch.float8_e5m2).to(self.rank)
3094        input_tensors = [
3095            torch.ones(2, 2).to(torch.float8_e5m2).to(self.rank)
3096            for _ in range(self.world_size)
3097        ]
3098
3099        with self.assertRaisesRegex(
3100            RuntimeError,
3101            "Float8 dtypes are not currenlty supported for NCCL reductions",
3102        ):
3103            with dist._coalescing_manager():
3104                for i in range(self.world_size):
3105                    dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
3106            self.assertEqual(output_tensors, input_tensors[self.rank])
3107
3108
3109class SetDeviceMethod(Enum):
3110    TORCH_CUDA_SET = auto()  # torch.cuda.set_device
3111    COLLECTIVE_ARGUMENT = auto()  # broadcast_object_list(device=)
3112
3113
3114class NcclProcessGroupWithDispatchedCollectivesTests(
3115    test_c10d_common.ProcessGroupWithDispatchedCollectivesTests
3116):
3117    @requires_nccl()
3118    @skip_if_lt_x_gpu(1)
3119    def test_collectives(self):
3120        self._test_collectives(backend="nccl")
3121
3122    @requires_nccl()
3123    @skip_if_lt_x_gpu(1)
3124    def test_allreduce_coalesced(self):
3125        self._test_allreduce_coalesced(backend="nccl")
3126
3127    @requires_nccl()
3128    @skip_if_lt_x_gpu(1)
3129    def test_all_to_all_single(self):
3130        self._test_all_to_all_single(backend="nccl")
3131
3132    @requires_nccl()
3133    @skip_if_lt_x_gpu(1)
3134    def test_allgather_base(self):
3135        store = dist.FileStore(self.file_name, self.world_size)
3136        dist.init_process_group(
3137            "nccl",
3138            world_size=self.world_size,
3139            rank=self.rank,
3140            store=store,
3141        )
3142        device = "cuda"
3143        tensor = torch.ones(10, 10, device=torch.device(device))
3144        output_tensor = torch.zeros(10, 10, device=torch.device(device))
3145        dist.all_gather_into_tensor(output_tensor, tensor)
3146        self.assertEqual(output_tensor, tensor)
3147
3148    @requires_nccl()
3149    @skip_if_lt_x_gpu(1)
3150    @parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
3151    def test_allgather_float8(self, float8_dtype):
3152        store = dist.FileStore(self.file_name, self.world_size)
3153        dist.init_process_group(
3154            "nccl",
3155            world_size=self.world_size,
3156            rank=self.rank,
3157            store=store,
3158        )
3159        device = "cuda"
3160        tensor = torch.ones(10, 16, device=torch.device(device)).to(float8_dtype)
3161        output_tensor = torch.zeros(10, 16, device=torch.device(device)).to(
3162            float8_dtype
3163        )
3164        dist.all_gather_into_tensor(output_tensor, tensor)
3165        self.assertEqual(output_tensor.view(torch.float32), tensor.view(torch.float32))
3166
3167
3168instantiate_parametrized_tests(NcclProcessGroupWithDispatchedCollectivesTests)
3169
3170
3171class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase):
3172    def setUp(self):
3173        super().setUp()
3174        # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
3175        # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
3176        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
3177        self._spawn_processes()
3178
3179    def tearDown(self):
3180        super().tearDown()
3181        try:
3182            os.remove(self.file_name)
3183        except OSError:
3184            pass
3185
3186    @property
3187    def device(self):
3188        return self.rank
3189
3190    @requires_nccl()
3191    @skip_if_lt_x_gpu(4)
3192    def test_new_group_local_sync(self):
3193        self._test_new_group_local_sync(backend="nccl")
3194
3195    @requires_nccl()
3196    @skip_if_lt_x_gpu(4)
3197    def test_new_group_local_sync_sanity_check(self):
3198        self._test_new_group_local_sync_sanity_check(backend="nccl")
3199
3200    @requires_nccl()
3201    @skip_if_lt_x_gpu(4)
3202    def test_new_group_local_sync_duplicated_pg(self):
3203        self._test_new_group_local_sync_duplicate_pg(backend="nccl")
3204
3205    def _init_two_pg2_subgroups(self, world_size: int = 4):
3206        if world_size != 4:
3207            raise NotImplementedError(
3208                f"need world size of 4 to get 2 subgroup PGs, but got world size of {world_size}"
3209            )
3210        store = c10d.FileStore(self.file_name, world_size)
3211        c10d.init_process_group(
3212            backend="nccl", store=store, rank=self.rank, world_size=world_size
3213        )
3214        # every rank creates the same sub groups
3215        # including unused sub groups in the current rank
3216        a_group = c10d.new_group([0, 1])
3217        b_group = c10d.new_group([2, 3])
3218        return a_group if self.rank < 2 else b_group
3219
3220    @requires_nccl()
3221    @skip_if_lt_x_gpu(4)
3222    def test_gather_subgroup(self):
3223        world_size = 4
3224        if self.rank >= world_size:
3225            # just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later
3226            return
3227
3228        subgroup = self._init_two_pg2_subgroups(world_size)
3229        device = torch.device("cuda:%d" % self.rank)
3230        input = torch.ones((10,), device=device) * self.rank
3231        if self.rank == 0 or self.rank == 2:
3232            gather_list = [torch.empty_like(input) for _ in range(subgroup.size())]
3233            torch.distributed.gather(
3234                input,
3235                gather_list=gather_list,
3236                dst=self.rank,
3237                group=subgroup,
3238                async_op=False,
3239            )
3240            for src in range(len(gather_list)):
3241                expected = (torch.ones_like(input) * self.rank) + src
3242                self.assertEqual(gather_list[src], expected)
3243        else:
3244            torch.distributed.gather(
3245                input,
3246                gather_list=None,
3247                dst=self.rank - 1,
3248                group=subgroup,
3249                async_op=False,
3250            )
3251
3252    @requires_nccl()
3253    @skip_if_lt_x_gpu(4)
3254    def test_gather_object_subgroup(self):
3255        world_size = 4
3256        if self.rank >= world_size:
3257            # just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later
3258            return
3259
3260        subgroup = self._init_two_pg2_subgroups(world_size)
3261
3262        # discrepancy #1
3263        # have to set device or else gather_object gets wrong device from 'current_device = _get_pg_default_device(group)
3264        torch.cuda.set_device(self.rank)
3265
3266        input = {"rank": self.rank}
3267        if self.rank == 0 or self.rank == 2:
3268            # discrepancy #2
3269            # another weird thing- what's the point of making me specify some empty objects in my list?
3270            # empty list should be valid imo.  (but it throws an error)
3271            gather_list = [{}, {}]
3272            torch.distributed.gather_object(
3273                input, object_gather_list=gather_list, dst=self.rank, group=subgroup
3274            )
3275            for src in range(len(gather_list)):
3276                self.assertEqual(gather_list[src]["rank"], self.rank + src)
3277        else:
3278            torch.distributed.gather_object(
3279                input, object_gather_list=None, dst=self.rank - 1, group=subgroup
3280            )
3281
3282    @requires_nccl()
3283    @skip_if_lt_x_gpu(4)
3284    def test_reduce_subgroup(self):
3285        world_size = 4
3286        if self.rank >= world_size:
3287            return
3288        subgroup = self._init_two_pg2_subgroups(world_size)
3289        device = torch.device("cuda:%d" % self.rank)
3290        x = torch.ones((10,), device=device) * self.rank
3291        if self.rank == 0 or self.rank == 2:
3292            expected = x + torch.ones((10,), device=device) * (self.rank + 1)
3293            c10d.reduce(x, dst=self.rank, group=subgroup, async_op=False)
3294            self.assertEqual(x, expected)
3295        else:
3296            c10d.reduce(x, dst=self.rank - 1, group=subgroup, async_op=False)
3297
3298    @requires_nccl()
3299    @skip_if_lt_x_gpu(4)
3300    @parametrize("async_op", [True, False])
3301    def test_send_recv_subgroup(self, async_op):
3302        world_size = 4
3303        if self.rank >= world_size:
3304            return
3305        subgroup = self._init_two_pg2_subgroups(world_size)
3306        device = torch.device("cuda:%d" % self.rank)
3307        if self.rank == 0 or self.rank == 2:
3308            x = torch.empty((10,), device=device)
3309            if async_op:
3310                c10d.irecv(x, src=self.rank + 1, group=subgroup).wait()
3311            else:
3312                c10d.recv(x, src=self.rank + 1, group=subgroup)
3313            expected = torch.ones((10,), device=device) * (self.rank + 1)
3314            self.assertEqual(x, expected)
3315        else:
3316            x = torch.ones((10,), device=device) * self.rank
3317            if async_op:
3318                c10d.isend(x, dst=self.rank - 1, group=subgroup).wait()
3319            else:
3320                c10d.send(x, dst=self.rank - 1, group=subgroup)
3321
3322    @requires_nccl()
3323    @skip_if_lt_x_gpu(4)
3324    def test_broadcast_subgroup(self):
3325        world_size = 4
3326        if self.rank >= world_size:
3327            return
3328        subgroup = self._init_two_pg2_subgroups(world_size)
3329        device = torch.device("cuda:%d" % self.rank)
3330        if self.rank == 0 or self.rank == 2:
3331            x = torch.empty((10,), device=device)
3332            c10d.broadcast(x, src=self.rank + 1, group=subgroup)
3333            expected = torch.ones((10,), device=device) * (self.rank + 1)
3334            self.assertEqual(x, expected)
3335        else:
3336            x = torch.ones((10,), device=device) * self.rank
3337            c10d.broadcast(x, src=self.rank, group=subgroup)
3338
3339    @requires_nccl()
3340    @skip_if_lt_x_gpu(4)
3341    @parametrize(
3342        "set_device",
3343        [SetDeviceMethod.TORCH_CUDA_SET, SetDeviceMethod.COLLECTIVE_ARGUMENT],
3344    )
3345    def test_send_recv_object_list_subgroup(self, set_device: SetDeviceMethod):
3346        world_size = 4
3347        if self.rank >= world_size:
3348            return
3349        subgroup = self._init_two_pg2_subgroups(world_size)
3350        if set_device == SetDeviceMethod.TORCH_CUDA_SET:
3351            torch.cuda.set_device(self.rank)
3352            device = None
3353        else:
3354            device = torch.device("cuda:%d" % self.rank)
3355        if self.rank == 0 or self.rank == 2:
3356            x = [{}]
3357            c10d.recv_object_list(x, src=self.rank + 1, group=subgroup, device=device)
3358            expected = [{"rank": self.rank + 1}]
3359            self.assertEqual(x, expected)
3360        else:
3361            x = [{"rank": self.rank}]
3362            c10d.send_object_list(x, dst=self.rank - 1, group=subgroup, device=device)
3363
3364    @requires_nccl()
3365    @skip_if_lt_x_gpu(4)
3366    @parametrize(
3367        "set_device",
3368        [SetDeviceMethod.TORCH_CUDA_SET, SetDeviceMethod.COLLECTIVE_ARGUMENT],
3369    )
3370    def test_broadcast_object_list_subgroup(self, set_device: SetDeviceMethod):
3371        world_size = 4
3372        if self.rank >= world_size:
3373            return
3374        subgroup = self._init_two_pg2_subgroups(world_size)
3375        if set_device == SetDeviceMethod.TORCH_CUDA_SET:
3376            torch.cuda.set_device(self.rank)
3377            device = None
3378        else:
3379            device = torch.device("cuda:%d" % self.rank)
3380        if self.rank == 0 or self.rank == 2:
3381            x = [{}]
3382            c10d.broadcast_object_list(
3383                x, src=self.rank + 1, group=subgroup, device=device
3384            )
3385            expected = [{"rank": self.rank + 1}]
3386            self.assertEqual(x, expected)
3387        else:
3388            x = [{"rank": self.rank}]
3389            c10d.broadcast_object_list(x, src=self.rank, group=subgroup, device=device)
3390
3391    @requires_nccl()
3392    @skip_if_lt_x_gpu(4)
3393    def test_scatter_subgroup(self):
3394        world_size = 4
3395        if self.rank >= world_size:
3396            return
3397        subgroup = self._init_two_pg2_subgroups(world_size)
3398        device = torch.device("cuda:%d" % self.rank)
3399        x = torch.empty((10,), device=device)
3400        expected = torch.ones((10,), device=device) * self.rank
3401        if self.rank == 0 or self.rank == 2:
3402            c10d.scatter(x, scatter_list=None, src=self.rank + 1, group=subgroup)
3403        else:
3404            scatter_list = [
3405                torch.ones((10,), device=device) * (self.rank - 1),
3406                torch.ones((10,), device=device) * self.rank,
3407            ]
3408            c10d.scatter(x, scatter_list=scatter_list, src=self.rank, group=subgroup)
3409        self.assertEqual(x, expected)
3410
3411    @requires_nccl()
3412    @skip_if_lt_x_gpu(4)
3413    def test_scatter_object_list_subgroup(self):
3414        world_size = 4
3415        if self.rank >= world_size:
3416            return
3417        subgroup = self._init_two_pg2_subgroups(world_size)
3418        torch.cuda.set_device(self.rank)
3419        scatter_object_output_list = [None]
3420        expected = [{"rank": self.rank}]
3421        if self.rank == 0 or self.rank == 2:
3422            c10d.scatter_object_list(
3423                scatter_object_output_list=scatter_object_output_list,
3424                scatter_object_input_list=None,
3425                src=self.rank + 1,
3426                group=subgroup,
3427            )
3428
3429        else:
3430            scatter_object_input_list = [
3431                {"rank": self.rank - 1},
3432                {"rank": self.rank},
3433            ]
3434            c10d.scatter_object_list(
3435                scatter_object_output_list=scatter_object_output_list,
3436                scatter_object_input_list=scatter_object_input_list,
3437                src=self.rank,
3438                group=subgroup,
3439            )
3440        self.assertEqual(scatter_object_output_list, expected)
3441
3442
3443instantiate_parametrized_tests(LargeCommTest)
3444
3445
3446class SparseCollective(MultiProcessTestCase):
3447    @property
3448    def world_size(self):
3449        return 1
3450
3451    def setUp(self):
3452        super().setUp()
3453        # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
3454        # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
3455        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
3456        # self.num_gpus = torch.cuda.device_count()
3457        self._spawn_processes()
3458
3459    def tearDown(self):
3460        super().tearDown()
3461        try:
3462            os.remove(self.file_name)
3463        except OSError:
3464            pass
3465
3466    class ToyModel(nn.Module):
3467        def __init__(self, rank, vocab_size, embedding_dim):
3468            super().__init__()
3469            self.embedding = nn.Embedding(vocab_size, embedding_dim, sparse=True).to(
3470                rank
3471            )
3472            self.linear = nn.Linear(embedding_dim, 1).to(rank)
3473
3474        def forward(self, inputs):
3475            embedded = self.embedding(inputs)
3476            # embedded shape: (batch_size, sequence_length, embedding_dim)
3477            flattened = torch.mean(embedded, dim=1)
3478            # flattened shape: (batch_size, embedding_dim)
3479            output = self.linear(flattened)
3480            # output shape: (batch_size, 1)
3481            return output
3482
3483    @requires_nccl()
3484    @skip_if_lt_x_gpu(1)
3485    def test_ddp_set_sparse_metadata(self):
3486        store = dist.FileStore(self.file_name, self.world_size)
3487        dist.init_process_group(
3488            "nccl",
3489            world_size=self.world_size,
3490            rank=self.rank,
3491            store=store,
3492        )
3493
3494        vocab_size = 5
3495
3496        model = SparseCollective.ToyModel(
3497            self.rank, vocab_size=vocab_size, embedding_dim=10
3498        )
3499        ddp_model = DistributedDataParallel(model)
3500        inputs = torch.tensor([[1, 0, 0], [0, 0, 0], [0, 0, 0]]).to(self.rank)
3501        # set sparse metadata on the DDP model
3502        indices = torch.Tensor(list(range(vocab_size)))
3503        ddp_model._set_sparse_metadata({"embedding.weight": indices})
3504        # forward pass
3505        try:
3506            output = ddp_model(inputs)
3507            loss = output.sum()
3508
3509            # backward pass
3510            loss.backward()
3511            self.assertTrue(ddp_model.module.embedding.weight.grad.indices, indices)
3512        except RuntimeError as e:
3513            if "NCCL does not support all_reduce with sparse tensors" in str(e):
3514                pass
3515            else:
3516                # Rethrow the exception if it's a different error
3517                raise
3518
3519
3520class NCCLTraceTestBase(MultiProcessTestCase):
3521    def setUp(self):
3522        super().setUp()
3523        os.environ[
3524            "TORCH_NCCL_ENABLE_TIMING"
3525        ] = "0"  # see 'timing_enabled' parametrized tests
3526        os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "1000"
3527        os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
3528        self.tempdir = tempfile.TemporaryDirectory()
3529        os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = self._trace_basename()
3530        os.environ["TORCH_NCCL_DEBUG_INFO_PIPE_FILE"] = self._trace_basename()
3531        self._spawn_processes()
3532
3533    @classmethod
3534    def _run(
3535        cls,
3536        parent_conn,
3537        rank: int,
3538        test_name: str,
3539        file_name: str,
3540        parent_pipe,
3541        **kwargs,
3542    ) -> None:
3543        cls.parent = parent_conn
3544        super()._run(rank, test_name, file_name, parent_pipe)
3545
3546    @property
3547    def local_device(self):
3548        return torch.device("cuda", self.rank_to_GPU[self.rank][0])
3549
3550    def _join_processes(self, fn):
3551        # We need to patch sys.exit() as skip_if will use sys.exit() and
3552        # the exit code from the this process will not be catched.
3553        with mock.patch("sys.exit") as exit_mock:
3554            fn()
3555        super()._join_processes(fn)
3556
3557    def _spawn_processes(self) -> None:
3558        proc = torch.multiprocessing.get_context("spawn").Process
3559        self.children_pipes = []
3560        parent_pipes = []
3561        for i in range(self.world_size):
3562            parent_conn, child_conn = torch.multiprocessing.Pipe()
3563            self.children_pipes.append(child_conn)
3564            parent_pipes.append(parent_conn)
3565        piter = iter(parent_pipes)
3566
3567        def wrap(*positional, args, **kwargs):
3568            args = (next(piter), *args)
3569            return proc(*positional, args=args, **kwargs)
3570
3571        self._start_processes(wrap)
3572
3573    def _create_process_group_nccl(self):
3574        store = dist.FileStore(self.file_name, self.world_size)
3575        c10d.init_process_group(
3576            "nccl", world_size=self.world_size, rank=self.rank, store=store
3577        )
3578        pg = c10d.distributed_c10d._get_default_group()
3579        return pg
3580
3581    def tearDown(self):
3582        super().tearDown()
3583        try:
3584            os.remove(self.file_name)
3585        except OSError:
3586            pass
3587
3588    @property
3589    def world_size(self):
3590        return 2
3591
3592    @property
3593    def rank_to_GPU(self):
3594        # return rank to GPU map
3595        return init_multigpu_helper(self.world_size, "nccl")
3596
3597    def _trace_basename(self):
3598        # we pass the base to the env, and the dump util will append rank
3599        return os.path.join(self.tempdir.name, "trace_")
3600
3601    def _trace_name(self, rank):
3602        return self._trace_basename() + str(rank)
3603
3604    def started_or_scheduled(self, timing_enabled):
3605        return "started" if timing_enabled else "scheduled"
3606
3607
3608class NCCLTraceTest(NCCLTraceTestBase):
3609    def _verify_trace(self, t, include_collectives, timing_enabled, is_json):
3610        ver = t["version"]
3611        self.assertEqual(ver, "2.4")
3612        pg_config = t["pg_config"]
3613        self.assertEqual(len(pg_config), 1)
3614        default_pg_info = pg_config["0"]
3615        self.assertIn("name", default_pg_info)
3616        self.assertIn("desc", default_pg_info)
3617        self.assertIn("ranks", default_pg_info)
3618        pg_status = t["pg_status"]
3619        self.assertEqual(len(pg_status), 1)
3620        self.assertEqual(str(pg_status["0"]["last_enqueued_collective"]), "2")
3621        self.assertEqual(str(pg_status["0"]["last_completed_collective"]), "2")
3622        self.assertEqual(
3623            str(pg_status["0"]["last_started_collective"]),
3624            "2" if timing_enabled else "-1",
3625        )
3626        global_ranks = pg_config["0"]["ranks"]
3627        self.assertEqual(len(json.loads(global_ranks)), self.world_size)
3628        if include_collectives:
3629            self.assertEqual(len(t["entries"]), 2)
3630            t = t["entries"]
3631            last = t[-1]
3632            self.assertEqual(last["process_group"], ("0", "default_pg"))
3633            self.assertEqual(last["state"], "completed")
3634            s = last["time_discovered_started_ns"]
3635            f = last["time_discovered_completed_ns"]
3636            self.assertEqual(last["record_id"], 1)
3637            self.assertIsNotNone(f)
3638            if timing_enabled:
3639                self.assertIsNotNone(s)
3640                self.assertTrue(s <= f)
3641            # we don't collect stack traces in JSON at the moment
3642            if not is_json:
3643                self.assertIn("test_c10d_nccl.py", str(last["frames"]))
3644            self.assertEqual(last["input_sizes"], ((3, 4),))
3645            self.assertEqual(last["input_dtypes"], ["Float"])
3646            self.assertEqual(last["output_sizes"], ((3, 4),))
3647            self.assertEqual(last["output_dtypes"], ["Float"])
3648            self.assertEqual(last["collective_seq_id"], 2)
3649            self.assertEqual(last["timeout_ms"], 600000)
3650            now = datetime.now()
3651            event_created_time = datetime.fromtimestamp(
3652                last["time_created_ns"] / 1000000000
3653            )
3654            before_test = now - timedelta(minutes=1)
3655            self.assertTrue(before_test < event_created_time < now)
3656            if timing_enabled:
3657                # very loose bounds, measured 0.036 ms on devgpu
3658                self.assertTrue(0 < last["duration_ms"] < 100)
3659            else:
3660                self.assertTrue("duration_ms" not in last)
3661        else:
3662            self.assertTrue("entries" not in t)
3663
3664    @requires_nccl()
3665    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3666    @parametrize("timing_enabled", [True, False])
3667    @parametrize("include_collectives", [True, False])
3668    def test_short_json(self, timing_enabled, include_collectives):
3669        if self.rank == self.MAIN_PROCESS_RANK:
3670            return
3671        pg = self._create_process_group_nccl()
3672        if timing_enabled:
3673            pg._enable_collectives_timing()
3674        device = self.local_device
3675        a = torch.full((3, 4), float(self.rank), device=device)
3676        for i in range(2):
3677            f = pg.allreduce(a)
3678        f.wait()
3679        torch.cuda.synchronize(device=device)
3680        # gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api
3681        time.sleep(1)
3682        t = json.loads(
3683            torch._C._distributed_c10d._dump_nccl_trace_json(
3684                includeCollectives=include_collectives
3685            )
3686        )
3687        self._verify_trace(t, include_collectives, timing_enabled, True)
3688        dist.destroy_process_group()
3689
3690    @requires_nccl()
3691    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3692    @parametrize("timing_enabled", [True, False])
3693    @parametrize("include_collectives", [True, False])
3694    def test_short_pickle(self, timing_enabled, include_collectives):
3695        if self.rank == self.MAIN_PROCESS_RANK:
3696            return
3697        pg = self._create_process_group_nccl()
3698        if timing_enabled:
3699            pg._enable_collectives_timing()
3700        device = self.local_device
3701        a = torch.full((3, 4), float(self.rank), device=device)
3702        for i in range(2):
3703            f = pg.allreduce(a)
3704        f.wait()
3705        torch.cuda.synchronize(device=device)
3706        # gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api
3707        time.sleep(1)
3708        t = pickle.loads(
3709            torch._C._distributed_c10d._dump_nccl_trace(
3710                includeCollectives=include_collectives
3711            )
3712        )
3713        self._verify_trace(
3714            t,
3715            include_collectives=include_collectives,
3716            timing_enabled=timing_enabled,
3717            is_json=True,
3718        )
3719        dist.destroy_process_group()
3720
3721    @requires_nccl()
3722    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3723    def test_dump_pipe(self):
3724        def open_file_with_timeout(file_path, mode, timeout=1.0):
3725            start_time = time.time()
3726            while time.time() - start_time < timeout:
3727                if os.path.exists(file_path):
3728                    return open(file_path, mode)
3729                time.sleep(0.1)
3730            raise FileNotFoundError
3731
3732        if self.rank == self.MAIN_PROCESS_RANK:
3733            for c in self.children_pipes:
3734                self.assertEqual(c.recv(), "next")
3735
3736            dump_file = self._trace_name(rank=0)
3737            pipe_file = dump_file + ".pipe"
3738            with open_file_with_timeout(pipe_file, "w") as f:
3739                f.write("1\n")
3740            with open_file_with_timeout(dump_file, "rb", timeout=10.0) as f:
3741                self.assertTrue("all_reduce" in str(pickle.load(f)))
3742
3743            for c in self.children_pipes:
3744                c.send("next")
3745            return
3746
3747        pg = self._create_process_group_nccl()
3748        device = self.local_device
3749        a = torch.full((3, 4), float(self.rank), device=device)
3750        for i in range(2):
3751            f = pg.allreduce(a)
3752        f.wait()
3753        torch.cuda.synchronize(device=device)
3754        self.parent.send("next")
3755        self.parent.recv()
3756
3757    @requires_nccl()
3758    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3759    def test_long(self):
3760        os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
3761        if self.rank == self.MAIN_PROCESS_RANK:
3762            return
3763        pg = self._create_process_group_nccl()
3764        device = self.local_device
3765        a = torch.full((3, 4), float(self.rank), device=device)
3766        for i in range(2):
3767            # test some other primitives to make sure
3768            # their strings are valid
3769            xs = [torch.ones(3, 4, device=device)]
3770            pg.broadcast(xs).wait()
3771            pg.allreduce(xs).wait()
3772            pg.reduce(xs).wait()
3773            ys = [[torch.empty(3, 4, device=device) for _ in range(self.world_size)]]
3774            pg.allgather(ys, xs).wait()
3775            pg.reduce_scatter(xs, ys).wait()
3776            f = pg.allreduce(a)
3777        f.wait()
3778        torch.cuda.synchronize(device=device)
3779        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
3780        t = t["entries"]
3781        self.assertEqual(len(t), 10)
3782        first = t[0]
3783        last = t[-1]
3784        self.assertEqual(last["profiling_name"], "nccl:all_reduce")
3785        self.assertEqual(last["state"], "completed")
3786        self.assertIn("test_c10d_nccl.py", str(last["frames"]))
3787        self.assertEqual(last["input_sizes"], ((3, 4),))
3788        self.assertEqual(last["input_dtypes"], ["Float"])
3789        self.assertEqual(last["output_sizes"], ((3, 4),))
3790        self.assertEqual(last["output_dtypes"], ["Float"])
3791        self.assertEqual(last["timeout_ms"], 600000)
3792        self.assertEqual(last["collective_seq_id"] - first["collective_seq_id"], 9)
3793        dist.destroy_process_group()
3794
3795    @requires_nccl()
3796    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3797    def test_trace_while_all_works_retired(self):
3798        os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
3799        if self.rank == self.MAIN_PROCESS_RANK:
3800            return
3801        pg = self._create_process_group_nccl()
3802        device = self.local_device
3803        # send more works than the buffer size to overwrite the previous entry
3804        for i in range(12):
3805            a = [torch.ones(3, 4, device=device)]
3806            pg.broadcast(a).wait()
3807        torch.cuda.synchronize(device=device)
3808
3809        # wait for all works to be retired
3810        pg._wait_for_pending_works()
3811        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
3812        t = t["entries"]
3813        self.assertEqual(len(t), 10)
3814        last = t[-1]
3815        self.assertEqual(last["retired"], True)
3816        self.assertEqual(last["state"], "completed")
3817
3818    @requires_nccl()
3819    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3820    @parametrize("timing_enabled", [True, False])
3821    @parametrize("only_active", [True, False])
3822    def test_trace_while_active(self, timing_enabled, only_active):
3823        if self.rank == self.MAIN_PROCESS_RANK:
3824            for c in self.children_pipes:
3825                self.assertEqual(c.recv(), "next")
3826            for c in self.children_pipes:
3827                c.send("next")
3828            return
3829
3830        pg = self._create_process_group_nccl()
3831        if timing_enabled:
3832            pg._enable_collectives_timing()
3833        device = self.local_device
3834        with torch.cuda.device(device):
3835            a = torch.full((3, 4), float(self.rank), device=device)
3836
3837            pg.allreduce(a).wait()
3838            e = torch.cuda.Event()
3839            e.record()
3840            if self.rank != 0:
3841                pg.allreduce(a).wait()
3842            e.synchronize()
3843            t = pickle.loads(
3844                torch._C._distributed_c10d._dump_nccl_trace(onlyActive=only_active)
3845            )
3846            t = t["entries"]
3847            if only_active:
3848                if self.rank == 0:
3849                    self.assertEqual(len(t), 0)
3850                else:
3851                    self.assertEqual(len(t), 1)
3852            if not only_active:
3853                if self.rank == 0:
3854                    self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce")
3855                    self.assertEqual(t[-1]["collective_seq_id"], 1)
3856                    self.assertEqual(t[-1]["state"], "completed")
3857                else:
3858                    self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce")
3859                    self.assertEqual(t[-1]["collective_seq_id"], 2)
3860                    self.assertEqual(
3861                        t[-1]["state"], self.started_or_scheduled(timing_enabled)
3862                    )
3863
3864            self.parent.send("next")
3865            self.assertEqual("next", self.parent.recv())
3866            if self.rank == 0:
3867                pg.allreduce(a).wait()
3868            torch.cuda.synchronize(device=device)
3869
3870    @requires_nccl()
3871    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3872    @parametrize("timing_enabled", [True, False])
3873    def test_trace_while_stuck(self, timing_enabled):
3874        if self.rank == self.MAIN_PROCESS_RANK:
3875            for c in self.children_pipes:
3876                self.assertEqual(c.recv(), "next")
3877            for c in self.children_pipes:
3878                c.send("next")
3879            return
3880
3881        pg = self._create_process_group_nccl()
3882        if timing_enabled:
3883            pg._enable_collectives_timing()
3884
3885        device = self.local_device
3886        with torch.cuda.device(device):
3887            a = torch.full((3, 4), float(self.rank), device=device)
3888
3889            pg.allreduce(a).wait()
3890            e = torch.cuda.Event()
3891            e.record()
3892
3893            def gather_trace():
3894                e.synchronize()
3895                # give the other thread some time to fill the cuda buffer
3896                time.sleep(5)
3897                t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
3898                t = t["entries"]
3899                self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce")
3900                if self.rank == 0:
3901                    self.assertEqual(t[-1]["collective_seq_id"], 1)
3902                    self.assertEqual(t[-1]["state"], "completed")
3903                else:
3904                    self.assertEqual(t[-1]["collective_seq_id"], 2)
3905                    self.assertEqual(
3906                        t[-1]["state"], self.started_or_scheduled(timing_enabled)
3907                    )
3908                    self.assertIsNone(t[-1]["time_discovered_completed_ns"])
3909                # this will eventually cause the missing rank 0
3910                # to continue which will unblock the non-zero ranks
3911                self.parent.send("next")
3912
3913            if self.rank != 0:
3914                pg.allreduce(a).wait()
3915                th = threading.Thread(target=gather_trace)
3916                th.start()
3917                # fill the cuda buffer, at around 1024 events
3918                # this will stall
3919                for i in range(2000):
3920                    a = a + a
3921                th.join()
3922            else:
3923                gather_trace()
3924
3925            self.assertEqual("next", self.parent.recv())
3926            if self.rank == 0:
3927                pg.allreduce(a).wait()
3928            torch.cuda.synchronize(device=device)
3929
3930    @requires_nccl()
3931    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3932    @parametrize(
3933        "op_sizes_per_coalesce",
3934        [
3935            [(2, 3)],
3936            [(2, 3), (5, 5), (1,)],
3937        ],
3938    )
3939    @parametrize("timing_enabled", [True, False])
3940    def test_batched_send_recv(self, op_sizes_per_coalesce, timing_enabled):
3941        """
3942        'WorkEnqueue' was skipped for isendirecv, leading to segfault on dump_entries when update_state tried to use
3943        a destructed Work obj's cuda events
3944        """
3945
3946        if self.rank == self.MAIN_PROCESS_RANK:
3947            return
3948        pg = self._create_process_group_nccl()
3949        if timing_enabled:
3950            pg._enable_collectives_timing()
3951
3952        num_coalesced_ops = 20
3953        ops_per_coalesce = len(op_sizes_per_coalesce)
3954        for i in range(num_coalesced_ops):
3955            ops = []
3956            for input_sizes in op_sizes_per_coalesce:
3957                tensor = torch.zeros(input_sizes).to(self.local_device)
3958                if self.rank == 0:
3959                    ops.append(dist.P2POp(dist.irecv, tensor, 1))
3960                elif self.rank == 1:
3961                    tensor *= 2
3962                    ops.append(dist.P2POp(dist.isend, tensor, 0))
3963
3964            dist.batch_isend_irecv(ops).pop().wait()
3965
3966        torch.cuda.synchronize(device=self.local_device)
3967
3968        if timing_enabled:
3969            # wait for watchdog thread to process the queue of works
3970            time.sleep(1)
3971
3972        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
3973        self.assertEqual(len(t["entries"]), num_coalesced_ops * (ops_per_coalesce + 1))
3974
3975        expected_record_id = 0
3976        expected_seq = 1
3977        expected_op_id = 1
3978        for seq in range(num_coalesced_ops):
3979            first_op = seq * (ops_per_coalesce + 1)
3980            coalesced_op = first_op + ops_per_coalesce
3981            for p2p_op_idx, input_sizes in zip(
3982                range(first_op, coalesced_op, 1), op_sizes_per_coalesce
3983            ):
3984                # the indivudal ops inside the coalescing group the individual op metadata,
3985                # but not the timing info coming from the actual coalesced kernel
3986                profiling_name = (
3987                    "nccl:recv 0<-1" if self.rank == 0 else "nccl:send 1->0"
3988                )
3989                self.assertEqual(
3990                    t["entries"][p2p_op_idx]["record_id"], expected_record_id
3991                )
3992                expected_record_id += 1
3993                self.assertEqual(
3994                    t["entries"][p2p_op_idx]["profiling_name"], profiling_name
3995                )
3996                self.assertEqual(
3997                    t["entries"][p2p_op_idx]["collective_seq_id"], expected_seq
3998                )
3999                self.assertEqual(t["entries"][p2p_op_idx]["op_id"], expected_op_id)
4000                expected_op_id += 1
4001                self.assertEqual(t["entries"][p2p_op_idx]["input_sizes"], [input_sizes])
4002                self.assertEqual(
4003                    t["entries"][p2p_op_idx]["output_sizes"], [input_sizes]
4004                )
4005                # duration doesn't get tagged onto individual ops yet, nor is their state updated
4006                self.assertEqual(t["entries"][p2p_op_idx]["state"], "scheduled")
4007                self.assertTrue("duration_ms" not in t["entries"][p2p_op_idx])
4008
4009            # the coalesced op has no metadata but indicates that coalescing was used,
4010            # and accurately reflects the timing and state info for the whole group
4011            self.assertEqual(
4012                t["entries"][coalesced_op]["record_id"], expected_record_id
4013            )
4014            expected_record_id += 1
4015            self.assertEqual(
4016                t["entries"][coalesced_op]["profiling_name"], "nccl:coalesced"
4017            )
4018            self.assertEqual(
4019                t["entries"][coalesced_op]["collective_seq_id"], expected_seq
4020            )
4021            expected_seq += 1
4022            self.assertEqual(t["entries"][coalesced_op]["state"], "completed")
4023            self.assertEqual(t["entries"][coalesced_op]["input_sizes"], [])
4024            self.assertEqual(t["entries"][coalesced_op]["output_sizes"], [])
4025            if timing_enabled:
4026                duration = t["entries"][coalesced_op]["duration_ms"]
4027                self.assertTrue(0.001 < duration < 10000, duration)
4028            else:
4029                self.assertTrue("duration_ms" not in t["entries"][coalesced_op])
4030            self.assertEqual(t["entries"][coalesced_op]["timeout_ms"], 600000)
4031
4032    @requires_nccl()
4033    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
4034    @parametrize(
4035        "op_sizes",
4036        [
4037            [(2, 3)],
4038            [(2, 3), (5, 5), (1,)],
4039        ],
4040    )
4041    @parametrize("timing_enabled", [True, False])
4042    def test_individual_send_recv(self, op_sizes, timing_enabled):
4043        """
4044        'WorkEnqueue' was skipped for isendirecv, leading to segfault on dump_entries when update_state tried to use
4045        a destructed Work obj's cuda events
4046        """
4047
4048        if self.rank == self.MAIN_PROCESS_RANK:
4049            return
4050        pg = self._create_process_group_nccl()
4051        if timing_enabled:
4052            pg._enable_collectives_timing()
4053        num_repeats = 10
4054        ops_per_repeat = len(op_sizes)
4055        for i in range(num_repeats):
4056            for input_sizes in op_sizes:
4057                tensor = torch.zeros(input_sizes).to(self.local_device)
4058                if self.rank == 0:
4059                    dist.recv(tensor, 1)
4060                elif self.rank == 1:
4061                    tensor *= 2
4062                    dist.send(tensor, 0)
4063
4064        torch.cuda.synchronize(device=self.local_device)
4065        if timing_enabled:
4066            # wait for watchdog thread to process the queue of works
4067            time.sleep(1)
4068
4069        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
4070        self.assertEqual(len(t["entries"]), num_repeats * (ops_per_repeat))
4071        expected_seq = 1
4072        expected_op_id = 1
4073        for seq in range(num_repeats * ops_per_repeat):
4074            input_sizes = op_sizes[seq % ops_per_repeat]
4075            profiling_name = "nccl:recv 0<-1" if self.rank == 0 else "nccl:send 1->0"
4076            self.assertEqual(t["entries"][seq]["profiling_name"], profiling_name)
4077            self.assertEqual(t["entries"][seq]["p2p_seq_id"], expected_seq)
4078            expected_seq += 1
4079            self.assertEqual(t["entries"][seq]["op_id"], expected_op_id)
4080            expected_op_id += 1
4081            self.assertEqual(t["entries"][seq]["input_sizes"], [input_sizes])
4082            self.assertEqual(t["entries"][seq]["output_sizes"], [input_sizes])
4083            self.assertEqual(t["entries"][seq]["state"], "completed")
4084
4085            if timing_enabled:
4086                duration = t["entries"][seq]["duration_ms"]
4087                self.assertTrue(0.001 < duration < 10000, duration)
4088            else:
4089                self.assertTrue("duration_ms" not in t["entries"][seq])
4090
4091    # TODO(whc) support and test coalesced collectives that use the c++ start/end group thingy instead of python
4092    # coalescing manager
4093
4094    # TODO(whc) test out other ops (And combinations of ops, if that's valid?)
4095    @requires_nccl()
4096    @skip_if_lt_x_gpu(2)
4097    @parametrize("timing_enabled", [True, False])
4098    def test_coalescing_manager_collective(self, timing_enabled):
4099        """
4100        The coalescing manager api works by accumulating operations in python via a contextmanager, and then making
4101        one call into c++ to an <op>_coalesced API.  It has limited support for ops and has been added recently to
4102        avoid overheads of making individual py-cpp calls.  This complicates flight recording..
4103
4104        For now, flight recording of coalescing_manager collectives is less detailed than cpp coalesced collectives.
4105        """
4106        if self.rank == self.MAIN_PROCESS_RANK:
4107            return
4108        pg = self._create_process_group_nccl()
4109        if timing_enabled:
4110            pg._enable_collectives_timing()
4111
4112        output_tensors = torch.zeros(2, 2).to(self.rank)
4113        input_tensors = [torch.ones(2, 2).to(self.rank) for _ in range(self.world_size)]
4114
4115        # TODO(whc) make this work with bigger world or something
4116        self.assertEqual(self.world_size, 2, self.world_size)
4117
4118        with dist._coalescing_manager():
4119            for i in range(self.world_size):
4120                dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
4121        self.assertEqual(output_tensors, input_tensors[self.rank] * self.world_size)
4122
4123        torch.cuda.synchronize(device=self.rank)
4124
4125        if timing_enabled:
4126            # wait for watchdog thread to process the queue of works
4127            time.sleep(1)
4128
4129        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
4130
4131        self.assertEqual(
4132            len(t["entries"]), 1
4133        )  # one for the reduce_scatter_tensor_coalesced, one for the endCoalescing
4134        self.assertEqual(
4135            t["entries"][0]["profiling_name"], "nccl:reduce_scatter_tensor_coalesced"
4136        )
4137        self.assertEqual(t["entries"][0]["collective_seq_id"], 1)
4138        self.assertEqual(t["entries"][0]["input_sizes"], [[2, 2], [2, 2]])
4139        self.assertEqual(
4140            t["entries"][0]["output_sizes"],
4141            [
4142                [
4143                    2,
4144                ],
4145                [
4146                    2,
4147                ],
4148            ],
4149        )
4150        self.assertEqual(t["entries"][0]["state"], "completed")
4151        if timing_enabled:
4152            duration = t["entries"][0]["duration_ms"]
4153            self.assertTrue(0.001 < duration < 10000, duration)
4154        else:
4155            self.assertTrue("duration_ms" not in t["entries"][0])
4156
4157
4158def check_if_test_is_skipped(fn):
4159    def wrapper(self, *args, **kwargs):
4160        for skip in TEST_SKIPS.values():
4161            if self.processes[0].exitcode == skip.exit_code:
4162                return MultiProcessTestCase._check_return_codes(self, *args, **kwargs)
4163        return fn(self, *args, **kwargs)
4164
4165    return wrapper
4166
4167
4168class NCCLTraceTestDumpOnTimeoutBase(NCCLTraceTestBase):
4169    timeout_sec = 1
4170
4171    def _create_process_group_nccl(self):
4172        store = dist.FileStore(self.file_name, self.world_size)
4173        c10d.init_process_group(
4174            "nccl",
4175            world_size=self.world_size,
4176            rank=self.rank,
4177            store=store,
4178            timeout=timedelta(seconds=NCCLTraceTestDumpOnTimeoutBase.timeout_sec),
4179        )
4180        pg = c10d.distributed_c10d._get_default_group()
4181        return pg
4182
4183    @check_if_test_is_skipped
4184    def _check_return_codes(self, elapsed_time):
4185        # the base test infra assumes processes exit with matching return codes,
4186        # but we want rank0 to abort and rank1 to exit cleanly in this test
4187        self.assertEqual(self.processes[0].exitcode, -6)
4188        self.assertEqual(self.processes[1].exitcode, 0)
4189
4190    def _wait_process(self, rank, timeout):
4191        try:
4192            self.processes[rank].join(timeout)
4193            return self.processes[rank].exitcode
4194        except TimeoutError:
4195            return None
4196
4197
4198@skip_but_pass_in_sandcastle
4199class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
4200    @requires_nccl()
4201    @skip_if_lt_x_gpu(2)
4202    @parametrize("timing_enabled", [True, False])
4203    def test_timeout_dumps(self, timing_enabled):
4204        # dump on heartbeatmonitor thread
4205        os.environ["TORCH_NCCL_COORD_CHECK_MILSEC"] = "1000"
4206        # need rank0 to crash before looking for its output file
4207        os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "1"
4208
4209        if self.rank == self.MAIN_PROCESS_RANK:
4210            # wait for rank0 to crash before looking for its output file
4211            # we rely on rank0 holding off its abort long enough to dump the debug info
4212            self.assertEqual(self._wait_process(0, timeout=90), -6)
4213            with open(self._trace_name(rank=0), "rb") as f:
4214                t = pickle.load(f)
4215                t = t["entries"]
4216                self.assertEqual(len(t), 2)
4217                self.assertEqual(t[0]["collective_seq_id"], 1)
4218                self.assertEqual(t[0]["state"], "completed")
4219                self.assertEqual(t[1]["collective_seq_id"], 2)
4220                self.assertEqual(
4221                    t[1]["state"], self.started_or_scheduled(timing_enabled)
4222                )
4223
4224            self.assertFalse(os.path.exists(self._trace_name(rank=1)))
4225
4226            return
4227
4228        pg = self._create_process_group_nccl()
4229        if timing_enabled:
4230            # we force disabled timing in setup, since there is no 'disable' function
4231            pg._enable_collectives_timing()
4232
4233        device = self.local_device
4234        with torch.cuda.device(device):
4235            a = torch.full((3, 4), float(self.rank), device=device)
4236
4237            pg.allreduce(a).wait()
4238            if self.rank == 0:
4239                pg.allreduce(a).wait()
4240
4241            # rank 0 will crash before it passes the sync, but rank1 will exit quickly and cleanly
4242            torch.cuda.synchronize(device=device)
4243
4244
4245instantiate_parametrized_tests(ProcessGroupNCCLGroupTest)
4246instantiate_parametrized_tests(NCCLTraceTestDumpOnTimeout)
4247instantiate_parametrized_tests(NCCLTraceTest)
4248
4249
4250@skip_but_pass_in_sandcastle
4251class NCCLTraceTestTimeoutDumpOnStuckRanks(NCCLTraceTestDumpOnTimeoutBase):
4252    @check_if_test_is_skipped
4253    def _check_return_codes(self, elapsed_time):
4254        # the base test infra assumes processes exit with matching return codes,
4255        # but we want rank0 to abort and rank1 to exit cleanly in this test
4256        self.assertEqual(self.processes[0].exitcode, -6)
4257        self.assertEqual(self.processes[1].exitcode, -6)
4258
4259    @requires_nccl()
4260    @skip_if_lt_x_gpu(2)
4261    def test_timeout_dumps_on_stuck_ranks(self):
4262        # need rank0 to crash quicker after detecting timeout
4263        os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "1"
4264        # restore this env var to its prior default in case another test changed it
4265        os.environ["TORCH_NCCL_COORD_CHECK_MILSEC"] = "1000"
4266
4267        if self.rank == self.MAIN_PROCESS_RANK:
4268            # wait for both rank0 and 1 to crash before looking for both ranks' output
4269            # file, and we rely on rank1 to sleep long enough to dump the debug info.
4270            self.assertEqual(self._wait_process(0, timeout=90), -6)
4271            self.assertEqual(self._wait_process(1, timeout=90), -6)
4272            self.assertTrue(os.path.exists(self._trace_name(rank=1)))
4273            self.assertTrue(os.path.exists(self._trace_name(rank=0)))
4274            with open(self._trace_name(rank=0), "rb") as f:
4275                t = pickle.load(f)
4276                t = t["entries"]
4277                self.assertEqual(len(t), 2)
4278            with open(self._trace_name(rank=1), "rb") as f:
4279                t = pickle.load(f)
4280                t = t["entries"]
4281                self.assertEqual(len(t), 1)
4282                self.assertEqual(t[0]["collective_seq_id"], 1)
4283                self.assertEqual(t[0]["state"], "completed")
4284            return
4285
4286        pg = self._create_process_group_nccl()
4287        device = self.local_device
4288        with torch.cuda.device(device):
4289            a = torch.full((3, 4), float(self.rank), device=device)
4290
4291            pg.allreduce(a).wait()
4292            if self.rank == 0:
4293                pg.allreduce(a).wait()
4294
4295            # rank 0 will get stuck, timeout and then signal a timeout to all ranks.
4296            torch.cuda.synchronize(device=device)
4297
4298            if self.rank == 1:
4299                # Force rank 1 to idle so that it will eventually timeout as well after
4300                # getting the global signal to dump the debugging info.
4301                time.sleep(600)
4302
4303
4304@skip_but_pass_in_sandcastle
4305class NcclErrorDumpTest(NCCLTraceTestBase):
4306    def _wait_process(self, rank, timeout):
4307        try:
4308            self.processes[rank].join(timeout)
4309            return self.processes[rank].exitcode
4310        except TimeoutError:
4311            return None
4312
4313    @check_if_test_is_skipped
4314    def _check_return_codes(self, elapsed_time):
4315        # the base test infra assumes processes exit with matching return codes,
4316        # but we want rank0 to abort with exception and rank1 to exit with exit 1
4317        self.assertEqual(self.processes[0].exitcode, -6)
4318        self.assertEqual(self.processes[1].exitcode, 1)
4319
4320    @requires_nccl()
4321    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
4322    @skip_if_lt_x_gpu(2)
4323    @skip_if_rocm_multiprocess
4324    def test_nccl_errors_dump(self):
4325        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
4326        os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "1000"
4327        os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
4328        # need rank0 to dump before abort
4329        os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "5"
4330
4331        if self.rank == self.MAIN_PROCESS_RANK:
4332            # wait for both rank0 and 1 to crash before looking for dump
4333            self.assertEqual(self._wait_process(0, timeout=90), -6)
4334            self.assertEqual(self._wait_process(1, timeout=90), 1)
4335            # verify that the trace file exists for rank0
4336            self.assertTrue(os.path.exists(self._trace_name(rank=0)))
4337            return
4338
4339        store = c10d.FileStore(self.file_name, self.world_size)
4340        process_group = c10d.ProcessGroupNCCL(
4341            store,
4342            self.rank,
4343            self.world_size,
4344            timeout=timedelta(seconds=10),
4345        )
4346        process_group.allreduce(torch.rand(10).cuda(self.rank))
4347        if self.rank == 0:
4348            work = process_group.allreduce(torch.rand(10).cuda(self.rank))
4349            # expect an error to be raised
4350            with self.assertRaisesRegex(dist.DistBackendError, ""):
4351                # Block the current stream on the NCCL stream
4352                work.wait()
4353                # Run some GPU operations
4354                a = torch.rand(10).cuda(self.rank)
4355        elif self.rank == 1:
4356            # Clean up structures (ex: files for FileStore before going down)
4357            del process_group
4358            sys.exit(1)
4359
4360
4361# tests that needs to be run with a larger world size
4362class ProcessGroupNCCLLargerScaleTest(MultiProcessTestCase):
4363    def _create_process_group_nccl(self, store, opts, device_id=None):
4364        # create nccl processgroup with opts
4365        c10d.init_process_group(
4366            "nccl",
4367            world_size=self.world_size,
4368            rank=self.rank,
4369            store=store,
4370            pg_options=opts,
4371            device_id=device_id,
4372        )
4373        pg = c10d.distributed_c10d._get_default_group()
4374        return pg
4375
4376    def opts(self, high_priority_stream=False):
4377        opts = c10d.ProcessGroupNCCL.Options()
4378        opts.is_high_priority_stream = high_priority_stream
4379        return opts
4380
4381    def setUp(self):
4382        super().setUp()
4383        # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
4384        # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
4385        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
4386        # self.num_gpus = torch.cuda.device_count()
4387        self._spawn_processes()
4388
4389    def tearDown(self):
4390        super().tearDown()
4391        try:
4392            os.remove(self.file_name)
4393        except OSError:
4394            pass
4395
4396    @property
4397    def world_size(self):
4398        return 8
4399
4400    @property
4401    def rank_to_GPU(self):
4402        # return rank to GPU map
4403        return init_multigpu_helper(self.world_size, "nccl")
4404
4405    @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
4406    @skip_if_lt_x_gpu(8)
4407    def test_comm_split_group_larger_scale(self):
4408        store = c10d.FileStore(self.file_name, self.world_size)
4409        device = torch.device(f"cuda:{self.rank}")
4410        pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
4411        backend = pg._get_backend(torch.device(device))
4412
4413        tensor = torch.full((1,), self.rank).cuda(device)
4414        ng1 = c10d.split_group(pg, [[0, 1], [2, 3, 4, 5, 6, 7]])
4415        backend1 = ng1._get_backend(torch.device(device))
4416
4417        # comm split happens eagerly since device_id is passed to init_process_group.
4418        self.assertEqual(backend.comm_split_count(), 1)
4419        # dist.broadcast take Source rank on global process group
4420        if self.rank < 2:
4421            dist.broadcast(tensor, 0, group=ng1)
4422            self.assertEqual(tensor, torch.full((1,), 0))
4423        else:
4424            dist.broadcast(tensor, 2, group=ng1)
4425            self.assertEqual(tensor, torch.full((1,), 2))
4426
4427        # test split with only one colored group, other ranks should be no color split.
4428        ng2 = c10d.split_group(pg, [[5, 6, 7]])
4429        self.assertEqual(backend.comm_split_count(), 2)
4430
4431        if self.rank >= 5:
4432            tensor2 = torch.full((1,), self.rank).cuda(device)
4433            dist.broadcast(tensor2, 7, group=ng2)
4434            self.assertEqual(tensor2, torch.full((1,), 7))
4435        else:
4436            self.assertEqual(ng2, None)
4437        # a barrier and a cuda sync before destroying all pgs.
4438        dist.barrier(pg)
4439        torch.cuda.synchronize()
4440        dist.destroy_process_group()
4441
4442    @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
4443    @skip_if_lt_x_gpu(8)
4444    def test_comm_recursive_split_group(self):
4445        store = c10d.FileStore(self.file_name, self.world_size)
4446        device = torch.device(f"cuda:{self.rank}")
4447        pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
4448        backend = pg._get_backend(torch.device(device))
4449
4450        # split the default PG into 2 subgroups, each subgroup (ng1) has 4 ranks.
4451        tensor1 = torch.full((1,), self.rank).cuda(device)
4452        ng1 = c10d.split_group(pg, [[0, 1, 2, 3], [4, 5, 6, 7]])
4453        backend1 = ng1._get_backend(torch.device(device))
4454        if self.rank < 4:
4455            dist.broadcast(tensor1, 0, group=ng1)
4456            self.assertEqual(tensor1, torch.full((1,), 0))
4457        else:
4458            dist.broadcast(tensor1, 4, group=ng1)
4459            self.assertEqual(tensor1, torch.full((1,), 4))
4460
4461        # comm split happens eagerly since device_id is passed to init_process_group.
4462        self.assertEqual(backend.comm_split_count(), 1)
4463        self.assertEqual(backend1.comm_split_count(), 0)
4464
4465        # further split ng1 into 2 subgroups, each subgroup (ng2) has 2 ranks.
4466        tensor2 = torch.full((1,), self.rank).cuda(device)
4467        ng2 = c10d.split_group(ng1, [[0, 1], [2, 3]])
4468        backend2 = ng2._get_backend(torch.device(device))
4469        self.assertEqual(backend.comm_split_count(), 1)
4470        self.assertEqual(backend1.comm_split_count(), 1)
4471        self.assertEqual(backend2.comm_split_count(), 0)
4472
4473        # execute collective calls within each 2-rank pg
4474        if self.rank == 0 or self.rank == 1:
4475            dist.broadcast(tensor2, 1, group=ng2)
4476            self.assertEqual(tensor2, torch.full((1,), 1))
4477
4478        if self.rank == 2 or self.rank == 3:
4479            dist.broadcast(tensor2, 2, group=ng2)
4480            self.assertEqual(tensor2, torch.full((1,), 2))
4481
4482        if self.rank == 4 or self.rank == 5:
4483            dist.broadcast(tensor2, 5, group=ng2)
4484            self.assertEqual(tensor2, torch.full((1,), 5))
4485
4486        if self.rank == 6 or self.rank == 7:
4487            dist.broadcast(tensor2, 6, group=ng2)
4488            self.assertEqual(tensor2, torch.full((1,), 6))
4489        # a barrier and a cuda sync before destroying all pgs.
4490        dist.barrier(pg)
4491        torch.cuda.synchronize()
4492        dist.destroy_process_group()
4493
4494
4495if __name__ == "__main__":
4496    assert (
4497        not torch.cuda._initialized
4498    ), "test_distributed must not have initialized CUDA context on main process"
4499
4500    run_tests()
4501