1# mypy: allow-untyped-defs 2import torch 3import torch.distributed as dist 4from torch.autograd.function import Function 5 6 7class SyncBatchNorm(Function): 8 @staticmethod 9 def forward( 10 self, 11 input, 12 weight, 13 bias, 14 running_mean, 15 running_var, 16 eps, 17 momentum, 18 process_group, 19 world_size, 20 ): 21 if not ( 22 input.is_contiguous(memory_format=torch.channels_last) 23 or input.is_contiguous(memory_format=torch.channels_last_3d) 24 ): 25 input = input.contiguous() 26 if weight is not None: 27 weight = weight.contiguous() 28 29 size = int(input.numel() // input.size(1)) 30 if size == 1 and world_size < 2: 31 raise ValueError( 32 f"Expected more than 1 value per channel when training, got input size {size}" 33 ) 34 35 num_channels = input.shape[1] 36 if input.numel() > 0: 37 # calculate mean/invstd for input. 38 mean, invstd = torch.batch_norm_stats(input, eps) 39 40 count = torch.full( 41 (1,), 42 input.numel() // input.size(1), 43 dtype=mean.dtype, 44 device=mean.device, 45 ) 46 47 # C, C, 1 -> (2C + 1) 48 combined = torch.cat([mean, invstd, count], dim=0) 49 else: 50 # for empty input, set stats and the count to zero. The stats with 51 # zero count will be filtered out later when computing global mean 52 # & invstd, but they still needs to participate the all_gather 53 # collective communication to unblock other peer processes. 54 combined = torch.zeros( 55 2 * num_channels + 1, dtype=input.dtype, device=input.device 56 ) 57 58 # Use allgather instead of allreduce because count could be different across 59 # ranks, simple all reduce op can not give correct results. 60 # batch_norm_gather_stats_with_counts calculates global mean & invstd based on 61 # all gathered mean, invstd and count. 62 # for nccl backend, use the optimized version of all gather. 63 # The Gloo backend does not support `all_gather_into_tensor`. 64 if process_group._get_backend_name() != "gloo": 65 # world_size * (2C + 1) 66 combined_size = combined.numel() 67 combined_flat = torch.empty( 68 1, 69 combined_size * world_size, 70 dtype=combined.dtype, 71 device=combined.device, 72 ) 73 dist.all_gather_into_tensor( 74 combined_flat, combined, process_group, async_op=False 75 ) 76 combined = torch.reshape(combined_flat, (world_size, combined_size)) 77 # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 78 mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) 79 else: 80 # world_size * (2C + 1) 81 combined_list = [torch.empty_like(combined) for _ in range(world_size)] 82 dist.all_gather(combined_list, combined, process_group, async_op=False) 83 combined = torch.stack(combined_list, dim=0) 84 # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 85 mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) 86 87 if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()): 88 # The lines below force a synchronization between CUDA and CPU, because 89 # the shape of the result count_all depends on the values in mask tensor. 90 # Such synchronizations break CUDA Graph capturing. 91 # See https://github.com/pytorch/pytorch/issues/78549 92 # FIXME: https://github.com/pytorch/pytorch/issues/78656 describes 93 # a better longer-term solution. 94 95 # remove stats from empty inputs 96 mask = count_all.squeeze(-1) >= 1 97 count_all = count_all[mask] 98 mean_all = mean_all[mask] 99 invstd_all = invstd_all[mask] 100 101 # calculate global mean & invstd 102 counts = count_all.view(-1) 103 if running_mean is not None and counts.dtype != running_mean.dtype: 104 counts = counts.to(running_mean.dtype) 105 mean, invstd = torch.batch_norm_gather_stats_with_counts( 106 input, 107 mean_all, 108 invstd_all, 109 running_mean, 110 running_var, 111 momentum, 112 eps, 113 counts, 114 ) 115 116 self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32)) 117 self.process_group = process_group 118 119 # apply element-wise normalization 120 if input.numel() > 0: 121 return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) 122 else: 123 return torch.empty_like(input) 124 125 @staticmethod 126 def backward(self, grad_output): 127 if not ( 128 grad_output.is_contiguous(memory_format=torch.channels_last) 129 or grad_output.is_contiguous(memory_format=torch.channels_last_3d) 130 ): 131 grad_output = grad_output.contiguous() 132 saved_input, weight, mean, invstd, count_tensor = self.saved_tensors 133 grad_input = grad_weight = grad_bias = None 134 process_group = self.process_group 135 136 if saved_input.numel() > 0: 137 # calculate local stats as well as grad_weight / grad_bias 138 ( 139 sum_dy, 140 sum_dy_xmu, 141 grad_weight, 142 grad_bias, 143 ) = torch.batch_norm_backward_reduce( 144 grad_output, 145 saved_input, 146 mean, 147 invstd, 148 weight, 149 self.needs_input_grad[0], 150 self.needs_input_grad[1], 151 self.needs_input_grad[2], 152 ) 153 154 if self.needs_input_grad[0]: 155 # synchronizing stats used to calculate input gradient. 156 num_channels = sum_dy.shape[0] 157 combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) 158 torch.distributed.all_reduce( 159 combined, 160 torch.distributed.ReduceOp.SUM, 161 process_group, 162 async_op=False, 163 ) 164 sum_dy, sum_dy_xmu = torch.split(combined, num_channels) 165 166 # backward pass for gradient calculation 167 if weight is not None and weight.dtype != mean.dtype: 168 weight = weight.to(mean.dtype) 169 grad_input = torch.batch_norm_backward_elemt( 170 grad_output, 171 saved_input, 172 mean, 173 invstd, 174 weight, 175 sum_dy, 176 sum_dy_xmu, 177 count_tensor, 178 ) 179 # synchronizing of grad_weight / grad_bias is not needed as distributed 180 # training would handle all reduce. 181 if weight is None or not self.needs_input_grad[1]: 182 grad_weight = None 183 184 if weight is None or not self.needs_input_grad[2]: 185 grad_bias = None 186 else: 187 # This process got an empty input tensor in the forward pass. 188 # Although this process can directly set grad_input as an empty 189 # tensor of zeros, it still needs to participate in the collective 190 # communication to unblock its peers, as other peer processes might 191 # have received non-empty inputs. 192 num_channels = saved_input.shape[1] 193 if self.needs_input_grad[0]: 194 # launch all_reduce to unblock other peer processes 195 combined = torch.zeros( 196 2 * num_channels, dtype=saved_input.dtype, device=saved_input.device 197 ) 198 torch.distributed.all_reduce( 199 combined, 200 torch.distributed.ReduceOp.SUM, 201 process_group, 202 async_op=False, 203 ) 204 205 # Leave grad_input, grad_weight and grad_bias as None, which will be 206 # interpreted by the autograd engine as Tensors full of zeros. 207 208 return grad_input, grad_weight, grad_bias, None, None, None, None, None, None 209 210 211class CrossMapLRN2d(Function): 212 @staticmethod 213 def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1): 214 ctx.size = size 215 ctx.alpha = alpha 216 ctx.beta = beta 217 ctx.k = k 218 ctx.scale = None 219 220 if input.dim() != 4: 221 raise ValueError( 222 f"CrossMapLRN2d: Expected input to be 4D, got {input.dim()}D instead." 223 ) 224 225 ctx.scale = ctx.scale or input.new() 226 output = input.new() 227 228 batch_size = input.size(0) 229 channels = input.size(1) 230 input_height = input.size(2) 231 input_width = input.size(3) 232 233 output.resize_as_(input) 234 ctx.scale.resize_as_(input) 235 236 # use output storage as temporary buffer 237 input_square = output 238 torch.pow(input, 2, out=input_square) 239 240 pre_pad = int((ctx.size - 1) / 2 + 1) 241 pre_pad_crop = min(pre_pad, channels) 242 243 scale_first = ctx.scale.select(1, 0) 244 scale_first.zero_() 245 # compute first feature map normalization 246 for c in range(pre_pad_crop): 247 scale_first.add_(input_square.select(1, c)) 248 249 # reuse computations for next feature maps normalization 250 # by adding the next feature map and removing the previous 251 for c in range(1, channels): 252 scale_previous = ctx.scale.select(1, c - 1) 253 scale_current = ctx.scale.select(1, c) 254 scale_current.copy_(scale_previous) 255 if c < channels - pre_pad + 1: 256 square_next = input_square.select(1, c + pre_pad - 1) 257 scale_current.add_(square_next, alpha=1) 258 259 if c > pre_pad: 260 square_previous = input_square.select(1, c - pre_pad) 261 scale_current.add_(square_previous, alpha=-1) 262 263 ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k) 264 265 torch.pow(ctx.scale, -ctx.beta, out=output) 266 output.mul_(input) 267 268 ctx.save_for_backward(input, output) 269 return output 270 271 @staticmethod 272 def backward(ctx, grad_output): 273 input, output = ctx.saved_tensors 274 grad_input = grad_output.new() 275 276 batch_size = input.size(0) 277 channels = input.size(1) 278 input_height = input.size(2) 279 input_width = input.size(3) 280 281 paddded_ratio = input.new(channels + ctx.size - 1, input_height, input_width) 282 accum_ratio = input.new(input_height, input_width) 283 284 cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size 285 inversePrePad = int(ctx.size - (ctx.size - 1) / 2) 286 287 grad_input.resize_as_(input) 288 torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output) 289 290 paddded_ratio.zero_() 291 padded_ratio_center = paddded_ratio.narrow(0, inversePrePad, channels) 292 for n in range(batch_size): 293 torch.mul(grad_output[n], output[n], out=padded_ratio_center) 294 padded_ratio_center.div_(ctx.scale[n]) 295 torch.sum( 296 paddded_ratio.narrow(0, 0, ctx.size - 1), 297 0, 298 keepdim=False, 299 out=accum_ratio, 300 ) 301 for c in range(channels): 302 accum_ratio.add_(paddded_ratio[c + ctx.size - 1]) 303 grad_input[n][c].addcmul_( 304 input[n][c], accum_ratio, value=-cache_ratio_value 305 ) 306 accum_ratio.add_(paddded_ratio[c], alpha=-1) 307 308 return grad_input, None, None, None, None 309 310 311class BackwardHookFunction(torch.autograd.Function): 312 @staticmethod 313 def forward(ctx, *args): 314 ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad]) 315 return args 316 317 @staticmethod 318 def backward(ctx, *args): 319 return args 320