xref: /aosp_15_r20/external/pytorch/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.distributed as dist
4from torch import nn
5
6
7def _quantize_per_tensor_cuda(x, scale, zero_point):
8    y = torch.round(x / scale) + zero_point
9    y = torch.clamp(y, 0, 255).to(torch.uint8)
10    return y
11
12
13def _dequantize_per_tensor_cuda(y, scale, zero_point):
14    x = scale * (y.to(torch.float32) - zero_point)
15    return x
16
17
18def _quantize_per_channel_cuda(x, scale, zero_point):
19    y = torch.zeros(x.size(), device=x.device)
20    for i in range(x.size()[0]):
21        y[i, :] = torch.round(x[i, :] / scale[i]) + zero_point[i]
22    y = torch.clamp(y, 0, 255).to(torch.uint8)
23    return y
24
25
26def _dequantize_per_channel_cuda(y, scale, zero_point):
27    y = y.to(torch.float32).cuda(y.device)
28    x = torch.zeros_like(y, device=y.device)
29    for i in range(x.size()[0]):
30        x[i, :] = scale[i] * (y[i, :] - zero_point[i])
31    return x
32
33
34def _get_allgather_out_list(all_gather_in_list, world_size):
35    out_list = [
36        torch.zeros_like(
37            all_gather_in_list,
38            device=all_gather_in_list.device,
39            dtype=all_gather_in_list.dtype,
40        )
41        for _ in range(world_size)
42    ]
43    return out_list
44
45
46def quantization_pertensor_hook(
47    process_group: dist.ProcessGroup, bucket: dist.GradBucket
48) -> torch.futures.Future[torch.Tensor]:
49    """
50    Apply ``torch.quantize_per_tensor`` logic to DDP using ``allgather`` protocol.
51
52    Workers first allgather the scale and zero point of their own
53    ``GradBucket`` prior to the quantization. After all workers have that information,
54    the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's
55    own gradient tensor, and uses ``allgather`` to communicate these across all workers.
56    The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes and
57    aggregates each quantized gradient tensor locally and returns the mean.
58
59    .. warning ::
60        This is experimental, and uses ``allgather`` protocol which is considerably slower than
61        ``allreduce`` protocol. It works only with flattened grads.
62
63    Example::
64        >>> # xdoctest: +SKIP
65        >>> ddp_model.register_comm_hook(process_group, quantization_pertensor_hook)
66    """
67    group_to_use = process_group if process_group is not None else dist.group.WORLD
68    rank = process_group.rank() if process_group is not None else dist.get_rank()
69    world_size = group_to_use.size()
70
71    tensor = bucket.buffer()
72
73    myObserver = torch.ao.quantization.MinMaxObserver().cuda(tensor.device)
74    myObserver(tensor)
75
76    s, z = myObserver.calculate_qparams()
77    s_and_z = torch.FloatTensor([s, z]).cuda(tensor.device)
78
79    all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size)
80
81    # First, allgather scale and zeros.
82    fut = dist.all_gather(
83        all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True
84    ).get_future()
85
86    def quantize_and_allgather(fut):
87        # Store scale and zeros across all workers.
88        all_ranks_s_and_z = fut.wait()[0]
89        # All workers quantize their own ``GradBucket`` tensors.
90        quantized_tensor = _quantize_per_tensor_cuda(
91            tensor, all_ranks_s_and_z[rank][0], all_ranks_s_and_z[rank][1]
92        )
93        # Allgather quantized tensors.
94        fut = dist.all_gather(
95            _get_allgather_out_list(quantized_tensor, world_size),
96            quantized_tensor,
97            group=group_to_use,
98            async_op=True,
99        ).get_future()
100
101        return fut.wait()
102
103    def dequantize_and_aggregate(fut):
104        all_ranks_quantized_tensor = fut.wait()[0]
105
106        aggregated_dequantized_tensor = torch.zeros_like(
107            all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32
108        )
109        # Using previously allgathered scales and zeros, dequantize gradient tensors
110        # locally and then aggregate them.
111        for r, quantized_tensor in enumerate(all_ranks_quantized_tensor):
112            aggregated_dequantized_tensor += _dequantize_per_tensor_cuda(
113                quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1]
114            )
115
116        return aggregated_dequantized_tensor / world_size
117
118    return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)
119
120
121def quantization_perchannel_hook(
122    process_group: dist.ProcessGroup, bucket: dist.GradBucket, bucket_size=512
123) -> torch.futures.Future[torch.Tensor]:
124    """
125    Apply``torch.quantize_per_channel`` logic to DDP using ``allgather`` protocol.
126
127    Compared to per-tensor, the main motivation of per-channel is
128    for considerably large tensors such as a tensor that contains 6 million
129    elements quantizing per a bucket size of 512 (or 128) elements may significantly
130    increase the resolution.
131
132    It first splits ``GradBucket`` tensor into multiple chunks (channels) of ``bucket_size``
133    elements. Then, workers allgather the scales and zero points of their own
134    ``GradBucket`` prior to the quantization. After all workers have that information,
135    the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's
136    own gradient tensor, and uses ``allgather`` to communicate these across all workers.
137    The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes, flattens, and
138    aggregates each quantized gradient tensor locally and returns the mean.
139
140    .. warning ::
141        This is experimental, and uses ``allgather`` protocol which is considerably slower than
142        ``allreduce`` protocol. It works only with flattened grads.
143
144    Example::
145        >>> # xdoctest: +SKIP
146        >>> ddp_model.register_comm_hook(process_group, quantization_perchannel_hook)
147    """
148    group_to_use = process_group if process_group is not None else dist.group.WORLD
149    rank = process_group.rank() if process_group is not None else dist.get_rank()
150    world_size = group_to_use.size()
151
152    tensor = bucket.buffer()
153
154    tensor_in_channels = (
155        nn.functional.pad(
156            input=tensor,
157            pad=(0, bucket_size - len(tensor) % bucket_size),
158            mode="constant",
159            value=0,
160        )
161        .view(-1, bucket_size)
162        .cuda(tensor.device)
163    )
164
165    myPerChannelObserver = torch.ao.quantization.PerChannelMinMaxObserver().cuda(
166        tensor.device
167    )
168    myPerChannelObserver(tensor_in_channels)
169
170    s_ch, z_ch = myPerChannelObserver.calculate_qparams()
171    s_and_z = torch.stack((s_ch, z_ch)).cuda(tensor.device)
172
173    all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size)
174    # First, allgather scale and zeros.
175    fut = dist.all_gather(
176        all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True
177    ).get_future()
178
179    def quantize_and_allgather(fut):
180        # Store scale and zeros across all workers.
181        all_ranks_s_and_z = fut.wait()[0]
182        # All workers quantize their corresponding ``GradBucket`` tensors.
183        quantized_tensor = _quantize_per_channel_cuda(
184            tensor_in_channels,
185            all_ranks_s_and_z[rank, 0, :],
186            all_ranks_s_and_z[rank, 1, :],
187        )
188        # Allgather quantized tensors.
189        fut = dist.all_gather(
190            _get_allgather_out_list(quantized_tensor, world_size),
191            quantized_tensor,
192            group=group_to_use,
193            async_op=True,
194        ).get_future()
195
196        return fut.wait()
197
198    def dequantize_and_aggregate(fut):
199        all_ranks_quantized_tensor = fut.wait()[0]
200
201        aggregated_dequantized_tensor = torch.zeros_like(
202            all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32
203        )
204        # Using previously allgathered scales and zeros, dequantize gradient tensors
205        # locally and then aggregate them.
206        for r, quantized_tensor in enumerate(all_ranks_quantized_tensor):
207            aggregated_dequantized_tensor += _dequantize_per_channel_cuda(
208                quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1]
209            )
210
211        return (
212            torch.flatten(aggregated_dequantized_tensor).cuda(tensor.device)[
213                : tensor.size()[0]
214            ]
215            / world_size
216        )
217
218    return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)
219