xref: /aosp_15_r20/external/pytorch/test/distributed/test_c10d_spawn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import os
4import sys
5import tempfile
6
7import torch
8import torch.distributed as c10d
9import torch.multiprocessing as mp
10from torch.testing._internal.common_distributed import MultiProcessTestCase
11from torch.testing._internal.common_utils import load_tests, NO_MULTIPROCESSING_SPAWN
12
13
14# Torch distributed.nn is not available in windows
15# check #42095, it errors on import.
16_torch_dist_nn_available = True
17try:
18    import torch.distributed.nn
19except ImportError:
20    _torch_dist_nn_available = False
21
22# load_tests from common_utils is used to automatically filter tests for
23# sharding on sandcastle. This line silences flake warnings
24load_tests = load_tests
25
26if not c10d.is_available():
27    print("c10d not available, skipping tests", file=sys.stderr)
28    sys.exit(0)
29
30if NO_MULTIPROCESSING_SPAWN:
31    print("spawn not available, skipping tests", file=sys.stderr)
32    sys.exit(0)
33
34
35class AbstractProcessGroupShareTensorTest:
36    world_size = 2
37
38    def _test_multiprocess(self, f, shared_tensors, init_pg, n_output):
39        ws = self.world_size
40        # file store will delete the test file on destruction
41        file = tempfile.NamedTemporaryFile(delete=False)
42        ctx = mp.get_context("spawn")
43        c2p = ctx.Queue(2)
44        p2c = ctx.Queue(2)
45        ps = []
46        for i in range(ws):
47            p = ctx.Process(
48                target=f, args=(i, file.name, shared_tensors, ws, init_pg, c2p, p2c)
49            )
50
51            p.start()
52            ps.append(p)
53
54        for _ in range(ws * n_output):
55            pid, expected, result = c2p.get()
56            self.assertEqual(
57                expected,
58                result,
59                msg=f"Expect rank {pid} to receive tensor {expected} but got {result}.",
60            )
61
62        for _ in range(ws):
63            p2c.put(0)
64
65        for p in ps:
66            p.join(2)
67
68    # Why classmethod? multiprocessing cannot pickle TestCase subclass when in
69    # spawn mode. See https://bugs.python.org/issue33884.
70    @classmethod
71    def _test_broadcast_process(
72        cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c
73    ):
74        pg = init_pg(rank, filename, world_size)
75        xs = [shared_tensors[rank]]
76        pg.broadcast(xs).wait()
77        c2p.put((rank, torch.zeros(2, 2), xs[0].to("cpu")))
78        p2c.get()
79
80    @classmethod
81    def _test_allreduce_process(
82        cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c
83    ):
84        pg = init_pg(rank, filename, world_size)
85        xs = [shared_tensors[rank]]
86        pg.allreduce(xs, op=c10d.ReduceOp.SUM).wait()
87        c2p.put((rank, torch.ones(2, 2) * 2, xs[0].to("cpu")))
88        p2c.get()
89
90    @classmethod
91    def _test_allgather_process(
92        cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c
93    ):
94        pg = init_pg(rank, filename, world_size)
95        xs = [shared_tensors[rank]]
96        ys = [[torch.zeros_like(xs[0]) for i in range(world_size)]]
97        pg.allgather(ys, xs).wait()
98        for i in range(world_size):
99            c2p.put((rank, torch.ones(2, 2) * i, ys[0][i].to("cpu")))
100
101        p2c.get()
102
103
104class TestDistributedNNFunctions(MultiProcessTestCase):
105    def setUp(self):
106        super().setUp()
107        self._spawn_processes()
108
109    def tearDown(self):
110        super().tearDown()
111        try:
112            os.remove(self.file_name)
113        except OSError:
114            pass
115
116    @property
117    def op_timeout_sec(self):
118        return 1
119
120    @property
121    def world_size(self):
122        return 2
123
124    def _test_broadcast(self, backend):
125        store = c10d.FileStore(self.file_name, self.world_size)
126        # This is required because these functions calls directly to the .dist and needs
127        # the world to be initialized
128        c10d.init_process_group(
129            store=store, rank=self.rank, world_size=self.world_size, backend=backend
130        )
131        device = torch.device(f"cuda:{self.rank}")
132        x = torch.ones(5, 5, device=device) + self.rank
133        x.requires_grad = True
134        y = torch.distributed.nn.broadcast(x, 1)
135        self.assertEqual(y, 1 + torch.ones(5, 5))
136        z = y.sin().sum()
137        z.backward()
138        # We can't check the gradient of communications numerically so we have to do some calculations
139        if self.rank == 1:
140            self.assertEqual(x.grad, 2 * torch.cos(x))
141        elif self.rank == 0:
142            self.assertEqual(x.grad, torch.zeros(5, 5, device=device))
143
144    def _test_reduce(self, backend):
145        store = c10d.FileStore(self.file_name, self.world_size)
146        # This is required because these functions calls directly to the .dist and needs
147        # the world to be initialized
148        c10d.init_process_group(
149            store=store, rank=self.rank, world_size=self.world_size, backend=backend
150        )
151        device = torch.device(f"cuda:{self.rank}")
152        x = torch.ones(5, 5, device=device) + self.rank
153        x.requires_grad = True
154        y = torch.distributed.nn.reduce(x, 1, op=c10d.ReduceOp.SUM)
155
156        if self.rank == 1:
157            self.assertEqual(y, 3 * torch.ones(5, 5, device=device))
158
159        z = y.sin().sum()
160        z.backward()
161        # Gradients are broadcasted to both ranks
162        x_g = (3 * torch.ones(5, 5, device=device)).cos()
163        self.assertEqual(x.grad, x_g)
164
165    def _test_allreduce(self, backend):
166        store = c10d.FileStore(self.file_name, self.world_size)
167        # This is required because these functions calls directly to the .dist and needs
168        # the world to be initialized
169        c10d.init_process_group(
170            store=store, rank=self.rank, world_size=self.world_size, backend=backend
171        )
172        device = torch.device(f"cuda:{self.rank}")
173        x = torch.ones(5, 5, device=device) + self.rank
174        x.requires_grad = True
175        y = torch.distributed.nn.all_reduce(x, op=c10d.ReduceOp.SUM)
176
177        self.assertEqual(y, 3 * torch.ones(5, 5, device=device))
178
179        z = y.sin().sum()
180        z.backward()
181        x_g = 2 * (3 * torch.ones(5, 5, device=device)).cos()
182        self.assertEqual(x.grad, x_g)
183
184    def _test_all_gather(self, backend):
185        store = c10d.FileStore(self.file_name, self.world_size)
186        # This is required because these functions calls directly to the .dist and needs
187        # the world to be initialized
188        c10d.init_process_group(
189            store=store, rank=self.rank, world_size=self.world_size, backend=backend
190        )
191        device = torch.device(f"cuda:{self.rank}")
192        x = torch.ones(5, 5, device=device) + self.rank
193        x.requires_grad = True
194        tensors = torch.distributed.nn.all_gather(x)
195        for i, t in enumerate(tensors):
196            self.assertEqual(t, torch.ones(5, 5, device=device) + i)
197        y = torch.sum(torch.stack(tensors), axis=0)
198        z = y.sin().sum()
199        z.backward()
200
201        x_s = 2 * (3 * torch.ones(5, 5, device=device)).cos()
202        self.assertEqual(x.grad, x_s)
203
204    def _test_all_to_all(self, backend):
205        store = c10d.FileStore(self.file_name, self.world_size)
206        # This is required because these functions calls directly to the .dist and needs
207        # the world to be initialized
208        c10d.init_process_group(
209            store=store, rank=self.rank, world_size=self.world_size, backend=backend
210        )
211        device = torch.device(f"cuda:{self.rank}")
212        x0 = torch.ones(5, 5, device=device) + 2 * self.rank
213        x1 = torch.ones(5, 5, device=device) + 2 * self.rank
214        x0.requires_grad = True
215        x1.requires_grad = True
216        y0 = torch.empty_like(x0)
217        y1 = torch.empty_like(x1)
218        tensors = torch.distributed.nn.all_to_all([y0, y1], [x0, x1])
219        for i, t in enumerate(tensors):
220            self.assertEqual(t, torch.ones(5, 5, device=device) + 2 * i)
221        y = torch.sum(torch.stack(tensors), axis=0)
222        z = y.sin().sum()
223        z.backward()
224        x_s = (4 * torch.ones(5, 5, device=device)).cos()
225        self.assertEqual(x0.grad, x_s)
226        self.assertEqual(x1.grad, x_s)
227
228    def _test_all_to_all_single(self, backend):
229        store = c10d.FileStore(self.file_name, self.world_size)
230        # This is required because these functions calls directly to the .dist and needs
231        # the world to be initialized
232        c10d.init_process_group(
233            store=store, rank=self.rank, world_size=self.world_size, backend=backend
234        )
235        device = torch.device(f"cuda:{self.rank}")
236        row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
237        x = torch.ones(int(row), 5, device=device) * (self.rank + 1)
238        x.requires_grad = True
239        y = torch.empty_like(x)
240        split_sizes = [(i + 1) * (self.rank + 1) for i in range(self.world_size)]
241        y = torch.distributed.nn.all_to_all_single(
242            y, x, output_split_sizes=split_sizes, input_split_sizes=split_sizes
243        )
244        expected = []
245        for idx, tensor in enumerate(torch.split(x, split_sizes)):
246            expected.append(torch.full_like(tensor, (idx + 1)))
247        expected = torch.cat(expected)
248        self.assertEqual(y, expected)
249        z = y.sin().sum()
250        z.backward()
251        x_s = ((self.rank + 1) * torch.ones(int(row), 5, device=device)).cos()
252        self.assertEqual(x.grad, x_s)
253