xref: /aosp_15_r20/external/pytorch/torch/_decomp/decompositions_for_jvp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import inspect
4from typing import Callable, Dict, List, Optional, Tuple
5
6import torch
7import torch._decomp
8from torch import Tensor
9from torch._prims_common.wrappers import _maybe_remove_out_wrapper
10
11
12decomposition_table = torch._decomp.decomposition_table
13decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {}
14register_decomposition = torch._decomp.register_decomposition
15aten = torch.ops.aten
16
17# NOTE: [forward-mode AD decompositions mechanism]
18#
19# The mechanism is in VariableType,
20#   IF any inputs have forward grad
21#      AND there is no forward AD formula implemented
22#      AND the functions is actually differentiable
23#   run the decomposition
24#      See run_jit_decomposition_with_args_for_jvp
25#      We currently use python decompositions that we torchscript.
26#
27# Note that we would be building the backward graph at the decomposed level
28# too, but that is OK, because we would've errored out otherwise anyway.
29#
30# TODO: The mechanism we are using to register decompositions doesn't
31# seem to be exclusively used for jvp. So open question here is whether
32# torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things.
33# If that is the case, we may go down the decomposition path unexpectedly
34# (and possibly produce an unintelligible error) vs erroring out earlier and
35# printing that the forward AD formula is not implemented.
36#
37# The solution to this may be to have a explicitly white list control when
38# to enable the decomposition.
39
40
41def maybe_register_decomposition(op):
42    def decorator(f):
43        try:
44            return register_decomposition(op)(f)
45        except Exception:
46            return f
47
48    return decorator
49
50
51# Functions where we need a special decomposition for jvp but there's another version that
52# should be used more generally (ex. for jvp we need to recompute the mean and variance for
53# the backwards of a normalization function. Without jvp, it should use the saved value)
54decomposition_table_for_jvp = {}
55
56
57def register_decomposition_for_jvp(fn):
58    return register_decomposition(fn, registry=decomposition_table_for_jvp)
59
60
61def _register_jit_decomposition_for_jvp(decomp, use_python=False):
62    if decomp in decomposition_table_for_jvp:
63        decomposition_table_used = decomposition_table_for_jvp
64    elif decomp in decomposition_table:
65        decomposition_table_used = decomposition_table
66    else:
67        raise RuntimeError(f"could not find decomposition for {decomp}")
68    decomp_fn = decomposition_table_used[decomp]
69
70    # `out_wrapper` extends a decompositions signature with
71    # an `out` parameter. However jit will use the unwrapped function's
72    # signature instead so we need to unwrap here to prevent an error
73    decomp_fn = _maybe_remove_out_wrapper(decomp_fn)
74
75    if use_python:
76        decomp_fn = torch.jit.ignore(decomp_fn)
77        sig = inspect.signature(decomp_fn)
78
79        # Create a string wrapping the function from the signature
80        # example output:
81        # def wrapped_decomp(x: torch.Tensor, y: int, z: int):
82        #   return decomp_fn(x, y, z)
83        # Thanks copilot!
84        def get_function_def(sig):
85            param_def = [f"{param_str}" for param_str in sig.parameters.values()]
86            param_use = [f"{param_str}" for param_str in sig.parameters.keys()]
87
88            return f"def wrapped_decomp({', '.join(param_def)}):\n  return decomp_fn({', '.join(param_use)})\n"
89
90        f_str = get_function_def(sig)
91        graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph
92    else:
93        graph = torch.jit.script(decomp_fn).graph
94    torch.jit._register_decomposition(decomp, graph)
95
96
97# The only decompositions here are temporary or hacks for the purposes of jvp
98
99
100# TODO: do these also belong here?
101@maybe_register_decomposition(aten.trace.default)
102def trace(self: Tensor) -> Tensor:
103    return torch.sum(torch.diag(self))
104
105
106@maybe_register_decomposition(aten.log_sigmoid_forward.default)
107def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
108    min = torch.minimum(self.new_zeros(()), self)
109    z = torch.exp(-torch.abs(self))
110    if self.is_cuda:
111        buffer = self.new_zeros((0,))
112    else:
113        buffer = z
114    return min - torch.log1p(z), buffer
115
116
117def recompute_mean_var(
118    input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool
119):
120    # for most norm decompositions, it will be the same as the core version except for here.
121    # We recompute the mean and variance so that they track gradients through input
122
123    mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
124    var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
125    eps = torch.pow(1 / rstd, 2) - var  # this makes me so sad inside
126    eps = eps.detach()
127    rstd = 1 / torch.sqrt(var + eps)
128    return mean, rstd
129
130
131@register_decomposition_for_jvp(aten.native_layer_norm_backward)
132def native_layer_norm_backward(
133    grad_out: Tensor,
134    input: Tensor,
135    normalized_shape: List[int],
136    mean: Tensor,
137    rstd: Tensor,
138    weight: Optional[Tensor],
139    bias: Optional[Tensor],
140    output_mask: List[bool],
141) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
142    input_shape = input.shape
143    input_ndim = input.dim()
144
145    axis = input_ndim - len(normalized_shape)
146    inner_dims = input_shape[axis:]
147    outer_dims = input_shape[:axis]
148    inner_dim_indices = list(range(axis, input_ndim))
149    outer_dim_indices = list(range(0, axis))
150
151    N = 1
152    for i in inner_dims:
153        N *= i
154    M = 1
155    for i in outer_dims:
156        M *= i
157    if M <= 0 or N <= 0:
158        return (
159            input.new_zeros(input_shape),
160            input.new_zeros(input_shape[axis:]),
161            input.new_zeros(input_shape[axis:]),
162        )
163
164    mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)
165
166    x_hat = (input - mean_) * rstd_
167    if weight is not None:
168        grad_x_hat = grad_out * weight
169    else:
170        grad_x_hat = grad_out
171    a = grad_x_hat * N
172    b = torch.sum(grad_x_hat, inner_dim_indices, True)
173    c1 = torch.mul(grad_x_hat, x_hat)
174    c2 = torch.sum(c1, inner_dim_indices, True)
175    c3 = torch.mul(x_hat, c2)
176    inner = a - b - c3
177
178    if output_mask[0]:
179        d_input: Optional[Tensor] = (rstd_ / N) * inner
180    else:
181        d_input = torch.zeros_like(input)  # should be None but doesn't work with vjp
182
183    if output_mask[1] and weight is not None:
184        if len(outer_dim_indices) > 0:
185            d_weight: Optional[Tensor] = torch.sum(
186                grad_out * x_hat, outer_dim_indices, False
187            )
188        else:
189            d_weight = grad_out * x_hat
190    elif weight is not None:
191        d_weight = torch.zeros_like(weight)  # should be None but doesn't work with vjp
192    else:
193        d_weight = torch.zeros(())  # should be None but doesn't work with vjp
194
195    if output_mask[2] and bias is not None:
196        if len(outer_dim_indices) > 0:
197            d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
198        else:
199            d_bias = grad_out.clone()
200    elif bias is not None:
201        d_bias = torch.zeros_like(bias)  # should be None but doesn't work with vjp
202    else:
203        d_bias = torch.zeros(())  # should be None but doesn't work with vjp
204
205    return (d_input, d_weight, d_bias)
206
207
208def prod(x: List[int]):
209    r = 1
210    for i in x:
211        r *= i
212    return r
213
214
215@register_decomposition_for_jvp(aten.native_batch_norm_backward)
216def native_batch_norm_backward(
217    grad_out: Tensor,
218    input: Tensor,
219    weight: Optional[Tensor],
220    running_mean: Optional[Tensor],
221    running_var: Optional[Tensor],
222    save_mean: Optional[Tensor],
223    save_invstd: Optional[Tensor],
224    train: bool,
225    eps: float,
226    output_mask: List[bool],
227) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
228    input_shape = input.shape
229    input_rank = input.dim()
230    assert input_rank >= 2, "rank of the input must be at least 2"
231
232    axis = 1
233    num_features = prod(input_shape) / input_shape[axis]  # type: ignore[arg-type]
234    mean = save_mean
235    invstd = save_invstd
236    if train:
237        assert (
238            save_mean is not None and save_invstd is not None
239        ), "when train=True, save_mean and save_invstd are required"
240
241        reduciton_dims = [0] + list(range(2, input.dim()))
242        assert invstd is not None  # for typing
243        mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
244    else:
245        assert running_mean is not None and running_var is not None
246        mean = running_mean
247        invstd = torch.rsqrt(running_var + eps)
248
249    assert invstd is not None and mean is not None
250
251    broadcast_mask = [1] * input_rank
252    broadcast_mask[axis] = input_shape[axis]
253
254    reduction_axes: List[int] = []
255    for i in range(input_rank):
256        if i != axis:
257            reduction_axes.append(i)
258
259    mean = torch.reshape(mean, broadcast_mask)
260    norm = 1.0 / num_features
261    grad_output_sum = torch.sum(grad_out, reduction_axes)
262    dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
263
264    grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
265    proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
266
267    if weight is None:
268        grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
269    else:
270        grad_scale = torch.reshape(invstd * weight, broadcast_mask)
271
272    if train:
273        proj = (input - mean) * proj_scale
274        grad_input = ((grad_out - proj) - grad_mean) * grad_scale
275    else:
276        grad_input = grad_out * grad_scale
277
278    if output_mask[1]:
279        grad_weight = dot_p * invstd
280    elif weight is not None:
281        grad_weight = torch.zeros_like(
282            weight
283        )  # should be None but doesn't work with vjp
284    else:
285        grad_weight = torch.zeros(())  # should be None but doesn't work with vjp
286
287    if output_mask[2]:
288        grad_bias = grad_output_sum
289    else:
290        grad_bias = torch.zeros_like(
291            grad_output_sum
292        )  # should be None but doesn't work with vjp
293
294    return (grad_input, grad_weight, grad_bias)
295
296
297@register_decomposition_for_jvp(aten.batch_norm_backward)
298def batch_norm_backward(
299    grad_out: Tensor,
300    input: Tensor,
301    weight: Tensor,
302    running_mean: Optional[Tensor],
303    running_var: Optional[Tensor],
304    save_mean: Optional[Tensor],
305    save_var: Optional[Tensor],
306    update: bool,
307    eps: float,
308    output_mask: List[bool],
309    reserve: Tensor,
310) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
311    return native_batch_norm_backward(
312        grad_out,
313        input,
314        weight,
315        running_mean,
316        running_var,
317        save_mean,
318        save_var,
319        update,
320        eps,
321        output_mask,
322    )
323
324
325_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
326_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
327_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
328_register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default)
329_register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default)
330_register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default)
331_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
332_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
333_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)
334_register_jit_decomposition_for_jvp(torch.ops.aten.batch_norm_backward.default)
335_register_jit_decomposition_for_jvp(torch.ops.aten.miopen_batch_norm_backward.default)
336