1# mypy: allow-untyped-defs 2import torch 3import torch.distributed as dist 4from torch import nn 5 6 7def _quantize_per_tensor_cuda(x, scale, zero_point): 8 y = torch.round(x / scale) + zero_point 9 y = torch.clamp(y, 0, 255).to(torch.uint8) 10 return y 11 12 13def _dequantize_per_tensor_cuda(y, scale, zero_point): 14 x = scale * (y.to(torch.float32) - zero_point) 15 return x 16 17 18def _quantize_per_channel_cuda(x, scale, zero_point): 19 y = torch.zeros(x.size(), device=x.device) 20 for i in range(x.size()[0]): 21 y[i, :] = torch.round(x[i, :] / scale[i]) + zero_point[i] 22 y = torch.clamp(y, 0, 255).to(torch.uint8) 23 return y 24 25 26def _dequantize_per_channel_cuda(y, scale, zero_point): 27 y = y.to(torch.float32).cuda(y.device) 28 x = torch.zeros_like(y, device=y.device) 29 for i in range(x.size()[0]): 30 x[i, :] = scale[i] * (y[i, :] - zero_point[i]) 31 return x 32 33 34def _get_allgather_out_list(all_gather_in_list, world_size): 35 out_list = [ 36 torch.zeros_like( 37 all_gather_in_list, 38 device=all_gather_in_list.device, 39 dtype=all_gather_in_list.dtype, 40 ) 41 for _ in range(world_size) 42 ] 43 return out_list 44 45 46def quantization_pertensor_hook( 47 process_group: dist.ProcessGroup, bucket: dist.GradBucket 48) -> torch.futures.Future[torch.Tensor]: 49 """ 50 Apply ``torch.quantize_per_tensor`` logic to DDP using ``allgather`` protocol. 51 52 Workers first allgather the scale and zero point of their own 53 ``GradBucket`` prior to the quantization. After all workers have that information, 54 the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's 55 own gradient tensor, and uses ``allgather`` to communicate these across all workers. 56 The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes and 57 aggregates each quantized gradient tensor locally and returns the mean. 58 59 .. warning :: 60 This is experimental, and uses ``allgather`` protocol which is considerably slower than 61 ``allreduce`` protocol. It works only with flattened grads. 62 63 Example:: 64 >>> # xdoctest: +SKIP 65 >>> ddp_model.register_comm_hook(process_group, quantization_pertensor_hook) 66 """ 67 group_to_use = process_group if process_group is not None else dist.group.WORLD 68 rank = process_group.rank() if process_group is not None else dist.get_rank() 69 world_size = group_to_use.size() 70 71 tensor = bucket.buffer() 72 73 myObserver = torch.ao.quantization.MinMaxObserver().cuda(tensor.device) 74 myObserver(tensor) 75 76 s, z = myObserver.calculate_qparams() 77 s_and_z = torch.FloatTensor([s, z]).cuda(tensor.device) 78 79 all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size) 80 81 # First, allgather scale and zeros. 82 fut = dist.all_gather( 83 all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True 84 ).get_future() 85 86 def quantize_and_allgather(fut): 87 # Store scale and zeros across all workers. 88 all_ranks_s_and_z = fut.wait()[0] 89 # All workers quantize their own ``GradBucket`` tensors. 90 quantized_tensor = _quantize_per_tensor_cuda( 91 tensor, all_ranks_s_and_z[rank][0], all_ranks_s_and_z[rank][1] 92 ) 93 # Allgather quantized tensors. 94 fut = dist.all_gather( 95 _get_allgather_out_list(quantized_tensor, world_size), 96 quantized_tensor, 97 group=group_to_use, 98 async_op=True, 99 ).get_future() 100 101 return fut.wait() 102 103 def dequantize_and_aggregate(fut): 104 all_ranks_quantized_tensor = fut.wait()[0] 105 106 aggregated_dequantized_tensor = torch.zeros_like( 107 all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32 108 ) 109 # Using previously allgathered scales and zeros, dequantize gradient tensors 110 # locally and then aggregate them. 111 for r, quantized_tensor in enumerate(all_ranks_quantized_tensor): 112 aggregated_dequantized_tensor += _dequantize_per_tensor_cuda( 113 quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1] 114 ) 115 116 return aggregated_dequantized_tensor / world_size 117 118 return fut.then(quantize_and_allgather).then(dequantize_and_aggregate) 119 120 121def quantization_perchannel_hook( 122 process_group: dist.ProcessGroup, bucket: dist.GradBucket, bucket_size=512 123) -> torch.futures.Future[torch.Tensor]: 124 """ 125 Apply``torch.quantize_per_channel`` logic to DDP using ``allgather`` protocol. 126 127 Compared to per-tensor, the main motivation of per-channel is 128 for considerably large tensors such as a tensor that contains 6 million 129 elements quantizing per a bucket size of 512 (or 128) elements may significantly 130 increase the resolution. 131 132 It first splits ``GradBucket`` tensor into multiple chunks (channels) of ``bucket_size`` 133 elements. Then, workers allgather the scales and zero points of their own 134 ``GradBucket`` prior to the quantization. After all workers have that information, 135 the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's 136 own gradient tensor, and uses ``allgather`` to communicate these across all workers. 137 The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes, flattens, and 138 aggregates each quantized gradient tensor locally and returns the mean. 139 140 .. warning :: 141 This is experimental, and uses ``allgather`` protocol which is considerably slower than 142 ``allreduce`` protocol. It works only with flattened grads. 143 144 Example:: 145 >>> # xdoctest: +SKIP 146 >>> ddp_model.register_comm_hook(process_group, quantization_perchannel_hook) 147 """ 148 group_to_use = process_group if process_group is not None else dist.group.WORLD 149 rank = process_group.rank() if process_group is not None else dist.get_rank() 150 world_size = group_to_use.size() 151 152 tensor = bucket.buffer() 153 154 tensor_in_channels = ( 155 nn.functional.pad( 156 input=tensor, 157 pad=(0, bucket_size - len(tensor) % bucket_size), 158 mode="constant", 159 value=0, 160 ) 161 .view(-1, bucket_size) 162 .cuda(tensor.device) 163 ) 164 165 myPerChannelObserver = torch.ao.quantization.PerChannelMinMaxObserver().cuda( 166 tensor.device 167 ) 168 myPerChannelObserver(tensor_in_channels) 169 170 s_ch, z_ch = myPerChannelObserver.calculate_qparams() 171 s_and_z = torch.stack((s_ch, z_ch)).cuda(tensor.device) 172 173 all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size) 174 # First, allgather scale and zeros. 175 fut = dist.all_gather( 176 all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True 177 ).get_future() 178 179 def quantize_and_allgather(fut): 180 # Store scale and zeros across all workers. 181 all_ranks_s_and_z = fut.wait()[0] 182 # All workers quantize their corresponding ``GradBucket`` tensors. 183 quantized_tensor = _quantize_per_channel_cuda( 184 tensor_in_channels, 185 all_ranks_s_and_z[rank, 0, :], 186 all_ranks_s_and_z[rank, 1, :], 187 ) 188 # Allgather quantized tensors. 189 fut = dist.all_gather( 190 _get_allgather_out_list(quantized_tensor, world_size), 191 quantized_tensor, 192 group=group_to_use, 193 async_op=True, 194 ).get_future() 195 196 return fut.wait() 197 198 def dequantize_and_aggregate(fut): 199 all_ranks_quantized_tensor = fut.wait()[0] 200 201 aggregated_dequantized_tensor = torch.zeros_like( 202 all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32 203 ) 204 # Using previously allgathered scales and zeros, dequantize gradient tensors 205 # locally and then aggregate them. 206 for r, quantized_tensor in enumerate(all_ranks_quantized_tensor): 207 aggregated_dequantized_tensor += _dequantize_per_channel_cuda( 208 quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1] 209 ) 210 211 return ( 212 torch.flatten(aggregated_dequantized_tensor).cuda(tensor.device)[ 213 : tensor.size()[0] 214 ] 215 / world_size 216 ) 217 218 return fut.then(quantize_and_allgather).then(dequantize_and_aggregate) 219