xref: /aosp_15_r20/external/pytorch/torch/_inductor/decomposition.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2import functools
3import logging
4import math
5import sys
6import typing
7from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8
9import torch
10import torch._decomp as decomp
11import torch._prims_common as utils
12import torch.ao.quantization.fx._decomposed
13from torch._decomp import (
14    core_aten_decompositions,
15    get_decompositions,
16    remove_decompositions,
17)
18from torch._decomp.decompositions import (
19    _grid_sampler_2d as decomp_grid_sampler_2d,
20    pw_cast_for_opmath,
21)
22from torch._decomp.decompositions_for_rng import extra_random_decomps
23from torch._dynamo.utils import counters
24from torch._higher_order_ops.out_dtype import out_dtype
25from torch._inductor.utils import pad_listlike
26from torch._prims_common import (
27    elementwise_dtypes,
28    ELEMENTWISE_TYPE_PROMOTION_KIND,
29    type_to_dtype,
30)
31from torch.fx.experimental.symbolic_shapes import definitely_true, guard_size_oblivious
32
33from . import config, inductor_prims
34from .utils import (
35    is_gpu,
36    needs_fallback_due_to_atomic_add_limitations,
37    use_scatter_fallback,
38)
39
40
41log = logging.getLogger(__name__)
42aten = torch.ops.aten
43prims = torch.ops.prims
44quantized = torch.ops.quantized
45_quantized = torch.ops._quantized
46quantized_decomposed = torch.ops.quantized_decomposed
47
48inductor_decompositions = get_decompositions(
49    [
50        aten._adaptive_avg_pool2d_backward,
51        aten.addmv,
52        aten.arange,
53        aten.bitwise_and_,
54        aten.bitwise_or_,
55        aten.clamp_min_,
56        aten.dist,
57        aten.empty_like,
58        aten.flip,
59        aten.gelu,
60        aten.hardtanh,
61        aten.index_select,
62        aten.lcm,
63        aten.leaky_relu,
64        aten.linalg_vector_norm,
65        aten._log_softmax,
66        aten.max_pool2d_with_indices_backward,
67        aten._native_batch_norm_legit,
68        aten._native_batch_norm_legit_functional,
69        aten._native_batch_norm_legit_no_training,
70        aten._batch_norm_with_update,
71        aten._batch_norm_with_update_functional,
72        aten._batch_norm_no_update,
73        aten.batch_norm_backward,
74        aten.native_batch_norm,
75        aten.native_group_norm,
76        aten.native_layer_norm,
77        aten.nll_loss2d_backward,
78        aten._softmax,
79        aten.sin_,
80        aten.sqrt_,
81        out_dtype,
82        aten._to_copy,
83        aten.tril_indices,
84        aten.triu_indices,
85        aten.upsample_bilinear2d.vec,
86        quantized.linear_dynamic_fp16_unpacked_weight,
87        _quantized.wrapped_quantized_linear,
88    ]
89)
90decompositions = {**core_aten_decompositions(), **inductor_decompositions}
91
92# Remove unwanted decompositions included via the core ATen decompositions from
93# the Inductor decomp table.
94decomps_to_exclude = [
95    aten._unsafe_index,
96    aten._unsafe_masked_index,
97    aten._unsafe_masked_index_put_accumulate,
98    aten._scaled_dot_product_flash_attention_for_cpu.default,  # See comments in torch/_decomp/decompositions.py
99    aten._softmax_backward_data,
100    aten.clamp_max,
101    aten.clamp_min,
102    aten.glu,  # inductor lowers this directly
103    aten.select_scatter,  # need to be in the ATen graph in order for it to work with the re-inplacing pass
104    aten.slice_scatter,  # need to be in the ATen graph in order for it to work with the re-inplacing pass
105    aten.split.Tensor,  # inductor lowers this directly
106    aten.squeeze,  # inductor lowers this directly
107    aten.sum,  # inductor lowers this directly
108    aten.unbind,  # inductor lowers this directly
109]
110
111remove_decompositions(decompositions, decomps_to_exclude)
112
113
114def register_decomposition(
115    ops: List[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]]
116) -> Callable[..., Any]:
117    for op in [ops] if callable(ops) else ops:  # type: ignore[attr-defined]
118        if op in decompositions:
119            log.warning("duplicate decomp: %s", ops)
120    return decomp.register_decomposition(ops, decompositions)
121
122
123# TODO: for now, inductor doesn't handle asserts
124# because the condition is symbol -> tensor in the graph.
125@register_decomposition([aten._assert_async.msg])
126def assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
127    return
128
129
130# Following `assert_async_msg_decomp` and implement as non-op.
131@register_decomposition([aten._functional_assert_async.msg])
132def functional_assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
133    return
134
135
136@register_decomposition([aten.sym_constrain_range_for_size.default])
137def sym_constrain_range_for_size(
138    symbol: torch.SymInt,
139    *,
140    min: Optional[torch.types.Number] = None,
141    max: Optional[torch.types.Number] = None,
142) -> None:
143    return
144
145
146@register_decomposition([aten.clamp])
147@pw_cast_for_opmath
148def clamp(
149    x: torch.Tensor,
150    min: Optional[torch.types.Number] = None,
151    max: Optional[torch.types.Number] = None,
152) -> torch.Tensor:
153    if min is not None:
154        x = x.clamp_min(min)
155    if max is not None:
156        x = x.clamp_max(max)
157    return x
158
159
160@register_decomposition([aten.full])
161def full(
162    size: List[Union[int, torch.SymInt]],
163    fill_value: torch.types.Number,
164    **kwargs: Any,
165) -> torch.Tensor:
166    dtype = kwargs.get("dtype")
167    if dtype is None:
168        kwargs["dtype"] = type_to_dtype(type(fill_value))
169        return torch.full(size, fill_value, **kwargs)
170    return NotImplemented
171
172
173# Not really sure how to put this into the main library.  PrimTorch wants
174# empty_permuted to go to the prim, and typically users don't really want
175# to decompose to empty_strided (but inductor is OK with it, because we are
176# cool with strides and everything goes to empty_strided)
177@register_decomposition([aten.empty_permuted.default])
178def empty_permuted(
179    size: List[Union[int, torch.SymInt]],
180    physical_layout: List[int],
181    **kwargs: Any,
182) -> torch.Tensor:
183    perm = [0] * len(size)
184    for p, l in enumerate(physical_layout):
185        perm[l] = p
186    return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm)
187
188
189@register_decomposition([aten.convolution_backward])
190def convolution_backward(
191    grad_output: torch.Tensor,
192    input: torch.Tensor,
193    weight: torch.Tensor,
194    bias_sizes: List[int],
195    stride: Union[int, List[int]],
196    padding: Union[int, List[int]],
197    dilation: Union[int, List[int]],
198    transposed: bool,
199    output_padding: List[int],
200    groups: int,
201    output_mask: List[bool],
202) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
203    if not output_mask[2] or not is_gpu(grad_output.device.type):
204        return NotImplemented
205    grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))
206    grad_inp, grad_weight, _ = aten.convolution_backward(
207        grad_output,
208        input,
209        weight,
210        bias_sizes,
211        stride,
212        padding,
213        dilation,
214        transposed,
215        output_padding,
216        groups,
217        [output_mask[0], output_mask[1], False],
218    )
219    return (grad_inp, grad_weight, grad_bias)
220
221
222@register_decomposition([aten.round.decimals])
223def round_dec(x: torch.Tensor, decimals: int = 0) -> torch.Tensor:
224    ten_pow_decimals = 10.0**decimals
225    return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals)
226
227
228@register_decomposition([aten.bmm])
229@pw_cast_for_opmath
230def bmm(
231    self: torch.Tensor,
232    batch2: torch.Tensor,
233) -> torch.Tensor:
234    if config.coordinate_descent_tuning:
235        if guard_size_oblivious(self.shape[1] == 1) or guard_size_oblivious(
236            batch2.shape[2] == 1
237        ):
238            out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2)
239            return out
240    if self.device.type == "cpu":
241        if guard_size_oblivious(self.size(1) == 1) and guard_size_oblivious(
242            batch2.size(-1) == 1
243        ):
244            counters["inductor"]["decompose_bmm"] += 1
245            return torch.sum(
246                self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True
247            ).unsqueeze(1)
248    return NotImplemented
249
250
251@register_decomposition([aten.addmm])
252@pw_cast_for_opmath
253def addmm(
254    self: torch.Tensor,
255    mat1: torch.Tensor,
256    mat2: torch.Tensor,
257    beta: torch.types.Number = 1,
258    alpha: torch.types.Number = 1,
259) -> torch.Tensor:
260    if self.device.type == "cpu":
261        if guard_size_oblivious(mat1.size(0) == 1) and guard_size_oblivious(
262            mat2.size(-1) == 1
263        ):
264            counters["inductor"]["decompose_addmm"] += 1
265            out = torch.sum(
266                mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True
267            ).unsqueeze(0)
268            return alpha * out + beta * self
269        if (
270            guard_size_oblivious(mat1.size(0) == 1)
271            and definitely_true(mat2.size(0) <= 16)
272            and definitely_true(mat2.size(1) <= 16)
273        ):
274            counters["inductor"]["decompose_addmm"] += 1
275            out = (mat1.T * mat2).sum(dim=0, keepdim=True)
276            return alpha * out + beta * self
277    return NotImplemented
278
279
280@register_decomposition([aten.mm])
281@pw_cast_for_opmath
282def mm(
283    self: torch.Tensor,
284    input2: torch.Tensor,
285) -> torch.Tensor:
286    # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.
287    # todo: Look into why and fix it (hopefully)
288    if config.coordinate_descent_tuning:
289        if guard_size_oblivious(self.shape[0] == 1) or guard_size_oblivious(
290            input2.shape[1] == 1
291        ):
292            return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1)
293    if self.device.type == "cpu":
294        if (
295            guard_size_oblivious(self.size(-1) == 1)
296            and guard_size_oblivious(self.size(0) > 0)
297            and guard_size_oblivious(input2.size(0) == 1)
298            and (self.dtype == input2.dtype)
299            and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32)
300        ):
301            counters["inductor"]["decompose_mm"] += 1
302            return torch.cat([self[i, :] * input2 for i in range(self.size(0))])
303        if guard_size_oblivious(self.size(0) == 1) and guard_size_oblivious(
304            input2.size(-1) == 1
305        ):
306            counters["inductor"]["decompose_mm"] += 1
307            return torch.sum(
308                self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True
309            ).unsqueeze(0)
310    return NotImplemented
311
312
313# This pass does two things:
314# - Eliminate cat when there is only one tensor input
315# - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we
316#   don't remove ALL empty tensors, only the naughty ones)
317@register_decomposition([aten.cat.default])
318def cat(
319    tensors: List[torch.Tensor],
320    dim: int = 0,
321) -> torch.Tensor:
322    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
323
324    def non_empty_tensor(x: torch.Tensor) -> bool:
325        # For better or worse, this is a valid cat:
326        #
327        #   torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)])
328        #
329        # We'd like to eliminate naughtiness like this for downstream passes
330        # like split_cat.  The easiest way is to just drop such inputs
331        # (guarding that they are non-zero).
332        #
333        # Is it permissible for this filtering to be size-oblivious?  A case
334        # where this could matter is cat([(2, 2), (u0,)], dim=0); if u0
335        # happened to be zero, we would have liked to have filtered it out.
336        # But actually, the ONLY way this could have passed is if u0 == 0,
337        # so by the time we get here we have already installed a deferred
338        # runtime assert forcing u0 to be zero.  So if this hasn't happened,
339        # we know that the unbacked SymInt has appropriate size and there are
340        # no problems.
341        if len(x.shape) == 1 and guard_size_oblivious(x.shape[0] == 0):
342            return False
343
344        if dim < len(x.shape) and guard_size_oblivious(x.shape[dim] == 0):
345            return False
346
347        return True
348
349    filtered_tensors = list(filter(non_empty_tensor, tensors))
350
351    if len(filtered_tensors) == 1:
352        return filtered_tensors[0].clone()
353    elif 1 < len(filtered_tensors) < len(tensors):
354        # on the first call, when we remove empty tensors, we redispatch recursively
355        return aten.cat.default(filtered_tensors, dim)
356
357    # optimization, avoid concat for single, repeated input
358    if len(filtered_tensors) > 1 and all(
359        t is filtered_tensors[0] for t in filtered_tensors
360    ):
361        inp = filtered_tensors[0]
362        shape = list(inp.shape)
363        dim = dim + len(inp.shape) if dim < 0 else dim
364        shape.insert(dim, len(filtered_tensors))
365        return inp.unsqueeze(dim).expand(*shape).flatten(dim, dim + 1).clone()
366
367    # when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed)
368    return NotImplemented
369
370
371@register_decomposition([aten.angle])
372def angle(x: torch.Tensor) -> torch.Tensor:
373    if x.is_complex():
374        return torch.where(
375            torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real)
376        )
377
378    # when x is real number
379    #   if x >= 0, return 0
380    #   if x < 0, return pi
381    #   if x is nan, return nan
382    _, dtype = elementwise_dtypes(
383        x,
384        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
385    )
386    pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device)
387    ret = torch.where(x < 0, pi, 0.0)
388    return torch.where(torch.isnan(x), float("nan"), ret)
389
390
391@register_decomposition([aten.add])
392def add(
393    x: torch.Tensor,
394    y: torch.Tensor,
395    *,
396    alpha: Optional[torch.types.Number] = None,
397) -> torch.Tensor:
398    # Require both x and y to be complex tensors.
399    x_is_complex_tensor = torch.is_tensor(x) and x.is_complex()
400    y_is_complex_tensor = torch.is_tensor(y) and y.is_complex()
401    if not x_is_complex_tensor or not y_is_complex_tensor:
402        return NotImplemented
403    z = y
404    if alpha is not None:
405        z = alpha * y
406    complex_type = torch.promote_types(x.dtype, y.dtype)
407
408    # For complex typed `x`, `x.view(x.real.dtype)` doubles the last dimension and can cause problem
409    # when broadcasting the add.
410    def reshape_tensor_complex(tensor: torch.Tensor) -> torch.Tensor:
411        """Reshape tensor from [*initial_dims, last_dim] to *initial_dims, last_dim/2, 2]"""
412        # Get the current shape of the tensor
413        *initial_dims, last_dim = tensor.shape
414
415        # Check if the last dimension is even. We should never reach here since `x.view(x.real.dtype)`
416        # doubles the last dimension for complex numbers.
417        if last_dim % 2 != 0:
418            raise AssertionError(
419                "The size of the last dimension must be even to reshape it to [..., last_dim/2, 2]"
420            )
421
422        # Reshape the tensor
423        new_shape = (*initial_dims, last_dim // 2, 2)
424        reshaped_tensor = tensor.view(new_shape)
425        return reshaped_tensor
426
427    x_reshaped = reshape_tensor_complex(x.view(x.real.dtype))
428    z_reshaped = reshape_tensor_complex(z.view(y.real.dtype))
429    result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type)
430    return result
431
432
433@register_decomposition([aten.conj_physical])
434def conj_physical(self: torch.Tensor) -> torch.Tensor:
435    assert not self.is_complex(), "TODO: implement this"
436    return self
437
438
439@register_decomposition([aten.lift, aten.detach_])
440def lift(self: torch.Tensor) -> torch.Tensor:
441    return self
442
443
444@register_decomposition([aten.bernoulli.default])
445def bernoulli(
446    self: torch.Tensor,
447    *,
448    generator: Optional[torch.Generator] = None,
449) -> torch.Tensor:
450    assert generator is None
451    return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
452
453
454@register_decomposition([aten.fmin, prims.fmin])
455def fmin(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
456    return torch.where(torch.isnan(other) | (other > self), self, other)
457
458
459@register_decomposition([aten.fmax, prims.fmax])
460def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
461    return torch.where(torch.isnan(other) | (other < self), self, other)
462
463
464@register_decomposition(aten.amax)
465def amax(
466    self: torch.Tensor,
467    dim: Optional[int] = None,
468    keepdim: bool = False,
469) -> torch.Tensor:
470    if self.dtype == torch.bool:
471        return torch.any(self, dim=dim, keepdim=keepdim)
472    return NotImplemented
473
474
475@register_decomposition(aten.amin)
476def amin(
477    self: torch.Tensor,
478    dim: Optional[int] = None,
479    keepdim: bool = False,
480) -> torch.Tensor:
481    if self.dtype == torch.bool:
482        return torch.all(self, dim=dim, keepdim=keepdim)
483    return NotImplemented
484
485
486@register_decomposition([aten.narrow_copy])
487def narrow_copy(
488    self: torch.Tensor,
489    dim: int,
490    start: int,
491    length: int,
492) -> torch.Tensor:
493    return torch.narrow(self, dim, start, length).clone()
494
495
496@register_decomposition([aten.view_copy.default])
497def view_copy_default(
498    self: torch.Tensor,
499    size: List[Union[int, torch.SymInt]],
500) -> torch.Tensor:
501    return aten.view(self, size).clone()
502
503
504@register_decomposition([aten.view_copy.dtype])
505def view_copy_dtype(
506    self: torch.Tensor,
507    dtype: torch.dtype,
508) -> torch.Tensor:
509    return self.to(dtype).clone()
510
511
512def get_like_layout(
513    tensor: torch.Tensor,
514    memory_format: Optional[torch.memory_format] = None,
515) -> torch.memory_format:
516    # TODO: _to_copy tensor to stride permutation
517    if memory_format is torch.preserve_format or memory_format is None:
518        return utils.suggest_memory_format(tensor)
519    else:
520        return memory_format
521
522
523@register_decomposition(aten.rand_like)
524def rand_like(
525    self: torch.Tensor,
526    *,
527    dtype: Optional[torch.dtype] = None,
528    device: Optional[torch.device] = None,
529    memory_format: Optional[torch.memory_format] = None,
530    **kwargs: Any,
531) -> torch.Tensor:
532    return torch.rand(
533        [*self.size()],
534        dtype=dtype or self.dtype,
535        device=device or self.device,
536        **kwargs,
537    ).to(memory_format=get_like_layout(self, memory_format))
538
539
540@register_decomposition(aten.randn_like)
541def randn_like(
542    self: torch.Tensor,
543    *,
544    dtype: Optional[torch.dtype] = None,
545    device: Optional[torch.device] = None,
546    memory_format: Optional[torch.memory_format] = None,
547    **kwargs: Any,
548) -> torch.Tensor:
549    return torch.randn(
550        [*self.size()],
551        dtype=dtype or self.dtype,
552        device=device or self.device,
553        **kwargs,
554    ).to(memory_format=get_like_layout(self, memory_format))
555
556
557@register_decomposition(aten.full_like)
558def full_like(
559    self: torch.Tensor,
560    fill_value: Union[int, float],
561    *,
562    dtype: Optional[torch.dtype] = None,
563    layout: Optional[torch.layout] = None,
564    device: Optional[torch.device] = None,
565    pin_memory: bool = False,
566    requires_grad: bool = False,
567    memory_format: torch.memory_format = torch.preserve_format,
568) -> torch.Tensor:
569    return torch.full(
570        [*self.size()],
571        fill_value,
572        dtype=dtype or self.dtype,
573        layout=layout or self.layout,
574        device=device or self.device,
575        requires_grad=requires_grad,
576    ).to(memory_format=get_like_layout(self, memory_format))
577
578
579@register_decomposition(aten.randint_like.default)
580def randint_like(
581    self: torch.Tensor,
582    high: int,
583    *,
584    dtype: Optional[torch.dtype] = None,
585    device: Optional[torch.device] = None,
586    memory_format: Optional[torch.memory_format] = None,
587    **kwargs: Any,
588) -> torch.Tensor:
589    return aten.randint.low(
590        0,
591        high,
592        [*self.size()],
593        dtype=dtype or self.dtype,
594        device=device or self.device,
595        **kwargs,
596    ).to(memory_format=get_like_layout(self, memory_format))
597
598
599@register_decomposition(aten.randint_like.low_dtype)
600def randint_like_low(
601    self: torch.Tensor,
602    low: int,
603    high: int,
604    *,
605    dtype: Optional[torch.dtype] = None,
606    device: Optional[torch.device] = None,
607    memory_format: Optional[torch.memory_format] = None,
608    **kwargs: Any,
609) -> torch.Tensor:
610    return aten.randint.low(
611        low,
612        high,
613        [*self.size()],
614        dtype=dtype or self.dtype,
615        device=device or self.device,
616        **kwargs,
617    ).to(memory_format=get_like_layout(self, memory_format))
618
619
620@register_decomposition(aten.randint.default)
621def randint(
622    high: int,
623    size: List[Union[int, torch.SymInt]],
624    **kwargs: Any,
625) -> torch.Tensor:
626    return aten.randint.low(0, high, size, **kwargs)
627
628
629@register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default)
630def linear_dynamic_fp16_unpacked_weight(
631    input: torch.Tensor,
632    weight: torch.Tensor,
633    bias: torch.Tensor,
634) -> torch.Tensor:
635    packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight)
636    return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(
637        input, packed_weight, bias, weight.size()[0]
638    )
639
640
641@register_decomposition(_quantized.wrapped_quantized_linear.default)
642def wrapped_quantized_linear(
643    input: torch.Tensor,
644    input_scale: torch.Tensor,
645    input_zero_point: torch.Tensor,
646    weight: torch.Tensor,
647    weight_scale: torch.Tensor,
648    weight_zero_point: torch.Tensor,
649    bias: torch.Tensor,
650    out_scale: torch.Tensor,
651    out_zero_point: torch.Tensor,
652    out_channel: int,
653) -> torch.Tensor:
654    packed_weight = torch.ops._quantized._wrapped_linear_prepack(
655        weight, weight_scale, weight_zero_point, bias
656    )
657    return torch.ops._quantized._wrapped_quantized_linear_prepacked(
658        input,
659        input_scale,
660        input_zero_point,
661        packed_weight,
662        out_scale,
663        out_zero_point,
664        out_channel,
665    )
666
667
668@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
669def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor:
670    def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor:
671        x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3))
672        if sys.byteorder == "little":
673            return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None]
674        else:
675            return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None]
676
677    scales = bitcast_u8_to_f32(packed[..., -8:-4])
678    offsets = bitcast_u8_to_f32(packed[..., -4:])
679    return packed[..., :-8].to(torch.float32) * scales + offsets
680
681
682@register_decomposition([aten.grid_sampler_2d])
683@pw_cast_for_opmath
684def grid_sampler_2d(
685    a: torch.Tensor,
686    grid: torch.Tensor,
687    interpolation_mode: int = 0,
688    padding_mode: int = 0,
689    align_corners: bool = False,
690) -> torch.Tensor:
691    # We do not expand the grid (_expand_grid=False) on cpu for performance reasons
692    # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x
693    # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2)
694    # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first.
695    # Thus we apply this hack to not expand the grid for this case.
696    _expand_grid = not (
697        a.device == torch.device("cpu")
698        and interpolation_mode == 0
699        and a.is_contiguous(memory_format=torch.contiguous_format)
700    )
701
702    output = decomp_grid_sampler_2d(
703        a,
704        grid=grid,
705        interpolation_mode=interpolation_mode,
706        padding_mode=padding_mode,
707        align_corners=align_corners,
708        _expand_grid=_expand_grid,
709    )
710    return output
711
712
713@register_decomposition(aten._foreach_addcmul.Scalar)
714def _foreach_addcmul_scalar(
715    self: List[torch.Tensor],
716    left_tensors: List[torch.Tensor],
717    right_tensors: List[torch.Tensor],
718    scalar: float = 1,
719) -> List[torch.Tensor]:
720    return aten._foreach_add.List(
721        self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar
722    )
723
724
725@register_decomposition(aten._foreach_addcdiv.Scalar)
726def _foreach_addcdiv_scalar(
727    self: List[torch.Tensor],
728    left_tensors: List[torch.Tensor],
729    right_tensors: List[torch.Tensor],
730    scalar: float = 1,
731) -> List[torch.Tensor]:
732    return aten._foreach_add.List(
733        self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar
734    )
735
736
737@register_decomposition(aten._foreach_lerp.Scalar)
738def _foreach_lerp_scalar(
739    start_tensors: List[torch.Tensor],
740    end_tensors: List[torch.Tensor],
741    weight: torch.types.Number,
742) -> List[torch.Tensor]:
743    return aten._foreach_add.List(
744        start_tensors,
745        aten._foreach_mul.Scalar(
746            aten._foreach_sub.List(end_tensors, start_tensors), weight
747        ),
748    )
749
750
751@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
752@register_decomposition(aten.miopen_batch_norm)
753def miopen_batch_norm(
754    input: torch.Tensor,
755    weight: torch.Tensor,
756    bias: typing.Optional[torch.Tensor],
757    running_mean: typing.Optional[torch.Tensor],
758    running_var: typing.Optional[torch.Tensor],
759    training: bool,
760    exponential_average_factor: float,
761    epsilon: float,
762) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
763    a, b, c = aten.native_batch_norm(
764        input,
765        weight,
766        bias,
767        running_mean,
768        running_var,
769        training,
770        exponential_average_factor,
771        epsilon,
772    )
773
774    if training:
775        return (a, b, c)
776    return (
777        a,
778        weight.new_zeros((0,)),
779        weight.new_zeros((0,)),
780    )
781
782
783@functools.lru_cache(None)
784def fast_random_decomps() -> Dict[Any, Callable[..., Any]]:
785    return {**decompositions, **extra_random_decomps}
786
787
788# TODO(aakhundov): replace this (and the above) Any by more
789# specific type and fix all the cascading mypy errors
790def select_decomp_table() -> Dict[Any, Callable[..., Any]]:
791    """decomps can change based on config"""
792    if config.fallback_random:
793        return decompositions
794    return fast_random_decomps()
795
796
797@register_decomposition(aten.masked_scatter)
798def masked_scatter(
799    self: torch.Tensor,
800    mask: torch.Tensor,
801    source: torch.Tensor,
802) -> torch.Tensor:
803    from .codegen.common import BackendFeature, has_backend_feature
804
805    if has_backend_feature(self.device, BackendFeature.MASKED_SCATTER_WITH_INDEX):
806        # This two-step algorithm is the same as eager CUDA, for eager CPU we
807        # use a 1-shot serial iteration.
808        self, mask = aten.broadcast_tensors([self, mask])
809        source_idx = mask.reshape(-1).cumsum(0) - 1
810        self_flat, mask_flat, source_flat = (x.flatten() for x in (self, mask, source))
811        result = aten._unsafe_masked_index(source_flat, mask_flat, [source_idx], 0)
812        return torch.where(mask_flat, result, self_flat).view(self.shape)
813    return NotImplemented
814
815
816@register_decomposition(quantized_decomposed.choose_qparams.tensor)
817def choose_qparams_tensor(
818    input: torch.Tensor,
819    quant_min: int,
820    quant_max: int,
821    eps: float,
822    dtype: torch.dtype,
823) -> Tuple[torch.Tensor, torch.Tensor]:
824    min_val, max_val = torch.aminmax(input)
825    scale = (max_val - min_val) / float(quant_max - quant_min)
826    scale = torch.max(scale, torch.Tensor([eps]))
827    zero_point = quant_min - torch.round(min_val / scale).to(torch.int)
828    zero_point = torch.clamp(zero_point, quant_min, quant_max)
829    return scale.to(torch.float64), zero_point.to(torch.int64)
830
831
832@register_decomposition(aten.put)
833def put(
834    self: torch.Tensor,
835    index: torch.Tensor,
836    source: torch.Tensor,
837    accumulate: bool = False,
838) -> torch.Tensor:
839    flattened = self.flatten()
840    flattened = torch.index_put(
841        flattened, [index], source.reshape(index.shape), accumulate
842    )
843    return flattened.reshape(self.shape)
844
845
846@register_decomposition(aten.put_)
847def put_(
848    self: torch.Tensor,
849    index: torch.Tensor,
850    source: torch.Tensor,
851    accumulate: bool = False,
852) -> torch.Tensor:
853    out = aten.put(self, index, source, accumulate=accumulate)
854    return self.copy_(out)
855
856
857@register_decomposition(aten._softmax_backward_data.default)
858@pw_cast_for_opmath
859def _softmax_backward_data(
860    grad_output: torch.Tensor,
861    output: torch.Tensor,
862    dim: int,
863    input_dtype: torch.dtype,
864) -> torch.Tensor:
865    new_grad_output = grad_output * output
866    sum_new_grad = torch.sum(new_grad_output, dim=dim, keepdim=True)
867    # grad_input = new_grad_output - output * sum_new_grad
868    grad_input = inductor_prims.fma(-output, sum_new_grad, new_grad_output)
869
870    # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor
871    # if grad_output.device == torch.device("cpu"):
872    #     return grad_input.contiguous()
873
874    if grad_output.dtype != input_dtype:
875        grad_input = grad_input.to(input_dtype)
876    return grad_input.contiguous()
877
878
879@register_decomposition(aten.index_reduce)
880def index_reduce(
881    self: torch.Tensor,
882    dim: int,
883    index: torch.Tensor,
884    src: torch.Tensor,
885    reduction_type: str,
886    *,
887    include_self: bool = True,
888) -> torch.Tensor:
889    if reduction_type == "mean" and not needs_fallback_due_to_atomic_add_limitations(
890        self.dtype
891    ):
892        true_division = self.dtype.is_floating_point or self.dtype.is_complex
893        ones = torch.ones_like(src)
894        if include_self:
895            out = self
896            counts = torch.ones_like(self).index_add(dim, index, ones)
897        else:
898            out = self.index_fill(dim, index, 0)
899            counts = torch.zeros_like(self).index_add(dim, index, ones)
900            counts = counts.masked_fill(counts < 1, 1)
901        out = out.index_add(dim, index, src)
902        return out / counts if true_division else out // counts
903
904    if use_scatter_fallback(
905        aten.scatter_reduce_.two,
906        reduction_type,
907        self.dtype,
908        src.dtype,
909        src.device.type,
910        True,
911    ):
912        return NotImplemented
913
914    repeats = self.shape[dim + 1 :].numel() * self.shape[:dim].numel()
915    index_shape = (index.numel(), *self.shape[dim + 1 :], *self.shape[:dim])
916    perm = (*range(self.ndim - dim, self.ndim), 0, *range(1, self.ndim - dim))
917    scatter_index = (
918        index.to(torch.int64)
919        .repeat_interleave(repeats)
920        .reshape(index_shape)
921        .permute(perm)
922    )
923    return self.scatter_reduce(
924        dim,
925        scatter_index,
926        src,
927        reduction_type,
928        include_self=include_self,
929    )
930
931
932@register_decomposition(aten.max_pool2d_with_indices)
933def max_pool2d_with_indices(
934    x: torch.Tensor,
935    kernel_size: List[int],
936    stride: Optional[Union[int, List[int]]] = None,
937    padding: Union[int, List[int]] = 0,
938    dilation: Union[int, List[int]] = 1,
939    ceil_mode: bool = False,
940) -> Tuple[torch.Tensor, torch.Tensor]:
941    if dilation == 1:
942        dilation = [1, 1]
943
944    if padding == 0:
945        padding = [0, 0]
946
947    if not stride:
948        stride = kernel_size
949
950    kernel_size = pad_listlike(kernel_size, 2)
951    dilation = pad_listlike(dilation, 2)
952    padding = pad_listlike(padding, 2)
953    stride = pad_listlike(stride, 2)
954
955    window_size = kernel_size[0] * kernel_size[1]
956    # We fallback when using non-default dilation or when the window size is too large
957    if (
958        torch._inductor.lowering.should_fallback_max_pool2d_with_indices(
959            kernel_size, dilation
960        )
961        or window_size > torch.iinfo(torch.int8).max
962    ):
963        return NotImplemented
964
965    vals, offsets = prims._low_memory_max_pool2d_with_offsets(
966        x,
967        kernel_size,
968        stride,
969        padding,
970        dilation,
971        ceil_mode,
972    )
973    indices = prims._low_memory_max_pool2d_offsets_to_indices(
974        offsets,
975        kernel_size[1],
976        x.size(-1),
977        stride,
978        padding,
979    )
980    return vals, indices
981