xref: /aosp_15_r20/external/pytorch/torch/_inductor/lowering.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import functools
4import itertools
5import logging
6import math
7import operator
8import os
9import warnings
10from collections import defaultdict
11from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
12from unittest.mock import patch
13
14import sympy
15
16import torch
17import torch.ao.quantization.fx._decomposed
18import torch.fx
19import torch.utils._pytree as pytree
20from torch._higher_order_ops.associative_scan import associative_scan_op
21from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
22from torch._prims_common import (
23    canonicalize_dim,
24    canonicalize_dims,
25    check,
26    dtype_to_type,
27    elementwise_dtypes,
28    ELEMENTWISE_TYPE_PROMOTION_KIND,
29    get_computation_dtype,
30    is_boolean_dtype,
31    is_float_dtype,
32    is_integer_dtype,
33    Number,
34)
35from torch.fx.experimental.sym_node import magic_methods, method_to_operator
36from torch.utils._sympy.functions import (
37    CeilDiv,
38    FloorDiv,
39    Identity,
40    IntTrueDiv,
41    ModularIndexing,
42)
43
44from .._dynamo.utils import import_submodule
45from . import config, inductor_prims, ir, test_operators  # NOQA: F401
46from .decomposition import decompositions, get_decompositions
47from .ir import (
48    DtypeView,
49    ExpandView,
50    IndexingConstant,
51    is_triton,
52    ops_wrapper,
53    PermuteView,
54    Pointwise,
55    Reduction,
56    SqueezeView,
57    TensorBox,
58    validate_ir,
59    View,
60)
61from .utils import (
62    ceildiv,
63    decode_device,
64    is_dynamic,
65    is_gpu,
66    is_pointwise_use,
67    needs_fallback_due_to_atomic_add_limitations,
68    pad_listlike,
69    sympy_product,
70    use_scatter_fallback,
71)
72from .virtualized import ops, V
73
74
75log = logging.getLogger(__name__)
76lowerings: Dict[torch._ops.OpOverload, Callable[..., Any]] = {}
77# Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints
78_maybe_layout_constraints: Dict[
79    torch._ops.OpOverload, Optional[Callable[..., Any]]
80] = {}
81fallbacks: Set[torch._ops.OpOverload] = set()
82aten = torch.ops.aten
83tr_c10d = torch.ops.tr_c10d
84prims = torch.ops.prims
85needs_realized_inputs: Set[torch._ops.OpOverload] = set()
86foreach_ops: Set[torch._ops.OpOverload] = set()
87inplace_foreach_ops: Set[torch._ops.OpOverload] = set()
88inplaceable_foreach_ops: Dict[torch._ops.OpOverload, torch._ops.OpOverload] = {}
89quantized_decomposed = torch.ops.quantized_decomposed
90
91
92def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., Any]]:
93    """Get layout constraints. Returns None if there are no layout constraints."""
94    if not isinstance(fn, torch._ops.OpOverload):
95        # Only OpOverloads have layout constraints.
96        return None
97    if fn in _maybe_layout_constraints:
98        return _maybe_layout_constraints[fn]
99    # OpOverload with custom lowerings override tag-based layout constraints
100    if fn in lowerings:
101        _maybe_layout_constraints[fn] = None
102        return None
103    # We lazily register tag-based layout constraints.
104
105    def handle_layout_constraint_tag(tag):
106        if tag is torch._C.Tag.needs_fixed_stride_order:
107            _maybe_layout_constraints[fn] = constrain_to_fx_strides
108            return _maybe_layout_constraints[fn]
109        elif tag is torch._C.Tag.flexible_layout:
110            _maybe_layout_constraints[fn] = None
111            return None
112        else:
113            raise AssertionError(f"Unknown layout constraint tag: {tag}")
114
115    tag = get_layout_constraint_tag(fn)
116    return handle_layout_constraint_tag(tag)
117
118
119def get_layout_constraint_tag(fn):
120    tags_by_priority = [
121        torch._C.Tag.needs_fixed_stride_order,
122        torch._C.Tag.flexible_layout,
123    ]
124    for tag in tags_by_priority:
125        if tag in fn.tags:
126            return tag
127    return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
128
129
130def assert_nyi(cond, msg):
131    if not cond:
132        raise NotImplementedError(f"inductor does not support {msg}")
133
134
135def add_needs_realized_inputs(fn):
136    if isinstance(fn, (list, tuple, set)):
137        return [add_needs_realized_inputs(x) for x in fn]
138    needs_realized_inputs.add(fn)
139    if isinstance(fn, torch._ops.OpOverloadPacket):
140        needs_realized_inputs.update(
141            getattr(fn, overload) for overload in fn.overloads()
142        )
143
144
145def add_layout_constraint(fn, constraint):
146    if isinstance(fn, torch._ops.OpOverloadPacket):
147        for overload in fn.overloads():
148            _maybe_layout_constraints[getattr(fn, overload)] = constraint
149    else:
150        _maybe_layout_constraints[fn] = constraint
151
152
153add_needs_realized_inputs(
154    [
155        aten.as_strided,
156        aten.as_strided_copy,
157        aten.avg_pool2d,
158        aten.avg_pool2d_backward,
159        aten.bmm,
160        aten.convolution,
161        aten.convolution_backward,
162        aten.max_pool2d_with_indices,
163        aten.max_pool2d_with_indices_backward,
164        aten.mm,
165        aten.upsample_nearest2d,
166        aten._upsample_nearest_exact2d,
167        aten._int_mm,
168    ]
169)
170
171# TODO(jansel): ezyang says we won't need this in the future, try removing it
172# based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28
173DTYPE_ID_LOOKUP = {
174    0: torch.uint8,
175    1: torch.int8,
176    2: torch.int16,
177    3: torch.int32,
178    4: torch.int64,
179    5: torch.float16,
180    6: torch.float32,
181    7: torch.float64,
182    8: torch.complex32,
183    9: torch.complex64,
184    10: torch.complex32,
185    11: torch.bool,
186    15: torch.bfloat16,
187    # TODO(jansel): add quantized types?
188    #  _(c10::qint8, QInt8) /* 12 */
189    # _(c10::quint8, QUInt8) /* 13 */
190    # _(c10::qint32, QInt32) /* 14 */
191    # _(c10::quint4x2, QUInt4x2) /* 16 */
192    # _(c10::quint2x4, QUInt2x4) /* 17 */
193}
194
195
196def decode_dtype(dtype: int):
197    if not isinstance(dtype, int):
198        return dtype
199    assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP"
200    dtype = DTYPE_ID_LOOKUP[dtype]
201    return dtype
202
203
204def is_integer_type(x):
205    if isinstance(x, TensorBox):
206        return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
207    elif isinstance(x, sympy.Expr):
208        return x.is_integer is True  # type: ignore[attr-defined]
209    else:
210        return isinstance(x, int)
211
212
213def is_boolean_type(x):
214    if isinstance(x, TensorBox):
215        return is_boolean_dtype(x.get_dtype())
216    else:
217        return isinstance(x, bool)
218
219
220def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND):
221    def construct_input(inp):
222        if isinstance(inp, (Number, sympy.Basic)):
223            return inp
224        else:
225            assert hasattr(inp, "get_dtype")
226            dim = len(inp.get_size())
227            # construct a tmp tensor to feed into torch.result_type
228            return torch.zeros([1] * dim, dtype=inp.get_dtype())
229
230    inps = [construct_input(arg) for arg in args]
231    _, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind)
232    return dtype
233
234
235def get_overloads(aten_fn):
236    if not isinstance(aten_fn, (list, tuple)):
237        aten_fn = [aten_fn]
238    else:
239        aten_fn = list(aten_fn)
240
241    for fn in list(aten_fn):
242        if isinstance(fn, torch._ops.OpOverloadPacket):
243            for overload in fn.overloads():
244                other_fn = getattr(fn, overload)
245                if other_fn not in lowerings:
246                    aten_fn.append(other_fn)
247
248    return aten_fn
249
250
251def in_namespace(op, namespace):
252    if isinstance(op, torch._ops.OpOverloadPacket):
253        return namespace in op._qualified_op_name
254    elif isinstance(op, torch._ops.OpOverload):
255        return namespace in op.name()
256    return False
257
258
259def transform_args(args, broadcast, type_promotion_kind, convert_input_to_bool):
260    indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
261    if (type_promotion_kind or convert_input_to_bool) and indices:
262        if convert_input_to_bool:
263            dtype = torch.bool
264        else:
265            # FIXME that's a crude approximation for promoting args
266            promoting_args = [
267                a
268                for a in args
269                if isinstance(a, (Number, sympy.Basic))
270                or getattr(a, "dtype", None) is not None
271            ]
272            dtype = get_promoted_dtype(
273                *promoting_args, type_promotion_kind=type_promotion_kind
274            )
275
276        # sometimes args are an immutable list so we can't mutate them
277        def promote(arg):
278            if isinstance(arg, TensorBox):
279                return to_dtype(arg, dtype)
280            elif isinstance(arg, ir.Constant):
281                return ir.Constant(arg.value, dtype, args[indices[0]].get_device())
282            else:
283                return arg
284
285        args = [promote(a) for a in args]
286    if broadcast and indices:
287        for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
288            args[i] = x
289        for i in range(len(args)):
290            if isinstance(args[i], ir.Constant):
291                args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size()))
292
293    return args
294
295
296def _register_foreach_lowering(aten_fn, decomp_fn):
297    """
298    Add a foreach lowering to lowerings dict.
299
300    Arguments:
301        aten_fn: torch.ops.aten.* fn we are lowering
302        decomp_fn: alternate implementation on our IR
303        broadcast: True to apply broadcasting to tensor inputs
304        type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion
305        convert_input_to_bool: some logical ops require inputs are converted to bool
306    """
307
308    @functools.wraps(decomp_fn)
309    def wrapped(*args, **kwargs):
310        assert len(args) <= 2
311        out = decomp_fn(*args, **kwargs)
312        validate_ir(out)
313        return out
314
315    aten_fns = get_overloads(aten_fn)
316    foreach_ops.update(aten_fns)
317    lowerings.update(dict.fromkeys(aten_fns, wrapped))
318    return wrapped
319
320
321def _register_lowering(
322    aten_fn, decomp_fn, broadcast, type_promotion_kind, convert_input_to_bool
323):
324    """
325    Add a lowering to lowerings dict
326
327    Arguments:
328        aten_fn: torch.ops.aten.* fn we are lowering
329        decomp_fn: alternate implementation on our IR
330        broadcast: True to apply broadcasting to tensor inputs
331        type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion
332        convert_input_to_bool: some logical ops require inputs are converted to bool
333    """
334
335    @functools.wraps(decomp_fn)
336    def wrapped(*args, **kwargs):
337        args: Union[List[Any], Tuple[Any, ...], Dict[Any, Any]] = list(args)
338        unpacked = False
339        # TODO maybe we need to use pytrees here
340        if len(args) == 1 and isinstance(args[0], (list, tuple)):
341            unpacked = True
342            args = args[0]
343
344        # kwargs tensors not supported yet unless it's a fallback op
345        if not all(
346            (fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn
347        ):
348            assert not any(isinstance(x, TensorBox) for x in kwargs.values())
349            # explicitly assert for "out=" ops for better error messages
350            assert not any(
351                x == "out" for x in kwargs.keys()
352            ), "out= ops aren't yet supported"
353
354        args = transform_args(
355            args, broadcast, type_promotion_kind, convert_input_to_bool
356        )
357
358        if unpacked:
359            args = [args]
360
361        out = decomp_fn(*args, **kwargs)
362        validate_ir(out)
363
364        return out
365
366    aten_fn = get_overloads(aten_fn)
367
368    lowerings.update(dict.fromkeys(aten_fn, wrapped))
369    return wrapped
370
371
372def register_lowering(
373    aten_fn,
374    broadcast=False,
375    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
376    convert_input_to_bool=False,
377):
378    """
379    Shim to support decorator syntax.
380    """
381    return functools.partial(
382        _register_lowering,
383        aten_fn,
384        broadcast=broadcast,
385        type_promotion_kind=type_promotion_kind,
386        convert_input_to_bool=convert_input_to_bool,
387    )
388
389
390def broadcast_symbolic_shapes(a, b):
391    """
392    Broadcasting logic based on symbolic shapes.
393
394    We give the shapes 0 and 1 concrete values, while all other shapes
395    are symbolic sympy formulas.
396    """
397    output = []
398    for x, y in itertools.zip_longest(
399        reversed(a), reversed(b), fillvalue=sympy.Integer(1)
400    ):
401        if y == 1:
402            output.append(x)
403        elif x == 1:
404            output.append(y)
405        else:
406            V.graph.sizevars.guard_equals(x, y)
407            if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols):
408                output.append(y)  # prefer shorter formula
409            else:
410                output.append(x)
411    return tuple(reversed(output))
412
413
414def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None):
415    assert (
416        override_return_dtype is None or type_promotion_kind is None
417    ), "only one of override_return_dtype or type_promotion_kind may be given"
418
419    if override_return_dtype is None and type_promotion_kind is None:
420        type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
421
422    if not any(isinstance(x, (sympy.Basic, int, float)) for x in inputs):
423        return inputs
424    if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs):
425        dtype = override_return_dtype or get_promoted_dtype(
426            *inputs, type_promotion_kind=type_promotion_kind
427        )
428
429        def const_func(x):
430            if isinstance(x, sympy.Basic):
431                return ir.IndexingConstant(x, dtype, decode_device(None))
432            else:
433                return ir.Constant(x, dtype, decode_device(None))
434
435        return [const_func(x) for x in inputs]
436    ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView, ir.Constant)))
437    out = []
438    for x in inputs:
439        if isinstance(x, (int, float)):
440            out.append(
441                ExpandView.create(
442                    ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size())
443                )
444            )
445        elif isinstance(x, sympy.Basic):
446            out.append(
447                ExpandView.create(
448                    IndexingConstant(x, ex.get_dtype(), ex.get_device()),
449                    list(ex.get_size()),
450                )
451            )
452        else:
453            out.append(x)
454
455    return out
456
457
458def make_pointwise(
459    fn,
460    override_return_dtype=None,
461    override_device=None,
462    override_fn_when_input_bool=None,
463    override_fn_when_gpu_float64=None,
464    allow_alpha=False,
465    triton_fallback=None,
466):
467    def inner(*inputs: List[TensorBox], alpha=None):
468        if triton_fallback is not None and any(map(is_triton, inputs)):
469            assert not allow_alpha  # not implemented
470            return triton_fallback(*inputs)
471
472        inputs = promote_constants(inputs, override_return_dtype)
473        if allow_alpha:
474            if alpha is not None and alpha != 1:
475                inputs = list(inputs)
476                inputs[-1] = mul(inputs[-1], alpha)
477        else:
478            assert alpha is None
479        loaders = [x.make_loader() for x in inputs]
480        ranges = inputs[0].get_size()
481        dtype = override_return_dtype or inputs[0].get_dtype()
482        is_gpu_device = is_gpu(decode_device(inputs[0].get_device()).type)
483
484        for other in inputs[1:]:
485            assert isinstance(other, ir.BaseConstant) or len(ranges) == len(
486                other.get_size()
487            ), f"ndim mismatch {fn} {ranges} {other.get_size()}"
488
489        # in tracing, we will annotate pointwise nodes that correspond to the output of
490        # a pointwise node that would have been run in eager. intermediary pointwise nodes
491        # during decompositions are not annotated.
492        emulate_precision_casts = (
493            V.graph is not None
494            and getattr(V.graph, "current_node", None) is not None
495            and V.graph.current_node.meta is not None
496            and V.graph.current_node.meta.get("low_precision_pointwise_barrier", False)
497            and dtype in (torch.bfloat16, torch.float16)
498        )
499
500        def inner_fn(index):
501            assert len(index) == len(ranges), f"wrong ndim {index} {ranges}"
502            if dtype == torch.bool and override_fn_when_input_bool is not None:
503                return override_fn_when_input_bool(*[load(index) for load in loaders])
504            elif (
505                override_fn_when_gpu_float64
506                and is_gpu_device
507                and dtype == torch.float64
508            ):
509                return override_fn_when_gpu_float64(*[load(index) for load in loaders])
510            else:
511                inputs_loaded = []
512                for load in loaders:
513                    out = load(index)
514                    if emulate_precision_casts:
515                        downcast = ops.to_dtype(out, dtype, use_compute_types=False)
516                        out = ops.to_dtype(downcast, dtype)
517                    inputs_loaded.append(out)
518
519                out = fn(*inputs_loaded)
520                if emulate_precision_casts:
521                    # fp16/bf16 kernels are computed in fp32. Casting down to fp16/bf16 here,
522                    # then upcasting again, to emulate casts that eager would do.
523                    downcast = ops.to_dtype(out, dtype, use_compute_types=False)
524                    return ops.to_dtype(downcast, dtype)
525                return out
526
527        if not override_device:
528            device = None
529            for i in inputs:
530                if is_gpu(i.get_device().type):
531                    device = i.get_device()
532                    break
533            if not device:
534                device = inputs[0].get_device()
535
536        device = override_device or device
537
538        return Pointwise.create(
539            device=device,
540            dtype=dtype,
541            inner_fn=inner_fn,
542            ranges=ranges,
543        )
544
545    return inner
546
547
548def make_foreach_pointwise(pw_fn, allow_alpha=False):
549    def inner(*inputs: List[List[TensorBox]], alpha=1):
550        # group by device, whether any of the inputs are dynamic, and whether their types match
551        # (proxy for type promotion)
552        def group_args(arg_pairs):
553            out = defaultdict(list)
554            for i, args in enumerate(arg_pairs):
555                use_foreach = (
556                    not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes
557                )
558                device = None
559                for t in args:
560                    if isinstance(t, TensorBox):
561                        device = t.data.get_device()
562                        break
563                assert (
564                    device is not None
565                ), "foreach op should have at least one tensor arg"
566                out[(device, use_foreach)].append((i, args))
567            return out
568
569        realize_outputs = (
570            len(V.graph.current_node.users) == 0
571            or V.graph.current_node.target in inplace_foreach_ops
572        )
573        for node in V.graph.current_node.users:
574            for user in node.users:
575                if not (user.op == "call_function" and (user.target in foreach_ops)):
576                    realize_outputs = True
577
578        a_list_input = None
579        for input in inputs:
580            if isinstance(input, (list, tuple)):
581                a_list_input = input
582                break
583        assert (
584            a_list_input is not None
585        ), "at least one input must be a list to a foreach op"
586
587        # broadcast scalar inputs to match length of list inputs
588        broadcast_inputs = []
589        for input in inputs:
590            if not isinstance(input, (list, tuple)):
591                broadcast_inputs.append([input] * len(a_list_input))
592            else:
593                broadcast_inputs.append(input)
594
595        groups = group_args(zip(*broadcast_inputs))
596
597        outputs = [None] * len(a_list_input)
598        for (device, use_foreach), group in groups.items():
599            operation_list: List[str] = []
600            for (
601                output_ind,
602                args,
603            ) in group:
604                if allow_alpha:
605                    output = pw_fn(*args, alpha=alpha)
606                else:
607                    output = pw_fn(*args)
608
609                outputs[output_ind] = output
610
611                if (
612                    V.graph.has_feature(device, BackendFeature.FOREACH)
613                    and use_foreach
614                    and realize_outputs
615                ):
616                    output.realize()
617                    operation_list.append(output.get_operation_name())
618
619            if operation_list:
620                V.graph.register_operation_list(operation_list)
621
622        assert all(x is not None for x in outputs)
623        return outputs
624
625    return inner
626
627
628def to_dtype(x: TensorBox, dtype: torch.dtype, copy=False):
629    src_dtype = x.get_dtype()
630    if src_dtype == dtype:
631        return clone(x) if copy else x
632
633    def _to_dtype(x):
634        return ops.to_dtype(x, dtype, src_dtype=src_dtype)
635
636    return make_pointwise(_to_dtype, override_return_dtype=dtype)(x)
637
638
639@register_lowering(prims.convert_element_type, type_promotion_kind=None)
640def _convert_element_type(x: TensorBox, dtype: torch.dtype):
641    if dtype.is_complex or x.get_dtype().is_complex:
642        if x.get_size():
643            # Decompose since aa aten fallback is more friendly for c++ codegen.
644            # This decomposition doesn't work for empty tensor, which needs more investigation.
645            dst = empty_like(x, dtype=dtype)
646            ir.InplaceCopyFallback.create(dst, x)
647            return dst
648        else:
649            return fallback_handler(
650                prims.convert_element_type.default, add_to_fallback_set=False
651            )(x, dtype)
652    return to_dtype(x, dtype, copy=True)
653
654
655def to_dtype_bitcast(x: TensorBox, dtype: torch.dtype, *, copy=False):
656    x_dtype = x.get_dtype()
657    if x_dtype == dtype:
658        return clone(x) if copy else x
659
660    def _get_primitive_bitwidth(dtype):
661        if dtype.is_floating_point:
662            return torch.finfo(dtype).bits
663        else:
664            return torch.iinfo(dtype).bits
665
666    src_bits = _get_primitive_bitwidth(x_dtype)
667    dst_bits = _get_primitive_bitwidth(dtype)
668    if src_bits != dst_bits:
669        # fallback to aten eager implementation for differing bitwidths
670        return fallback_handler(aten.view.dtype)(x, dtype)
671    else:
672        return TensorBox(DtypeView.create(x, dtype))
673
674
675@register_lowering(aten.view.dtype, type_promotion_kind=None)
676def _view_dtype(x: TensorBox, dtype: torch.dtype):
677    if dtype.is_complex or x.get_dtype().is_complex:
678        return TensorBox.create(
679            ir.ComplexView.create(torch.ops.aten.view.dtype, x, dtype)
680        )
681    return to_dtype_bitcast(x, dtype)
682
683
684def to_device(x: TensorBox, device: torch.device, *, copy=False):
685    device = decode_device(device)
686    if x.get_device() == device:
687        return clone(x) if copy else x
688    return TensorBox.create(ir.DeviceCopy.create(x, device))
689
690
691@register_lowering(prims.device_put, type_promotion_kind=None)
692def _device_put(x: TensorBox, device: torch.device):
693    return to_device(x, device, copy=True)
694
695
696def register_pointwise(
697    aten_fn,
698    name=None,
699    broadcast=True,
700    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
701    convert_input_to_bool=False,
702    override_return_dtype=None,
703    override_fn_when_input_bool=None,
704    allow_alpha=False,
705    use_libdevice_for_f64=False,
706    triton_fallback=None,
707):
708    """A pointwise function that maps ops.{name} to inputs"""
709    name = name or aten_fn.__name__
710    fn = ops_wrapper(name)
711    if use_libdevice_for_f64:
712        fn_libdevice = ops_wrapper("libdevice_" + name)
713    if override_fn_when_input_bool is not None:
714        override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool)
715
716    fn = make_pointwise(
717        fn,
718        override_return_dtype=override_return_dtype,
719        override_fn_when_input_bool=override_fn_when_input_bool,
720        override_fn_when_gpu_float64=fn_libdevice if use_libdevice_for_f64 else None,  # type: ignore[possibly-undefined]
721        allow_alpha=allow_alpha,
722        triton_fallback=triton_fallback,
723    )
724    fn = register_lowering(
725        aten_fn,
726        broadcast=broadcast,
727        type_promotion_kind=type_promotion_kind,
728        convert_input_to_bool=convert_input_to_bool,
729    )(fn)
730
731    if hasattr(prims, name):
732        register_lowering(
733            getattr(prims, name),
734            type_promotion_kind=None,
735            convert_input_to_bool=convert_input_to_bool,
736        )(fn)
737    return fn
738
739
740def register_frexp():
741    """A pointwise function that maps ops.frexp to inputs"""
742    name = "frexp"
743    frexp = ops_wrapper("frexp")
744
745    def frexp0(*args, **kwargs):
746        return frexp(*args, **kwargs)[0]  # type: ignore[index] # next PR
747
748    def frexp1(*args, **kwargs):
749        return frexp(*args, **kwargs)[1]  # type: ignore[index] # next PR
750
751    pw_fns = [
752        make_pointwise(frexp0),
753        make_pointwise(frexp1, override_return_dtype=torch.int32),
754    ]
755
756    def fn(*args, **kwargs):
757        return pw_fns[0](*args, **kwargs), pw_fns[1](*args, **kwargs)
758
759    fn = register_lowering(
760        aten.frexp,
761    )(fn)
762
763    if hasattr(prims, name):
764        register_lowering(
765            getattr(prims, name),
766            type_promotion_kind=None,
767        )(fn)
768    return fn
769
770
771register_frexp()
772
773
774def register_foreach_pointwise(
775    aten_fn,
776    pointwise_lowering_fn,
777    allow_alpha=False,
778):
779    fn = make_foreach_pointwise(pointwise_lowering_fn, allow_alpha=allow_alpha)
780    fn = _register_foreach_lowering(aten_fn, fn)
781    return fn
782
783
784@register_lowering(aten.where, broadcast=False, type_promotion_kind=None)
785def where(cond, a, b):
786    def fn(*args):
787        return ops.where(*args)
788
789    if isinstance(a, (float, int)):
790        a = constant_like(a)(b)
791    if isinstance(b, (float, int)):
792        b = constant_like(b)(a)
793
794    args = [cond, a, b]
795    dtype = get_promoted_dtype(
796        args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
797    )
798    indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
799    for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
800        args[i] = x
801    for i in range(len(args)):
802        if isinstance(args[i], ir.Constant):
803            args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size()))
804    return make_pointwise(fn, override_return_dtype=dtype)(
805        args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype)
806    )
807
808
809@register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None)
810def broadcast_tensors(*inputs):
811    if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)):
812        return broadcast_tensors(*inputs[0])
813    target: List[sympy.Expr] = functools.reduce(
814        broadcast_symbolic_shapes, [x.get_size() for x in inputs], []
815    )
816    outputs = []
817    for x in inputs:
818        sizes = x.get_size()
819        if len(sizes) != len(target) or any(
820            ((a == 1 and b != 1) or (a != 1 and b == 1)) for a, b in zip(sizes, target)
821        ):
822            x = expand(x, target)
823        outputs.append(x)
824    return outputs
825
826
827@register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of])
828def nop(x):
829    return x  # AOT autograd handles this for us
830
831
832if hasattr(aten, "lift_fresh"):
833    register_lowering(aten.lift_fresh)(nop)
834
835
836@register_lowering(aten.squeeze, type_promotion_kind=None)
837def squeeze(x, dim=None):
838    assert isinstance(x, TensorBox)
839    if dim is None:
840        return TensorBox(SqueezeView.create(x.data))
841
842    dim = (
843        V.graph.sizevars.evaluate_static_shape(dim)
844        if isinstance(dim, (int, sympy.Expr))
845        else tuple(V.graph.sizevars.evaluate_static_shape(d) for d in dim)
846    )
847    dim = canonicalize_dims(len(x.get_size()), dim)  # type: ignore[call-overload]
848    dims = set((dim,) if not isinstance(dim, tuple) else dim)
849
850    new_shape = []
851    for d, s in enumerate(x.get_size()):
852        if not (d in dims and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1))):
853            new_shape.append(s)
854
855    # squeeze does nothing if the size isn't 1
856    return view(x, new_shape) if new_shape != x.get_size() else x
857
858
859@register_lowering(aten.squeeze_copy, type_promotion_kind=None)
860def squeeze_copy(x, dim=None):
861    return clone(squeeze(x, dim))
862
863
864@register_lowering([aten.squeeze_])
865def squeeze_(x, dim=None):
866    val = squeeze(x, dim)
867    assert isinstance(x, TensorBox)
868    assert isinstance(val, TensorBox)
869    x.data = val.data
870    return x
871
872
873@register_lowering(aten.isinf)
874def isinf(x):
875    if is_integer_type(x):
876        return full_like(x, False, dtype=torch.bool)
877    fn = ops_wrapper("isinf")
878    return make_pointwise(fn, override_return_dtype=torch.bool)(x)
879
880
881@register_lowering(aten.isnan)
882def isnan(x):
883    if is_integer_type(x):
884        return full_like(x, False, dtype=torch.bool)
885    fn = ops_wrapper("isnan")
886    return make_pointwise(fn, override_return_dtype=torch.bool)(x)
887
888
889@register_lowering(aten.ceil)
890def ceil(x):
891    if is_integer_type(x):
892        return clone(x)
893    fn = ops_wrapper("ceil")
894    return make_pointwise(fn)(x)
895
896
897@register_lowering(aten.floor)
898def floor(x):
899    if is_integer_type(x):
900        return clone(x)
901    fn = ops_wrapper("floor")
902    return make_pointwise(fn)(x)
903
904
905@register_lowering(aten.round.default)
906def round(x):
907    if is_integer_type(x):
908        return clone(x)
909    else:
910        fn = ops_wrapper("round")
911        return make_pointwise(fn)(x)
912
913
914@register_lowering(aten.trunc)
915def trunc(x):
916    if is_integer_type(x):
917        return clone(x)
918    fn = ops_wrapper("trunc")
919    return make_pointwise(fn)(x)
920
921
922@register_lowering(aten.expand, type_promotion_kind=None)
923def expand(x, sizes):
924    from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
925
926    (x,) = promote_constants([x])
927    if isinstance(x, ir.BaseConstant):
928        return ExpandView.create(x, tuple(sizes))
929    assert isinstance(x, TensorBox)
930    assert isinstance(sizes, (list, tuple))
931    if tuple(x.get_size()) == tuple(sizes):
932        return x
933
934    if not free_unbacked_symbols(x.get_size()):
935        x_size_product = V.graph.sizevars.size_hint(sympy_product(x.get_size()))
936        # TODO: It would be better to realize the input if any of its sizes
937        # are unbacked, because typically the size will be non-zero.  However,
938        # this cannot be done directly as below as we'll choke on the size_hint
939        # here
940        if x_size_product > 0 and not free_unbacked_symbols(sizes):
941            # maybe realize input before broadcasting it
942            x.mark_reuse(
943                V.graph.sizevars.size_hint(sympy_product(sizes)) // x_size_product
944            )
945    return TensorBox(ExpandView.create(x.data, tuple(sizes)))
946
947
948@register_lowering(prims.broadcast_in_dim, type_promotion_kind=None)
949def broadcast_in_dim(a, shape, broadcast_dimensions):
950    s = list(shape)
951    for broadcast_dimension in broadcast_dimensions:
952        s[broadcast_dimension] = -1
953
954    v = a
955    for idx, x in enumerate(s):
956        if x != -1:
957            v = unsqueeze(v, idx)
958
959    return expand(v, shape)
960
961
962@register_lowering(aten.expand_as, type_promotion_kind=None)
963def expand_as(x, y):
964    return expand(x, y.get_size())
965
966
967@register_lowering(aten.repeat)
968def repeat(x, repeats):
969    old_size = list(x.get_size())
970    if len(repeats) > len(old_size):
971        old_size = [sympy.Integer(1)] * (len(repeats) - len(old_size)) + old_size
972        x = view(x, list(old_size))
973    assert len(repeats) == len(x.get_size())
974
975    new_size = list(x.get_size())
976
977    zero_tensor = False
978    for i in range(len(repeats)):
979        if repeats[i] == 0:
980            zero_tensor = True
981        new_size[i] = new_size[i] * repeats[i]
982
983    if zero_tensor:
984        return empty(new_size, dtype=x.get_dtype(), device=x.get_device())
985    if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)):
986        return clone(expand(x, new_size))
987
988    x_loader: Callable[[Any], Any]
989
990    def inner_fn(index):
991        assert len(index) == len(repeats)
992        index = list(index)
993        for i in range(len(repeats)):
994            if repeats[i] != 1:
995                if old_size[i] == 1:
996                    index[i] = sympy.Integer(0)
997                else:
998                    index[i] = ModularIndexing(index[i], 1, old_size[i])
999        return x_loader(index)
1000
1001    old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size))
1002    if old_size_product > 0:
1003        # maybe realize the input
1004        x.mark_reuse(
1005            V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product
1006        )
1007
1008    x_loader = x.make_loader()
1009    return Pointwise.create(
1010        device=x.get_device(),
1011        dtype=x.get_dtype(),
1012        inner_fn=inner_fn,
1013        ranges=list(new_size),
1014    )
1015
1016
1017@register_lowering(aten._unsafe_view, type_promotion_kind=None)
1018@register_lowering(aten.view, type_promotion_kind=None)
1019@register_lowering(aten.reshape, type_promotion_kind=None)
1020def view(x, sizes):
1021    assert isinstance(x, TensorBox)
1022    assert isinstance(sizes, (list, tuple))
1023    return TensorBox(View.create(x.data, sizes))
1024
1025
1026@register_lowering(aten.permute, type_promotion_kind=None)
1027def permute(x, dims):
1028    assert isinstance(x, TensorBox)
1029    assert isinstance(dims, (list, tuple))
1030    return TensorBox(PermuteView.create(x.data, tuple(dims)))
1031
1032
1033@register_lowering(aten.slice, type_promotion_kind=None)
1034def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True):
1035    assert isinstance(x, TensorBox)
1036    dim = _validate_dim(x, dim, 0)
1037    return TensorBox(ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp))
1038
1039
1040@register_lowering(aten.as_strided, type_promotion_kind=None)
1041def as_strided(x, size, stride, storage_offset=None):
1042    if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView):
1043        # as_strided ignores views
1044        x = x.data.unwrap_view()
1045    x.realize()
1046    if not ir.is_storage_and_layout(x):
1047        raise NotImplementedError(f"unrealized as_strided({x}, ...)")
1048    storage, old_layout = ir.as_storage_and_layout(x)
1049    new_layout = ir.FixedLayout(
1050        old_layout.device,
1051        old_layout.dtype,
1052        [sympy.expand(s) for s in size],
1053        [sympy.expand(s) for s in stride],
1054        sympy.expand(storage_offset or 0),
1055    )
1056    return TensorBox(ir.ReinterpretView(storage, new_layout))
1057
1058
1059@register_lowering(aten.as_strided_, type_promotion_kind=None)
1060def as_strided_(x, size, stride, storage_offset=None):
1061    assert isinstance(x, TensorBox)
1062    x.data = as_strided(x, size, stride, storage_offset).data
1063    return x
1064
1065
1066@register_lowering(aten.as_strided_copy, type_promotion_kind=None)
1067def as_strided_copy(x, size, stride, storage_offset=None):
1068    result = as_strided(x, size, stride, storage_offset)
1069    return clone(result)
1070
1071
1072def pointwise_cat(inputs, dim=0):
1073    # (inclusive, exclusive)
1074    inputs_ranges: List[Tuple[sympy.Expr, sympy.Expr]] = []
1075    prev_end = 0
1076    for inp in inputs:
1077        inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim]))  # type: ignore[arg-type]
1078        prev_end = inputs_ranges[-1][-1]  # type: ignore[assignment]
1079
1080    inputs_loaders = [inp.make_loader() for inp in inputs]
1081
1082    def inner_fn(idx):
1083        idx_dim = ops.index_expr(idx[dim], torch.int64)
1084
1085        masks = []
1086        masked_loads = []
1087        for i in range(len(inputs)):
1088            start = (
1089                ops.constant(0, torch.int64)
1090                if i == 0
1091                else ops.index_expr(inputs_ranges[i][0], torch.int64)
1092            )
1093            end = ops.index_expr(inputs_ranges[i][1], torch.int64)
1094
1095            start_cond = ops.ge(idx_dim, start)
1096            end_cond = ops.lt(idx_dim, end)
1097            if i == 0:
1098                mask = end_cond
1099            elif i == len(inputs) - 1:
1100                mask = start_cond
1101            else:
1102                mask = ops.and_(start_cond, end_cond)
1103
1104            masks.append(mask)
1105            idx_load = list(idx)
1106
1107            # if we're concatting [4], [2]
1108            # when we index the second tensor for 5 we want to index 5 - 4
1109            # Use Identity to prevent expansion of index * stride to keep expression
1110            # in same int bitwidth as shape
1111            idx_load[dim] = Identity(idx_load[dim] - inputs_ranges[i][0])
1112
1113            masked_loads.append(
1114                ops.masked(
1115                    mask,
1116                    lambda: inputs_loaders[i](idx_load),
1117                    0.0,  # this value should be unused
1118                ),
1119            )
1120
1121        next_val = masked_loads[-1]
1122        for i in range((len(inputs)) - 2, -1, -1):
1123            next_val = ops.where(
1124                masks[i],
1125                masked_loads[i],
1126                next_val,
1127            )
1128        return next_val
1129
1130    new_size = list(inputs[0].get_size())
1131    new_size[dim] = inputs_ranges[-1][-1]
1132
1133    return Pointwise.create(
1134        device=inputs[0].get_device(),
1135        dtype=inputs[0].get_dtype(),
1136        inner_fn=inner_fn,
1137        ranges=new_size,
1138    )
1139
1140
1141@register_lowering(quantized_decomposed.quantize_per_channel, type_promotion_kind=None)
1142def quantized_decomposed_quantize_per_channel(
1143    input: TensorBox,
1144    scales: TensorBox,
1145    zero_points: TensorBox,
1146    axis: int,
1147    quant_min: int,
1148    quant_max: int,
1149    dtype: torch.dtype,
1150) -> TensorBox:
1151    assert len(scales.get_size()) == 1, "expect scales 1 dim"
1152    assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
1153
1154    if input.get_dtype() == torch.bfloat16:
1155        input = to_dtype(input, torch.float32)
1156    assert (
1157        input.get_dtype() == torch.float32
1158    ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
1159    assert axis < len(
1160        input.get_size()
1161    ), f"Expecting axis to be < {len(input.get_size())}"
1162
1163    input_loader = input.make_loader()
1164    scales_loader = scales.make_loader()
1165    zero_points_loader = zero_points.make_loader()
1166
1167    def inner_fn(idx):
1168        channel_idx = (idx[axis],)
1169
1170        input = input_loader(idx)
1171        scale = scales_loader(channel_idx)
1172        zero_point = zero_points_loader(channel_idx)
1173        qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
1174
1175        if scales.dtype != torch.float32:
1176            scale = ops.to_dtype(scale, torch.float32)
1177        if zero_points.dtype != torch.int32:
1178            zero_point = ops.to_dtype(zero_point, torch.int32)
1179        inv_scale = ops.reciprocal(scale)
1180        val = ops.round(input * inv_scale) + zero_point
1181        clamped = ops.maximum(qmin, ops.minimum(qmax, val))
1182        return ops.to_dtype(clamped, dtype)
1183
1184    return Pointwise.create(
1185        device=input.get_device(),
1186        dtype=dtype,
1187        inner_fn=inner_fn,
1188        ranges=input.get_size(),
1189    )
1190
1191
1192@register_lowering(
1193    quantized_decomposed.dequantize_per_channel, type_promotion_kind=None
1194)
1195def quantized_decomposed_dequantize_per_channel(
1196    input: TensorBox,
1197    scales: TensorBox,
1198    zero_points: TensorBox,
1199    axis: int,
1200    quant_min: int,
1201    quant_max: int,
1202    dtype: torch.dtype,
1203) -> TensorBox:
1204    assert len(scales.get_size()) == 1, "expect scales 1 dim"
1205    assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
1206    assert (
1207        input.get_dtype() == dtype
1208    ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
1209    assert axis < len(
1210        input.get_size()
1211    ), f"Expecting axis to be < {len(input.get_size())}"
1212
1213    input_loader = input.make_loader()
1214    scales_loader = scales.make_loader()
1215    zero_points_loader = zero_points.make_loader()
1216
1217    def inner_fn(idx):
1218        channel_idx = (idx[axis],)
1219
1220        input = input_loader(idx)
1221        scale = scales_loader(channel_idx)
1222        zero_point = zero_points_loader(channel_idx)
1223
1224        if scales.dtype != torch.float32:
1225            scale = ops.to_dtype(scale, torch.float32)
1226        if zero_points.dtype != torch.float32:
1227            zero_point = ops.to_dtype(zero_point, torch.float32)
1228        val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale
1229        return val
1230
1231    return Pointwise.create(
1232        device=input.get_device(),
1233        dtype=torch.float32,
1234        inner_fn=inner_fn,
1235        ranges=input.get_size(),
1236    )
1237
1238
1239@register_lowering(
1240    quantized_decomposed.quantize_per_tensor.default, type_promotion_kind=None
1241)
1242def quantized_decomposed_quantize_per_tensor_default(
1243    input: TensorBox,
1244    scale: float,
1245    zero_point: int,
1246    quant_min: int,
1247    quant_max: int,
1248    dtype: torch.dtype,
1249) -> TensorBox:
1250    if input.get_dtype() == torch.bfloat16:
1251        input = to_dtype(input, torch.float32)
1252    assert (
1253        input.get_dtype() == torch.float32
1254    ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
1255
1256    input_loader = input.make_loader()
1257
1258    def inner_fn(idx, scale, zero_point):
1259        input = input_loader(idx)
1260        inv_scale, zero_point = _create_constants(
1261            1.0 / scale, zero_point, dtype=torch.float32
1262        )
1263        val = ops.round(input * inv_scale) + zero_point
1264        qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
1265        clamped = ops.minimum(ops.maximum(val, qmin), qmax)
1266        return ops.to_dtype(clamped, dtype)
1267
1268    return Pointwise.create(
1269        device=input.get_device(),
1270        dtype=dtype,
1271        inner_fn=functools.partial(
1272            inner_fn, scale=float(scale), zero_point=int(zero_point)
1273        ),
1274        ranges=input.get_size(),
1275    )
1276
1277
1278@register_lowering(
1279    quantized_decomposed.dequantize_per_tensor.default, type_promotion_kind=None
1280)
1281def quantized_decomposed_dequantize_per_tensor_default(
1282    input: TensorBox,
1283    scale: float,
1284    zero_point: int,
1285    quant_min: int,
1286    quant_max: int,
1287    dtype: torch.dtype,
1288) -> TensorBox:
1289    assert (
1290        input.get_dtype() == dtype
1291    ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
1292
1293    input_loader = input.make_loader()
1294
1295    def inner_fn(idx, scale, zero_point):
1296        input = input_loader(idx)
1297        scale, zero_point = _create_constants(scale, zero_point, dtype=torch.float32)
1298        val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale
1299        return val
1300
1301    return Pointwise.create(
1302        device=input.get_device(),
1303        dtype=torch.float32,
1304        inner_fn=functools.partial(
1305            inner_fn, scale=float(scale), zero_point=int(zero_point)
1306        ),
1307        ranges=input.get_size(),
1308    )
1309
1310
1311@register_lowering(
1312    quantized_decomposed.quantize_per_tensor.tensor, type_promotion_kind=None
1313)
1314def quantized_decomposed_quantize_per_tensor_tensor(
1315    input: TensorBox,
1316    scale: TensorBox,
1317    zero_point: TensorBox,
1318    quant_min: int,
1319    quant_max: int,
1320    dtype: torch.dtype,
1321) -> TensorBox:
1322    if input.get_dtype() == torch.bfloat16:
1323        input = to_dtype(input, torch.float32)
1324    assert (
1325        input.get_dtype() == torch.float32
1326    ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
1327    assert len(scale.get_size()) == 0 or (
1328        len(scale.get_size()) == 1 and scale.get_size()[0] == 1
1329    ), "expect scale as scalar tensor"
1330    assert len(zero_point.get_size()) == 0 or (
1331        len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
1332    ), "expect zero_point as scalar tensor"
1333
1334    input_loader = input.make_loader()
1335    scale_loader = scale.make_loader()
1336    zero_point_loader = zero_point.make_loader()
1337
1338    def inner_fn(idx):
1339        input = input_loader(idx)
1340        _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ())
1341        _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ())
1342        if scale.dtype != torch.float32:
1343            _scale = ops.to_dtype(_scale, torch.float32)
1344        if zero_point.dtype != torch.float32:
1345            _zero_point = ops.to_dtype(_zero_point, torch.float32)
1346        val = ops.round(input * ops.reciprocal(_scale)) + _zero_point
1347        qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
1348        clamped = ops.minimum(ops.maximum(val, qmin), qmax)
1349        return ops.to_dtype(clamped, dtype)
1350
1351    return Pointwise.create(
1352        device=input.get_device(),
1353        dtype=dtype,
1354        inner_fn=inner_fn,
1355        ranges=input.get_size(),
1356    )
1357
1358
1359@register_lowering(
1360    quantized_decomposed.dequantize_per_tensor.tensor, type_promotion_kind=None
1361)
1362def quantized_decomposed_dequantize_per_tensor_tensor(
1363    input: TensorBox,
1364    scale: TensorBox,
1365    zero_point: TensorBox,
1366    quant_min: int,
1367    quant_max: int,
1368    dtype: torch.dtype,
1369) -> TensorBox:
1370    assert len(scale.get_size()) == 0 or (
1371        len(scale.get_size()) == 1 and scale.get_size()[0] == 1
1372    ), "expect scale as scalar tensor"
1373    assert len(zero_point.get_size()) == 0 or (
1374        len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
1375    ), "expect zero_point as scalar tensor"
1376    assert (
1377        input.get_dtype() == dtype
1378    ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
1379
1380    input_loader = input.make_loader()
1381    scale_loader = scale.make_loader()
1382    zero_point_loader = zero_point.make_loader()
1383
1384    def inner_fn(idx):
1385        input = input_loader(idx)
1386        _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ())
1387        _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ())
1388        if scale.dtype != torch.float32:
1389            _scale = ops.to_dtype(_scale, torch.float32)
1390        if zero_point.dtype != torch.float32:
1391            _zero_point = ops.to_dtype(_zero_point, torch.float32)
1392        val = ops.sub(ops.to_dtype(input, torch.float32), _zero_point) * _scale
1393        return val
1394
1395    return Pointwise.create(
1396        device=input.get_device(),
1397        dtype=torch.float32,
1398        inner_fn=inner_fn,
1399        ranges=input.get_size(),
1400    )
1401
1402
1403@register_lowering(aten.cat)
1404def cat(inputs, dim=0):
1405    cpu_device = inputs[0].get_device().type == "cpu"
1406    if cpu_device and all(
1407        input.get_dtype() in [torch.int8, torch.uint8] for input in inputs
1408    ):
1409        # TODO <leslie> Remove this fallback when we support vectorization
1410        # code gen with uint8 data type directly.
1411        for input in inputs:
1412            input.realize()
1413        if all(len(input.get_size()) == 4 for input in inputs):
1414            inputs, _ = require_channels_last(aten.cat, *inputs)
1415        return fallback_handler(aten.cat.default)(inputs, dim)
1416
1417    if len(inputs) == 1:
1418        return clone(inputs[0])
1419
1420    dim = _validate_dim(inputs[0], dim, 0)
1421    dtype = get_promoted_dtype(
1422        *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
1423    )
1424    inputs = [to_dtype(inp, dtype) for inp in inputs]
1425
1426    def unwrap_tensor(x: Union[TensorBox, ir.StorageBox]) -> ir.IRNode:
1427        if isinstance(x, TensorBox):
1428            if isinstance(x.data, ir.BaseView):
1429                return x.data.unwrap_view()
1430            else:
1431                return x.data
1432
1433        if isinstance(x, ir.StorageBox):
1434            return x.data
1435
1436        return x
1437
1438    def is_reduction(t):
1439        return isinstance(t, ir.ComputedBuffer) and isinstance(t.data, ir.Reduction)
1440
1441    def can_fuse_reduction(t):
1442        if isinstance(t, (TensorBox, ir.StorageBox)):
1443            return can_fuse_reduction(unwrap_tensor(t))
1444        return (
1445            is_reduction(t)
1446            or isinstance(t, ir.Pointwise)
1447            and any(
1448                can_fuse_reduction(V.graph.get_buffer(read))
1449                for read in t.get_read_names()
1450            )
1451        )
1452
1453    # fusing reducutions into computed concat buffer can cause regressions.
1454    fusable_reduction = any(can_fuse_reduction(t) for t in inputs)
1455
1456    def should_lower_cat_input(x) -> bool:
1457        # Unrealized inputs will not be storage and layouts, and we dont want to realize
1458        # them in case we want to fuse
1459        if ir.is_storage_and_layout(x):
1460            storage, _ = ir.as_storage_and_layout(x, freeze=False)
1461            return not ir.ConcatKernel.can_realize_into_without_copy(storage)
1462
1463        if isinstance(x, (TensorBox, ir.StorageBox)):
1464            return should_lower_cat_input(unwrap_tensor(x))
1465
1466        if isinstance(x, ir.Pointwise):
1467            return True
1468
1469        return False
1470
1471    # TODO: We observed negative performance impact of pointwise_cat optimization on CPU so disabled it.
1472    #             We will revisit this later after enabling vectorization on index_expr.
1473    if cpu_device:
1474        return TensorBox(ir.ConcatKernel.create(inputs, dim))
1475
1476    def op_count(x):
1477        if isinstance(x, (TensorBox, ir.StorageBox)):
1478            return op_count(unwrap_tensor(x))
1479
1480        # this will correspond to a direct memory read
1481        if not isinstance(x, ir.Pointwise):
1482            return 0
1483
1484        count = x.inner_fn_opcount().num_ops
1485        for read in x.get_read_names():
1486            count += op_count(V.graph.get_buffer(read))
1487
1488        return count
1489
1490    # as of inputs increase, possibility for register spilling also increases
1491    # past a certain threshold of inputs we only fuse if the if the input kernels
1492    # are simple
1493    # not sure if we want to expose to users via config since logic may change in future
1494    MAX_COMPLEX_POINTWISE_CAT = 8
1495    MAX_SIMPLE_OP_COUNT = 2
1496
1497    def additional_pointwise_ops(op: torch._ops.OpOverload):
1498        return op in (aten.cat.default, aten.constant_pad_nd.default)
1499
1500    if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or (
1501        (len(inputs) <= config.max_pointwise_cat_inputs)
1502        and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs)
1503    ):
1504        pointwise_uses = all(
1505            is_pointwise_use(use, additional_pointwise_ops)
1506            for use in V.current_node.users
1507        )
1508        # fuse in case we will be used in a pointwise node, and there are any inputs we
1509        # we can prevent materialization of.
1510        fuse_pointwise_use = (
1511            any(should_lower_cat_input(inp) for inp in inputs) and pointwise_uses
1512        )
1513
1514        # horizontal fuse in case all inputs will require a copy kernel anyway.
1515        # only horizontally fuse pointwise kernels
1516        horizontal_fuse_cat = all(
1517            should_lower_cat_input(inp) for inp in inputs
1518        ) and not any(can_fuse_reduction(t) for t in inputs)
1519        if fuse_pointwise_use or (horizontal_fuse_cat and not fusable_reduction):
1520            return pointwise_cat(inputs, dim)
1521
1522    return TensorBox(ir.ConcatKernel.create(inputs, dim))
1523
1524
1525@register_lowering(aten.diagonal, type_promotion_kind=None)
1526def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1):
1527    original_shape = input.get_size()
1528    num_dims = len(original_shape)
1529    dim1 = canonicalize_dim(idx=dim1, rank=num_dims)
1530    dim2 = canonicalize_dim(idx=dim2, rank=num_dims)
1531
1532    check(
1533        dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
1534    )
1535
1536    offset_negative = V.graph.sizevars.evaluate_expr(sympy.Lt(offset, 0))
1537    if offset_negative:
1538        diag_size = V.graph.sizevars.evaluate_max(
1539            V.graph.sizevars.evaluate_min(
1540                original_shape[dim1] + offset, original_shape[dim2]
1541            ),
1542            0,  # type: ignore[arg-type]
1543        )
1544    else:
1545        diag_size = V.graph.sizevars.evaluate_max(
1546            V.graph.sizevars.evaluate_min(
1547                original_shape[dim1], original_shape[dim2] - offset
1548            ),
1549            0,  # type: ignore[arg-type]
1550        )
1551
1552    base_idx = (0, 0)
1553    if offset_negative:
1554        base_idx = (-offset, 0)
1555    else:
1556        base_idx = (0, offset)
1557
1558    sizes = [s for i, s in enumerate(original_shape) if i not in (dim1, dim2)]
1559    sizes.append(diag_size)
1560
1561    def reindexer(idx):
1562        diag_idx = idx[-1]
1563        original_idx = [0] * len(original_shape)
1564        cur_dim = 0
1565        for d in range(num_dims):
1566            if d == dim1:
1567                original_idx[d] = diag_idx + base_idx[0]
1568            elif d == dim2:
1569                original_idx[d] = diag_idx + base_idx[1]
1570            else:
1571                original_idx[d] = idx[cur_dim]
1572                cur_dim += 1
1573
1574        assert cur_dim == len(original_shape) - 2
1575        return original_idx
1576
1577    return TensorBox(ir.GenericView.create(input, sizes, reindexer))
1578
1579
1580@register_lowering(aten.diagonal_copy, type_promotion_kind=None)
1581def diagonal_copy(input, offset: int = 0, dim1: int = 0, dim2: int = 1):
1582    return clone(diagonal(input, offset, dim1, dim2))
1583
1584
1585@register_lowering(aten.diagonal_scatter, type_promotion_kind=None)
1586def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1):
1587    output = clone(input)
1588    target = diagonal(output, offset, dim1, dim2)
1589    mutate_to(target, src)
1590    return output
1591
1592
1593@register_lowering(aten.select, type_promotion_kind=None)
1594def select(x, dim, idx):
1595    idx = View.handle_negative_index(idx, x.get_size()[dim])
1596    return squeeze(slice_(x, dim, idx, idx + 1), dim)
1597
1598
1599@register_lowering(aten.split, type_promotion_kind=None)
1600def split(x, sizes, dim=0, clamp=True):
1601    dim = _validate_dim(x, dim, 0)
1602    if isinstance(sizes, sympy.Expr):
1603        # TODO: We don't have to guard on sizes per se, but the number
1604        # of splits must stay constant
1605        sizes = V.graph.sizevars.evaluate_static_shape(sizes)
1606    if isinstance(sizes, (int, sympy.Integer)):
1607        x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim])
1608        sizes = [sizes] * ((x_size + sizes - 1) // sizes)
1609    result = []
1610    start = 0
1611    for size in sizes:
1612        end = start + size
1613        result.append(slice_(x, dim, start, end, clamp=clamp))
1614        start = end
1615    return result
1616
1617
1618@register_lowering(aten.split_with_sizes, type_promotion_kind=None)
1619def split_with_sizes(x, sizes, dim=0):
1620    return split(x, sizes, dim, clamp=False)
1621
1622
1623@register_lowering(aten.unbind, type_promotion_kind=None)
1624def unbind(x, dim=0):
1625    dim = _validate_dim(x, dim, 0)
1626    x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim])
1627    result = []
1628    for i in range(x_size):
1629        result.append(select(x, dim, i))
1630    return result
1631
1632
1633@register_lowering(aten.unfold, type_promotion_kind=None)
1634def unfold(x, dimension, size, step):
1635    sizes = x.get_size()
1636    ndim = len(sizes)
1637    dim = canonicalize_dim(ndim, dimension)
1638
1639    if ndim == 0:
1640        return slice_(unsqueeze(x, 0), end=size)
1641
1642    dim_size = sizes[dim]
1643    sizevars = V.graph.sizevars
1644    sizevars.guard_leq(size, dim_size)
1645    sizevars.guard_lt(0, step)  # type: ignore[arg-type]
1646
1647    new_dim_size = FloorDiv(dim_size - size, step) + 1
1648    if sizevars.size_hint(dim_size) > 0:
1649        x.mark_reuse(sizevars.size_hint(CeilDiv(new_dim_size * size, dim_size)))
1650
1651    out_size = [*sizes[:dim], new_dim_size, *sizes[dim + 1 :], size]
1652
1653    def reindexer(idx):
1654        dim_idx = idx[-1] + idx[dim] * step
1655        return (*idx[:dim], dim_idx, *idx[dim + 1 : -1])
1656
1657    return TensorBox(ir.GenericView.create(x, out_size, reindexer))
1658
1659
1660@register_lowering(aten.unsqueeze, type_promotion_kind=None)
1661def unsqueeze(x, dim):
1662    dim = _validate_dim(x, dim, 1)
1663    new_shape = list(x.get_size())
1664    new_shape.insert(dim, sympy.Integer(1))
1665    return view(x, new_shape)
1666
1667
1668@register_lowering(aten.unsqueeze_, type_promotion_kind=None)
1669def unsqueeze_(x, dim):
1670    val = unsqueeze(x, dim)
1671    assert isinstance(x, TensorBox)
1672    assert isinstance(val, TensorBox)
1673    x.data = val.data
1674    return x
1675
1676
1677def _validate_dim(x, dim, offset=0):
1678    dim = V.graph.sizevars.shape_env.evaluate_expr(sympy.sympify(dim))
1679    ndim = len(x.get_size())
1680    if dim < 0:
1681        dim += ndim + offset
1682    assert 0 <= dim < ndim + offset
1683    return dim
1684
1685
1686@register_lowering(aten.glu)
1687def glu(x, dim=-1):
1688    dim = _validate_dim(x, dim, 0)
1689    # TODO: don't guard on static shape here
1690    new_len = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) // 2
1691    a = slice_(x, dim, 0, new_len)
1692    b = slice_(x, dim, new_len, new_len * 2)
1693    return mul(a, sigmoid(b))
1694
1695
1696def fallback_handler(kernel, add_to_fallback_set=True):
1697    if add_to_fallback_set:
1698        fallbacks.add(kernel)
1699
1700    def handler(*args, **kwargs):
1701        def wrap_tensors(x):
1702            return TensorBox.create(x) if isinstance(x, ir.IRNode) else x
1703
1704        return pytree.tree_map(
1705            wrap_tensors, ir.FallbackKernel.create(kernel, *args, **kwargs)
1706        )
1707
1708    return handler
1709
1710
1711@functools.lru_cache(None)
1712def _warn_complex_not_supported():
1713    warnings.warn(
1714        "Torchinductor does not support code generation for complex operators. Performance may be worse than eager."
1715    )
1716
1717
1718# There are some types (CPU) which we accept as input but not as
1719# output.
1720def unsupported_input_tensor(t: torch.Tensor, parent=None):
1721    "Do not support reading or writing to this tensor"
1722    if t.is_complex():
1723        # Complex views are supported with IR ComplexView
1724        if parent and parent.target in (
1725            torch.ops.aten.view.dtype,
1726            torch.ops.prims.convert_element_type.default,
1727        ):
1728            return False
1729        _warn_complex_not_supported()
1730        return True
1731    return False
1732
1733
1734def unsupported_output_tensor(t: torch.Tensor, parent=None):
1735    "Do not support writing tensor but can read from it"
1736    if unsupported_input_tensor(t, parent):
1737        return True
1738    return t.is_cpu and config.disable_cpp_codegen
1739
1740
1741def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=True):
1742    # Custom fallback lowering
1743    if node.target is aten.view_as_complex.default:
1744        return False
1745
1746    # We should be able to remove this special case once `disable_cpp_codegen` is killed.
1747    if node.target is aten.lift_fresh_copy.default:
1748        return False
1749
1750    def check_skip_condition(node, parent, is_output):
1751        if not isinstance(node, torch.fx.Node):
1752            return False
1753
1754        if "val" not in node.meta:
1755            return False
1756
1757        for meta in pytree.tree_leaves(node.meta["val"]):
1758            if not isinstance(meta, torch._subclasses.FakeTensor):
1759                continue
1760
1761            if is_output:
1762                if unsupported_output_tensor(meta, parent):
1763                    return True
1764            else:
1765                if unsupported_input_tensor(meta, parent):
1766                    return True
1767
1768        return False
1769
1770    # only skip codegen if there is a cpu output, not input
1771    for arg in pytree.arg_tree_leaves(*node.args, **node.kwargs):
1772        if check_skip_condition(arg, node, is_output=False):
1773            return True
1774
1775    return check_skip_condition(node, node, is_output=True)
1776
1777
1778def make_fallback(op, layout_constraint=None, warn=True):
1779    assert op not in decompositions, f"both a fallback and a decomp for same op: {op}"
1780    if (
1781        warn
1782        and bool(os.getenv("CI"))
1783        and get_decompositions([op])
1784        # if fallback_random, we allow not decomposing random
1785        and not (
1786            config.fallback_random
1787            and op in torch._decomp.decompositions_for_rng.extra_random_decomps
1788        )
1789    ):
1790        # Note: 'warn' is holdover from when this was a warning, but for ops that previously
1791        # set warn=False we do not want a CI error.
1792        # Ignore the 'suppress errors' configs in CI, as this particular warning happens on startup anyway and is not
1793        # likely to be triggered preferentially on one CI config over another.
1794        if torch._dynamo.config.suppress_errors:
1795            torch._dynamo.config.suppress_errors = False
1796            log.warning(
1797                "A make_fallback error occurred in suppress_errors config,"
1798                " and suppress_errors is being disabled to surface it."
1799            )
1800        raise AssertionError(
1801            f"make_fallback({op}): a decomposition exists, we should switch to it."
1802            " To fix this error, either add a decomposition to core_aten_decompositions (preferred)"
1803            " or inductor_decompositions, and delete the corresponding `make_fallback` line."
1804            " Get help from the inductor team if unsure, don't pick arbitrarily to unblock yourself.",
1805        )
1806
1807    def register_fallback(op_overload):
1808        add_needs_realized_inputs(op_overload)
1809        if layout_constraint is not None:
1810            add_layout_constraint(op_overload, layout_constraint)
1811        return register_lowering(op_overload, type_promotion_kind=None)(
1812            fallback_handler(op_overload)
1813        )
1814
1815    if isinstance(op, torch._ops.OpOverloadPacket):
1816        for ol in op.overloads():
1817            op_overload = getattr(op, ol)
1818            register_fallback(op_overload)
1819    elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
1820        register_fallback(op)
1821    else:
1822        raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}")
1823
1824
1825def philox_rand_offset(shape):
1826    """
1827    TorchInductor offset calculation differs from PyTorch eager offset
1828    calculation for random ops (tl.rand vs torch.rand). In future, we should
1829    strive for same impl for tl.rand and torch.rand.
1830    """
1831    numel = 1
1832    for s in shape:
1833        numel = numel * s
1834    return tensor(numel, dtype=torch.int64)
1835
1836
1837@register_lowering(torch.ops.rngprims.philox_rand, type_promotion_kind=None)
1838def philox_rand(size, seed, offset, stride, device, dtype):
1839    # stride arg is optional and will be used in future for distributed random
1840    # ops. Currently, its unused.
1841    random_pos = ir.FixedLayout(
1842        device,
1843        dtype,
1844        size,
1845        ir.FlexibleLayout.contiguous_strides(size),
1846    ).make_indexer()
1847    seed_loader = seed.make_loader()
1848    offset_loader = offset.make_loader()
1849
1850    def inner_fn(index):
1851        # Both seed and offset in the philox_rand op are tensors.
1852        # torch seed and offsets are of type int64, but tl.rand accepts int32
1853        seed_index_expr = ops.to_dtype(seed_loader([]), torch.int32)
1854        offset_index_expr = ops.to_dtype(offset_loader([]), torch.int32)
1855        # Get the offset'd position
1856        rand_index_expr = ops.add(
1857            ops.index_expr(random_pos(index), torch.int32), offset_index_expr
1858        )
1859        result = ops.rand(
1860            seed_index_expr,
1861            rand_index_expr,
1862        )
1863        return ops.to_dtype(result, dtype)
1864
1865    random_values_node = Pointwise.create(
1866        device=device,
1867        dtype=dtype,
1868        inner_fn=inner_fn,
1869        ranges=list(size),
1870    )
1871
1872    offset_node = philox_rand_offset(size)
1873    return random_values_node, offset_node
1874
1875
1876@register_lowering(aten.native_dropout, type_promotion_kind=None)
1877def native_dropout(x, p, train):
1878    if config.fallback_random:
1879        return pytree.tree_map(
1880            TensorBox.create,
1881            ir.FallbackKernel.create(aten.native_dropout.default, x, p, train),
1882        )
1883    else:
1884        raise AssertionError("should be handled in replace_random.py")
1885
1886
1887@register_lowering(aten.bernoulli_, type_promotion_kind=None)
1888def bernoulli_(x, *args):
1889    assert config.fallback_random or x.get_device() == torch.device(
1890        "cpu"
1891    ), "this should be handled in decomps unless config.fallback_random or the device is CPU"
1892    x.realize()
1893    op_overload = (
1894        aten.bernoulli_.float
1895        if len(args) == 0 or isinstance(args[0], float)
1896        else aten.bernoulli_.Tensor
1897    )
1898    ir.InplaceBernoulliFallback(op_overload, x, *args)
1899    return x
1900
1901
1902@register_lowering(aten.bernoulli.p, type_promotion_kind=None)
1903def bernoulli_p(x, *args):
1904    assert config.fallback_random or x.get_device() == torch.device(
1905        "cpu"
1906    ), "this should be handled in decomps unless config.fallback_random or the device is CPU"
1907    return bernoulli_(clone(x), *args)
1908
1909
1910# This shouldn't be called in general
1911@register_lowering(aten._foobar)
1912def _foobar(_):
1913    raise AssertionError
1914
1915
1916@functools.lru_cache(1)
1917def _warn_triton_random(salt):
1918    log.info("using triton random, expect difference from eager")
1919
1920
1921def warn_triton_random():
1922    # only warn once per graph
1923    _warn_triton_random(V.graph.creation_time)
1924
1925
1926fallback_rand_default = fallback_handler(aten.rand.default)
1927fallback_rand_generator = fallback_handler(aten.rand.generator)
1928fallback_randn_default = fallback_handler(aten.randn.default)
1929fallback_randn_generator = fallback_handler(aten.randn.generator)
1930make_fallback(aten.randint)
1931
1932
1933@register_lowering(aten.rand)
1934def rand(*args, **kwargs):
1935    if kwargs.get("generator", None) is not None:
1936        return fallback_rand_generator(*args, **kwargs)
1937    elif config.fallback_random:
1938        kwargs.pop("generator", None)
1939        return fallback_rand_default(*args, **kwargs)
1940    raise AssertionError("should have been handled in replace_random.py")
1941
1942
1943@register_lowering(aten.randn)
1944def randn(*args, **kwargs):
1945    if kwargs.get("generator", None) is not None:
1946        return fallback_randn_generator(*args, **kwargs)
1947    elif config.fallback_random:
1948        kwargs.pop("generator", None)
1949        return fallback_randn_default(*args, **kwargs)
1950    raise AssertionError("should have been handled in replace_random.py")
1951
1952
1953@register_lowering(inductor_prims.force_stride_order, type_promotion_kind=None)
1954def inductor_force_stride_order(input_tensor, stride):
1955    stride_order = ir.get_stride_order(stride)
1956    return ir.ExternKernel.require_stride_order(input_tensor, stride_order)
1957
1958
1959@register_lowering(inductor_prims.seed, type_promotion_kind=None)
1960def inductor_seed(device: torch.device):
1961    raise AssertionError("should be handled in fuse_seed_creation_pass()")
1962
1963
1964@register_lowering(inductor_prims.seeds, type_promotion_kind=None)
1965def inductor_seeds(count, device):
1966    warn_triton_random()
1967    return TensorBox.create(ir.RandomSeeds(count, decode_device(device)))
1968
1969
1970@register_lowering(inductor_prims.lookup_seed, type_promotion_kind=None)
1971def inductor_lookup_seed(seeds, index):
1972    def inner_fn(_):
1973        return ops.load_seed(seeds.get_name(), index)
1974
1975    return Pointwise.create(
1976        device=seeds.get_device(),
1977        dtype=seeds.get_dtype(),
1978        inner_fn=inner_fn,
1979        ranges=[],
1980    )
1981
1982
1983@register_lowering(inductor_prims.random, type_promotion_kind=None)
1984def inductor_random(size: List[int], seed: TensorBox, mode: str, *, offset: int = 0):
1985    assert not config.fallback_random
1986    assert mode in ("rand", "randn")
1987    size = [*size]
1988    dtype = torch.float32
1989    device = seed.get_device()
1990    random_pos = ir.FixedLayout(
1991        device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset
1992    ).make_indexer()
1993    seed_loader = seed.make_loader()
1994
1995    def inner_fn(index):
1996        return getattr(ops, mode)(
1997            seed_loader([]),
1998            ops.index_expr(random_pos(index), torch.int32),
1999        )
2000
2001    result = Pointwise.create(
2002        device=device,
2003        dtype=dtype,
2004        inner_fn=inner_fn,
2005        ranges=[*size],
2006    )
2007    result.realize()
2008    return result
2009
2010
2011@register_lowering(inductor_prims.randint, type_promotion_kind=None)
2012def inductor_randint(
2013    low: int, high: int, size: List[int], seed: TensorBox, *, offset: int = 0
2014):
2015    assert not config.fallback_random
2016    size = [*size]
2017    dtype = torch.int64
2018    device = seed.get_device()
2019    random_pos = ir.FixedLayout(
2020        device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset
2021    ).make_indexer()
2022    seed_loader = seed.make_loader()
2023
2024    def inner_fn(index):
2025        return ops.randint64(
2026            seed_loader([]),
2027            ops.index_expr(random_pos(index), torch.int32),
2028            ops.index_expr(low, torch.int64),
2029            ops.index_expr(high, torch.int64),
2030        )
2031
2032    return Pointwise.create(
2033        device=device,
2034        dtype=dtype,
2035        inner_fn=inner_fn,
2036        ranges=[*size],
2037    )
2038
2039
2040@register_lowering(aten.bucketize, type_promotion_kind=None)
2041def bucketize(
2042    input: TensorBox,
2043    boundaries: TensorBox,
2044    *,
2045    out_int32: bool = False,
2046    right: bool = False,
2047):
2048    assert len(boundaries.get_size()) == 1
2049
2050    if not (
2051        V.graph.has_feature(input, BackendFeature.BUCKETIZE)
2052        and V.graph.has_feature(boundaries, BackendFeature.BUCKETIZE)
2053    ):
2054        return fallback_handler(aten.bucketize.Tensor, add_to_fallback_set=False)(
2055            input, boundaries, out_int32=out_int32, right=right
2056        )
2057
2058    # The entire boundaries tensor needs to be used by ops.bucketize, so we
2059    # need to realize it into global memory; or in other words, we can't
2060    # guarantee that boundaries.get_name() (used below) will exist unless
2061    # we call boundaries.realize().
2062    boundaries.realize()
2063    boundaries_size = boundaries.get_size()[0]
2064    device = input.get_device()
2065    input_loader = input.make_loader()
2066
2067    index_dtype = torch.int32 if out_int32 else torch.int64
2068
2069    def inner_fn(index):
2070        val = input_loader(index)
2071        indices = ops.bucketize(
2072            val,
2073            boundaries.get_name(),
2074            boundaries_size,
2075            index_dtype,
2076            right,
2077        )
2078
2079        return indices
2080
2081    return Pointwise.create(
2082        device=device,
2083        dtype=index_dtype,
2084        inner_fn=inner_fn,
2085        ranges=input.get_size(),
2086    )
2087
2088
2089def require_dense(_, *args, **kwargs):
2090    args, kwargs = pytree.tree_map_only(
2091        ir.IRNode, ir.ExternKernel.require_stride1, (args, kwargs)
2092    )
2093    return args, kwargs
2094
2095
2096def require_contiguous(_, *args, **kwargs):
2097    args, kwargs = pytree.tree_map_only(
2098        ir.IRNode, ir.ExternKernel.require_contiguous, (args, kwargs)
2099    )
2100    return args, kwargs
2101
2102
2103def require_channels_last(_, *args, **kwargs):
2104    args, kwargs = pytree.tree_map_only(
2105        ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs)
2106    )
2107    return args, kwargs
2108
2109
2110def constrain_to_fx_strides(fx_node, *args, ignore_mutated_args_FIXME=False, **kwargs):
2111    def apply_constraint(arg, fx_arg):
2112        if isinstance(arg, ir.IRNode):
2113            stride_order = ir.get_stride_order(fx_arg.meta["val"].stride())
2114            return ir.ExternKernel.require_stride_order(arg, stride_order)
2115        return arg
2116
2117    # There's a silent incorrectness bug where we if we constrain a mutated arg,
2118    # we may end up cloning it, writing in-place to the clone, and then using
2119    # the original value (instead of the cloned value). Our short-term fix for this
2120    # is to never constrain mutated args; longer term we do want to fix this.
2121    # https://github.com/pytorch/pytorch/issues/128084
2122    if ignore_mutated_args_FIXME:
2123        assert isinstance(fx_node.target, torch._ops.OpOverload)
2124        schema = fx_node.target._schema
2125
2126        def maybe_apply_constraint(schema_arg, arg, fx_arg):
2127            if schema_arg.alias_info is not None and schema_arg.alias_info.is_write:
2128                return arg
2129            return apply_constraint(arg, fx_arg)
2130
2131        new_args = []
2132        new_kwargs = {}
2133
2134        for idx, (arg, fx_arg) in enumerate(zip(args, fx_node.args)):
2135            schema_arg = schema.arguments[idx]
2136            new_args.append(maybe_apply_constraint(schema_arg, arg, fx_arg))
2137
2138        schema_kwargs = {arg.name: arg for arg in schema.arguments}
2139
2140        for key in kwargs.keys():
2141            arg = kwargs[key]
2142            fx_arg = fx_node.kwargs[key]
2143            schema_arg = schema_kwargs[key]
2144            new_kwargs[key] = maybe_apply_constraint(schema_arg, arg, fx_arg)
2145
2146        return tuple(new_args), new_kwargs
2147
2148    args = tuple(
2149        apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
2150    )
2151    kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
2152    return args, kwargs
2153
2154
2155# TODO(jansel): we should implement decomps or lowerings for these
2156# https://github.com/pytorch/torchdynamo/issues/327
2157FALLBACK_ALLOW_LIST = {
2158    "torchvision::roi_align",
2159}
2160
2161
2162def sdpa_constraint(fx_node, *args, **kwargs):
2163    # sdpa requires dense last dimension]
2164
2165    def apply_constraint(arg, fx_arg):
2166        if not isinstance(arg, ir.IRNode):
2167            return arg
2168
2169        meta_val = fx_arg.meta["val"]
2170        meta_stride = meta_val.stride()
2171
2172        stride_order = ir.get_stride_order(meta_stride)
2173        if stride_order and stride_order[-1] != 0:
2174            # contiguous stride order
2175            stride_order = list(reversed(range(len(arg.get_size()))))
2176
2177        if not meta_val.is_cuda:
2178            return ir.ExternKernel.require_stride_order(arg, stride_order)
2179
2180        # This is the minimum alignment required by SDPA kernels for attention_bias.
2181        # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask
2182        ALIGNMENT = 8
2183
2184        assert isinstance(arg, TensorBox)
2185        if len(arg.get_size()) not in (3, 4):
2186            return arg
2187
2188        def is_aligned_realized_tensor(x):
2189            aligned_strides = all(
2190                (V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0
2191                for i in range(len(x.get_stride()) - 1)
2192            )
2193            return (
2194                V.graph.sizevars.size_hint(x.get_stride()[-1])
2195            ) == 1 and aligned_strides
2196
2197        try:
2198            arg.get_stride()
2199            if is_aligned_realized_tensor(arg):
2200                return V.graph.try_match_insignificant_strides(
2201                    ir.ExternKernel.realize_input(arg), meta_stride
2202                )
2203        except AttributeError:
2204            pass
2205
2206        def is_aligned(x):
2207            return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0
2208
2209        if isinstance(arg.data, ir.BaseView):
2210            if not is_aligned(arg):
2211                if is_aligned(arg.unwrap_view()):
2212                    return V.graph.try_match_insignificant_strides(
2213                        ir.ExternKernel.realize_input(arg), meta_stride
2214                    )
2215
2216        return ir.ExternKernel.require_stride_order(arg, stride_order)
2217
2218    args = tuple(
2219        apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
2220    )
2221    kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
2222    return args, kwargs
2223
2224
2225# WIP
2226make_fallback(aten._adaptive_avg_pool3d)  # @isuruf
2227make_fallback(aten.adaptive_max_pool3d)  # @isuruf
2228make_fallback(aten.fractional_max_pool3d)  # @isuruf
2229make_fallback(aten.max_pool3d_with_indices)  # @isuruf (can this one be implemented?)
2230
2231
2232# 1) Easy
2233make_fallback(aten.uniform, warn=False)
2234make_fallback(aten.exponential.default, warn=False)  # (fails accuracy on test_torch.py)
2235make_fallback(aten._pdist_forward)  # Has decomp. Needs benchmarks
2236make_fallback(aten.soft_margin_loss_backward, warn=False)  # py_impl?
2237make_fallback(aten.searchsorted)  # bucketized is implemented (see eager impl)
2238
2239
2240# 1.5) Easy or Impossible
2241make_fallback(aten._cdist_forward)  # p=2 should be feasible
2242make_fallback(aten._cdist_backward)
2243
2244# 2) Medium
2245make_fallback(aten.max_unpool2d)
2246make_fallback(aten.max_unpool3d)
2247make_fallback(aten._trilinear)
2248
2249
2250# 3) Difficult
2251# Scans
2252# See the discussion at
2253# https://dev-discuss.pytorch.org/t/pytorch-sparse-gnn-compiler-rfc/1644/19
2254make_fallback(aten.segment_reduce.default)
2255make_fallback(aten._segment_reduce_backward.default)
2256
2257# Histogram (need to implement Histogram IR)
2258make_fallback(aten.histc)
2259make_fallback(aten.histogram.bin_ct)
2260make_fallback(aten._histogramdd_bin_edges.default)
2261make_fallback(aten._histogramdd_from_bin_cts.default)
2262
2263# Need templated kernel
2264make_fallback(aten.addbmm)
2265make_fallback(aten._addmm_activation, warn=False)
2266
2267# Need templated kernel. Probably impossible to write efficiently
2268make_fallback(aten.convolution_backward, constrain_to_fx_strides)
2269make_fallback(aten._cudnn_rnn, require_dense)
2270make_fallback(aten._cudnn_rnn_backward, require_contiguous)
2271
2272# Haven't checked but sound difficult / impossible
2273make_fallback(aten._embedding_bag, require_contiguous)
2274make_fallback(aten._embedding_bag_forward_only, require_contiguous)
2275make_fallback(aten._embedding_bag_backward)
2276make_fallback(aten._embedding_bag_per_sample_weights_backward)
2277make_fallback(aten._embedding_bag_per_sample_weights_backward)
2278make_fallback(aten._fused_moving_avg_obs_fq_helper)
2279make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
2280
2281
2282# 4) Backwards (try py_impl'ing them) when fwd is written as a decomp
2283make_fallback(aten.max_pool3d_with_indices_backward)
2284make_fallback(aten._adaptive_avg_pool2d_backward, require_dense)
2285make_fallback(aten._adaptive_avg_pool3d_backward)
2286make_fallback(aten.adaptive_max_pool2d_backward)
2287make_fallback(aten.adaptive_max_pool3d_backward)
2288make_fallback(aten.fractional_max_pool2d_backward)
2289make_fallback(aten.fractional_max_pool3d_backward)
2290make_fallback(aten.replication_pad1d_backward)
2291make_fallback(aten.replication_pad2d_backward)
2292make_fallback(aten.upsample_linear1d_backward)
2293make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
2294make_fallback(aten.upsample_trilinear3d_backward)
2295make_fallback(aten.grid_sampler_2d_backward, require_dense)
2296make_fallback(aten._pdist_backward)
2297
2298
2299# 5) Impossible (missing triton/CPU features)
2300
2301# Sorting / Sorting-like
2302make_fallback(aten.sort)
2303make_fallback(aten.sort.stable)
2304make_fallback(aten.kthvalue)
2305make_fallback(aten.topk)
2306make_fallback(aten.mode)
2307make_fallback(aten.median)
2308make_fallback(aten.nanmedian)
2309make_fallback(aten.randperm)
2310# see: https://github.com/pytorch/pytorch/pull/121354
2311make_fallback(aten.resize_)
2312make_fallback(aten.resize_as_)
2313
2314# Linalg
2315make_fallback(aten._linalg_det)
2316make_fallback(aten.linalg_householder_product)
2317make_fallback(aten.linalg_inv_ex)
2318make_fallback(aten.linalg_ldl_factor_ex)
2319make_fallback(aten.linalg_ldl_solve)
2320make_fallback(aten.linalg_lu)
2321make_fallback(aten.linalg_lu_factor_ex)
2322make_fallback(aten.linalg_lu_solve)
2323make_fallback(aten.linalg_matrix_exp)
2324make_fallback(aten.linalg_qr)
2325make_fallback(aten._linalg_slogdet)
2326make_fallback(aten._linalg_solve_ex)
2327make_fallback(aten.linalg_solve_triangular)
2328make_fallback(aten._linalg_svd)
2329make_fallback(aten.lu_unpack)
2330make_fallback(aten.ormqr)
2331make_fallback(aten._linalg_check_errors)
2332make_fallback(aten.linalg_pinv.atol_rtol_tensor)
2333make_fallback(aten._linalg_eigh)
2334make_fallback(aten.triangular_solve)
2335make_fallback(aten.linalg_cholesky_ex)
2336make_fallback(aten.cholesky_inverse)
2337make_fallback(aten.cholesky_solve)
2338make_fallback(aten.geqrf)
2339make_fallback(aten._fft_r2c)  # needs complex as well
2340
2341# Data dependent (are these necessary?)
2342make_fallback(aten.nonzero.default)
2343
2344# Misc
2345make_fallback(aten.gcd.default, warn=False)
2346make_fallback(aten._thnn_fused_lstm_cell, require_dense)
2347make_fallback(torch._prims.rng_prims.run_and_save_rng_state)
2348make_fallback(torch._prims.rng_prims.run_with_rng_state)
2349
2350# Implmented / Half implemented
2351# Scans. Implemented for CUDA, missing CPU
2352make_fallback(aten.masked_scatter)
2353make_fallback(aten.masked_scatter_backward)
2354
2355# Complex number support
2356make_fallback(aten.view_as_complex, require_contiguous)
2357make_fallback(aten.angle)  # needs complex
2358
2359# Needs efficentzerotensor
2360make_fallback(aten._efficientzerotensor)
2361
2362# Needs Sparse
2363make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
2364make_fallback(aten.to_sparse)
2365make_fallback(aten._to_sparse)
2366
2367# Needs dimname support
2368make_fallback(aten.zeros.names)
2369
2370# 6) Pattern-matched
2371make_fallback(
2372    aten._scaled_dot_product_efficient_attention.default,
2373    sdpa_constraint,
2374    warn=False,
2375)
2376make_fallback(
2377    aten._scaled_dot_product_efficient_attention_backward.default,
2378    sdpa_constraint,
2379    warn=False,
2380)
2381make_fallback(
2382    aten._scaled_dot_product_flash_attention.default,
2383    sdpa_constraint,
2384    warn=False,
2385)
2386make_fallback(
2387    aten._scaled_dot_product_flash_attention_backward.default,
2388    sdpa_constraint,
2389    warn=False,
2390)
2391make_fallback(
2392    aten._scaled_dot_product_cudnn_attention.default,
2393    sdpa_constraint,
2394    warn=False,
2395)
2396make_fallback(
2397    aten._scaled_dot_product_cudnn_attention_backward.default,
2398    sdpa_constraint,
2399    warn=False,
2400)
2401make_fallback(
2402    aten._scaled_dot_product_flash_attention_for_cpu.default,
2403    sdpa_constraint,
2404    warn=False,
2405)
2406make_fallback(
2407    aten._scaled_dot_product_flash_attention_for_cpu_backward.default,
2408    sdpa_constraint,
2409    warn=False,
2410)
2411make_fallback(aten._flash_attention_forward.default, sdpa_constraint)
2412make_fallback(aten._flash_attention_backward.default, sdpa_constraint)
2413make_fallback(aten._efficient_attention_forward.default, sdpa_constraint)
2414make_fallback(aten._efficient_attention_backward.default, sdpa_constraint)
2415
2416# index_reduce requires fallback when use_scatter_fallback(...) returns True
2417make_fallback(aten.index_reduce)
2418
2419
2420# Register with type_promotion_kind None.
2421# For example, fp16.copy_(fp32) should **not** promote the first input's dtype.
2422@register_lowering(aten.copy, type_promotion_kind=None)
2423def copy(self, src, non_blocking=False):
2424    x = src
2425    if self.get_device() != src.get_device():
2426        x = to_device(x, self.get_device())
2427    if self.get_dtype() != src.get_dtype():
2428        x = to_dtype(x, self.get_dtype())
2429
2430    if self.get_size() != src.get_size():
2431        out = expand(x, self.get_size())
2432        return clone(out)
2433    return clone(x)
2434
2435
2436@register_lowering(aten.clone)
2437def clone(x, *, memory_format=None):
2438    # TODO(jansel): memory format
2439    return Pointwise.create(
2440        device=x.get_device(),
2441        dtype=x.get_dtype(),
2442        inner_fn=x.make_loader(),
2443        ranges=list(x.get_size()),
2444    )
2445
2446
2447def clone_preserve_reinterpret_view(x):
2448    reinterpret_view_layouts = []
2449    if isinstance(x, TensorBox) and isinstance(x.data, ir.ReinterpretView):
2450        x = x.data  # unwrap TensorBox
2451        while isinstance(x, ir.ReinterpretView):
2452            reinterpret_view_layouts.append(x.get_layout())
2453            x = x.data
2454        x = TensorBox(x)
2455
2456    x = clone(x)
2457
2458    if reinterpret_view_layouts:
2459        x = x.data  # unwrap TensorBox
2460        for layout in reinterpret_view_layouts[::-1]:
2461            x = ir.ReinterpretView(x, layout)
2462        x = TensorBox(x)
2463
2464    return x
2465
2466
2467if hasattr(aten, "lift_fresh_copy"):
2468    register_lowering(aten.lift_fresh_copy)(clone)
2469
2470
2471@register_lowering(prims.iota)
2472def iota(
2473    length,
2474    *,
2475    start,
2476    step,
2477    dtype,
2478    device,
2479    requires_grad,
2480):
2481    def fn(index):
2482        return ops.index_expr(step * index[0] + start, dtype=dtype)
2483
2484    return Pointwise.create(
2485        device=decode_device(device),
2486        dtype=dtype,
2487        inner_fn=fn,
2488        ranges=[length],
2489    )
2490
2491
2492@register_lowering(aten.select_scatter, type_promotion_kind=None)
2493def select_scatter(x, src, dim: int, index: int):
2494    assert x.get_dtype() == src.get_dtype()
2495    x_loader = x.make_loader()
2496    dim = _validate_dim(x, dim, 0)
2497    if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)):
2498        index = index + x.get_size()[dim]
2499    V.graph.sizevars.guard_leq(0, index)  # type: ignore[arg-type]
2500    V.graph.sizevars.guard_lt(index, x.get_size()[dim])  # type: ignore[arg-type]
2501    src = expand(unsqueeze(src, dim), x.get_size())
2502    src_loader = src.make_loader()
2503
2504    def inner_fn(idx):
2505        return ops.where(
2506            ops.eq(
2507                ops.index_expr(idx[dim], torch.int32),
2508                ops.index_expr(index, torch.int32),
2509            ),
2510            src_loader(idx),
2511            x_loader(idx),
2512        )
2513
2514    return Pointwise.create(
2515        device=x.get_device(),
2516        dtype=x.get_dtype(),
2517        inner_fn=inner_fn,
2518        ranges=list(x.get_size()),
2519    )
2520
2521
2522@register_lowering(aten.slice_scatter, type_promotion_kind=None)
2523def slice_scatter(x, src, dim=0, start=None, end=None, step=1):
2524    assert x.get_dtype() == src.get_dtype()
2525    x_loader = x.make_loader()
2526    dim = _validate_dim(x, dim, 0)
2527    dim_size = x.get_size()[dim]
2528
2529    start, end = ir.SliceView.normalize_start_end(x, dim, start, end)
2530
2531    src_size = list(x.get_size())
2532    src_size[dim] = FloorDiv(end - start + (step - 1), step)
2533    src = expand(src, src_size)
2534    src_loader = src.make_loader()
2535
2536    def inner_fn(idx):
2537        if start == 0 and end == dim_size and step == 1:
2538            # selecting every element is the same as just src.clone()
2539            return src_loader(idx)
2540
2541        idx_dim = ops.index_expr(idx[dim], torch.int64)
2542        src_idx = list(idx)
2543        src_idx[dim] = FloorDiv(idx[dim] - start, step)
2544
2545        mask = []
2546        if start != 0:
2547            mask.append(
2548                ops.ge(
2549                    idx_dim,
2550                    ops.index_expr(sympy.expand(start), torch.int64),
2551                )
2552            )
2553        if end != dim_size:
2554            mask.append(
2555                ops.lt(
2556                    idx_dim,
2557                    ops.index_expr(sympy.expand(end), torch.int64),
2558                )
2559            )
2560        if step != 1:
2561            mask.append(
2562                ops.eq(
2563                    ops.index_expr(
2564                        ModularIndexing(idx[dim] - start, 1, step), torch.int64
2565                    ),
2566                    ops.constant(0, torch.int64),
2567                )
2568            )
2569        assert mask
2570        mask = functools.reduce(ops.and_, mask)
2571        src_val = ops.masked(
2572            mask,
2573            lambda: src_loader(src_idx),
2574            0 if is_integer_type(x) else 0.0,
2575        )
2576        return ops.where(
2577            mask,
2578            src_val,
2579            x_loader(idx),
2580        )
2581
2582    return Pointwise.create(
2583        device=x.get_device(),
2584        dtype=x.get_dtype(),
2585        inner_fn=inner_fn,
2586        ranges=list(x.get_size()),
2587    )
2588
2589
2590def _unwrap(x):
2591    if isinstance(x, (list, tuple)) and len(x) > 0:
2592        return _unwrap(x[0])
2593    return x
2594
2595
2596@register_lowering([torch.tensor, aten.scalar_tensor])
2597def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False):
2598    assert_nyi(layout in (None, torch.strided), f"layout={layout}")
2599    assert_nyi(not pin_memory, "pin_memory")
2600    if isinstance(_unwrap(data), int):
2601        dtype = dtype or torch.int64
2602    else:
2603        dtype = dtype or torch.get_default_dtype()
2604
2605    ranges: List[sympy.Expr] = []
2606
2607    if isinstance(data, sympy.Basic):
2608
2609        def inner_fn(index):
2610            return ops.index_expr(data, dtype)
2611
2612    elif isinstance(data, (float, int)):
2613
2614        def inner_fn(index):
2615            return ops.constant(data, dtype)
2616
2617    elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8:
2618        # inline small tensors
2619        ranges.append(sympy.Integer(len(data)))
2620
2621        def inner_fn(index):
2622            def binary_search(start, end):
2623                assert start < end
2624                if end - start == 1:
2625                    return ops.constant(data[start], dtype)
2626                mid = (end - start) // 2 + start
2627                return ops.where(
2628                    ops.lt(
2629                        ops.index_expr(index[0], torch.int64),
2630                        ops.constant(mid, torch.int64),
2631                    ),
2632                    binary_search(start, mid),
2633                    binary_search(mid, end),
2634                )
2635
2636            if len(data) == 0:
2637                return ops.constant(0, dtype)
2638            return binary_search(0, len(data))
2639
2640    else:
2641        return V.graph.add_tensor_constant(
2642            torch.tensor(data, dtype=dtype, device=device)
2643        )
2644
2645    return Pointwise.create(
2646        device=decode_device(device),
2647        dtype=dtype,
2648        inner_fn=inner_fn,
2649        ranges=ranges,
2650    )
2651
2652
2653@register_lowering(torch.as_tensor)
2654def as_tensor(data, dtype=None, device=None):
2655    if isinstance(data, TensorBox):
2656        if dtype is not None:
2657            data = to_dtype(data, dtype)
2658        if device is not None:
2659            data = to_device(data, device)
2660        return data
2661    return tensor(data, dtype=dtype, device=device)
2662
2663
2664@register_lowering(torch.LongTensor)
2665def long_tensor(data):
2666    return tensor(data, dtype=torch.int64)
2667
2668
2669@register_lowering(aten._local_scalar_dense)
2670def _local_scalar_dense(data):
2671    from torch.fx.experimental.symbolic_shapes import resolve_unbacked_bindings
2672
2673    # This is interesting!  Most lowerings return tensors, so you can just
2674    # return the buffer you allocated and it will get used (or not used, if
2675    # it's dead.)  But _local_scalar_dense (aka item) returns an int,
2676    # not a Tensor, so you would have a type mismatch if you return a buffer;
2677    # we are obligated to return a sympy expression instead.  However,
2678    # we need to actually codegen the .item() call somehow.  We do this
2679    # by registering a faux buffer for the DynamicScalar IR node, which is
2680    # solely responsible for generating this .item().  The buffer is
2681    # not used for anything (notice we discard it); at codegen time,
2682    # the "buffer" just gets assigned None.
2683    unbacked_bindings = resolve_unbacked_bindings(
2684        V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
2685    )
2686    assert len(unbacked_bindings) == 1, unbacked_bindings
2687    # NB: Have to be very careful here.  V.graph.current_node.meta["val"]
2688    # seemingly also contains a symbol which you want to do binding for,
2689    # but it actually isn't.  In particular, if we have later performed
2690    # a deferred runtime assert saying that u0 == s0, you will actually
2691    # see s0 from expr!  This is bad because we need to actually generate
2692    # the assert that says u0 == s0, so we need to know where to get u0
2693    # from (this call).  In particular, we must use unbacked_bindings, which
2694    # is guaranteed to have the original, unreplaced symbol in question.
2695    #
2696    # NB2: Another thing we have to be very careful about are symbol bindings
2697    # that require nontrivial refinement, e.g., when you have a binding site
2698    # x: Sym(u0 * 4) = y.item().  Here, the code generation must do a division
2699    # in order to appropriately bind u0.  This is communicated via the keypath
2700    # in unbacked_bindings, and we need to hold onto it in order to generate
2701    # code appropriately for this case.
2702    binding_sym, keypath = next(iter(unbacked_bindings.items()))
2703    buffer = ir.DynamicScalar(binding_sym, keypath, data)
2704    buffer.name = V.graph.register_buffer(buffer)
2705    V.graph.register_operation(buffer)
2706    # NB: the replaced expr is OK to use directly downstream, we want
2707    # simplifications in this case!
2708    val = V.graph.current_node.meta["val"]
2709    if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
2710        return val.node.expr
2711    else:
2712        return sympy.sympify(val)
2713
2714
2715@register_lowering(aten._assert_scalar)
2716def _assert_scalar(data, msg):
2717    # NB: These will be handled at codegen time
2718    # Not sure if we are guaranteed to be able to serve out truth from the
2719    # deferred_runtime_asserts, TODO: try this assert out
2720    # assert bool(data.scalar), data
2721    return None
2722
2723
2724def _full(fill_value, device, dtype, size):
2725    value = fill_value
2726    if not isinstance(fill_value, (int, float)) and hasattr(value, "value"):
2727        value = value.value
2728
2729    if isinstance(value, (int, float)):
2730
2731        def inner_fn(index):
2732            return ops.constant(value, dtype)
2733
2734    elif isinstance(value, sympy.Basic):
2735
2736        def inner_fn(index):
2737            return ops.index_expr(value, dtype)
2738
2739    else:
2740        assert len(value.get_size()) == 0
2741        value_loader = value.make_loader()
2742
2743        def inner_fn(index):
2744            return value_loader([])
2745
2746    return Pointwise.create(
2747        device=device,
2748        dtype=dtype,
2749        inner_fn=inner_fn,
2750        ranges=list(size),
2751    )
2752
2753
2754@register_lowering(aten.full_like, type_promotion_kind=None)
2755def full_like(x, fill_value, **kwargs):
2756    return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)
2757
2758
2759def tensor_constructor(fill_value):
2760    # torch.zeros, torch.ones, etc
2761    def inner(
2762        *size,
2763        names=None,
2764        dtype=None,
2765        device=None,
2766        layout=None,
2767        pin_memory=False,
2768        memory_format=None,
2769    ):
2770        assert_nyi(names is None, "named tensors")
2771        assert_nyi(layout in (None, torch.strided), f"layout={layout}")
2772        assert_nyi(not pin_memory, "pin_memory")
2773        device = decode_device(device)
2774        dtype = dtype or torch.get_default_dtype()
2775        if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
2776            size = tuple(size[0])
2777        # See https://github.com/pytorch/pytorch/issues/118102
2778        # All sizes at lowering time should be sympy.Symbol, not SymInt!
2779        for s in size:
2780            assert not isinstance(s, torch.SymInt)
2781        size = [sympy.expand(s) for s in size]
2782        return _full(fill_value, device, dtype, size)
2783
2784    return inner
2785
2786
2787@register_lowering([torch.empty, aten.empty])
2788def empty(
2789    *size,
2790    names=None,
2791    dtype=None,
2792    layout=None,
2793    device=None,
2794    pin_memory=None,
2795    memory_format=None,
2796):
2797    assert_nyi(names is None, "named tensors")
2798    device = decode_device(device)
2799    if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
2800        size = tuple(size[0])
2801    return empty_strided(
2802        size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
2803    )
2804
2805
2806def create_tensor_like(creation_fn):
2807    """
2808    Shim to convert X_like(...) into X(...).  For example zeros_like() into zeros().
2809    """
2810
2811    def _constant_like(
2812        x, *, dtype=None, device=None, layout=None, pin_memory=False, memory_format=None
2813    ):
2814        assert_nyi(not pin_memory, "pin_memory")
2815        assert_nyi(layout in (None, torch.strided), f"layout={layout}")
2816        if dtype is None:
2817            dtype = x.get_dtype()
2818        else:
2819            dtype = decode_dtype(dtype)
2820        device = device or x.get_device()
2821        size = list(x.get_size())
2822        return creation_fn(
2823            size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory
2824        )
2825
2826    return _constant_like
2827
2828
2829def constant_like(fill_value):
2830    return create_tensor_like(tensor_constructor(fill_value))
2831
2832
2833empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty))
2834ones_like = create_tensor_like(tensor_constructor(1))
2835zeros_like = create_tensor_like(tensor_constructor(0))
2836
2837
2838def new_constant(fill_value):
2839    def _new_constant(
2840        x, size, *, dtype=None, layout=None, device=None, pin_memory=None
2841    ):
2842        assert isinstance(size, (list, tuple))
2843        assert_nyi(not pin_memory, "pin_memory")
2844        assert_nyi(layout in (None, torch.strided), f"layout={layout}")
2845        dtype = decode_dtype(dtype) or x.get_dtype()
2846        device = device or x.get_device()
2847        size = [sympy.Integer(s) for s in size]
2848        return _full(fill_value, device, dtype, size)
2849
2850    return _new_constant
2851
2852
2853@register_lowering(aten.new_empty)
2854def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None):
2855    if dtype is None:
2856        dtype = x.get_dtype()
2857    if device is None:
2858        device = x.get_device()
2859    return empty_strided(
2860        size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
2861    )
2862
2863
2864@register_lowering(aten.empty_strided)
2865def empty_strided(
2866    size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
2867):
2868    assert isinstance(size, (list, tuple))
2869    assert isinstance(stride, (list, tuple, type(None)))
2870    assert_nyi(not pin_memory, "pin_memory")
2871    assert_nyi(layout in (None, torch.strided), f"layout={layout}")
2872    dtype = decode_dtype(dtype) or torch.get_default_dtype()
2873    device = device or torch.tensor(0.0).device
2874    pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size)
2875    pointwise.realize()
2876    buffer = pointwise.data.data
2877    # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode
2878    buffer.data.ranges = [0] * len(size)
2879    assert isinstance(buffer, ir.ComputedBuffer)
2880    size = [sympy.expand(s) for s in size]
2881    stride = (
2882        [sympy.expand(s) for s in stride]
2883        if stride
2884        else ir.FlexibleLayout.contiguous_strides(size)
2885    )
2886    buffer.layout = ir.FixedLayout(
2887        device=device,
2888        dtype=dtype,
2889        size=size,
2890        stride=stride,
2891    )
2892    return pointwise
2893
2894
2895@register_lowering(aten.new_empty_strided)
2896def new_empty_strided(
2897    x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
2898):
2899    if dtype is None:
2900        dtype = x.get_dtype()
2901    if device is None:
2902        device = x.get_device()
2903    return empty_strided(
2904        size, stride, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
2905    )
2906
2907
2908@register_lowering(prims.copy_strided.default)
2909def copy_strided(x, stride):
2910    stride = [V.graph.sizevars.size_hint(s) for s in stride]
2911    stride_order = sorted(range(len(stride)), key=stride.__getitem__)
2912    return ir.ExternKernel.require_stride_order(x, stride_order)
2913
2914
2915@register_lowering([torch.full, aten.full])
2916def full(size, fill_value, **kwargs):
2917    assert kwargs.get("dtype") is not None, "dtype should be handled by decomposition"
2918    return tensor_constructor(fill_value)(size, **kwargs)
2919
2920
2921@register_lowering(aten.gather, type_promotion_kind=None)
2922def gather(x, dim, index, sparse_grad=False):
2923    # sparse_grad doesn't affect forward computation,
2924    # and backward tracing is taken care of by AOT Autograd
2925    assert isinstance(x, TensorBox)
2926    if index.get_numel() == 0:
2927        # Empty index case. Return an empty array with the same shape
2928        return new_empty(x, index.get_size())
2929
2930    assert index.get_dtype() == torch.int64
2931    size = x.get_size()
2932    offset = len(size) == 0
2933    dim = _validate_dim(x, dim, offset)
2934
2935    if offset:
2936        x = expand(x, [1])
2937        size = [1]
2938
2939    x_loader = x.make_loader()
2940    index_loader = index.make_loader()
2941
2942    def fn(idx):
2943        idx = list(idx)
2944        gather_idx = ops.indirect_indexing(index_loader(idx), size[dim])
2945        if len(idx) == 0:
2946            idx = [gather_idx]
2947        else:
2948            idx[dim] = gather_idx
2949        return x_loader(idx)
2950
2951    return Pointwise.create(
2952        device=x.get_device(),
2953        dtype=x.get_dtype(),
2954        inner_fn=fn,
2955        ranges=index.get_size(),
2956    )
2957
2958
2959@register_lowering(aten.embedding, type_promotion_kind=None)
2960def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
2961    assert not sparse
2962    assert isinstance(weight, TensorBox)
2963    assert isinstance(indices, TensorBox)
2964    assert "int" in str(indices.get_dtype())
2965
2966    weight_loader = weight.make_loader()
2967    indices_loader = indices.make_loader()
2968    indices_ndim = len(indices.get_size())
2969    weight_size = weight.get_size()
2970    new_size = [*indices.get_size(), *weight_size[1:]]
2971
2972    def fn(idx):
2973        assert len(idx) == len(new_size), f"{idx} != {new_size}"
2974        var_index = indices_loader(idx[:indices_ndim])
2975        weight_idx = [ops.indirect_indexing(var_index, weight_size[0])] + [
2976            *idx[indices_ndim:]
2977        ]
2978        return weight_loader(weight_idx)
2979
2980    return Pointwise.create(
2981        device=weight.get_device(),
2982        dtype=weight.get_dtype(),
2983        inner_fn=fn,
2984        ranges=new_size,
2985    )
2986
2987
2988def check_and_broadcast_indices(indices, device):
2989    assert all(
2990        i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8)
2991        for i in indices
2992        if i is not None
2993    ), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
2994    if any(
2995        i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None
2996    ):
2997        raise NotImplementedError("Fallback for bool indices")
2998
2999    valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)]
3000    assert len(valid_idxs) > 0, "requires at least 1 non-None index"
3001    new_indices = [None] * len(indices)
3002    for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])):
3003        # Eager allows indices to be CPU tensor when running on CUDA
3004        # FIXME: Calling to_device(x, device) should work but
3005        # test_advancedindex_mixed_cpu_devices still fails
3006        if x.get_device() != device:
3007            raise NotImplementedError("Fallback when indices is on a different device")
3008        new_indices[i] = x
3009    return new_indices, valid_idxs
3010
3011
3012def index_output_size_and_inner_fn(
3013    x_size,
3014    indices,
3015    tensor_indices,
3016    tensor_size,
3017    indices_loaders,
3018    indexed_size,
3019    x_loader,
3020    check,
3021):
3022    # Note that behavior of indexing differs when there are non consecutive
3023    # tensors. In this case, the tensor index is pulled to the beginning.
3024    #
3025    # Suppose a = torch.arange(3 * 4 * 5 * 6 * 7).view(3, 4, 5, 6, 7)
3026    #         x = torch.tensor[1,2]
3027    # Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will
3028    # be pulled to the front.
3029    non_consecutive_tensors = False
3030    for previous, current in zip(tensor_indices, tensor_indices[1:]):
3031        if current - previous != 1:
3032            non_consecutive_tensors = True
3033
3034    output_size = [x_size[i] for i, val in enumerate(indices) if val is None]
3035    output_size = [*output_size, *x_size[len(output_size) + len(tensor_indices) :]]
3036
3037    first_tensor_index = tensor_indices[0]
3038    if non_consecutive_tensors:
3039        output_size = tensor_size + output_size
3040    else:
3041        output_size = (
3042            output_size[:first_tensor_index]
3043            + tensor_size
3044            + output_size[first_tensor_index:]
3045        )
3046
3047    def fn(idx):
3048        assert len(idx) == len(output_size)
3049        assert len(indices_loaders) == len(indexed_size)
3050
3051        rank = len(tensor_size)
3052        new_index = []
3053        first_tensor_index = tensor_indices[0]
3054        start_offset = 0 if non_consecutive_tensors else first_tensor_index
3055        next_idx = 0
3056        for i in range(tensor_indices[-1] + 1):
3057            if i == start_offset:
3058                next_idx += rank
3059            if indices[i] is None:
3060                assert next_idx < len(idx)
3061                new_index.append(idx[next_idx])
3062                next_idx += 1
3063            else:
3064                loader = indices_loaders[i]
3065                assert loader is not None
3066                size = indexed_size[i]
3067                new_index.append(
3068                    ops.indirect_indexing(
3069                        loader(idx[start_offset : start_offset + rank]),
3070                        size,
3071                        check=check,
3072                    )
3073                )
3074        new_index = [
3075            *new_index,
3076            *idx[next_idx:],
3077        ]
3078        return new_index if x_loader is None else x_loader(new_index)
3079
3080    return output_size, fn
3081
3082
3083def index_impl(x, indices, check):
3084    output_size, inner_fn, _ = index_impl_helper(x, indices, check)
3085
3086    return Pointwise.create(
3087        device=x.get_device(),
3088        dtype=x.get_dtype(),
3089        inner_fn=inner_fn,
3090        ranges=output_size,
3091    )
3092
3093
3094def index_impl_helper(x, indices, check):
3095    assert isinstance(indices, (list, tuple))
3096    x_loader = x.make_loader()
3097    indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device())
3098    assert len(tensor_indices) > 0, "Must have at least one valid idx"
3099
3100    indices_loaders = [i.make_loader() if i is not None else None for i in indices]
3101    # no guards on output size, all the guards are set in broadcast_tensors
3102
3103    # We can use the first one since they are all required to be the same size
3104    tensor_size = list(indices[tensor_indices[0]].get_size())
3105
3106    x_size = x.get_size()
3107
3108    indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None]
3109    if check and 0 in indexed_size and 0 not in tensor_size:
3110        raise IndexError("index is out of bounds for dimension with size 0")
3111
3112    indexed_size = [x_size[i] for i in range(len(indices))]
3113    output_size, index_inner_fn = index_output_size_and_inner_fn(
3114        x_size,
3115        indices,
3116        tensor_indices,
3117        tensor_size,
3118        indices_loaders,
3119        indexed_size,
3120        None,
3121        check=check,
3122    )
3123
3124    def inner_fn(idx):
3125        return x_loader(index_inner_fn(idx))
3126
3127    return output_size, inner_fn, index_inner_fn
3128
3129
3130@register_lowering(aten.index, type_promotion_kind=None)
3131def index(x, indices):
3132    try:
3133        return index_impl(x, indices, check=True)
3134    except NotImplementedError:
3135        # Fallback to ATen for boolean indexing
3136        x.realize()
3137        return fallback_handler(aten.index.Tensor, add_to_fallback_set=False)(
3138            x, indices
3139        )
3140
3141
3142@register_lowering(aten._unsafe_index, type_promotion_kind=None)
3143def _unsafe_index(x, indices):
3144    return index_impl(x, indices, check=False)
3145
3146
3147# All the indexing decompositions are written in terms of index, index_put, and index_put_
3148# We cannot have this lowering as a decomposition as it introduces
3149# mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead
3150# code elimination and common subexpression elimination optimizations, which
3151# assume graphs to be side-effect free. More details at
3152# https://github.com/pytorch/torchdynamo/issues/1235
3153# and
3154# https://github.com/pytorch/torchdynamo/issues/1863
3155@register_lowering(aten.index_put)
3156def index_put(x, indices, values, accumulate=False):
3157    return index_put_(clone(x), indices, values, accumulate)
3158
3159
3160@register_lowering(aten._unsafe_index_put)
3161def _unsafe_index_put(x, indices, values, accumulate=False):
3162    return index_put_impl_(clone(x), indices, values, accumulate, check=False)
3163
3164
3165def index_put_as_masked_fill(self, indices, value, accumulate):
3166    if value.get_device() != self.get_device():
3167        value = to_device(value, self.get_device())
3168    if accumulate:
3169        value = add(self, value)
3170    return mutate_to(self, where(indices[0], value, self))
3171
3172
3173def index_put_fallback(self, indices, values, accumulate):
3174    deterministic = torch.are_deterministic_algorithms_enabled()
3175    if is_triton(values) and (accumulate or deterministic):
3176        msg = (
3177            "index put with accumulate."
3178            if not deterministic
3179            else "deterministic index put."
3180        )
3181        if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
3182            msg = f"{msg} Found from : \n {stack_trace}"
3183        V.graph.disable_cudagraphs_reason = msg
3184
3185    ir.IndexPutFallback(V.graph.current_node.target, self, indices, values, accumulate)
3186    return self
3187
3188
3189@register_lowering(aten.index_put_, type_promotion_kind=None)
3190def index_put_(self, indices, values, accumulate=False):
3191    return index_put_impl_(self, indices, values, accumulate, check=True)
3192
3193
3194@register_lowering(inductor_prims._unsafe_index_put_, type_promotion_kind=None)
3195def _unsafe_index_put_(self, indices, values, accumulate=False):
3196    return index_put_impl_(self, indices, values, accumulate, check=False)
3197
3198
3199def index_put_impl_(self, indices, values, accumulate, check):
3200    # Dispatch to masked fill for single boolean index with single value
3201    if (
3202        values.get_numel() == 1
3203        and len(indices) == 1
3204        and indices[0].get_dtype() in {torch.bool, torch.uint8}
3205    ):
3206        mask = indices[0]
3207        for _ in range(len(mask.get_size()), len(self.get_size())):
3208            mask = unsqueeze(mask, -1)
3209        return index_put_as_masked_fill(self, [mask], values, accumulate)
3210
3211    # Fallback in torch deterministic mode
3212    if torch.are_deterministic_algorithms_enabled():
3213        return index_put_fallback(self, indices, values, accumulate)
3214
3215    # Fallback if there is a boolean index
3216    for index in indices:
3217        if index is not None and index.get_dtype() in {torch.bool, torch.uint8}:
3218            return index_put_fallback(self, indices, values, accumulate)
3219
3220    x_size = self.get_size()
3221    x_ndim = len(x_size)
3222
3223    if accumulate and needs_fallback_due_to_atomic_add_limitations(self.get_dtype()):
3224        # self is an scalar Tensor
3225        if x_ndim == 0:
3226            self = view(self, [1])
3227        self = index_put_fallback(self, indices, values, accumulate)
3228        if x_ndim == 0:
3229            self = view(self, [])
3230        return self
3231
3232    values = to_dtype(values, self.get_dtype())
3233
3234    try:
3235        # Note that code will only get here when dtype is uint32
3236        indices, tensor_indices = check_and_broadcast_indices(
3237            indices, self.get_device()
3238        )
3239    except NotImplementedError:
3240        return index_put_fallback(self, indices, values, accumulate)
3241
3242    indices_loaders = [i.make_loader() if i is not None else None for i in indices]
3243
3244    assert isinstance(self, TensorBox)
3245    self.realize()
3246
3247    # self is an scalar Tensor
3248    if x_ndim == 0:
3249        self = view(self, [1])
3250
3251    # We can use the first one since they are all required to be the same size
3252    tensor_size = list(indices[tensor_indices[0]].get_size())
3253    indexed_size = [x_size[i] for i in range(len(indices))]
3254
3255    expected_vals_size, inner_fn = index_output_size_and_inner_fn(
3256        x_size,
3257        indices,
3258        tensor_indices,
3259        tensor_size,
3260        indices_loaders,
3261        indexed_size,
3262        None,
3263        check=check,
3264    )
3265
3266    values = expand(values, expected_vals_size)
3267    # all guards are set above during broadcast_tensors and expand
3268
3269    scatter = ir.Scatter(
3270        device=self.get_device(),
3271        dtype=self.get_dtype(),
3272        inner_fn=values.make_loader(),
3273        ranges=expected_vals_size,  # iter_ranges,
3274        output_indexer=inner_fn,
3275        scatter_mode="atomic_add" if accumulate else None,
3276    )
3277    buffer = ir.ComputedBuffer(
3278        None,
3279        ir.MutationLayoutSHOULDREMOVE(self),
3280        scatter,
3281    )
3282    buffer.name = V.graph.register_buffer(buffer)
3283    V.graph.register_operation(buffer)
3284
3285    if x_ndim == 0:
3286        self = view(self, [])
3287    return self
3288
3289
3290fallback__unsafe_masked_index = fallback_handler(
3291    aten._unsafe_masked_index.default, add_to_fallback_set=False
3292)
3293
3294fallback__unsafe_masked_index_put_accumulate = fallback_handler(
3295    aten._unsafe_masked_index_put_accumulate.default, add_to_fallback_set=False
3296)
3297
3298
3299@register_lowering(aten._unsafe_masked_index, type_promotion_kind=None)
3300def _unsafe_masked_index(self, mask, indices, fill):
3301    ranges, _, _unsafe_index_fn = index_impl_helper(self, indices, check=False)
3302    mask_loader = mask.make_loader()
3303    self_loader = self.make_loader()
3304
3305    def inner_fn(idx):
3306        if mask.dtype != torch.bool:
3307            mask_val = ops.to_dtype(mask_loader(idx), torch.bool)
3308        else:
3309            mask_val = mask_loader(idx)
3310        return ops.masked(mask_val, lambda: self_loader(_unsafe_index_fn(idx)), fill)
3311
3312    return Pointwise.create(
3313        device=self.get_device(),
3314        dtype=self.get_dtype(),
3315        inner_fn=inner_fn,
3316        ranges=ranges,
3317    )
3318
3319
3320@register_lowering(aten._unsafe_masked_index_put_accumulate, type_promotion_kind=None)
3321def _unsafe_masked_index_put_accumulate(x, mask, indices, values):
3322    masked_value = where(mask, values, 0)
3323    shape = x.get_size()
3324    clamped_indices = [
3325        clamp(indices[i], -shape[i], shape[i] - 1) if indices[i] else None
3326        for i in range(len(indices))
3327    ]
3328    # TODO: use a masked store for this. currently only triton
3329    # supports masked stores and cpp backend does not.
3330    return _unsafe_index_put(x, clamped_indices, masked_value, accumulate=True)
3331
3332
3333@make_pointwise
3334def clamp(a, min, max):
3335    return ops.maximum(min, ops.minimum(max, a))
3336
3337
3338@register_lowering(aten.as_strided_scatter, type_promotion_kind=None)
3339def as_strided_scatter(self, src, size, stride, storage_offset=None):
3340    output = clone(self)
3341    output_view = as_strided(output, size, stride, storage_offset)
3342    copy_(output_view, src)
3343    return output
3344
3345
3346@register_lowering(aten.scatter, type_promotion_kind=None)
3347def scatter(x, dim: int, index, src, **kwargs):
3348    return scatter_(clone(x), dim, index, src, **kwargs)
3349
3350
3351def scatter_fallback(
3352    op_overload: torch._ops.OpOverload,
3353    self,
3354    dim: int,
3355    index,
3356    src,
3357    *,
3358    reduce: Optional[str] = None,
3359    include_self: bool = True,
3360):
3361    src_is_tensor = isinstance(src, TensorBox)
3362    if use_scatter_fallback(
3363        op_overload,
3364        reduce,
3365        self.get_dtype(),
3366        src.get_dtype() if src_is_tensor else type(src),
3367        src.get_device().type if src_is_tensor else "not impl",
3368        src_is_tensor,
3369    ):
3370        ir.ScatterFallback(
3371            op_overload,
3372            self,
3373            dim,
3374            index,
3375            src,
3376            reduce=reduce,
3377            include_self=include_self,
3378        )
3379        return self
3380
3381    return None
3382
3383
3384@register_lowering(aten.scatter_, type_promotion_kind=None)
3385def scatter_(self, dim: int, index, src, *, reduce: Optional[str] = None):
3386    assert reduce in {None, "add", "multiply"}
3387    if reduce is None:
3388        op_overload = getattr(aten.scatter_, V.graph.current_node.target._overloadname)  # type: ignore[union-attr]
3389        fallback_result = scatter_fallback(
3390            op_overload, self, dim, index, src, reduce=reduce
3391        )
3392        if fallback_result is not None:
3393            return fallback_result
3394
3395    if reduce == "add":
3396        reduce = "sum"
3397    elif reduce == "multiply":
3398        reduce = "prod"
3399    return scatter_reduce_(self, dim, index, src, reduce)
3400
3401
3402@register_lowering(aten.scatter_add, type_promotion_kind=None)
3403def scatter_add(x, dim: int, index, src):
3404    return scatter_add_(clone(x), dim, index, src)
3405
3406
3407@register_lowering(aten.scatter_add_, type_promotion_kind=None)
3408def scatter_add_(x, dim: int, index, src):
3409    return scatter_reduce_(x, dim, index, src, "sum")
3410
3411
3412@register_lowering(aten.scatter_reduce, type_promotion_kind=None)
3413def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs):
3414    return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs)
3415
3416
3417@register_lowering(aten.scatter_reduce_, type_promotion_kind=None)
3418def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True):
3419    assert reduce in {None, "sum", "prod", "mean", "amax", "amin"}
3420    assert (
3421        len(aten.scatter_reduce_.overloads()) == 1
3422        and "two" in aten.scatter_reduce_.overloads()
3423    ), "aten.scatter_reduce_.two is not the unique overload of aten.scatter_reduce_"
3424
3425    if isinstance(src, Number):
3426        src = full_like(self, src)
3427
3428    fallback_result = scatter_fallback(
3429        aten.scatter_reduce_.two,
3430        self,
3431        dim,
3432        index,
3433        src,
3434        reduce=reduce,
3435        include_self=include_self,
3436    )
3437
3438    if fallback_result:
3439        return fallback_result
3440
3441    assert isinstance(self, TensorBox)
3442    assert "int" in str(index.get_dtype())
3443
3444    ndim = len(self.get_size())
3445    if ndim == 0:
3446        self = view(self, [1])
3447
3448    if isinstance(src, TensorBox) and len(src.get_size()) == 0:
3449        src = view(src, [1])
3450
3451    if isinstance(index, TensorBox) and len(index.get_size()) == 0:
3452        index = view(index, [1])
3453
3454    if index.get_numel() == 0:
3455        return self
3456
3457    dim = _validate_dim(self, dim)
3458
3459    self.realize()
3460    index_loader = index.make_loader()
3461    src_loader = src.make_loader() if isinstance(src, TensorBox) else None
3462
3463    def output_indexer(idx):
3464        # self is captured from the end of the function, so it may have 0 dim
3465        shape = self.get_size()
3466        ndim = len(shape)
3467        indirect_idx = list(idx)
3468        indirect_idx[dim] = ops.indirect_indexing(
3469            index_loader(idx), 1 if ndim == 0 else shape[dim], wrap_neg=False
3470        )
3471        return indirect_idx
3472
3473    def fn(idx):
3474        if src_loader:
3475            return src_loader(idx)
3476        else:
3477            # src is a scalar
3478            return ops.constant(src, self.get_dtype())
3479
3480    def backend_reduce_str(reduce):
3481        if reduce == "sum":
3482            return "atomic_add"
3483        else:
3484            # TODO: Need to support more reduction type
3485            assert reduce is None
3486            return None
3487
3488    if not include_self:
3489        # zero out the corresponding elements first
3490        zero_out = ir.Scatter(
3491            device=self.get_device(),
3492            dtype=self.get_dtype(),
3493            inner_fn=lambda index: ops.constant(0, self.get_dtype()),
3494            ranges=index.get_size(),
3495            output_indexer=output_indexer,
3496            scatter_mode=None,
3497        )
3498        buffer = ir.ComputedBuffer(
3499            None,
3500            ir.MutationLayoutSHOULDREMOVE(self),
3501            zero_out,
3502        )
3503        buffer.name = V.graph.register_buffer(buffer)
3504        V.graph.register_operation(buffer)
3505
3506    # self[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0
3507    # self[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1
3508    # self[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2
3509    scatter = ir.Scatter(
3510        device=self.get_device(),
3511        dtype=self.get_dtype(),
3512        inner_fn=fn,
3513        ranges=index.get_size(),
3514        output_indexer=output_indexer,
3515        scatter_mode=backend_reduce_str(reduce),
3516    )
3517    buffer = ir.ComputedBuffer(
3518        None,
3519        ir.MutationLayoutSHOULDREMOVE(self),
3520        scatter,
3521    )
3522    buffer.name = V.graph.register_buffer(buffer)
3523    V.graph.register_operation(buffer)
3524
3525    if ndim == 0:
3526        self = view(self, [])
3527    return self
3528
3529
3530def upsample_nearestnd(
3531    x,
3532    output_size,
3533    scales_x: Tuple[Optional[float], ...],
3534    n: int = 2,
3535    exact: bool = False,
3536):
3537    x.realize_hint()  # elements are reused
3538    x_loader = x.make_loader()
3539    i_sizes = x.get_size()[-n:]
3540    batch = x.get_size()[:-n]
3541    i_sizes = [V.graph.sizevars.evaluate_static_shape(i) for i in i_sizes]
3542
3543    assert len(scales_x) == n
3544    o_sizes = output_size
3545
3546    inv_scales = [i / o for i, o in zip(i_sizes, o_sizes)]
3547    for i, scale in enumerate(scales_x):
3548        if scale is not None:
3549            inv_scales[i] = 1.0 / scale
3550
3551    def scale_fn(x, scale, size):
3552        # Nearest Exact: input_index = round(scale * (output_index + 0.5) - 0.5)
3553        #                            = floor(scale * (output_index + 0.5))
3554        # Nearest: input_index = floor(scale * output_index)
3555        x = ops.index_expr(x, torch.float32)
3556        if exact:
3557            x = ops.add(x, ops.constant(0.5, torch.float32))
3558        x = ops.mul(x, ops.constant(scale, torch.float32))
3559        x = ops.to_dtype(x, torch.int32)
3560        return ops.indirect_indexing(x, size, check=False)
3561
3562    def fn(idx):
3563        x = idx[-n:]
3564        b = idx[:-n]
3565        return x_loader(
3566            [*b, *[scale_fn(i, s, size) for i, s, size in zip(x, inv_scales, i_sizes)]]
3567        )
3568
3569    return Pointwise.create(
3570        device=x.get_device(),
3571        dtype=x.get_dtype(),
3572        inner_fn=fn,
3573        ranges=[*batch, *o_sizes],
3574    )
3575
3576
3577@register_lowering(aten.upsample_nearest1d.default)
3578def upsample_nearest1d(x, output_size, scales: Optional[float] = None):
3579    return upsample_nearestnd(x, output_size, (scales,), n=1)
3580
3581
3582@register_lowering(aten._upsample_nearest_exact1d.default)
3583def _upsample_nearest_exact1d(x, output_size, scales: Optional[float] = None):
3584    return upsample_nearestnd(x, output_size, (scales,), n=1, exact=True)
3585
3586
3587@register_lowering(aten.upsample_nearest2d.default)
3588def upsample_nearest2d(
3589    x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None
3590):
3591    return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2)
3592
3593
3594@register_lowering(aten._upsample_nearest_exact2d.default)
3595def _upsample_nearest_exact2d(
3596    x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None
3597):
3598    return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2, exact=True)
3599
3600
3601@register_lowering(aten.upsample_nearest3d.default)
3602def upsample_nearest3d(
3603    x,
3604    output_size,
3605    scales_d: Optional[float] = None,
3606    scales_h: Optional[float] = None,
3607    scales_w: Optional[float] = None,
3608):
3609    return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3)
3610
3611
3612@register_lowering(aten._upsample_nearest_exact3d.default)
3613def _upsample_nearest_exact3d(
3614    x,
3615    output_size,
3616    scales_d: Optional[float] = None,
3617    scales_h: Optional[float] = None,
3618    scales_w: Optional[float] = None,
3619):
3620    return upsample_nearestnd(
3621        x, output_size, (scales_d, scales_h, scales_w), n=3, exact=True
3622    )
3623
3624
3625def _create_constants(*args, dtype):
3626    return tuple(ops.constant(a, dtype) for a in args)
3627
3628
3629@register_lowering(prims.rev.default)
3630def rev(x, dims):
3631    # note - dims pre-canonicalized
3632    x_loader = x.make_loader()
3633    sizes = x.get_size()
3634
3635    def loader(idx):
3636        idx = list(idx)
3637        assert len(idx) == len(sizes)
3638        for dim in dims:
3639            idx[dim] = (sizes[dim] - 1) - idx[dim]
3640
3641        return x_loader(idx)
3642
3643    return Pointwise.create(
3644        device=x.get_device(),
3645        dtype=x.get_dtype(),
3646        inner_fn=loader,
3647        ranges=sizes,
3648    )
3649
3650
3651@register_lowering(aten.constant_pad_nd, type_promotion_kind=None)
3652def constant_pad_nd(x, padding, fill_value=0):
3653    assert (len(padding) % 2) == 0
3654    if all(p == 0 for p in padding):
3655        return clone(x)
3656
3657    sizes = x.get_size()
3658
3659    bounds = list(reversed(list(zip(padding[::2], padding[1::2]))))
3660    n = len(sizes) - len(bounds)
3661
3662    # if padding is a complicated expression, hoist it
3663    bounds_precomp: List[Tuple[sympy.Symbol, Any]] = []
3664    for l, h in bounds:
3665        bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h))  # type: ignore[arg-type]
3666
3667    output_size = list(sizes[:n])
3668    mask_sizes = []
3669    for (low, high), size in zip(bounds, sizes[n:]):
3670        mask_sizes.append(size)
3671        output_size.append(sympy.expand(size + low + high))
3672    assert len(output_size) == len(sizes)
3673    fill_value = dtype_to_type(x.get_dtype())(fill_value)
3674
3675    def mask(index):
3676        mask = []
3677        for idx, (low, high), length in zip(index[n:], bounds, mask_sizes):
3678            if low != 0:
3679                mask.append(range_mask_low(idx, 0))
3680            if high != 0:
3681                mask.append(range_mask_high(idx, length))
3682        mask = functools.reduce(ops.and_, mask)
3683        return ops.masked(mask, lambda: x_loader(index), fill_value)
3684
3685    def offset_fn(index):
3686        new_index = list(index[:n])
3687        for idx, (low, high) in zip(index[n:], bounds_precomp):
3688            new_index.append(idx - low)
3689        assert len(new_index) == len(index)
3690        return mask(new_index)
3691
3692    x_loader = x.make_loader()
3693    return Pointwise.create(
3694        device=x.get_device(),
3695        dtype=x.get_dtype(),
3696        inner_fn=offset_fn,
3697        ranges=output_size,
3698    )
3699
3700
3701def range_mask_low(i: sympy.Expr, low: Union[sympy.Expr, int]):
3702    return ops.ge(
3703        ops.index_expr(i, torch.int64),
3704        ops.index_expr(sympy.Integer(low), torch.int64),
3705    )
3706
3707
3708def range_mask_high(i: sympy.Expr, high: sympy.Expr):
3709    return ops.lt(
3710        ops.index_expr(i, torch.int64),
3711        ops.index_expr(high, torch.int64),
3712    )
3713
3714
3715def range_mask(i: sympy.Expr, high: sympy.Expr, low: sympy.Expr):
3716    return ops.and_(
3717        range_mask_low(i, low),
3718        range_mask_high(i, high),
3719    )
3720
3721
3722def constant_boundary_condition(
3723    x, fill_value, padding=None, pad_fill_value=1.0, dim=None
3724):
3725    h = x.get_size()[-dim:]
3726    x_loader = x.make_loader()
3727    padding_h = padding or [0] * dim
3728
3729    def load(index):
3730        prefix = index[:-dim]
3731        ih = index[-dim:]
3732
3733        mask = functools.reduce(
3734            ops.and_,
3735            [range_mask(ih[i], h[i] + padding_h[i], -padding_h[i]) for i in range(dim)],
3736        )
3737        return (
3738            ops.masked(
3739                mask,
3740                lambda: constant_boundary_condition(x, pad_fill_value, dim=dim)(
3741                    [*prefix, *ih]
3742                ),
3743                fill_value,
3744            )
3745            if padding
3746            else ops.masked(mask, lambda: x_loader([*prefix, *ih]), fill_value)
3747        )
3748
3749    return load
3750
3751
3752def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
3753    x_out = FloorDiv(
3754        x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i]
3755    )
3756
3757    if ceil_mode:
3758        x_alt = FloorDiv(
3759            x + 2 * padding[i] - (kernel_size[i] - 1) + 2 * (stride[i] - 1), stride[i]
3760        )
3761        if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
3762            # Sliding windows must start within the input or left padding
3763            x_alt -= 1  # type: ignore[assignment]
3764            V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i])  # type: ignore[arg-type]
3765        if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
3766            # ceil mode is actually a no-op, lets guard on that
3767            V.graph.sizevars.guard_equals(x_out, x_alt)
3768            ceil_mode = False
3769        else:
3770            x_out = x_alt
3771    return x_out, ceil_mode
3772
3773
3774def should_fallback_max_pool2d_with_indices(kernel_size, dilation):
3775    kernel_size = pad_listlike(kernel_size, 2)
3776    window_size = kernel_size[0] * kernel_size[1]
3777    return (window_size > 25) or any(d > 1 for d in dilation)
3778
3779
3780def max_pool2d_checks(
3781    x, kernel_size, stride, padding, dilation, *, assert_fallback=None
3782):
3783    if padding == 0:
3784        padding = [0, 0]
3785    if dilation == 1:
3786        dilation = [1, 1]
3787    if not stride:
3788        stride = kernel_size
3789
3790    kernel_size = pad_listlike(kernel_size, 2)
3791    stride = pad_listlike(stride, 2)
3792    padding = pad_listlike(padding, 2)
3793    dilation = pad_listlike(dilation, 2)
3794
3795    assert isinstance(x, TensorBox)
3796    assert len(kernel_size) == 2
3797    assert len(stride) == 2
3798    assert len(padding) == 2
3799    assert len(dilation) == 2
3800    assert len(x.get_size()) in (3, 4)
3801
3802    use_fallback = should_fallback_max_pool2d_with_indices(kernel_size, dilation)
3803    if assert_fallback is not None:
3804        assert use_fallback == assert_fallback
3805
3806    return kernel_size, stride, padding, dilation, use_fallback
3807
3808
3809@register_lowering(prims._low_memory_max_pool2d_with_offsets, type_promotion_kind=None)
3810def _low_memory_max_pool2d_with_offsets(
3811    x,
3812    kernel_size,
3813    stride,
3814    padding,
3815    dilation,
3816    ceil_mode=False,
3817):
3818    # assert we are not on a fallback path, the inductor decomp should have guaranteed this
3819    kernel_size, stride, padding, dilation, _ = max_pool2d_checks(
3820        x, kernel_size, stride, padding, dilation, assert_fallback=False
3821    )
3822
3823    x.realize_hint()
3824    *batch, h, w = x.get_size()
3825
3826    h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode)
3827    w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode)
3828
3829    dtype = x.dtype
3830    min_value = (
3831        False
3832        if dtype is torch.bool
3833        else (float("-inf") if dtype.is_floating_point else torch.iinfo(dtype).min)
3834    )
3835
3836    new_size = list(batch) + [h_out, w_out]
3837    if padding[0] or padding[1] or ceil_mode1 or ceil_mode2:
3838        x_loader = constant_boundary_condition(x, min_value, dim=2)
3839    else:
3840        x_loader = x.make_loader()
3841
3842    def fn(idx, return_index):
3843        *prefix, bh, bw = idx
3844        maxval = None
3845        maxindex = None
3846        for h_inc, w_inc in itertools.product(
3847            range(kernel_size[0]), range(kernel_size[1])
3848        ):
3849            ih = bh * stride[0] + h_inc - padding[0]
3850            iw = bw * stride[1] + w_inc - padding[1]
3851            val = x_loader([*prefix, ih, iw])
3852            if return_index:
3853                index = ops.index_expr(h_inc * kernel_size[1] + w_inc, torch.int8)
3854                if maxindex is None:
3855                    maxindex = index
3856                else:
3857                    maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
3858            if maxval is None:
3859                maxval = val
3860            else:
3861                maxval = ops.maximum(val, maxval)
3862        if return_index:
3863            return maxindex
3864        else:
3865            return maxval
3866
3867    out = Pointwise.create(
3868        device=x.get_device(),
3869        dtype=x.get_dtype(),
3870        inner_fn=functools.partial(fn, return_index=False),
3871        ranges=new_size,
3872    )
3873    offsets = Pointwise.create(
3874        device=x.get_device(),
3875        dtype=torch.int8,
3876        inner_fn=functools.partial(fn, return_index=True),
3877        ranges=new_size,
3878    )
3879    return out, offsets
3880
3881
3882@register_lowering(
3883    prims._low_memory_max_pool2d_offsets_to_indices, type_promotion_kind=None
3884)
3885def _low_memory_max_pool2d_offsets_to_indices(
3886    offsets, kernel_width, input_width, stride, padding
3887):
3888    # TODO: Generalize to other max pooling flavors, and arbitrary dim
3889
3890    offsets_loader = offsets.make_loader()
3891
3892    def increments_to_index(h_inc, w_inc, bh, bw):
3893        w_in = ops.index_expr(input_width, torch.int64)
3894        hbase = ops.index_expr(bh * stride[0] - padding[0], torch.int64)
3895        wbase = ops.index_expr(bw * stride[1] - padding[1], torch.int64)
3896        ih = hbase + h_inc
3897        iw = wbase + w_inc
3898        return ih * w_in + iw
3899
3900    def offsets_to_indices(idx):
3901        *prefix, bh, bw = idx
3902        offset = offsets_loader([*prefix, bh, bw])
3903        kw_const = ops.constant(kernel_width, torch.int32)
3904        h_inc = offset // kw_const
3905        w_inc = offset - (h_inc * kw_const)
3906        return increments_to_index(h_inc, w_inc, bh, bw)
3907
3908    indices = Pointwise.create(
3909        device=offsets.get_device(),
3910        dtype=torch.int64,
3911        inner_fn=offsets_to_indices,
3912        ranges=offsets.get_size(),
3913    )
3914    return indices
3915
3916
3917# Fallback selected when we do not decompose to the low-memory path.
3918make_fallback(aten.max_pool2d_with_indices)
3919
3920
3921fallback_max_pool2d_with_indices_backward = fallback_handler(
3922    aten.max_pool2d_with_indices_backward.default,
3923    add_to_fallback_set=False,
3924)
3925
3926
3927@register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None)
3928def max_pool2d_with_indices_backward(
3929    grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
3930):
3931    if padding == 0:
3932        padding = [0, 0]
3933    if dilation == 1:
3934        dilation = [1, 1]
3935    if not stride:
3936        stride = kernel_size
3937
3938    assert isinstance(x, TensorBox)
3939    assert len(kernel_size) == 2
3940    assert len(stride) == 2
3941    assert len(padding) == 2
3942    assert len(dilation) == 2
3943    assert len(x.get_size()) in (3, 4)
3944
3945    # we will read this many times, so make sure it is computed
3946    grad_output.realize_hint()
3947    try:
3948        gO_stride = grad_output.get_stride()
3949    except AttributeError:
3950        # some classes don't have `get_stride`
3951        # TODO will need a better way of determining if inputs are channels-last
3952        gO_stride = None
3953    if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise):  # type: ignore[attr-defined]
3954        data = x.data.data  # type: ignore[attr-defined]
3955        x_buffer = ir.ComputedBuffer(
3956            name=None,
3957            layout=ir.FlexibleLayout(
3958                device=data.get_device(),
3959                dtype=data.get_dtype(),
3960                size=data.get_size(),
3961            ),
3962            data=data,
3963        )
3964        x_buffer.decide_layout()
3965        x_stride = x_buffer.get_stride()
3966    else:
3967        try:
3968            x_stride = x.get_stride()
3969        except AttributeError:
3970            x_stride = None
3971
3972    is_channels_last = (x_stride is not None and x_stride[1] == 1) or (
3973        gO_stride is not None and gO_stride[1] == 1
3974    )
3975    if any(d != 1 for d in dilation):
3976        # dilation NYI
3977        return fallback_max_pool2d_with_indices_backward(
3978            grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
3979        )
3980
3981    *batch, height, width = x.get_size()
3982    *_, pooled_height, pooled_width = grad_output.get_size()
3983
3984    indices_loader = indices.make_loader()
3985    grad_loader = grad_output.make_loader()
3986    new_size = list(x.get_size())
3987
3988    h_window_size = max(
3989        max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1)
3990        for h in range(kernel_size[0] * 2)
3991    )
3992    w_window_size = max(
3993        max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1)
3994        for w in range(kernel_size[1] * 2)
3995    )
3996
3997    window_size = h_window_size * w_window_size
3998
3999    if window_size > 25:
4000        # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
4001        return fallback_max_pool2d_with_indices_backward(
4002            grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
4003        )
4004
4005    indices_size = indices.get_size()
4006
4007    def fn(idx):
4008        *prefix, h, w = idx
4009        index_test = ops.index_expr(h * width + w, torch.int32)
4010        h = h + padding[0]
4011        w = w + padding[1]
4012        phstart = ops.index_expr(
4013            FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
4014        )
4015        pwstart = ops.index_expr(
4016            FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
4017        )
4018        phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32)
4019        pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32)
4020
4021        phstart = ops.maximum(phstart, ops.constant(0, torch.int32))
4022        pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32))
4023        phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32))
4024        pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32))
4025
4026        gradient = None
4027        for ph_ in range(h_window_size):
4028            for pw_ in range(w_window_size):
4029                ph = ops.add(phstart, ops.constant(ph_, torch.int32))
4030                pw = ops.add(pwstart, ops.constant(pw_, torch.int32))
4031                grad_index = [
4032                    *prefix,
4033                    ops.indirect_indexing(
4034                        ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32))),
4035                        indices_size[-2],
4036                        check=False,
4037                    ),
4038                    ops.indirect_indexing(
4039                        ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32))),
4040                        indices_size[-1],
4041                        check=False,
4042                    ),
4043                ]
4044
4045                index_actual = indices_loader(grad_index)
4046                grad_part = grad_loader(grad_index)
4047                check = ops.eq(index_actual, index_test)
4048
4049                if gradient is None:
4050                    # don't need mask for 0, 0
4051                    gradient = ops.where(
4052                        check, grad_part, ops.constant(0.0, torch.float32)
4053                    )
4054                else:
4055                    mask = ops.and_(
4056                        ops.and_(
4057                            ops.lt(ph, phend),
4058                            ops.lt(pw, pwend),
4059                        ),
4060                        check,
4061                    )
4062                    gradient = ops.where(mask, ops.add(gradient, grad_part), gradient)
4063        assert gradient is not None
4064        return gradient
4065
4066    out = Pointwise.create(
4067        device=grad_output.get_device(),
4068        dtype=grad_output.get_dtype(),
4069        inner_fn=fn,
4070        ranges=new_size,
4071    )
4072    if is_channels_last:
4073        return ir.ExternKernel.require_channels_last(out)
4074    else:
4075        return out
4076
4077
4078def pad_adaptive_loader(x, pad_val=0.0):
4079    *_, h, w = x.get_size()
4080    x_loader = x.make_loader()
4081
4082    def load(prefix, increments, start_indices, end_indices):
4083        ih, iw = increments
4084        h_start_index, w_start_index = start_indices
4085        h_end_index, w_end_index = end_indices
4086
4087        mask = ops.and_(
4088            ops.lt(
4089                ops.index_expr(h_start_index + ih, torch.int64),
4090                ops.index_expr(h_end_index, torch.int64),
4091            ),
4092            ops.lt(
4093                ops.index_expr(w_start_index + iw, torch.int64),
4094                ops.index_expr(w_end_index, torch.int64),
4095            ),
4096        )
4097
4098        return ops.masked(
4099            mask,
4100            lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]),
4101            pad_val,
4102        )
4103
4104    return load
4105
4106
4107def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out):
4108    h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
4109    h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
4110
4111    w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
4112    w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
4113
4114    return h_start_index, h_end_index, w_start_index, w_end_index
4115
4116
4117def _adaptive_pooling_fn(
4118    start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
4119):
4120    h_in, w_in = in_sizes
4121    h_out, w_out = out_sizes
4122
4123    (
4124        h_start_index_fn,
4125        h_end_index_fn,
4126        w_start_index_fn,
4127        w_end_index_fn,
4128    ) = compute_indices_adaptive_pooling(
4129        start_index, end_index, h_in, w_in, h_out, w_out
4130    )
4131
4132    def fn(idx, loader):
4133        *prefix, bh, bw = idx
4134
4135        h_start_index = h_start_index_fn(bh)
4136        h_end_index = h_end_index_fn(bh)
4137
4138        w_start_index = w_start_index_fn(bw)
4139        w_end_index = w_end_index_fn(bw)
4140
4141        result = None
4142        for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
4143            val = loader(
4144                prefix,
4145                [ih, iw],
4146                [h_start_index, w_start_index],
4147                [h_end_index, w_end_index],
4148            )
4149            if result is None:
4150                result = val
4151            else:
4152                result = pooling_fn(val, result)
4153        return result
4154
4155    return fn
4156
4157
4158def _adaptive_pooling_fn_with_idx(
4159    start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
4160):
4161    h_in, w_in = in_sizes
4162    h_out, w_out = out_sizes
4163
4164    (
4165        h_start_index_fn,
4166        h_end_index_fn,
4167        w_start_index_fn,
4168        w_end_index_fn,
4169    ) = compute_indices_adaptive_pooling(
4170        start_index, end_index, h_in, w_in, h_out, w_out
4171    )
4172
4173    def fn(idx, loader):
4174        *prefix, bh, bw = idx
4175
4176        h_start_index = h_start_index_fn(bh)
4177        h_end_index = h_end_index_fn(bh)
4178
4179        w_start_index = w_start_index_fn(bw)
4180        w_end_index = w_end_index_fn(bw)
4181
4182        maxval = None
4183        maxindex = None
4184        for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
4185            val = loader(
4186                prefix,
4187                [ih, iw],
4188                [h_start_index, w_start_index],
4189                [h_end_index, w_end_index],
4190            )
4191
4192            index = ops.index_expr(
4193                (h_start_index + ih) * w_in + w_start_index + iw, torch.int64
4194            )
4195
4196            if maxindex is None:
4197                maxindex = index
4198            else:
4199                maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
4200
4201            if maxval is None:
4202                maxval = val
4203            else:
4204                maxval = pooling_fn(val, maxval)
4205
4206        return maxindex
4207
4208    return fn
4209
4210
4211fallback_adaptive_avg_pool2d = fallback_handler(
4212    aten._adaptive_avg_pool2d.default, add_to_fallback_set=False
4213)
4214
4215
4216@register_lowering(aten._adaptive_avg_pool2d)
4217def _adaptive_avg_pool2d(x, output_size):
4218    assert isinstance(x, TensorBox)
4219    assert len(output_size) == 2
4220    x.realize_hint()
4221
4222    *batch, h_in, w_in = x.get_size()
4223
4224    h_in = V.graph.sizevars.evaluate_static_shape(h_in)
4225    w_in = V.graph.sizevars.evaluate_static_shape(w_in)
4226
4227    h_out, w_out = output_size
4228
4229    # no-op if the same input and output
4230    if h_in == h_out and w_in == w_out:
4231        return clone(x)
4232
4233    if h_out == 0 or w_out == 0:
4234        o_size = [*batch, h_out, w_out]
4235        return empty(o_size, dtype=x.get_dtype(), device=x.get_device())
4236    if h_in % h_out == 0 and w_in % w_out == 0:
4237        kernel_size = [h_in // h_out, w_in // w_out]
4238        return avg_pool2d(x, kernel_size)
4239
4240    h_kernel_max = ceildiv((h_in + h_out - 1), h_out)
4241    w_kernel_max = ceildiv((w_in + w_out - 1), w_out)
4242
4243    new_size = list(batch) + [h_out, w_out]
4244    dtype = x.get_dtype()
4245
4246    window_size = h_kernel_max * w_kernel_max
4247    if window_size > 25:
4248        # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
4249        return fallback_adaptive_avg_pool2d(x, output_size)
4250
4251    def start_index(index, out_dim, inp_dim):
4252        return FloorDiv((index * inp_dim), out_dim)
4253
4254    def end_index(index, out_dim, inp_dim):
4255        return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
4256
4257    fn_sum = _adaptive_pooling_fn(
4258        start_index=start_index,
4259        end_index=end_index,
4260        kernel_maxes=[h_kernel_max, w_kernel_max],
4261        in_sizes=[h_in, w_in],
4262        out_sizes=[h_out, w_out],
4263        pooling_fn=ops.add,
4264    )
4265
4266    ones_loader = pad_adaptive_loader(ones_like(x))
4267
4268    def fn(idx):
4269        return ops.truediv(
4270            fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader)
4271        )
4272
4273    rv = Pointwise.create(
4274        device=x.get_device(),
4275        dtype=dtype,
4276        inner_fn=fn,
4277        ranges=new_size,
4278    )
4279    # TODO: should we force these to be realized?
4280    return rv
4281
4282
4283fallback_adaptive_max_pool2d = fallback_handler(
4284    aten.adaptive_max_pool2d.default, add_to_fallback_set=False
4285)
4286
4287
4288@register_lowering(aten.adaptive_max_pool2d)
4289def adaptive_max_pool2d(x, output_size):
4290    assert isinstance(x, TensorBox)
4291    assert len(output_size) == 2
4292    x.realize_hint()
4293
4294    *batch, h_in, w_in = x.get_size()
4295
4296    h_in = V.graph.sizevars.evaluate_static_shape(h_in)
4297    w_in = V.graph.sizevars.evaluate_static_shape(w_in)
4298
4299    h_out, w_out = output_size
4300
4301    if h_out == 0 or w_out == 0:
4302        o_size = [*batch, h_out, w_out]
4303        return empty(o_size, dtype=x.get_dtype(), device=x.get_device()), empty(
4304            o_size, dtype=torch.int64, device=x.get_device()
4305        )
4306    if h_in % h_out == 0 and w_in % w_out == 0:
4307        kernel_size = [h_in // h_out, w_in // w_out]
4308        if should_fallback_max_pool2d_with_indices(kernel_size, dilation=[1, 1]):
4309            return max_pool2d_with_indices(x, kernel_size)  # type: ignore[name-defined]   # noqa: F821
4310        else:
4311            v, offsets = _low_memory_max_pool2d_with_offsets(
4312                x,
4313                kernel_size,
4314                stride=kernel_size,
4315                padding=[0, 0],
4316                dilation=[1, 1],
4317                ceil_mode=False,
4318            )
4319            indices = _low_memory_max_pool2d_offsets_to_indices(
4320                offsets, kernel_size[1], w_in, kernel_size, padding=[0, 0]
4321            )
4322            return v, indices
4323
4324    h_kernel_max = ceildiv((h_in + h_out - 1), h_out)
4325    w_kernel_max = ceildiv((w_in + w_out - 1), w_out)
4326
4327    new_size = list(batch) + [h_out, w_out]
4328    dtype = x.get_dtype()
4329
4330    window_size = h_kernel_max * w_kernel_max
4331    if window_size > 25:
4332        # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
4333        return fallback_adaptive_max_pool2d(x, output_size)
4334
4335    def start_index(index, out_dim, inp_dim):
4336        return FloorDiv((index * inp_dim), out_dim)
4337
4338    def end_index(index, out_dim, inp_dim):
4339        return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
4340
4341    inner_func_max_val = _adaptive_pooling_fn(
4342        start_index=start_index,
4343        end_index=end_index,
4344        kernel_maxes=[h_kernel_max, w_kernel_max],
4345        in_sizes=[h_in, w_in],
4346        out_sizes=[h_out, w_out],
4347        pooling_fn=ops.maximum,
4348    )
4349
4350    inner_func_max_idx = _adaptive_pooling_fn_with_idx(
4351        start_index=start_index,
4352        end_index=end_index,
4353        kernel_maxes=[h_kernel_max, w_kernel_max],
4354        in_sizes=[h_in, w_in],
4355        out_sizes=[h_out, w_out],
4356        pooling_fn=ops.maximum,
4357    )
4358
4359    def inner_fn_max_val(idx):
4360        return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf")))
4361
4362    def inner_fn_max_idx(idx):
4363        return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf")))
4364
4365    rv = Pointwise.create(
4366        device=x.get_device(),
4367        dtype=dtype,
4368        inner_fn=inner_fn_max_val,
4369        ranges=new_size,
4370    )
4371    ri = Pointwise.create(
4372        device=x.get_device(),
4373        dtype=torch.int64,
4374        inner_fn=inner_fn_max_idx,
4375        ranges=new_size,
4376    )
4377    return rv, ri
4378
4379
4380fallback_fractional_max_pool2d = fallback_handler(
4381    aten.fractional_max_pool2d.default, add_to_fallback_set=False
4382)
4383
4384
4385def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim):
4386    out_sz = out_sz[dim]
4387    in_sz = in_sz[dim]
4388    kernel_sz = kernel_sz[dim]
4389    alpha = IntTrueDiv(in_sz - kernel_sz, out_sz - 1)
4390    samples_loader = samples.make_loader()
4391
4392    def load(prefix, i):
4393        sample = samples_loader([*prefix, dim])
4394        i_expr = ops.index_expr(i, samples.get_dtype())
4395        alpha_expr = ops.index_expr(alpha, samples.get_dtype())
4396        seq_i = ops.floor((i_expr + sample) * alpha_expr) - ops.floor(
4397            sample * alpha_expr
4398        )
4399        seq_i = ops.to_dtype(seq_i, torch.int64)
4400
4401        mask = ops.lt(
4402            i_expr,
4403            ops.index_expr(out_sz - 1, torch.int64),
4404        )
4405        return ops.where(mask, seq_i, ops.index_expr(in_sz - kernel_sz, torch.int64))
4406
4407    return load
4408
4409
4410@register_lowering(aten.fractional_max_pool2d)
4411def fractional_max_pool2d(x, kernel_size, output_size, random_samples):
4412    x.realize_hint()
4413    *batch, inp_h, inp_w = x.get_size()
4414    kernel_h, kernel_w = kernel_size
4415    h_out, w_out = output_size
4416
4417    if kernel_h * kernel_w >= 25:
4418        return fallback_fractional_max_pool2d(
4419            x, kernel_size, output_size, random_samples
4420        )
4421
4422    gen_offsets_for_dim = functools.partial(
4423        _fractional_pooling_offsets,
4424        samples=random_samples,
4425        in_sz=[inp_h, inp_w],
4426        out_sz=output_size,
4427        kernel_sz=kernel_size,
4428    )
4429
4430    h_index_fn = gen_offsets_for_dim(dim=0)
4431    w_index_fn = gen_offsets_for_dim(dim=1)
4432    x_loader = x.make_loader()
4433
4434    def fn(idx, return_index):
4435        *prefix, bh, bw = idx
4436
4437        h_start_index = ops.indirect_indexing(h_index_fn(prefix, bh), inp_h)
4438        w_start_index = ops.indirect_indexing(w_index_fn(prefix, bw), inp_w)
4439
4440        maxval = None
4441        maxindex = None
4442        for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])):
4443            val = x_loader([*prefix, h_start_index + ih, w_start_index + iw])
4444            if return_index:
4445                index = ops.index_expr(
4446                    (h_start_index + ih) * inp_w + w_start_index + iw, torch.int64
4447                )
4448                if maxindex is None:
4449                    maxindex = index
4450                else:
4451                    maxindex = ops.where(
4452                        ops.or_(ops.gt(val, maxval), ops.isnan(val)), index, maxindex
4453                    )
4454            if maxval is None:
4455                maxval = val
4456            else:
4457                maxval = ops.maximum(val, maxval)
4458        if return_index:
4459            return maxindex
4460        else:
4461            return maxval
4462
4463    new_size = list(batch) + [h_out, w_out]
4464    rv = Pointwise.create(
4465        device=x.get_device(),
4466        dtype=x.get_dtype(),
4467        inner_fn=functools.partial(fn, return_index=False),
4468        ranges=new_size,
4469    )
4470
4471    ri = Pointwise.create(
4472        device=x.get_device(),
4473        dtype=torch.int64,
4474        inner_fn=functools.partial(fn, return_index=True),
4475        ranges=new_size,
4476    )
4477    return rv, ri
4478
4479
4480@register_lowering(aten.upsample_nearest2d_backward.default)
4481def upsample_nearest2d_backward(
4482    x, output_size=None, input_size=None, scales_h=None, scales_w=None
4483):
4484    x.realize_hint()
4485
4486    *batch, inp_h, inp_w = x.get_size()
4487    inp_h = V.graph.sizevars.evaluate_static_shape(inp_h)
4488    inp_w = V.graph.sizevars.evaluate_static_shape(inp_w)
4489
4490    *batch, out_h, out_w = input_size
4491
4492    if inp_h % out_h == 0 and inp_w % out_w == 0:
4493        return avg_pool2d(x, [inp_h // out_h, inp_w // out_w], divisor_override=1)
4494
4495    h_kernel_max = ceildiv(inp_h, out_h)
4496    w_kernel_max = ceildiv(inp_w, out_w)
4497
4498    def start_index(index, out_dim, inp_dim):
4499        return CeilDiv(index * inp_dim, sympy.sympify(out_dim))
4500
4501    def end_index(index, out_dim, inp_dim):
4502        return start_index((index + 1), out_dim, inp_dim)
4503
4504    fn_sum = _adaptive_pooling_fn(
4505        start_index=start_index,
4506        end_index=end_index,
4507        kernel_maxes=[h_kernel_max, w_kernel_max],
4508        in_sizes=[inp_h, inp_w],
4509        out_sizes=[out_h, out_w],
4510        pooling_fn=ops.add,
4511    )
4512
4513    def fn(idx):
4514        return fn_sum(idx, pad_adaptive_loader(x))
4515
4516    rv = Pointwise.create(
4517        device=x.get_device(),
4518        dtype=x.get_dtype(),
4519        inner_fn=fn,
4520        ranges=list(input_size),
4521    )
4522
4523    return rv
4524
4525
4526fallback_avg_pool2d = fallback_handler(
4527    aten.avg_pool2d.default, add_to_fallback_set=False
4528)
4529fallback_avg_pool3d = fallback_handler(
4530    aten.avg_pool3d.default, add_to_fallback_set=False
4531)
4532
4533
4534@register_lowering(aten.avg_pool2d, type_promotion_kind=None)
4535def avg_pool2d(
4536    x,
4537    kernel_size,
4538    stride=(),
4539    padding=0,
4540    ceil_mode=False,
4541    count_include_pad=True,
4542    divisor_override=None,
4543):
4544    return _avg_poolnd(
4545        x,
4546        kernel_size,
4547        stride,
4548        padding,
4549        ceil_mode,
4550        count_include_pad,
4551        divisor_override,
4552        dim=2,
4553    )
4554
4555
4556@register_lowering(aten.avg_pool3d, type_promotion_kind=None)
4557def avg_pool3d(
4558    x,
4559    kernel_size,
4560    stride=(),
4561    padding=0,
4562    ceil_mode=False,
4563    count_include_pad=True,
4564    divisor_override=None,
4565):
4566    return _avg_poolnd(
4567        x,
4568        kernel_size,
4569        stride,
4570        padding,
4571        ceil_mode,
4572        count_include_pad,
4573        divisor_override,
4574        dim=3,
4575    )
4576
4577
4578def _avg_poolnd(
4579    x,
4580    kernel_size,
4581    stride,
4582    padding,
4583    ceil_mode,
4584    count_include_pad,
4585    divisor_override,
4586    dim,
4587):
4588    if not stride:
4589        stride = kernel_size
4590    if not padding:
4591        padding = [0] * dim
4592    kernel_size = pad_listlike(kernel_size, dim)
4593    stride = pad_listlike(stride, dim)
4594    padding = pad_listlike(padding, dim)
4595
4596    assert isinstance(x, TensorBox)
4597    assert len(kernel_size) == dim
4598    assert len(stride) == dim
4599    assert len(padding) == dim
4600    assert len(x.get_size()) in (dim + 1, dim + 2)
4601
4602    x.realize_hint()
4603    batch = x.get_size()[:-dim]
4604    h = x.get_size()[-dim:]
4605
4606    h_out, ceil_modes = zip(
4607        *[
4608            pooling_size(h[i], i, kernel_size, stride, padding, ceil_mode)
4609            for i in range(dim)
4610        ]
4611    )
4612
4613    if any(padding) or any(ceil_modes):
4614        x_loader = constant_boundary_condition(x, 0.0, dim=dim)
4615        had_padding = True
4616    else:
4617        x_loader = x.make_loader()
4618        had_padding = False
4619
4620    new_size = list(batch) + list(h_out)
4621    dtype = x.get_dtype()
4622
4623    window_size = functools.reduce(operator.mul, kernel_size)
4624    if window_size > 25:
4625        # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
4626        if dim == 2:
4627            fallback = fallback_avg_pool2d
4628        elif dim == 3:
4629            fallback = fallback_avg_pool3d
4630        else:
4631            raise ValueError(f"Unknown dim: {dim}")
4632
4633        return fallback(
4634            x,
4635            kernel_size,
4636            stride,
4637            padding,
4638            ceil_mode,
4639            count_include_pad,
4640            divisor_override,
4641        )
4642
4643    def fn_sum(idx, loader):
4644        prefix = idx[:-dim]
4645        b = idx[-dim:]
4646        total = None
4647        for ih in itertools.product(*[range(kernel_size[i]) for i in range(dim)]):
4648            inp = [b[i] * stride[i] + ih[i] - padding[i] for i in range(dim)]
4649            val = loader([*prefix, *inp])
4650            if total is None:
4651                total = val
4652            else:
4653                total = ops.add(val, total)
4654        return total
4655
4656    if not had_padding or divisor_override:
4657        if divisor_override:
4658            scale = 1 / divisor_override
4659        else:
4660            scale = 1.0 / window_size
4661
4662        def fn(idx):
4663            return ops.mul(fn_sum(idx, x_loader), ops.constant(scale, dtype))
4664
4665    else:
4666
4667        def fn(idx):
4668            prefix = idx[:-dim]
4669            bh = idx[-dim:]
4670
4671            divide_factors = []
4672            for i in range(dim):
4673                hstart = bh[i] * stride[i] - padding[i]
4674                hend = sympy.Min(hstart + kernel_size[i], h[i] + padding[i])
4675                if not count_include_pad:
4676                    hstart = sympy.Max(hstart, 0)
4677                    hend = sympy.Min(hend, h[i])
4678                factor = ops.index_expr(hend - hstart, torch.int32)
4679                divide_factors.append(factor)
4680            divide_factor = functools.reduce(ops.mul, divide_factors)
4681            return ops.truediv(fn_sum(idx, x_loader), divide_factor)
4682
4683    rv = Pointwise.create(
4684        device=x.get_device(),
4685        dtype=dtype,
4686        inner_fn=fn,
4687        ranges=new_size,
4688    )
4689    # TODO(jansel): should we force these to be realized?
4690    return rv
4691
4692
4693fallback_avg_pool2d_backward = fallback_handler(
4694    aten.avg_pool2d_backward.default, add_to_fallback_set=False
4695)
4696
4697
4698@register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None)
4699def avg_pool2d_backward(
4700    grad_output,
4701    x,
4702    kernel_size,
4703    stride,
4704    padding,
4705    ceil_mode,
4706    count_include_pad,
4707    divisor_override=None,
4708):
4709    assert divisor_override is None or divisor_override != 0, "divisor must be not zero"
4710    if not stride:
4711        stride = kernel_size
4712    if not padding:
4713        padding = [0, 0]
4714
4715    assert isinstance(grad_output, TensorBox)
4716    assert isinstance(x, TensorBox)
4717    assert len(kernel_size) == 2
4718    assert len(stride) == 2
4719    assert len(padding) == 2
4720    assert len(x.get_size()) in (3, 4)
4721
4722    grad_output.realize_hint()  # we will read this many times, so make sure it is computed
4723
4724    *batch, height, width = x.get_size()
4725
4726    h_out, ceil_mode1 = pooling_size(height, 0, kernel_size, stride, padding, ceil_mode)
4727    w_out, ceil_mode2 = pooling_size(width, 1, kernel_size, stride, padding, ceil_mode)
4728
4729    grad_loader = grad_output.make_loader()
4730
4731    had_padding = padding[0] or padding[1] or ceil_mode1 or ceil_mode2
4732
4733    *_, pooled_height, pooled_width = grad_output.get_size()
4734    new_size = list(x.get_size())
4735    dtype = x.get_dtype()
4736
4737    h_window_size = max(
4738        max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1)
4739        for h in range(kernel_size[0] * 2)
4740    )
4741    w_window_size = max(
4742        max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1)
4743        for w in range(kernel_size[1] * 2)
4744    )
4745
4746    window_size = h_window_size * w_window_size
4747    if window_size > 25:
4748        # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
4749        return fallback_avg_pool2d_backward(
4750            grad_output,
4751            x,
4752            kernel_size,
4753            stride,
4754            padding,
4755            ceil_mode,
4756            count_include_pad,
4757            divisor_override,
4758        )
4759
4760    def compute_pool_size_without_padding(ph, pw):
4761        """
4762        This computes the scaling factor that we will divide an element
4763        by when `count_include_pad=False`
4764        """
4765        stride_h = ops.constant(stride[0], torch.int32)
4766        stride_w = ops.constant(stride[1], torch.int32)
4767        pad_h = ops.constant(padding[0], torch.int32)
4768        pad_w = ops.constant(padding[1], torch.int32)
4769        kernel_h = ops.constant(kernel_size[0], torch.int32)
4770        kernel_w = ops.constant(kernel_size[1], torch.int32)
4771        hstart = ops.sub(ops.mul(ph, stride_h), pad_h)
4772        wstart = ops.sub(ops.mul(pw, stride_w), pad_w)
4773        hend = ops.minimum(
4774            ops.add(hstart, kernel_h),
4775            ops.add(ops.index_expr(height, torch.int32), pad_h),
4776        )
4777        wend = ops.minimum(
4778            ops.add(wstart, kernel_w),
4779            ops.add(ops.index_expr(width, torch.int32), pad_w),
4780        )
4781        hstart = ops.maximum(hstart, ops.constant(0, torch.int32))
4782        wstart = ops.maximum(wstart, ops.constant(0, torch.int32))
4783        hend = ops.minimum(hend, ops.index_expr(height, torch.int32))
4784        wend = ops.minimum(wend, ops.index_expr(width, torch.int32))
4785        divide_factor = ops.mul(ops.sub(hend, hstart), ops.sub(wend, wstart))
4786        return divide_factor
4787
4788    def fn(idx):
4789        *prefix, h, w = idx
4790        h = h + padding[0]
4791        w = w + padding[1]
4792        phstart = ops.index_expr(
4793            FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
4794        )
4795        pwstart = ops.index_expr(
4796            FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
4797        )
4798        phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32)
4799        pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32)
4800
4801        phstart = ops.maximum(phstart, ops.constant(0, torch.int32))
4802        pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32))
4803        phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32))
4804        pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32))
4805
4806        gradient = None
4807        for ph_ in range(h_window_size):
4808            for pw_ in range(w_window_size):
4809                ph = ops.add(phstart, ops.constant(ph_, torch.int32))
4810                pw = ops.add(pwstart, ops.constant(pw_, torch.int32))
4811
4812                if divisor_override is not None:
4813                    scale = divisor_override
4814                elif count_include_pad or not had_padding:
4815                    scale = kernel_size[0] * kernel_size[1]
4816                else:
4817                    scale = compute_pool_size_without_padding(ph, pw)
4818
4819                part = ops.truediv(
4820                    grad_loader(
4821                        [
4822                            *prefix,
4823                            ops.indirect_indexing(
4824                                ops.minimum(
4825                                    ph, ops.sub(phend, ops.constant(1, torch.int32))
4826                                ),
4827                                pooled_height,
4828                                check=False,
4829                            ),
4830                            ops.indirect_indexing(
4831                                ops.minimum(
4832                                    pw, ops.sub(pwend, ops.constant(1, torch.int32))
4833                                ),
4834                                pooled_width,
4835                                check=False,
4836                            ),
4837                        ]
4838                    ),
4839                    scale,
4840                )
4841
4842                mask = ops.and_(
4843                    ops.lt(ph, phend),
4844                    ops.lt(pw, pwend),
4845                )
4846                if gradient is None:
4847                    gradient = ops.where(mask, part, ops.constant(0.0, torch.float32))
4848                else:
4849                    gradient = ops.where(mask, ops.add(gradient, part), gradient)
4850        assert gradient is not None
4851        return gradient
4852
4853    rv = Pointwise.create(
4854        device=grad_output.get_device(),
4855        dtype=dtype,
4856        inner_fn=fn,
4857        ranges=new_size,
4858    )
4859    return rv
4860
4861
4862fallback_avg_pool3d_backward = fallback_handler(
4863    aten.avg_pool3d_backward.default, add_to_fallback_set=False
4864)
4865
4866
4867@register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None)
4868def avg_pool3d_backward(
4869    grad_output,
4870    x,
4871    kernel_size,
4872    stride,
4873    padding,
4874    ceil_mode,
4875    count_include_pad,
4876    divisor_override=None,
4877):
4878    assert divisor_override is None or divisor_override != 0, "divisor must be not zero"
4879    if not stride:
4880        stride = kernel_size
4881    if not padding:
4882        padding = [0, 0, 0]
4883
4884    assert isinstance(grad_output, TensorBox)
4885    assert isinstance(x, TensorBox)
4886    assert len(kernel_size) == 3
4887    assert len(stride) == 3
4888    assert len(padding) == 3
4889    assert len(x.get_size()) in (4, 5)
4890
4891    grad_output.realize_hint()
4892
4893    *batch, depth, height, width = x.get_size()
4894
4895    d_out, ceil_mode_d = pooling_size(depth, 0, kernel_size, stride, padding, ceil_mode)
4896    h_out, ceil_mode_h = pooling_size(
4897        height, 1, kernel_size, stride, padding, ceil_mode
4898    )
4899    w_out, ceil_mode_w = pooling_size(width, 2, kernel_size, stride, padding, ceil_mode)
4900
4901    grad_loader = grad_output.make_loader()
4902    had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w
4903
4904    *_, pooled_depth, pooled_height, pooled_width = grad_output.get_size()
4905    new_size = list(x.get_size())
4906    dtype = x.get_dtype()
4907
4908    d_window_size, h_window_size, w_window_size = (
4909        max(
4910            max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1)
4911            for d in range(kernel_size[i] * 2)
4912        )
4913        for i in range(3)
4914    )
4915
4916    window_size = d_window_size * h_window_size * w_window_size
4917    if window_size > 125:
4918        # Kernel size too big. Results in hard-to-optimize Triton code.
4919        return fallback_avg_pool3d_backward(
4920            grad_output,
4921            x,
4922            kernel_size,
4923            stride,
4924            padding,
4925            ceil_mode,
4926            count_include_pad,
4927            divisor_override,
4928        )
4929
4930    def compute_pool_size_without_padding(pd, ph, pw):
4931        stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride)
4932        pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding)
4933        kernel_d, kernel_h, kernel_w = (
4934            ops.constant(k, torch.int32) for k in kernel_size
4935        )
4936
4937        dstart, hstart, wstart = (
4938            ops.sub(ops.mul(p, s), pad)
4939            for p, s, pad in zip(
4940                [pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w]
4941            )
4942        )
4943        dend, hend, wend = (
4944            ops.minimum(
4945                ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad)
4946            )
4947            for start, k, dim, pad in zip(
4948                [dstart, hstart, wstart],
4949                [kernel_d, kernel_h, kernel_w],
4950                [depth, height, width],
4951                [pad_d, pad_h, pad_w],
4952            )
4953        )
4954        dstart, hstart, wstart = (
4955            ops.maximum(start, ops.constant(0, torch.int32))
4956            for start in [dstart, hstart, wstart]
4957        )
4958        dend, hend, wend = (
4959            ops.minimum(end, ops.index_expr(dim, torch.int32))
4960            for end, dim in zip([dend, hend, wend], [depth, height, width])
4961        )
4962        divide_factor = ops.mul(
4963            ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart)
4964        )
4965        return divide_factor
4966
4967    def fn(idx):
4968        *prefix, d, h, w = idx
4969        d, h, w = (v + pad for v, pad in zip([d, h, w], padding))
4970
4971        pdstart, phstart, pwstart = (
4972            ops.index_expr(FloorDiv(v - k + s, s), torch.int32)
4973            for v, k, s in zip([d, h, w], kernel_size, stride)
4974        )
4975
4976        pdend, phend, pwend = (
4977            ops.index_expr(FloorDiv(v, s) + 1, torch.int32)
4978            for v, s in zip([d, h, w], stride)
4979        )
4980
4981        pdstart, phstart, pwstart = (
4982            ops.maximum(pstart, ops.constant(0, torch.int32))
4983            for pstart in [pdstart, phstart, pwstart]
4984        )
4985        pdend, phend, pwend = (
4986            ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32))
4987            for pend, pooled_dim in zip(
4988                [pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width]
4989            )
4990        )
4991
4992        gradient = None
4993        # Iterate over the 3D region to accumulate gradients
4994        for pd_ in range(d_window_size):
4995            for ph_ in range(h_window_size):
4996                for pw_ in range(w_window_size):
4997                    pd, ph, pw = (
4998                        ops.add(pstart, ops.constant(p_, torch.int32))
4999                        for pstart, p_ in zip(
5000                            [pdstart, phstart, pwstart], [pd_, ph_, pw_]
5001                        )
5002                    )
5003
5004                    if divisor_override is not None:
5005                        scale = divisor_override
5006                    elif count_include_pad or not had_padding:
5007                        scale = kernel_size[0] * kernel_size[1] * kernel_size[2]
5008                    else:
5009                        scale = compute_pool_size_without_padding(pd, ph, pw)
5010
5011                    part = ops.truediv(
5012                        grad_loader(
5013                            [
5014                                *prefix,
5015                                ops.indirect_indexing(
5016                                    ops.minimum(
5017                                        pd, ops.sub(pdend, ops.constant(1, torch.int32))
5018                                    ),
5019                                    pooled_depth,
5020                                    check=False,
5021                                ),
5022                                ops.indirect_indexing(
5023                                    ops.minimum(
5024                                        ph, ops.sub(phend, ops.constant(1, torch.int32))
5025                                    ),
5026                                    pooled_height,
5027                                    check=False,
5028                                ),
5029                                ops.indirect_indexing(
5030                                    ops.minimum(
5031                                        pw, ops.sub(pwend, ops.constant(1, torch.int32))
5032                                    ),
5033                                    pooled_width,
5034                                    check=False,
5035                                ),
5036                            ]
5037                        ),
5038                        scale,
5039                    )
5040
5041                    mask = ops.and_(
5042                        ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)),
5043                        ops.lt(pw, pwend),
5044                    )
5045                    if gradient is None:
5046                        gradient = ops.where(
5047                            mask, part, ops.constant(0.0, torch.float32)
5048                        )
5049                    else:
5050                        gradient = ops.where(mask, ops.add(gradient, part), gradient)
5051        assert gradient is not None
5052        return gradient
5053
5054    rv = Pointwise.create(
5055        device=grad_output.get_device(),
5056        dtype=dtype,
5057        inner_fn=fn,
5058        ranges=new_size,
5059    )
5060    return rv
5061
5062
5063def _validate_reduction_axis(x, axis):
5064    size = x.get_size()
5065    if isinstance(axis, int):
5066        axis = [axis]
5067    elif not axis:
5068        axis = range(len(size))
5069    if len(size) == 0:
5070        assert tuple(axis) in [(), (0,), (-1,)], f"invalid axis: {axis}"
5071        return []
5072    axis = list(axis)
5073    for i in range(len(axis)):
5074        if axis[i] < 0:
5075            axis[i] += len(size) if len(size) else 1
5076        assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0)
5077    assert len(set(axis)) == len(axis), "reduction axis not unique"
5078    return axis
5079
5080
5081def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype):
5082    if dtype is not None:
5083        x = to_dtype(x, dtype)
5084    size = x.get_size()
5085    axis = set(_validate_reduction_axis(x, axis))
5086
5087    kept_sizes = []
5088    kept_idx = []
5089    reduced_sizes = []
5090    reduced_idx = []
5091    for i in range(len(size)):
5092        if i in axis:
5093            reduced_idx.append(i)
5094            reduced_sizes.append(size[i])
5095        else:
5096            kept_idx.append(i)
5097            kept_sizes.append(size[i])
5098
5099    def loader(index, reduction_index):
5100        assert len(reduction_index) == len(reduced_idx)
5101        if keepdims:
5102            assert len(index) == len(size)
5103            index = [index[i] for i in kept_idx]
5104        assert len(index) == len(kept_idx)
5105        new_index = [None] * (len(index) + len(reduction_index))
5106        for idx, var in itertools.chain(
5107            zip(kept_idx, index), zip(reduced_idx, reduction_index)
5108        ):
5109            new_index[idx] = var
5110        return inner_loader(new_index)
5111
5112    if keepdims:
5113        new_size = list(size)
5114        for i in reduced_idx:
5115            new_size[i] = sympy.Integer(1)
5116    else:
5117        new_size = kept_sizes
5118
5119    inner_loader = x.make_loader()
5120    return dict(
5121        device=x.get_device(),
5122        dst_dtype=override_return_dtype or x.get_dtype(),
5123        src_dtype=x.get_dtype(),
5124        inner_fn=loader,
5125        ranges=new_size,
5126        reduction_ranges=reduced_sizes,
5127    )
5128
5129
5130def make_reduction(reduction_type: str, override_return_dtype=None):
5131    def inner(x, axis=None, keepdims=False, *, dtype=None):
5132        kwargs = _make_reduction_inner(
5133            x,
5134            axis=axis,
5135            keepdims=keepdims,
5136            dtype=dtype,
5137            override_return_dtype=override_return_dtype,
5138        )
5139        result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs)
5140        if isinstance(
5141            result.data.data, Reduction
5142        ):  # Only realize if reduction isn't unrolled
5143            result.realize()
5144        return result
5145
5146    return inner
5147
5148
5149def _make_scan_inner(x, *, axis, dtype):
5150    if dtype is not None:
5151        x = to_dtype(x, dtype)
5152    size = x.get_size()
5153    axis = _validate_dim(x, axis)
5154
5155    return dict(
5156        device=x.get_device(),
5157        dtypes=(x.get_dtype(),),
5158        inner_fns=(x.make_loader(),),
5159        size=x.get_size(),
5160        axis=axis,
5161    )
5162
5163
5164@register_lowering(aten.mean)
5165def mean(x, axis=None, keepdim=False, *, dtype=None):
5166    if dtype is not None:
5167        x = to_dtype(x, dtype)
5168    size = x.get_size()
5169    axis = _validate_reduction_axis(x, axis)
5170    # compute in higher-precision until end of mean lowering
5171    output_dtype = x.get_dtype()
5172    if output_dtype in (torch.float16, torch.bfloat16):
5173        x = to_dtype(x, torch.float)
5174    sum_result = sum_(x, axis, keepdim)
5175    denom = sympy_product(size[i] for i in axis)
5176    denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device())
5177    denom = ExpandView.create(denom, list(sum_result.get_size()))
5178    return to_dtype(div(sum_result, denom), output_dtype)
5179
5180
5181def var_mean_sum_(x, axis, correction, keepdim, return_mean):
5182    if correction is None:
5183        correction = 1
5184
5185    size = x.get_size()
5186    axis = _validate_reduction_axis(x, axis)
5187    x_mean = mean(x, axis, keepdim=True)
5188    if return_mean:
5189        x_mean.realize()
5190
5191    diffs = square(sub(x, x_mean))
5192    sum_result = sum_(diffs, axis, keepdim)
5193
5194    denom = sympy_product(size[i] for i in axis)
5195    if correction:
5196        denom = sympy.Max(denom - correction, 0)
5197    denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device())
5198    denom = ExpandView.create(denom, list(sum_result.get_size()))
5199    x_var = div(sum_result, denom)
5200    if not return_mean:
5201        return (x_var,)
5202
5203    x_mean = x_mean if keepdim else squeeze(x_mean, axis)
5204    return x_var, x_mean
5205
5206
5207def use_two_step_variance(x, axis, keepdim):
5208    # Instead of unrolling welford, just unroll the simpler two-step var
5209    axis = _validate_reduction_axis(x, axis)
5210    kwargs = _make_reduction_inner(
5211        x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None
5212    )
5213
5214    ranges = kwargs["ranges"]
5215    reduction_numel = sympy_product(kwargs["reduction_ranges"])
5216    return (
5217        isinstance(reduction_numel, sympy.Integer)
5218        and int(reduction_numel) < config.unroll_reductions_threshold
5219        and sympy_product(ranges) != 1
5220    )
5221
5222
5223def var_mean_welford_(x, axis, *, correction, keepdim, return_mean):
5224    if correction is None:
5225        correction = 1
5226
5227    kwargs = _make_reduction_inner(
5228        x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None
5229    )
5230    loader = kwargs.pop("inner_fn")
5231    kwargs.pop("dst_dtype")
5232    kwargs.pop("src_dtype")
5233
5234    mean, m2, _ = ir.WelfordReduction.create(
5235        inner_fns=(loader,),
5236        reduction_type="welford_reduce",
5237        dtype=x.get_dtype(),
5238        **kwargs,
5239    )
5240    m2.realize()
5241
5242    dtype = x.get_dtype()
5243    size = x.get_size()
5244    axis = _validate_reduction_axis(x, axis)
5245    rnumel = sympy_product(size[i] for i in axis)
5246
5247    def get_constant_or_index_expr(x, dtype):
5248        if isinstance(x, sympy.Expr) and not x.is_number:
5249            return ops.to_dtype(ops.index_expr(x, torch.int64), dtype)
5250        return ops.constant(x, dtype)
5251
5252    def scale_fn(data):
5253        c = get_constant_or_index_expr(correction, dtype)
5254        N = get_constant_or_index_expr(rnumel, dtype)
5255        zero = ops.constant(0, dtype)
5256        return data / ops.maximum(zero, N - c)
5257
5258    var = make_pointwise(scale_fn)(m2)
5259
5260    if return_mean:
5261        mean.realize()
5262        return var, mean
5263    return (var,)
5264
5265
5266def var_mean_helper_(x, *, axis, correction, keepdim, return_mean):
5267    out_dtype = x.get_dtype()
5268    compute_dtype = get_computation_dtype(out_dtype)
5269    x = to_dtype(x, compute_dtype, copy=False)
5270    kwargs = dict(
5271        x=x,
5272        axis=axis,
5273        correction=correction,
5274        keepdim=keepdim,
5275        return_mean=return_mean,
5276    )
5277    output = (
5278        var_mean_sum_(**kwargs)
5279        if use_two_step_variance(x, axis=axis, keepdim=keepdim)
5280        else var_mean_welford_(**kwargs)
5281    )
5282    output = tuple(to_dtype(x, out_dtype, copy=False) for x in output)
5283    return output[0] if not return_mean else output
5284
5285
5286@register_lowering([aten.var, prims.var])
5287def var_(x, axis=None, *, correction=None, keepdim=False):
5288    return var_mean_helper_(
5289        x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False
5290    )
5291
5292
5293@register_lowering(aten.var_mean)
5294def var_mean(x, axis=None, *, correction=None, keepdim=False):
5295    return var_mean_helper_(
5296        x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True
5297    )
5298
5299
5300def pow_recursive(x, y, dtype):
5301    if y < 0:
5302        return pow_recursive(ops.reciprocal(x), -y, dtype)
5303    if y == 0:
5304        return ops.constant(1, dtype)
5305    if y == 1:
5306        return x
5307
5308    result = pow_recursive(x, y // 2, dtype)
5309    result = ops.mul(result, result)
5310    if (y % 2) == 1:
5311        result = ops.mul(result, x)
5312    return result
5313
5314
5315@make_pointwise
5316def pow_native(a, b):
5317    return ops.pow(a, b)
5318
5319
5320fallback_pow_tensor_tensor = fallback_handler(
5321    aten.pow.Tensor_Tensor, add_to_fallback_set=False
5322)
5323fallback_pow_scalar = fallback_handler(aten.pow.Scalar, add_to_fallback_set=False)
5324fallback_pow_tensor_scalar = fallback_handler(
5325    aten.pow.Tensor_Scalar, add_to_fallback_set=False
5326)
5327
5328
5329@register_lowering(aten.pow, broadcast=True)
5330def pow(a, b):
5331    if isinstance(b, float) and b == int(b):
5332        return pow(a, int(b))
5333    elif isinstance(b, float) and b == 0.5:
5334        return sqrt(a)
5335    elif isinstance(b, int) and b == 1:
5336        return clone(a)
5337
5338    # Type promotion ensures all tensor arguments have the same type
5339    dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox))
5340    is_integer_pow = is_integer_dtype(dtype)
5341
5342    # Optimize away small fixed powers, or for integers avoid falling back to ATen
5343    embed_exponent = isinstance(b, int) and (
5344        -32 < b < 32 or (is_integer_pow and b >= 0)
5345    )
5346    if embed_exponent:
5347        loader = a.make_loader()
5348
5349        def fn(idx):
5350            return pow_recursive(loader(idx), b, a.get_dtype())
5351
5352        return Pointwise.create(
5353            device=a.get_device(),
5354            dtype=a.get_dtype(),
5355            inner_fn=fn,
5356            ranges=a.get_size(),
5357        )
5358
5359    if isinstance(a, Number):
5360        if a == 1:
5361            return full_like(b, 1)
5362        if a == 2 and is_float_dtype(b.get_dtype()):
5363            return exp2(b)
5364
5365    if is_integer_pow:
5366        # ops.pow doesn't work for integers
5367        if isinstance(a, Number):
5368            return fallback_pow_scalar(a, b)
5369        elif isinstance(b, Number):
5370            return fallback_pow_tensor_scalar(a, b)
5371        else:
5372            return fallback_pow_tensor_tensor(a, b)
5373
5374    return pow_native(a, b)
5375
5376
5377def mutate_to(changed, val, unsafe_alias=False):
5378    if isinstance(changed, TensorBox):
5379        changed_data = changed.data
5380    else:
5381        changed_data = changed
5382    if isinstance(val, TensorBox):
5383        val = val.data
5384
5385    if not isinstance(val, ir.StorageBox):
5386        # introduce a copy to handle views
5387        val = Pointwise.create(
5388            device=changed.get_device(),
5389            dtype=changed.get_dtype(),
5390            inner_fn=val.make_loader(),
5391            ranges=changed.get_size(),
5392        ).data
5393        assert isinstance(val, ir.StorageBox)
5394
5395    if isinstance(changed_data, ir.StorageBox) and not (
5396        changed_data.is_input_buffer()
5397        # In AOTI, module parameters and buffers are not lifted as graph inputs
5398        or changed_data.is_module_buffer()
5399        or isinstance(changed_data.data, ir.NopKernel)
5400    ):
5401        # Fast path, just swing the data pointer
5402        val.realize()
5403        changed_data.data = val.data
5404        return changed
5405
5406    ir.MutationLayoutSHOULDREMOVE.realize_into(
5407        val, changed_data, unsafe_alias=unsafe_alias
5408    )
5409    return changed
5410
5411
5412@register_lowering(aten.fill_)
5413def fill_(x, fill_value):
5414    return mutate_to(x, full_like(x, fill_value))
5415
5416
5417@register_lowering(aten.copy_, type_promotion_kind=None)
5418def copy_(dst, src, non_blocking=False):
5419    if dst is src:
5420        # dst.copy_(dst) can happen from the reinplacing pass
5421        return dst
5422    src = to_device(src, dst.get_device())
5423    src = to_dtype(src, dst.get_dtype())
5424    src = expand(src, dst.get_size())
5425    return mutate_to(dst, src)
5426
5427
5428@make_pointwise
5429def floordiv(a, b):
5430    return ops.floordiv(a, b)
5431
5432
5433@make_pointwise
5434def truncdiv(a, b):
5435    return ops.truncdiv(a, b)
5436
5437
5438@register_lowering(aten.div, broadcast=True)
5439def div_mode(a, b, rounding_mode=None):
5440    both_integer = is_integer_type(a) and is_integer_type(b)
5441    both_boolean = is_boolean_type(a) and is_boolean_type(b)
5442
5443    # floordiv and truncdiv need special handling for integer tensors on Triton,
5444    # see the discussion at https://github.com/openai/triton/issues/605
5445    if rounding_mode == "floor":
5446        assert not both_boolean, "floordiv operands can not be boolean at the same time"
5447        return floordiv(a, b) if both_integer else floor(div(a, b))
5448    if rounding_mode == "trunc":
5449        assert not both_boolean, "truncdiv operands can not be boolean at the same time"
5450        return truncdiv(a, b) if both_integer else trunc(div(a, b))
5451    return div(a, b)
5452
5453
5454@register_lowering([aten.mul], broadcast=True)
5455def mul(a, b):
5456    both_bool = is_boolean_type(a) and is_boolean_type(b)
5457    if both_bool:
5458        return logical_and(a, b)
5459    else:
5460        fn = ops_wrapper(aten.mul.__name__)
5461        return make_pointwise(fn)(a, b)
5462
5463
5464def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]:
5465    """Try convert an arbitrary IR node into an ir.Constant value"""
5466
5467    # First try unwrapping the IRNode to see if it is already an ir.Constant
5468    # Optional step, but avoids unnecessary inner_fn evaluation.
5469    if isinstance(x, ir.MutableBox):
5470        return get_constant_value(x.data)
5471    if isinstance(x, ir.BaseView):
5472        return get_constant_value(x.unwrap_view())
5473    if isinstance(x, ir.Constant):
5474        return x
5475
5476    # If the unwrapped node is not an ir.Constant, try evaluating inner_fn
5477    # to see if the returned value is from an `ops.constant` call
5478    if not isinstance(x, ir.Loops):
5479        return None
5480
5481    handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device())
5482    with V.set_ops_handler(handler), patch.object(
5483        ir.FlexibleLayout, "allow_indexing", True
5484    ):
5485        out = x.inner_fn(*x.inner_fn_args())
5486
5487    assert isinstance(out, torch._inductor.virtualized.OpsValue)
5488    if isinstance(out.value, ir.Constant):
5489        return out.value
5490    return None
5491
5492
5493# NOTE: prims.div maps to a / b in C, so performs truncation division on
5494#   integer inputs and true division for floating and complex inputs.
5495@register_lowering([prims.div], broadcast=True)
5496def div_prim(a, b):
5497    is_integral = all(is_boolean_type(x) or is_integer_type(x) for x in [a, b])
5498
5499    if is_integral:
5500        return truncdiv(a, b)
5501
5502    if (divisor := get_constant_value(b)) is not None:
5503        # Replace divide by constant with multiply by reciprocal
5504        if divisor.value == 0:
5505            reciprocal = math.copysign(float("inf"), divisor.value)
5506        else:
5507            reciprocal = 1.0 / divisor.value
5508        return mul(a, reciprocal)
5509
5510    def fn(*args):
5511        return ops.truediv(*args)
5512
5513    return make_pointwise(fn)(a, b)
5514
5515
5516@register_lowering(
5517    [aten.true_divide, aten.div.Tensor],
5518    broadcast=True,
5519    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
5520)
5521def div(a, b):
5522    a, b = promote_constants(
5523        (a, b), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
5524    )
5525    return div_prim(a, b)
5526
5527
5528@register_lowering([aten.fmod, prims.fmod], broadcast=True)
5529def fmod(a, b):
5530    is_integral = is_boolean_type(a) or is_integer_type(a)
5531
5532    if is_integral:
5533
5534        def fn(a, b):
5535            return ops.mod(a, b)
5536
5537    else:
5538
5539        def fn(a, b):
5540            return ops.fmod(a, b)
5541
5542    return make_pointwise(fn)(a, b)
5543
5544
5545@register_lowering(aten.rsqrt)
5546def rsqrt(x):
5547    dtype = x.get_dtype()
5548    if is_integer_dtype(dtype) or is_boolean_dtype(dtype):
5549        x = to_dtype(x, torch.get_default_dtype())
5550
5551    def _rsqrt(x):
5552        return ops.rsqrt(x)
5553
5554    return make_pointwise(_rsqrt)(x)
5555
5556
5557@register_lowering([aten.sum, prims.sum])
5558def sum_(x, axis=None, keepdims=False, *, dtype=None):
5559    if (
5560        is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
5561    ) and dtype is None:
5562        dtype = torch.int64
5563
5564    fn = make_reduction("sum", override_return_dtype=dtype)
5565    return fn(x, axis, keepdims, dtype=dtype)
5566
5567
5568fallback_cumsum = fallback_handler(aten.cumsum.default)
5569fallback_cumprod = fallback_handler(aten.cumprod.default)
5570fallback_logcumsumexp = fallback_handler(aten.logcumsumexp.default)
5571fallback_cummax = fallback_handler(aten.cummax.default)
5572fallback_cummin = fallback_handler(aten.cummin.default)
5573
5574
5575@register_lowering(aten.cumsum)
5576def cumsum(x, axis=None, dtype=None):
5577    if (
5578        is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
5579    ) and dtype is None:
5580        dtype = torch.int64
5581
5582    if len(x.get_size()) == 0:
5583        assert axis in [0, -1]
5584        dtype = dtype or x.get_dtype()
5585        return to_dtype(x, dtype, copy=True)
5586
5587    def combine_fn(a_tuple, b_tuple):
5588        (a,) = a_tuple
5589        (b,) = b_tuple
5590        return (ops.add(a, b),)
5591
5592    kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
5593    (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn)
5594    if result is None:
5595        return fallback_cumsum(x, dim=axis, dtype=dtype)
5596    return result
5597
5598
5599@register_lowering(aten.cumprod)
5600def cumprod(x, axis=None, dtype=None):
5601    if (
5602        is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
5603    ) and dtype is None:
5604        dtype = torch.int64
5605
5606    if len(x.get_size()) == 0:
5607        assert axis in [0, -1]
5608        dtype = dtype or x.get_dtype()
5609        return to_dtype(x, dtype, copy=True)
5610
5611    def combine_fn(a_tuple, b_tuple):
5612        (a,) = a_tuple
5613        (b,) = b_tuple
5614        return (ops.mul(a, b),)
5615
5616    kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
5617    (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn)
5618    if result is None:
5619        return fallback_cumprod(x, dim=axis, dtype=dtype)
5620    return result
5621
5622
5623@register_lowering(aten.logcumsumexp)
5624def logcumsumexp(x, dim):
5625    def log_add_exp_helper(a_tuple, b_tuple):
5626        (a,) = a_tuple
5627        (b,) = b_tuple
5628        min_v = ops.minimum(a, b)
5629        max_v = ops.maximum(a, b)
5630        mask = (min_v != max_v) | (~ops.isinf(min_v))
5631        return (ops.where(mask, ops.log1p(ops.exp(min_v - max_v)) + max_v, a),)
5632
5633    dtype = x.get_dtype()
5634    if len(x.get_size()) == 0:
5635        assert dim in [0, -1]
5636        return clone(x)
5637
5638    kwargs = _make_scan_inner(x, axis=dim, dtype=dtype)
5639    (result,) = ir.Scan.create(**kwargs, combine_fn=log_add_exp_helper)
5640    if result is None:
5641        return fallback_logcumsumexp(x, dim=dim)
5642    return result
5643
5644
5645@register_lowering(aten.cummax, type_promotion_kind=None)
5646def cummax(x, axis=None):
5647    if len(x.get_size()) == 0:
5648        assert axis in [0, -1]
5649        return clone(x), empty_like(x, dtype=torch.int64)
5650
5651    dtype = x.get_dtype()
5652    combine_fn = ir.get_reduction_combine_fn(
5653        "argmax", dtype=dtype, arg_break_ties_left=False
5654    )
5655
5656    min_value = (
5657        False
5658        if dtype is torch.bool
5659        else (
5660            torch.finfo(dtype).min
5661            if dtype.is_floating_point
5662            else torch.iinfo(dtype).min
5663        )
5664    )
5665
5666    kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
5667    kwargs["dtypes"] = (dtype, torch.int64)
5668    kwargs["inner_fns"] = (x.make_loader(), lambda _: "rindex")
5669    values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn)  # type: ignore[arg-type] # next PR
5670    if values is None:
5671        return fallback_cummax(x, dim=axis)
5672    return values, indices
5673
5674
5675@register_lowering(aten.cummin, type_promotion_kind=None)
5676def cummin(x, axis=None):
5677    if len(x.get_size()) == 0:
5678        assert axis in [0, -1]
5679        return clone(x), empty_like(x, dtype=torch.int64)
5680
5681    dtype = x.get_dtype()
5682    combine_fn = ir.get_reduction_combine_fn(
5683        "argmin", dtype=dtype, arg_break_ties_left=False
5684    )
5685
5686    max_value = (
5687        True
5688        if dtype is torch.bool
5689        else (
5690            torch.finfo(dtype).max
5691            if dtype.is_floating_point
5692            else torch.iinfo(dtype).max
5693        )
5694    )
5695
5696    kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
5697    kwargs["dtypes"] = (dtype, torch.int64)
5698    kwargs["inner_fns"] = (x.make_loader(), lambda _: "rindex")
5699    values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn)  # type: ignore[arg-type] # next PR
5700    if values is None:
5701        return fallback_cummin(x, dim=axis)
5702    return values, indices
5703
5704
5705@register_lowering(aten.prod)
5706def prod(x, axis=None, keepdims=False, *, dtype=None):
5707    if (
5708        is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
5709    ) and dtype is None:
5710        dtype = torch.int64
5711
5712    fn = make_reduction("prod", override_return_dtype=dtype)
5713    return fn(x, axis, keepdims, dtype=dtype)
5714
5715
5716@register_lowering(aten.any)
5717def reduce_any(x, dim=None, keepdim=False):
5718    x = to_dtype(x, torch.bool)
5719    return make_reduction("any")(x, axis=dim, keepdims=keepdim)
5720
5721
5722@register_lowering(aten.max, type_promotion_kind=None)
5723def reduce_max(x, dim=None, keepdim=False):
5724    if dim is not None:
5725        return (
5726            reduce_amax(x, axis=dim, keepdims=keepdim),
5727            reduce_argmax(x, axis=dim, keepdims=keepdim),
5728        )
5729
5730    return reduce_amax(x, axis=None, keepdims=keepdim)
5731
5732
5733@register_lowering(aten.min, type_promotion_kind=None)
5734def reduce_min(x, dim=None, keepdim=False):
5735    if dim is not None:
5736        return (
5737            reduce_amin(x, axis=dim, keepdims=keepdim),
5738            reduce_argmin(x, axis=dim, keepdims=keepdim),
5739        )
5740
5741    return reduce_amin(x, axis=None, keepdims=keepdim)
5742
5743
5744register_lowering(prims.xor_sum)(make_reduction("xor_sum"))
5745reduce_amax = register_lowering(aten.amax)(make_reduction("max"))
5746reduce_amin = register_lowering(aten.amin)(make_reduction("min"))
5747reduce_argmax = register_lowering(aten.argmax)(
5748    make_reduction("argmax", override_return_dtype=torch.int64)
5749)
5750reduce_argmin = register_lowering(aten.argmin)(
5751    make_reduction("argmin", override_return_dtype=torch.int64)
5752)
5753
5754add = register_pointwise(
5755    aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or"
5756)
5757
5758sort_fallback = fallback_handler(aten.sort.stable, add_to_fallback_set=False)
5759
5760
5761@register_lowering(aten.sort.stable, type_promotion_kind=None)
5762def sort_stable(x, *, stable=None, dim=-1, descending=False):
5763    if stable is None:
5764        stable = False
5765
5766    shape = x.get_size()
5767    device = x.get_device()
5768    dim = canonicalize_dim(len(shape), dim)
5769    if len(shape) == 0:
5770        return clone(x), _full(0, device, torch.int64, shape)
5771
5772    dim_size = shape[dim] if len(shape) else 1
5773    if not V.graph.sizevars.statically_known_lt(dim_size, torch.iinfo(torch.int16).max):
5774        return sort_fallback(x, stable=stable, dim=dim, descending=descending)
5775
5776    indices = iota(
5777        dim_size, start=0, step=1, dtype=torch.int16, device=device, requires_grad=False
5778    )
5779    view_shape = [1] * len(shape)
5780    if len(shape):
5781        view_shape[dim] = dim_size
5782    indices = view(indices, view_shape)
5783    indices = expand(indices, shape)
5784
5785    values, indices = ir.Sort.create(
5786        device=device,
5787        dtypes=(x.dtype, indices.dtype),
5788        inner_fns=(x.make_loader(), indices.make_loader()),
5789        size=shape,
5790        axis=dim,
5791        stable=stable,
5792        descending=descending,
5793    )
5794    if values is None:
5795        return sort_fallback(x, stable=stable, dim=dim, descending=descending)
5796
5797    assert indices is not None
5798    return values, to_dtype(indices, torch.int64)
5799
5800
5801@register_lowering(aten.sort.default, type_promotion_kind=None)
5802def sort(x, dim=-1, descending=False):
5803    return sort_stable(x, stable=False, dim=dim, descending=descending)
5804
5805
5806def register_pointwise_numeric(op, name=None, triton_fallback=None):
5807    return register_pointwise(
5808        op,
5809        name=name,
5810        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
5811        triton_fallback=triton_fallback,
5812    )
5813
5814
5815def register_pointwise_numeric_ldf64(op):
5816    return register_pointwise(
5817        op,
5818        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
5819        use_libdevice_for_f64=True,
5820    )
5821
5822
5823exp = register_pointwise_numeric_ldf64(aten.exp)
5824exp2 = register_pointwise_numeric(aten.exp2)
5825expm1 = register_pointwise_numeric(aten.expm1)
5826relu = register_pointwise(aten.relu)
5827sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid)
5828sqrt = register_pointwise_numeric_ldf64(aten.sqrt)
5829square = register_pointwise(aten.square)
5830sub = register_pointwise(aten.sub, allow_alpha=True)
5831register_pointwise_numeric_ldf64(aten.cos)
5832register_pointwise_numeric_ldf64(aten.sin)
5833abs = register_pointwise(aten.abs)
5834bitwise_and = register_pointwise(aten.bitwise_and)
5835bitwise_left_shift = register_pointwise(aten.bitwise_left_shift)
5836bitwise_not = register_pointwise(
5837    aten.bitwise_not, override_fn_when_input_bool="logical_not"
5838)
5839bitwise_or = register_pointwise(aten.bitwise_or)
5840bitwise_right_shift = register_pointwise(aten.bitwise_right_shift)
5841bitwise_xor = register_pointwise(aten.bitwise_xor)
5842register_pointwise_numeric(aten.lgamma)
5843erf = register_pointwise_numeric(aten.erf)
5844register_lowering(
5845    aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
5846)(erf)
5847
5848register_pointwise_numeric(aten.log1p)
5849register_pointwise_numeric(aten.tan)
5850register_pointwise_numeric(aten.tanh)
5851register_pointwise_numeric_ldf64(aten.log)
5852logical_and = register_pointwise(
5853    aten.logical_and,
5854    type_promotion_kind=None,
5855    convert_input_to_bool=True,
5856    override_return_dtype=torch.bool,
5857)
5858logical_not = register_pointwise(
5859    aten.logical_not,
5860    type_promotion_kind=None,
5861    convert_input_to_bool=True,
5862    override_return_dtype=torch.bool,
5863)
5864logical_or = register_pointwise(
5865    aten.logical_or,
5866    type_promotion_kind=None,
5867    convert_input_to_bool=True,
5868    override_return_dtype=torch.bool,
5869)
5870logical_xor = register_pointwise(
5871    aten.logical_xor,
5872    type_promotion_kind=None,
5873    convert_input_to_bool=True,
5874    override_return_dtype=torch.bool,
5875)
5876maximum = register_pointwise(aten.maximum)
5877minimum = register_pointwise(aten.minimum)
5878register_lowering(aten.clamp_min)(maximum)
5879register_lowering(aten.clamp_max)(minimum)
5880neg = register_pointwise(aten.neg)
5881abs = register_pointwise(aten.abs)
5882reciprocal = register_pointwise_numeric(aten.reciprocal)
5883register_pointwise(aten.remainder)
5884sign = register_pointwise(aten.sign, override_fn_when_input_bool="identity")
5885register_pointwise(aten.ceil)
5886register_pointwise(aten.signbit, override_return_dtype=torch.bool)
5887
5888register_lowering(aten._neg_view)(neg)
5889
5890register_pointwise(aten.le, override_return_dtype=torch.bool)
5891register_pointwise(aten.lt, override_return_dtype=torch.bool)
5892register_pointwise(aten.ge, override_return_dtype=torch.bool)
5893gt = register_pointwise(aten.gt, override_return_dtype=torch.bool)
5894register_pointwise(aten.eq, override_return_dtype=torch.bool)
5895register_pointwise(aten.ne, override_return_dtype=torch.bool)
5896
5897register_pointwise_numeric(aten.cosh)
5898register_pointwise_numeric(aten.sinh)
5899register_pointwise_numeric(aten.acos)
5900register_pointwise_numeric(aten.acosh)
5901register_pointwise_numeric(aten.asin)
5902register_pointwise_numeric(aten.asinh)
5903register_pointwise_numeric(aten.atan2)
5904register_pointwise_numeric(aten.atan)
5905register_pointwise_numeric(aten.atanh)
5906register_pointwise_numeric(aten.copysign)
5907register_pointwise_numeric(aten.erfc)
5908register_pointwise_numeric(aten.erfinv)
5909register_pointwise_numeric(aten.hypot)
5910register_pointwise_numeric(aten.log10)
5911register_pointwise_numeric(aten.log2)
5912register_pointwise_numeric(aten.nextafter)
5913
5914from .codegen.common import BackendFeature, pointwise_overrides_data
5915
5916
5917def _get_pointwise_overrides(ns, name):
5918    data = pointwise_overrides_data[name]
5919    op = getattr(ns, data.name, None)
5920    if op is None:
5921        return
5922
5923    def make_triton_fallback(op):
5924        if data.triton is None:
5925            return fallback_handler(op)
5926
5927    if isinstance(op, torch._ops.OpOverloadPacket):
5928        for olname in op.overloads():
5929            ol = getattr(op, olname)
5930            yield ol, data.type_promotion_kind, make_triton_fallback(ol)
5931    else:
5932        yield op, data.type_promotion_kind, make_triton_fallback(op)
5933
5934
5935for name in pointwise_overrides_data:
5936    for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides(
5937        aten, name
5938    ):
5939        register_pointwise(
5940            op,
5941            name=name,
5942            type_promotion_kind=type_promotion_kind,
5943            triton_fallback=triton_fallback,
5944        )
5945
5946    for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides(
5947        prims, name
5948    ):
5949        register_pointwise(
5950            op,
5951            name=name,
5952            type_promotion_kind=type_promotion_kind,
5953            triton_fallback=triton_fallback,
5954        )
5955
5956
5957foreach_add_list = register_foreach_pointwise(
5958    aten._foreach_add.List, add, allow_alpha=True
5959)
5960foreach_add_scalar = register_foreach_pointwise(
5961    aten._foreach_add.Scalar, add, allow_alpha=True
5962)
5963register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True)
5964foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul)
5965foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul)
5966register_foreach_pointwise(aten._foreach_sub.List, sub)
5967register_foreach_pointwise(aten._foreach_sub.Scalar, sub)
5968register_foreach_pointwise(aten._foreach_neg.default, neg)
5969register_foreach_pointwise(aten._foreach_abs.default, abs)
5970register_foreach_pointwise(aten._foreach_pow.Scalar, pow)
5971register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow)
5972foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div)
5973foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div)
5974register_foreach_pointwise(aten._foreach_sqrt, sqrt)
5975register_foreach_pointwise(aten._foreach_maximum.List, maximum)
5976register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum)
5977register_foreach_pointwise(aten._foreach_minimum.List, minimum)
5978register_foreach_pointwise(aten._foreach_minimum.Scalar, minimum)
5979register_foreach_pointwise(aten._foreach_clamp_min.List, maximum)
5980register_foreach_pointwise(aten._foreach_clamp_min.Scalar, maximum)
5981register_foreach_pointwise(aten._foreach_clamp_max.List, minimum)
5982register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum)
5983register_foreach_pointwise(aten._foreach_reciprocal, reciprocal)
5984register_foreach_pointwise(aten._foreach_sign, sign)
5985register_foreach_pointwise(aten._foreach_copy, copy)
5986
5987
5988# these are only encountered as outputs of the graph
5989# reinplacing epilogue copies improves compile time
5990# by removing extra buffers sent to the scheduler.
5991def register_foreach_inplace(aten_op, outplace_aten_op, outplace_op):
5992    inplaceable_foreach_ops[outplace_aten_op] = aten_op
5993    inplace_foreach_ops.add(aten_op)
5994
5995    def fn(*args, **kwargs):
5996        results = outplace_op(*args, **kwargs)
5997        mut_results = []
5998        for arg, result in zip(args[0], results):
5999            mut_results.append(mutate_to(arg, result, unsafe_alias=True))
6000
6001        return mut_results
6002
6003    _register_foreach_lowering(aten_op, fn)
6004
6005
6006register_foreach_inplace(
6007    aten._foreach_add_.List, aten._foreach_add.List, foreach_add_list
6008)
6009register_foreach_inplace(
6010    aten._foreach_add_.Scalar, aten._foreach_add.Scalar, foreach_add_scalar
6011)
6012register_foreach_inplace(
6013    aten._foreach_mul_.List, aten._foreach_mul.List, foreach_mul_list
6014)
6015register_foreach_inplace(
6016    aten._foreach_mul_.Scalar, aten._foreach_mul.Scalar, foreach_mul_scalar
6017)
6018register_foreach_inplace(
6019    aten._foreach_div_.List, aten._foreach_div.List, foreach_div_list
6020)
6021register_foreach_inplace(
6022    aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar
6023)
6024
6025
6026def register_inplace(aten_op, outplace_op):
6027    @register_lowering(aten_op, type_promotion_kind=None)
6028    def fn(*args, **kwargs):
6029        result = outplace_op(*args, **kwargs)
6030        result = to_dtype(result, args[0].get_dtype())
6031        return mutate_to(args[0], result)
6032
6033    return fn
6034
6035
6036register_inplace(aten.add_, add)
6037register_inplace(aten.bitwise_and_, bitwise_and)
6038register_inplace(aten.bitwise_left_shift_, bitwise_left_shift)
6039register_inplace(aten.bitwise_not_, bitwise_not)
6040register_inplace(aten.bitwise_or_, bitwise_or)
6041register_inplace(aten.bitwise_right_shift_, bitwise_right_shift)
6042register_inplace(aten.bitwise_xor_, bitwise_xor)
6043register_inplace(aten.mul_, mul)
6044register_inplace(aten.div_.Tensor, div)
6045register_inplace(aten.div_.Tensor_mode, div_mode)
6046register_inplace(aten.logical_and_, logical_and)
6047register_inplace(aten.logical_not_, logical_not)
6048register_inplace(aten.logical_or_, logical_or)
6049register_inplace(aten.logical_xor_, logical_xor)
6050register_inplace(aten.sub_, sub)
6051register_inplace(aten.relu_, relu)
6052register_inplace(aten.sigmoid_, sigmoid)
6053
6054
6055register_lowering(aten.__and__)(bitwise_and)
6056register_lowering(aten.__lshift__)(bitwise_left_shift)
6057register_lowering(aten.__or__)(bitwise_or)
6058register_lowering(aten.__rshift__)(bitwise_right_shift)
6059register_lowering(aten.__xor__)(bitwise_xor)
6060
6061register_inplace(aten.__iand__, aten.__and__)
6062register_inplace(aten.__ilshift__, aten.__lshift__)
6063register_inplace(aten.__ior__, aten.__or__)
6064register_inplace(aten.__irshift__, aten.__rshift__)
6065register_inplace(aten.__ixor__, aten.__xor__)
6066
6067
6068@register_lowering(aten.sym_constrain_range)
6069def sym_constrain_range(a, min=None, max=None):
6070    return None
6071
6072
6073@register_lowering(aten.sym_size.int)
6074def sym_size(a, dim):
6075    val = V.graph.current_node.meta["val"]
6076    # Note [Can val be an int?]
6077    # ~~~~~~~~~~~~~~~~~~~~~~~~~
6078    # In principle, someone could construct an FX graph where
6079    # a call to size/stride has a val that is a plain int (not
6080    # SymInt).  However, we will maintain the invariant that
6081    # this is not possible: if you are constructing an FX graph
6082    # where there is a call to size/stride that returns an
6083    # int, but you KNOW that int must always be a constant,
6084    # then you do not need trace that call at all (and just
6085    # constant propagate the integer as is.)
6086    assert isinstance(val, torch.SymInt)
6087    return val.node.expr
6088
6089
6090@register_lowering(aten.sym_stride.int)
6091def sym_stride(a, dim):
6092    val = V.graph.current_node.meta["val"]
6093    # See Note [Can val be an int?]
6094    assert isinstance(val, torch.SymInt)
6095    return val.node.expr
6096
6097
6098@register_lowering(aten.sym_numel)
6099def sym_numel(a):
6100    return a.get_numel()
6101
6102
6103for method, func in magic_methods.items():
6104    register_lowering(method_to_operator(method))(func)
6105
6106
6107@register_lowering(aten._foobar)
6108def foobar(self, *args, **kwargs):
6109    raise NotImplementedError("Helpful for debugging")
6110
6111
6112@register_lowering(torch.ops._inductor_test.realize)
6113def _realize(x):
6114    x.realize()
6115    return clone(x)
6116
6117
6118@register_lowering(torch.ops.inductor.resize_storage_bytes_)
6119def resize_storage_bytes_(variable, new_size):
6120    variable.realize()
6121    ir.ResizeStorageBytes(variable, new_size)
6122    return variable
6123
6124
6125@register_lowering(torch.ops.aten.set_.source_Tensor)
6126def set__source_tensor(self, source_tensor):
6127    self.realize()
6128    source_tensor.realize()
6129    return TensorBox.create(ir.SetSourceTensorKernel(self, source_tensor))
6130
6131
6132if hasattr(torch.ops.fsdp, "set_"):
6133
6134    @register_lowering(torch.ops.fsdp.set_.default)
6135    def fsdp_set_(self, source_tensor):
6136        self.realize()
6137        source_tensor.realize()
6138        ir.SetSourceTensorKernel(self, source_tensor)
6139
6140
6141@register_lowering(torch.ops.aten.resize)
6142def resize(x, size, *, memory_format=None):
6143    assert isinstance(x, TensorBox)
6144    assert isinstance(size, (list, tuple))
6145
6146    if memory_format is None:
6147        memory_format = torch.contiguous_format
6148    if memory_format == torch.preserve_format:
6149        raise RuntimeError(f"unsupported memory format: {memory_format}")
6150
6151    if memory_format == torch.channels_last:
6152        assert len(size) == 4
6153    if memory_format == torch.channels_last_3d:
6154        assert len(size) == 5
6155
6156    old_numel = x.get_numel()
6157    dtype = x.get_dtype()
6158    device = x.get_device()
6159
6160    if isinstance(x.data, ir.BaseView):
6161        x.data = x.data.unwrap_view()
6162
6163    if (
6164        torch.are_deterministic_algorithms_enabled()
6165        and torch.utils.deterministic.fill_uninitialized_memory  # type: ignore[attr-defined]
6166    ):
6167        if is_float_dtype(dtype):
6168            uninitalized_val = float("nan")
6169        elif is_integer_dtype(dtype):
6170            uninitalized_val = torch.iinfo(dtype).max
6171        else:
6172            uninitalized_val = True
6173    else:
6174        # using zero as that is what empty does
6175        uninitalized_val = 0.0
6176
6177    if V.graph.sizevars.statically_known_equals(old_numel, 0):  # type: ignore[arg-type]
6178        return full(size, uninitalized_val, dtype=dtype, device=device)
6179
6180    x_flat = as_strided(
6181        x,
6182        [
6183            old_numel,
6184        ],
6185        [
6186            1,
6187        ],
6188    )
6189    flat_loader = x_flat.make_loader()
6190    out_stride = ir.FlexibleLayout.stride_ordered_for_memory_format(size, memory_format)
6191    out_indexer = ir.FixedLayout(device, dtype, size, out_stride).make_indexer()
6192
6193    def inner_fn(idx):
6194        flat_index = out_indexer(idx)
6195        flat_index_expr = ops.index_expr(flat_index, torch.int64)
6196        limit = ops.index_expr(old_numel, torch.int64)
6197        mask = ops.lt(flat_index_expr, limit)
6198        return ops.masked(mask, lambda: flat_loader([flat_index]), uninitalized_val)
6199
6200    out = Pointwise.create(
6201        device=device, dtype=dtype, inner_fn=inner_fn, ranges=list(size)
6202    )
6203    return out
6204
6205
6206from torch._higher_order_ops.auto_functionalize import auto_functionalized
6207
6208
6209make_fallback(auto_functionalized)
6210
6211
6212@register_lowering(triton_kernel_wrapper_mutation)
6213def triton_kernel_wrap_(*, kernel_idx, constant_args_idx, grid, kwargs):
6214    from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
6215
6216    constant_args = kernel_side_table.get_constant_args(constant_args_idx)
6217    ir.UserDefinedTritonKernel(
6218        kernel_idx=kernel_idx,
6219        grid=grid,
6220        kernel_args={**kwargs, **constant_args},
6221    )
6222    return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)}
6223
6224
6225@register_lowering(torch.ops.higher_order.cond)
6226def cond(pred, true_fn, false_fn, operands):
6227    if is_triton(pred) or any(map(is_triton, operands)):
6228        msg = "control flow operator: torch.cond."
6229        if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
6230            msg = f"{msg} Found from : \n {stack_trace}"
6231        V.graph.disable_cudagraphs_reason = msg
6232
6233    result = ir.Conditional.create(pred, true_fn, false_fn, operands)
6234    return list(map(TensorBox.create, result))
6235
6236
6237@register_lowering(torch.ops.higher_order.while_loop)
6238def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
6239    if any(map(is_triton, carried_inputs + additional_inputs)):
6240        msg = "control flow operator: torch.while_loop."
6241        if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
6242            msg = f"{msg} Found from : \n {stack_trace}"
6243        V.graph.disable_cudagraphs_reason = msg
6244
6245    result = ir.WhileLoop.create(cond_fn, body_fn, carried_inputs, additional_inputs)
6246    return list(map(TensorBox.create, result))
6247
6248
6249@register_lowering(associative_scan_op, type_promotion_kind=None)
6250def associative_scan(combine_fn: ir.Subgraph, input, dim: int):
6251    from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph
6252
6253    subgraph_inputs = [
6254        InputDescriptor(dtype=x.get_dtype(), device=x.get_device())
6255        for x in itertools.chain(input, input)
6256    ]
6257    lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs)  # type: ignore[var-annotated]
6258
6259    def wrapped_combine_fn(lhs, rhs):
6260        return lowered_combine_fn(
6261            *pytree.tree_leaves(lhs),
6262            *pytree.tree_leaves(rhs),
6263        )
6264
6265    kwargs = _make_scan_inner(input[0], axis=dim, dtype=None)
6266    kwargs["dtypes"] = tuple(x.get_dtype() for x in input)
6267    kwargs["inner_fns"] = tuple(x.make_loader() for x in input)
6268    result = ir.Scan.create(
6269        combine_fn=wrapped_combine_fn,
6270        can_fallback_to_aten=False,
6271        **kwargs,
6272    )
6273    if result[0] is None:
6274        raise RuntimeError("Unable to generate code for associative_scan op")
6275    return result
6276
6277
6278@register_lowering(torch.ops.prims._sink_tokens.default)
6279def _sink_tokens(tokens):
6280    return None
6281
6282
6283@register_lowering(torch.ops.higher_order.with_effects)
6284def with_effects(token, op, *args, **kwargs):
6285    result = ir.EffectfulKernel.create(op, *args, **kwargs)
6286
6287    from torch._higher_order_ops.effects import get_effect_key
6288
6289    effect_type = get_effect_key(op, args, kwargs)
6290    assert effect_type is not None
6291    effectful_kernel = V.graph.effectful_ops[effect_type]
6292
6293    if result is None:
6294        return (effectful_kernel,)
6295
6296    result = pytree.tree_map_only(ir.MultiOutput, TensorBox.create, result)
6297    if not isinstance(result, (list, tuple)):
6298        return (effectful_kernel, result)
6299    else:
6300        return (effectful_kernel, *result)
6301
6302
6303try:
6304    import torch.distributed._functional_collectives
6305
6306    _c10d_functional = torch.ops._c10d_functional
6307
6308    @register_lowering(_c10d_functional.all_reduce)
6309    def _all_reduce(inp, reduce_op, group_name):
6310        inp = clone(inp)
6311        if config.reorder_for_compute_comm_overlap:
6312            # The horizontal fusion of this clone often severely delays the
6313            # scheduling of the all_reduce_ node. Horizontally fusing this
6314            # clone can almost never out-perform scheduling the all_reduce_
6315            # earlier. Also in most cases, this clone is eliminated via
6316            # in-place reuse. Therefore, we tell the scheduler to not fuse it.
6317            inp.realize()
6318            V.graph.no_fuse_buffer_names.add(inp.get_name())
6319        ir._CollectiveKernel.create_inplace(
6320            _c10d_functional.all_reduce_.default, inp, reduce_op, group_name
6321        )
6322        return inp
6323
6324    @register_lowering(_c10d_functional.all_reduce_)
6325    def _all_reduce_(inp, reduce_op, group_name):
6326        ir._CollectiveKernel.create_inplace(
6327            _c10d_functional.all_reduce_.default, inp, reduce_op, group_name
6328        )
6329        return inp
6330
6331    @register_lowering(_c10d_functional.all_reduce_coalesced)
6332    def _all_reduce_coalesced(inputs, reduce_op, group_name):
6333        inputs = [clone(inp) for inp in inputs]
6334        ir._CollectiveKernel.create_inplace(
6335            _c10d_functional.all_reduce_coalesced_.default,
6336            inputs,
6337            reduce_op,
6338            group_name,
6339        )
6340        return inputs
6341
6342    @register_lowering(_c10d_functional.all_reduce_coalesced_)
6343    def _all_reduce_coalesced_(inputs, reduce_op, group_name):
6344        ir._CollectiveKernel.create_inplace(
6345            _c10d_functional.all_reduce_coalesced_.default,
6346            inputs,
6347            reduce_op,
6348            group_name,
6349        )
6350        return inputs
6351
6352    @register_lowering(_c10d_functional.all_gather_into_tensor)
6353    def _all_gather_into_tensor(inp, group_size, group_name):
6354        return ir.TensorBox.create(
6355            ir._CollectiveKernel.create_out_of_place(
6356                _c10d_functional.all_gather_into_tensor.default,
6357                inp,
6358                group_size,
6359                group_name,
6360            )
6361        )
6362
6363    @register_lowering(_c10d_functional.all_gather_into_tensor_coalesced)
6364    def _all_gather_into_tensor_coalesced(inputs, group_size, group_name):
6365        return pytree.tree_map(
6366            ir.TensorBox.create,
6367            ir._CollectiveKernel.create_out_of_place(
6368                _c10d_functional.all_gather_into_tensor_coalesced.default,
6369                inputs,
6370                group_size,
6371                group_name,
6372            ),
6373        )
6374
6375    @register_lowering(_c10d_functional.all_gather_into_tensor_out)
6376    def _all_gather_into_tensor_out(inp, group_size, group_name, *, out):
6377        ir._CollectiveKernel.create_inplace(
6378            _c10d_functional.all_gather_into_tensor_out.default,
6379            inp,
6380            group_size,
6381            group_name,
6382            out=out,
6383        )
6384        return out
6385
6386    @register_lowering(_c10d_functional.reduce_scatter_tensor)
6387    def _reduce_scatter_tensor(inp, reduce_op, group_size, group_name):
6388        return ir.TensorBox.create(
6389            ir._CollectiveKernel.create_out_of_place(
6390                _c10d_functional.reduce_scatter_tensor.default,
6391                inp,
6392                reduce_op,
6393                group_size,
6394                group_name,
6395            )
6396        )
6397
6398    @register_lowering(_c10d_functional.reduce_scatter_tensor_coalesced)
6399    def _reduce_scatter_tensor_coalesced(inputs, reduce_op, group_size, group_name):
6400        return pytree.tree_map(
6401            ir.TensorBox.create,
6402            ir._CollectiveKernel.create_out_of_place(
6403                _c10d_functional.reduce_scatter_tensor_coalesced.default,
6404                inputs,
6405                reduce_op,
6406                group_size,
6407                group_name,
6408            ),
6409        )
6410
6411    @register_lowering(_c10d_functional.all_to_all_single)
6412    def _all_to_all_single(inp, output_split_sizes, input_split_sizes, group_name):
6413        return ir.TensorBox.create(
6414            ir._CollectiveKernel.create_out_of_place(
6415                _c10d_functional.all_to_all_single.default,
6416                inp,
6417                output_split_sizes,
6418                input_split_sizes,
6419                group_name,
6420            )
6421        )
6422
6423    @register_lowering(_c10d_functional.broadcast)
6424    def _broadcast(inp, src, group_name):
6425        inp = clone(inp)
6426        ir._CollectiveKernel.create_inplace(
6427            _c10d_functional.broadcast_.default, inp, src, group_name
6428        )
6429        return inp
6430
6431    @register_lowering(_c10d_functional.broadcast_)
6432    def _broadcast_(inp, src, group_name):
6433        ir._CollectiveKernel.create_inplace(
6434            _c10d_functional.broadcast_.default, inp, src, group_name
6435        )
6436        return inp
6437
6438    @register_lowering(_c10d_functional.wait_tensor)
6439    def _wait_tensor(inp):
6440        ir._WaitKernel.create_wait(_c10d_functional.wait_tensor.default, inp)
6441        return inp
6442
6443    @register_lowering(torch.ops._dtensor.shard_dim_alltoall)
6444    def _shard_dim_alltoall(inp, gather_dim, shard_dim, group_name):
6445        return ir.TensorBox.create(
6446            ir._CollectiveKernel.create_out_of_place(
6447                torch.ops._dtensor.shard_dim_alltoall.default,
6448                inp,
6449                gather_dim,
6450                shard_dim,
6451                group_name,
6452            )
6453        )
6454
6455except (AttributeError, ImportError):
6456    log.info(
6457        "Inductor support for distributed collectives depends on building torch.distributed"
6458    )
6459
6460# populate lowerings defined in kernel/*
6461from . import kernel
6462
6463
6464import_submodule(kernel)
6465
6466from . import quantized_lowerings
6467
6468
6469quantized_lowerings.register_quantized_ops()
6470quantized_lowerings.register_woq_mm_ops()
6471
6472from . import mkldnn_lowerings
6473
6474
6475mkldnn_lowerings.register_onednn_fusion_ops()
6476
6477from . import jagged_lowerings
6478
6479
6480jagged_lowerings.register_jagged_ops()
6481