1# Owner(s): ["oncall: distributed"] 2 3import os 4import sys 5 6import torch 7import torch.cuda 8import torch.distributed as dist 9import torch.distributed.algorithms._quantization.quantization as quant 10from torch.distributed.algorithms._quantization.quantization import DQuantType 11from torch.testing._internal.common_distributed import ( 12 init_multigpu_helper, 13 MultiProcessTestCase, 14 requires_gloo, 15 requires_nccl, 16 skip_if_lt_x_gpu, 17 skip_if_rocm_multiprocess, 18) 19from torch.testing._internal.common_utils import ( 20 NO_MULTIPROCESSING_SPAWN, 21 run_tests, 22 skip_but_pass_in_sandcastle_if, 23 TEST_WITH_DEV_DBG_ASAN, 24) 25 26 27torch.backends.cuda.matmul.allow_tf32 = False 28 29if not dist.is_available(): 30 print("Distributed not available, skipping tests", file=sys.stderr) 31 sys.exit(0) 32 33 34def _build_tensor(size, value=None, dtype=torch.float, device_id=None): 35 if value is None: 36 value = size 37 if device_id is None: 38 return torch.empty(size, dtype=dtype).fill_(value) 39 else: 40 return torch.empty(size, dtype=dtype).fill_(value).cuda(device_id) 41 42 43if TEST_WITH_DEV_DBG_ASAN: 44 print( 45 "Skip dev-asan as torch + multiprocessing spawn have known issues", 46 file=sys.stderr, 47 ) 48 sys.exit(0) 49 50if NO_MULTIPROCESSING_SPAWN: 51 print("Spawn not available, skipping tests.", file=sys.stderr) 52 sys.exit(0) 53 54BACKEND = os.environ["BACKEND"] 55if BACKEND == "gloo" or BACKEND == "nccl": 56 57 class DistQuantizationTests(MultiProcessTestCase): 58 def setUp(self): 59 super().setUp() 60 self._spawn_processes() 61 torch.backends.cudnn.flags(enabled=True, allow_tf32=False).__enter__() 62 63 def tearDown(self): 64 super().tearDown() 65 try: 66 os.remove(self.file_name) 67 except OSError: 68 pass 69 70 @property 71 def op_timeout_sec(self): 72 return 1 73 74 @property 75 def world_size(self): 76 return int(os.environ["WORLD_SIZE"]) 77 78 @requires_gloo() 79 @skip_but_pass_in_sandcastle_if( 80 BACKEND != "gloo", "Only gloo backend supports all_gather_fp16" 81 ) 82 def test_all_gather_fp16(self): 83 store = dist.FileStore(self.file_name, self.world_size) 84 dist.init_process_group( 85 store=store, rank=self.rank, world_size=self.world_size, backend="gloo" 86 ) 87 device = torch.device(f"cuda:{self.rank}") 88 group = list(range(0, self.world_size)) 89 group_id = dist.group.WORLD 90 self._test_all_gather( 91 group, group_id, self.rank, dtype=torch.float32, qtype=DQuantType.FP16 92 ) 93 94 @requires_gloo() 95 @skip_but_pass_in_sandcastle_if( 96 BACKEND != "gloo", "Only gloo backend supports all_gather_fp16" 97 ) 98 def test_all_gather_bfp16(self): 99 store = dist.FileStore(self.file_name, self.world_size) 100 dist.init_process_group( 101 store=store, rank=self.rank, world_size=self.world_size, backend="gloo" 102 ) 103 device = torch.device(f"cuda:{self.rank}") 104 group = list(range(0, self.world_size)) 105 group_id = dist.group.WORLD 106 self._test_all_gather( 107 group, group_id, self.rank, dtype=torch.float32, qtype=DQuantType.BFP16 108 ) 109 110 @requires_nccl() 111 @skip_but_pass_in_sandcastle_if( 112 BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16" 113 ) 114 @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) 115 @skip_if_rocm_multiprocess 116 def test_all_to_all_fp16(self): 117 store = dist.FileStore(self.file_name, self.world_size) 118 dist.init_process_group( 119 store=store, rank=self.rank, world_size=self.world_size, backend="nccl" 120 ) 121 device = torch.device(f"cuda:{self.rank}") 122 group = list(range(0, self.world_size)) 123 group_id = dist.new_group(range(self.world_size)) 124 rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND) 125 self._test_all_to_all( 126 group, 127 group_id, 128 self.rank, 129 cuda=True, 130 rank_to_GPU=rank_to_GPU, 131 dtype=torch.float32, 132 qtype=DQuantType.FP16, 133 ) 134 135 @requires_nccl() 136 @skip_but_pass_in_sandcastle_if( 137 BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16" 138 ) 139 @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) 140 @skip_if_rocm_multiprocess 141 def test_all_to_all_bfp16(self): 142 store = dist.FileStore(self.file_name, self.world_size) 143 dist.init_process_group( 144 store=store, rank=self.rank, world_size=self.world_size, backend="nccl" 145 ) 146 device = torch.device(f"cuda:{self.rank}") 147 group = list(range(0, self.world_size)) 148 group_id = dist.new_group(range(self.world_size)) 149 rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND) 150 self._test_all_to_all( 151 group, 152 group_id, 153 self.rank, 154 cuda=True, 155 rank_to_GPU=rank_to_GPU, 156 dtype=torch.float32, 157 qtype=DQuantType.BFP16, 158 ) 159 160 @requires_nccl() 161 @skip_but_pass_in_sandcastle_if( 162 BACKEND != "nccl", "Only nccl backend supports all_to_all_single_fp16" 163 ) 164 @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) 165 def test_all_to_all_single_fp16(self): 166 store = dist.FileStore(self.file_name, self.world_size) 167 dist.init_process_group( 168 store=store, rank=self.rank, world_size=self.world_size, backend="nccl" 169 ) 170 device = torch.device(f"cuda:{self.rank}") 171 group = list(range(0, self.world_size)) 172 group_id = dist.new_group(range(self.world_size)) 173 rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND) 174 self._test_all_to_all_single( 175 group, 176 group_id, 177 self.rank, 178 cuda=True, 179 rank_to_GPU=rank_to_GPU, 180 dtype=torch.float32, 181 qtype=DQuantType.FP16, 182 ) 183 184 @requires_nccl() 185 @skip_but_pass_in_sandcastle_if( 186 BACKEND != "nccl", "Only nccl backend supports all_to_all_single_bfp16" 187 ) 188 @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) 189 def test_all_to_all_single_bfp16(self): 190 store = dist.FileStore(self.file_name, self.world_size) 191 dist.init_process_group( 192 store=store, rank=self.rank, world_size=self.world_size, backend="nccl" 193 ) 194 device = torch.device(f"cuda:{self.rank}") 195 group = list(range(0, self.world_size)) 196 group_id = dist.new_group(range(self.world_size)) 197 rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND) 198 self._test_all_to_all_single( 199 group, 200 group_id, 201 self.rank, 202 cuda=True, 203 rank_to_GPU=rank_to_GPU, 204 dtype=torch.float32, 205 qtype=DQuantType.BFP16, 206 ) 207 208 def _test_all_gather( 209 self, 210 group, 211 group_id, 212 rank, 213 cuda=False, 214 rank_to_GPU=None, 215 dtype=torch.float, 216 qtype=None, 217 ): 218 for dest in group: 219 tensor = _build_tensor([dest + 1, dest + 1], rank, dtype=dtype) 220 tensors = [ 221 _build_tensor([dest + 1, dest + 1], -1, dtype=dtype) for i in group 222 ] 223 expected_tensors = [ 224 _build_tensor([dest + 1, dest + 1], i, dtype=dtype) for i in group 225 ] 226 if cuda: 227 tensor = tensor.cuda(rank_to_GPU[rank][0]) 228 tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] 229 if tensors[0].dtype == torch.complex64: 230 tensor_shapes = [torch.view_as_real(tensors[0]).shape] 231 else: 232 tensor_shapes = [tensors[0].shape] 233 allgather = quant.auto_quantize(dist.all_gather, qtype, quant_loss=None) 234 allgather(tensors, tensor, group=group_id, async_op=False) 235 236 for t1, t2 in zip(tensors, expected_tensors): 237 self.assertEqual(t1, t2) 238 239 def _test_all_to_all( 240 self, 241 group, 242 group_id, 243 rank, 244 cuda=False, 245 rank_to_GPU=None, 246 dtype=torch.float, 247 qtype=None, 248 ): 249 if group_id is not None: 250 size = len(group) 251 in_splits = [i + 1 for i in group] 252 in_tensors = [ 253 torch.ones([in_splits[i], size], dtype=dtype) * rank 254 for i, _ in enumerate(group) 255 ] 256 out_tensors = [ 257 torch.ones([(rank + 1), size], dtype=dtype) for _ in group 258 ] 259 expected_tensors = [ 260 torch.ones([rank + 1, size], dtype=dtype) * i for i in group 261 ] 262 if cuda: 263 in_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in in_tensors] 264 expected_tensors = [ 265 t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors 266 ] 267 out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors] 268 quantize_alltoall = quant.auto_quantize( 269 dist.all_to_all, qtype, quant_loss=None 270 ) 271 quantize_alltoall(out_tensors, in_tensors, group=group_id) 272 for t1, t2 in zip(out_tensors, expected_tensors): 273 self.assertEqual(t1, t2) 274 275 def _test_all_to_all_single( 276 self, 277 group, 278 group_id, 279 rank, 280 cuda=False, 281 rank_to_GPU=None, 282 dtype=torch.float, 283 qtype=DQuantType.FP16, 284 ): 285 if group_id is not None: 286 size = len(group) 287 in_splits = [i + 1 for i in group] 288 out_splits = [rank + 1 for _ in group] 289 in_tensor = torch.ones([sum(in_splits), size], dtype=dtype) * rank 290 out_tensor = torch.ones([(rank + 1) * size, size], dtype=dtype) 291 expected_tensor = torch.cat( 292 [torch.ones([rank + 1, size], dtype=dtype) * i for i in group] 293 ) 294 if cuda: 295 rank_to_GPU = rank_to_GPU[rank][0] 296 in_tensor = in_tensor.cuda(rank_to_GPU) 297 expected_tensor = expected_tensor.cuda(rank_to_GPU) 298 out_tensor = out_tensor.cuda(rank_to_GPU) 299 quantize_alltoall_single = quant.auto_quantize( 300 dist.all_to_all_single, qtype, quant_loss=None 301 ) 302 quantize_alltoall_single( 303 out_tensor, 304 in_tensor, 305 out_splits=out_splits, 306 in_splits=in_splits, 307 group=group_id, 308 ) 309 self.assertEqual(out_tensor, expected_tensor) 310 311 312if __name__ == "__main__": 313 run_tests() 314