1# Owner(s): ["oncall: distributed"] 2 3import os 4import sys 5import tempfile 6 7import torch 8import torch.distributed as c10d 9import torch.multiprocessing as mp 10from torch.testing._internal.common_distributed import MultiProcessTestCase 11from torch.testing._internal.common_utils import load_tests, NO_MULTIPROCESSING_SPAWN 12 13 14# Torch distributed.nn is not available in windows 15# check #42095, it errors on import. 16_torch_dist_nn_available = True 17try: 18 import torch.distributed.nn 19except ImportError: 20 _torch_dist_nn_available = False 21 22# load_tests from common_utils is used to automatically filter tests for 23# sharding on sandcastle. This line silences flake warnings 24load_tests = load_tests 25 26if not c10d.is_available(): 27 print("c10d not available, skipping tests", file=sys.stderr) 28 sys.exit(0) 29 30if NO_MULTIPROCESSING_SPAWN: 31 print("spawn not available, skipping tests", file=sys.stderr) 32 sys.exit(0) 33 34 35class AbstractProcessGroupShareTensorTest: 36 world_size = 2 37 38 def _test_multiprocess(self, f, shared_tensors, init_pg, n_output): 39 ws = self.world_size 40 # file store will delete the test file on destruction 41 file = tempfile.NamedTemporaryFile(delete=False) 42 ctx = mp.get_context("spawn") 43 c2p = ctx.Queue(2) 44 p2c = ctx.Queue(2) 45 ps = [] 46 for i in range(ws): 47 p = ctx.Process( 48 target=f, args=(i, file.name, shared_tensors, ws, init_pg, c2p, p2c) 49 ) 50 51 p.start() 52 ps.append(p) 53 54 for _ in range(ws * n_output): 55 pid, expected, result = c2p.get() 56 self.assertEqual( 57 expected, 58 result, 59 msg=f"Expect rank {pid} to receive tensor {expected} but got {result}.", 60 ) 61 62 for _ in range(ws): 63 p2c.put(0) 64 65 for p in ps: 66 p.join(2) 67 68 # Why classmethod? multiprocessing cannot pickle TestCase subclass when in 69 # spawn mode. See https://bugs.python.org/issue33884. 70 @classmethod 71 def _test_broadcast_process( 72 cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c 73 ): 74 pg = init_pg(rank, filename, world_size) 75 xs = [shared_tensors[rank]] 76 pg.broadcast(xs).wait() 77 c2p.put((rank, torch.zeros(2, 2), xs[0].to("cpu"))) 78 p2c.get() 79 80 @classmethod 81 def _test_allreduce_process( 82 cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c 83 ): 84 pg = init_pg(rank, filename, world_size) 85 xs = [shared_tensors[rank]] 86 pg.allreduce(xs, op=c10d.ReduceOp.SUM).wait() 87 c2p.put((rank, torch.ones(2, 2) * 2, xs[0].to("cpu"))) 88 p2c.get() 89 90 @classmethod 91 def _test_allgather_process( 92 cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c 93 ): 94 pg = init_pg(rank, filename, world_size) 95 xs = [shared_tensors[rank]] 96 ys = [[torch.zeros_like(xs[0]) for i in range(world_size)]] 97 pg.allgather(ys, xs).wait() 98 for i in range(world_size): 99 c2p.put((rank, torch.ones(2, 2) * i, ys[0][i].to("cpu"))) 100 101 p2c.get() 102 103 104class TestDistributedNNFunctions(MultiProcessTestCase): 105 def setUp(self): 106 super().setUp() 107 self._spawn_processes() 108 109 def tearDown(self): 110 super().tearDown() 111 try: 112 os.remove(self.file_name) 113 except OSError: 114 pass 115 116 @property 117 def op_timeout_sec(self): 118 return 1 119 120 @property 121 def world_size(self): 122 return 2 123 124 def _test_broadcast(self, backend): 125 store = c10d.FileStore(self.file_name, self.world_size) 126 # This is required because these functions calls directly to the .dist and needs 127 # the world to be initialized 128 c10d.init_process_group( 129 store=store, rank=self.rank, world_size=self.world_size, backend=backend 130 ) 131 device = torch.device(f"cuda:{self.rank}") 132 x = torch.ones(5, 5, device=device) + self.rank 133 x.requires_grad = True 134 y = torch.distributed.nn.broadcast(x, 1) 135 self.assertEqual(y, 1 + torch.ones(5, 5)) 136 z = y.sin().sum() 137 z.backward() 138 # We can't check the gradient of communications numerically so we have to do some calculations 139 if self.rank == 1: 140 self.assertEqual(x.grad, 2 * torch.cos(x)) 141 elif self.rank == 0: 142 self.assertEqual(x.grad, torch.zeros(5, 5, device=device)) 143 144 def _test_reduce(self, backend): 145 store = c10d.FileStore(self.file_name, self.world_size) 146 # This is required because these functions calls directly to the .dist and needs 147 # the world to be initialized 148 c10d.init_process_group( 149 store=store, rank=self.rank, world_size=self.world_size, backend=backend 150 ) 151 device = torch.device(f"cuda:{self.rank}") 152 x = torch.ones(5, 5, device=device) + self.rank 153 x.requires_grad = True 154 y = torch.distributed.nn.reduce(x, 1, op=c10d.ReduceOp.SUM) 155 156 if self.rank == 1: 157 self.assertEqual(y, 3 * torch.ones(5, 5, device=device)) 158 159 z = y.sin().sum() 160 z.backward() 161 # Gradients are broadcasted to both ranks 162 x_g = (3 * torch.ones(5, 5, device=device)).cos() 163 self.assertEqual(x.grad, x_g) 164 165 def _test_allreduce(self, backend): 166 store = c10d.FileStore(self.file_name, self.world_size) 167 # This is required because these functions calls directly to the .dist and needs 168 # the world to be initialized 169 c10d.init_process_group( 170 store=store, rank=self.rank, world_size=self.world_size, backend=backend 171 ) 172 device = torch.device(f"cuda:{self.rank}") 173 x = torch.ones(5, 5, device=device) + self.rank 174 x.requires_grad = True 175 y = torch.distributed.nn.all_reduce(x, op=c10d.ReduceOp.SUM) 176 177 self.assertEqual(y, 3 * torch.ones(5, 5, device=device)) 178 179 z = y.sin().sum() 180 z.backward() 181 x_g = 2 * (3 * torch.ones(5, 5, device=device)).cos() 182 self.assertEqual(x.grad, x_g) 183 184 def _test_all_gather(self, backend): 185 store = c10d.FileStore(self.file_name, self.world_size) 186 # This is required because these functions calls directly to the .dist and needs 187 # the world to be initialized 188 c10d.init_process_group( 189 store=store, rank=self.rank, world_size=self.world_size, backend=backend 190 ) 191 device = torch.device(f"cuda:{self.rank}") 192 x = torch.ones(5, 5, device=device) + self.rank 193 x.requires_grad = True 194 tensors = torch.distributed.nn.all_gather(x) 195 for i, t in enumerate(tensors): 196 self.assertEqual(t, torch.ones(5, 5, device=device) + i) 197 y = torch.sum(torch.stack(tensors), axis=0) 198 z = y.sin().sum() 199 z.backward() 200 201 x_s = 2 * (3 * torch.ones(5, 5, device=device)).cos() 202 self.assertEqual(x.grad, x_s) 203 204 def _test_all_to_all(self, backend): 205 store = c10d.FileStore(self.file_name, self.world_size) 206 # This is required because these functions calls directly to the .dist and needs 207 # the world to be initialized 208 c10d.init_process_group( 209 store=store, rank=self.rank, world_size=self.world_size, backend=backend 210 ) 211 device = torch.device(f"cuda:{self.rank}") 212 x0 = torch.ones(5, 5, device=device) + 2 * self.rank 213 x1 = torch.ones(5, 5, device=device) + 2 * self.rank 214 x0.requires_grad = True 215 x1.requires_grad = True 216 y0 = torch.empty_like(x0) 217 y1 = torch.empty_like(x1) 218 tensors = torch.distributed.nn.all_to_all([y0, y1], [x0, x1]) 219 for i, t in enumerate(tensors): 220 self.assertEqual(t, torch.ones(5, 5, device=device) + 2 * i) 221 y = torch.sum(torch.stack(tensors), axis=0) 222 z = y.sin().sum() 223 z.backward() 224 x_s = (4 * torch.ones(5, 5, device=device)).cos() 225 self.assertEqual(x0.grad, x_s) 226 self.assertEqual(x1.grad, x_s) 227 228 def _test_all_to_all_single(self, backend): 229 store = c10d.FileStore(self.file_name, self.world_size) 230 # This is required because these functions calls directly to the .dist and needs 231 # the world to be initialized 232 c10d.init_process_group( 233 store=store, rank=self.rank, world_size=self.world_size, backend=backend 234 ) 235 device = torch.device(f"cuda:{self.rank}") 236 row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2 237 x = torch.ones(int(row), 5, device=device) * (self.rank + 1) 238 x.requires_grad = True 239 y = torch.empty_like(x) 240 split_sizes = [(i + 1) * (self.rank + 1) for i in range(self.world_size)] 241 y = torch.distributed.nn.all_to_all_single( 242 y, x, output_split_sizes=split_sizes, input_split_sizes=split_sizes 243 ) 244 expected = [] 245 for idx, tensor in enumerate(torch.split(x, split_sizes)): 246 expected.append(torch.full_like(tensor, (idx + 1))) 247 expected = torch.cat(expected) 248 self.assertEqual(y, expected) 249 z = y.sin().sum() 250 z.backward() 251 x_s = ((self.rank + 1) * torch.ones(int(row), 5, device=device)).cos() 252 self.assertEqual(x.grad, x_s) 253