xref: /aosp_15_r20/external/pytorch/test/distributed/algorithms/quantization/test_quantization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import os
4import sys
5
6import torch
7import torch.cuda
8import torch.distributed as dist
9import torch.distributed.algorithms._quantization.quantization as quant
10from torch.distributed.algorithms._quantization.quantization import DQuantType
11from torch.testing._internal.common_distributed import (
12    init_multigpu_helper,
13    MultiProcessTestCase,
14    requires_gloo,
15    requires_nccl,
16    skip_if_lt_x_gpu,
17    skip_if_rocm_multiprocess,
18)
19from torch.testing._internal.common_utils import (
20    NO_MULTIPROCESSING_SPAWN,
21    run_tests,
22    skip_but_pass_in_sandcastle_if,
23    TEST_WITH_DEV_DBG_ASAN,
24)
25
26
27torch.backends.cuda.matmul.allow_tf32 = False
28
29if not dist.is_available():
30    print("Distributed not available, skipping tests", file=sys.stderr)
31    sys.exit(0)
32
33
34def _build_tensor(size, value=None, dtype=torch.float, device_id=None):
35    if value is None:
36        value = size
37    if device_id is None:
38        return torch.empty(size, dtype=dtype).fill_(value)
39    else:
40        return torch.empty(size, dtype=dtype).fill_(value).cuda(device_id)
41
42
43if TEST_WITH_DEV_DBG_ASAN:
44    print(
45        "Skip dev-asan as torch + multiprocessing spawn have known issues",
46        file=sys.stderr,
47    )
48    sys.exit(0)
49
50if NO_MULTIPROCESSING_SPAWN:
51    print("Spawn not available, skipping tests.", file=sys.stderr)
52    sys.exit(0)
53
54BACKEND = os.environ["BACKEND"]
55if BACKEND == "gloo" or BACKEND == "nccl":
56
57    class DistQuantizationTests(MultiProcessTestCase):
58        def setUp(self):
59            super().setUp()
60            self._spawn_processes()
61            torch.backends.cudnn.flags(enabled=True, allow_tf32=False).__enter__()
62
63        def tearDown(self):
64            super().tearDown()
65            try:
66                os.remove(self.file_name)
67            except OSError:
68                pass
69
70        @property
71        def op_timeout_sec(self):
72            return 1
73
74        @property
75        def world_size(self):
76            return int(os.environ["WORLD_SIZE"])
77
78        @requires_gloo()
79        @skip_but_pass_in_sandcastle_if(
80            BACKEND != "gloo", "Only gloo backend supports all_gather_fp16"
81        )
82        def test_all_gather_fp16(self):
83            store = dist.FileStore(self.file_name, self.world_size)
84            dist.init_process_group(
85                store=store, rank=self.rank, world_size=self.world_size, backend="gloo"
86            )
87            device = torch.device(f"cuda:{self.rank}")
88            group = list(range(0, self.world_size))
89            group_id = dist.group.WORLD
90            self._test_all_gather(
91                group, group_id, self.rank, dtype=torch.float32, qtype=DQuantType.FP16
92            )
93
94        @requires_gloo()
95        @skip_but_pass_in_sandcastle_if(
96            BACKEND != "gloo", "Only gloo backend supports all_gather_fp16"
97        )
98        def test_all_gather_bfp16(self):
99            store = dist.FileStore(self.file_name, self.world_size)
100            dist.init_process_group(
101                store=store, rank=self.rank, world_size=self.world_size, backend="gloo"
102            )
103            device = torch.device(f"cuda:{self.rank}")
104            group = list(range(0, self.world_size))
105            group_id = dist.group.WORLD
106            self._test_all_gather(
107                group, group_id, self.rank, dtype=torch.float32, qtype=DQuantType.BFP16
108            )
109
110        @requires_nccl()
111        @skip_but_pass_in_sandcastle_if(
112            BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16"
113        )
114        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
115        @skip_if_rocm_multiprocess
116        def test_all_to_all_fp16(self):
117            store = dist.FileStore(self.file_name, self.world_size)
118            dist.init_process_group(
119                store=store, rank=self.rank, world_size=self.world_size, backend="nccl"
120            )
121            device = torch.device(f"cuda:{self.rank}")
122            group = list(range(0, self.world_size))
123            group_id = dist.new_group(range(self.world_size))
124            rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND)
125            self._test_all_to_all(
126                group,
127                group_id,
128                self.rank,
129                cuda=True,
130                rank_to_GPU=rank_to_GPU,
131                dtype=torch.float32,
132                qtype=DQuantType.FP16,
133            )
134
135        @requires_nccl()
136        @skip_but_pass_in_sandcastle_if(
137            BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16"
138        )
139        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
140        @skip_if_rocm_multiprocess
141        def test_all_to_all_bfp16(self):
142            store = dist.FileStore(self.file_name, self.world_size)
143            dist.init_process_group(
144                store=store, rank=self.rank, world_size=self.world_size, backend="nccl"
145            )
146            device = torch.device(f"cuda:{self.rank}")
147            group = list(range(0, self.world_size))
148            group_id = dist.new_group(range(self.world_size))
149            rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND)
150            self._test_all_to_all(
151                group,
152                group_id,
153                self.rank,
154                cuda=True,
155                rank_to_GPU=rank_to_GPU,
156                dtype=torch.float32,
157                qtype=DQuantType.BFP16,
158            )
159
160        @requires_nccl()
161        @skip_but_pass_in_sandcastle_if(
162            BACKEND != "nccl", "Only nccl backend supports all_to_all_single_fp16"
163        )
164        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
165        def test_all_to_all_single_fp16(self):
166            store = dist.FileStore(self.file_name, self.world_size)
167            dist.init_process_group(
168                store=store, rank=self.rank, world_size=self.world_size, backend="nccl"
169            )
170            device = torch.device(f"cuda:{self.rank}")
171            group = list(range(0, self.world_size))
172            group_id = dist.new_group(range(self.world_size))
173            rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND)
174            self._test_all_to_all_single(
175                group,
176                group_id,
177                self.rank,
178                cuda=True,
179                rank_to_GPU=rank_to_GPU,
180                dtype=torch.float32,
181                qtype=DQuantType.FP16,
182            )
183
184        @requires_nccl()
185        @skip_but_pass_in_sandcastle_if(
186            BACKEND != "nccl", "Only nccl backend supports all_to_all_single_bfp16"
187        )
188        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
189        def test_all_to_all_single_bfp16(self):
190            store = dist.FileStore(self.file_name, self.world_size)
191            dist.init_process_group(
192                store=store, rank=self.rank, world_size=self.world_size, backend="nccl"
193            )
194            device = torch.device(f"cuda:{self.rank}")
195            group = list(range(0, self.world_size))
196            group_id = dist.new_group(range(self.world_size))
197            rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND)
198            self._test_all_to_all_single(
199                group,
200                group_id,
201                self.rank,
202                cuda=True,
203                rank_to_GPU=rank_to_GPU,
204                dtype=torch.float32,
205                qtype=DQuantType.BFP16,
206            )
207
208        def _test_all_gather(
209            self,
210            group,
211            group_id,
212            rank,
213            cuda=False,
214            rank_to_GPU=None,
215            dtype=torch.float,
216            qtype=None,
217        ):
218            for dest in group:
219                tensor = _build_tensor([dest + 1, dest + 1], rank, dtype=dtype)
220                tensors = [
221                    _build_tensor([dest + 1, dest + 1], -1, dtype=dtype) for i in group
222                ]
223                expected_tensors = [
224                    _build_tensor([dest + 1, dest + 1], i, dtype=dtype) for i in group
225                ]
226                if cuda:
227                    tensor = tensor.cuda(rank_to_GPU[rank][0])
228                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
229                if tensors[0].dtype == torch.complex64:
230                    tensor_shapes = [torch.view_as_real(tensors[0]).shape]
231                else:
232                    tensor_shapes = [tensors[0].shape]
233                allgather = quant.auto_quantize(dist.all_gather, qtype, quant_loss=None)
234                allgather(tensors, tensor, group=group_id, async_op=False)
235
236                for t1, t2 in zip(tensors, expected_tensors):
237                    self.assertEqual(t1, t2)
238
239        def _test_all_to_all(
240            self,
241            group,
242            group_id,
243            rank,
244            cuda=False,
245            rank_to_GPU=None,
246            dtype=torch.float,
247            qtype=None,
248        ):
249            if group_id is not None:
250                size = len(group)
251                in_splits = [i + 1 for i in group]
252                in_tensors = [
253                    torch.ones([in_splits[i], size], dtype=dtype) * rank
254                    for i, _ in enumerate(group)
255                ]
256                out_tensors = [
257                    torch.ones([(rank + 1), size], dtype=dtype) for _ in group
258                ]
259                expected_tensors = [
260                    torch.ones([rank + 1, size], dtype=dtype) * i for i in group
261                ]
262                if cuda:
263                    in_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in in_tensors]
264                    expected_tensors = [
265                        t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors
266                    ]
267                    out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors]
268                quantize_alltoall = quant.auto_quantize(
269                    dist.all_to_all, qtype, quant_loss=None
270                )
271                quantize_alltoall(out_tensors, in_tensors, group=group_id)
272                for t1, t2 in zip(out_tensors, expected_tensors):
273                    self.assertEqual(t1, t2)
274
275        def _test_all_to_all_single(
276            self,
277            group,
278            group_id,
279            rank,
280            cuda=False,
281            rank_to_GPU=None,
282            dtype=torch.float,
283            qtype=DQuantType.FP16,
284        ):
285            if group_id is not None:
286                size = len(group)
287                in_splits = [i + 1 for i in group]
288                out_splits = [rank + 1 for _ in group]
289                in_tensor = torch.ones([sum(in_splits), size], dtype=dtype) * rank
290                out_tensor = torch.ones([(rank + 1) * size, size], dtype=dtype)
291                expected_tensor = torch.cat(
292                    [torch.ones([rank + 1, size], dtype=dtype) * i for i in group]
293                )
294                if cuda:
295                    rank_to_GPU = rank_to_GPU[rank][0]
296                    in_tensor = in_tensor.cuda(rank_to_GPU)
297                    expected_tensor = expected_tensor.cuda(rank_to_GPU)
298                    out_tensor = out_tensor.cuda(rank_to_GPU)
299                    quantize_alltoall_single = quant.auto_quantize(
300                        dist.all_to_all_single, qtype, quant_loss=None
301                    )
302                    quantize_alltoall_single(
303                        out_tensor,
304                        in_tensor,
305                        out_splits=out_splits,
306                        in_splits=in_splits,
307                        group=group_id,
308                    )
309                    self.assertEqual(out_tensor, expected_tensor)
310
311
312if __name__ == "__main__":
313    run_tests()
314