xref: /aosp_15_r20/external/pytorch/torch/nn/modules/_functions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.distributed as dist
4from torch.autograd.function import Function
5
6
7class SyncBatchNorm(Function):
8    @staticmethod
9    def forward(
10        self,
11        input,
12        weight,
13        bias,
14        running_mean,
15        running_var,
16        eps,
17        momentum,
18        process_group,
19        world_size,
20    ):
21        if not (
22            input.is_contiguous(memory_format=torch.channels_last)
23            or input.is_contiguous(memory_format=torch.channels_last_3d)
24        ):
25            input = input.contiguous()
26        if weight is not None:
27            weight = weight.contiguous()
28
29        size = int(input.numel() // input.size(1))
30        if size == 1 and world_size < 2:
31            raise ValueError(
32                f"Expected more than 1 value per channel when training, got input size {size}"
33            )
34
35        num_channels = input.shape[1]
36        if input.numel() > 0:
37            # calculate mean/invstd for input.
38            mean, invstd = torch.batch_norm_stats(input, eps)
39
40            count = torch.full(
41                (1,),
42                input.numel() // input.size(1),
43                dtype=mean.dtype,
44                device=mean.device,
45            )
46
47            # C, C, 1 -> (2C + 1)
48            combined = torch.cat([mean, invstd, count], dim=0)
49        else:
50            # for empty input, set stats and the count to zero. The stats with
51            # zero count will be filtered out later when computing global mean
52            # & invstd, but they still needs to participate the all_gather
53            # collective communication to unblock other peer processes.
54            combined = torch.zeros(
55                2 * num_channels + 1, dtype=input.dtype, device=input.device
56            )
57
58        # Use allgather instead of allreduce because count could be different across
59        # ranks, simple all reduce op can not give correct results.
60        # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
61        # all gathered mean, invstd and count.
62        # for nccl backend, use the optimized version of all gather.
63        # The Gloo backend does not support `all_gather_into_tensor`.
64        if process_group._get_backend_name() != "gloo":
65            # world_size * (2C + 1)
66            combined_size = combined.numel()
67            combined_flat = torch.empty(
68                1,
69                combined_size * world_size,
70                dtype=combined.dtype,
71                device=combined.device,
72            )
73            dist.all_gather_into_tensor(
74                combined_flat, combined, process_group, async_op=False
75            )
76            combined = torch.reshape(combined_flat, (world_size, combined_size))
77            # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
78            mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
79        else:
80            # world_size * (2C + 1)
81            combined_list = [torch.empty_like(combined) for _ in range(world_size)]
82            dist.all_gather(combined_list, combined, process_group, async_op=False)
83            combined = torch.stack(combined_list, dim=0)
84            # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
85            mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
86
87        if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()):
88            # The lines below force a synchronization between CUDA and CPU, because
89            # the shape of the result count_all depends on the values in mask tensor.
90            # Such synchronizations break CUDA Graph capturing.
91            # See https://github.com/pytorch/pytorch/issues/78549
92            # FIXME: https://github.com/pytorch/pytorch/issues/78656 describes
93            # a better longer-term solution.
94
95            # remove stats from empty inputs
96            mask = count_all.squeeze(-1) >= 1
97            count_all = count_all[mask]
98            mean_all = mean_all[mask]
99            invstd_all = invstd_all[mask]
100
101        # calculate global mean & invstd
102        counts = count_all.view(-1)
103        if running_mean is not None and counts.dtype != running_mean.dtype:
104            counts = counts.to(running_mean.dtype)
105        mean, invstd = torch.batch_norm_gather_stats_with_counts(
106            input,
107            mean_all,
108            invstd_all,
109            running_mean,
110            running_var,
111            momentum,
112            eps,
113            counts,
114        )
115
116        self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
117        self.process_group = process_group
118
119        # apply element-wise normalization
120        if input.numel() > 0:
121            return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
122        else:
123            return torch.empty_like(input)
124
125    @staticmethod
126    def backward(self, grad_output):
127        if not (
128            grad_output.is_contiguous(memory_format=torch.channels_last)
129            or grad_output.is_contiguous(memory_format=torch.channels_last_3d)
130        ):
131            grad_output = grad_output.contiguous()
132        saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
133        grad_input = grad_weight = grad_bias = None
134        process_group = self.process_group
135
136        if saved_input.numel() > 0:
137            # calculate local stats as well as grad_weight / grad_bias
138            (
139                sum_dy,
140                sum_dy_xmu,
141                grad_weight,
142                grad_bias,
143            ) = torch.batch_norm_backward_reduce(
144                grad_output,
145                saved_input,
146                mean,
147                invstd,
148                weight,
149                self.needs_input_grad[0],
150                self.needs_input_grad[1],
151                self.needs_input_grad[2],
152            )
153
154            if self.needs_input_grad[0]:
155                # synchronizing stats used to calculate input gradient.
156                num_channels = sum_dy.shape[0]
157                combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
158                torch.distributed.all_reduce(
159                    combined,
160                    torch.distributed.ReduceOp.SUM,
161                    process_group,
162                    async_op=False,
163                )
164                sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
165
166                # backward pass for gradient calculation
167                if weight is not None and weight.dtype != mean.dtype:
168                    weight = weight.to(mean.dtype)
169                grad_input = torch.batch_norm_backward_elemt(
170                    grad_output,
171                    saved_input,
172                    mean,
173                    invstd,
174                    weight,
175                    sum_dy,
176                    sum_dy_xmu,
177                    count_tensor,
178                )
179            # synchronizing of grad_weight / grad_bias is not needed as distributed
180            # training would handle all reduce.
181            if weight is None or not self.needs_input_grad[1]:
182                grad_weight = None
183
184            if weight is None or not self.needs_input_grad[2]:
185                grad_bias = None
186        else:
187            # This process got an empty input tensor in the forward pass.
188            # Although this process can directly set grad_input as an empty
189            # tensor of zeros, it still needs to participate in the collective
190            # communication to unblock its peers, as other peer processes might
191            # have received non-empty inputs.
192            num_channels = saved_input.shape[1]
193            if self.needs_input_grad[0]:
194                # launch all_reduce to unblock other peer processes
195                combined = torch.zeros(
196                    2 * num_channels, dtype=saved_input.dtype, device=saved_input.device
197                )
198                torch.distributed.all_reduce(
199                    combined,
200                    torch.distributed.ReduceOp.SUM,
201                    process_group,
202                    async_op=False,
203                )
204
205            # Leave grad_input, grad_weight and grad_bias as None, which will be
206            # interpreted by the autograd engine as Tensors full of zeros.
207
208        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
209
210
211class CrossMapLRN2d(Function):
212    @staticmethod
213    def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
214        ctx.size = size
215        ctx.alpha = alpha
216        ctx.beta = beta
217        ctx.k = k
218        ctx.scale = None
219
220        if input.dim() != 4:
221            raise ValueError(
222                f"CrossMapLRN2d: Expected input to be 4D, got {input.dim()}D instead."
223            )
224
225        ctx.scale = ctx.scale or input.new()
226        output = input.new()
227
228        batch_size = input.size(0)
229        channels = input.size(1)
230        input_height = input.size(2)
231        input_width = input.size(3)
232
233        output.resize_as_(input)
234        ctx.scale.resize_as_(input)
235
236        # use output storage as temporary buffer
237        input_square = output
238        torch.pow(input, 2, out=input_square)
239
240        pre_pad = int((ctx.size - 1) / 2 + 1)
241        pre_pad_crop = min(pre_pad, channels)
242
243        scale_first = ctx.scale.select(1, 0)
244        scale_first.zero_()
245        # compute first feature map normalization
246        for c in range(pre_pad_crop):
247            scale_first.add_(input_square.select(1, c))
248
249        # reuse computations for next feature maps normalization
250        # by adding the next feature map and removing the previous
251        for c in range(1, channels):
252            scale_previous = ctx.scale.select(1, c - 1)
253            scale_current = ctx.scale.select(1, c)
254            scale_current.copy_(scale_previous)
255            if c < channels - pre_pad + 1:
256                square_next = input_square.select(1, c + pre_pad - 1)
257                scale_current.add_(square_next, alpha=1)
258
259            if c > pre_pad:
260                square_previous = input_square.select(1, c - pre_pad)
261                scale_current.add_(square_previous, alpha=-1)
262
263        ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k)
264
265        torch.pow(ctx.scale, -ctx.beta, out=output)
266        output.mul_(input)
267
268        ctx.save_for_backward(input, output)
269        return output
270
271    @staticmethod
272    def backward(ctx, grad_output):
273        input, output = ctx.saved_tensors
274        grad_input = grad_output.new()
275
276        batch_size = input.size(0)
277        channels = input.size(1)
278        input_height = input.size(2)
279        input_width = input.size(3)
280
281        paddded_ratio = input.new(channels + ctx.size - 1, input_height, input_width)
282        accum_ratio = input.new(input_height, input_width)
283
284        cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size
285        inversePrePad = int(ctx.size - (ctx.size - 1) / 2)
286
287        grad_input.resize_as_(input)
288        torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output)
289
290        paddded_ratio.zero_()
291        padded_ratio_center = paddded_ratio.narrow(0, inversePrePad, channels)
292        for n in range(batch_size):
293            torch.mul(grad_output[n], output[n], out=padded_ratio_center)
294            padded_ratio_center.div_(ctx.scale[n])
295            torch.sum(
296                paddded_ratio.narrow(0, 0, ctx.size - 1),
297                0,
298                keepdim=False,
299                out=accum_ratio,
300            )
301            for c in range(channels):
302                accum_ratio.add_(paddded_ratio[c + ctx.size - 1])
303                grad_input[n][c].addcmul_(
304                    input[n][c], accum_ratio, value=-cache_ratio_value
305                )
306                accum_ratio.add_(paddded_ratio[c], alpha=-1)
307
308        return grad_input, None, None, None, None
309
310
311class BackwardHookFunction(torch.autograd.Function):
312    @staticmethod
313    def forward(ctx, *args):
314        ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
315        return args
316
317    @staticmethod
318    def backward(ctx, *args):
319        return args
320