xref: /aosp_15_r20/external/pytorch/torch/_decomp/decompositions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import functools
4import itertools
5import numbers
6import operator
7import sys
8from enum import Enum
9from functools import partial, reduce
10from itertools import chain, product
11from typing import Any, Callable, cast, Iterable, List, Optional, Tuple, Union
12
13import torch
14import torch._meta_registrations
15import torch._prims as prims
16import torch._prims_common as utils
17import torch.nn.functional as F
18from torch import sym_float, sym_int, Tensor
19from torch._decomp import register_decomposition
20from torch._higher_order_ops.out_dtype import out_dtype
21from torch._prims_common import (
22    IntLike,
23    NumberType,
24    suggest_memory_format,
25    TensorLike,
26    TensorSequenceType,
27)
28from torch._prims_common.wrappers import (
29    _maybe_convert_to_dtype,
30    _maybe_resize_out,
31    _safe_copy_out,
32    out_wrapper,
33)
34from torch.utils import _pytree as pytree
35from torch.utils._pytree import tree_map
36
37
38DispatchKey = torch._C.DispatchKey  # type: ignore[attr-defined]
39
40# None of these functions are publicly accessible; get at them
41# from torch._decomps
42__all__: List[str] = []
43
44aten = torch._ops.ops.aten
45
46
47class Reduction(Enum):
48    NONE = 0
49    MEAN = 1
50    SUM = 2
51
52
53# This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided
54# We're currently re-using ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops
55# Will need to validate the non-elementwise uses
56def type_casts(
57    f: Callable,
58    type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND,
59    compute_dtype_only: bool = False,
60):
61    @functools.wraps(f)
62    def inner(*args, **kwargs):
63        flat_args = [
64            x for x in pytree.arg_tree_leaves(*args, **kwargs) if isinstance(x, Tensor)
65        ]
66        computation_dtype, result_dtype = utils.elementwise_dtypes(
67            *flat_args, type_promotion_kind=type_promotion
68        )
69
70        # TODO: pretty sure this is not quite right
71        def increase_prec(x):
72            if isinstance(x, Tensor):
73                return x.to(computation_dtype)
74            else:
75                return x
76
77        def decrease_prec(x):
78            if isinstance(x, Tensor):
79                return x.to(result_dtype)
80            else:
81                return x
82
83        r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
84        if compute_dtype_only:
85            return r
86        else:
87            return tree_map(decrease_prec, r)
88
89    return inner
90
91
92compute_only_pw_cast_for_opmath = partial(
93    type_casts,
94    type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
95    compute_dtype_only=True,
96)
97pw_cast_for_opmath = partial(
98    type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
99)
100pw_cast_for_int_to_real = partial(
101    type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
102)
103
104
105# This expands x until x.dim() == dim. Might be useful as an operator
106def _unsqueeze_to_dim(x: Tensor, dim: int) -> Tensor:
107    for _ in range(dim - x.dim()):
108        x = x.unsqueeze(-1)
109    return x
110
111
112@register_decomposition(aten.tanh_backward)
113@out_wrapper("grad_input")
114@pw_cast_for_opmath
115def tanh_backward(out_grad: Tensor, y: Tensor):
116    return out_grad * (1 - y * y).conj_physical()
117
118
119@register_decomposition(aten.sigmoid_backward)
120@out_wrapper("grad_input")
121@pw_cast_for_opmath
122def sigmoid_backward(out_grad: Tensor, y: Tensor):
123    return out_grad * (y * (1 - y)).conj_physical()
124
125
126@register_decomposition(aten.softplus_backward)
127@out_wrapper("grad_input")
128@pw_cast_for_opmath
129def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float):
130    z = (x * beta).exp()
131    return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
132
133
134@register_decomposition(aten.elu_backward)
135@out_wrapper("grad_input")
136@pw_cast_for_opmath
137def elu_backward(
138    grad_output: Tensor,
139    alpha: float,
140    scale: float,
141    input_scale: float,
142    is_result: bool,
143    self_or_result: Tensor,
144):
145    negcoef = alpha * scale
146    poscoef = scale
147    negiptcoef = input_scale
148    if is_result:
149        return torch.where(
150            self_or_result <= 0,
151            grad_output * negiptcoef * (self_or_result + negcoef),
152            grad_output * poscoef,
153        )
154    else:
155        return torch.where(
156            self_or_result <= 0,
157            grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef),
158            grad_output * poscoef,
159        )
160
161
162@register_decomposition([aten.fill.Scalar])
163def fill_scalar(self, value):
164    return torch.full_like(self, value)
165
166
167@register_decomposition([aten.fill.Tensor])
168def fill_tensor(self, value: Tensor):
169    torch._check(
170        value.dim() == 0,
171        lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions",
172    )
173    return aten.copy(self, value)
174
175
176@register_decomposition(aten.hardsigmoid)
177@out_wrapper()
178@pw_cast_for_opmath
179def hardsigmoid(self: Tensor) -> Tensor:
180    return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
181
182
183@register_decomposition(aten.hardsigmoid_backward)
184@out_wrapper("grad_input")
185@pw_cast_for_opmath
186def hardsigmoid_backward(grad_output: Tensor, self: Tensor):
187    return torch.where(
188        (self > -3.0) & (self < 3.0),
189        grad_output * (1.0 / 6.0),
190        0.0,
191    )
192
193
194@register_decomposition(aten.hardtanh_backward)
195@out_wrapper("grad_input")
196def hardtanh_backward(
197    grad_output: Tensor, self: Tensor, min_val: float, max_val: float
198):
199    return torch.where((self <= min_val) | (self >= max_val), 0.0, grad_output)
200
201
202@register_decomposition(aten.hardswish)
203@out_wrapper()
204@pw_cast_for_opmath
205def hardswish(self: Tensor) -> Tensor:
206    return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
207
208
209@register_decomposition(aten.hardswish_backward)
210@out_wrapper()
211@pw_cast_for_opmath
212def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor:
213    return torch.where(
214        self < -3,
215        0.0,
216        torch.where(self <= 3, grad_output * ((self / 3) + 0.5), grad_output),
217    )
218
219
220@register_decomposition(aten.threshold_backward)
221@out_wrapper("grad_input")
222def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float):
223    return torch.where(self <= threshold, 0, grad_output)
224
225
226@register_decomposition(aten.leaky_relu_backward)
227@out_wrapper("grad_input")
228@pw_cast_for_opmath
229def leaky_relu_backward(
230    grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool
231):
232    return torch.where(self > 0, grad_output, grad_output * negative_slope)
233
234
235@register_decomposition(aten.gelu_backward)
236@out_wrapper("grad_input")
237@pw_cast_for_opmath
238def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"):
239    M_SQRT2 = 1.41421356237309504880
240    M_SQRT1_2 = 0.70710678118654752440
241    M_2_SQRTPI = 1.12837916709551257390
242    if approximate == "tanh":
243        kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
244        kKappa = 0.044715
245        x_sq = self * self
246        x_cube = x_sq * self
247        inner = kBeta * (self + kKappa * x_cube)
248        tanh_inner = torch.tanh(inner)
249
250        left = 0.5 * self
251        right = 1 + tanh_inner
252
253        left_derivative = 0.5 * right
254
255        tanh_derivative = 1 - tanh_inner * tanh_inner
256        inner_derivative = kBeta * (1 + 3 * kKappa * x_sq)
257        right_derivative = left * tanh_derivative * inner_derivative
258
259        return grad * (left_derivative + right_derivative)
260    else:
261        kAlpha = M_SQRT1_2
262        kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
263        cdf = 0.5 * (1 + torch.erf(self * kAlpha))
264        pdf = kBeta * torch.exp(self * self * -0.5)
265        return grad * (cdf + self * pdf)
266
267
268@register_decomposition(aten.mish_backward)
269@pw_cast_for_opmath
270def mish_backward(grad_output: Tensor, input: Tensor):
271    input_tanh_softplus = torch.tanh(F.softplus(input))
272    input_sigmoid = torch.sigmoid(input)
273    out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus)
274    return grad_output * (input_tanh_softplus + out)
275
276
277@register_decomposition(aten.silu)
278@out_wrapper()
279@pw_cast_for_opmath
280def silu(self: Tensor) -> Tensor:
281    return self * torch.sigmoid(self)
282
283
284@register_decomposition(aten.silu_backward)
285@out_wrapper("grad_input")
286@pw_cast_for_opmath
287def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
288    sigmoid = 1 / (1 + torch.exp(-self))
289    return grad_output * sigmoid * (1 + self * (1 - sigmoid))
290
291
292@register_decomposition(aten._prelu_kernel)
293def _prelu_kernel(self: Tensor, weight: Tensor) -> Tensor:
294    return torch.where(self > 0, self, weight * self)
295
296
297@register_decomposition(aten._prelu_kernel_backward)
298def _prelu_kernel_backward(
299    grad_output: Tensor,
300    self: Tensor,
301    weight: Tensor,
302) -> Tuple[Tensor, Tensor]:
303    input_grad = torch.where(self > 0, grad_output, weight * grad_output)
304    weight_grad = torch.where(self > 0, 0.0, self * grad_output)
305    return (input_grad, weight_grad)
306
307
308@register_decomposition(aten.rrelu_with_noise)
309@aten.rrelu_with_noise.default.py_impl(DispatchKey.AutogradCUDA)
310@out_wrapper()
311@pw_cast_for_opmath
312def rrelu_with_noise(
313    self: Tensor,
314    noise: Tensor,
315    lower: float = 0.125,
316    upper: float = 0.3333333333333333,
317    training: bool = False,
318    generator: Optional[torch.Generator] = None,
319) -> Tensor:
320    assert generator is None
321    if training:
322        not_positive = self <= 0
323        r = aten.uniform(self, lower, upper)
324        output = torch.where(not_positive, self * r, self)
325        noise.copy_(torch.where(not_positive, r, 1))
326        return output
327    else:
328        negative_slope = (lower + upper) / 2
329        return aten.leaky_relu(self, negative_slope)
330
331
332@register_decomposition(aten.rrelu_with_noise_)
333@aten.rrelu_with_noise_.default.py_impl(DispatchKey.AutogradCUDA)
334@pw_cast_for_opmath
335def rrelu_with_noise_(
336    self: Tensor,
337    noise: Tensor,
338    lower: float = 0.125,
339    upper: float = 0.3333333333333333,
340    training: bool = False,
341    generator: Optional[torch.Generator] = None,
342) -> Tensor:
343    return self.copy_(rrelu_with_noise(self, noise, lower, upper, training, generator))
344
345
346@register_decomposition(aten.rrelu_with_noise_backward)
347@out_wrapper()
348@pw_cast_for_opmath
349def rrelu_with_noise_backward(
350    grad_output: Tensor,
351    self: Tensor,
352    noise: Tensor,
353    lower: float,
354    upper: float,
355    training: bool,
356    self_is_result: bool,
357) -> Tensor:
358    if training and upper - lower > 1e-6:
359        return grad_output.mul(noise)
360    else:
361        negative_slope = (lower + upper) / 2
362        return aten.leaky_relu_backward(
363            grad_output, self, negative_slope, self_is_result
364        )
365
366
367@register_decomposition(aten.log_sigmoid_backward)
368@out_wrapper("grad_input")
369@pw_cast_for_opmath
370def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor:
371    in_negative = self < 0
372    max_deriv = torch.where(in_negative, 1, 0)
373    sign = torch.where(in_negative, 1, -1)
374    z = torch.exp(-torch.abs(self))
375    return grad_output * (max_deriv - sign * (z / (1 + z)))
376    # CPU has a special formula that uses buffer, but disabled for convenience sake
377    # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
378
379
380def apply_loss_reduction(loss: Tensor, reduction: int):
381    if reduction == Reduction.MEAN.value:
382        return torch.mean(loss)
383    elif reduction == Reduction.SUM.value:
384        return torch.sum(loss)
385    else:
386        return loss
387
388
389def to_real_dtype(dtype: torch.dtype):
390    if dtype == torch.complex32:
391        return torch.float16
392    elif dtype == torch.complex64:
393        return torch.float32
394    elif dtype == torch.complex128:
395        return torch.float64
396
397
398# TODO: None of these loss castings are quite correct, see
399# https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels
400# perform the pointwise portion in opmath, but don't maintain it between the
401# pointwise portion and the reduction
402
403
404@register_decomposition(aten.mse_loss)
405@out_wrapper()
406@pw_cast_for_opmath
407def mse_loss(
408    self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
409) -> Tensor:
410    loss = (self - target) ** 2
411    return apply_loss_reduction(loss, reduction)
412
413
414@register_decomposition(aten.mse_loss_backward)
415@out_wrapper("grad_input")
416@pw_cast_for_opmath
417def mse_loss_backward(
418    grad_output: Tensor, input: Tensor, target: Tensor, reduction: int
419):
420    norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0
421    return norm * (input - target) * grad_output
422
423
424@register_decomposition(aten._safe_softmax)
425def safe_softmax(self, dim, dtype=None):
426    out = torch.softmax(self, dim=dim, dtype=dtype)
427    masked = self.eq(float("-inf"))
428    masked_rows = torch.all(masked, dim=dim, keepdim=True)
429    zeros = torch.zeros_like(out)
430    return torch.where(masked_rows, zeros, out)
431
432
433@register_decomposition(aten.smooth_l1_loss)
434@out_wrapper()
435@pw_cast_for_opmath
436def smooth_l1_loss(
437    self: Tensor,
438    target: Tensor,
439    reduction: int = Reduction.MEAN.value,
440    beta: float = 1.0,
441):
442    loss = (self - target).abs()
443    loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
444    return apply_loss_reduction(loss, reduction)
445
446
447@register_decomposition(aten.smooth_l1_loss_backward.default)
448@pw_cast_for_opmath
449def smooth_l1_loss_backward(
450    grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, beta: float
451):
452    norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0
453    x = self - target
454    abs_x = torch.abs(x)
455    norm_grad = norm * grad_output
456    return torch.where(
457        abs_x < beta,
458        norm_grad * x / beta,
459        norm_grad * torch.sign(x),
460    )
461
462
463@register_decomposition(aten.smooth_l1_loss_backward.grad_input)
464@pw_cast_for_opmath
465def smooth_l1_loss_backward_out(
466    grad_output: Tensor,
467    self: Tensor,
468    target: Tensor,
469    reduction: int,
470    beta: float,
471    grad_input: Tensor,
472):
473    result = smooth_l1_loss_backward(grad_output, self, target, reduction, beta)
474    _maybe_resize_out(grad_input, result.shape)
475    return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True)
476
477
478@register_decomposition(aten.huber_loss_backward.default)
479@pw_cast_for_opmath
480def huber_loss_backward(
481    grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float
482):
483    norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0
484    x = self - target
485    return torch.where(
486        x < -delta,
487        -norm * grad_output * delta,
488        torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output),
489    )
490
491
492# We cannot use @out_wrapper() here, because the output tensor is not named 'out', it's 'grad_input'
493@register_decomposition(aten.huber_loss_backward.out)
494@pw_cast_for_opmath
495def huber_loss_backward_out(
496    grad_output: Tensor,
497    self: Tensor,
498    target: Tensor,
499    reduction: int,
500    delta: float,
501    grad_input: Tensor,
502):
503    result = huber_loss_backward(grad_output, self, target, reduction, delta)
504    _maybe_resize_out(grad_input, result.shape)
505    return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True)
506
507
508def _nll_loss_backward(
509    grad_output: Tensor,
510    self: Tensor,
511    target: Tensor,
512    weight: Optional[Tensor],
513    reduction: int,
514    ignore_index: int,
515    total_weight: Tensor,
516) -> Tensor:
517    channel_dim = 0 if self.dim() < 2 else 1
518    if reduction == Reduction.MEAN.value:
519        grad_output = grad_output / total_weight
520
521    target = target.unsqueeze(channel_dim)
522    safe_target = torch.where(target != ignore_index, target, 0)
523    grad_input = torch.zeros_like(self)
524    grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
525
526    if grad_input.dim() > grad_output.dim() > 0:
527        grad_output = grad_output.unsqueeze(channel_dim)
528
529    if weight is not None:
530        new_shape = [1 for _ in range(self.dim())]
531        new_shape[channel_dim] = weight.shape[0]
532        weight = weight.reshape(new_shape)
533        grad_output = grad_output * weight
534
535    grad_output = torch.where(target != ignore_index, grad_output, 0)
536
537    return grad_input * grad_output
538
539
540@register_decomposition(aten.glu_backward)
541@out_wrapper("grad_input")
542@pw_cast_for_opmath
543def glu_backward(grad_output: Tensor, self: Tensor, dim: int) -> Tensor:
544    assert self.dim() > 0, "glu does not support 0-dimensional tensors"
545    wrap_dim = utils.canonicalize_dim(self.dim(), dim)
546    nIn = self.size(wrap_dim)
547    assert (
548        nIn % 2 == 0
549    ), f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}"
550    inputSize = nIn // 2
551    firstHalf = self.narrow(wrap_dim, 0, inputSize)
552    secondHalf = self.narrow(wrap_dim, inputSize, inputSize)
553    gradInputFirstHalf = torch.sigmoid(secondHalf)
554    gradInputSecondHalf = (
555        (1.0 - gradInputFirstHalf) * gradInputFirstHalf * firstHalf * grad_output
556    )
557    gradInputFirstHalf = gradInputFirstHalf * grad_output
558    return torch.cat([gradInputFirstHalf, gradInputSecondHalf], dim=wrap_dim)
559
560
561@register_decomposition(aten.nll_loss_backward)
562@out_wrapper("grad_input")
563def nll_loss_backward(
564    grad_output: Tensor,
565    self: Tensor,
566    target: Tensor,
567    weight: Optional[Tensor],
568    reduction: int,
569    ignore_index: int,
570    total_weight: Tensor,
571) -> Tensor:
572    assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D"
573    assert (
574        target.dim() <= 1
575    ), "0D or 1D target tensor expected, multi-target not supported"
576
577    no_batch_dim = self.dim() == 1 and target.dim() == 0
578    assert no_batch_dim or (
579        self.shape[0] == target.shape[0]
580    ), f"size mismatch (got input: {self.shape}, target: {target.shape})"
581    assert total_weight.numel() == 1, (
582        "expected total_weight to be a single element tensor, got: ",
583        f"{total_weight.shape} ({total_weight.numel()} elements)",
584    )
585
586    assert (
587        weight is None or weight.numel() == self.shape[-1]
588    ), "weight tensor should be defined either for all or no classes"
589
590    if reduction == Reduction.NONE.value and self.dim() == 2:
591        assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], (
592            f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but "
593            f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}"
594        )
595    else:
596        assert (
597            grad_output.dim() <= 1 and grad_output.numel() == 1
598        ), f"Expected a single element grad_output tensor, but got: {grad_output.shape}"
599
600    return _nll_loss_backward(
601        grad_output, self, target, weight, reduction, ignore_index, total_weight
602    )
603
604
605@register_decomposition(aten.nll_loss2d_backward)
606@out_wrapper("grad_input")
607def nll_loss2d_backward(
608    grad_output: Tensor,
609    self: Tensor,
610    target: Tensor,
611    weight: Optional[Tensor],
612    reduction: int,
613    ignore_index: int,
614    total_weight: Tensor,
615) -> Tensor:
616    assert (
617        self.dim() == 4
618    ), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}"
619
620    assert (
621        target.dim() == 3
622    ), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}"
623
624    assert (
625        self.shape[0] == target.shape[0]
626        and self.shape[2] == target.shape[1]
627        and self.shape[3] == target.shape[2]
628    ), f"size mismatch (got input: {self.shape}, target: {target.shape}"
629
630    assert total_weight.numel() == 1, (
631        "expected total_weight to be a single element tensor, "
632        f"got: {total_weight.shape} ( {total_weight.numel()}, elements)"
633    )
634
635    return _nll_loss_backward(
636        grad_output, self, target, weight, reduction, ignore_index, total_weight
637    )
638
639
640@register_decomposition(aten.binary_cross_entropy)
641@out_wrapper()
642@pw_cast_for_opmath
643def binary_cross_entropy(
644    self: Tensor,
645    target: Tensor,
646    weight: Optional[Tensor] = None,
647    reduction: int = Reduction.MEAN.value,
648) -> Tensor:
649    # We cannot currently model this without introducing data-dependent control flow
650    # TORCH_CHECK(
651    #     (input_val >= 0) && (input_val <= 1),
652    #     "all elements of input should be between 0 and 1"
653    # )
654    loss = (target - 1) * torch.maximum(
655        torch.log1p(-self), self.new_full((), -100)
656    ) - target * torch.maximum(torch.log(self), self.new_full((), -100))
657    if weight is not None:
658        loss = loss * weight
659    return apply_loss_reduction(loss, reduction)
660
661
662@register_decomposition(aten.binary_cross_entropy_backward)
663@out_wrapper("grad_input")
664@pw_cast_for_opmath
665def binary_cross_entropy_backward(
666    grad_output: Tensor,
667    self: Tensor,
668    target: Tensor,
669    weight: Optional[Tensor] = None,
670    reduction: int = Reduction.MEAN.value,
671) -> Tensor:
672    EPSILON = 1e-12
673    result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON)
674    if weight is not None:
675        result = result * weight
676    if reduction == Reduction.MEAN.value:
677        result = result / self.numel()
678    return result
679
680
681@register_decomposition(aten.soft_margin_loss)
682@out_wrapper()
683@pw_cast_for_opmath
684def soft_margin_loss(
685    input: Tensor,
686    target: Tensor,
687    reduction: int = Reduction.MEAN.value,
688) -> Tensor:
689    loss = torch.log1p(torch.exp(-input * target))
690    return apply_loss_reduction(loss, reduction)
691
692
693@register_decomposition(aten.soft_margin_loss_backward)
694@out_wrapper("grad_input")
695@pw_cast_for_opmath
696def soft_margin_loss_backward(
697    grad_output: Tensor,
698    self: Tensor,
699    target: Tensor,
700    reduction: int = Reduction.MEAN.value,
701) -> Tensor:
702    grad_input = target * grad_output * (torch.sigmoid(target * self) - 1)
703    if reduction == Reduction.MEAN.value:
704        grad_input = grad_input / self.numel()
705    return grad_input
706
707
708@register_decomposition(aten.dist)
709@out_wrapper()
710def dist(input: Tensor, other: Tensor, p: float = 2):
711    return aten.norm(input - other, p=p)
712
713
714@register_decomposition(aten._euclidean_dist)
715@out_wrapper()
716def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor:
717    x1_norm = x1.pow(2).sum(-1, True)
718    x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format)
719    x2_norm = x2.pow(2).sum(-1, True)
720    x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format)
721    x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1)
722    x2_ = torch.cat([x2, x2_pad, x2_norm], -1)
723    result = x1_.matmul(x2_.mT)
724    return result.clamp_min(0).sqrt()
725
726
727@register_decomposition(aten.slice_backward)
728@out_wrapper()
729def slice_backward(
730    grad_output: Tensor,
731    input_sizes: List[int],
732    dim: int,
733    start: int,
734    end: int,
735    step: int,
736):
737    grad_input = grad_output.new_zeros(input_sizes)
738    return torch.slice_scatter(grad_input, grad_output, dim, start, end, step)
739
740
741@register_decomposition(aten.slice.Tensor)
742def slice_forward(
743    # Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1
744    self: Tensor,
745    dim: int = 0,
746    start: Optional[int] = None,
747    end: Optional[int] = None,
748    step: int = 1,
749):
750    from torch.fx.experimental.symbolic_shapes import (
751        guard_size_oblivious,
752        statically_known_true,
753    )
754
755    ndim = self.dim()
756    if ndim == 0:
757        raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
758    dim = utils.canonicalize_dim(self.dim(), dim)
759    sizes = list(self.size())
760    strides = list(self.stride())
761
762    if step <= 0:
763        raise RuntimeError("slice step must be positive")
764
765    start_val = start if start is not None else 0
766    end_val = end if end is not None else sys.maxsize  # 2^63 - 1
767
768    if guard_size_oblivious(start_val < 0):
769        start_val += sizes[dim]
770
771    if guard_size_oblivious(end_val < 0):
772        end_val += sizes[dim]
773
774    if guard_size_oblivious(start_val < 0):
775        start_val = 0
776    elif guard_size_oblivious(start_val > sizes[dim]):
777        start_val = sizes[dim]
778
779    if guard_size_oblivious(end_val < start_val):
780        end_val = start_val
781    elif statically_known_true(end_val == sys.maxsize) or guard_size_oblivious(
782        end_val > sizes[dim]
783    ):
784        end_val = sizes[dim]
785
786    storage_offset = self.storage_offset() + start_val * strides[dim]
787    len = end_val - start_val
788    sizes[dim] = (len + step - 1) // step
789    strides[dim] *= step
790
791    if self.is_quantized:
792        raise NotImplementedError(
793            "Slice decomposition for quantized tensors aren't implemented"
794        )
795    else:
796        return self.as_strided(sizes, strides, storage_offset)
797
798
799def _normalize_start_end(
800    x: Tensor, dim: int, start: Optional[int], end: Optional[int]
801) -> Tuple[int, int]:
802    """
803    Normalize start and end such that both are in the range
804    [0, x.get_size()[dim]] and start <= end.
805    """
806    dim_size = x.shape[dim]
807
808    def clamp_wrap(val, lower, upper, default) -> int:
809        if val is None:
810            return default
811        if val < 0:
812            val = val + dim_size
813        return min(max(val, lower), upper)
814
815    start = clamp_wrap(start, 0, dim_size, 0)
816    end = clamp_wrap(end, start, dim_size, dim_size)
817    return start, end
818
819
820# This is not in torch._refs because aten.index used by
821# aten._unsafe_masked_index does not have a decomposition.
822@register_decomposition(aten.slice_scatter)
823@out_wrapper()
824def slice_scatter(
825    input: Tensor,
826    src: Tensor,
827    dim: int = 0,
828    start: Optional[int] = None,
829    end: Optional[int] = None,
830    step: int = 1,
831):
832    dim = utils.canonicalize_dim(input.ndim, dim)
833    dim_size = input.shape[dim]
834    start, end = _normalize_start_end(input, dim, start, end)
835
836    src_size = list(input.shape)
837    src_size[dim] = (end - start + (step - 1)) // step
838    src = src.expand(src_size)
839
840    if start == 0 and end == dim_size and step == 1:
841        return src.clone()
842
843    indices = [None] * input.dim()
844    idx = torch.arange(dim_size, device=input.device)
845    indices[dim] = (idx - start) // step
846
847    mask = torch.ones(dim_size, device=input.device, dtype=torch.bool)
848    if start != 0:
849        mask = torch.logical_and(mask, idx >= start)
850
851    if end != dim_size:
852        mask = torch.logical_and(mask, idx < end)
853
854    if step != 1:
855        mask = torch.logical_and(mask, (idx - start) % step == 0)
856
857    mask_shape = [1] * input.dim()
858    mask_shape[dim] = -1
859    mask = mask.view(mask_shape)
860    return aten.where(mask, aten._unsafe_masked_index(src, mask, indices, 0), input)
861
862
863@register_decomposition(aten.select_backward)
864@out_wrapper()
865def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int):
866    grad_input = grad_output.new_zeros(input_sizes)
867    return torch.select_scatter(grad_input, grad_output, dim, index)
868
869
870@register_decomposition(aten.diagonal_backward)
871@out_wrapper()
872def diagonal_backward(
873    grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int
874):
875    grad_input = grad_output.new_zeros(input_sizes)
876    return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
877
878
879def _cast_grad_to_input_dtype(
880    grad_output: Tensor, grad_input: Tensor, input_dtype: torch.dtype
881):
882    if grad_output.dtype != input_dtype:
883        grad_input = grad_input.to(input_dtype)
884    return grad_input
885
886
887@register_decomposition(aten._softmax_backward_data)
888@out_wrapper("grad_input")
889@compute_only_pw_cast_for_opmath
890def _softmax_backward_data(
891    grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype
892):
893    new_grad_output = grad_output * output
894    grad_input = new_grad_output - output * torch.sum(
895        new_grad_output, dim=dim, keepdim=True
896    )
897
898    # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor
899    # if grad_output.device == torch.device("cpu"):
900    #     return grad_input.contiguous()
901
902    return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype).contiguous()
903
904
905@register_decomposition(aten._log_softmax_backward_data)
906@out_wrapper()
907@compute_only_pw_cast_for_opmath
908def _log_softmax_backward_data(
909    grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype
910):
911    grad_input = grad_output - torch.exp(output) * torch.sum(
912        grad_output, dim=dim, keepdim=True
913    )
914    return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype)
915
916
917def _im2col_col2im_indices_along_dim(
918    input_d, kernel_d, dilation_d, padding_d, stride_d, device
919):
920    """Utility function to implement im2col and col2im"""
921    blocks_d = input_d + padding_d * 2 - dilation_d * (kernel_d - 1)
922
923    arange_kw = partial(torch.arange, dtype=torch.int64, device=device)
924
925    # Stride kernel over input and find starting indices along dim d
926    blocks_d_indices = arange_kw(0, blocks_d, stride_d).unsqueeze(0)
927
928    # Apply dilation on kernel and find its indices along dim d
929    kernel_grid = arange_kw(0, kernel_d * dilation_d, dilation_d).unsqueeze(-1)
930
931    # Broadcast and add kernel starting positions (indices) with
932    # kernel_grid along dim d, to get block indices along dim d
933    return blocks_d_indices + kernel_grid
934
935
936@register_decomposition(aten.im2col)
937@out_wrapper()
938def im2col(
939    input: Tensor,
940    kernel_size: List[int],
941    dilation: List[int],
942    padding: List[int],
943    stride: List[int],
944) -> Tensor:
945    torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
946    torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
947    torch._check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
948    torch._check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
949
950    def check_positive(param, param_name, strict=True):
951        cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
952        torch._check(
953            cond, lambda: "{param_name} should be greater {'than' zero, but got {param}"
954        )
955
956    check_positive(kernel_size, "kernel_size")
957    check_positive(dilation, "dilation")
958    check_positive(dilation, "padding", strict=False)
959    check_positive(stride, "stride")
960
961    shape = input.shape
962    ndim = len(shape)
963    torch._check(
964        ndim in (3, 4) and all(d != 0 for d in shape[-3:]),
965        lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size "
966        f"and non-zero dimensions, but got: {tuple(shape)}",
967    )
968    output_size = tuple(
969        1 + (out + 2 * pad - dil * (ker - 1) - 1) // st
970        for out, pad, dil, ker, st in zip(
971            shape[-2:], padding, dilation, kernel_size, stride
972        )
973    )
974    torch._check(
975        all(c > 0 for c in output_size),
976        lambda: f"Given an input with spacial size {tuple(shape[-2:])}, "
977        f"kernel_size={kernel_size}, dilation={dilation}, "
978        f"padding={padding}, stride={stride}, "
979        "the calculated shape of the array of sliding blocks "
980        f"is {output_size}, but its components must be at least one.",
981    )
982    batched_input = ndim == 4
983    if not batched_input:
984        input = input.unsqueeze(0)
985
986    batch_dim, channel_dim, input_h, input_w = input.shape
987
988    stride_h, stride_w = stride
989    padding_h, padding_w = padding
990    dilation_h, dilation_w = dilation
991    kernel_h, kernel_w = kernel_size
992
993    blocks_row_indices = _im2col_col2im_indices_along_dim(
994        input_h, kernel_h, dilation_h, padding_h, stride_h, input.device
995    )
996    blocks_col_indices = _im2col_col2im_indices_along_dim(
997        input_w, kernel_w, dilation_w, padding_w, stride_w, input.device
998    )
999
1000    # Note that F.pad takes (padding_left, padding_right, padding_top, padding_bottom)
1001    # ugh
1002    padded_input = F.pad(input, (padding_w, padding_w, padding_h, padding_h))
1003
1004    blocks_row_indices = blocks_row_indices.unsqueeze(-1).unsqueeze(-1)
1005    output = padded_input[:, :, blocks_row_indices, blocks_col_indices]
1006    output = output.permute(0, 1, 2, 4, 3, 5)
1007    num_blocks_row = blocks_row_indices.size(1)
1008    num_blocks_col = blocks_col_indices.size(1)
1009    output = output.reshape(
1010        batch_dim, channel_dim * kernel_h * kernel_w, num_blocks_row * num_blocks_col
1011    )
1012
1013    if not batched_input:
1014        output = output.squeeze(0)
1015    return output
1016
1017
1018@register_decomposition(aten.col2im)
1019@out_wrapper()
1020@pw_cast_for_opmath
1021def col2im(
1022    input: Tensor,
1023    output_size: List[int],
1024    kernel_size: List[int],
1025    dilation: List[int],
1026    padding: List[int],
1027    stride: List[int],
1028) -> Tensor:
1029    torch._check(len(output_size) == 2, lambda: "only 2D output_size supported")
1030    torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
1031    torch._check(len(dilation) == 2, lambda: "only 2D dilation supported")
1032    torch._check(len(padding) == 2, lambda: "only 2D padding supported")
1033    torch._check(len(stride) == 2, lambda: "only 2D stride supported")
1034
1035    def check_positive(param, param_name, strict=True):
1036        cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
1037        torch._check(
1038            cond, lambda: "{param_name} should be greater than zero, but got {param}"
1039        )
1040
1041    check_positive(kernel_size, "kernel_size")
1042    check_positive(dilation, "dilation")
1043    check_positive(padding, "padding", strict=False)
1044    check_positive(stride, "stride")
1045    check_positive(output_size, "output_size")
1046
1047    shape = input.shape
1048    ndim = len(shape)
1049    torch._check(
1050        ndim in (2, 3) and all(d != 0 for d in shape[-2:]),
1051        lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size "
1052        f"and non-zero dimensions, but got: {tuple(shape)}",
1053    )
1054    prod_kernel_size = kernel_size[0] * kernel_size[1]
1055    torch._check(
1056        shape[-2] % prod_kernel_size == 0,
1057        lambda: "Expected size of input's first non-batch dimension to be divisible by the "
1058        f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and "
1059        f"kernel_size={kernel_size}",
1060    )
1061    col = [
1062        1 + (out + 2 * pad - dil * (ker - 1) - 1) // st
1063        for out, pad, dil, ker, st in zip(
1064            output_size, padding, dilation, kernel_size, stride
1065        )
1066    ]
1067    L = col[0] * col[1]
1068    torch._check(
1069        shape[-1] == L,
1070        lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
1071        f"dilation={dilation}, padding={padding}, stride={stride}, "
1072        f"expected input.size(-1) to be {L} but got {shape[-1]}.",
1073    )
1074    torch._check(
1075        L > 0,
1076        lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
1077        f"dilation={dilation}, padding={padding}, stride={stride}, "
1078        f"expected input.size(-1) to be {L} but got {shape[-1]}.",
1079    )
1080    batched_input = ndim == 3
1081    if not batched_input:
1082        input = input.unsqueeze(0)
1083
1084    shape = input.shape
1085
1086    out_h, out_w = output_size
1087    stride_h, stride_w = stride
1088    padding_h, padding_w = padding
1089    dilation_h, dilation_w = dilation
1090    kernel_h, kernel_w = kernel_size
1091
1092    # col2im is defined as the backwards of im2col, so we differentiate its decomposition by hand
1093    input = input.reshape([shape[0], shape[1] // prod_kernel_size] + kernel_size + col)
1094    input = input.permute(0, 1, 2, 4, 3, 5)
1095
1096    indices_row = _im2col_col2im_indices_along_dim(
1097        out_h, kernel_h, dilation_h, padding_h, stride_h, input.device
1098    )
1099    indices_row = _unsqueeze_to_dim(indices_row, 4)
1100    indices_col = _im2col_col2im_indices_along_dim(
1101        out_w, kernel_w, dilation_w, padding_w, stride_w, input.device
1102    )
1103
1104    output_padded_size = [o + 2 * p for o, p in zip(output_size, padding)]
1105    output = input.new_zeros(
1106        [shape[0], shape[1] // prod(kernel_size)] + output_padded_size
1107    )
1108    idx = (None, None, indices_row, indices_col)
1109    output = aten._unsafe_index_put(output, idx, input, accumulate=True)
1110    output = F.pad(output, (-padding_w, -padding_w, -padding_h, -padding_h))
1111
1112    if not batched_input:
1113        output = output.squeeze(0)
1114    return output
1115
1116
1117@register_decomposition(aten.native_dropout_backward)
1118@out_wrapper()
1119def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
1120    # According to the CUDA kernel implementation we should have this test;
1121    # but it seems to fail tests!
1122    # torch._check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}")
1123
1124    # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format
1125    # This different from TensorIterator's behavior
1126    r = (grad_output * (mask.type_as(grad_output) * scale)).clone(
1127        memory_format=utils.suggest_memory_format(grad_output)
1128    )
1129    return r
1130
1131
1132@register_decomposition(aten.unfold_backward)
1133@out_wrapper()
1134def unfold_backward(
1135    grad: Tensor, input_size: List[int], dimension: int, size: int, step: int
1136) -> Tensor:
1137    if len(input_size) == 0:
1138        return torch.squeeze_copy(grad, 0)
1139    dim = utils.canonicalize_dim(len(input_size), dimension)
1140    idx = torch.arange(input_size[dim], device=grad.device, dtype=torch.int32)
1141    idx = idx.unfold(0, size, step).flatten()
1142    grad = grad.movedim(-1, dim + 1).flatten(dim, dim + 1)
1143    # nb. At the moment this generates two kernels in triton
1144    # It could potentially be fused into one call to scatter_reduce,
1145    # in the case step <= size provided scatter_reduce generates 1 kernel
1146    grad_input = grad.new_zeros(input_size)
1147    index = (None,) * dim + (idx,)
1148    return aten._unsafe_index_put(grad_input, index, grad, accumulate=True).contiguous()
1149
1150
1151@register_decomposition(aten.logit_backward.default)
1152@pw_cast_for_opmath
1153def logit_backward(
1154    grad_output: Tensor, self: Tensor, eps: Optional[float] = None
1155) -> Tensor:
1156    if eps is not None:
1157        lo = eps
1158        hi = 1.0 - lo
1159        return torch.where(
1160            torch.logical_and(self >= lo, self <= hi),
1161            grad_output / (self * (1.0 - self)),
1162            0.0,
1163        )
1164    else:
1165        return torch.where(
1166            torch.logical_and(self >= 0.0, self <= 1.0),
1167            grad_output / (self * (1.0 - self)),
1168            self.new_full((), float("nan")),
1169        )
1170
1171
1172@register_decomposition(aten.dropout)
1173@aten.dropout.default.py_impl(DispatchKey.CompositeImplicitAutograd)
1174@aten.dropout.default.py_impl(DispatchKey.Autograd)
1175def dropout(input: Tensor, p: float, train: Optional[bool]):
1176    if train and p != 0:
1177        return aten.native_dropout(input, p, train)[0]
1178    else:
1179        return input.clone()
1180
1181
1182@register_decomposition(aten.native_dropout)
1183@out_wrapper("out0", "out1")
1184def native_dropout(input: Tensor, p: float, train: Optional[bool]):
1185    if train and p != 0:
1186        if p == 1:
1187            return (torch.zeros_like(input), torch.zeros_like(input, dtype=torch.bool))
1188        if not input.dtype.is_floating_point:
1189            raise RuntimeError(
1190                "result type Float can't be cast to the desired output type Long"
1191            )
1192        bool_mask = torch.rand_like(input) > p
1193        res = bool_mask * input * float(1.0 / (1.0 - p))
1194        return (res, bool_mask)
1195    else:
1196        return (input, torch.ones_like(input, dtype=torch.bool))
1197
1198
1199@register_decomposition(aten._softmax)
1200@out_wrapper()
1201def _softmax(x: Tensor, dim: int, half_to_float: bool):
1202    # eager softmax returns a contiguous tensor. Ensure that decomp also returns
1203    # a contiguous tensor.
1204    x = x.contiguous()
1205    if half_to_float:
1206        assert x.dtype == torch.half
1207    computation_dtype, result_dtype = utils.elementwise_dtypes(
1208        x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
1209    )
1210    x = x.to(computation_dtype)
1211    if x.numel() == 0:
1212        unnormalized = torch.exp(x)
1213    else:
1214        x_max = torch.amax(x, dim, keepdim=True)
1215        unnormalized = torch.exp(x - x_max)
1216    result = unnormalized / torch.sum(unnormalized, dim, keepdim=True)
1217    if not half_to_float:
1218        result = result.to(result_dtype)
1219    return result
1220
1221
1222@register_decomposition(aten._log_softmax)
1223@out_wrapper()
1224def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
1225    # eager log_softmax returns a contiguous tensor. Ensure that decomp also
1226    # returns a contiguous tensor.
1227    x = x.contiguous()
1228    if half_to_float:
1229        assert x.dtype == torch.half
1230    computation_dtype, result_dtype = utils.elementwise_dtypes(
1231        x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
1232    )
1233    x = x.to(computation_dtype)
1234    if x.numel() == 0:
1235        shifted = x
1236    else:
1237        x_max = torch.amax(x, dim, keepdim=True)
1238        shifted = x - x_max
1239    shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True))
1240    result = shifted - shifted_logsumexp
1241    if not half_to_float:
1242        result = result.to(result_dtype)
1243    return result
1244
1245
1246@register_decomposition(aten.embedding)
1247@out_wrapper()
1248def embedding(
1249    weight: Tensor,
1250    indices: Tensor,
1251    padding_idx: int = -1,
1252    scale_grad_by_freq: bool = False,
1253    sparse: bool = False,
1254) -> Tensor:
1255    assert weight.dim() == 2, "'weight' must be 2-D"
1256    # Nb. scale_grad_by_freq is not used in the forward
1257    if indices.ndim <= 1:
1258        # We need this one as weight[indices] calls item() in these cases
1259        out = weight.index_select(0, indices)
1260        if indices.ndim == 0:
1261            out = out.squeeze(0)
1262        return out
1263    else:
1264        return weight[indices]
1265
1266
1267@register_decomposition(aten.embedding_dense_backward)
1268@out_wrapper()
1269def embedding_dense_backward(
1270    grad_output: Tensor,
1271    indices: Tensor,
1272    num_weights: int,
1273    padding_idx: int,
1274    scale_grad_by_freq: bool,
1275):
1276    computation_dtype, result_dtype = utils.elementwise_dtypes(
1277        grad_output, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
1278    )
1279    grad_output = grad_output.to(computation_dtype)
1280    indices = _maybe_convert_to_dtype(indices, torch.long)  # type: ignore[assignment]
1281    if scale_grad_by_freq:
1282        counts = indices.new_zeros((num_weights,))
1283        ones = torch.ones_like(indices)
1284        counts = aten._unsafe_index_put(counts, [indices], ones, accumulate=True)
1285        grad_weights_scale = counts[indices]
1286        grad_output = grad_output / grad_weights_scale.unsqueeze(-1)
1287
1288    mask = _unsqueeze_to_dim(indices == padding_idx, grad_output.ndim)
1289    grad = grad_output.masked_fill(mask, 0)
1290    grad_weight = grad_output.new_zeros(
1291        (num_weights,) + grad_output.shape[indices.ndim :]
1292    )
1293    return aten._unsafe_index_put(grad_weight, [indices], grad, accumulate=True).to(
1294        result_dtype
1295    )
1296
1297
1298def prod(x: List[int]):
1299    r = 1
1300    for i in x:
1301        r *= i
1302    return r
1303
1304
1305def _pad_chunk(
1306    tensors: List[Tensor],
1307    dim: int,
1308    num_chunks: int,
1309) -> List[Tensor]:
1310    padded_tensors = []
1311    for tensor in tensors:
1312        tensor_size = tensor.size()
1313        pad_along_dim = (tensor_size[dim] + num_chunks - 1) // num_chunks * num_chunks
1314        if pad_along_dim != tensor_size[dim]:
1315            # Use aten.constant_pad_nd instead of copy_ for functionalization
1316            pad = [0] * 2 * (tensor.ndim - dim - 1) + [
1317                0,
1318                pad_along_dim - tensor_size[dim],
1319            ]
1320            tensor = aten.constant_pad_nd(tensor, pad, 0)
1321        view_size = tensor_size[:dim] + torch.Size([num_chunks, -1])
1322        padded_tensors.append(tensor.view(view_size))
1323    return padded_tensors
1324
1325
1326def have_same_ndims(tensors: List[Tensor]):
1327    ndim = tensors[0].ndim
1328    for tensor in tensors:
1329        if tensor.ndim != ndim:
1330            return False
1331    return True
1332
1333
1334def leading_dimension_matches(tensors: List[Tensor], dim: int):
1335    leading_dim_sizes = tensors[0].size()[:dim]
1336    for tensor in tensors:
1337        torch._check(
1338            tensor.size()[:dim] == leading_dim_sizes,
1339            lambda: "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors",
1340        )
1341
1342
1343def _preprocess_chunk_cat_inputs(
1344    tensors: List[Tensor],
1345    dim: int,
1346    num_chunks: int,
1347):
1348    torch._check(num_chunks >= 1, lambda: "_chunk_cat expects positive num_chunks")
1349    torch._check(
1350        len(tensors) > 0, lambda: "_chunk_cat expects a non-empty input tensor list"
1351    )
1352    expected_dtype = tensors[0].dtype
1353    expected_device = tensors[0].device
1354    for tensor in tensors:
1355        torch._check(tensor.numel() > 0, lambda: "_chunk_cat expects non-empty tensor")
1356        torch._check(
1357            tensor.dtype == expected_dtype,
1358            lambda: "_chunk_cat expects all input tensors with the same dtype",
1359        )
1360        torch._check(
1361            tensor.device == expected_device,
1362            lambda: "_chunk_cat expects all inputs tensors on the same device",
1363        )
1364    if have_same_ndims(tensors):
1365        dim = utils.canonicalize_dim(tensors[0].dim(), dim)
1366    else:
1367        torch._check(
1368            dim >= 0,
1369            lambda: "_chunk_cat expects non-negative dim when input tensors have different ndims",
1370        )
1371        for tensor in tensors:
1372            torch._check(
1373                dim < tensor.ndim,
1374                lambda: "_chunk_cat expects dim < ndim for all input tensors",
1375            )
1376    leading_dimension_matches(tensors, dim)
1377    return dim
1378
1379
1380@register_decomposition([aten._chunk_cat.default, aten._chunk_cat.out])
1381def _chunk_cat(
1382    tensors: List[Tensor],
1383    dim: int,
1384    num_chunks: int,
1385    out: Optional[Tensor] = None,
1386) -> Tensor:
1387    dim = _preprocess_chunk_cat_inputs(tensors, dim, num_chunks)
1388    padded_tensors = _pad_chunk(tensors, dim, num_chunks)
1389    if out is None:
1390        return torch.cat(padded_tensors, dim + 1)
1391    else:
1392        torch.cat(padded_tensors, dim + 1, out=out)
1393        return out
1394
1395
1396@register_decomposition(aten.split_with_sizes)
1397def split_with_sizes(
1398    self: Tensor, split_sizes: List[int], dim: int = 0
1399) -> List[Tensor]:
1400    # NB: Perform the check_is_size tests first so that the
1401    # sum test does not try to do a replacement
1402    for i in range(len(split_sizes)):
1403        torch._check_is_size(
1404            split_sizes[i],
1405            lambda: "split_with_sizes expects split_sizes have only non-negative entries",
1406        )
1407    torch._check_with(
1408        ValueError,
1409        sum(split_sizes) == self.shape[dim],
1410        lambda: f"Split sizes add up to {sum(split_sizes)} but got the tensor's size of {self.shape[dim]}",
1411    )
1412    num_splits = len(split_sizes)
1413    splits = []
1414    start_idx = 0
1415
1416    for i in range(num_splits):
1417        length = split_sizes[i]
1418        splits.append(self.narrow(dim, start_idx, length))
1419        start_idx += length
1420    return splits
1421
1422
1423# out_wrapper currently does not allow optional outputs
1424@register_decomposition(
1425    [aten.split_with_sizes_copy.default, aten.split_with_sizes_copy.out]
1426)
1427def split_with_sizes_copy(
1428    self: Tensor,
1429    split_sizes: List[int],
1430    dim: int = 0,
1431    out: Optional[List[Tensor]] = None,
1432) -> Optional[List[Tensor]]:
1433    splits = split_with_sizes(self, split_sizes, dim=dim)
1434    if out is None:
1435        return [s.clone(memory_format=torch.contiguous_format) for s in splits]
1436    else:
1437        for output, split in zip(out, splits):
1438            _maybe_resize_out(output, split.shape)
1439            _safe_copy_out(copy_from=split, copy_to=output, exact_dtype=True)
1440        return None
1441
1442
1443@register_decomposition(aten.unsafe_split.Tensor)
1444def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]:
1445    return aten.split.Tensor(input, split_size, dim)
1446
1447
1448@register_decomposition(aten.unsafe_split_with_sizes.default)
1449def unsafe_split_with_sizes(
1450    input: Tensor, split_sizes: List[int], dim: int = 0
1451) -> Tuple[Tensor, ...]:
1452    return aten.split_with_sizes.default(input, split_sizes, dim)
1453
1454
1455@register_decomposition(aten.split.Tensor)
1456def split(self: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]:
1457    input_sizes = self.shape
1458    dim_size = input_sizes[dim]
1459    if split_size == 0:
1460        assert dim_size == 0
1461        return (self,)
1462    chunks = (dim_size + split_size - 1) // split_size
1463
1464    # Avoid importing sympy at a module level
1465    from torch.fx.experimental.symbolic_shapes import guard_int
1466
1467    chunks = guard_int(chunks)
1468    split_sizes = [split_size for i in range(chunks)]
1469    split_sizes[-1] = split_size - (split_size * chunks - dim_size)
1470    return torch.split(self, split_sizes, dim)
1471
1472
1473@aten.tensor_split.tensor_indices_or_sections.py_impl(
1474    DispatchKey.CompositeImplicitAutograd
1475)
1476def tensor_split_tensor_indices_or_sections_py_impl(
1477    self: Tensor,
1478    tensor_indices_or_sections: Tensor,
1479    dim: int = 0,
1480) -> Tuple[Tensor, ...]:
1481    assert tensor_indices_or_sections.device.type == "cpu"
1482    assert tensor_indices_or_sections.dtype == torch.int64
1483    split_dim = tensor_indices_or_sections.dim()
1484    torch._check(
1485        split_dim == 1 or split_dim == 0,
1486        lambda: "tensor_split expected tensor_indices_or_sections to be a zero-dimensional "
1487        f"or one-dimensional tensor, but got a tensor with {split_dim} dims",
1488    )
1489    if split_dim == 0:
1490        sections = tensor_indices_or_sections.item()
1491        assert isinstance(sections, IntLike)
1492        return self.tensor_split(sections, dim)
1493    else:
1494        indices = [i.item() for i in tensor_indices_or_sections]
1495        # WARNING: Tempted to torch._check_is_size on the indices here?  You
1496        # can't: tensor_split works with negative values in indices:
1497        #
1498        # >>> torch.tensor_split(torch.randn(10), torch.tensor([-5, 5]))
1499        # (tensor([ 0.3540,  2.1074, -0.8507,  1.1639,  0.3055]), tensor([]),
1500        # tensor([-0.4285,  1.0692, -0.1776,  0.9362,  1.6143]))
1501        #
1502        # Sorry, I don't make the rules.  Explicitly do the item call in user
1503        # code if you KNOW that they are non-negative.
1504        return self.tensor_split(indices, dim)
1505
1506
1507# TODO: this doesn't appear to have enough precision in bfloat16
1508@register_decomposition(aten.addmm)
1509@out_wrapper()
1510@pw_cast_for_opmath
1511def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
1512    if not self.is_floating_point() and not self.is_complex():
1513        beta = int(beta)
1514        alpha = int(alpha)
1515    out = alpha * torch.mm(mat1, mat2)
1516    if beta == 0:
1517        return out
1518
1519    # The output of aten.addmm is contiguous, we need to match this behavior in the decomposition.
1520    # The original implementation 'beta * self + out' would return a strided tensor if `self` is strided.
1521    # We thus use `out`, the output of torch.mm, which is always contiguous, as the first argument for addition.
1522    # This is relying on TensorIterator's behavior that it takes higher precedence on the stride of first input.
1523    # Alternative, we can write `(beta * self + out).contiguous()`, but it introduces another copy in some cases.
1524    # This implementation is not ideal, and we should revisit this when we have a better solution.
1525    return out + beta * self
1526
1527
1528@register_decomposition(aten._addmm_activation)
1529@out_wrapper()
1530@pw_cast_for_opmath
1531def _addmm_activation(
1532    self: Tensor,
1533    mat1: Tensor,
1534    mat2: Tensor,
1535    beta: int = 1,
1536    alpha: int = 1,
1537    use_gelu: bool = False,
1538):
1539    out = addmm(self, mat1, mat2, beta, alpha)
1540    if use_gelu:
1541        if self.is_cuda:
1542            return aten.gelu(out, approximate="tanh")
1543        else:
1544            return aten.gelu(out)
1545    return aten.relu(out)
1546
1547
1548@register_decomposition(aten.addmv)
1549@out_wrapper()
1550@pw_cast_for_opmath
1551def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1):
1552    if not self.is_floating_point() and not self.is_complex():
1553        beta = int(beta)
1554        alpha = int(alpha)
1555    out = alpha * torch.mv(mat1, vec)
1556    if beta == 0:
1557        return out
1558    return out + beta * self
1559
1560
1561@register_decomposition(aten.native_group_norm_backward.default)
1562@pw_cast_for_opmath
1563def native_group_norm_backward(
1564    grad_output: Tensor,
1565    input: Tensor,
1566    mean: Tensor,
1567    rstd: Tensor,
1568    gamma: Optional[Tensor],
1569    N: int,
1570    C: int,
1571    HxW: int,
1572    group: int,
1573    output_mask: List[bool],
1574) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
1575    utils.check_same_device(
1576        grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False
1577    )
1578    utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False)
1579    utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False)
1580    torch._check(
1581        input.numel() == N * C * HxW,
1582        lambda: f"Expect input to have {N * C * HxW} elements",
1583    )
1584    torch._check(
1585        mean.shape == (N, group),
1586        lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}",
1587    )
1588    torch._check(
1589        gamma is None or gamma.numel() == C,
1590        lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}",
1591    )
1592
1593    cpg, _rem = divmod(C, group)
1594    torch._check(
1595        _rem == 0,
1596        lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}",
1597    )
1598
1599    # Compute Internal gradients
1600    ds = torch.mul(grad_output, input).view(N, C, HxW).sum(dim=[2])
1601    db = grad_output.view(N, C, HxW).sum(dim=[2])
1602
1603    d_input: Optional[Tensor] = None
1604    d_gamma: Optional[Tensor] = None
1605    d_bias: Optional[Tensor] = None
1606    if output_mask[0]:
1607        s = 1.0 / (HxW * cpg)
1608        if gamma is not None:
1609            ds_val = torch.mul(ds, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2)
1610            db_val = torch.mul(db, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2)
1611            c1 = torch.mul(
1612                rstd.unsqueeze(-1),
1613                gamma.reshape(1, group, cpg),
1614            )
1615        else:
1616            ds_val = ds.reshape(N, group, cpg).sum(2)
1617            db_val = db.reshape(N, group, cpg).sum(2)
1618            c1 = torch.mul(
1619                rstd.unsqueeze(-1),
1620                torch.ones((1, group, cpg), device=rstd.device),
1621            )
1622        c2 = (db_val * mean - ds_val) * rstd * rstd * rstd * s
1623        c3 = -c2 * mean - db_val * rstd * s
1624
1625        c1 = c1.unsqueeze(-1)
1626        c2 = _unsqueeze_to_dim(c2, 4)
1627        c3 = _unsqueeze_to_dim(c3, 4)
1628        d_input = (
1629            torch.mul(grad_output.reshape(N, group, cpg, HxW), c1)
1630            + torch.mul(input.reshape(N, group, cpg, HxW), c2)
1631            + c3
1632        )
1633        d_input = d_input.reshape(input.shape).to(input.dtype)
1634    if output_mask[1]:
1635        d_gamma = (
1636            (
1637                (ds.view(N, group, cpg) - db.view(N, group, cpg) * mean.unsqueeze(-1))
1638                * rstd.unsqueeze(-1)
1639            )
1640            .sum(dim=[0])
1641            .reshape(C)
1642        )
1643    if output_mask[2]:
1644        d_bias = db.sum(dim=[0])
1645
1646    return (d_input, d_gamma, d_bias)
1647
1648
1649# out_wrapper currently does not allow optional outputs
1650@register_decomposition(aten.native_group_norm_backward.out)
1651def native_group_norm_backward_out(
1652    grad_output: Tensor,
1653    input: Tensor,
1654    mean: Tensor,
1655    rstd: Tensor,
1656    gamma: Optional[Tensor],
1657    N: int,
1658    C: int,
1659    HxW: int,
1660    group: int,
1661    output_mask: List[bool],
1662    *,
1663    out0: torch.Tensor,
1664    out1: torch.Tensor,
1665    out2: torch.Tensor,
1666) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
1667    result = native_group_norm_backward(
1668        grad_output, input, mean, rstd, gamma, N, C, HxW, group, output_mask
1669    )
1670    grad_input = (out0, out1, out2)
1671    for i, r in enumerate(result):
1672        if r is not None:
1673            _maybe_resize_out(grad_input[i], r.shape)
1674            _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True)
1675
1676    return grad_input
1677
1678
1679def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]:
1680    if x is not None:
1681        return x.to(dtype)
1682    return x
1683
1684
1685# TODO: Take a closer look at the type promotion semantics
1686@register_decomposition(aten.native_layer_norm_backward.default)
1687def native_layer_norm_backward(
1688    grad_out: Tensor,
1689    input: Tensor,
1690    normalized_shape: List[int],
1691    mean: Tensor,
1692    rstd: Tensor,
1693    weight: Optional[Tensor],
1694    bias: Optional[Tensor],
1695    output_mask: List[bool],
1696) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
1697    input_shape = input.shape
1698    input_ndim = input.dim()
1699    computation_dtype = utils.get_computation_dtype(input.dtype)
1700    grad_out_cast, input_cast, weight_cast, bias_cast = (
1701        x.to(computation_dtype).contiguous() if x is not None else x
1702        for x in (grad_out, input, weight, bias)
1703    )
1704    assert grad_out_cast is not None
1705
1706    axis = input_ndim - len(normalized_shape)
1707    inner_dims = input_shape[axis:]
1708    outer_dims = input_shape[:axis]
1709    inner_dim_indices: List[int] = []
1710    outer_dim_indices: List[int] = []
1711    for i in range(input_ndim):
1712        if i >= axis:
1713            inner_dim_indices.append(i)
1714        else:
1715            outer_dim_indices.append(i)
1716
1717    N = prod(inner_dims)  # type: ignore[arg-type]
1718    M = prod(outer_dims)  # type: ignore[arg-type]
1719    if M <= 0 or N <= 0:
1720        return (
1721            input.new_zeros(input_shape) if output_mask[0] else None,
1722            input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
1723            input.new_zeros(input_shape[axis:]) if output_mask[2] else None,
1724        )
1725    mean = _unsqueeze_to_dim(mean, input_cast.dim())  # type: ignore[union-attr]
1726    rstd = _unsqueeze_to_dim(rstd, input_cast.dim())  # type: ignore[union-attr]
1727    x_hat = (input_cast - mean) * rstd
1728    if weight_cast is not None:
1729        grad_x_hat = grad_out_cast * weight_cast
1730    else:
1731        grad_x_hat = grad_out_cast
1732    a = grad_x_hat * N
1733    b = torch.sum(grad_x_hat, inner_dim_indices, True)
1734    c1 = torch.mul(grad_x_hat, x_hat)
1735    c2 = torch.sum(c1, inner_dim_indices, True)
1736    c3 = torch.mul(x_hat, c2)
1737
1738    inner = a - b - c3
1739    d_input: Optional[Tensor] = None
1740    d_weight: Optional[Tensor] = None
1741    d_bias: Optional[Tensor] = None
1742    if output_mask[0]:
1743        d_input = (rstd / N) * inner
1744
1745    if output_mask[1] and weight_cast is not None:
1746        if len(outer_dim_indices) > 0:
1747            d_weight = torch.sum(grad_out_cast * x_hat, outer_dim_indices, False)
1748        else:
1749            d_weight = grad_out_cast * x_hat
1750
1751    if output_mask[2] and bias_cast is not None:
1752        if len(outer_dim_indices) > 0:
1753            d_bias = torch.sum(grad_out_cast, outer_dim_indices, False)
1754        else:
1755            d_bias = grad_out_cast.clone()
1756
1757    return (
1758        _maybe_cast(d_input, input.dtype),
1759        _maybe_cast(d_weight, input.dtype),
1760        _maybe_cast(d_bias, input.dtype),
1761    )
1762
1763
1764# out_wrapper currently does not allow optional outputs
1765@register_decomposition(aten.native_layer_norm_backward.out)
1766def native_layer_norm_backward_out(
1767    grad_out: Tensor,
1768    input: Tensor,
1769    normalized_shape: List[int],
1770    mean: Tensor,
1771    rstd: Tensor,
1772    weight: Optional[Tensor],
1773    bias: Optional[Tensor],
1774    output_mask: List[bool],
1775    *,
1776    out0: torch.Tensor,
1777    out1: torch.Tensor,
1778    out2: torch.Tensor,
1779) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
1780    result = native_layer_norm_backward(
1781        grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask
1782    )
1783    grad_input = (out0, out1, out2)
1784    for i, r in enumerate(result):
1785        if r is not None:
1786            _maybe_resize_out(grad_input[i], r.shape)
1787            _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True)
1788
1789    return grad_input
1790
1791
1792def native_batch_norm_helper(
1793    input: Tensor,
1794    weight: Optional[Tensor],
1795    bias: Optional[Tensor],
1796    running_mean: Optional[Tensor],
1797    running_var: Optional[Tensor],
1798    training: bool,
1799    momentum: float,
1800    eps: float,
1801    functional: bool,
1802) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
1803    reduction_dims = [0] + list(range(2, input.dim()))
1804    computation_dtype = utils.get_computation_dtype(input.dtype)
1805    new_running_mean = running_mean
1806    new_running_var = running_var
1807    if training:
1808        computation_dtype = utils.get_computation_dtype(input.dtype)
1809        input_acc = input.to(dtype=computation_dtype)
1810        biased_var, mean = torch.var_mean(
1811            input_acc, dim=reduction_dims, correction=0, keepdim=True
1812        )
1813        rstd = torch.rsqrt(biased_var + eps)
1814
1815        output = (input - mean) * rstd
1816
1817        save_mean = torch.squeeze(mean, reduction_dims)
1818        save_rstd = torch.squeeze(rstd, reduction_dims)
1819        if running_mean is not None:
1820            new_running_mean = momentum * save_mean + (1 - momentum) * running_mean
1821            if not functional:
1822                running_mean.copy_(new_running_mean)
1823        if running_var is not None:
1824            n = input.numel() / input.shape[1]
1825            # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction
1826            # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose
1827            # numerics probably don't matter.
1828            squeezed_var = torch.squeeze(biased_var, reduction_dims)
1829            unbiased_var = squeezed_var * (n / (n - 1))
1830            new_running_var = momentum * unbiased_var + (1 - momentum) * running_var
1831            if not functional:
1832                running_var.copy_(new_running_var)
1833    else:
1834        assert running_mean is not None and running_var is not None
1835        running_mean = running_mean.to(dtype=computation_dtype, copy=True)
1836        new_running_mean = running_mean
1837        running_var = running_var.to(dtype=computation_dtype, copy=True)
1838        new_running_var = running_var
1839        mean = running_mean
1840        invstd = 1 / (torch.sqrt(running_var + eps))
1841        # Very annoying inconsistency where CPU and CUDA give different shapes
1842        if input.device.type != "cpu":
1843            save_mean = running_mean
1844            save_rstd = invstd
1845        else:
1846            save_mean = input.new_zeros((0,))
1847            save_rstd = input.new_zeros((0,))
1848        mean = _unsqueeze_to_dim(mean, input.dim() - 1)
1849        invstd = _unsqueeze_to_dim(invstd, input.dim() - 1)
1850        output = (input - mean) * invstd
1851
1852    if weight is not None:
1853        weight = weight.flatten()
1854        weight = _unsqueeze_to_dim(weight, input.dim() - 1)
1855        output = output * weight
1856
1857    if bias is not None:
1858        bias = bias.flatten()
1859        bias = _unsqueeze_to_dim(bias, input.dim() - 1)
1860        output = output + bias
1861
1862    if input.device.type == "cpu":
1863        save_mean = save_mean.to(dtype=input.dtype)
1864        save_rstd = save_rstd.to(dtype=input.dtype)
1865    return (
1866        output.to(dtype=input.dtype),
1867        save_mean,
1868        save_rstd,
1869        new_running_mean,
1870        new_running_var,
1871    )
1872
1873
1874@register_decomposition(aten.native_batch_norm)
1875@out_wrapper("out", "save_mean", "save_invstd")
1876def native_batch_norm(
1877    input: Tensor,
1878    weight: Optional[Tensor],
1879    bias: Optional[Tensor],
1880    running_mean: Optional[Tensor],
1881    running_var: Optional[Tensor],
1882    training: bool,
1883    momentum: float,
1884    eps: float,
1885) -> Tuple[Tensor, Tensor, Tensor]:
1886    output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
1887        input, weight, bias, running_mean, running_var, training, momentum, eps, False
1888    )
1889    return output, save_mean, save_rstd
1890
1891
1892# TODO: this decomposition is NOT here to stay. We would much prefer replacing native_batch_norm
1893# with our new correctly schema'd _native_batch_norm_legit and its variants, but
1894# we cannot do that immediately in the C++ because it would be forwards incompatible
1895# with some mobile use cases.
1896#
1897# Since this change is most impactful for aot autograd/functionalization, we simply
1898# register this decomposition on the Autograd key for the python dispatcher (which is
1899# currently only used by aot autograd/functionalization and no one else, really).
1900# In two weeks or so, we should remove this decomposition and phase out the current native_batch_norm
1901# to be _native_batch_norm_legit and have the right schema (stating that there are input mutations).
1902@aten.native_batch_norm.default.py_impl(DispatchKey.Autograd)
1903@aten.native_batch_norm.default.py_impl(DispatchKey.CompositeImplicitAutograd)
1904def native_batch_norm_decomposition(
1905    input: Tensor,
1906    weight: Optional[Tensor],
1907    bias: Optional[Tensor],
1908    running_mean: Optional[Tensor],
1909    running_var: Optional[Tensor],
1910    training: bool,
1911    momentum: float,
1912    eps: float,
1913) -> Tuple[Tensor, Tensor, Tensor]:
1914    if running_mean is None and running_var is None:
1915        return aten._native_batch_norm_legit(
1916            input, weight, bias, training, momentum, eps
1917        )
1918    if running_mean is None:
1919        raise RuntimeError(
1920            "running_mean is None, but running_var is provided. "
1921            "They should both be None or both be provided."
1922        )
1923    if running_var is None:
1924        raise RuntimeError(
1925            "running_var is None, but running_mean is provided. "
1926            "They should both be None or both be provided."
1927        )
1928    if training:
1929        # HACK: batch norm consolidation should clean this up so this op doesn't take in a training arg.
1930        return aten._native_batch_norm_legit(
1931            input, weight, bias, running_mean, running_var, training, momentum, eps
1932        )
1933    else:
1934        return aten._native_batch_norm_legit_no_training(
1935            input, weight, bias, running_mean, running_var, momentum, eps
1936        )
1937
1938
1939@aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd)
1940def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> List[Tensor]:
1941    dim_size = tensor.size(dim)
1942    split_size = (dim_size + chunks - 1) // chunks
1943
1944    if split_size == 0 and dim_size == 0:
1945        split_sizes = [split_size for _ in chunks]
1946        split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
1947        return torch.ops.aten.unsafe_split_with_sizes.default(tensor, split_sizes, dim)
1948    return torch.ops.aten.unsafe_split.Tensor(tensor, split_size, dim)
1949
1950
1951@register_decomposition(aten._native_batch_norm_legit_no_training.default)
1952def _native_batch_norm_legit_no_training(
1953    input: Tensor,
1954    weight: Optional[Tensor],
1955    bias: Optional[Tensor],
1956    running_mean: Tensor,
1957    running_var: Tensor,
1958    momentum: float,
1959    eps: float,
1960) -> Tuple[Tensor, Tensor, Tensor]:
1961    return aten._native_batch_norm_legit.default(
1962        input,
1963        weight,
1964        bias,
1965        running_mean,
1966        running_var,
1967        False,  # training
1968        momentum,
1969        eps,
1970    )
1971
1972
1973@register_decomposition(aten._native_batch_norm_legit.default)
1974def _native_batch_norm_legit(
1975    input: Tensor,
1976    weight: Optional[Tensor],
1977    bias: Optional[Tensor],
1978    running_mean: Tensor,
1979    running_var: Tensor,
1980    training: bool,
1981    momentum: float,
1982    eps: float,
1983) -> Tuple[Tensor, Tensor, Tensor]:
1984    output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
1985        input, weight, bias, running_mean, running_var, training, momentum, eps, False
1986    )
1987    return output, save_mean, save_rstd
1988
1989
1990@register_decomposition(aten._native_batch_norm_legit.no_stats)
1991def _native_batch_norm_legit_no_stats(
1992    input: Tensor,
1993    weight: Optional[Tensor],
1994    bias: Optional[Tensor],
1995    training: bool,
1996    momentum: float,
1997    eps: float,
1998) -> Tuple[Tensor, Tensor, Tensor]:
1999    output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
2000        input, weight, bias, None, None, training, momentum, eps, False
2001    )
2002    return output, save_mean, save_rstd
2003
2004
2005@register_decomposition(aten._native_batch_norm_legit_functional.default)
2006def _native_batch_norm_legit_functional(
2007    input: Tensor,
2008    weight: Optional[Tensor],
2009    bias: Optional[Tensor],
2010    running_mean: Tensor,
2011    running_var: Tensor,
2012    training: bool,
2013    momentum: float,
2014    eps: float,
2015) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
2016    (
2017        output,
2018        save_mean,
2019        save_rstd,
2020        new_running_mean,
2021        new_running_var,
2022    ) = native_batch_norm_helper(
2023        input, weight, bias, running_mean, running_var, training, momentum, eps, True
2024    )
2025    assert new_running_mean is not None, "new_running_mean should not be None"
2026    assert new_running_var is not None, "new_running_var should not be None"
2027    return output, save_mean, save_rstd, new_running_mean, new_running_var
2028
2029
2030def _get_batch_norm_reserve_tensor(
2031    input: Tensor,
2032    weight: Optional[Tensor],
2033    bias: Optional[Tensor],
2034    running_mean: Tensor,
2035    running_var: Tensor,
2036    eps: float,
2037    training: bool,
2038) -> Tensor:
2039    """
2040    Return a reserve tensor for batch norm, used only by cudnn to pass forward state to the
2041    backward pass. This is needed for `_batch_norm_with_update` and `_batch_norm_no_update`,
2042    which support a variety of backends including cudnn. We create this tensor here to get
2043    the correct shape in the traced graph if we detect that will call the cudnn kernel,
2044    and rely on DCE to avoid materializing this tensor.
2045    """
2046    backend = torch._C._select_batch_norm_backend(  # type: ignore[attr-defined]
2047        input, weight, bias, running_mean, running_var, True, eps
2048    )
2049    reserve_size = 0
2050    if backend == torch._C._BatchNormBackend.Cudnn:  # type: ignore[attr-defined]
2051        reserve_size = torch._C._get_cudnn_batch_norm_reserve_space_size(input, training)  # type: ignore[attr-defined]
2052    return torch.empty(
2053        reserve_size, dtype=torch.uint8, layout=input.layout, device=input.device
2054    )
2055
2056
2057@register_decomposition(aten._batch_norm_with_update.default)
2058def _batch_norm_with_update(
2059    input: Tensor,
2060    weight: Optional[Tensor],
2061    bias: Optional[Tensor],
2062    running_mean: Tensor,
2063    running_var: Tensor,
2064    momentum: float,
2065    eps: float,
2066) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
2067    output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
2068        input,
2069        weight,
2070        bias,
2071        running_mean,
2072        running_var,
2073        True,  # training
2074        momentum,
2075        eps,
2076        False,  # functional
2077    )
2078    reserve = _get_batch_norm_reserve_tensor(
2079        input, weight, bias, running_mean, running_var, eps, training=True
2080    )
2081    return output, save_mean, save_rstd, reserve
2082
2083
2084@register_decomposition(aten._batch_norm_with_update_functional.default)
2085def _batch_norm_with_update_functional(
2086    input: Tensor,
2087    weight: Optional[Tensor],
2088    bias: Optional[Tensor],
2089    running_mean: Tensor,
2090    running_var: Tensor,
2091    momentum: float,
2092    eps: float,
2093) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
2094    (
2095        output,
2096        save_mean,
2097        save_rstd,
2098        new_rm,
2099        new_rv,
2100    ) = native_batch_norm_helper(
2101        input, weight, bias, running_mean, running_var, True, momentum, eps, True
2102    )
2103    reserve = _get_batch_norm_reserve_tensor(
2104        input, weight, bias, running_mean, running_var, eps, training=True
2105    )
2106    assert new_rm is not None, "new_running_mean should not be None"
2107    assert new_rv is not None, "new_running_var should not be None"
2108    return (output, save_mean, save_rstd, reserve, new_rm, new_rv)
2109
2110
2111@register_decomposition(aten._batch_norm_no_update.default)
2112def _batch_norm_no_update(
2113    input: Tensor,
2114    weight: Optional[Tensor],
2115    bias: Optional[Tensor],
2116    running_mean: Tensor,
2117    running_var: Tensor,
2118    momentum: float,
2119    eps: float,
2120) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
2121    output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
2122        input,
2123        weight,
2124        bias,
2125        running_mean,
2126        running_var,
2127        False,  # training
2128        momentum,
2129        eps,
2130        False,  # functional
2131    )
2132    reserve = _get_batch_norm_reserve_tensor(
2133        input, weight, bias, running_mean, running_var, eps, training=False
2134    )
2135    return output, save_mean, save_rstd, reserve
2136
2137
2138@register_decomposition(aten._fused_dropout)
2139@out_wrapper("out0", "out1")
2140@pw_cast_for_opmath
2141def _fused_dropout_decomposition(input, p, generator=None):
2142    assert generator is None
2143    mask = (torch.rand_like(input) < p).to(dtype=torch.uint8)
2144    res = mask.type_as(input) * input * (1.0 / p)
2145    return (res, mask)
2146
2147
2148@register_decomposition(aten._to_copy)
2149@out_wrapper()
2150def _to_copy(
2151    x: Union[Tensor, NumberType],
2152    *,
2153    dtype: Optional[torch.dtype] = None,
2154    layout=None,
2155    device: Optional[torch.device] = None,
2156    pin_memory: bool = False,
2157    non_blocking: bool = False,
2158    memory_format: Optional[torch.memory_format] = None,
2159):
2160    assert not layout or layout == torch.strided, "TODO"
2161    assert not pin_memory, "TODO"
2162    assert isinstance(x, (torch.Tensor, int, float, bool, complex))
2163    if device is None and dtype is None and memory_format is None:
2164        if isinstance(x, torch.Tensor):
2165            return x.clone()
2166        else:
2167            return x
2168    dtype_converted = False
2169
2170    if isinstance(x, torch.Tensor):
2171        x_tensor = x
2172    else:
2173        x_tensor = torch.scalar_tensor(x)
2174
2175    if device is not None and device != x_tensor.device:
2176        # avoid conversions on cpu
2177        if dtype is not None and device.type == "cpu":
2178            x_tensor = torch._prims.convert_element_type(x_tensor, dtype)
2179            dtype_converted = True
2180        x_tensor = torch._prims.device_put(x_tensor, device)
2181
2182    if dtype is not None and not dtype_converted:
2183        x_tensor = torch._prims.convert_element_type(x_tensor, dtype)
2184        dtype_converted = True
2185
2186    if memory_format is not None:  # no ref/prim for memory format
2187        return torch.clone(x_tensor, memory_format=memory_format)
2188    return x_tensor
2189
2190
2191# Questionable decompositions
2192# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
2193# Note that this decomposition causes issues with in-place ops
2194@register_decomposition([aten.detach, aten.lift, aten.lift_fresh])
2195@out_wrapper()
2196def nop_decomposition(x):
2197    return aten.alias(x)
2198
2199
2200# Also register to the Autograd dispatch key, so this decomp can run above autograd.
2201# native_batch_norm needs to decompose into other ops before autograd.
2202@aten.cudnn_batch_norm.default.py_impl(DispatchKey.Autograd)
2203@register_decomposition(aten.cudnn_batch_norm)
2204@out_wrapper("out0", "out1", "out2", "out3")
2205def cudnn_batch_norm(
2206    input: Tensor,
2207    weight: Tensor,
2208    bias: Optional[Tensor],
2209    running_mean: Optional[Tensor],
2210    running_var: Optional[Tensor],
2211    training: bool,
2212    exponential_average_factor: float,
2213    epsilon: float,
2214):
2215    a, b, c = aten.native_batch_norm(
2216        input,
2217        weight,
2218        bias,
2219        running_mean,
2220        running_var,
2221        training,
2222        exponential_average_factor,
2223        epsilon,
2224    )
2225    # Cudnn return running mean and variance when training is True
2226    if training:
2227        return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
2228    return (
2229        a,
2230        weight.new_zeros((0,)),
2231        weight.new_zeros((0,)),
2232        input.new_zeros((0,), dtype=torch.uint8),
2233    )
2234
2235
2236def _broadcast_batch_norm_backward(x, broadcast_mask):
2237    for axis, mask in enumerate(broadcast_mask):
2238        if mask == 1 and not (axis < x.ndim and x.shape[axis] == mask):
2239            x = x.unsqueeze(axis)
2240    return x
2241
2242
2243@register_decomposition(aten.batch_norm_backward.default)
2244def batch_norm_backward(
2245    grad_out: Tensor,
2246    input: Tensor,
2247    weight: Optional[Tensor],
2248    running_mean: Optional[Tensor],
2249    running_var: Optional[Tensor],
2250    save_mean: Optional[Tensor],
2251    save_invstd: Optional[Tensor],
2252    train: bool,
2253    eps: float,
2254    output_mask: List[bool],
2255    reserve: Tensor,
2256) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
2257    return native_batch_norm_backward(
2258        grad_out,
2259        input,
2260        weight,
2261        running_mean,
2262        running_var,
2263        save_mean,
2264        save_invstd,
2265        train,
2266        eps,
2267        output_mask,
2268    )
2269
2270
2271@register_decomposition(aten.native_batch_norm_backward.default)
2272def native_batch_norm_backward(
2273    grad_out: Tensor,
2274    input: Tensor,
2275    weight: Optional[Tensor],
2276    running_mean: Optional[Tensor],
2277    running_var: Optional[Tensor],
2278    save_mean: Optional[Tensor],
2279    save_invstd: Optional[Tensor],
2280    train: bool,
2281    eps: float,
2282    output_mask: List[bool],
2283) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
2284    input_dtype = input.dtype
2285    if weight is not None:
2286        weight_dtype = weight.dtype
2287    else:
2288        weight_dtype = input_dtype
2289    computation_dtype = utils.get_computation_dtype(input.dtype)
2290    (
2291        grad_out_cast,
2292        input_cast,
2293        weight_cast,
2294        running_mean_cast,
2295        running_var_cast,
2296        save_mean_cast,
2297        save_invstd_cast,
2298    ) = (
2299        x.to(computation_dtype) if x is not None else x
2300        for x in (
2301            grad_out,
2302            input,
2303            weight,
2304            running_mean,
2305            running_var,
2306            save_mean,
2307            save_invstd,
2308        )
2309    )
2310    input_shape = input.shape
2311    input_rank = input.dim()
2312    assert input_rank >= 2, "rank of the input must be at least 2"
2313
2314    axis = 1
2315    num_features = prod(list(input_shape)) / input_shape[axis]
2316    mean = save_mean_cast
2317    invstd = save_invstd_cast
2318    if train:
2319        assert save_mean_cast is not None and save_invstd_cast is not None
2320    else:
2321        assert running_mean_cast is not None and running_var_cast is not None
2322        mean = running_mean_cast
2323        invstd = torch.rsqrt(running_var_cast + eps)
2324
2325    broadcast_mask: List[int] = [1] * input_rank
2326    broadcast_mask[axis] = input_shape[axis]
2327
2328    reduction_axes: List[int] = []
2329    for i in range(input_rank):
2330        if i != axis:
2331            reduction_axes.append(i)
2332
2333    mean = _broadcast_batch_norm_backward(mean, broadcast_mask)  # type: ignore[arg-type]
2334    norm = 1.0 / num_features
2335    grad_output_sum = torch.sum(grad_out_cast, reduction_axes)  # type: ignore[arg-type]
2336    dot_p = torch.sum(grad_out_cast * (input_cast - mean), reduction_axes)  # type: ignore[operator]
2337
2338    grad_mean = _broadcast_batch_norm_backward(grad_output_sum * norm, broadcast_mask)
2339    proj_scale = _broadcast_batch_norm_backward(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)  # type: ignore[operator]
2340
2341    if weight_cast is None:
2342        grad_scale = _broadcast_batch_norm_backward(invstd, broadcast_mask) * 1.0  # type: ignore[arg-type]
2343    else:
2344        grad_scale = _broadcast_batch_norm_backward(
2345            invstd * weight_cast, broadcast_mask
2346        )
2347
2348    if train:
2349        proj = (input_cast - mean) * proj_scale  # type: ignore[operator]
2350        grad_input = ((grad_out_cast - proj) - grad_mean) * grad_scale
2351    else:
2352        grad_input = grad_out_cast * grad_scale
2353
2354    if output_mask[1]:
2355        grad_weight = dot_p * invstd
2356    else:
2357        grad_weight = None  # "None" doesn't work with vjp, should use zeros for vjp
2358
2359    if output_mask[2]:
2360        grad_bias = grad_output_sum
2361    else:
2362        grad_bias = None  # "None" doesn't work with vjp, should use zeros for vjp
2363
2364    return (
2365        grad_input.to(input_dtype),
2366        _maybe_cast(grad_weight, weight_dtype),
2367        _maybe_cast(grad_bias, weight_dtype),
2368    )
2369
2370
2371# out_wrapper currently does not allow optional outputs
2372@register_decomposition(aten.native_batch_norm_backward.out)
2373def native_batch_norm_backward_out(
2374    grad_out: Tensor,
2375    input: Tensor,
2376    weight: Optional[Tensor],
2377    running_mean: Optional[Tensor],
2378    running_var: Optional[Tensor],
2379    save_mean: Optional[Tensor],
2380    save_invstd: Optional[Tensor],
2381    train: bool,
2382    eps: float,
2383    output_mask: List[bool],
2384    *,
2385    out0: torch.Tensor,
2386    out1: torch.Tensor,
2387    out2: torch.Tensor,
2388) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
2389    result = native_batch_norm_backward(
2390        grad_out,
2391        input,
2392        weight,
2393        running_mean,
2394        running_var,
2395        save_mean,
2396        save_invstd,
2397        train,
2398        eps,
2399        output_mask,
2400    )
2401    grad_input = (out0, out1, out2)
2402    for i, r in enumerate(result):
2403        if r is not None:
2404            _maybe_resize_out(grad_input[i], r.shape)
2405            _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True)
2406
2407    return grad_input
2408
2409
2410@register_decomposition(aten.miopen_batch_norm_backward)
2411@out_wrapper("out0", "out1", "out2")
2412def miopen_batch_norm_backward(
2413    input: Tensor,
2414    grad_output: Tensor,
2415    weight: Tensor,
2416    running_mean: Optional[Tensor],
2417    running_var: Optional[Tensor],
2418    save_mean: Optional[Tensor],
2419    save_var: Optional[Tensor],
2420    epsilon: float,
2421):
2422    return aten.native_batch_norm_backward(
2423        grad_output,
2424        input,
2425        weight,
2426        running_mean,
2427        running_var,
2428        save_mean,
2429        save_var,
2430        True,
2431        epsilon,
2432        [True, True, True],
2433    )
2434
2435
2436@register_decomposition(aten.cudnn_batch_norm_backward)
2437@out_wrapper("out0", "out1", "out2")
2438def cudnn_batch_norm_backward(
2439    input: Tensor,
2440    grad_output: Tensor,
2441    weight: Tensor,
2442    running_mean: Optional[Tensor],
2443    running_var: Optional[Tensor],
2444    save_mean: Optional[Tensor],
2445    save_var: Optional[Tensor],
2446    epsilon: float,
2447    reserveSpace: Tensor,
2448):
2449    return aten.native_batch_norm_backward(
2450        grad_output,
2451        input,
2452        weight,
2453        running_mean,
2454        running_var,
2455        save_mean,
2456        save_var,
2457        True,
2458        epsilon,
2459        [True, True, True],
2460    )
2461
2462
2463@register_decomposition(aten._adaptive_avg_pool2d)
2464@out_wrapper()
2465@pw_cast_for_opmath
2466def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
2467    # Preconditions
2468    device = input.device
2469    shape = input.shape
2470    ndim = len(shape)
2471    torch._check(
2472        ndim in (3, 4),
2473        lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}",
2474    )
2475    for d in input.shape[-2:]:
2476        torch._check(
2477            d != 0,
2478            lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for "
2479            f"non-batch dimensions, but input has shape {tuple(shape)}.",
2480        )
2481
2482    # Optimisation (we should also do this in the kernel implementation)
2483    if shape[-2] % output_size[-2] == 0 and shape[-1] % output_size[-1] == 0:
2484        stride = tuple(i // o for i, o in zip(shape[-2:], output_size))
2485        kernel = tuple(
2486            i - (o - 1) * s for i, o, s in zip(shape[-2:], output_size, stride)
2487        )
2488        return torch.nn.functional.avg_pool2d(input, kernel, stride)
2489
2490    def start_index(a, b, c):
2491        return torch.div(a * c, b, rounding_mode="trunc")
2492
2493    def end_index(a, b, c):
2494        return torch.div((a + 1) * c + b - 1, b, rounding_mode="trunc")
2495
2496    def compute_idx(in_size, out_size):
2497        orange = torch.arange(out_size, device=device, dtype=torch.int64)
2498        i0 = start_index(orange, out_size, in_size)
2499        # Let length = end_index - start_index, i.e. the length of the pooling kernels
2500        # length.max() can be computed analytically as follows:
2501        maxlength = in_size // out_size + 1
2502        in_size_mod = in_size % out_size
2503        # adaptive = True iff there are kernels with different lengths
2504        adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
2505        if adaptive:
2506            maxlength += 1
2507        elif in_size_mod == 0:
2508            maxlength -= 1
2509
2510        range_max = torch.arange(maxlength, device=device, dtype=torch.int64)
2511        idx = i0.unsqueeze(-1) + range_max
2512        if adaptive:
2513            # Need to clamp to avoid accessing out-of-bounds memory
2514            # TODO make minimum accept scalars
2515            maxval = torch.scalar_tensor(
2516                in_size - 1, dtype=idx.dtype, device=idx.device
2517            )
2518            idx = torch.minimum(idx, maxval)
2519
2520            # Compute the length
2521            i1 = end_index(orange, out_size, in_size)
2522            length = i1 - i0
2523        else:
2524            length = maxlength
2525        return idx, length, range_max, adaptive
2526
2527    # length is not None if it's constant, otherwise we'll need to compute it
2528    idxh, length_h, range_max_h, adaptive_h = compute_idx(shape[-2], output_size[-2])
2529    idxw, length_w, range_max_w, adaptive_w = compute_idx(shape[-1], output_size[-1])
2530
2531    vals = input[..., _unsqueeze_to_dim(idxh, 4), idxw]
2532    # Shortcut for the simpler case
2533    if not adaptive_h and not adaptive_w:
2534        return torch.mean(vals, dim=(-3, -1))
2535
2536    def maybe_mask(vals, length, range_max, adaptive, dim):
2537        if isinstance(length, IntLike):
2538            return vals, length
2539        else:
2540            # zero-out the things we didn't really want to select
2541            assert dim < 0
2542            # hack
2543            mask = range_max >= length.unsqueeze(-1)
2544            if dim == -2:
2545                mask = _unsqueeze_to_dim(mask, 4)
2546            vals = torch.masked_fill(vals, mask, 0.0)
2547            # Compute the length of each window
2548            length = _unsqueeze_to_dim(length, -dim)
2549            return vals, length
2550
2551    vals, length_h = maybe_mask(
2552        vals, length_h, range_max_h, adaptive=adaptive_h, dim=-2
2553    )
2554    vals, length_w = maybe_mask(
2555        vals, length_w, range_max_w, adaptive=adaptive_w, dim=-1
2556    )
2557
2558    # We unroll the sum as we assume that the kernels are going to be small
2559    ret = None
2560    for i, j in product(range(vals.shape[-3]), range(vals.shape[-1])):
2561        if ret is None:
2562            ret = vals[..., i, :, j]
2563        else:
2564            ret = ret + vals[..., i, :, j]
2565    return ret / (length_h * length_w)
2566
2567
2568@register_decomposition(aten.index_add_)
2569def index_add_(
2570    x: TensorLike,
2571    dim: int,
2572    index: TensorLike,
2573    tensor: TensorLike,
2574    *,
2575    alpha: NumberType = 1,
2576):
2577    return _index_add(x, dim, index, tensor, inplace=True, alpha=alpha)
2578
2579
2580@register_decomposition(aten.index_add)
2581@out_wrapper()
2582def index_add(
2583    x: TensorLike,
2584    dim: int,
2585    index: TensorLike,
2586    tensor: TensorLike,
2587    *,
2588    alpha: NumberType = 1,
2589):
2590    return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha)
2591
2592
2593def _index_add(
2594    x: TensorLike,
2595    dim: int,
2596    index: TensorLike,
2597    tensor: TensorLike,
2598    *,
2599    inplace: bool,
2600    alpha: NumberType = 1,
2601):
2602    dim = utils.canonicalize_dims(x.ndim, dim)
2603    torch._check(
2604        index.ndim <= 1,
2605        lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
2606    )
2607    index_size = index.size(0) if index.ndim == 1 else 1
2608    tensor_size = tensor.size(dim) if tensor.ndim > 0 else 1
2609    torch._check(
2610        tensor_size == index_size,
2611        lambda: f"Number of indices ({index_size}) should be equal to tensor.size(dim) ({tensor_size}), for {dim=}",
2612    )
2613    if alpha != 1:
2614        python_type = utils.dtype_to_type(x.dtype)
2615        torch._check(
2616            python_type == bool
2617            or utils.is_weakly_lesser_type(type(alpha), python_type),
2618            lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
2619        )
2620        tensor = tensor * alpha
2621    # Treat scalars as elements of \R^1
2622    zero_dim = x.ndim == 0
2623    x1 = x.unsqueeze(0) if zero_dim else x
2624    idx = (None,) * dim + (index,)
2625    index_put = aten.index_put_ if inplace else aten.index_put
2626    out = index_put(x1, idx, tensor, accumulate=True)
2627    if inplace:
2628        return x
2629    else:
2630        return out.squeeze(0) if zero_dim else out.contiguous()
2631
2632
2633@register_decomposition(aten.pad_sequence.default)
2634@aten.pad_sequence.default.py_impl(DispatchKey.CompositeImplicitAutograd)
2635def pad_sequence(sequences, batch_first=False, padding_value=0.0):
2636    torch._check(len(sequences) > 0, lambda: "received an empty list of sequences")
2637    sequences_size = len(sequences)
2638    max_size = sequences[0].size()
2639    trailing_dims = max_size[1:]
2640    max_len = max(x.size(0) for x in sequences)
2641    if batch_first:
2642        out_dims = (sequences_size, max_len)
2643    else:
2644        out_dims = (max_len, sequences_size)
2645    out_dims = out_dims + trailing_dims
2646    out = sequences[0].new_full(out_dims, padding_value)
2647    dim_paddings = (0, 0) * len(trailing_dims)
2648    for i in range(sequences_size):
2649        currseq = sequences[i]
2650        row = aten.constant_pad_nd(
2651            currseq, dim_paddings + (0, max_len - currseq.size(0)), padding_value
2652        )
2653        if batch_first:
2654            out = aten.select_scatter(out, row, dim=0, index=i)
2655        else:
2656            out = aten.select_scatter(out, row, dim=1, index=i)
2657    return out
2658
2659
2660@register_decomposition(aten.index_copy_)
2661def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
2662    return _index_copy(x, dim, index, tensor, inplace=True)
2663
2664
2665@register_decomposition(aten.index_copy)
2666@out_wrapper()
2667def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
2668    return _index_copy(x, dim, index, tensor, inplace=False)
2669
2670
2671def _index_copy(
2672    x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool
2673):
2674    dim = utils.canonicalize_dims(x.ndim, dim)
2675    torch._check(
2676        index.ndim <= 1,
2677        lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
2678    )
2679    # Treat scalars as elements of \R^1
2680    zero_dim = x.ndim == 0
2681    x1 = x.unsqueeze(0) if zero_dim else x
2682    index = index.unsqueeze(0) if index.ndim == 0 else index
2683    idx = (None,) * dim + (index,)
2684    index_put = aten.index_put_ if inplace else aten.index_put
2685    out = index_put(x1, idx, tensor)
2686    if inplace:
2687        return x
2688    else:
2689        return out.squeeze(0) if zero_dim else out.contiguous()
2690
2691
2692# nb: Should use acc_t, not op_math
2693@register_decomposition(aten.log_sigmoid_forward)
2694@out_wrapper("output", "buffer")
2695@pw_cast_for_opmath
2696def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
2697    min = torch.minimum(self.new_zeros(()), self)
2698    z = torch.exp(-torch.abs(self))
2699    if self.is_cuda:
2700        buffer = self.new_zeros((0,))
2701    else:
2702        buffer = z
2703    return min - torch.log1p(z), buffer
2704
2705
2706@register_decomposition(aten.uniform)
2707@out_wrapper()
2708def uniform(
2709    x: Tensor,
2710    low: Union[bool, int, float] = 0.0,
2711    high: Union[bool, int, float] = 1.0,
2712    generator: Optional[torch.Generator] = None,
2713):
2714    return prims._uniform_helper(
2715        x.shape,
2716        low=sym_float(low),
2717        high=sym_float(high),
2718        dtype=x.dtype,
2719        device=x.device,
2720        generator=generator,
2721    )
2722
2723
2724@register_decomposition(aten.uniform_)
2725def uniform_(self, low=0, high=1, generator=None):
2726    return self.copy_(uniform(self, low, high, generator))
2727
2728
2729# aten/src/ATen/native/UpSample.cpp compute_output_size
2730def upsample_compute_output_size(input_size, output_size, scale_factors):
2731    spatial_dimensions = len(input_size) - 2
2732    if output_size is not None:
2733        torch._check(
2734            scale_factors is None,
2735            lambda: "Must specify exactly one of output_size and scale_factors",
2736        )
2737        torch._check(len(output_size) == spatial_dimensions, lambda: "")
2738        return output_size
2739    if scale_factors is not None:
2740        # NB: this isn't necessary lol
2741        torch._check(
2742            output_size is None,
2743            lambda: "Must specify exactly one of output_size and scale_factors",
2744        )
2745        torch._check(len(scale_factors) == spatial_dimensions, lambda: "")
2746        output_size = []
2747        for i, s in enumerate(scale_factors):
2748            if int(s) == s:
2749                output_size.append(input_size[i + 2] * int(s))
2750            else:
2751                output_size.append(sym_int(input_size[i + 2] * s))
2752        return output_size
2753    torch._check(
2754        False, lambda: "Must specify exactly one of output_size and scale_factors"
2755    )
2756
2757
2758def get_scale_value(scales, idx):
2759    if scales is None:
2760        return None
2761    return scales[idx]
2762
2763
2764@register_decomposition(aten.upsample_nearest1d.vec)
2765@register_decomposition(aten.upsample_nearest2d.vec)
2766@register_decomposition(aten.upsample_nearest3d.vec)
2767@aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
2768@aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd)
2769@aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
2770@aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd)
2771@aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
2772@aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd)
2773def _upsample_nearest_vec(
2774    input: Tensor,
2775    output_size: Optional[List[int]],
2776    scale_factors: Optional[List[float]],
2777) -> Tensor:
2778    osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
2779    scales = (
2780        scale_factors if scale_factors else [None] * len(osize)  # type: ignore[list-item]
2781    )
2782    return _upsample_nearest(input, osize, scales)
2783
2784
2785@register_decomposition(aten._upsample_nearest_exact1d.vec)
2786@register_decomposition(aten._upsample_nearest_exact2d.vec)
2787@register_decomposition(aten._upsample_nearest_exact3d.vec)
2788@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
2789@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.Autograd)
2790@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
2791@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.Autograd)
2792@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
2793@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd)
2794def _upsample_nearest_exact_vec(
2795    input: Tensor,
2796    output_size: Optional[List[int]],
2797    scale_factors: Optional[List[float]],
2798) -> Tensor:
2799    osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
2800    scales = (
2801        scale_factors if scale_factors else [None] * len(osize)  # type: ignore[list-item]
2802    )
2803    return _upsample_nearest(input, osize, scales, exact=True)
2804
2805
2806def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
2807    # For each dim in output_size, compute the set of input indices used
2808    # to produce the upsampled output.
2809    indices = []
2810    num_spatial_dims = len(output_size)
2811    offset = 0.5 if exact else 0.0
2812
2813    for d in range(num_spatial_dims):
2814        # Math matches aten/src/ATen/native/cpu/UpSampleKernel.cpp
2815        #
2816        # Indices are computed as following:
2817        # scale = isize / osize
2818        # Case: exact=False
2819        # input_index = floor(output_index * scale)
2820        # Same as OpenCV INTER_NEAREST
2821        #
2822        # Case: exact=False
2823        # index_f32 = (output_index + 0.5) * scale - 0.5
2824        # input_index = round(index_f32)
2825        # Same as Pillow and Scikit-Image/Scipy ndi.zoom
2826        osize = output_size[d]
2827        isize = input.shape[-num_spatial_dims + d]
2828        scale = isize / (isize * scales[d]) if scales[d] is not None else isize / osize
2829
2830        output_indices = torch.arange(osize, dtype=torch.float32, device=input.device)
2831        input_indices = ((output_indices + offset) * scale).to(torch.int64)
2832        for _ in range(num_spatial_dims - 1 - d):
2833            input_indices = input_indices.unsqueeze(-1)
2834        indices.append(input_indices)
2835    return indices
2836
2837
2838@register_decomposition([aten.upsample_nearest1d.default, aten.upsample_nearest1d.out])
2839@aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
2840@aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd)
2841@out_wrapper(preserve_memory_format=True, exact_dtype=True)
2842def upsample_nearest1d(
2843    input: Tensor,
2844    output_size: List[int],
2845    scales: Optional[float] = None,
2846) -> Tensor:
2847    return _upsample_nearest(input, output_size, [scales])
2848
2849
2850@register_decomposition(
2851    [aten._upsample_nearest_exact1d.default, aten._upsample_nearest_exact1d.out]
2852)
2853@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
2854@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd)
2855@out_wrapper(preserve_memory_format=True, exact_dtype=True)
2856def upsample_nearest_exact1d(
2857    input: Tensor,
2858    output_size: List[int],
2859    scales: Optional[float] = None,
2860) -> Tensor:
2861    return _upsample_nearest(input, output_size, [scales], exact=True)
2862
2863
2864@register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out])
2865@aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
2866@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
2867@out_wrapper(preserve_memory_format=True, exact_dtype=True)
2868def upsample_nearest2d(
2869    input: Tensor,
2870    output_size: List[int],
2871    scales_h: Optional[float] = None,
2872    scales_w: Optional[float] = None,
2873) -> Tensor:
2874    return _upsample_nearest(input, output_size, [scales_h, scales_w])
2875
2876
2877@register_decomposition(
2878    [aten._upsample_nearest_exact2d.default, aten._upsample_nearest_exact2d.out]
2879)
2880@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
2881@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd)
2882@out_wrapper(preserve_memory_format=True, exact_dtype=True)
2883def _upsample_nearest_exact2d(
2884    input: Tensor,
2885    output_size: List[int],
2886    scales_h: Optional[float] = None,
2887    scales_w: Optional[float] = None,
2888) -> Tensor:
2889    return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True)
2890
2891
2892@register_decomposition([aten.upsample_nearest3d.default, aten.upsample_nearest3d.out])
2893@aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
2894@aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd)
2895@out_wrapper(preserve_memory_format=True, exact_dtype=True)
2896def upsample_nearest3d(
2897    input: Tensor,
2898    output_size: List[int],
2899    scales_d: Optional[float] = None,
2900    scales_h: Optional[float] = None,
2901    scales_w: Optional[float] = None,
2902) -> Tensor:
2903    return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w])
2904
2905
2906@register_decomposition(
2907    [aten._upsample_nearest_exact3d.default, aten._upsample_nearest_exact3d.out]
2908)
2909@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
2910@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd)
2911@out_wrapper(preserve_memory_format=True, exact_dtype=True)
2912def _upsample_nearest_exact3d(
2913    input: Tensor,
2914    output_size: List[int],
2915    scales_d: Optional[float] = None,
2916    scales_h: Optional[float] = None,
2917    scales_w: Optional[float] = None,
2918) -> Tensor:
2919    return _upsample_nearest(
2920        input, output_size, [scales_d, scales_h, scales_w], exact=True
2921    )
2922
2923
2924@pw_cast_for_opmath
2925def _upsample_nearest(
2926    input: Tensor,
2927    output_size: List[int],
2928    scales: List[Optional[float]],
2929    exact: bool = False,
2930) -> Tensor:
2931    spatial_indices = _compute_upsample_nearest_indices(
2932        input, output_size, scales, exact=exact
2933    )
2934
2935    indices = [None, None] + spatial_indices
2936    result = aten._unsafe_index(input, indices)
2937
2938    if result.ndim == 4:
2939        # convert output to correct memory format, if necessary
2940        memory_format = utils.suggest_memory_format(input)
2941
2942        # following "heuristic: only use channels_last path when it's faster than the contiguous path"
2943        n_channels = input.shape[1]
2944        if input.device.type == "cuda" and n_channels < 4:
2945            memory_format = torch.contiguous_format
2946
2947        result = result.contiguous(memory_format=memory_format)
2948    return result
2949
2950
2951def gather_params(params, has_biases, has_projections):
2952    if has_biases and has_projections:
2953        group_size = 5
2954    elif has_biases:
2955        group_size = 4
2956    elif has_projections:
2957        group_size = 3
2958    else:
2959        group_size = 2
2960
2961    assert len(params) % group_size == 0, len(params)
2962    return [
2963        tuple(params[i : i + group_size]) for i in range(0, len(params), group_size)
2964    ]
2965
2966
2967def params_hiddens(params, hiddens, i, bidirectional):
2968    if bidirectional:
2969        cur_params, cur_hidden = params[2 * i], hiddens[2 * i]
2970        bidir_params, bidir_hidden = params[2 * i + 1], hiddens[2 * i + 1]
2971    else:
2972        cur_params, cur_hidden = params[i], hiddens[i]
2973        bidir_params, bidir_hidden = None, None
2974
2975    return cur_params, cur_hidden, bidir_params, bidir_hidden
2976
2977
2978def update_hidden_for_packed(cur_hidden, last_batch_size, batch_size, hiddens):
2979    assert last_batch_size > batch_size
2980    hiddens.append(cur_hidden.narrow(0, batch_size, last_batch_size - batch_size))
2981    return cur_hidden.narrow(0, 0, batch_size)
2982
2983
2984def update_hidden_for_packed_reverse(
2985    cur_hidden, last_batch_size, batch_size, inp_hidden
2986):
2987    if last_batch_size == batch_size:
2988        return cur_hidden
2989    assert last_batch_size < batch_size
2990    return torch.concat(
2991        (
2992            cur_hidden,
2993            inp_hidden.narrow(0, last_batch_size, batch_size - last_batch_size),
2994        )
2995    )
2996
2997
2998def one_layer_rnn_data(
2999    inp, hidden, params, has_biases, hidden_fn, batch_sizes, reverse=False
3000):
3001    ih_weight = params[0]
3002    hh_weight = params[1]
3003    ih_bias = params[2] if has_biases else None
3004    hh_bias = params[3] if has_biases else None
3005
3006    step_output = []
3007    hiddens: List[torch.Tensor] = []
3008
3009    last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0]
3010    cur_hidden = hidden.narrow(0, 0, last_batch_size)
3011    split_inp = torch.split(inp, list(batch_sizes))
3012    if reverse:
3013        split_inp = split_inp[::-1]
3014    for inp in split_inp:
3015        i = inp.shape[0]
3016
3017        if last_batch_size == i:
3018            pass  # don't update cur_hidden
3019        # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest
3020        elif reverse:
3021            cur_hidden = update_hidden_for_packed_reverse(
3022                cur_hidden, last_batch_size, i, hidden
3023            )
3024        else:
3025            cur_hidden = update_hidden_for_packed(
3026                cur_hidden, last_batch_size, i, hiddens
3027            )
3028
3029        cur_hidden = hidden_fn(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
3030        last_batch_size = i
3031        step_output.append(cur_hidden)
3032
3033    if reverse:
3034        step_output.reverse()
3035    else:
3036        hiddens.append(cur_hidden)
3037        hiddens.reverse()
3038
3039    out = torch.cat(step_output, 0)
3040    hidden_out = torch.cat(hiddens, 0) if not reverse else cur_hidden
3041    return out, hidden_out
3042
3043
3044def rnn_cell(nonlinearity):
3045    def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
3046        return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
3047
3048    return inner
3049
3050
3051def rnn_cell_data(nonlinearity):
3052    def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
3053        i = F.linear(i, ih_weight, ih_bias)
3054        return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
3055
3056    return inner
3057
3058
3059def one_layer_rnn(inp, hidden, params, has_biases, hidden_fn, reverse=False):
3060    ih_weight = params[0]
3061    hh_weight = params[1]
3062    ih_bias = params[2] if has_biases else None
3063    hh_bias = params[3] if has_biases else None
3064
3065    precomputed_input = F.linear(inp, ih_weight, ih_bias)
3066    precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
3067    cur_hidden = hidden.unsqueeze(0)
3068    step_output = []
3069    for i in precomputed_input:
3070        cur_hidden = hidden_fn(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
3071        step_output.append(cur_hidden)
3072
3073    if reverse:
3074        step_output.reverse()
3075
3076    out = torch.cat(step_output, 0)
3077
3078    return out, cur_hidden.squeeze(0)
3079
3080
3081def mkldnn_one_layer_lstm(inp, hidden, params, has_biases, reverse=False):
3082    w0 = params[0]
3083    w1 = params[1]
3084    if has_biases:
3085        w2 = params[2]
3086        w3 = params[3]
3087    else:
3088        w2 = torch.zeros(w0.size())
3089        w3 = torch.zeros(w1.size())
3090
3091    hx = hidden[0].unsqueeze(0)
3092    cx = hidden[1].unsqueeze(0)
3093
3094    batch_sizes: List[int] = []
3095    mode = 2  # third_party/ideep/include/ideep/abstract_types.hpp: ideep::rnn_kind::LSTM = 2
3096    hidden_size = hx.size(2)
3097    num_layers = 1
3098
3099    # _rnn_helper already handles bidirectional and batch_first so we hard-code them to False here
3100    bidirectional = False
3101    batch_first = False
3102
3103    train = False
3104    # If batch_first, inp has been permuted in _rnn_helper. Convert to contiguous here.
3105    # Same as aten/src/ATen/native/mkldnn/RNN.cpp: mkldnn_rnn: input = input.contiguous();
3106    inp = inp.contiguous()
3107    hx = hx.contiguous()
3108    cx = cx.contiguous()
3109    outputs = torch.ops.aten.mkldnn_rnn_layer.default(
3110        inp,
3111        w0,
3112        w1,
3113        w2,
3114        w3,
3115        hx,
3116        cx,
3117        reverse,
3118        batch_sizes,
3119        mode,
3120        hidden_size,
3121        num_layers,
3122        has_biases,
3123        bidirectional,
3124        batch_first,
3125        train,
3126    )
3127    y, hy, cy = outputs[0], outputs[1], outputs[2]
3128    return y, (hy.squeeze(0), cy.squeeze(0))
3129
3130
3131def _rnn_helper(
3132    input,
3133    hidden,
3134    params,
3135    has_biases,
3136    num_layers,
3137    dropout,
3138    train,
3139    bidirectional,
3140    batch_first,
3141    layer_fn,
3142):
3143    input = input.transpose(0, 1) if batch_first else input
3144    final_hiddens = []
3145
3146    for i in range(num_layers):
3147        cur_params, cur_hidden, bidir_params, bidir_hidden = params_hiddens(
3148            params, hidden, i, bidirectional
3149        )
3150        dropout = dropout if (train and num_layers < i - 1) else 0.0
3151        fwd_inp, fwd_hidden = layer_fn(input, cur_hidden, cur_params, has_biases)
3152        final_hiddens.append(fwd_hidden)
3153
3154        if bidirectional:
3155            bwd_inp, bwd_hidden = layer_fn(
3156                input, bidir_hidden, bidir_params, has_biases, reverse=True
3157            )
3158            final_hiddens.append(bwd_hidden)
3159
3160        if bidirectional:
3161            input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1)  # type: ignore[possibly-undefined]
3162        else:
3163            input = fwd_inp
3164
3165        if dropout != 0 and train and i < num_layers - 1:
3166            input = torch.dropout(input, dropout, train=True)
3167
3168    input = input.transpose(0, 1) if batch_first else input
3169    return input, final_hiddens
3170
3171
3172@register_decomposition(aten.rnn_tanh.input)
3173@aten.rnn_tanh.input.py_impl(DispatchKey.CompositeImplicitAutograd)
3174@aten.rnn_tanh.input.py_impl(DispatchKey.Autograd)
3175def rnn_tanh_input(
3176    input,
3177    hx,
3178    params,
3179    has_biases,
3180    num_layers,
3181    dropout,
3182    train,
3183    bidirectional,
3184    batch_first,
3185):
3186    hidden = hx.unbind(0)
3187    params = gather_params(params, has_biases, False)
3188    out, final_hiddens = _rnn_helper(
3189        input,
3190        hidden,
3191        params,
3192        has_biases,
3193        num_layers,
3194        dropout,
3195        train,
3196        bidirectional,
3197        batch_first,
3198        partial(one_layer_rnn, hidden_fn=rnn_cell(torch.tanh)),
3199    )
3200    return out, torch.stack(final_hiddens, 0)
3201
3202
3203@register_decomposition(aten.rnn_relu.input)
3204@aten.rnn_relu.input.py_impl(DispatchKey.CompositeImplicitAutograd)
3205@aten.rnn_relu.input.py_impl(DispatchKey.Autograd)
3206def rnn_relu_input(
3207    input,
3208    hx,
3209    params,
3210    has_biases,
3211    num_layers,
3212    dropout,
3213    train,
3214    bidirectional,
3215    batch_first,
3216):
3217    hidden = hx.unbind(0)
3218    params = gather_params(params, has_biases, False)
3219    out, final_hiddens = _rnn_helper(
3220        input,
3221        hidden,
3222        params,
3223        has_biases,
3224        num_layers,
3225        dropout,
3226        train,
3227        bidirectional,
3228        batch_first,
3229        partial(one_layer_rnn, hidden_fn=rnn_cell(torch.relu)),
3230    )
3231    return out, torch.stack(final_hiddens, 0)
3232
3233
3234@register_decomposition(aten.rnn_relu.data)
3235@aten.rnn_relu.data.py_impl(DispatchKey.CompositeImplicitAutograd)
3236@aten.rnn_relu.data.py_impl(DispatchKey.Autograd)
3237def rnn_relu_data(
3238    data,
3239    batch_sizes,
3240    hx,
3241    params,
3242    has_biases,
3243    num_layers,
3244    dropout,
3245    train,
3246    bidirectional,
3247):
3248    hidden = hx.unbind(0)
3249    params = gather_params(params, has_biases, False)
3250    out, final_hiddens = _rnn_helper(
3251        data,
3252        hidden,
3253        params,
3254        has_biases,
3255        num_layers,
3256        dropout,
3257        train,
3258        bidirectional,
3259        False,
3260        partial(
3261            one_layer_rnn_data,
3262            batch_sizes=batch_sizes,
3263            hidden_fn=rnn_cell_data(torch.relu),
3264        ),
3265    )
3266    return out, torch.stack(final_hiddens, 0)
3267
3268
3269@register_decomposition(aten.rnn_tanh.data)
3270@aten.rnn_tanh.data.py_impl(DispatchKey.CompositeImplicitAutograd)
3271@aten.rnn_tanh.data.py_impl(DispatchKey.Autograd)
3272def rnn_tanh_data(
3273    data,
3274    batch_sizes,
3275    hx,
3276    params,
3277    has_biases,
3278    num_layers,
3279    dropout,
3280    train,
3281    bidirectional,
3282):
3283    hidden = hx.unbind(0)
3284    params = gather_params(params, has_biases, False)
3285    out, final_hiddens = _rnn_helper(
3286        data,
3287        hidden,
3288        params,
3289        has_biases,
3290        num_layers,
3291        dropout,
3292        train,
3293        bidirectional,
3294        False,
3295        partial(
3296            one_layer_rnn_data,
3297            batch_sizes=batch_sizes,
3298            hidden_fn=rnn_cell_data(torch.tanh),
3299        ),
3300    )
3301    return out, torch.stack(final_hiddens, 0)
3302
3303
3304def lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim):
3305    gates = F.linear(hx, hh_weight, hh_bias) + inp
3306    chunked_gates = gates.chunk(4, chunk_dim)
3307    in_gate = chunked_gates[0].sigmoid()
3308    forget_gate = chunked_gates[1].sigmoid()
3309    cell_gate = chunked_gates[2].tanh()
3310    out_gate = chunked_gates[3].sigmoid()
3311    cy = forget_gate * cx + (in_gate * cell_gate)
3312    hy = out_gate * cy.tanh()
3313    hy = hy if hr_weight is None else F.linear(hy, hr_weight, None)
3314
3315    return hy, cy
3316
3317
3318def one_layer_lstm(inp, hidden, params, has_biases, reverse=False):
3319    ih_weight = params[0]
3320    hh_weight = params[1]
3321    ih_bias = params[2] if has_biases else None
3322    hh_bias = params[3] if has_biases else None
3323    hr_weight = (
3324        params[4] if len(params) == 5 else params[2] if len(params) == 3 else None
3325    )
3326
3327    hx = hidden[0].unsqueeze(0)
3328    cx = hidden[1].unsqueeze(0)
3329
3330    precomputed_input = F.linear(inp, ih_weight, ih_bias)
3331    precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
3332    step_output = []
3333    for inp in precomputed_input:
3334        hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=2)
3335        step_output.append(hx)
3336
3337    if reverse:
3338        step_output.reverse()
3339
3340    out = torch.cat(step_output, 0)
3341
3342    return out, (hx.squeeze(1), cx.squeeze(1))
3343
3344
3345def one_layer_lstm_data(inp, hidden, params, has_biases, batch_sizes, reverse=False):
3346    ih_weight = params[0]
3347    hh_weight = params[1]
3348    ih_bias = params[2] if has_biases else None
3349    hh_bias = params[3] if has_biases else None
3350    hr_weight = (
3351        params[4] if len(params) == 5 else params[2] if len(params) == 3 else None
3352    )
3353
3354    step_output = []
3355    hiddens = []
3356
3357    last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0]
3358    split_inp = torch.split(inp, list(batch_sizes))
3359    if reverse:
3360        split_inp = split_inp[::-1]
3361
3362    orig_hx = hidden[0]
3363    orig_cx = hidden[1]
3364    hx, cx = orig_hx.narrow(0, 0, last_batch_size), orig_cx.narrow(
3365        0, 0, last_batch_size
3366    )
3367
3368    for inp in split_inp:
3369        i = inp.shape[0]
3370        inp = F.linear(inp, ih_weight, ih_bias)
3371
3372        # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest
3373        if i < last_batch_size:
3374            hiddens.append(
3375                (
3376                    hx.narrow(0, i, last_batch_size - i),
3377                    cx.narrow(0, i, last_batch_size - i),
3378                )
3379            )
3380            hx, cx = hx.narrow(0, 0, i), cx.narrow(0, 0, i)
3381
3382        # this will only happen when reverse=True
3383        if i > last_batch_size:
3384            hx = torch.concat(
3385                (hx, orig_hx.narrow(0, last_batch_size, i - last_batch_size)), 0
3386            )
3387            cx = torch.concat(
3388                (cx, orig_cx.narrow(0, last_batch_size, i - last_batch_size)), 0
3389            )
3390
3391        hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=1)
3392        last_batch_size = i
3393        step_output.append(hx)
3394
3395    if reverse:
3396        step_output.reverse()
3397        hidden_out = (hx, cx)
3398    else:
3399        hiddens.append((hx, cx))
3400        hiddens.reverse()
3401        hidden0, hidden1 = zip(*hiddens)
3402        hidden_out = torch.cat(hidden0, 0), torch.cat(hidden1, 0)
3403
3404    out = torch.cat(step_output, 0)
3405    return out, hidden_out
3406
3407
3408def select_one_layer_lstm_function(input, hx, params):
3409    r"""Check whether we could use decompose lstm with mkldnn_rnn_layer.
3410    All the below conditions need to be met:
3411        * ``torch._C._get_mkldnn_enabled()`` returns ``True``.
3412        * All the input args are on CPU.
3413        * The dtypes of args are either torch.float or torch.bfloat16.
3414        * Inference.
3415        * ``has_projections`` returns ``False``.
3416
3417    Args:
3418        * input: the input sequence to LSTM
3419        * hx: a tuple of the input hidden state and cell state ``(h_0, c_0)`` to LSTM
3420        * params: the weight and bias tensors of LSTM
3421    """
3422
3423    def use_mkldnn(input, hx, params):
3424        if not torch._C._get_mkldnn_enabled():
3425            return False
3426
3427        tensors = [input] + list(hx) + list(chain.from_iterable(params))
3428        devices = {t.device for t in tensors}
3429        if len(devices) != 1:
3430            return False
3431
3432        device = devices.pop()
3433        if device != torch.device("cpu"):
3434            return False
3435        # With autocast, possible to have mixed dtype here
3436        dtypes = {t.dtype for t in tensors}
3437        for dtype in dtypes:
3438            if dtype not in [torch.float, torch.bfloat16]:
3439                return False
3440
3441        if input.requires_grad:
3442            return False
3443
3444        has_projections = hx[0].size(2) != hx[1].size(2)
3445        if has_projections:
3446            return False
3447
3448        return True
3449
3450    # mkldnn_one_layer_lstm does not depend on seq_len while one_layer_lstm
3451    # will expand over the seq_len dim
3452    if use_mkldnn(input, hx, params):
3453        return mkldnn_one_layer_lstm
3454    else:
3455        return one_layer_lstm
3456
3457
3458@register_decomposition(aten.lstm.input)
3459@aten.lstm.input.py_impl(DispatchKey.CompositeImplicitAutograd)
3460@aten.lstm.input.py_impl(DispatchKey.Autograd)
3461def lstm_impl(
3462    input,
3463    hx,
3464    params,
3465    has_biases,
3466    num_layers,
3467    dropout,
3468    train,
3469    bidirectional,
3470    batch_first,
3471):
3472    assert len(hx) == 2, "lstm expects two hidden states"
3473    params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2))
3474    hidden = list(zip(hx[0], hx[1]))
3475    layer_fn = select_one_layer_lstm_function(input, hx, params)
3476    out, final_hiddens = _rnn_helper(
3477        input,
3478        hidden,
3479        params,
3480        has_biases,
3481        num_layers,
3482        dropout,
3483        train,
3484        bidirectional,
3485        batch_first,
3486        layer_fn,
3487    )
3488    final_hiddens = list(zip(*final_hiddens))
3489    return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
3490
3491
3492@register_decomposition(aten.lstm.data)
3493@aten.lstm.data.py_impl(DispatchKey.CompositeImplicitAutograd)
3494@aten.lstm.data.py_impl(DispatchKey.Autograd)
3495def lstm_data_impl(
3496    data,
3497    batch_sizes,
3498    hx,
3499    params,
3500    has_biases,
3501    num_layers,
3502    dropout,
3503    train,
3504    bidirectional,
3505):
3506    assert len(hx) == 2, "lstm expects two hidden states"
3507    params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2))
3508    hidden = list(zip(hx[0], hx[1]))
3509    out, final_hiddens = _rnn_helper(
3510        data,
3511        hidden,
3512        params,
3513        has_biases,
3514        num_layers,
3515        dropout,
3516        train,
3517        bidirectional,
3518        False,
3519        partial(one_layer_lstm_data, batch_sizes=batch_sizes),
3520    )
3521    final_hiddens = list(zip(*final_hiddens))
3522    return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
3523
3524
3525def gru_cell(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
3526    chunked_igates = inp.chunk(3, 1)
3527    chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 2)
3528    reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
3529    input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
3530    new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
3531    return (cur_hidden - new_gate) * input_gate + new_gate
3532
3533
3534def gru_cell_data(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
3535    chunked_igates = F.linear(inp, ih_weight, ih_bias).chunk(3, 1)
3536    chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 1)
3537    reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
3538    input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
3539    new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
3540    return (cur_hidden - new_gate) * input_gate + new_gate
3541
3542
3543@register_decomposition(aten.gru.data)
3544@aten.gru.data.py_impl(DispatchKey.CompositeImplicitAutograd)
3545@aten.gru.data.py_impl(DispatchKey.Autograd)
3546def gru_impl_data(
3547    data,
3548    batch_sizes,
3549    hx,
3550    params,
3551    has_biases,
3552    num_layers,
3553    dropout,
3554    train,
3555    bidirectional,
3556):
3557    params = gather_params(params, has_biases, False)
3558    out, final_hiddens = _rnn_helper(
3559        data,
3560        hx.unbind(0),
3561        params,
3562        has_biases,
3563        num_layers,
3564        dropout,
3565        train,
3566        bidirectional,
3567        False,
3568        partial(one_layer_rnn_data, batch_sizes=batch_sizes, hidden_fn=gru_cell_data),
3569    )
3570    return out, torch.stack(final_hiddens, 0)
3571
3572
3573@register_decomposition(aten.gru.input)
3574@aten.gru.input.py_impl(DispatchKey.CompositeImplicitAutograd)
3575@aten.gru.input.py_impl(DispatchKey.Autograd)
3576def gru_impl(
3577    input,
3578    hx,
3579    params,
3580    has_biases,
3581    num_layers,
3582    dropout,
3583    train,
3584    bidirectional,
3585    batch_first,
3586):
3587    params = gather_params(params, has_biases, False)
3588    out, final_hiddens = _rnn_helper(
3589        input,
3590        hx.unbind(0),
3591        params,
3592        has_biases,
3593        num_layers,
3594        dropout,
3595        train,
3596        bidirectional,
3597        batch_first,
3598        partial(one_layer_rnn, hidden_fn=gru_cell),
3599    )
3600    return out, torch.stack(final_hiddens, 0)
3601
3602
3603@register_decomposition(aten._upsample_bilinear2d_aa.vec)
3604@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
3605@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.Autograd)
3606def upsample_bilinear2d_aa_vec(input, output_size, align_corners, scale_factors):
3607    osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3608    scale_h = get_scale_value(scale_factors, 0)
3609    scale_w = get_scale_value(scale_factors, 1)
3610    return torch.ops.aten._upsample_bilinear2d_aa(
3611        input, osize, align_corners, scale_h, scale_w
3612    )
3613
3614
3615@register_decomposition(aten._upsample_bicubic2d_aa.vec)
3616@aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
3617@aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.Autograd)
3618def upsample_bicubic2d_aa_vec(input, output_size, align_corners, scale_factors):
3619    osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3620    scale_h = get_scale_value(scale_factors, 0)
3621    scale_w = get_scale_value(scale_factors, 1)
3622    return torch.ops.aten._upsample_bicubic2d_aa(
3623        input, osize, align_corners, scale_h, scale_w
3624    )
3625
3626
3627@register_decomposition(aten.upsample_bilinear2d.vec)
3628@register_decomposition(aten.upsample_trilinear3d.vec)
3629@aten.upsample_linear1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
3630@aten.upsample_linear1d.vec.py_impl(DispatchKey.Autograd)
3631@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
3632@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd)
3633@aten.upsample_trilinear3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
3634@aten.upsample_trilinear3d.vec.py_impl(DispatchKey.Autograd)
3635def _upsample_linear_vec(input, output_size, align_corners, scale_factors):
3636    osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
3637    scales = scale_factors if scale_factors else [None] * len(osize)
3638    return _upsample_linear(input, osize, align_corners, scales)
3639
3640
3641@register_decomposition([aten.upsample_linear1d.default, aten.upsample_linear1d.out])
3642@out_wrapper()
3643def upsample_linear1d(
3644    input: Tensor,
3645    output_size: List[int],
3646    align_corners: bool,
3647    scales_w: Optional[float] = None,
3648) -> Tensor:
3649    return _upsample_linear(input, output_size, align_corners, [scales_w])
3650
3651
3652@register_decomposition(
3653    [aten.upsample_bilinear2d.default, aten.upsample_bilinear2d.out]
3654)
3655@aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd)
3656@out_wrapper()
3657def upsample_bilinear2d(
3658    input: Tensor,
3659    output_size: List[int],
3660    align_corners: bool,
3661    scales_h: Optional[float] = None,
3662    scales_w: Optional[float] = None,
3663) -> Tensor:
3664    return _upsample_linear(input, output_size, align_corners, [scales_h, scales_w])
3665
3666
3667@register_decomposition(
3668    [aten.upsample_trilinear3d.default, aten.upsample_trilinear3d.out]
3669)
3670@out_wrapper()
3671def upsample_trilinear3d(
3672    input: Tensor,
3673    output_size: List[int],
3674    align_corners: bool,
3675    scales_d: Optional[float] = None,
3676    scales_h: Optional[float] = None,
3677    scales_w: Optional[float] = None,
3678) -> Tensor:
3679    return _upsample_linear(
3680        input, output_size, align_corners, [scales_d, scales_h, scales_w]
3681    )
3682
3683
3684def _compute_scale(in_size, out_size, align_corners, scale=None):
3685    if align_corners:
3686        return (in_size - 1.0) / (out_size - 1.0) if out_size > 1 else 0
3687    else:
3688        return 1.0 / scale if scale is not None and scale > 0 else in_size / out_size
3689
3690
3691def _compute_source_index(scale, dst_index, align_corners):
3692    if align_corners:
3693        return scale * dst_index
3694    else:
3695        return scale * (dst_index + 0.5) - 0.5
3696
3697
3698def _sum_tensors_uint8(
3699    src: Iterable[Tensor], weights: Iterable[Tensor], weights_precision: Tensor
3700) -> Tensor:
3701    output = _sum_tensors(
3702        s.to(torch.int32) * c.to(torch.int32) for s, c in zip(src, weights)
3703    ) + (1 << (weights_precision - 1))
3704    output = output >> weights_precision
3705    return torch.clamp(output, 0, 255).to(torch.uint8)
3706
3707
3708def _compute_weight_precision(weights: TensorSequenceType) -> Tensor:
3709    max_weight = torch.stack(weights).max()
3710    max_weight_precision = 22
3711    precisions = torch.arange(max_weight_precision, device=max_weight.device)
3712    values = 0.5 + max_weight * (1 << (precisions + 1))
3713    mask = values >= (1 << 15)
3714    return max_weight_precision - mask.sum()
3715
3716
3717@pw_cast_for_opmath
3718def _upsample_linear(
3719    input: Tensor,
3720    output_size: List[int],
3721    align_corners: bool,
3722    scales: List[Optional[float]],
3723) -> Tensor:
3724    # get dimensions of original image
3725    n_batch, n_channels = input.shape[:2]
3726    inp_sizes = input.shape[2:]
3727    n_dims = len(inp_sizes)
3728
3729    _, dtype = utils.elementwise_dtypes(
3730        input,
3731        type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
3732    )
3733
3734    def get_values(inp_size, out_size, scales, nsqueeze):
3735        # First Calculate scaling factor
3736        scale_factor = _compute_scale(inp_size, out_size, align_corners, scales)
3737        # We have to create arange with int64 dtype and use .to in order to avoid
3738        # additional kernels creation in inductor and get a perf slowdown
3739        i = torch.arange(out_size, device=input.device).to(dtype=dtype)
3740
3741        x_f32 = _compute_source_index(scale_factor, i, align_corners).clamp(min=0.0)
3742        x_f32 = x_f32.reshape(x_f32.shape[0], *[1] * (nsqueeze))
3743        x = x_f32.to(torch.int64)
3744        xp1 = (x + 1).clamp(max=inp_size - 1)
3745        return x_f32, x, xp1
3746
3747    values = [
3748        get_values(inp_size, out_size, scales, n_dims - 1 - i)
3749        for i, (inp_size, out_size, scales) in enumerate(
3750            zip(inp_sizes, output_size, scales)
3751        )
3752    ]
3753    xs_f32, xs, xp1s = list(zip(*values))
3754
3755    vs = []
3756    for a in product(*[[0, 1]] * n_dims):
3757        idx = [None, None] + [xs[k] if a[k] == 0 else xp1s[k] for k in range(n_dims)]
3758        v = aten._unsafe_index(input, idx)
3759        v = _maybe_convert_to_dtype(v, dtype)
3760        vs.append(v)
3761
3762    for i in reversed(range(n_dims)):
3763        xscale = (xs_f32[i] - xs[i]).clamp(0.0, 1.0).to(dtype)
3764        vs = [
3765            # x1 * (1 - alpha) + x2 * alpha == x1 + (x2 - x1) * alpha
3766            v1 + torch.mul(v2 - v1, xscale)
3767            for v1, v2 in zip(vs[::2], vs[1::2])
3768        ]
3769
3770    assert len(vs) == 1
3771    result = vs[0]
3772
3773    # convert output to correct memory format, if necessary
3774    memory_format = utils.suggest_memory_format(input)
3775
3776    # following "heuristic: only use channels_last path when it's faster than the contiguous path"
3777    if input.device.type == "cuda" and n_channels < 16:
3778        memory_format = torch.contiguous_format
3779
3780    assert isinstance(result, torch.Tensor)
3781
3782    result = result.contiguous(memory_format=memory_format)
3783
3784    if not input.is_floating_point():
3785        result = result.round()
3786
3787    return result
3788
3789
3790# We should be applying decompositions after all transformations
3791@register_decomposition(aten.is_same_size.default)
3792def is_same_size(a: Tensor, b: Tensor) -> bool:
3793    return a.shape == b.shape
3794
3795
3796@register_decomposition([aten._reshape_alias, aten._unsafe_view])
3797@out_wrapper()
3798def _reshape_alias(x, shape, *args):
3799    return aten.view(x, shape)
3800
3801
3802@register_decomposition([aten._unsafe_index])
3803def _unsafe_index(x, indices):
3804    return aten.index(x, indices)
3805
3806
3807@register_decomposition([aten._unsafe_index_put])
3808def _unsafe_index_put(x, indices, value, accumulate=False):
3809    return aten.index_put(x, indices, value, accumulate)
3810
3811
3812@register_decomposition([aten._unsafe_masked_index])
3813def _unsafe_masked_index(x, mask, indices, fill):
3814    for index in indices:
3815        if index is not None:
3816            torch._check(
3817                index.dtype in [torch.long, torch.int],
3818                lambda: "tensors used as indices must be long or int tensors",
3819            )
3820
3821    torch._check(
3822        mask.dtype == torch.bool,
3823        lambda: "tensors used as masks must be bool tensors",
3824    )
3825
3826    if x.numel() == 0:
3827        meta_result = torch._meta_registrations.meta_index_Tensor(x, indices)
3828        return x.new_full(meta_result.shape, fill)
3829
3830    for i in range(len(indices)):
3831        index = indices[i]
3832        if index is not None:
3833            indices[i] = index.clamp(min=0, max=x.size(i) - 1)
3834
3835    return aten._unsafe_index(x, indices).masked_fill(~mask, fill)
3836
3837
3838@register_decomposition([aten._unsafe_masked_index_put_accumulate])
3839def _unsafe_masked_index_put_accumulate(x, mask, indices, values):
3840    for index in indices:
3841        if index is not None:
3842            torch._check(
3843                index.dtype in [torch.long, torch.int],
3844                lambda: "tensors used as indices must be long or int tensors",
3845            )
3846
3847    torch._check(
3848        mask.dtype == torch.bool,
3849        lambda: "tensors used as masks must be bool tensors",
3850    )
3851
3852    if x.numel() == 0:
3853        return x.clone()
3854
3855    for i in range(len(indices)):
3856        index = indices[i]
3857        if index is not None:
3858            indices[i] = index.clamp(min=-x.size(i), max=x.size(i) - 1)
3859
3860    masked_value = values.masked_fill(~mask, 0)
3861    return aten._unsafe_index_put(x, indices, masked_value, accumulate=True)
3862
3863
3864def _nll_loss_forward(
3865    self: Tensor,
3866    target: Tensor,
3867    weight: Optional[Tensor],
3868    reduction: int,
3869    ignore_index: int,
3870) -> Tuple[Tensor, Tensor]:
3871    # self can be [N, C] or [C]
3872    # target can be [N] or []
3873
3874    n_dims = self.dim()
3875    channel_dim = 1
3876    if n_dims < 2:
3877        channel_dim = 0
3878
3879    if weight is not None:
3880        if n_dims > 1:
3881            shape = [
3882                1,
3883            ] * n_dims
3884            shape[channel_dim] = weight.shape[0]
3885            w = weight.view(shape)
3886        else:
3887            w = weight
3888        self = self * w
3889    safe_target = torch.where(target != ignore_index, target, 0)
3890    safe_target_ = safe_target.unsqueeze(channel_dim)
3891    # target can be [N, 1] or [1]
3892
3893    result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
3894
3895    result = torch.where(target != ignore_index, result, 0)
3896
3897    if reduction == Reduction.NONE.value and n_dims > 1:
3898        total_weight = self.new_full((), 0.0)
3899        return result, total_weight
3900
3901    if weight is not None:
3902        w = w.expand(self.shape)
3903        wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
3904        wsum = torch.where(target != ignore_index, wsum, 0)
3905        total_weight = wsum.sum()
3906    else:
3907        total_weight = (target != ignore_index).sum().to(self)
3908
3909    if reduction == Reduction.SUM.value:
3910        result = result.sum()
3911    elif reduction == Reduction.MEAN.value:
3912        result = result.sum() / total_weight
3913
3914    return result, total_weight
3915
3916
3917@register_decomposition(aten.nll_loss_forward)
3918@out_wrapper("output", "total_weight")
3919def nll_loss_forward(
3920    self: Tensor,
3921    target: Tensor,
3922    weight: Optional[Tensor],
3923    reduction: int,
3924    ignore_index: int,
3925) -> Tuple[Tensor, Tensor]:
3926    assert self.dim() > 0 and self.dim() <= 2, "input tensor should be 1D or 2D"
3927    assert (
3928        target.dim() <= 1
3929    ), "0D or 1D target tensor expected, multi-target not supported"
3930
3931    no_batch_dim = self.dim() == 1 and target.dim() == 0
3932    assert no_batch_dim or (
3933        self.shape[0] == target.shape[0]
3934    ), f"size mismatch (got input: {self.shape}, target: {target.shape})"
3935
3936    n_classes = self.shape[-1]
3937
3938    assert weight is None or (
3939        weight.dim() == 1 and weight.numel() == n_classes
3940    ), f"weight tensor should be defined either for all {n_classes} classes or no classes but got weight tensor of shape: {weight.shape}"  # noqa: B950
3941
3942    return _nll_loss_forward(self, target, weight, reduction, ignore_index)
3943
3944
3945@register_decomposition(aten.nll_loss2d_forward)
3946@out_wrapper("output", "total_weight")
3947def nll_loss2d_forward(
3948    self: Tensor,
3949    target: Tensor,
3950    weight: Optional[Tensor],
3951    reduction: int,
3952    ignore_index: int,
3953) -> Tuple[Tensor, Tensor]:
3954    return _nll_loss_forward(self, target, weight, reduction, ignore_index)
3955
3956
3957# These are adapted from aten/src/ATen/native/UpSample.h, wich is based on
3958# https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
3959def _upsample_cubic_convolution1(x: Tensor, A: float) -> Tensor:
3960    return ((A + 2) * x - (A + 3)) * x * x + 1
3961
3962
3963def _upsample_cubic_convolution2(x: Tensor, A: float) -> Tensor:
3964    return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A
3965
3966
3967def _upsample_get_cubic_coefficients(t: Tensor) -> TensorSequenceType:
3968    A = -0.75
3969
3970    if t.device == torch.device("cpu"):
3971        tt1 = torch.stack([t, 1.0 - t], dim=0)
3972        tt2 = torch.stack([t + 1.0, 2.0 - t], dim=0)
3973        w03 = _upsample_cubic_convolution2(tt2, A)
3974        w12 = _upsample_cubic_convolution1(tt1, A)
3975        w0, w3 = torch.unbind(w03, dim=0)
3976        w1, w2 = torch.unbind(w12, dim=0)
3977        return w0, w1, w2, w3
3978    else:
3979        return (
3980            _upsample_cubic_convolution2(t + 1.0, A),
3981            _upsample_cubic_convolution1(t, A),
3982            _upsample_cubic_convolution1(1.0 - t, A),
3983            _upsample_cubic_convolution2(2.0 - t, A),
3984        )
3985
3986
3987def _upsample_cubic_interp1d(coeffs: TensorSequenceType, ts: Tensor) -> Tensor:
3988    coeffs2 = _upsample_get_cubic_coefficients(ts)
3989    return _sum_tensors(c1 * c2 for (c1, c2) in zip(coeffs, coeffs2))
3990
3991
3992# Need this instead of just sum() to keep mypy happy
3993def _sum_tensors(ts: Iterable[Tensor]) -> Tensor:
3994    return reduce(torch.add, ts)
3995
3996
3997def _linspace_from_neg_one(
3998    num_steps: int, align_corners: bool, dtype: torch.dtype, device: torch.device
3999):
4000    if num_steps <= 1:
4001        return torch.tensor(0, device=device, dtype=dtype)
4002
4003    a = ((num_steps - 1) / num_steps) if not align_corners else 1
4004    return torch.linspace(-a, a, steps=num_steps, device=device, dtype=dtype)
4005
4006
4007def _make_base_grid_4d(theta: Tensor, h: int, w: int, align_corners: bool):
4008    dtype = theta.dtype
4009    device = theta.device
4010
4011    # Using padding and summation generates a single kernel vs using torch.stack where 3 kernels generated
4012    # corresponding to each individual tensor: grid_x, grid_y, grid_one
4013    grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, w, 1)
4014    grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(h, 1, 1)
4015    grid_one = torch.ones((1, 1, 1), dtype=dtype, device=device)
4016
4017    # this is just a temporary hack and we should use torch.stack here once #104480 is merged
4018    grid_x = torch.nn.functional.pad(grid_x, pad=(0, 2), mode="constant", value=0)
4019    grid_y = torch.nn.functional.pad(grid_y, pad=(1, 1), mode="constant", value=0)
4020    grid_one = torch.nn.functional.pad(grid_one, pad=(2, 0), mode="constant", value=0)
4021    return grid_x + grid_y + grid_one
4022
4023
4024def _make_base_grid_5d(theta: Tensor, d: int, h: int, w: int, align_corners: bool):
4025    dtype = theta.dtype
4026    device = theta.device
4027
4028    grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, 1, w, 1)
4029    grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(1, h, 1, 1)
4030    grid_z = _linspace_from_neg_one(d, align_corners, dtype, device).view(d, 1, 1, 1)
4031    grid_one = torch.ones((1, 1, 1, 1), dtype=dtype, device=device)
4032
4033    # this is just a temporary hack and we should use torch.stack here once #104480 is merged
4034    grid_x = torch.nn.functional.pad(grid_x, pad=(0, 3), mode="constant", value=0)
4035    grid_y = torch.nn.functional.pad(grid_y, pad=(1, 2), mode="constant", value=0)
4036    grid_z = torch.nn.functional.pad(grid_z, pad=(2, 1), mode="constant", value=0)
4037    grid_one = torch.nn.functional.pad(grid_one, pad=(3, 0), mode="constant", value=0)
4038    return grid_x + grid_y + grid_z + grid_one
4039
4040
4041def _affine_grid_generator_4d(theta: Tensor, size: List[int], align_corners: bool):
4042    n, _, h, w = size
4043    base_grid = _make_base_grid_4d(theta, h, w, align_corners=align_corners)
4044    # base_grid shape is (h, w, 3) and theta shape is (n, 2, 3)
4045    # We do manually a matrix multiplication which is faster than mm()
4046    # (h * w, 3, 1) * (n, 1, 3, 2) -> (n, h * w, 2)
4047    grid = (base_grid.view(-1, 3, 1) * theta.mT.unsqueeze(1)).sum(-2)
4048    return grid.view(n, h, w, 2)
4049
4050
4051def _affine_grid_generator_5d(theta: Tensor, size: List[int], align_corners: bool):
4052    n, _, d, h, w = size
4053    base_grid = _make_base_grid_5d(theta, d, h, w, align_corners=align_corners)
4054    # base_grid shape is (d, h, w, 4) and theta shape is (n, 3, 4)
4055    # We do manually a matrix multiplication which is faster than mm()
4056    # (d * h * w, 4, 1) * (n, 1, 4, 3) -> (n, h * w, 3)
4057    grid = (base_grid.view(-1, 4, 1) * theta.mT.unsqueeze(1)).sum(-2)
4058    return grid.view(n, d, h, w, 3)
4059
4060
4061@register_decomposition(aten.affine_grid_generator)
4062@out_wrapper()
4063@pw_cast_for_opmath
4064def affine_grid_generator(theta: Tensor, size: List[int], align_corners: bool):
4065    torch._check(
4066        len(size) in (4, 5),
4067        lambda: "affine_grid_generator needs 4d (spatial) or 5d (volumetric) inputs.",
4068    )
4069    if len(size) == 4:
4070        return _affine_grid_generator_4d(theta, size, align_corners=align_corners)
4071    else:
4072        return _affine_grid_generator_5d(theta, size, align_corners=align_corners)
4073
4074
4075def _grid_sampler_2d(
4076    a: Tensor,
4077    grid: Tensor,
4078    interpolation_mode: int = 0,
4079    padding_mode: int = 0,
4080    align_corners: bool = False,
4081    _expand_grid: bool = True,
4082) -> Tensor:
4083    # This method is a copy of grid_sampler_2d implementation and introduced with additional arg _expand_grid to
4084    # optionally expand the input grid for performance reasons.
4085    # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x
4086    # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2)
4087    # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first.
4088    # Thus we apply this hack to not expand the grid for this case.
4089
4090    torch._check(
4091        interpolation_mode in (0, 1, 2),
4092        lambda: f"Invalid interpolation mode {interpolation_mode}",
4093    )
4094    torch._check(
4095        padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
4096    )
4097
4098    def unnormalize(coords: Tensor, size: int) -> Tensor:
4099        # Rescale coordinates from [-1, 1] to:
4100        #   [0, size - 1] if align_corners is True
4101        #   [-.5, size -.5] if align_corners is False
4102        mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5)
4103        ofs = size * 0.5 - 0.5
4104        return coords * mul + ofs
4105
4106    # Reflects coordinates until they fall between low and high (inclusive).
4107    # The bounds are passed as twice their value so that half-integer values
4108    # can be represented as ints.
4109    def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor:
4110        if twice_low == twice_high:
4111            return torch.zeros_like(coords)
4112        coords_min = twice_low / 2
4113        coords_span = (twice_high - twice_low) / 2
4114        coords2 = (coords - coords_min).abs()
4115        extra = torch.fmod(coords2, coords_span)
4116        flips = (coords2 / coords_span).floor().to(dtype=torch.int8)
4117        return torch.where(
4118            flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra
4119        )
4120
4121    def compute_coordinates(coords: Tensor, size: int) -> Tensor:
4122        if padding_mode == 0:  # Zero
4123            return coords
4124        elif padding_mode == 1:  # Borders
4125            return torch.clamp(coords, 0, size - 1)
4126        else:  # padding_mode == 2, Reflection
4127            if align_corners:
4128                coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1))
4129            else:
4130                coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1)
4131            return torch.clamp(coords_reflected, 0, size - 1)
4132
4133    def compute_source_index(coords: Tensor, size: int) -> Tensor:
4134        coords_un = unnormalize(coords, size)
4135        return compute_coordinates(coords_un, size)
4136
4137    N, C, iH, iW = a.shape
4138    _, oH, oW, two = grid.shape
4139    assert two == 2
4140
4141    if _expand_grid:
4142        # Let's expand grid to [N, C, oH, oW, 2]
4143        # This allows to generate a single triton cuda kernel instead of two kernels.
4144        # Two kernels are due source indices, weights have shape (N, 1, oH, oW), xnumel=N*oH*oW
4145        # and output has shape (N, C, oH, oW), xnumel=N*C*oH*oW
4146        # Expanding grid to (N, C, oH, oW, two) unifies xnumel to N*C*oH*oW
4147        grid = grid.view(N, 1, oH, oW, two).expand(N, C, oH, oW, 2)
4148
4149    def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor:
4150        return torch.logical_and(
4151            0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys < iH))
4152        )
4153
4154    N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1)
4155    C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1)
4156
4157    def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType:
4158        cond = in_bounds_cond(xs, ys)
4159        # To clip to inside valid coordinates, we map the coordinates
4160        # to (x, y) = (0, 0) and also set the weight to 0
4161        # We also change the shape of the tensor to the appropriate one for
4162        # broadcasting with N_idx, C_idx for the purposes of advanced indexing
4163        c = C if _expand_grid else 1
4164        return tuple(
4165            torch.where(cond, t, 0).view(N, c, oH, oW)
4166            for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws)
4167        )
4168
4169    def get_summand(ix: Tensor, iy: Tensor, w) -> Tensor:
4170        # Perform clipping, index into input tensor and multiply by weight
4171        idx_x, idx_y, w_ = clip(ix, iy, w)
4172        return a[N_idx, C_idx, idx_y, idx_x] * w_
4173
4174    x = grid[..., 0]
4175    y = grid[..., 1]
4176
4177    if interpolation_mode == 0:  # Bilinear
4178        ix = compute_source_index(x, iW)
4179        iy = compute_source_index(y, iH)
4180
4181        ix_nw, iy_nw = ix.floor(), iy.floor()
4182        ix_ne, iy_ne = ix_nw + 1, iy_nw
4183        ix_sw, iy_sw = ix_nw, iy_nw + 1
4184        ix_se, iy_se = ix_ne, iy_sw
4185
4186        w_nw = (ix_se - ix) * (iy_se - iy)
4187        w_ne = (ix - ix_sw) * (iy_sw - iy)
4188        w_sw = (ix_ne - ix) * (iy - iy_ne)
4189        w_se = (ix - ix_nw) * (iy - iy_nw)
4190
4191        return _sum_tensors(
4192            get_summand(ix, iy, w)
4193            for (ix, iy, w) in (
4194                (ix_nw, iy_nw, w_nw),
4195                (ix_ne, iy_ne, w_ne),
4196                (ix_sw, iy_sw, w_sw),
4197                (ix_se, iy_se, w_se),
4198            )
4199        )
4200    elif interpolation_mode == 1:  # Nearest
4201        ix = compute_source_index(x, iW)
4202        iy = compute_source_index(y, iH)
4203
4204        ix_nearest = ix.round()
4205        iy_nearest = iy.round()
4206
4207        return get_summand(ix_nearest, iy_nearest, 1)
4208    else:  # interpolation_mode == 2, Bicubic
4209        ix = unnormalize(x, iW)
4210        iy = unnormalize(y, iH)
4211
4212        ix_nw = ix.floor()
4213        iy_nw = iy.floor()
4214
4215        tx = ix - ix_nw
4216        ty = iy - iy_nw
4217
4218        if not _expand_grid:
4219            tx = tx.unsqueeze(1)
4220            ty = ty.unsqueeze(1)
4221
4222        def get_value_bounded(ix: Tensor, iy: Tensor) -> Tensor:
4223            x = compute_coordinates(ix, iW)
4224            y = compute_coordinates(iy, iH)
4225            return get_summand(x, y, 1)
4226
4227        def get_coeff(ofs: int) -> Tensor:
4228            iy_ofs = iy_nw + (ofs - 1)
4229            cs = (
4230                get_value_bounded(ix_nw - 1, iy_ofs),
4231                get_value_bounded(ix_nw, iy_ofs),
4232                get_value_bounded(ix_nw + 1, iy_ofs),
4233                get_value_bounded(ix_nw + 2, iy_ofs),
4234            )
4235            return _upsample_cubic_interp1d(cs, tx)
4236
4237        coeffs = tuple(get_coeff(ofs) for ofs in range(4))
4238        return _upsample_cubic_interp1d(coeffs, ty)
4239
4240
4241@register_decomposition(aten.grid_sampler_2d)
4242@out_wrapper()
4243@pw_cast_for_opmath
4244def grid_sampler_2d(
4245    a: Tensor,
4246    grid: Tensor,
4247    interpolation_mode: int = 0,
4248    padding_mode: int = 0,
4249    align_corners: bool = False,
4250) -> Tensor:
4251    return _grid_sampler_2d(
4252        a,
4253        grid=grid,
4254        interpolation_mode=interpolation_mode,
4255        padding_mode=padding_mode,
4256        align_corners=align_corners,
4257    )
4258
4259
4260@register_decomposition(aten.mv)
4261@out_wrapper()
4262@pw_cast_for_opmath
4263def mv(self, vec):
4264    torch._check(
4265        self.dim() == 2 and vec.dim() == 1,
4266        lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}",
4267    )
4268    torch._check(
4269        self.size(1) == vec.size(0),
4270        lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})",
4271    )
4272    return (self * vec).sum(dim=1)
4273
4274
4275@register_decomposition(aten.binary_cross_entropy_with_logits)
4276@out_wrapper()
4277def binary_cross_entropy_with_logits(
4278    self, target, weight=None, pos_weight=None, reduction=Reduction.MEAN.value
4279):
4280    if pos_weight is not None:
4281        log_weight = (pos_weight - 1) * target + 1
4282        loss = (1 - target) * self - (log_weight * F.logsigmoid(self))
4283    else:
4284        loss = (1 - target) * self - F.logsigmoid(self)
4285
4286    if weight is not None:
4287        loss = loss * weight
4288
4289    return apply_loss_reduction(loss, reduction)
4290
4291
4292def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> bool:
4293    # For comments of the logic of this function see eager in /native/LinearAlgebra.cpp
4294
4295    t1, t2 = (tensor1, tensor2) if tensor1.ndim >= tensor2.ndim else (tensor2, tensor1)
4296
4297    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
4298
4299    if not (t1.ndim >= 3 and t2.ndim <= 2):
4300        return False
4301    if t2.requires_grad and not is_out:
4302        return True
4303    if tensor1.ndim == 2:
4304        return False
4305    if guard_size_oblivious(t1.numel() == 0):
4306        return True
4307
4308    t1_shape = t1.shape
4309    t1_stride = t1.stride()
4310    return all(
4311        st1 == st2 * s2
4312        for (st1, st2, s2) in zip(t1_stride[:-2], t1_stride[1:-1], t1_shape[1:-1])
4313    )
4314
4315
4316@aten.matmul.default.py_impl(DispatchKey.CompositeImplicitAutograd)
4317@aten.matmul.out.py_impl(DispatchKey.CompositeImplicitAutograd)
4318@out_wrapper(pass_is_out=True)
4319def matmul(tensor1, tensor2, *, is_out=False):
4320    dim_tensor1 = tensor1.dim()
4321    dim_tensor2 = tensor2.dim()
4322    assert dim_tensor1 != 0 and dim_tensor2 != 0
4323    if dim_tensor1 == 1 and dim_tensor2 == 1:
4324        return torch.dot(tensor1, tensor2)
4325    elif dim_tensor1 == 2 and dim_tensor2 == 1:
4326        return torch.mv(tensor1, tensor2)
4327    elif dim_tensor1 == 1 and dim_tensor2 == 2:
4328        return torch.squeeze(torch.mm(torch.unsqueeze(tensor1, 0), tensor2), 0)
4329    elif dim_tensor1 == 2 and dim_tensor2 == 2:
4330        return torch.mm(tensor1, tensor2)
4331    elif should_fold(tensor1, tensor2, is_out):
4332        # dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) ||
4333        # dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2)
4334        # and some condition on the strides is fulfilled
4335
4336        # optimization: use mm instead of bmm by folding the batch of the larger tensor
4337        # into its leading matrix dimension
4338        transpose = dim_tensor2 > dim_tensor1
4339        t1 = tensor2.mT if transpose else tensor1
4340        t2 = (
4341            tensor2 if not transpose else (tensor1.t() if dim_tensor1 == 2 else tensor1)
4342        )
4343        # Invariant: t1.dim() >= 3 && (t2.dim() == 1 || t2.dim() == 2)
4344        #            and t1 and t2 are matmul-compatible
4345
4346        # Why not t1.view(-1, sizes_1[-1])?
4347        # If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous.
4348        # This can happen in e.g. [3, 5, 0] @ [0, 0].
4349        sizes_1 = t1.shape
4350        output_shape = list(sizes_1[:-1])
4351        folded_dim1 = reduce(operator.mul, output_shape)
4352
4353        # Readjust output_shape if we are multiplying by a matrix
4354        t2_is_matrix = t2.dim() == 2
4355        if t2_is_matrix:
4356            output_shape.append(t2.shape[1])
4357
4358        # This will almost always be a view.
4359        # It may not be a view if t2->requires_grad(). See should_fold in aten/ for an explanation
4360        t1_folded = t1.reshape(folded_dim1, sizes_1[-1])
4361        if t2_is_matrix:
4362            # This copies if we perform a 2D @ 3D and the first tensor requires_grad
4363            # See should_fold native/LinearAlgebra.cpp for why.
4364            output = t1_folded.mm(t2).view(output_shape)
4365            return output.mT.contiguous() if transpose else output
4366        else:
4367            return t1_folded.mv(t2).view(output_shape)
4368
4369    elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
4370        # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
4371        # we track m1 vs m2 separately even though they must match for nicer error messages
4372        n = tensor1.size(-2) if dim_tensor1 > 1 else 1
4373        m1 = tensor1.size(-1)
4374        batch_tensor1 = tensor1.shape[:-2]
4375        m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1)
4376        p = tensor2.size(-1) if dim_tensor2 > 1 else 1
4377
4378        batch_tensor2: List[int] = []
4379        # TODO: handling of slice
4380        for i in range(dim_tensor2 - 2):
4381            batch_tensor2.append(tensor2.size(i))
4382
4383        # Same optimization for the gradients as that in should_fold
4384        # If we're going to broadcast, we force it to go through the should_fold branch
4385        if (
4386            dim_tensor1 == 3
4387            and dim_tensor2 == 3
4388            and batch_tensor1[0] != batch_tensor2[0]
4389        ):
4390            if batch_tensor1[0] == 1 and tensor1.requires_grad:
4391                return matmul(tensor1.squeeze(0), tensor2)
4392            if batch_tensor2[0] == 1 and tensor2.requires_grad:
4393                return matmul(tensor1, tensor2.squeeze(0))
4394
4395        # expand the batch portion (i.e. cut off matrix dimensions and expand rest)
4396        expand_batch_portion = list(
4397            torch.broadcast_shapes(batch_tensor1, batch_tensor2)
4398        )
4399
4400        tensor1_expand_size = expand_batch_portion + [n, m1]
4401
4402        expand_batch_product = prod(expand_batch_portion)
4403
4404        # HACK: We need reshape with symint support
4405        tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape(
4406            expand_batch_product, n, m1
4407        )
4408
4409        vector_rhs = dim_tensor2 == 1
4410        if vector_rhs:
4411            tensor2_expand_size = expand_batch_portion + [m2]
4412            tensor2_expanded = (
4413                tensor2.expand(tensor2_expand_size)
4414                .reshape(expand_batch_product, m2)
4415                .unsqueeze(2)
4416            )
4417        else:
4418            tensor2_expand_size = expand_batch_portion + [m2, p]
4419            tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape(
4420                expand_batch_product, m2, p
4421            )
4422
4423        output_shape = expand_batch_portion
4424        if dim_tensor1 > 1:
4425            output_shape.append(n)
4426
4427        if dim_tensor2 > 1:
4428            output_shape.append(p)
4429
4430        if vector_rhs:
4431            return tensor1_expanded.bmm(tensor2_expanded).squeeze(-1).view(output_shape)
4432        else:
4433            return tensor1_expanded.bmm(tensor2_expanded).view(output_shape)
4434    else:
4435        torch._check(False, lambda: "both arguments to matmul need to be at least 1D")
4436
4437
4438@register_decomposition([aten.upsample_bicubic2d.default, aten.upsample_bicubic2d.out])
4439@aten.upsample_bicubic2d.default.py_impl(DispatchKey.Autograd)
4440@out_wrapper()
4441@pw_cast_for_opmath
4442def upsample_bicubic2d_default(
4443    input: Tensor,
4444    output_size: Tuple[int, int],
4445    align_corners: bool,
4446    scale_h: Optional[float] = None,
4447    scale_w: Optional[float] = None,
4448) -> Tensor:
4449    # get dimensions of original image
4450    _, _, in_h, in_w = input.shape
4451
4452    # Calculate horizontal and vertical scaling factor
4453    h_scale_factor = _compute_scale(in_h, output_size[0], align_corners, scale_h)
4454    w_scale_factor = _compute_scale(in_w, output_size[1], align_corners, scale_w)
4455
4456    _, dtype = utils.elementwise_dtypes(
4457        input, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
4458    )
4459
4460    # We have to create arange with int64 dtype and use .to in order to avoid
4461    # additional kernels creation in inductor and get a perf slowdown
4462    i = torch.arange(output_size[0], device=input.device).to(dtype=dtype)
4463    j = torch.arange(output_size[1], device=input.device).to(dtype=dtype)
4464
4465    x_float = _compute_source_index(w_scale_factor, j, align_corners)
4466    y_float = _compute_source_index(h_scale_factor, i, align_corners)
4467    y_float = y_float.unsqueeze(-1)
4468
4469    x = x_float.floor()
4470    y = y_float.floor()
4471
4472    # We should also clamp xscale/yscale
4473    # See guard_index_and_lambda in UpSample.h
4474    yscale = (y_float - y).clamp(0.0, 1.0)
4475    xscale = (x_float - x).clamp(0.0, 1.0)
4476    x = x.to(torch.int64)
4477    y = y.to(torch.int64)
4478
4479    iys_ofs = (y - 1, y, y + 1, y + 2)
4480    ixs_ofs = (x - 1, x, x + 1, x + 2)
4481
4482    weights_x = _upsample_get_cubic_coefficients(xscale)
4483    weights_y = _upsample_get_cubic_coefficients(yscale)
4484
4485    weights_precision_x, weights_precision_y = None, None
4486    if input.dtype == torch.uint8:
4487        weights_precision_x = _compute_weight_precision(weights_x)
4488        weights_precision_y = _compute_weight_precision(weights_y)
4489
4490        weights_x = [
4491            (w * (1 << weights_precision_x) + torch.sign(w) * 0.5).to(torch.int16)
4492            for w in weights_x
4493        ]
4494        weights_y = [
4495            (w * (1 << weights_precision_y) + torch.sign(w) * 0.5).to(torch.int16)
4496            for w in weights_y
4497        ]
4498
4499    def load_bounded(ys, xs):
4500        y_idx = torch.clamp(ys, 0, in_h - 1)
4501        x_idx = torch.clamp(xs, 0, in_w - 1)
4502        v = aten._unsafe_index(input, [None, None, y_idx, x_idx])
4503        return v
4504
4505    def get_x_interp(y):
4506        src_x = tuple(load_bounded(y, x_ofs) for x_ofs in ixs_ofs)
4507        if input.dtype == torch.uint8:
4508            assert weights_precision_x is not None
4509            return _sum_tensors_uint8(src_x, weights_x, weights_precision_x)
4510        return _sum_tensors(c1 * c2 for (c1, c2) in zip(src_x, weights_x))
4511
4512    src_y = tuple(get_x_interp(y_ofs) for y_ofs in iys_ofs)
4513    if input.dtype == torch.uint8:
4514        assert weights_precision_y is not None
4515        result = _sum_tensors_uint8(src_y, weights_y, weights_precision_y)
4516    else:
4517        result = _sum_tensors(c1 * c2 for (c1, c2) in zip(src_y, weights_y))
4518
4519    # convert output to correct memory format, if necessary
4520    memory_format = utils.suggest_memory_format(input)
4521    result = result.contiguous(memory_format=memory_format)
4522    return result
4523
4524
4525@register_decomposition(aten.upsample_bicubic2d.vec)
4526@aten.upsample_bicubic2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
4527@aten.upsample_bicubic2d.vec.py_impl(DispatchKey.Autograd)
4528@out_wrapper()
4529@pw_cast_for_opmath
4530def upsample_bicubic2d_vec(
4531    a: Tensor,
4532    output_size: Optional[Tuple[int, int]],
4533    align_corners: bool,
4534    scale_factors: Optional[Tuple[float, float]] = None,
4535) -> Tensor:
4536    torch._check(
4537        bool(output_size) + bool(scale_factors) == 1,
4538        lambda: "Must specify exactly one of output_size and scale_factors.",
4539    )
4540    if output_size is None:
4541        assert scale_factors is not None
4542        output_size = cast(
4543            Tuple[int, int],
4544            tuple(
4545                sym_int(sym_float(w) * scale)
4546                for w, scale in zip(a.shape[2:], scale_factors)
4547            ),
4548        )
4549    scale_h, scale_w = scale_factors if scale_factors else (None, None)
4550    return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w)
4551
4552
4553@register_decomposition(aten.reflection_pad1d)
4554@register_decomposition(aten.reflection_pad2d)
4555@register_decomposition(aten.reflection_pad3d)
4556@pw_cast_for_opmath
4557@out_wrapper()
4558def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
4559    def idx(left, middle, right):
4560        dim_idx = torch.arange(-left, middle + right, device=a.device)
4561        return middle - 1 - (middle - 1 - dim_idx.abs()).abs()
4562
4563    return _reflection_or_replication_pad(
4564        a,
4565        padding,
4566        idx,
4567    )
4568
4569
4570@register_decomposition(aten.replication_pad1d)
4571@register_decomposition(aten.replication_pad2d)
4572@register_decomposition(aten.replication_pad3d)
4573@pw_cast_for_opmath
4574@out_wrapper()
4575def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
4576    def idx(left, middle, right):
4577        dim_idx = torch.arange(-left, middle + right, device=a.device)
4578        return torch.clamp(dim_idx, 0, middle - 1)
4579
4580    return _reflection_or_replication_pad(
4581        a,
4582        padding,
4583        idx,
4584    )
4585
4586
4587def _reflection_or_replication_pad(
4588    a: Tensor,
4589    padding: Tuple[int, ...],
4590    idx_fn: Callable[[int, int, int], Tensor],
4591) -> Tensor:
4592    dim = len(padding) // 2
4593    torch._check(
4594        a.dim() in (dim + 1, dim + 2),
4595        lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input",
4596    )
4597    inp_shape = a.shape[-dim:]
4598    nc_dim = a.dim() - dim
4599
4600    padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)]
4601    padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)]
4602
4603    result = a
4604    for i in range(dim):
4605        idx: List[Any] = [None] * result.dim()
4606        idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
4607        result = aten._unsafe_index(result, idx)
4608
4609    # convert output to correct memory format, if necessary
4610    memory_format = utils.suggest_memory_format(result)
4611    result = result.contiguous(memory_format=memory_format)
4612    return result
4613
4614
4615@register_decomposition(aten.reflection_pad1d_backward)
4616@register_decomposition(aten.reflection_pad2d_backward)
4617@register_decomposition(aten.reflection_pad3d_backward)
4618@out_wrapper("grad_input")
4619def _reflection_pad_backward(grad_output, x, padding):
4620    dim = len(padding) // 2
4621
4622    dhw = [h - 1 for h in x.shape[-dim:]]
4623
4624    padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)]
4625    padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)]
4626
4627    indices = []
4628    for i in range(x.ndim):
4629        view_shape = [1] * x.ndim
4630        view_shape[i] = -1
4631        indices.append(torch.arange(x.shape[i], device=x.device).view(view_shape))
4632
4633    b = indices[:-dim]
4634    xyz = indices[-dim:]
4635
4636    def index_range_condition(index_range):
4637        i, lb, ub = index_range
4638        return torch.logical_and(i >= lb, i <= ub)
4639
4640    # Areas after reflection:
4641    #
4642    #   top-left    |   top     |   top-right
4643    # -----------------------------------------
4644    #   left        |   center  |   right
4645    # -----------------------------------------
4646    #   bottom-left |   bottom  |   bottom-right
4647    #
4648    # The center area is the original matrix. Other areas are reflections.
4649
4650    center = [xyz[i] + padding_left[i] for i in range(dim)]
4651    left_reflect = [padding_left[i] - xyz[i] for i in range(dim)]
4652    right_reflect = [2 * dhw[i] + padding_left[i] - xyz[i] for i in range(dim)]
4653
4654    # Accumulate gradients from different areas
4655    # If some of the padding is negative, center load is not always valid
4656    range_c = [
4657        (center[i], 0, dhw[i] + padding_left[i] + padding_right[i]) for i in range(dim)
4658    ]
4659    cond = functools.reduce(
4660        aten.logical_and, [index_range_condition(range_c[i]) for i in range(dim)]
4661    )
4662    grad = aten._unsafe_masked_index(grad_output, cond, b + center, 0.0)
4663
4664    def accumulate(grad, out, index_ranges):
4665        # If the upper bound is less than the lower bound, we can get rid of one accumulation.
4666        # This happens when the padding size is zero.
4667        for i in range(dim):
4668            upper_less_than_lower = index_ranges[i][2] < index_ranges[i][1]
4669            if isinstance(upper_less_than_lower, bool) and upper_less_than_lower:
4670                return grad
4671
4672        cond = functools.reduce(
4673            aten.logical_and,
4674            [index_range_condition(index_range) for index_range in index_ranges],
4675        )
4676        g = aten._unsafe_masked_index(grad_output, cond, b + out, 0.0)
4677        return grad + g
4678
4679    for area in itertools.product(*[[-1, 0, 1] for _ in range(dim)]):
4680        if area == tuple([0] * dim):
4681            # center, this is already done.
4682            continue
4683
4684        outs = []
4685        index_ranges = []
4686
4687        for i in range(dim):
4688            if area[i] == 0:
4689                out = center[i]
4690                index_range = range_c[i]
4691            elif area[i] == -1:
4692                out = left_reflect[i]
4693                index_range = (xyz[i], 1, padding_left[i])
4694            elif area[i] == 1:
4695                out = right_reflect[i]
4696                index_range = (xyz[i], dhw[i] - padding_right[i], dhw[i] - 1)
4697
4698            outs.append(out)  # type: ignore[possibly-undefined]
4699            index_ranges.append(index_range)  # type: ignore[possibly-undefined]
4700
4701        grad = accumulate(grad, outs, index_ranges)
4702
4703    return grad
4704
4705
4706@register_decomposition(aten.aminmax)
4707@out_wrapper("min", "max")
4708def aminmax(self, *, dim=None, keepdim=False):
4709    amin = torch.amin(self, dim=dim, keepdim=keepdim)
4710    amax = torch.amax(self, dim=dim, keepdim=keepdim)
4711    return amin, amax
4712
4713
4714@register_decomposition(aten.nansum)
4715@out_wrapper()
4716def nansum(self, dim=None, keepdim=False, *, dtype=None):
4717    return aten.sum(torch.where(torch.isnan(self), 0, self), dim, keepdim, dtype=dtype)
4718
4719
4720@register_decomposition([aten.arange.default, aten.arange.out])
4721@out_wrapper()
4722def arange_default(
4723    end: NumberType,
4724    *,
4725    dtype: Optional[torch.dtype] = None,
4726    layout: torch.layout = torch.strided,
4727    device: Optional[torch.device] = None,
4728    pin_memory: bool = False,
4729):
4730    return aten.arange.start_step(
4731        0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
4732    )
4733
4734
4735@register_decomposition([aten.arange.start])
4736def arange_start(
4737    start: NumberType,
4738    end: NumberType,
4739    *,
4740    dtype: Optional[torch.dtype] = None,
4741    layout: torch.layout = torch.strided,
4742    device: Optional[torch.device] = None,
4743    pin_memory: bool = False,
4744):
4745    return aten.arange.start_step(
4746        start, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
4747    )
4748
4749
4750@register_decomposition(out_dtype)
4751def out_dtype_decomp(*args, **kwargs):
4752    from torch._higher_order_ops.out_dtype import out_dtype_dense
4753
4754    return out_dtype_dense(*args, **kwargs)
4755
4756
4757@register_decomposition(aten.multi_margin_loss)
4758@aten.multi_margin_loss.default.py_impl(DispatchKey.Autograd)
4759@out_wrapper()
4760def multi_margin_loss(
4761    input: Tensor,
4762    target: Tensor,
4763    p: NumberType = 1,
4764    margin: NumberType = 1,
4765    weight: Optional[Tensor] = None,
4766    reduction: int = Reduction.MEAN.value,
4767) -> Tensor:
4768    input = torch.atleast_2d(input)
4769    target = torch.atleast_1d(target)
4770    nframe = input.shape[0]
4771    dim = input.shape[1]
4772    torch._check(p == 1 or p == 2, lambda: "only p == 1 and p == 2 supported")
4773    torch._check(
4774        input.ndim == 2 and dim != 0,
4775        lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {input.shape}",
4776    )
4777    torch._check(
4778        target.ndim == 1 and target.numel() == nframe,
4779        lambda: f"inconsistent target size, expected {nframe} but got {target.shape}",
4780    )
4781    if weight is not None:
4782        weight = torch.atleast_1d(weight)
4783        torch._check(
4784            weight.ndim == 1 and weight.numel() == dim,  # type: ignore[union-attr]
4785            lambda: f"inconsistent weight size, expected {dim} but got {weight.shape}",  # type: ignore[union-attr]
4786        )
4787    target = target.unsqueeze(1)
4788    u = torch.gather(input, dim=1, index=target)
4789    z = margin - u + input
4790    z = z.clamp_min(0)
4791    z = z if p == 1 else z * z
4792    if weight is not None:
4793        z = z * weight[target]
4794    idx = torch.arange(dim, device=input.device)
4795    z = torch.where(idx != target, z, 0)
4796    if reduction == Reduction.MEAN.value:
4797        return z.mean()
4798    elif reduction == Reduction.SUM.value:
4799        return z.sum() / z.shape[1]
4800    else:
4801        return z.mean(dim=1)
4802
4803
4804@register_decomposition(aten.multilabel_margin_loss_forward)
4805@aten.multilabel_margin_loss_forward.default.py_impl(DispatchKey.Autograd)
4806@out_wrapper("output", "is_target")
4807def multilabel_margin_loss_forward(
4808    input: Tensor,
4809    target: Tensor,
4810    reduction: int,
4811) -> Tuple[Tensor, Tensor]:
4812    orig_input_shape = input.shape
4813    orig_target_shape = target.shape
4814    input = torch.atleast_2d(input)
4815    target = torch.atleast_2d(target)
4816    dim = input.shape[1]
4817    torch._check(
4818        len(orig_input_shape) <= 2 and dim != 0,
4819        lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {orig_input_shape}",
4820    )
4821    torch._check(
4822        len(orig_target_shape) <= 2 and orig_target_shape == orig_input_shape,
4823        lambda: f"inconsistent target size: {orig_target_shape} for input of size: {orig_input_shape}",
4824    )
4825    # ignores labels after the first -1, detects when -1 is not present
4826    idx = torch.arange(dim, device=target.device)
4827    is_end = target == -1
4828    end_idx = torch.amin(torch.where(is_end, idx, dim), dim=-1, keepdim=True)
4829    # target indices
4830    target_mask = idx < end_idx
4831    # masks target to be able to use gather, which doesn't allow -1
4832    tidx0 = torch.where(target_mask, target, 0)
4833    u = torch.gather(input, dim=-1, index=tidx0)
4834    # is_target
4835    tidx1 = torch.where(target_mask, target, -1)
4836    is_target = torch.any(idx == tidx1.unsqueeze(dim=-1), dim=1)
4837    # loss
4838    z = 1.0 - u.T.unsqueeze(dim=-1) + input
4839    z = z.clamp_min(0)
4840    z = z / dim
4841    # masks loss
4842    z = torch.where(is_target, 0, z)
4843    # reduction
4844    if reduction == Reduction.MEAN.value:
4845        z = z.sum(dim=(0, -1)).mean()
4846    elif reduction == Reduction.SUM.value:
4847        z = z.sum()
4848    else:
4849        z = z.sum(dim=(0, -1))
4850    # result
4851    is_target = is_target.to(input.dtype).reshape(orig_target_shape)
4852    return z, is_target
4853
4854
4855# scaled_dot_product_attention used to be decomposed in pre-autograd, given that
4856# it calls _scaled_dot_product_attention_math and
4857# _scaled_dot_product_attention_math only has a CompositeImplicitAutograd
4858# kernel. As a result it's decomposed into ops with finer granularity.
4859# However recent PRs (#103826 #105131 #115913) added new logic in
4860# scaled_dot_product_attention and now it calls
4861# _scaled_dot_product_flash_attention_for_cpu in export path. This results
4862# in _scaled_dot_product_flash_attention_for_cpu showing up in export result.
4863# This decomposition ensures scaled_dot_product_attention is still decomposed
4864# the same way as before, i.e., going through
4865# _scaled_dot_product_attention_math. Notice that this decomp rule should be
4866# excluded by inductor.
4867@register_decomposition(aten._scaled_dot_product_flash_attention_for_cpu.default)
4868def scaled_dot_product_flash_attention_for_cpu(
4869    query: Tensor,
4870    key: Tensor,
4871    value: Tensor,
4872    dropout_p: float = 0.0,
4873    is_causal: bool = False,
4874    *,
4875    attn_mask: Optional[Tensor] = None,
4876    scale: Optional[float] = None,
4877) -> Tuple[Tensor, Tensor]:
4878    dtype = query.dtype
4879    torch._check(
4880        torch.is_floating_point(query),
4881        lambda: f"query must be FP32, FP64, BF16, FP16 but got {query.dtype}",
4882    )
4883    torch._check(
4884        query.dim() == 4 and key.dim() == 4 and value.dim() == 4,
4885        lambda: f"q, k, v must be a 4 dimensional tensor, got {query.dim()}, {key.dim()}, {value.dim()}",
4886    )
4887    torch._check(
4888        dropout_p == 0.0, lambda: f"dropout probability must be zero, got {dropout_p}"
4889    )
4890    torch._check(
4891        query.shape[3] == value.shape[3] and key.shape[3] == value.shape[3],
4892        lambda: "q, k, v should have the same head size",
4893    )
4894
4895    output, attn = aten._scaled_dot_product_attention_math.default(
4896        query,
4897        key,
4898        value,
4899        attn_mask=attn_mask,
4900        dropout_p=dropout_p,
4901        is_causal=is_causal,
4902        dropout_mask=None,
4903        scale=scale,
4904    )
4905    # Why this change?
4906    # In pre-dispatch export scaled_dot_product_attention is executed via
4907    # * flash_attention.
4908    # flash_attention allocates output tensor as (N, L, H, E)
4909    #   it then transposes that to get (N, H, L, E) which is supposed to be the return
4910    # tensor dim for scaled_dot_product_attention
4911    # assume x: [N, H, L, E] is the output sdpa
4912    # In MHA code, this output is then permuted via (2, 0, 1, 3) to get
4913    # (L, N, H, E) dim tensor
4914    # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via
4915    # x = x.view(L * N, H * E)
4916    # During pre autograd dispatch call to contiguous is not traced because
4917    # flash_attention output after the x.permute is already contiguous
4918    # on which the view is valid
4919    # However, during 2nd stage export, post-dispatch, we run _match variant
4920    # instead of flash* to get the decomposition. _match variant returns
4921    # x: [N, H, L, E] applying x.permute(2, 0, 1, 3) returns
4922    # x: [L, N, H, E] and without converting this to contiguous tensor
4923    # subsequent view is not valid and the export fails
4924    # solution is to maintain the return tensor view from the decomp to be
4925    # exactly same as *flash* variant.
4926    # flash variants output is contiguous as [N, L, H, E]
4927    # _match variant out is contiguous as [N, H, L, E]
4928    # out = out.transpose(1, 2).contiguous gets output as contiguous
4929    # in [N, L, H, E].
4930    # Subsrequent transpose(1, 2) then returns a view on which
4931    # aforementioned code snippet, as showm below, is valid
4932    # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via
4933    # x = x.view(L * N, H * E)
4934
4935    # Really the invariant you want to maintain is:
4936    # pre-dispatch op-output and its decomposed representation must
4937    # return tensor with same view and dims
4938    output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format)
4939    return (output.transpose(1, 2), attn)
4940
4941
4942def register_inplace(aten_op, outplace_op):
4943    @register_decomposition(aten_op)
4944    def inplace_op(*args, **kwargs):
4945        out = outplace_op(*args, **kwargs)
4946        return args[0].copy_(out)
4947
4948    return inplace_op
4949
4950
4951@register_decomposition([aten.baddbmm])
4952@out_wrapper()
4953@pw_cast_for_opmath
4954def baddbmm(self, batch1, batch2, beta=1, alpha=1):
4955    if not self.is_floating_point() and not self.is_complex():
4956        beta = int(beta)
4957        alpha = int(alpha)
4958    result = torch.bmm(batch1, batch2)
4959    if not isinstance(alpha, numbers.Number) or alpha != 1:
4960        result = result * alpha
4961    if beta == 0:
4962        return result
4963    if not isinstance(beta, numbers.Number) or beta != 1:
4964        self = self * beta
4965    return self + result
4966
4967
4968@register_decomposition(aten.floor_divide)
4969@out_wrapper()
4970def floor_divide(self, other):
4971    return torch.div(self, other, rounding_mode="floor")
4972
4973
4974@register_decomposition(aten.sym_numel)
4975def sym_numel(t):
4976    return functools.reduce(operator.mul, t.shape, 1)
4977
4978
4979@register_decomposition([aten.sum.default, aten.sum.out])
4980def sum_default(
4981    self: Tensor,
4982    *,
4983    dtype: Optional[torch.dtype] = None,
4984    out: Optional[Tensor] = None,
4985) -> Tensor:
4986    if out is None:
4987        return aten.sum.dim_IntList(self, [], dtype=dtype)
4988    else:
4989        return aten.sum.IntList_out(self, [], dtype=dtype, out=out)
4990
4991
4992@register_decomposition([aten.squeeze.default, aten.squeeze.dim])
4993def squeeze_default(self: Tensor, dim: Optional[int] = None):
4994    # handle a scalar directly
4995    if not isinstance(self, torch.Tensor):
4996        return self
4997    # perform squeeze
4998    if dim is None:
4999        return aten.squeeze.dims(self, list(range(self.dim())))
5000    else:
5001        return aten.squeeze.dims(self, [dim])
5002
5003
5004@register_decomposition(torch.ops.aten._weight_norm_interface)
5005def _weight_norm_interface(v, g, dim=0):
5006    # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58
5007    keep_dim = tuple(i for i in range(len(v.shape)) if i != dim)
5008    # align with cuda behavior, keep norm in 'float' when g is 'bfloat16'
5009    norm_dtype = torch.float if g.dtype == torch.bfloat16 else None
5010    norm = v.norm(2, keep_dim, keepdim=True, dtype=norm_dtype)
5011    return v * (g / norm.to(g.dtype)), norm
5012
5013
5014@register_decomposition(aten.isin)
5015@out_wrapper()
5016def isin(elements, test_elements, *, assume_unique=False, invert=False):
5017    # handle when either elements or test_elements are Scalars (they can't both be)
5018    if not isinstance(elements, torch.Tensor):
5019        elements = torch.tensor(elements, device=test_elements.device)
5020    if not isinstance(test_elements, torch.Tensor):
5021        test_elements = torch.tensor(test_elements, device=elements.device)
5022
5023    if test_elements.numel() < 10.0 * pow(elements.numel(), 0.145):
5024        return isin_default(elements, test_elements, invert=invert)
5025    else:
5026        return isin_sorting(
5027            elements, test_elements, assume_unique=assume_unique, invert=invert
5028        )
5029
5030
5031def isin_default(elements, test_elements, *, invert=False):
5032    if elements.numel() == 0:
5033        return torch.empty_like(elements, dtype=torch.bool)
5034
5035    x = elements.view(*elements.shape, *((1,) * test_elements.ndim))
5036    if not invert:
5037        cmp = x == test_elements
5038    else:
5039        cmp = x != test_elements
5040    dim = tuple(range(-1, -test_elements.ndim - 1, -1))
5041    return cmp.any(dim=dim)
5042
5043
5044def isin_sorting(elements, test_elements, *, assume_unique=False, invert=False):
5045    elements_flat = elements.flatten()
5046    test_elements_flat = test_elements.flatten()
5047    if assume_unique:
5048        # This is the same as the aten implementation. For
5049        # assume_unique=False, we cannot use unique() here, so we use a
5050        # version with searchsorted instead.
5051        all_elements = torch.cat([elements_flat, test_elements_flat])
5052        sorted_elements, sorted_order = torch.sort(all_elements, stable=True)
5053
5054        duplicate_mask = sorted_elements[1:] == sorted_elements[:-1]
5055        duplicate_mask = torch.constant_pad_nd(duplicate_mask, [0, 1], False)
5056
5057        if invert:
5058            duplicate_mask = duplicate_mask.logical_not()
5059
5060        mask = torch.empty_like(duplicate_mask)
5061        mask = mask.index_copy(0, sorted_order, duplicate_mask)
5062
5063        return mask[0 : elements.numel()]
5064    else:
5065        sorted_test_elements, _ = torch.sort(test_elements_flat)
5066        idx = torch.searchsorted(sorted_test_elements, elements_flat)
5067        test_idx = torch.where(idx < sorted_test_elements.numel(), idx, 0)
5068        cmp = sorted_test_elements[test_idx] == elements_flat
5069        cmp = cmp.logical_not() if invert else cmp
5070        return cmp.reshape(elements.shape)
5071
5072
5073@register_decomposition(aten.take)
5074@out_wrapper()
5075def take(self, index):
5076    flattened = self.reshape(-1)
5077    return flattened[index]
5078
5079
5080@register_decomposition(aten.resize_as)
5081def resize_as(self, other, memory_format=None):
5082    if memory_format is None:
5083        memory_format = torch.contiguous_format
5084    if memory_format == torch.preserve_format:
5085        memory_format = suggest_memory_format(other)
5086    return aten.resize(self, other.shape, memory_format=memory_format)
5087
5088
5089register_inplace(aten.addbmm_, aten.addbmm)
5090register_inplace(aten.addmm_, aten.addmm)
5091register_inplace(aten.addmv_, aten.addmv)
5092register_inplace(aten.baddbmm_, aten.baddbmm)
5093register_inplace(aten.fill_, aten.fill)
5094register_inplace(aten.gelu_, aten.gelu)
5095register_inplace(aten.hardswish_, aten.hardswish)
5096register_inplace(aten.hardtanh_, aten.hardtanh)
5097register_inplace(aten.hardsigmoid_, aten.hardsigmoid)
5098register_inplace(aten.__iand__, aten.__and__)
5099register_inplace(aten.__ilshift__, aten.__lshift__)
5100register_inplace(aten.index_put_, aten.index_put)
5101register_inplace(aten.index_reduce_, aten.index_reduce)
5102register_inplace(aten.__ior__, aten.__or__)
5103register_inplace(aten.__irshift__, aten.__rshift__)
5104register_inplace(aten.__ixor__, aten.__xor__)
5105register_inplace(aten.leaky_relu_, aten.leaky_relu)
5106register_inplace(aten.logit_, aten.logit)
5107register_inplace(aten.relu_, aten.relu)
5108register_inplace(aten.renorm_, aten.renorm)
5109register_inplace(aten.round_, aten.round)
5110register_inplace(aten.scatter_, aten.scatter)
5111register_inplace(aten.scatter_add_, aten.scatter_add)
5112register_inplace(aten.scatter_reduce_, aten.scatter_reduce)
5113register_inplace(aten.silu_, aten.silu)
5114