1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import functools 4import itertools 5import numbers 6import operator 7import sys 8from enum import Enum 9from functools import partial, reduce 10from itertools import chain, product 11from typing import Any, Callable, cast, Iterable, List, Optional, Tuple, Union 12 13import torch 14import torch._meta_registrations 15import torch._prims as prims 16import torch._prims_common as utils 17import torch.nn.functional as F 18from torch import sym_float, sym_int, Tensor 19from torch._decomp import register_decomposition 20from torch._higher_order_ops.out_dtype import out_dtype 21from torch._prims_common import ( 22 IntLike, 23 NumberType, 24 suggest_memory_format, 25 TensorLike, 26 TensorSequenceType, 27) 28from torch._prims_common.wrappers import ( 29 _maybe_convert_to_dtype, 30 _maybe_resize_out, 31 _safe_copy_out, 32 out_wrapper, 33) 34from torch.utils import _pytree as pytree 35from torch.utils._pytree import tree_map 36 37 38DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] 39 40# None of these functions are publicly accessible; get at them 41# from torch._decomps 42__all__: List[str] = [] 43 44aten = torch._ops.ops.aten 45 46 47class Reduction(Enum): 48 NONE = 0 49 MEAN = 1 50 SUM = 2 51 52 53# This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided 54# We're currently re-using ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops 55# Will need to validate the non-elementwise uses 56def type_casts( 57 f: Callable, 58 type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND, 59 compute_dtype_only: bool = False, 60): 61 @functools.wraps(f) 62 def inner(*args, **kwargs): 63 flat_args = [ 64 x for x in pytree.arg_tree_leaves(*args, **kwargs) if isinstance(x, Tensor) 65 ] 66 computation_dtype, result_dtype = utils.elementwise_dtypes( 67 *flat_args, type_promotion_kind=type_promotion 68 ) 69 70 # TODO: pretty sure this is not quite right 71 def increase_prec(x): 72 if isinstance(x, Tensor): 73 return x.to(computation_dtype) 74 else: 75 return x 76 77 def decrease_prec(x): 78 if isinstance(x, Tensor): 79 return x.to(result_dtype) 80 else: 81 return x 82 83 r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs)) 84 if compute_dtype_only: 85 return r 86 else: 87 return tree_map(decrease_prec, r) 88 89 return inner 90 91 92compute_only_pw_cast_for_opmath = partial( 93 type_casts, 94 type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 95 compute_dtype_only=True, 96) 97pw_cast_for_opmath = partial( 98 type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 99) 100pw_cast_for_int_to_real = partial( 101 type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 102) 103 104 105# This expands x until x.dim() == dim. Might be useful as an operator 106def _unsqueeze_to_dim(x: Tensor, dim: int) -> Tensor: 107 for _ in range(dim - x.dim()): 108 x = x.unsqueeze(-1) 109 return x 110 111 112@register_decomposition(aten.tanh_backward) 113@out_wrapper("grad_input") 114@pw_cast_for_opmath 115def tanh_backward(out_grad: Tensor, y: Tensor): 116 return out_grad * (1 - y * y).conj_physical() 117 118 119@register_decomposition(aten.sigmoid_backward) 120@out_wrapper("grad_input") 121@pw_cast_for_opmath 122def sigmoid_backward(out_grad: Tensor, y: Tensor): 123 return out_grad * (y * (1 - y)).conj_physical() 124 125 126@register_decomposition(aten.softplus_backward) 127@out_wrapper("grad_input") 128@pw_cast_for_opmath 129def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float): 130 z = (x * beta).exp() 131 return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0)) 132 133 134@register_decomposition(aten.elu_backward) 135@out_wrapper("grad_input") 136@pw_cast_for_opmath 137def elu_backward( 138 grad_output: Tensor, 139 alpha: float, 140 scale: float, 141 input_scale: float, 142 is_result: bool, 143 self_or_result: Tensor, 144): 145 negcoef = alpha * scale 146 poscoef = scale 147 negiptcoef = input_scale 148 if is_result: 149 return torch.where( 150 self_or_result <= 0, 151 grad_output * negiptcoef * (self_or_result + negcoef), 152 grad_output * poscoef, 153 ) 154 else: 155 return torch.where( 156 self_or_result <= 0, 157 grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef), 158 grad_output * poscoef, 159 ) 160 161 162@register_decomposition([aten.fill.Scalar]) 163def fill_scalar(self, value): 164 return torch.full_like(self, value) 165 166 167@register_decomposition([aten.fill.Tensor]) 168def fill_tensor(self, value: Tensor): 169 torch._check( 170 value.dim() == 0, 171 lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions", 172 ) 173 return aten.copy(self, value) 174 175 176@register_decomposition(aten.hardsigmoid) 177@out_wrapper() 178@pw_cast_for_opmath 179def hardsigmoid(self: Tensor) -> Tensor: 180 return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6 181 182 183@register_decomposition(aten.hardsigmoid_backward) 184@out_wrapper("grad_input") 185@pw_cast_for_opmath 186def hardsigmoid_backward(grad_output: Tensor, self: Tensor): 187 return torch.where( 188 (self > -3.0) & (self < 3.0), 189 grad_output * (1.0 / 6.0), 190 0.0, 191 ) 192 193 194@register_decomposition(aten.hardtanh_backward) 195@out_wrapper("grad_input") 196def hardtanh_backward( 197 grad_output: Tensor, self: Tensor, min_val: float, max_val: float 198): 199 return torch.where((self <= min_val) | (self >= max_val), 0.0, grad_output) 200 201 202@register_decomposition(aten.hardswish) 203@out_wrapper() 204@pw_cast_for_opmath 205def hardswish(self: Tensor) -> Tensor: 206 return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6 207 208 209@register_decomposition(aten.hardswish_backward) 210@out_wrapper() 211@pw_cast_for_opmath 212def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor: 213 return torch.where( 214 self < -3, 215 0.0, 216 torch.where(self <= 3, grad_output * ((self / 3) + 0.5), grad_output), 217 ) 218 219 220@register_decomposition(aten.threshold_backward) 221@out_wrapper("grad_input") 222def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float): 223 return torch.where(self <= threshold, 0, grad_output) 224 225 226@register_decomposition(aten.leaky_relu_backward) 227@out_wrapper("grad_input") 228@pw_cast_for_opmath 229def leaky_relu_backward( 230 grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool 231): 232 return torch.where(self > 0, grad_output, grad_output * negative_slope) 233 234 235@register_decomposition(aten.gelu_backward) 236@out_wrapper("grad_input") 237@pw_cast_for_opmath 238def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"): 239 M_SQRT2 = 1.41421356237309504880 240 M_SQRT1_2 = 0.70710678118654752440 241 M_2_SQRTPI = 1.12837916709551257390 242 if approximate == "tanh": 243 kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 244 kKappa = 0.044715 245 x_sq = self * self 246 x_cube = x_sq * self 247 inner = kBeta * (self + kKappa * x_cube) 248 tanh_inner = torch.tanh(inner) 249 250 left = 0.5 * self 251 right = 1 + tanh_inner 252 253 left_derivative = 0.5 * right 254 255 tanh_derivative = 1 - tanh_inner * tanh_inner 256 inner_derivative = kBeta * (1 + 3 * kKappa * x_sq) 257 right_derivative = left * tanh_derivative * inner_derivative 258 259 return grad * (left_derivative + right_derivative) 260 else: 261 kAlpha = M_SQRT1_2 262 kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5 263 cdf = 0.5 * (1 + torch.erf(self * kAlpha)) 264 pdf = kBeta * torch.exp(self * self * -0.5) 265 return grad * (cdf + self * pdf) 266 267 268@register_decomposition(aten.mish_backward) 269@pw_cast_for_opmath 270def mish_backward(grad_output: Tensor, input: Tensor): 271 input_tanh_softplus = torch.tanh(F.softplus(input)) 272 input_sigmoid = torch.sigmoid(input) 273 out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus) 274 return grad_output * (input_tanh_softplus + out) 275 276 277@register_decomposition(aten.silu) 278@out_wrapper() 279@pw_cast_for_opmath 280def silu(self: Tensor) -> Tensor: 281 return self * torch.sigmoid(self) 282 283 284@register_decomposition(aten.silu_backward) 285@out_wrapper("grad_input") 286@pw_cast_for_opmath 287def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor: 288 sigmoid = 1 / (1 + torch.exp(-self)) 289 return grad_output * sigmoid * (1 + self * (1 - sigmoid)) 290 291 292@register_decomposition(aten._prelu_kernel) 293def _prelu_kernel(self: Tensor, weight: Tensor) -> Tensor: 294 return torch.where(self > 0, self, weight * self) 295 296 297@register_decomposition(aten._prelu_kernel_backward) 298def _prelu_kernel_backward( 299 grad_output: Tensor, 300 self: Tensor, 301 weight: Tensor, 302) -> Tuple[Tensor, Tensor]: 303 input_grad = torch.where(self > 0, grad_output, weight * grad_output) 304 weight_grad = torch.where(self > 0, 0.0, self * grad_output) 305 return (input_grad, weight_grad) 306 307 308@register_decomposition(aten.rrelu_with_noise) 309@aten.rrelu_with_noise.default.py_impl(DispatchKey.AutogradCUDA) 310@out_wrapper() 311@pw_cast_for_opmath 312def rrelu_with_noise( 313 self: Tensor, 314 noise: Tensor, 315 lower: float = 0.125, 316 upper: float = 0.3333333333333333, 317 training: bool = False, 318 generator: Optional[torch.Generator] = None, 319) -> Tensor: 320 assert generator is None 321 if training: 322 not_positive = self <= 0 323 r = aten.uniform(self, lower, upper) 324 output = torch.where(not_positive, self * r, self) 325 noise.copy_(torch.where(not_positive, r, 1)) 326 return output 327 else: 328 negative_slope = (lower + upper) / 2 329 return aten.leaky_relu(self, negative_slope) 330 331 332@register_decomposition(aten.rrelu_with_noise_) 333@aten.rrelu_with_noise_.default.py_impl(DispatchKey.AutogradCUDA) 334@pw_cast_for_opmath 335def rrelu_with_noise_( 336 self: Tensor, 337 noise: Tensor, 338 lower: float = 0.125, 339 upper: float = 0.3333333333333333, 340 training: bool = False, 341 generator: Optional[torch.Generator] = None, 342) -> Tensor: 343 return self.copy_(rrelu_with_noise(self, noise, lower, upper, training, generator)) 344 345 346@register_decomposition(aten.rrelu_with_noise_backward) 347@out_wrapper() 348@pw_cast_for_opmath 349def rrelu_with_noise_backward( 350 grad_output: Tensor, 351 self: Tensor, 352 noise: Tensor, 353 lower: float, 354 upper: float, 355 training: bool, 356 self_is_result: bool, 357) -> Tensor: 358 if training and upper - lower > 1e-6: 359 return grad_output.mul(noise) 360 else: 361 negative_slope = (lower + upper) / 2 362 return aten.leaky_relu_backward( 363 grad_output, self, negative_slope, self_is_result 364 ) 365 366 367@register_decomposition(aten.log_sigmoid_backward) 368@out_wrapper("grad_input") 369@pw_cast_for_opmath 370def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor: 371 in_negative = self < 0 372 max_deriv = torch.where(in_negative, 1, 0) 373 sign = torch.where(in_negative, 1, -1) 374 z = torch.exp(-torch.abs(self)) 375 return grad_output * (max_deriv - sign * (z / (1 + z))) 376 # CPU has a special formula that uses buffer, but disabled for convenience sake 377 # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output 378 379 380def apply_loss_reduction(loss: Tensor, reduction: int): 381 if reduction == Reduction.MEAN.value: 382 return torch.mean(loss) 383 elif reduction == Reduction.SUM.value: 384 return torch.sum(loss) 385 else: 386 return loss 387 388 389def to_real_dtype(dtype: torch.dtype): 390 if dtype == torch.complex32: 391 return torch.float16 392 elif dtype == torch.complex64: 393 return torch.float32 394 elif dtype == torch.complex128: 395 return torch.float64 396 397 398# TODO: None of these loss castings are quite correct, see 399# https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels 400# perform the pointwise portion in opmath, but don't maintain it between the 401# pointwise portion and the reduction 402 403 404@register_decomposition(aten.mse_loss) 405@out_wrapper() 406@pw_cast_for_opmath 407def mse_loss( 408 self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value 409) -> Tensor: 410 loss = (self - target) ** 2 411 return apply_loss_reduction(loss, reduction) 412 413 414@register_decomposition(aten.mse_loss_backward) 415@out_wrapper("grad_input") 416@pw_cast_for_opmath 417def mse_loss_backward( 418 grad_output: Tensor, input: Tensor, target: Tensor, reduction: int 419): 420 norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0 421 return norm * (input - target) * grad_output 422 423 424@register_decomposition(aten._safe_softmax) 425def safe_softmax(self, dim, dtype=None): 426 out = torch.softmax(self, dim=dim, dtype=dtype) 427 masked = self.eq(float("-inf")) 428 masked_rows = torch.all(masked, dim=dim, keepdim=True) 429 zeros = torch.zeros_like(out) 430 return torch.where(masked_rows, zeros, out) 431 432 433@register_decomposition(aten.smooth_l1_loss) 434@out_wrapper() 435@pw_cast_for_opmath 436def smooth_l1_loss( 437 self: Tensor, 438 target: Tensor, 439 reduction: int = Reduction.MEAN.value, 440 beta: float = 1.0, 441): 442 loss = (self - target).abs() 443 loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta) 444 return apply_loss_reduction(loss, reduction) 445 446 447@register_decomposition(aten.smooth_l1_loss_backward.default) 448@pw_cast_for_opmath 449def smooth_l1_loss_backward( 450 grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, beta: float 451): 452 norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0 453 x = self - target 454 abs_x = torch.abs(x) 455 norm_grad = norm * grad_output 456 return torch.where( 457 abs_x < beta, 458 norm_grad * x / beta, 459 norm_grad * torch.sign(x), 460 ) 461 462 463@register_decomposition(aten.smooth_l1_loss_backward.grad_input) 464@pw_cast_for_opmath 465def smooth_l1_loss_backward_out( 466 grad_output: Tensor, 467 self: Tensor, 468 target: Tensor, 469 reduction: int, 470 beta: float, 471 grad_input: Tensor, 472): 473 result = smooth_l1_loss_backward(grad_output, self, target, reduction, beta) 474 _maybe_resize_out(grad_input, result.shape) 475 return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True) 476 477 478@register_decomposition(aten.huber_loss_backward.default) 479@pw_cast_for_opmath 480def huber_loss_backward( 481 grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float 482): 483 norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0 484 x = self - target 485 return torch.where( 486 x < -delta, 487 -norm * grad_output * delta, 488 torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output), 489 ) 490 491 492# We cannot use @out_wrapper() here, because the output tensor is not named 'out', it's 'grad_input' 493@register_decomposition(aten.huber_loss_backward.out) 494@pw_cast_for_opmath 495def huber_loss_backward_out( 496 grad_output: Tensor, 497 self: Tensor, 498 target: Tensor, 499 reduction: int, 500 delta: float, 501 grad_input: Tensor, 502): 503 result = huber_loss_backward(grad_output, self, target, reduction, delta) 504 _maybe_resize_out(grad_input, result.shape) 505 return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True) 506 507 508def _nll_loss_backward( 509 grad_output: Tensor, 510 self: Tensor, 511 target: Tensor, 512 weight: Optional[Tensor], 513 reduction: int, 514 ignore_index: int, 515 total_weight: Tensor, 516) -> Tensor: 517 channel_dim = 0 if self.dim() < 2 else 1 518 if reduction == Reduction.MEAN.value: 519 grad_output = grad_output / total_weight 520 521 target = target.unsqueeze(channel_dim) 522 safe_target = torch.where(target != ignore_index, target, 0) 523 grad_input = torch.zeros_like(self) 524 grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0) 525 526 if grad_input.dim() > grad_output.dim() > 0: 527 grad_output = grad_output.unsqueeze(channel_dim) 528 529 if weight is not None: 530 new_shape = [1 for _ in range(self.dim())] 531 new_shape[channel_dim] = weight.shape[0] 532 weight = weight.reshape(new_shape) 533 grad_output = grad_output * weight 534 535 grad_output = torch.where(target != ignore_index, grad_output, 0) 536 537 return grad_input * grad_output 538 539 540@register_decomposition(aten.glu_backward) 541@out_wrapper("grad_input") 542@pw_cast_for_opmath 543def glu_backward(grad_output: Tensor, self: Tensor, dim: int) -> Tensor: 544 assert self.dim() > 0, "glu does not support 0-dimensional tensors" 545 wrap_dim = utils.canonicalize_dim(self.dim(), dim) 546 nIn = self.size(wrap_dim) 547 assert ( 548 nIn % 2 == 0 549 ), f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}" 550 inputSize = nIn // 2 551 firstHalf = self.narrow(wrap_dim, 0, inputSize) 552 secondHalf = self.narrow(wrap_dim, inputSize, inputSize) 553 gradInputFirstHalf = torch.sigmoid(secondHalf) 554 gradInputSecondHalf = ( 555 (1.0 - gradInputFirstHalf) * gradInputFirstHalf * firstHalf * grad_output 556 ) 557 gradInputFirstHalf = gradInputFirstHalf * grad_output 558 return torch.cat([gradInputFirstHalf, gradInputSecondHalf], dim=wrap_dim) 559 560 561@register_decomposition(aten.nll_loss_backward) 562@out_wrapper("grad_input") 563def nll_loss_backward( 564 grad_output: Tensor, 565 self: Tensor, 566 target: Tensor, 567 weight: Optional[Tensor], 568 reduction: int, 569 ignore_index: int, 570 total_weight: Tensor, 571) -> Tensor: 572 assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D" 573 assert ( 574 target.dim() <= 1 575 ), "0D or 1D target tensor expected, multi-target not supported" 576 577 no_batch_dim = self.dim() == 1 and target.dim() == 0 578 assert no_batch_dim or ( 579 self.shape[0] == target.shape[0] 580 ), f"size mismatch (got input: {self.shape}, target: {target.shape})" 581 assert total_weight.numel() == 1, ( 582 "expected total_weight to be a single element tensor, got: ", 583 f"{total_weight.shape} ({total_weight.numel()} elements)", 584 ) 585 586 assert ( 587 weight is None or weight.numel() == self.shape[-1] 588 ), "weight tensor should be defined either for all or no classes" 589 590 if reduction == Reduction.NONE.value and self.dim() == 2: 591 assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], ( 592 f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but " 593 f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}" 594 ) 595 else: 596 assert ( 597 grad_output.dim() <= 1 and grad_output.numel() == 1 598 ), f"Expected a single element grad_output tensor, but got: {grad_output.shape}" 599 600 return _nll_loss_backward( 601 grad_output, self, target, weight, reduction, ignore_index, total_weight 602 ) 603 604 605@register_decomposition(aten.nll_loss2d_backward) 606@out_wrapper("grad_input") 607def nll_loss2d_backward( 608 grad_output: Tensor, 609 self: Tensor, 610 target: Tensor, 611 weight: Optional[Tensor], 612 reduction: int, 613 ignore_index: int, 614 total_weight: Tensor, 615) -> Tensor: 616 assert ( 617 self.dim() == 4 618 ), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}" 619 620 assert ( 621 target.dim() == 3 622 ), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}" 623 624 assert ( 625 self.shape[0] == target.shape[0] 626 and self.shape[2] == target.shape[1] 627 and self.shape[3] == target.shape[2] 628 ), f"size mismatch (got input: {self.shape}, target: {target.shape}" 629 630 assert total_weight.numel() == 1, ( 631 "expected total_weight to be a single element tensor, " 632 f"got: {total_weight.shape} ( {total_weight.numel()}, elements)" 633 ) 634 635 return _nll_loss_backward( 636 grad_output, self, target, weight, reduction, ignore_index, total_weight 637 ) 638 639 640@register_decomposition(aten.binary_cross_entropy) 641@out_wrapper() 642@pw_cast_for_opmath 643def binary_cross_entropy( 644 self: Tensor, 645 target: Tensor, 646 weight: Optional[Tensor] = None, 647 reduction: int = Reduction.MEAN.value, 648) -> Tensor: 649 # We cannot currently model this without introducing data-dependent control flow 650 # TORCH_CHECK( 651 # (input_val >= 0) && (input_val <= 1), 652 # "all elements of input should be between 0 and 1" 653 # ) 654 loss = (target - 1) * torch.maximum( 655 torch.log1p(-self), self.new_full((), -100) 656 ) - target * torch.maximum(torch.log(self), self.new_full((), -100)) 657 if weight is not None: 658 loss = loss * weight 659 return apply_loss_reduction(loss, reduction) 660 661 662@register_decomposition(aten.binary_cross_entropy_backward) 663@out_wrapper("grad_input") 664@pw_cast_for_opmath 665def binary_cross_entropy_backward( 666 grad_output: Tensor, 667 self: Tensor, 668 target: Tensor, 669 weight: Optional[Tensor] = None, 670 reduction: int = Reduction.MEAN.value, 671) -> Tensor: 672 EPSILON = 1e-12 673 result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON) 674 if weight is not None: 675 result = result * weight 676 if reduction == Reduction.MEAN.value: 677 result = result / self.numel() 678 return result 679 680 681@register_decomposition(aten.soft_margin_loss) 682@out_wrapper() 683@pw_cast_for_opmath 684def soft_margin_loss( 685 input: Tensor, 686 target: Tensor, 687 reduction: int = Reduction.MEAN.value, 688) -> Tensor: 689 loss = torch.log1p(torch.exp(-input * target)) 690 return apply_loss_reduction(loss, reduction) 691 692 693@register_decomposition(aten.soft_margin_loss_backward) 694@out_wrapper("grad_input") 695@pw_cast_for_opmath 696def soft_margin_loss_backward( 697 grad_output: Tensor, 698 self: Tensor, 699 target: Tensor, 700 reduction: int = Reduction.MEAN.value, 701) -> Tensor: 702 grad_input = target * grad_output * (torch.sigmoid(target * self) - 1) 703 if reduction == Reduction.MEAN.value: 704 grad_input = grad_input / self.numel() 705 return grad_input 706 707 708@register_decomposition(aten.dist) 709@out_wrapper() 710def dist(input: Tensor, other: Tensor, p: float = 2): 711 return aten.norm(input - other, p=p) 712 713 714@register_decomposition(aten._euclidean_dist) 715@out_wrapper() 716def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor: 717 x1_norm = x1.pow(2).sum(-1, True) 718 x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format) 719 x2_norm = x2.pow(2).sum(-1, True) 720 x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format) 721 x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1) 722 x2_ = torch.cat([x2, x2_pad, x2_norm], -1) 723 result = x1_.matmul(x2_.mT) 724 return result.clamp_min(0).sqrt() 725 726 727@register_decomposition(aten.slice_backward) 728@out_wrapper() 729def slice_backward( 730 grad_output: Tensor, 731 input_sizes: List[int], 732 dim: int, 733 start: int, 734 end: int, 735 step: int, 736): 737 grad_input = grad_output.new_zeros(input_sizes) 738 return torch.slice_scatter(grad_input, grad_output, dim, start, end, step) 739 740 741@register_decomposition(aten.slice.Tensor) 742def slice_forward( 743 # Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1 744 self: Tensor, 745 dim: int = 0, 746 start: Optional[int] = None, 747 end: Optional[int] = None, 748 step: int = 1, 749): 750 from torch.fx.experimental.symbolic_shapes import ( 751 guard_size_oblivious, 752 statically_known_true, 753 ) 754 755 ndim = self.dim() 756 if ndim == 0: 757 raise RuntimeError("slice() cannot be applied to a 0-dim tensor.") 758 dim = utils.canonicalize_dim(self.dim(), dim) 759 sizes = list(self.size()) 760 strides = list(self.stride()) 761 762 if step <= 0: 763 raise RuntimeError("slice step must be positive") 764 765 start_val = start if start is not None else 0 766 end_val = end if end is not None else sys.maxsize # 2^63 - 1 767 768 if guard_size_oblivious(start_val < 0): 769 start_val += sizes[dim] 770 771 if guard_size_oblivious(end_val < 0): 772 end_val += sizes[dim] 773 774 if guard_size_oblivious(start_val < 0): 775 start_val = 0 776 elif guard_size_oblivious(start_val > sizes[dim]): 777 start_val = sizes[dim] 778 779 if guard_size_oblivious(end_val < start_val): 780 end_val = start_val 781 elif statically_known_true(end_val == sys.maxsize) or guard_size_oblivious( 782 end_val > sizes[dim] 783 ): 784 end_val = sizes[dim] 785 786 storage_offset = self.storage_offset() + start_val * strides[dim] 787 len = end_val - start_val 788 sizes[dim] = (len + step - 1) // step 789 strides[dim] *= step 790 791 if self.is_quantized: 792 raise NotImplementedError( 793 "Slice decomposition for quantized tensors aren't implemented" 794 ) 795 else: 796 return self.as_strided(sizes, strides, storage_offset) 797 798 799def _normalize_start_end( 800 x: Tensor, dim: int, start: Optional[int], end: Optional[int] 801) -> Tuple[int, int]: 802 """ 803 Normalize start and end such that both are in the range 804 [0, x.get_size()[dim]] and start <= end. 805 """ 806 dim_size = x.shape[dim] 807 808 def clamp_wrap(val, lower, upper, default) -> int: 809 if val is None: 810 return default 811 if val < 0: 812 val = val + dim_size 813 return min(max(val, lower), upper) 814 815 start = clamp_wrap(start, 0, dim_size, 0) 816 end = clamp_wrap(end, start, dim_size, dim_size) 817 return start, end 818 819 820# This is not in torch._refs because aten.index used by 821# aten._unsafe_masked_index does not have a decomposition. 822@register_decomposition(aten.slice_scatter) 823@out_wrapper() 824def slice_scatter( 825 input: Tensor, 826 src: Tensor, 827 dim: int = 0, 828 start: Optional[int] = None, 829 end: Optional[int] = None, 830 step: int = 1, 831): 832 dim = utils.canonicalize_dim(input.ndim, dim) 833 dim_size = input.shape[dim] 834 start, end = _normalize_start_end(input, dim, start, end) 835 836 src_size = list(input.shape) 837 src_size[dim] = (end - start + (step - 1)) // step 838 src = src.expand(src_size) 839 840 if start == 0 and end == dim_size and step == 1: 841 return src.clone() 842 843 indices = [None] * input.dim() 844 idx = torch.arange(dim_size, device=input.device) 845 indices[dim] = (idx - start) // step 846 847 mask = torch.ones(dim_size, device=input.device, dtype=torch.bool) 848 if start != 0: 849 mask = torch.logical_and(mask, idx >= start) 850 851 if end != dim_size: 852 mask = torch.logical_and(mask, idx < end) 853 854 if step != 1: 855 mask = torch.logical_and(mask, (idx - start) % step == 0) 856 857 mask_shape = [1] * input.dim() 858 mask_shape[dim] = -1 859 mask = mask.view(mask_shape) 860 return aten.where(mask, aten._unsafe_masked_index(src, mask, indices, 0), input) 861 862 863@register_decomposition(aten.select_backward) 864@out_wrapper() 865def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int): 866 grad_input = grad_output.new_zeros(input_sizes) 867 return torch.select_scatter(grad_input, grad_output, dim, index) 868 869 870@register_decomposition(aten.diagonal_backward) 871@out_wrapper() 872def diagonal_backward( 873 grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int 874): 875 grad_input = grad_output.new_zeros(input_sizes) 876 return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2) 877 878 879def _cast_grad_to_input_dtype( 880 grad_output: Tensor, grad_input: Tensor, input_dtype: torch.dtype 881): 882 if grad_output.dtype != input_dtype: 883 grad_input = grad_input.to(input_dtype) 884 return grad_input 885 886 887@register_decomposition(aten._softmax_backward_data) 888@out_wrapper("grad_input") 889@compute_only_pw_cast_for_opmath 890def _softmax_backward_data( 891 grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype 892): 893 new_grad_output = grad_output * output 894 grad_input = new_grad_output - output * torch.sum( 895 new_grad_output, dim=dim, keepdim=True 896 ) 897 898 # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor 899 # if grad_output.device == torch.device("cpu"): 900 # return grad_input.contiguous() 901 902 return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype).contiguous() 903 904 905@register_decomposition(aten._log_softmax_backward_data) 906@out_wrapper() 907@compute_only_pw_cast_for_opmath 908def _log_softmax_backward_data( 909 grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype 910): 911 grad_input = grad_output - torch.exp(output) * torch.sum( 912 grad_output, dim=dim, keepdim=True 913 ) 914 return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype) 915 916 917def _im2col_col2im_indices_along_dim( 918 input_d, kernel_d, dilation_d, padding_d, stride_d, device 919): 920 """Utility function to implement im2col and col2im""" 921 blocks_d = input_d + padding_d * 2 - dilation_d * (kernel_d - 1) 922 923 arange_kw = partial(torch.arange, dtype=torch.int64, device=device) 924 925 # Stride kernel over input and find starting indices along dim d 926 blocks_d_indices = arange_kw(0, blocks_d, stride_d).unsqueeze(0) 927 928 # Apply dilation on kernel and find its indices along dim d 929 kernel_grid = arange_kw(0, kernel_d * dilation_d, dilation_d).unsqueeze(-1) 930 931 # Broadcast and add kernel starting positions (indices) with 932 # kernel_grid along dim d, to get block indices along dim d 933 return blocks_d_indices + kernel_grid 934 935 936@register_decomposition(aten.im2col) 937@out_wrapper() 938def im2col( 939 input: Tensor, 940 kernel_size: List[int], 941 dilation: List[int], 942 padding: List[int], 943 stride: List[int], 944) -> Tensor: 945 torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported") 946 torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported") 947 torch._check(len(padding) == 2, lambda: "im2col(): only 2D padding supported") 948 torch._check(len(stride) == 2, lambda: "im2col(): only 2D stride supported") 949 950 def check_positive(param, param_name, strict=True): 951 cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param) 952 torch._check( 953 cond, lambda: "{param_name} should be greater {'than' zero, but got {param}" 954 ) 955 956 check_positive(kernel_size, "kernel_size") 957 check_positive(dilation, "dilation") 958 check_positive(dilation, "padding", strict=False) 959 check_positive(stride, "stride") 960 961 shape = input.shape 962 ndim = len(shape) 963 torch._check( 964 ndim in (3, 4) and all(d != 0 for d in shape[-3:]), 965 lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size " 966 f"and non-zero dimensions, but got: {tuple(shape)}", 967 ) 968 output_size = tuple( 969 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st 970 for out, pad, dil, ker, st in zip( 971 shape[-2:], padding, dilation, kernel_size, stride 972 ) 973 ) 974 torch._check( 975 all(c > 0 for c in output_size), 976 lambda: f"Given an input with spacial size {tuple(shape[-2:])}, " 977 f"kernel_size={kernel_size}, dilation={dilation}, " 978 f"padding={padding}, stride={stride}, " 979 "the calculated shape of the array of sliding blocks " 980 f"is {output_size}, but its components must be at least one.", 981 ) 982 batched_input = ndim == 4 983 if not batched_input: 984 input = input.unsqueeze(0) 985 986 batch_dim, channel_dim, input_h, input_w = input.shape 987 988 stride_h, stride_w = stride 989 padding_h, padding_w = padding 990 dilation_h, dilation_w = dilation 991 kernel_h, kernel_w = kernel_size 992 993 blocks_row_indices = _im2col_col2im_indices_along_dim( 994 input_h, kernel_h, dilation_h, padding_h, stride_h, input.device 995 ) 996 blocks_col_indices = _im2col_col2im_indices_along_dim( 997 input_w, kernel_w, dilation_w, padding_w, stride_w, input.device 998 ) 999 1000 # Note that F.pad takes (padding_left, padding_right, padding_top, padding_bottom) 1001 # ugh 1002 padded_input = F.pad(input, (padding_w, padding_w, padding_h, padding_h)) 1003 1004 blocks_row_indices = blocks_row_indices.unsqueeze(-1).unsqueeze(-1) 1005 output = padded_input[:, :, blocks_row_indices, blocks_col_indices] 1006 output = output.permute(0, 1, 2, 4, 3, 5) 1007 num_blocks_row = blocks_row_indices.size(1) 1008 num_blocks_col = blocks_col_indices.size(1) 1009 output = output.reshape( 1010 batch_dim, channel_dim * kernel_h * kernel_w, num_blocks_row * num_blocks_col 1011 ) 1012 1013 if not batched_input: 1014 output = output.squeeze(0) 1015 return output 1016 1017 1018@register_decomposition(aten.col2im) 1019@out_wrapper() 1020@pw_cast_for_opmath 1021def col2im( 1022 input: Tensor, 1023 output_size: List[int], 1024 kernel_size: List[int], 1025 dilation: List[int], 1026 padding: List[int], 1027 stride: List[int], 1028) -> Tensor: 1029 torch._check(len(output_size) == 2, lambda: "only 2D output_size supported") 1030 torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported") 1031 torch._check(len(dilation) == 2, lambda: "only 2D dilation supported") 1032 torch._check(len(padding) == 2, lambda: "only 2D padding supported") 1033 torch._check(len(stride) == 2, lambda: "only 2D stride supported") 1034 1035 def check_positive(param, param_name, strict=True): 1036 cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param) 1037 torch._check( 1038 cond, lambda: "{param_name} should be greater than zero, but got {param}" 1039 ) 1040 1041 check_positive(kernel_size, "kernel_size") 1042 check_positive(dilation, "dilation") 1043 check_positive(padding, "padding", strict=False) 1044 check_positive(stride, "stride") 1045 check_positive(output_size, "output_size") 1046 1047 shape = input.shape 1048 ndim = len(shape) 1049 torch._check( 1050 ndim in (2, 3) and all(d != 0 for d in shape[-2:]), 1051 lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size " 1052 f"and non-zero dimensions, but got: {tuple(shape)}", 1053 ) 1054 prod_kernel_size = kernel_size[0] * kernel_size[1] 1055 torch._check( 1056 shape[-2] % prod_kernel_size == 0, 1057 lambda: "Expected size of input's first non-batch dimension to be divisible by the " 1058 f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and " 1059 f"kernel_size={kernel_size}", 1060 ) 1061 col = [ 1062 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st 1063 for out, pad, dil, ker, st in zip( 1064 output_size, padding, dilation, kernel_size, stride 1065 ) 1066 ] 1067 L = col[0] * col[1] 1068 torch._check( 1069 shape[-1] == L, 1070 lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, " 1071 f"dilation={dilation}, padding={padding}, stride={stride}, " 1072 f"expected input.size(-1) to be {L} but got {shape[-1]}.", 1073 ) 1074 torch._check( 1075 L > 0, 1076 lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, " 1077 f"dilation={dilation}, padding={padding}, stride={stride}, " 1078 f"expected input.size(-1) to be {L} but got {shape[-1]}.", 1079 ) 1080 batched_input = ndim == 3 1081 if not batched_input: 1082 input = input.unsqueeze(0) 1083 1084 shape = input.shape 1085 1086 out_h, out_w = output_size 1087 stride_h, stride_w = stride 1088 padding_h, padding_w = padding 1089 dilation_h, dilation_w = dilation 1090 kernel_h, kernel_w = kernel_size 1091 1092 # col2im is defined as the backwards of im2col, so we differentiate its decomposition by hand 1093 input = input.reshape([shape[0], shape[1] // prod_kernel_size] + kernel_size + col) 1094 input = input.permute(0, 1, 2, 4, 3, 5) 1095 1096 indices_row = _im2col_col2im_indices_along_dim( 1097 out_h, kernel_h, dilation_h, padding_h, stride_h, input.device 1098 ) 1099 indices_row = _unsqueeze_to_dim(indices_row, 4) 1100 indices_col = _im2col_col2im_indices_along_dim( 1101 out_w, kernel_w, dilation_w, padding_w, stride_w, input.device 1102 ) 1103 1104 output_padded_size = [o + 2 * p for o, p in zip(output_size, padding)] 1105 output = input.new_zeros( 1106 [shape[0], shape[1] // prod(kernel_size)] + output_padded_size 1107 ) 1108 idx = (None, None, indices_row, indices_col) 1109 output = aten._unsafe_index_put(output, idx, input, accumulate=True) 1110 output = F.pad(output, (-padding_w, -padding_w, -padding_h, -padding_h)) 1111 1112 if not batched_input: 1113 output = output.squeeze(0) 1114 return output 1115 1116 1117@register_decomposition(aten.native_dropout_backward) 1118@out_wrapper() 1119def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float): 1120 # According to the CUDA kernel implementation we should have this test; 1121 # but it seems to fail tests! 1122 # torch._check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}") 1123 1124 # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format 1125 # This different from TensorIterator's behavior 1126 r = (grad_output * (mask.type_as(grad_output) * scale)).clone( 1127 memory_format=utils.suggest_memory_format(grad_output) 1128 ) 1129 return r 1130 1131 1132@register_decomposition(aten.unfold_backward) 1133@out_wrapper() 1134def unfold_backward( 1135 grad: Tensor, input_size: List[int], dimension: int, size: int, step: int 1136) -> Tensor: 1137 if len(input_size) == 0: 1138 return torch.squeeze_copy(grad, 0) 1139 dim = utils.canonicalize_dim(len(input_size), dimension) 1140 idx = torch.arange(input_size[dim], device=grad.device, dtype=torch.int32) 1141 idx = idx.unfold(0, size, step).flatten() 1142 grad = grad.movedim(-1, dim + 1).flatten(dim, dim + 1) 1143 # nb. At the moment this generates two kernels in triton 1144 # It could potentially be fused into one call to scatter_reduce, 1145 # in the case step <= size provided scatter_reduce generates 1 kernel 1146 grad_input = grad.new_zeros(input_size) 1147 index = (None,) * dim + (idx,) 1148 return aten._unsafe_index_put(grad_input, index, grad, accumulate=True).contiguous() 1149 1150 1151@register_decomposition(aten.logit_backward.default) 1152@pw_cast_for_opmath 1153def logit_backward( 1154 grad_output: Tensor, self: Tensor, eps: Optional[float] = None 1155) -> Tensor: 1156 if eps is not None: 1157 lo = eps 1158 hi = 1.0 - lo 1159 return torch.where( 1160 torch.logical_and(self >= lo, self <= hi), 1161 grad_output / (self * (1.0 - self)), 1162 0.0, 1163 ) 1164 else: 1165 return torch.where( 1166 torch.logical_and(self >= 0.0, self <= 1.0), 1167 grad_output / (self * (1.0 - self)), 1168 self.new_full((), float("nan")), 1169 ) 1170 1171 1172@register_decomposition(aten.dropout) 1173@aten.dropout.default.py_impl(DispatchKey.CompositeImplicitAutograd) 1174@aten.dropout.default.py_impl(DispatchKey.Autograd) 1175def dropout(input: Tensor, p: float, train: Optional[bool]): 1176 if train and p != 0: 1177 return aten.native_dropout(input, p, train)[0] 1178 else: 1179 return input.clone() 1180 1181 1182@register_decomposition(aten.native_dropout) 1183@out_wrapper("out0", "out1") 1184def native_dropout(input: Tensor, p: float, train: Optional[bool]): 1185 if train and p != 0: 1186 if p == 1: 1187 return (torch.zeros_like(input), torch.zeros_like(input, dtype=torch.bool)) 1188 if not input.dtype.is_floating_point: 1189 raise RuntimeError( 1190 "result type Float can't be cast to the desired output type Long" 1191 ) 1192 bool_mask = torch.rand_like(input) > p 1193 res = bool_mask * input * float(1.0 / (1.0 - p)) 1194 return (res, bool_mask) 1195 else: 1196 return (input, torch.ones_like(input, dtype=torch.bool)) 1197 1198 1199@register_decomposition(aten._softmax) 1200@out_wrapper() 1201def _softmax(x: Tensor, dim: int, half_to_float: bool): 1202 # eager softmax returns a contiguous tensor. Ensure that decomp also returns 1203 # a contiguous tensor. 1204 x = x.contiguous() 1205 if half_to_float: 1206 assert x.dtype == torch.half 1207 computation_dtype, result_dtype = utils.elementwise_dtypes( 1208 x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 1209 ) 1210 x = x.to(computation_dtype) 1211 if x.numel() == 0: 1212 unnormalized = torch.exp(x) 1213 else: 1214 x_max = torch.amax(x, dim, keepdim=True) 1215 unnormalized = torch.exp(x - x_max) 1216 result = unnormalized / torch.sum(unnormalized, dim, keepdim=True) 1217 if not half_to_float: 1218 result = result.to(result_dtype) 1219 return result 1220 1221 1222@register_decomposition(aten._log_softmax) 1223@out_wrapper() 1224def _log_softmax(x: Tensor, dim: int, half_to_float: bool): 1225 # eager log_softmax returns a contiguous tensor. Ensure that decomp also 1226 # returns a contiguous tensor. 1227 x = x.contiguous() 1228 if half_to_float: 1229 assert x.dtype == torch.half 1230 computation_dtype, result_dtype = utils.elementwise_dtypes( 1231 x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 1232 ) 1233 x = x.to(computation_dtype) 1234 if x.numel() == 0: 1235 shifted = x 1236 else: 1237 x_max = torch.amax(x, dim, keepdim=True) 1238 shifted = x - x_max 1239 shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True)) 1240 result = shifted - shifted_logsumexp 1241 if not half_to_float: 1242 result = result.to(result_dtype) 1243 return result 1244 1245 1246@register_decomposition(aten.embedding) 1247@out_wrapper() 1248def embedding( 1249 weight: Tensor, 1250 indices: Tensor, 1251 padding_idx: int = -1, 1252 scale_grad_by_freq: bool = False, 1253 sparse: bool = False, 1254) -> Tensor: 1255 assert weight.dim() == 2, "'weight' must be 2-D" 1256 # Nb. scale_grad_by_freq is not used in the forward 1257 if indices.ndim <= 1: 1258 # We need this one as weight[indices] calls item() in these cases 1259 out = weight.index_select(0, indices) 1260 if indices.ndim == 0: 1261 out = out.squeeze(0) 1262 return out 1263 else: 1264 return weight[indices] 1265 1266 1267@register_decomposition(aten.embedding_dense_backward) 1268@out_wrapper() 1269def embedding_dense_backward( 1270 grad_output: Tensor, 1271 indices: Tensor, 1272 num_weights: int, 1273 padding_idx: int, 1274 scale_grad_by_freq: bool, 1275): 1276 computation_dtype, result_dtype = utils.elementwise_dtypes( 1277 grad_output, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 1278 ) 1279 grad_output = grad_output.to(computation_dtype) 1280 indices = _maybe_convert_to_dtype(indices, torch.long) # type: ignore[assignment] 1281 if scale_grad_by_freq: 1282 counts = indices.new_zeros((num_weights,)) 1283 ones = torch.ones_like(indices) 1284 counts = aten._unsafe_index_put(counts, [indices], ones, accumulate=True) 1285 grad_weights_scale = counts[indices] 1286 grad_output = grad_output / grad_weights_scale.unsqueeze(-1) 1287 1288 mask = _unsqueeze_to_dim(indices == padding_idx, grad_output.ndim) 1289 grad = grad_output.masked_fill(mask, 0) 1290 grad_weight = grad_output.new_zeros( 1291 (num_weights,) + grad_output.shape[indices.ndim :] 1292 ) 1293 return aten._unsafe_index_put(grad_weight, [indices], grad, accumulate=True).to( 1294 result_dtype 1295 ) 1296 1297 1298def prod(x: List[int]): 1299 r = 1 1300 for i in x: 1301 r *= i 1302 return r 1303 1304 1305def _pad_chunk( 1306 tensors: List[Tensor], 1307 dim: int, 1308 num_chunks: int, 1309) -> List[Tensor]: 1310 padded_tensors = [] 1311 for tensor in tensors: 1312 tensor_size = tensor.size() 1313 pad_along_dim = (tensor_size[dim] + num_chunks - 1) // num_chunks * num_chunks 1314 if pad_along_dim != tensor_size[dim]: 1315 # Use aten.constant_pad_nd instead of copy_ for functionalization 1316 pad = [0] * 2 * (tensor.ndim - dim - 1) + [ 1317 0, 1318 pad_along_dim - tensor_size[dim], 1319 ] 1320 tensor = aten.constant_pad_nd(tensor, pad, 0) 1321 view_size = tensor_size[:dim] + torch.Size([num_chunks, -1]) 1322 padded_tensors.append(tensor.view(view_size)) 1323 return padded_tensors 1324 1325 1326def have_same_ndims(tensors: List[Tensor]): 1327 ndim = tensors[0].ndim 1328 for tensor in tensors: 1329 if tensor.ndim != ndim: 1330 return False 1331 return True 1332 1333 1334def leading_dimension_matches(tensors: List[Tensor], dim: int): 1335 leading_dim_sizes = tensors[0].size()[:dim] 1336 for tensor in tensors: 1337 torch._check( 1338 tensor.size()[:dim] == leading_dim_sizes, 1339 lambda: "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors", 1340 ) 1341 1342 1343def _preprocess_chunk_cat_inputs( 1344 tensors: List[Tensor], 1345 dim: int, 1346 num_chunks: int, 1347): 1348 torch._check(num_chunks >= 1, lambda: "_chunk_cat expects positive num_chunks") 1349 torch._check( 1350 len(tensors) > 0, lambda: "_chunk_cat expects a non-empty input tensor list" 1351 ) 1352 expected_dtype = tensors[0].dtype 1353 expected_device = tensors[0].device 1354 for tensor in tensors: 1355 torch._check(tensor.numel() > 0, lambda: "_chunk_cat expects non-empty tensor") 1356 torch._check( 1357 tensor.dtype == expected_dtype, 1358 lambda: "_chunk_cat expects all input tensors with the same dtype", 1359 ) 1360 torch._check( 1361 tensor.device == expected_device, 1362 lambda: "_chunk_cat expects all inputs tensors on the same device", 1363 ) 1364 if have_same_ndims(tensors): 1365 dim = utils.canonicalize_dim(tensors[0].dim(), dim) 1366 else: 1367 torch._check( 1368 dim >= 0, 1369 lambda: "_chunk_cat expects non-negative dim when input tensors have different ndims", 1370 ) 1371 for tensor in tensors: 1372 torch._check( 1373 dim < tensor.ndim, 1374 lambda: "_chunk_cat expects dim < ndim for all input tensors", 1375 ) 1376 leading_dimension_matches(tensors, dim) 1377 return dim 1378 1379 1380@register_decomposition([aten._chunk_cat.default, aten._chunk_cat.out]) 1381def _chunk_cat( 1382 tensors: List[Tensor], 1383 dim: int, 1384 num_chunks: int, 1385 out: Optional[Tensor] = None, 1386) -> Tensor: 1387 dim = _preprocess_chunk_cat_inputs(tensors, dim, num_chunks) 1388 padded_tensors = _pad_chunk(tensors, dim, num_chunks) 1389 if out is None: 1390 return torch.cat(padded_tensors, dim + 1) 1391 else: 1392 torch.cat(padded_tensors, dim + 1, out=out) 1393 return out 1394 1395 1396@register_decomposition(aten.split_with_sizes) 1397def split_with_sizes( 1398 self: Tensor, split_sizes: List[int], dim: int = 0 1399) -> List[Tensor]: 1400 # NB: Perform the check_is_size tests first so that the 1401 # sum test does not try to do a replacement 1402 for i in range(len(split_sizes)): 1403 torch._check_is_size( 1404 split_sizes[i], 1405 lambda: "split_with_sizes expects split_sizes have only non-negative entries", 1406 ) 1407 torch._check_with( 1408 ValueError, 1409 sum(split_sizes) == self.shape[dim], 1410 lambda: f"Split sizes add up to {sum(split_sizes)} but got the tensor's size of {self.shape[dim]}", 1411 ) 1412 num_splits = len(split_sizes) 1413 splits = [] 1414 start_idx = 0 1415 1416 for i in range(num_splits): 1417 length = split_sizes[i] 1418 splits.append(self.narrow(dim, start_idx, length)) 1419 start_idx += length 1420 return splits 1421 1422 1423# out_wrapper currently does not allow optional outputs 1424@register_decomposition( 1425 [aten.split_with_sizes_copy.default, aten.split_with_sizes_copy.out] 1426) 1427def split_with_sizes_copy( 1428 self: Tensor, 1429 split_sizes: List[int], 1430 dim: int = 0, 1431 out: Optional[List[Tensor]] = None, 1432) -> Optional[List[Tensor]]: 1433 splits = split_with_sizes(self, split_sizes, dim=dim) 1434 if out is None: 1435 return [s.clone(memory_format=torch.contiguous_format) for s in splits] 1436 else: 1437 for output, split in zip(out, splits): 1438 _maybe_resize_out(output, split.shape) 1439 _safe_copy_out(copy_from=split, copy_to=output, exact_dtype=True) 1440 return None 1441 1442 1443@register_decomposition(aten.unsafe_split.Tensor) 1444def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]: 1445 return aten.split.Tensor(input, split_size, dim) 1446 1447 1448@register_decomposition(aten.unsafe_split_with_sizes.default) 1449def unsafe_split_with_sizes( 1450 input: Tensor, split_sizes: List[int], dim: int = 0 1451) -> Tuple[Tensor, ...]: 1452 return aten.split_with_sizes.default(input, split_sizes, dim) 1453 1454 1455@register_decomposition(aten.split.Tensor) 1456def split(self: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]: 1457 input_sizes = self.shape 1458 dim_size = input_sizes[dim] 1459 if split_size == 0: 1460 assert dim_size == 0 1461 return (self,) 1462 chunks = (dim_size + split_size - 1) // split_size 1463 1464 # Avoid importing sympy at a module level 1465 from torch.fx.experimental.symbolic_shapes import guard_int 1466 1467 chunks = guard_int(chunks) 1468 split_sizes = [split_size for i in range(chunks)] 1469 split_sizes[-1] = split_size - (split_size * chunks - dim_size) 1470 return torch.split(self, split_sizes, dim) 1471 1472 1473@aten.tensor_split.tensor_indices_or_sections.py_impl( 1474 DispatchKey.CompositeImplicitAutograd 1475) 1476def tensor_split_tensor_indices_or_sections_py_impl( 1477 self: Tensor, 1478 tensor_indices_or_sections: Tensor, 1479 dim: int = 0, 1480) -> Tuple[Tensor, ...]: 1481 assert tensor_indices_or_sections.device.type == "cpu" 1482 assert tensor_indices_or_sections.dtype == torch.int64 1483 split_dim = tensor_indices_or_sections.dim() 1484 torch._check( 1485 split_dim == 1 or split_dim == 0, 1486 lambda: "tensor_split expected tensor_indices_or_sections to be a zero-dimensional " 1487 f"or one-dimensional tensor, but got a tensor with {split_dim} dims", 1488 ) 1489 if split_dim == 0: 1490 sections = tensor_indices_or_sections.item() 1491 assert isinstance(sections, IntLike) 1492 return self.tensor_split(sections, dim) 1493 else: 1494 indices = [i.item() for i in tensor_indices_or_sections] 1495 # WARNING: Tempted to torch._check_is_size on the indices here? You 1496 # can't: tensor_split works with negative values in indices: 1497 # 1498 # >>> torch.tensor_split(torch.randn(10), torch.tensor([-5, 5])) 1499 # (tensor([ 0.3540, 2.1074, -0.8507, 1.1639, 0.3055]), tensor([]), 1500 # tensor([-0.4285, 1.0692, -0.1776, 0.9362, 1.6143])) 1501 # 1502 # Sorry, I don't make the rules. Explicitly do the item call in user 1503 # code if you KNOW that they are non-negative. 1504 return self.tensor_split(indices, dim) 1505 1506 1507# TODO: this doesn't appear to have enough precision in bfloat16 1508@register_decomposition(aten.addmm) 1509@out_wrapper() 1510@pw_cast_for_opmath 1511def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1): 1512 if not self.is_floating_point() and not self.is_complex(): 1513 beta = int(beta) 1514 alpha = int(alpha) 1515 out = alpha * torch.mm(mat1, mat2) 1516 if beta == 0: 1517 return out 1518 1519 # The output of aten.addmm is contiguous, we need to match this behavior in the decomposition. 1520 # The original implementation 'beta * self + out' would return a strided tensor if `self` is strided. 1521 # We thus use `out`, the output of torch.mm, which is always contiguous, as the first argument for addition. 1522 # This is relying on TensorIterator's behavior that it takes higher precedence on the stride of first input. 1523 # Alternative, we can write `(beta * self + out).contiguous()`, but it introduces another copy in some cases. 1524 # This implementation is not ideal, and we should revisit this when we have a better solution. 1525 return out + beta * self 1526 1527 1528@register_decomposition(aten._addmm_activation) 1529@out_wrapper() 1530@pw_cast_for_opmath 1531def _addmm_activation( 1532 self: Tensor, 1533 mat1: Tensor, 1534 mat2: Tensor, 1535 beta: int = 1, 1536 alpha: int = 1, 1537 use_gelu: bool = False, 1538): 1539 out = addmm(self, mat1, mat2, beta, alpha) 1540 if use_gelu: 1541 if self.is_cuda: 1542 return aten.gelu(out, approximate="tanh") 1543 else: 1544 return aten.gelu(out) 1545 return aten.relu(out) 1546 1547 1548@register_decomposition(aten.addmv) 1549@out_wrapper() 1550@pw_cast_for_opmath 1551def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1): 1552 if not self.is_floating_point() and not self.is_complex(): 1553 beta = int(beta) 1554 alpha = int(alpha) 1555 out = alpha * torch.mv(mat1, vec) 1556 if beta == 0: 1557 return out 1558 return out + beta * self 1559 1560 1561@register_decomposition(aten.native_group_norm_backward.default) 1562@pw_cast_for_opmath 1563def native_group_norm_backward( 1564 grad_output: Tensor, 1565 input: Tensor, 1566 mean: Tensor, 1567 rstd: Tensor, 1568 gamma: Optional[Tensor], 1569 N: int, 1570 C: int, 1571 HxW: int, 1572 group: int, 1573 output_mask: List[bool], 1574) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: 1575 utils.check_same_device( 1576 grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False 1577 ) 1578 utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False) 1579 utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False) 1580 torch._check( 1581 input.numel() == N * C * HxW, 1582 lambda: f"Expect input to have {N * C * HxW} elements", 1583 ) 1584 torch._check( 1585 mean.shape == (N, group), 1586 lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}", 1587 ) 1588 torch._check( 1589 gamma is None or gamma.numel() == C, 1590 lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}", 1591 ) 1592 1593 cpg, _rem = divmod(C, group) 1594 torch._check( 1595 _rem == 0, 1596 lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}", 1597 ) 1598 1599 # Compute Internal gradients 1600 ds = torch.mul(grad_output, input).view(N, C, HxW).sum(dim=[2]) 1601 db = grad_output.view(N, C, HxW).sum(dim=[2]) 1602 1603 d_input: Optional[Tensor] = None 1604 d_gamma: Optional[Tensor] = None 1605 d_bias: Optional[Tensor] = None 1606 if output_mask[0]: 1607 s = 1.0 / (HxW * cpg) 1608 if gamma is not None: 1609 ds_val = torch.mul(ds, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2) 1610 db_val = torch.mul(db, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2) 1611 c1 = torch.mul( 1612 rstd.unsqueeze(-1), 1613 gamma.reshape(1, group, cpg), 1614 ) 1615 else: 1616 ds_val = ds.reshape(N, group, cpg).sum(2) 1617 db_val = db.reshape(N, group, cpg).sum(2) 1618 c1 = torch.mul( 1619 rstd.unsqueeze(-1), 1620 torch.ones((1, group, cpg), device=rstd.device), 1621 ) 1622 c2 = (db_val * mean - ds_val) * rstd * rstd * rstd * s 1623 c3 = -c2 * mean - db_val * rstd * s 1624 1625 c1 = c1.unsqueeze(-1) 1626 c2 = _unsqueeze_to_dim(c2, 4) 1627 c3 = _unsqueeze_to_dim(c3, 4) 1628 d_input = ( 1629 torch.mul(grad_output.reshape(N, group, cpg, HxW), c1) 1630 + torch.mul(input.reshape(N, group, cpg, HxW), c2) 1631 + c3 1632 ) 1633 d_input = d_input.reshape(input.shape).to(input.dtype) 1634 if output_mask[1]: 1635 d_gamma = ( 1636 ( 1637 (ds.view(N, group, cpg) - db.view(N, group, cpg) * mean.unsqueeze(-1)) 1638 * rstd.unsqueeze(-1) 1639 ) 1640 .sum(dim=[0]) 1641 .reshape(C) 1642 ) 1643 if output_mask[2]: 1644 d_bias = db.sum(dim=[0]) 1645 1646 return (d_input, d_gamma, d_bias) 1647 1648 1649# out_wrapper currently does not allow optional outputs 1650@register_decomposition(aten.native_group_norm_backward.out) 1651def native_group_norm_backward_out( 1652 grad_output: Tensor, 1653 input: Tensor, 1654 mean: Tensor, 1655 rstd: Tensor, 1656 gamma: Optional[Tensor], 1657 N: int, 1658 C: int, 1659 HxW: int, 1660 group: int, 1661 output_mask: List[bool], 1662 *, 1663 out0: torch.Tensor, 1664 out1: torch.Tensor, 1665 out2: torch.Tensor, 1666) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: 1667 result = native_group_norm_backward( 1668 grad_output, input, mean, rstd, gamma, N, C, HxW, group, output_mask 1669 ) 1670 grad_input = (out0, out1, out2) 1671 for i, r in enumerate(result): 1672 if r is not None: 1673 _maybe_resize_out(grad_input[i], r.shape) 1674 _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True) 1675 1676 return grad_input 1677 1678 1679def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]: 1680 if x is not None: 1681 return x.to(dtype) 1682 return x 1683 1684 1685# TODO: Take a closer look at the type promotion semantics 1686@register_decomposition(aten.native_layer_norm_backward.default) 1687def native_layer_norm_backward( 1688 grad_out: Tensor, 1689 input: Tensor, 1690 normalized_shape: List[int], 1691 mean: Tensor, 1692 rstd: Tensor, 1693 weight: Optional[Tensor], 1694 bias: Optional[Tensor], 1695 output_mask: List[bool], 1696) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: 1697 input_shape = input.shape 1698 input_ndim = input.dim() 1699 computation_dtype = utils.get_computation_dtype(input.dtype) 1700 grad_out_cast, input_cast, weight_cast, bias_cast = ( 1701 x.to(computation_dtype).contiguous() if x is not None else x 1702 for x in (grad_out, input, weight, bias) 1703 ) 1704 assert grad_out_cast is not None 1705 1706 axis = input_ndim - len(normalized_shape) 1707 inner_dims = input_shape[axis:] 1708 outer_dims = input_shape[:axis] 1709 inner_dim_indices: List[int] = [] 1710 outer_dim_indices: List[int] = [] 1711 for i in range(input_ndim): 1712 if i >= axis: 1713 inner_dim_indices.append(i) 1714 else: 1715 outer_dim_indices.append(i) 1716 1717 N = prod(inner_dims) # type: ignore[arg-type] 1718 M = prod(outer_dims) # type: ignore[arg-type] 1719 if M <= 0 or N <= 0: 1720 return ( 1721 input.new_zeros(input_shape) if output_mask[0] else None, 1722 input.new_zeros(input_shape[axis:]) if output_mask[1] else None, 1723 input.new_zeros(input_shape[axis:]) if output_mask[2] else None, 1724 ) 1725 mean = _unsqueeze_to_dim(mean, input_cast.dim()) # type: ignore[union-attr] 1726 rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr] 1727 x_hat = (input_cast - mean) * rstd 1728 if weight_cast is not None: 1729 grad_x_hat = grad_out_cast * weight_cast 1730 else: 1731 grad_x_hat = grad_out_cast 1732 a = grad_x_hat * N 1733 b = torch.sum(grad_x_hat, inner_dim_indices, True) 1734 c1 = torch.mul(grad_x_hat, x_hat) 1735 c2 = torch.sum(c1, inner_dim_indices, True) 1736 c3 = torch.mul(x_hat, c2) 1737 1738 inner = a - b - c3 1739 d_input: Optional[Tensor] = None 1740 d_weight: Optional[Tensor] = None 1741 d_bias: Optional[Tensor] = None 1742 if output_mask[0]: 1743 d_input = (rstd / N) * inner 1744 1745 if output_mask[1] and weight_cast is not None: 1746 if len(outer_dim_indices) > 0: 1747 d_weight = torch.sum(grad_out_cast * x_hat, outer_dim_indices, False) 1748 else: 1749 d_weight = grad_out_cast * x_hat 1750 1751 if output_mask[2] and bias_cast is not None: 1752 if len(outer_dim_indices) > 0: 1753 d_bias = torch.sum(grad_out_cast, outer_dim_indices, False) 1754 else: 1755 d_bias = grad_out_cast.clone() 1756 1757 return ( 1758 _maybe_cast(d_input, input.dtype), 1759 _maybe_cast(d_weight, input.dtype), 1760 _maybe_cast(d_bias, input.dtype), 1761 ) 1762 1763 1764# out_wrapper currently does not allow optional outputs 1765@register_decomposition(aten.native_layer_norm_backward.out) 1766def native_layer_norm_backward_out( 1767 grad_out: Tensor, 1768 input: Tensor, 1769 normalized_shape: List[int], 1770 mean: Tensor, 1771 rstd: Tensor, 1772 weight: Optional[Tensor], 1773 bias: Optional[Tensor], 1774 output_mask: List[bool], 1775 *, 1776 out0: torch.Tensor, 1777 out1: torch.Tensor, 1778 out2: torch.Tensor, 1779) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: 1780 result = native_layer_norm_backward( 1781 grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask 1782 ) 1783 grad_input = (out0, out1, out2) 1784 for i, r in enumerate(result): 1785 if r is not None: 1786 _maybe_resize_out(grad_input[i], r.shape) 1787 _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True) 1788 1789 return grad_input 1790 1791 1792def native_batch_norm_helper( 1793 input: Tensor, 1794 weight: Optional[Tensor], 1795 bias: Optional[Tensor], 1796 running_mean: Optional[Tensor], 1797 running_var: Optional[Tensor], 1798 training: bool, 1799 momentum: float, 1800 eps: float, 1801 functional: bool, 1802) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: 1803 reduction_dims = [0] + list(range(2, input.dim())) 1804 computation_dtype = utils.get_computation_dtype(input.dtype) 1805 new_running_mean = running_mean 1806 new_running_var = running_var 1807 if training: 1808 computation_dtype = utils.get_computation_dtype(input.dtype) 1809 input_acc = input.to(dtype=computation_dtype) 1810 biased_var, mean = torch.var_mean( 1811 input_acc, dim=reduction_dims, correction=0, keepdim=True 1812 ) 1813 rstd = torch.rsqrt(biased_var + eps) 1814 1815 output = (input - mean) * rstd 1816 1817 save_mean = torch.squeeze(mean, reduction_dims) 1818 save_rstd = torch.squeeze(rstd, reduction_dims) 1819 if running_mean is not None: 1820 new_running_mean = momentum * save_mean + (1 - momentum) * running_mean 1821 if not functional: 1822 running_mean.copy_(new_running_mean) 1823 if running_var is not None: 1824 n = input.numel() / input.shape[1] 1825 # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction 1826 # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose 1827 # numerics probably don't matter. 1828 squeezed_var = torch.squeeze(biased_var, reduction_dims) 1829 unbiased_var = squeezed_var * (n / (n - 1)) 1830 new_running_var = momentum * unbiased_var + (1 - momentum) * running_var 1831 if not functional: 1832 running_var.copy_(new_running_var) 1833 else: 1834 assert running_mean is not None and running_var is not None 1835 running_mean = running_mean.to(dtype=computation_dtype, copy=True) 1836 new_running_mean = running_mean 1837 running_var = running_var.to(dtype=computation_dtype, copy=True) 1838 new_running_var = running_var 1839 mean = running_mean 1840 invstd = 1 / (torch.sqrt(running_var + eps)) 1841 # Very annoying inconsistency where CPU and CUDA give different shapes 1842 if input.device.type != "cpu": 1843 save_mean = running_mean 1844 save_rstd = invstd 1845 else: 1846 save_mean = input.new_zeros((0,)) 1847 save_rstd = input.new_zeros((0,)) 1848 mean = _unsqueeze_to_dim(mean, input.dim() - 1) 1849 invstd = _unsqueeze_to_dim(invstd, input.dim() - 1) 1850 output = (input - mean) * invstd 1851 1852 if weight is not None: 1853 weight = weight.flatten() 1854 weight = _unsqueeze_to_dim(weight, input.dim() - 1) 1855 output = output * weight 1856 1857 if bias is not None: 1858 bias = bias.flatten() 1859 bias = _unsqueeze_to_dim(bias, input.dim() - 1) 1860 output = output + bias 1861 1862 if input.device.type == "cpu": 1863 save_mean = save_mean.to(dtype=input.dtype) 1864 save_rstd = save_rstd.to(dtype=input.dtype) 1865 return ( 1866 output.to(dtype=input.dtype), 1867 save_mean, 1868 save_rstd, 1869 new_running_mean, 1870 new_running_var, 1871 ) 1872 1873 1874@register_decomposition(aten.native_batch_norm) 1875@out_wrapper("out", "save_mean", "save_invstd") 1876def native_batch_norm( 1877 input: Tensor, 1878 weight: Optional[Tensor], 1879 bias: Optional[Tensor], 1880 running_mean: Optional[Tensor], 1881 running_var: Optional[Tensor], 1882 training: bool, 1883 momentum: float, 1884 eps: float, 1885) -> Tuple[Tensor, Tensor, Tensor]: 1886 output, save_mean, save_rstd, _, _ = native_batch_norm_helper( 1887 input, weight, bias, running_mean, running_var, training, momentum, eps, False 1888 ) 1889 return output, save_mean, save_rstd 1890 1891 1892# TODO: this decomposition is NOT here to stay. We would much prefer replacing native_batch_norm 1893# with our new correctly schema'd _native_batch_norm_legit and its variants, but 1894# we cannot do that immediately in the C++ because it would be forwards incompatible 1895# with some mobile use cases. 1896# 1897# Since this change is most impactful for aot autograd/functionalization, we simply 1898# register this decomposition on the Autograd key for the python dispatcher (which is 1899# currently only used by aot autograd/functionalization and no one else, really). 1900# In two weeks or so, we should remove this decomposition and phase out the current native_batch_norm 1901# to be _native_batch_norm_legit and have the right schema (stating that there are input mutations). 1902@aten.native_batch_norm.default.py_impl(DispatchKey.Autograd) 1903@aten.native_batch_norm.default.py_impl(DispatchKey.CompositeImplicitAutograd) 1904def native_batch_norm_decomposition( 1905 input: Tensor, 1906 weight: Optional[Tensor], 1907 bias: Optional[Tensor], 1908 running_mean: Optional[Tensor], 1909 running_var: Optional[Tensor], 1910 training: bool, 1911 momentum: float, 1912 eps: float, 1913) -> Tuple[Tensor, Tensor, Tensor]: 1914 if running_mean is None and running_var is None: 1915 return aten._native_batch_norm_legit( 1916 input, weight, bias, training, momentum, eps 1917 ) 1918 if running_mean is None: 1919 raise RuntimeError( 1920 "running_mean is None, but running_var is provided. " 1921 "They should both be None or both be provided." 1922 ) 1923 if running_var is None: 1924 raise RuntimeError( 1925 "running_var is None, but running_mean is provided. " 1926 "They should both be None or both be provided." 1927 ) 1928 if training: 1929 # HACK: batch norm consolidation should clean this up so this op doesn't take in a training arg. 1930 return aten._native_batch_norm_legit( 1931 input, weight, bias, running_mean, running_var, training, momentum, eps 1932 ) 1933 else: 1934 return aten._native_batch_norm_legit_no_training( 1935 input, weight, bias, running_mean, running_var, momentum, eps 1936 ) 1937 1938 1939@aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd) 1940def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> List[Tensor]: 1941 dim_size = tensor.size(dim) 1942 split_size = (dim_size + chunks - 1) // chunks 1943 1944 if split_size == 0 and dim_size == 0: 1945 split_sizes = [split_size for _ in chunks] 1946 split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size) 1947 return torch.ops.aten.unsafe_split_with_sizes.default(tensor, split_sizes, dim) 1948 return torch.ops.aten.unsafe_split.Tensor(tensor, split_size, dim) 1949 1950 1951@register_decomposition(aten._native_batch_norm_legit_no_training.default) 1952def _native_batch_norm_legit_no_training( 1953 input: Tensor, 1954 weight: Optional[Tensor], 1955 bias: Optional[Tensor], 1956 running_mean: Tensor, 1957 running_var: Tensor, 1958 momentum: float, 1959 eps: float, 1960) -> Tuple[Tensor, Tensor, Tensor]: 1961 return aten._native_batch_norm_legit.default( 1962 input, 1963 weight, 1964 bias, 1965 running_mean, 1966 running_var, 1967 False, # training 1968 momentum, 1969 eps, 1970 ) 1971 1972 1973@register_decomposition(aten._native_batch_norm_legit.default) 1974def _native_batch_norm_legit( 1975 input: Tensor, 1976 weight: Optional[Tensor], 1977 bias: Optional[Tensor], 1978 running_mean: Tensor, 1979 running_var: Tensor, 1980 training: bool, 1981 momentum: float, 1982 eps: float, 1983) -> Tuple[Tensor, Tensor, Tensor]: 1984 output, save_mean, save_rstd, _, _ = native_batch_norm_helper( 1985 input, weight, bias, running_mean, running_var, training, momentum, eps, False 1986 ) 1987 return output, save_mean, save_rstd 1988 1989 1990@register_decomposition(aten._native_batch_norm_legit.no_stats) 1991def _native_batch_norm_legit_no_stats( 1992 input: Tensor, 1993 weight: Optional[Tensor], 1994 bias: Optional[Tensor], 1995 training: bool, 1996 momentum: float, 1997 eps: float, 1998) -> Tuple[Tensor, Tensor, Tensor]: 1999 output, save_mean, save_rstd, _, _ = native_batch_norm_helper( 2000 input, weight, bias, None, None, training, momentum, eps, False 2001 ) 2002 return output, save_mean, save_rstd 2003 2004 2005@register_decomposition(aten._native_batch_norm_legit_functional.default) 2006def _native_batch_norm_legit_functional( 2007 input: Tensor, 2008 weight: Optional[Tensor], 2009 bias: Optional[Tensor], 2010 running_mean: Tensor, 2011 running_var: Tensor, 2012 training: bool, 2013 momentum: float, 2014 eps: float, 2015) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: 2016 ( 2017 output, 2018 save_mean, 2019 save_rstd, 2020 new_running_mean, 2021 new_running_var, 2022 ) = native_batch_norm_helper( 2023 input, weight, bias, running_mean, running_var, training, momentum, eps, True 2024 ) 2025 assert new_running_mean is not None, "new_running_mean should not be None" 2026 assert new_running_var is not None, "new_running_var should not be None" 2027 return output, save_mean, save_rstd, new_running_mean, new_running_var 2028 2029 2030def _get_batch_norm_reserve_tensor( 2031 input: Tensor, 2032 weight: Optional[Tensor], 2033 bias: Optional[Tensor], 2034 running_mean: Tensor, 2035 running_var: Tensor, 2036 eps: float, 2037 training: bool, 2038) -> Tensor: 2039 """ 2040 Return a reserve tensor for batch norm, used only by cudnn to pass forward state to the 2041 backward pass. This is needed for `_batch_norm_with_update` and `_batch_norm_no_update`, 2042 which support a variety of backends including cudnn. We create this tensor here to get 2043 the correct shape in the traced graph if we detect that will call the cudnn kernel, 2044 and rely on DCE to avoid materializing this tensor. 2045 """ 2046 backend = torch._C._select_batch_norm_backend( # type: ignore[attr-defined] 2047 input, weight, bias, running_mean, running_var, True, eps 2048 ) 2049 reserve_size = 0 2050 if backend == torch._C._BatchNormBackend.Cudnn: # type: ignore[attr-defined] 2051 reserve_size = torch._C._get_cudnn_batch_norm_reserve_space_size(input, training) # type: ignore[attr-defined] 2052 return torch.empty( 2053 reserve_size, dtype=torch.uint8, layout=input.layout, device=input.device 2054 ) 2055 2056 2057@register_decomposition(aten._batch_norm_with_update.default) 2058def _batch_norm_with_update( 2059 input: Tensor, 2060 weight: Optional[Tensor], 2061 bias: Optional[Tensor], 2062 running_mean: Tensor, 2063 running_var: Tensor, 2064 momentum: float, 2065 eps: float, 2066) -> Tuple[Tensor, Tensor, Tensor, Tensor]: 2067 output, save_mean, save_rstd, _, _ = native_batch_norm_helper( 2068 input, 2069 weight, 2070 bias, 2071 running_mean, 2072 running_var, 2073 True, # training 2074 momentum, 2075 eps, 2076 False, # functional 2077 ) 2078 reserve = _get_batch_norm_reserve_tensor( 2079 input, weight, bias, running_mean, running_var, eps, training=True 2080 ) 2081 return output, save_mean, save_rstd, reserve 2082 2083 2084@register_decomposition(aten._batch_norm_with_update_functional.default) 2085def _batch_norm_with_update_functional( 2086 input: Tensor, 2087 weight: Optional[Tensor], 2088 bias: Optional[Tensor], 2089 running_mean: Tensor, 2090 running_var: Tensor, 2091 momentum: float, 2092 eps: float, 2093) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: 2094 ( 2095 output, 2096 save_mean, 2097 save_rstd, 2098 new_rm, 2099 new_rv, 2100 ) = native_batch_norm_helper( 2101 input, weight, bias, running_mean, running_var, True, momentum, eps, True 2102 ) 2103 reserve = _get_batch_norm_reserve_tensor( 2104 input, weight, bias, running_mean, running_var, eps, training=True 2105 ) 2106 assert new_rm is not None, "new_running_mean should not be None" 2107 assert new_rv is not None, "new_running_var should not be None" 2108 return (output, save_mean, save_rstd, reserve, new_rm, new_rv) 2109 2110 2111@register_decomposition(aten._batch_norm_no_update.default) 2112def _batch_norm_no_update( 2113 input: Tensor, 2114 weight: Optional[Tensor], 2115 bias: Optional[Tensor], 2116 running_mean: Tensor, 2117 running_var: Tensor, 2118 momentum: float, 2119 eps: float, 2120) -> Tuple[Tensor, Tensor, Tensor, Tensor]: 2121 output, save_mean, save_rstd, _, _ = native_batch_norm_helper( 2122 input, 2123 weight, 2124 bias, 2125 running_mean, 2126 running_var, 2127 False, # training 2128 momentum, 2129 eps, 2130 False, # functional 2131 ) 2132 reserve = _get_batch_norm_reserve_tensor( 2133 input, weight, bias, running_mean, running_var, eps, training=False 2134 ) 2135 return output, save_mean, save_rstd, reserve 2136 2137 2138@register_decomposition(aten._fused_dropout) 2139@out_wrapper("out0", "out1") 2140@pw_cast_for_opmath 2141def _fused_dropout_decomposition(input, p, generator=None): 2142 assert generator is None 2143 mask = (torch.rand_like(input) < p).to(dtype=torch.uint8) 2144 res = mask.type_as(input) * input * (1.0 / p) 2145 return (res, mask) 2146 2147 2148@register_decomposition(aten._to_copy) 2149@out_wrapper() 2150def _to_copy( 2151 x: Union[Tensor, NumberType], 2152 *, 2153 dtype: Optional[torch.dtype] = None, 2154 layout=None, 2155 device: Optional[torch.device] = None, 2156 pin_memory: bool = False, 2157 non_blocking: bool = False, 2158 memory_format: Optional[torch.memory_format] = None, 2159): 2160 assert not layout or layout == torch.strided, "TODO" 2161 assert not pin_memory, "TODO" 2162 assert isinstance(x, (torch.Tensor, int, float, bool, complex)) 2163 if device is None and dtype is None and memory_format is None: 2164 if isinstance(x, torch.Tensor): 2165 return x.clone() 2166 else: 2167 return x 2168 dtype_converted = False 2169 2170 if isinstance(x, torch.Tensor): 2171 x_tensor = x 2172 else: 2173 x_tensor = torch.scalar_tensor(x) 2174 2175 if device is not None and device != x_tensor.device: 2176 # avoid conversions on cpu 2177 if dtype is not None and device.type == "cpu": 2178 x_tensor = torch._prims.convert_element_type(x_tensor, dtype) 2179 dtype_converted = True 2180 x_tensor = torch._prims.device_put(x_tensor, device) 2181 2182 if dtype is not None and not dtype_converted: 2183 x_tensor = torch._prims.convert_element_type(x_tensor, dtype) 2184 dtype_converted = True 2185 2186 if memory_format is not None: # no ref/prim for memory format 2187 return torch.clone(x_tensor, memory_format=memory_format) 2188 return x_tensor 2189 2190 2191# Questionable decompositions 2192# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced. 2193# Note that this decomposition causes issues with in-place ops 2194@register_decomposition([aten.detach, aten.lift, aten.lift_fresh]) 2195@out_wrapper() 2196def nop_decomposition(x): 2197 return aten.alias(x) 2198 2199 2200# Also register to the Autograd dispatch key, so this decomp can run above autograd. 2201# native_batch_norm needs to decompose into other ops before autograd. 2202@aten.cudnn_batch_norm.default.py_impl(DispatchKey.Autograd) 2203@register_decomposition(aten.cudnn_batch_norm) 2204@out_wrapper("out0", "out1", "out2", "out3") 2205def cudnn_batch_norm( 2206 input: Tensor, 2207 weight: Tensor, 2208 bias: Optional[Tensor], 2209 running_mean: Optional[Tensor], 2210 running_var: Optional[Tensor], 2211 training: bool, 2212 exponential_average_factor: float, 2213 epsilon: float, 2214): 2215 a, b, c = aten.native_batch_norm( 2216 input, 2217 weight, 2218 bias, 2219 running_mean, 2220 running_var, 2221 training, 2222 exponential_average_factor, 2223 epsilon, 2224 ) 2225 # Cudnn return running mean and variance when training is True 2226 if training: 2227 return (a, b, c, input.new_zeros((0,), dtype=torch.uint8)) 2228 return ( 2229 a, 2230 weight.new_zeros((0,)), 2231 weight.new_zeros((0,)), 2232 input.new_zeros((0,), dtype=torch.uint8), 2233 ) 2234 2235 2236def _broadcast_batch_norm_backward(x, broadcast_mask): 2237 for axis, mask in enumerate(broadcast_mask): 2238 if mask == 1 and not (axis < x.ndim and x.shape[axis] == mask): 2239 x = x.unsqueeze(axis) 2240 return x 2241 2242 2243@register_decomposition(aten.batch_norm_backward.default) 2244def batch_norm_backward( 2245 grad_out: Tensor, 2246 input: Tensor, 2247 weight: Optional[Tensor], 2248 running_mean: Optional[Tensor], 2249 running_var: Optional[Tensor], 2250 save_mean: Optional[Tensor], 2251 save_invstd: Optional[Tensor], 2252 train: bool, 2253 eps: float, 2254 output_mask: List[bool], 2255 reserve: Tensor, 2256) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: 2257 return native_batch_norm_backward( 2258 grad_out, 2259 input, 2260 weight, 2261 running_mean, 2262 running_var, 2263 save_mean, 2264 save_invstd, 2265 train, 2266 eps, 2267 output_mask, 2268 ) 2269 2270 2271@register_decomposition(aten.native_batch_norm_backward.default) 2272def native_batch_norm_backward( 2273 grad_out: Tensor, 2274 input: Tensor, 2275 weight: Optional[Tensor], 2276 running_mean: Optional[Tensor], 2277 running_var: Optional[Tensor], 2278 save_mean: Optional[Tensor], 2279 save_invstd: Optional[Tensor], 2280 train: bool, 2281 eps: float, 2282 output_mask: List[bool], 2283) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: 2284 input_dtype = input.dtype 2285 if weight is not None: 2286 weight_dtype = weight.dtype 2287 else: 2288 weight_dtype = input_dtype 2289 computation_dtype = utils.get_computation_dtype(input.dtype) 2290 ( 2291 grad_out_cast, 2292 input_cast, 2293 weight_cast, 2294 running_mean_cast, 2295 running_var_cast, 2296 save_mean_cast, 2297 save_invstd_cast, 2298 ) = ( 2299 x.to(computation_dtype) if x is not None else x 2300 for x in ( 2301 grad_out, 2302 input, 2303 weight, 2304 running_mean, 2305 running_var, 2306 save_mean, 2307 save_invstd, 2308 ) 2309 ) 2310 input_shape = input.shape 2311 input_rank = input.dim() 2312 assert input_rank >= 2, "rank of the input must be at least 2" 2313 2314 axis = 1 2315 num_features = prod(list(input_shape)) / input_shape[axis] 2316 mean = save_mean_cast 2317 invstd = save_invstd_cast 2318 if train: 2319 assert save_mean_cast is not None and save_invstd_cast is not None 2320 else: 2321 assert running_mean_cast is not None and running_var_cast is not None 2322 mean = running_mean_cast 2323 invstd = torch.rsqrt(running_var_cast + eps) 2324 2325 broadcast_mask: List[int] = [1] * input_rank 2326 broadcast_mask[axis] = input_shape[axis] 2327 2328 reduction_axes: List[int] = [] 2329 for i in range(input_rank): 2330 if i != axis: 2331 reduction_axes.append(i) 2332 2333 mean = _broadcast_batch_norm_backward(mean, broadcast_mask) # type: ignore[arg-type] 2334 norm = 1.0 / num_features 2335 grad_output_sum = torch.sum(grad_out_cast, reduction_axes) # type: ignore[arg-type] 2336 dot_p = torch.sum(grad_out_cast * (input_cast - mean), reduction_axes) # type: ignore[operator] 2337 2338 grad_mean = _broadcast_batch_norm_backward(grad_output_sum * norm, broadcast_mask) 2339 proj_scale = _broadcast_batch_norm_backward(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) # type: ignore[operator] 2340 2341 if weight_cast is None: 2342 grad_scale = _broadcast_batch_norm_backward(invstd, broadcast_mask) * 1.0 # type: ignore[arg-type] 2343 else: 2344 grad_scale = _broadcast_batch_norm_backward( 2345 invstd * weight_cast, broadcast_mask 2346 ) 2347 2348 if train: 2349 proj = (input_cast - mean) * proj_scale # type: ignore[operator] 2350 grad_input = ((grad_out_cast - proj) - grad_mean) * grad_scale 2351 else: 2352 grad_input = grad_out_cast * grad_scale 2353 2354 if output_mask[1]: 2355 grad_weight = dot_p * invstd 2356 else: 2357 grad_weight = None # "None" doesn't work with vjp, should use zeros for vjp 2358 2359 if output_mask[2]: 2360 grad_bias = grad_output_sum 2361 else: 2362 grad_bias = None # "None" doesn't work with vjp, should use zeros for vjp 2363 2364 return ( 2365 grad_input.to(input_dtype), 2366 _maybe_cast(grad_weight, weight_dtype), 2367 _maybe_cast(grad_bias, weight_dtype), 2368 ) 2369 2370 2371# out_wrapper currently does not allow optional outputs 2372@register_decomposition(aten.native_batch_norm_backward.out) 2373def native_batch_norm_backward_out( 2374 grad_out: Tensor, 2375 input: Tensor, 2376 weight: Optional[Tensor], 2377 running_mean: Optional[Tensor], 2378 running_var: Optional[Tensor], 2379 save_mean: Optional[Tensor], 2380 save_invstd: Optional[Tensor], 2381 train: bool, 2382 eps: float, 2383 output_mask: List[bool], 2384 *, 2385 out0: torch.Tensor, 2386 out1: torch.Tensor, 2387 out2: torch.Tensor, 2388) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: 2389 result = native_batch_norm_backward( 2390 grad_out, 2391 input, 2392 weight, 2393 running_mean, 2394 running_var, 2395 save_mean, 2396 save_invstd, 2397 train, 2398 eps, 2399 output_mask, 2400 ) 2401 grad_input = (out0, out1, out2) 2402 for i, r in enumerate(result): 2403 if r is not None: 2404 _maybe_resize_out(grad_input[i], r.shape) 2405 _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True) 2406 2407 return grad_input 2408 2409 2410@register_decomposition(aten.miopen_batch_norm_backward) 2411@out_wrapper("out0", "out1", "out2") 2412def miopen_batch_norm_backward( 2413 input: Tensor, 2414 grad_output: Tensor, 2415 weight: Tensor, 2416 running_mean: Optional[Tensor], 2417 running_var: Optional[Tensor], 2418 save_mean: Optional[Tensor], 2419 save_var: Optional[Tensor], 2420 epsilon: float, 2421): 2422 return aten.native_batch_norm_backward( 2423 grad_output, 2424 input, 2425 weight, 2426 running_mean, 2427 running_var, 2428 save_mean, 2429 save_var, 2430 True, 2431 epsilon, 2432 [True, True, True], 2433 ) 2434 2435 2436@register_decomposition(aten.cudnn_batch_norm_backward) 2437@out_wrapper("out0", "out1", "out2") 2438def cudnn_batch_norm_backward( 2439 input: Tensor, 2440 grad_output: Tensor, 2441 weight: Tensor, 2442 running_mean: Optional[Tensor], 2443 running_var: Optional[Tensor], 2444 save_mean: Optional[Tensor], 2445 save_var: Optional[Tensor], 2446 epsilon: float, 2447 reserveSpace: Tensor, 2448): 2449 return aten.native_batch_norm_backward( 2450 grad_output, 2451 input, 2452 weight, 2453 running_mean, 2454 running_var, 2455 save_mean, 2456 save_var, 2457 True, 2458 epsilon, 2459 [True, True, True], 2460 ) 2461 2462 2463@register_decomposition(aten._adaptive_avg_pool2d) 2464@out_wrapper() 2465@pw_cast_for_opmath 2466def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]): 2467 # Preconditions 2468 device = input.device 2469 shape = input.shape 2470 ndim = len(shape) 2471 torch._check( 2472 ndim in (3, 4), 2473 lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}", 2474 ) 2475 for d in input.shape[-2:]: 2476 torch._check( 2477 d != 0, 2478 lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for " 2479 f"non-batch dimensions, but input has shape {tuple(shape)}.", 2480 ) 2481 2482 # Optimisation (we should also do this in the kernel implementation) 2483 if shape[-2] % output_size[-2] == 0 and shape[-1] % output_size[-1] == 0: 2484 stride = tuple(i // o for i, o in zip(shape[-2:], output_size)) 2485 kernel = tuple( 2486 i - (o - 1) * s for i, o, s in zip(shape[-2:], output_size, stride) 2487 ) 2488 return torch.nn.functional.avg_pool2d(input, kernel, stride) 2489 2490 def start_index(a, b, c): 2491 return torch.div(a * c, b, rounding_mode="trunc") 2492 2493 def end_index(a, b, c): 2494 return torch.div((a + 1) * c + b - 1, b, rounding_mode="trunc") 2495 2496 def compute_idx(in_size, out_size): 2497 orange = torch.arange(out_size, device=device, dtype=torch.int64) 2498 i0 = start_index(orange, out_size, in_size) 2499 # Let length = end_index - start_index, i.e. the length of the pooling kernels 2500 # length.max() can be computed analytically as follows: 2501 maxlength = in_size // out_size + 1 2502 in_size_mod = in_size % out_size 2503 # adaptive = True iff there are kernels with different lengths 2504 adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0) 2505 if adaptive: 2506 maxlength += 1 2507 elif in_size_mod == 0: 2508 maxlength -= 1 2509 2510 range_max = torch.arange(maxlength, device=device, dtype=torch.int64) 2511 idx = i0.unsqueeze(-1) + range_max 2512 if adaptive: 2513 # Need to clamp to avoid accessing out-of-bounds memory 2514 # TODO make minimum accept scalars 2515 maxval = torch.scalar_tensor( 2516 in_size - 1, dtype=idx.dtype, device=idx.device 2517 ) 2518 idx = torch.minimum(idx, maxval) 2519 2520 # Compute the length 2521 i1 = end_index(orange, out_size, in_size) 2522 length = i1 - i0 2523 else: 2524 length = maxlength 2525 return idx, length, range_max, adaptive 2526 2527 # length is not None if it's constant, otherwise we'll need to compute it 2528 idxh, length_h, range_max_h, adaptive_h = compute_idx(shape[-2], output_size[-2]) 2529 idxw, length_w, range_max_w, adaptive_w = compute_idx(shape[-1], output_size[-1]) 2530 2531 vals = input[..., _unsqueeze_to_dim(idxh, 4), idxw] 2532 # Shortcut for the simpler case 2533 if not adaptive_h and not adaptive_w: 2534 return torch.mean(vals, dim=(-3, -1)) 2535 2536 def maybe_mask(vals, length, range_max, adaptive, dim): 2537 if isinstance(length, IntLike): 2538 return vals, length 2539 else: 2540 # zero-out the things we didn't really want to select 2541 assert dim < 0 2542 # hack 2543 mask = range_max >= length.unsqueeze(-1) 2544 if dim == -2: 2545 mask = _unsqueeze_to_dim(mask, 4) 2546 vals = torch.masked_fill(vals, mask, 0.0) 2547 # Compute the length of each window 2548 length = _unsqueeze_to_dim(length, -dim) 2549 return vals, length 2550 2551 vals, length_h = maybe_mask( 2552 vals, length_h, range_max_h, adaptive=adaptive_h, dim=-2 2553 ) 2554 vals, length_w = maybe_mask( 2555 vals, length_w, range_max_w, adaptive=adaptive_w, dim=-1 2556 ) 2557 2558 # We unroll the sum as we assume that the kernels are going to be small 2559 ret = None 2560 for i, j in product(range(vals.shape[-3]), range(vals.shape[-1])): 2561 if ret is None: 2562 ret = vals[..., i, :, j] 2563 else: 2564 ret = ret + vals[..., i, :, j] 2565 return ret / (length_h * length_w) 2566 2567 2568@register_decomposition(aten.index_add_) 2569def index_add_( 2570 x: TensorLike, 2571 dim: int, 2572 index: TensorLike, 2573 tensor: TensorLike, 2574 *, 2575 alpha: NumberType = 1, 2576): 2577 return _index_add(x, dim, index, tensor, inplace=True, alpha=alpha) 2578 2579 2580@register_decomposition(aten.index_add) 2581@out_wrapper() 2582def index_add( 2583 x: TensorLike, 2584 dim: int, 2585 index: TensorLike, 2586 tensor: TensorLike, 2587 *, 2588 alpha: NumberType = 1, 2589): 2590 return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha) 2591 2592 2593def _index_add( 2594 x: TensorLike, 2595 dim: int, 2596 index: TensorLike, 2597 tensor: TensorLike, 2598 *, 2599 inplace: bool, 2600 alpha: NumberType = 1, 2601): 2602 dim = utils.canonicalize_dims(x.ndim, dim) 2603 torch._check( 2604 index.ndim <= 1, 2605 lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", 2606 ) 2607 index_size = index.size(0) if index.ndim == 1 else 1 2608 tensor_size = tensor.size(dim) if tensor.ndim > 0 else 1 2609 torch._check( 2610 tensor_size == index_size, 2611 lambda: f"Number of indices ({index_size}) should be equal to tensor.size(dim) ({tensor_size}), for {dim=}", 2612 ) 2613 if alpha != 1: 2614 python_type = utils.dtype_to_type(x.dtype) 2615 torch._check( 2616 python_type == bool 2617 or utils.is_weakly_lesser_type(type(alpha), python_type), 2618 lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", 2619 ) 2620 tensor = tensor * alpha 2621 # Treat scalars as elements of \R^1 2622 zero_dim = x.ndim == 0 2623 x1 = x.unsqueeze(0) if zero_dim else x 2624 idx = (None,) * dim + (index,) 2625 index_put = aten.index_put_ if inplace else aten.index_put 2626 out = index_put(x1, idx, tensor, accumulate=True) 2627 if inplace: 2628 return x 2629 else: 2630 return out.squeeze(0) if zero_dim else out.contiguous() 2631 2632 2633@register_decomposition(aten.pad_sequence.default) 2634@aten.pad_sequence.default.py_impl(DispatchKey.CompositeImplicitAutograd) 2635def pad_sequence(sequences, batch_first=False, padding_value=0.0): 2636 torch._check(len(sequences) > 0, lambda: "received an empty list of sequences") 2637 sequences_size = len(sequences) 2638 max_size = sequences[0].size() 2639 trailing_dims = max_size[1:] 2640 max_len = max(x.size(0) for x in sequences) 2641 if batch_first: 2642 out_dims = (sequences_size, max_len) 2643 else: 2644 out_dims = (max_len, sequences_size) 2645 out_dims = out_dims + trailing_dims 2646 out = sequences[0].new_full(out_dims, padding_value) 2647 dim_paddings = (0, 0) * len(trailing_dims) 2648 for i in range(sequences_size): 2649 currseq = sequences[i] 2650 row = aten.constant_pad_nd( 2651 currseq, dim_paddings + (0, max_len - currseq.size(0)), padding_value 2652 ) 2653 if batch_first: 2654 out = aten.select_scatter(out, row, dim=0, index=i) 2655 else: 2656 out = aten.select_scatter(out, row, dim=1, index=i) 2657 return out 2658 2659 2660@register_decomposition(aten.index_copy_) 2661def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): 2662 return _index_copy(x, dim, index, tensor, inplace=True) 2663 2664 2665@register_decomposition(aten.index_copy) 2666@out_wrapper() 2667def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): 2668 return _index_copy(x, dim, index, tensor, inplace=False) 2669 2670 2671def _index_copy( 2672 x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool 2673): 2674 dim = utils.canonicalize_dims(x.ndim, dim) 2675 torch._check( 2676 index.ndim <= 1, 2677 lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", 2678 ) 2679 # Treat scalars as elements of \R^1 2680 zero_dim = x.ndim == 0 2681 x1 = x.unsqueeze(0) if zero_dim else x 2682 index = index.unsqueeze(0) if index.ndim == 0 else index 2683 idx = (None,) * dim + (index,) 2684 index_put = aten.index_put_ if inplace else aten.index_put 2685 out = index_put(x1, idx, tensor) 2686 if inplace: 2687 return x 2688 else: 2689 return out.squeeze(0) if zero_dim else out.contiguous() 2690 2691 2692# nb: Should use acc_t, not op_math 2693@register_decomposition(aten.log_sigmoid_forward) 2694@out_wrapper("output", "buffer") 2695@pw_cast_for_opmath 2696def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: 2697 min = torch.minimum(self.new_zeros(()), self) 2698 z = torch.exp(-torch.abs(self)) 2699 if self.is_cuda: 2700 buffer = self.new_zeros((0,)) 2701 else: 2702 buffer = z 2703 return min - torch.log1p(z), buffer 2704 2705 2706@register_decomposition(aten.uniform) 2707@out_wrapper() 2708def uniform( 2709 x: Tensor, 2710 low: Union[bool, int, float] = 0.0, 2711 high: Union[bool, int, float] = 1.0, 2712 generator: Optional[torch.Generator] = None, 2713): 2714 return prims._uniform_helper( 2715 x.shape, 2716 low=sym_float(low), 2717 high=sym_float(high), 2718 dtype=x.dtype, 2719 device=x.device, 2720 generator=generator, 2721 ) 2722 2723 2724@register_decomposition(aten.uniform_) 2725def uniform_(self, low=0, high=1, generator=None): 2726 return self.copy_(uniform(self, low, high, generator)) 2727 2728 2729# aten/src/ATen/native/UpSample.cpp compute_output_size 2730def upsample_compute_output_size(input_size, output_size, scale_factors): 2731 spatial_dimensions = len(input_size) - 2 2732 if output_size is not None: 2733 torch._check( 2734 scale_factors is None, 2735 lambda: "Must specify exactly one of output_size and scale_factors", 2736 ) 2737 torch._check(len(output_size) == spatial_dimensions, lambda: "") 2738 return output_size 2739 if scale_factors is not None: 2740 # NB: this isn't necessary lol 2741 torch._check( 2742 output_size is None, 2743 lambda: "Must specify exactly one of output_size and scale_factors", 2744 ) 2745 torch._check(len(scale_factors) == spatial_dimensions, lambda: "") 2746 output_size = [] 2747 for i, s in enumerate(scale_factors): 2748 if int(s) == s: 2749 output_size.append(input_size[i + 2] * int(s)) 2750 else: 2751 output_size.append(sym_int(input_size[i + 2] * s)) 2752 return output_size 2753 torch._check( 2754 False, lambda: "Must specify exactly one of output_size and scale_factors" 2755 ) 2756 2757 2758def get_scale_value(scales, idx): 2759 if scales is None: 2760 return None 2761 return scales[idx] 2762 2763 2764@register_decomposition(aten.upsample_nearest1d.vec) 2765@register_decomposition(aten.upsample_nearest2d.vec) 2766@register_decomposition(aten.upsample_nearest3d.vec) 2767@aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) 2768@aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd) 2769@aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) 2770@aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd) 2771@aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) 2772@aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd) 2773def _upsample_nearest_vec( 2774 input: Tensor, 2775 output_size: Optional[List[int]], 2776 scale_factors: Optional[List[float]], 2777) -> Tensor: 2778 osize = upsample_compute_output_size(input.size(), output_size, scale_factors) 2779 scales = ( 2780 scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item] 2781 ) 2782 return _upsample_nearest(input, osize, scales) 2783 2784 2785@register_decomposition(aten._upsample_nearest_exact1d.vec) 2786@register_decomposition(aten._upsample_nearest_exact2d.vec) 2787@register_decomposition(aten._upsample_nearest_exact3d.vec) 2788@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) 2789@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.Autograd) 2790@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) 2791@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.Autograd) 2792@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) 2793@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd) 2794def _upsample_nearest_exact_vec( 2795 input: Tensor, 2796 output_size: Optional[List[int]], 2797 scale_factors: Optional[List[float]], 2798) -> Tensor: 2799 osize = upsample_compute_output_size(input.size(), output_size, scale_factors) 2800 scales = ( 2801 scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item] 2802 ) 2803 return _upsample_nearest(input, osize, scales, exact=True) 2804 2805 2806def _compute_upsample_nearest_indices(input, output_size, scales, exact=False): 2807 # For each dim in output_size, compute the set of input indices used 2808 # to produce the upsampled output. 2809 indices = [] 2810 num_spatial_dims = len(output_size) 2811 offset = 0.5 if exact else 0.0 2812 2813 for d in range(num_spatial_dims): 2814 # Math matches aten/src/ATen/native/cpu/UpSampleKernel.cpp 2815 # 2816 # Indices are computed as following: 2817 # scale = isize / osize 2818 # Case: exact=False 2819 # input_index = floor(output_index * scale) 2820 # Same as OpenCV INTER_NEAREST 2821 # 2822 # Case: exact=False 2823 # index_f32 = (output_index + 0.5) * scale - 0.5 2824 # input_index = round(index_f32) 2825 # Same as Pillow and Scikit-Image/Scipy ndi.zoom 2826 osize = output_size[d] 2827 isize = input.shape[-num_spatial_dims + d] 2828 scale = isize / (isize * scales[d]) if scales[d] is not None else isize / osize 2829 2830 output_indices = torch.arange(osize, dtype=torch.float32, device=input.device) 2831 input_indices = ((output_indices + offset) * scale).to(torch.int64) 2832 for _ in range(num_spatial_dims - 1 - d): 2833 input_indices = input_indices.unsqueeze(-1) 2834 indices.append(input_indices) 2835 return indices 2836 2837 2838@register_decomposition([aten.upsample_nearest1d.default, aten.upsample_nearest1d.out]) 2839@aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd) 2840@aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd) 2841@out_wrapper(preserve_memory_format=True, exact_dtype=True) 2842def upsample_nearest1d( 2843 input: Tensor, 2844 output_size: List[int], 2845 scales: Optional[float] = None, 2846) -> Tensor: 2847 return _upsample_nearest(input, output_size, [scales]) 2848 2849 2850@register_decomposition( 2851 [aten._upsample_nearest_exact1d.default, aten._upsample_nearest_exact1d.out] 2852) 2853@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd) 2854@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd) 2855@out_wrapper(preserve_memory_format=True, exact_dtype=True) 2856def upsample_nearest_exact1d( 2857 input: Tensor, 2858 output_size: List[int], 2859 scales: Optional[float] = None, 2860) -> Tensor: 2861 return _upsample_nearest(input, output_size, [scales], exact=True) 2862 2863 2864@register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out]) 2865@aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd) 2866@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd) 2867@out_wrapper(preserve_memory_format=True, exact_dtype=True) 2868def upsample_nearest2d( 2869 input: Tensor, 2870 output_size: List[int], 2871 scales_h: Optional[float] = None, 2872 scales_w: Optional[float] = None, 2873) -> Tensor: 2874 return _upsample_nearest(input, output_size, [scales_h, scales_w]) 2875 2876 2877@register_decomposition( 2878 [aten._upsample_nearest_exact2d.default, aten._upsample_nearest_exact2d.out] 2879) 2880@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd) 2881@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd) 2882@out_wrapper(preserve_memory_format=True, exact_dtype=True) 2883def _upsample_nearest_exact2d( 2884 input: Tensor, 2885 output_size: List[int], 2886 scales_h: Optional[float] = None, 2887 scales_w: Optional[float] = None, 2888) -> Tensor: 2889 return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True) 2890 2891 2892@register_decomposition([aten.upsample_nearest3d.default, aten.upsample_nearest3d.out]) 2893@aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd) 2894@aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd) 2895@out_wrapper(preserve_memory_format=True, exact_dtype=True) 2896def upsample_nearest3d( 2897 input: Tensor, 2898 output_size: List[int], 2899 scales_d: Optional[float] = None, 2900 scales_h: Optional[float] = None, 2901 scales_w: Optional[float] = None, 2902) -> Tensor: 2903 return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w]) 2904 2905 2906@register_decomposition( 2907 [aten._upsample_nearest_exact3d.default, aten._upsample_nearest_exact3d.out] 2908) 2909@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd) 2910@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd) 2911@out_wrapper(preserve_memory_format=True, exact_dtype=True) 2912def _upsample_nearest_exact3d( 2913 input: Tensor, 2914 output_size: List[int], 2915 scales_d: Optional[float] = None, 2916 scales_h: Optional[float] = None, 2917 scales_w: Optional[float] = None, 2918) -> Tensor: 2919 return _upsample_nearest( 2920 input, output_size, [scales_d, scales_h, scales_w], exact=True 2921 ) 2922 2923 2924@pw_cast_for_opmath 2925def _upsample_nearest( 2926 input: Tensor, 2927 output_size: List[int], 2928 scales: List[Optional[float]], 2929 exact: bool = False, 2930) -> Tensor: 2931 spatial_indices = _compute_upsample_nearest_indices( 2932 input, output_size, scales, exact=exact 2933 ) 2934 2935 indices = [None, None] + spatial_indices 2936 result = aten._unsafe_index(input, indices) 2937 2938 if result.ndim == 4: 2939 # convert output to correct memory format, if necessary 2940 memory_format = utils.suggest_memory_format(input) 2941 2942 # following "heuristic: only use channels_last path when it's faster than the contiguous path" 2943 n_channels = input.shape[1] 2944 if input.device.type == "cuda" and n_channels < 4: 2945 memory_format = torch.contiguous_format 2946 2947 result = result.contiguous(memory_format=memory_format) 2948 return result 2949 2950 2951def gather_params(params, has_biases, has_projections): 2952 if has_biases and has_projections: 2953 group_size = 5 2954 elif has_biases: 2955 group_size = 4 2956 elif has_projections: 2957 group_size = 3 2958 else: 2959 group_size = 2 2960 2961 assert len(params) % group_size == 0, len(params) 2962 return [ 2963 tuple(params[i : i + group_size]) for i in range(0, len(params), group_size) 2964 ] 2965 2966 2967def params_hiddens(params, hiddens, i, bidirectional): 2968 if bidirectional: 2969 cur_params, cur_hidden = params[2 * i], hiddens[2 * i] 2970 bidir_params, bidir_hidden = params[2 * i + 1], hiddens[2 * i + 1] 2971 else: 2972 cur_params, cur_hidden = params[i], hiddens[i] 2973 bidir_params, bidir_hidden = None, None 2974 2975 return cur_params, cur_hidden, bidir_params, bidir_hidden 2976 2977 2978def update_hidden_for_packed(cur_hidden, last_batch_size, batch_size, hiddens): 2979 assert last_batch_size > batch_size 2980 hiddens.append(cur_hidden.narrow(0, batch_size, last_batch_size - batch_size)) 2981 return cur_hidden.narrow(0, 0, batch_size) 2982 2983 2984def update_hidden_for_packed_reverse( 2985 cur_hidden, last_batch_size, batch_size, inp_hidden 2986): 2987 if last_batch_size == batch_size: 2988 return cur_hidden 2989 assert last_batch_size < batch_size 2990 return torch.concat( 2991 ( 2992 cur_hidden, 2993 inp_hidden.narrow(0, last_batch_size, batch_size - last_batch_size), 2994 ) 2995 ) 2996 2997 2998def one_layer_rnn_data( 2999 inp, hidden, params, has_biases, hidden_fn, batch_sizes, reverse=False 3000): 3001 ih_weight = params[0] 3002 hh_weight = params[1] 3003 ih_bias = params[2] if has_biases else None 3004 hh_bias = params[3] if has_biases else None 3005 3006 step_output = [] 3007 hiddens: List[torch.Tensor] = [] 3008 3009 last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0] 3010 cur_hidden = hidden.narrow(0, 0, last_batch_size) 3011 split_inp = torch.split(inp, list(batch_sizes)) 3012 if reverse: 3013 split_inp = split_inp[::-1] 3014 for inp in split_inp: 3015 i = inp.shape[0] 3016 3017 if last_batch_size == i: 3018 pass # don't update cur_hidden 3019 # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest 3020 elif reverse: 3021 cur_hidden = update_hidden_for_packed_reverse( 3022 cur_hidden, last_batch_size, i, hidden 3023 ) 3024 else: 3025 cur_hidden = update_hidden_for_packed( 3026 cur_hidden, last_batch_size, i, hiddens 3027 ) 3028 3029 cur_hidden = hidden_fn(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias) 3030 last_batch_size = i 3031 step_output.append(cur_hidden) 3032 3033 if reverse: 3034 step_output.reverse() 3035 else: 3036 hiddens.append(cur_hidden) 3037 hiddens.reverse() 3038 3039 out = torch.cat(step_output, 0) 3040 hidden_out = torch.cat(hiddens, 0) if not reverse else cur_hidden 3041 return out, hidden_out 3042 3043 3044def rnn_cell(nonlinearity): 3045 def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): 3046 return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i) 3047 3048 return inner 3049 3050 3051def rnn_cell_data(nonlinearity): 3052 def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): 3053 i = F.linear(i, ih_weight, ih_bias) 3054 return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i) 3055 3056 return inner 3057 3058 3059def one_layer_rnn(inp, hidden, params, has_biases, hidden_fn, reverse=False): 3060 ih_weight = params[0] 3061 hh_weight = params[1] 3062 ih_bias = params[2] if has_biases else None 3063 hh_bias = params[3] if has_biases else None 3064 3065 precomputed_input = F.linear(inp, ih_weight, ih_bias) 3066 precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input 3067 cur_hidden = hidden.unsqueeze(0) 3068 step_output = [] 3069 for i in precomputed_input: 3070 cur_hidden = hidden_fn(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias) 3071 step_output.append(cur_hidden) 3072 3073 if reverse: 3074 step_output.reverse() 3075 3076 out = torch.cat(step_output, 0) 3077 3078 return out, cur_hidden.squeeze(0) 3079 3080 3081def mkldnn_one_layer_lstm(inp, hidden, params, has_biases, reverse=False): 3082 w0 = params[0] 3083 w1 = params[1] 3084 if has_biases: 3085 w2 = params[2] 3086 w3 = params[3] 3087 else: 3088 w2 = torch.zeros(w0.size()) 3089 w3 = torch.zeros(w1.size()) 3090 3091 hx = hidden[0].unsqueeze(0) 3092 cx = hidden[1].unsqueeze(0) 3093 3094 batch_sizes: List[int] = [] 3095 mode = 2 # third_party/ideep/include/ideep/abstract_types.hpp: ideep::rnn_kind::LSTM = 2 3096 hidden_size = hx.size(2) 3097 num_layers = 1 3098 3099 # _rnn_helper already handles bidirectional and batch_first so we hard-code them to False here 3100 bidirectional = False 3101 batch_first = False 3102 3103 train = False 3104 # If batch_first, inp has been permuted in _rnn_helper. Convert to contiguous here. 3105 # Same as aten/src/ATen/native/mkldnn/RNN.cpp: mkldnn_rnn: input = input.contiguous(); 3106 inp = inp.contiguous() 3107 hx = hx.contiguous() 3108 cx = cx.contiguous() 3109 outputs = torch.ops.aten.mkldnn_rnn_layer.default( 3110 inp, 3111 w0, 3112 w1, 3113 w2, 3114 w3, 3115 hx, 3116 cx, 3117 reverse, 3118 batch_sizes, 3119 mode, 3120 hidden_size, 3121 num_layers, 3122 has_biases, 3123 bidirectional, 3124 batch_first, 3125 train, 3126 ) 3127 y, hy, cy = outputs[0], outputs[1], outputs[2] 3128 return y, (hy.squeeze(0), cy.squeeze(0)) 3129 3130 3131def _rnn_helper( 3132 input, 3133 hidden, 3134 params, 3135 has_biases, 3136 num_layers, 3137 dropout, 3138 train, 3139 bidirectional, 3140 batch_first, 3141 layer_fn, 3142): 3143 input = input.transpose(0, 1) if batch_first else input 3144 final_hiddens = [] 3145 3146 for i in range(num_layers): 3147 cur_params, cur_hidden, bidir_params, bidir_hidden = params_hiddens( 3148 params, hidden, i, bidirectional 3149 ) 3150 dropout = dropout if (train and num_layers < i - 1) else 0.0 3151 fwd_inp, fwd_hidden = layer_fn(input, cur_hidden, cur_params, has_biases) 3152 final_hiddens.append(fwd_hidden) 3153 3154 if bidirectional: 3155 bwd_inp, bwd_hidden = layer_fn( 3156 input, bidir_hidden, bidir_params, has_biases, reverse=True 3157 ) 3158 final_hiddens.append(bwd_hidden) 3159 3160 if bidirectional: 3161 input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1) # type: ignore[possibly-undefined] 3162 else: 3163 input = fwd_inp 3164 3165 if dropout != 0 and train and i < num_layers - 1: 3166 input = torch.dropout(input, dropout, train=True) 3167 3168 input = input.transpose(0, 1) if batch_first else input 3169 return input, final_hiddens 3170 3171 3172@register_decomposition(aten.rnn_tanh.input) 3173@aten.rnn_tanh.input.py_impl(DispatchKey.CompositeImplicitAutograd) 3174@aten.rnn_tanh.input.py_impl(DispatchKey.Autograd) 3175def rnn_tanh_input( 3176 input, 3177 hx, 3178 params, 3179 has_biases, 3180 num_layers, 3181 dropout, 3182 train, 3183 bidirectional, 3184 batch_first, 3185): 3186 hidden = hx.unbind(0) 3187 params = gather_params(params, has_biases, False) 3188 out, final_hiddens = _rnn_helper( 3189 input, 3190 hidden, 3191 params, 3192 has_biases, 3193 num_layers, 3194 dropout, 3195 train, 3196 bidirectional, 3197 batch_first, 3198 partial(one_layer_rnn, hidden_fn=rnn_cell(torch.tanh)), 3199 ) 3200 return out, torch.stack(final_hiddens, 0) 3201 3202 3203@register_decomposition(aten.rnn_relu.input) 3204@aten.rnn_relu.input.py_impl(DispatchKey.CompositeImplicitAutograd) 3205@aten.rnn_relu.input.py_impl(DispatchKey.Autograd) 3206def rnn_relu_input( 3207 input, 3208 hx, 3209 params, 3210 has_biases, 3211 num_layers, 3212 dropout, 3213 train, 3214 bidirectional, 3215 batch_first, 3216): 3217 hidden = hx.unbind(0) 3218 params = gather_params(params, has_biases, False) 3219 out, final_hiddens = _rnn_helper( 3220 input, 3221 hidden, 3222 params, 3223 has_biases, 3224 num_layers, 3225 dropout, 3226 train, 3227 bidirectional, 3228 batch_first, 3229 partial(one_layer_rnn, hidden_fn=rnn_cell(torch.relu)), 3230 ) 3231 return out, torch.stack(final_hiddens, 0) 3232 3233 3234@register_decomposition(aten.rnn_relu.data) 3235@aten.rnn_relu.data.py_impl(DispatchKey.CompositeImplicitAutograd) 3236@aten.rnn_relu.data.py_impl(DispatchKey.Autograd) 3237def rnn_relu_data( 3238 data, 3239 batch_sizes, 3240 hx, 3241 params, 3242 has_biases, 3243 num_layers, 3244 dropout, 3245 train, 3246 bidirectional, 3247): 3248 hidden = hx.unbind(0) 3249 params = gather_params(params, has_biases, False) 3250 out, final_hiddens = _rnn_helper( 3251 data, 3252 hidden, 3253 params, 3254 has_biases, 3255 num_layers, 3256 dropout, 3257 train, 3258 bidirectional, 3259 False, 3260 partial( 3261 one_layer_rnn_data, 3262 batch_sizes=batch_sizes, 3263 hidden_fn=rnn_cell_data(torch.relu), 3264 ), 3265 ) 3266 return out, torch.stack(final_hiddens, 0) 3267 3268 3269@register_decomposition(aten.rnn_tanh.data) 3270@aten.rnn_tanh.data.py_impl(DispatchKey.CompositeImplicitAutograd) 3271@aten.rnn_tanh.data.py_impl(DispatchKey.Autograd) 3272def rnn_tanh_data( 3273 data, 3274 batch_sizes, 3275 hx, 3276 params, 3277 has_biases, 3278 num_layers, 3279 dropout, 3280 train, 3281 bidirectional, 3282): 3283 hidden = hx.unbind(0) 3284 params = gather_params(params, has_biases, False) 3285 out, final_hiddens = _rnn_helper( 3286 data, 3287 hidden, 3288 params, 3289 has_biases, 3290 num_layers, 3291 dropout, 3292 train, 3293 bidirectional, 3294 False, 3295 partial( 3296 one_layer_rnn_data, 3297 batch_sizes=batch_sizes, 3298 hidden_fn=rnn_cell_data(torch.tanh), 3299 ), 3300 ) 3301 return out, torch.stack(final_hiddens, 0) 3302 3303 3304def lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim): 3305 gates = F.linear(hx, hh_weight, hh_bias) + inp 3306 chunked_gates = gates.chunk(4, chunk_dim) 3307 in_gate = chunked_gates[0].sigmoid() 3308 forget_gate = chunked_gates[1].sigmoid() 3309 cell_gate = chunked_gates[2].tanh() 3310 out_gate = chunked_gates[3].sigmoid() 3311 cy = forget_gate * cx + (in_gate * cell_gate) 3312 hy = out_gate * cy.tanh() 3313 hy = hy if hr_weight is None else F.linear(hy, hr_weight, None) 3314 3315 return hy, cy 3316 3317 3318def one_layer_lstm(inp, hidden, params, has_biases, reverse=False): 3319 ih_weight = params[0] 3320 hh_weight = params[1] 3321 ih_bias = params[2] if has_biases else None 3322 hh_bias = params[3] if has_biases else None 3323 hr_weight = ( 3324 params[4] if len(params) == 5 else params[2] if len(params) == 3 else None 3325 ) 3326 3327 hx = hidden[0].unsqueeze(0) 3328 cx = hidden[1].unsqueeze(0) 3329 3330 precomputed_input = F.linear(inp, ih_weight, ih_bias) 3331 precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input 3332 step_output = [] 3333 for inp in precomputed_input: 3334 hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=2) 3335 step_output.append(hx) 3336 3337 if reverse: 3338 step_output.reverse() 3339 3340 out = torch.cat(step_output, 0) 3341 3342 return out, (hx.squeeze(1), cx.squeeze(1)) 3343 3344 3345def one_layer_lstm_data(inp, hidden, params, has_biases, batch_sizes, reverse=False): 3346 ih_weight = params[0] 3347 hh_weight = params[1] 3348 ih_bias = params[2] if has_biases else None 3349 hh_bias = params[3] if has_biases else None 3350 hr_weight = ( 3351 params[4] if len(params) == 5 else params[2] if len(params) == 3 else None 3352 ) 3353 3354 step_output = [] 3355 hiddens = [] 3356 3357 last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0] 3358 split_inp = torch.split(inp, list(batch_sizes)) 3359 if reverse: 3360 split_inp = split_inp[::-1] 3361 3362 orig_hx = hidden[0] 3363 orig_cx = hidden[1] 3364 hx, cx = orig_hx.narrow(0, 0, last_batch_size), orig_cx.narrow( 3365 0, 0, last_batch_size 3366 ) 3367 3368 for inp in split_inp: 3369 i = inp.shape[0] 3370 inp = F.linear(inp, ih_weight, ih_bias) 3371 3372 # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest 3373 if i < last_batch_size: 3374 hiddens.append( 3375 ( 3376 hx.narrow(0, i, last_batch_size - i), 3377 cx.narrow(0, i, last_batch_size - i), 3378 ) 3379 ) 3380 hx, cx = hx.narrow(0, 0, i), cx.narrow(0, 0, i) 3381 3382 # this will only happen when reverse=True 3383 if i > last_batch_size: 3384 hx = torch.concat( 3385 (hx, orig_hx.narrow(0, last_batch_size, i - last_batch_size)), 0 3386 ) 3387 cx = torch.concat( 3388 (cx, orig_cx.narrow(0, last_batch_size, i - last_batch_size)), 0 3389 ) 3390 3391 hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=1) 3392 last_batch_size = i 3393 step_output.append(hx) 3394 3395 if reverse: 3396 step_output.reverse() 3397 hidden_out = (hx, cx) 3398 else: 3399 hiddens.append((hx, cx)) 3400 hiddens.reverse() 3401 hidden0, hidden1 = zip(*hiddens) 3402 hidden_out = torch.cat(hidden0, 0), torch.cat(hidden1, 0) 3403 3404 out = torch.cat(step_output, 0) 3405 return out, hidden_out 3406 3407 3408def select_one_layer_lstm_function(input, hx, params): 3409 r"""Check whether we could use decompose lstm with mkldnn_rnn_layer. 3410 All the below conditions need to be met: 3411 * ``torch._C._get_mkldnn_enabled()`` returns ``True``. 3412 * All the input args are on CPU. 3413 * The dtypes of args are either torch.float or torch.bfloat16. 3414 * Inference. 3415 * ``has_projections`` returns ``False``. 3416 3417 Args: 3418 * input: the input sequence to LSTM 3419 * hx: a tuple of the input hidden state and cell state ``(h_0, c_0)`` to LSTM 3420 * params: the weight and bias tensors of LSTM 3421 """ 3422 3423 def use_mkldnn(input, hx, params): 3424 if not torch._C._get_mkldnn_enabled(): 3425 return False 3426 3427 tensors = [input] + list(hx) + list(chain.from_iterable(params)) 3428 devices = {t.device for t in tensors} 3429 if len(devices) != 1: 3430 return False 3431 3432 device = devices.pop() 3433 if device != torch.device("cpu"): 3434 return False 3435 # With autocast, possible to have mixed dtype here 3436 dtypes = {t.dtype for t in tensors} 3437 for dtype in dtypes: 3438 if dtype not in [torch.float, torch.bfloat16]: 3439 return False 3440 3441 if input.requires_grad: 3442 return False 3443 3444 has_projections = hx[0].size(2) != hx[1].size(2) 3445 if has_projections: 3446 return False 3447 3448 return True 3449 3450 # mkldnn_one_layer_lstm does not depend on seq_len while one_layer_lstm 3451 # will expand over the seq_len dim 3452 if use_mkldnn(input, hx, params): 3453 return mkldnn_one_layer_lstm 3454 else: 3455 return one_layer_lstm 3456 3457 3458@register_decomposition(aten.lstm.input) 3459@aten.lstm.input.py_impl(DispatchKey.CompositeImplicitAutograd) 3460@aten.lstm.input.py_impl(DispatchKey.Autograd) 3461def lstm_impl( 3462 input, 3463 hx, 3464 params, 3465 has_biases, 3466 num_layers, 3467 dropout, 3468 train, 3469 bidirectional, 3470 batch_first, 3471): 3472 assert len(hx) == 2, "lstm expects two hidden states" 3473 params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2)) 3474 hidden = list(zip(hx[0], hx[1])) 3475 layer_fn = select_one_layer_lstm_function(input, hx, params) 3476 out, final_hiddens = _rnn_helper( 3477 input, 3478 hidden, 3479 params, 3480 has_biases, 3481 num_layers, 3482 dropout, 3483 train, 3484 bidirectional, 3485 batch_first, 3486 layer_fn, 3487 ) 3488 final_hiddens = list(zip(*final_hiddens)) 3489 return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0) 3490 3491 3492@register_decomposition(aten.lstm.data) 3493@aten.lstm.data.py_impl(DispatchKey.CompositeImplicitAutograd) 3494@aten.lstm.data.py_impl(DispatchKey.Autograd) 3495def lstm_data_impl( 3496 data, 3497 batch_sizes, 3498 hx, 3499 params, 3500 has_biases, 3501 num_layers, 3502 dropout, 3503 train, 3504 bidirectional, 3505): 3506 assert len(hx) == 2, "lstm expects two hidden states" 3507 params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2)) 3508 hidden = list(zip(hx[0], hx[1])) 3509 out, final_hiddens = _rnn_helper( 3510 data, 3511 hidden, 3512 params, 3513 has_biases, 3514 num_layers, 3515 dropout, 3516 train, 3517 bidirectional, 3518 False, 3519 partial(one_layer_lstm_data, batch_sizes=batch_sizes), 3520 ) 3521 final_hiddens = list(zip(*final_hiddens)) 3522 return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0) 3523 3524 3525def gru_cell(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): 3526 chunked_igates = inp.chunk(3, 1) 3527 chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 2) 3528 reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid() 3529 input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid() 3530 new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh() 3531 return (cur_hidden - new_gate) * input_gate + new_gate 3532 3533 3534def gru_cell_data(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): 3535 chunked_igates = F.linear(inp, ih_weight, ih_bias).chunk(3, 1) 3536 chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 1) 3537 reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid() 3538 input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid() 3539 new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh() 3540 return (cur_hidden - new_gate) * input_gate + new_gate 3541 3542 3543@register_decomposition(aten.gru.data) 3544@aten.gru.data.py_impl(DispatchKey.CompositeImplicitAutograd) 3545@aten.gru.data.py_impl(DispatchKey.Autograd) 3546def gru_impl_data( 3547 data, 3548 batch_sizes, 3549 hx, 3550 params, 3551 has_biases, 3552 num_layers, 3553 dropout, 3554 train, 3555 bidirectional, 3556): 3557 params = gather_params(params, has_biases, False) 3558 out, final_hiddens = _rnn_helper( 3559 data, 3560 hx.unbind(0), 3561 params, 3562 has_biases, 3563 num_layers, 3564 dropout, 3565 train, 3566 bidirectional, 3567 False, 3568 partial(one_layer_rnn_data, batch_sizes=batch_sizes, hidden_fn=gru_cell_data), 3569 ) 3570 return out, torch.stack(final_hiddens, 0) 3571 3572 3573@register_decomposition(aten.gru.input) 3574@aten.gru.input.py_impl(DispatchKey.CompositeImplicitAutograd) 3575@aten.gru.input.py_impl(DispatchKey.Autograd) 3576def gru_impl( 3577 input, 3578 hx, 3579 params, 3580 has_biases, 3581 num_layers, 3582 dropout, 3583 train, 3584 bidirectional, 3585 batch_first, 3586): 3587 params = gather_params(params, has_biases, False) 3588 out, final_hiddens = _rnn_helper( 3589 input, 3590 hx.unbind(0), 3591 params, 3592 has_biases, 3593 num_layers, 3594 dropout, 3595 train, 3596 bidirectional, 3597 batch_first, 3598 partial(one_layer_rnn, hidden_fn=gru_cell), 3599 ) 3600 return out, torch.stack(final_hiddens, 0) 3601 3602 3603@register_decomposition(aten._upsample_bilinear2d_aa.vec) 3604@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd) 3605@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.Autograd) 3606def upsample_bilinear2d_aa_vec(input, output_size, align_corners, scale_factors): 3607 osize = upsample_compute_output_size(input.size(), output_size, scale_factors) 3608 scale_h = get_scale_value(scale_factors, 0) 3609 scale_w = get_scale_value(scale_factors, 1) 3610 return torch.ops.aten._upsample_bilinear2d_aa( 3611 input, osize, align_corners, scale_h, scale_w 3612 ) 3613 3614 3615@register_decomposition(aten._upsample_bicubic2d_aa.vec) 3616@aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd) 3617@aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.Autograd) 3618def upsample_bicubic2d_aa_vec(input, output_size, align_corners, scale_factors): 3619 osize = upsample_compute_output_size(input.size(), output_size, scale_factors) 3620 scale_h = get_scale_value(scale_factors, 0) 3621 scale_w = get_scale_value(scale_factors, 1) 3622 return torch.ops.aten._upsample_bicubic2d_aa( 3623 input, osize, align_corners, scale_h, scale_w 3624 ) 3625 3626 3627@register_decomposition(aten.upsample_bilinear2d.vec) 3628@register_decomposition(aten.upsample_trilinear3d.vec) 3629@aten.upsample_linear1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) 3630@aten.upsample_linear1d.vec.py_impl(DispatchKey.Autograd) 3631@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) 3632@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd) 3633@aten.upsample_trilinear3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) 3634@aten.upsample_trilinear3d.vec.py_impl(DispatchKey.Autograd) 3635def _upsample_linear_vec(input, output_size, align_corners, scale_factors): 3636 osize = upsample_compute_output_size(input.size(), output_size, scale_factors) 3637 scales = scale_factors if scale_factors else [None] * len(osize) 3638 return _upsample_linear(input, osize, align_corners, scales) 3639 3640 3641@register_decomposition([aten.upsample_linear1d.default, aten.upsample_linear1d.out]) 3642@out_wrapper() 3643def upsample_linear1d( 3644 input: Tensor, 3645 output_size: List[int], 3646 align_corners: bool, 3647 scales_w: Optional[float] = None, 3648) -> Tensor: 3649 return _upsample_linear(input, output_size, align_corners, [scales_w]) 3650 3651 3652@register_decomposition( 3653 [aten.upsample_bilinear2d.default, aten.upsample_bilinear2d.out] 3654) 3655@aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd) 3656@out_wrapper() 3657def upsample_bilinear2d( 3658 input: Tensor, 3659 output_size: List[int], 3660 align_corners: bool, 3661 scales_h: Optional[float] = None, 3662 scales_w: Optional[float] = None, 3663) -> Tensor: 3664 return _upsample_linear(input, output_size, align_corners, [scales_h, scales_w]) 3665 3666 3667@register_decomposition( 3668 [aten.upsample_trilinear3d.default, aten.upsample_trilinear3d.out] 3669) 3670@out_wrapper() 3671def upsample_trilinear3d( 3672 input: Tensor, 3673 output_size: List[int], 3674 align_corners: bool, 3675 scales_d: Optional[float] = None, 3676 scales_h: Optional[float] = None, 3677 scales_w: Optional[float] = None, 3678) -> Tensor: 3679 return _upsample_linear( 3680 input, output_size, align_corners, [scales_d, scales_h, scales_w] 3681 ) 3682 3683 3684def _compute_scale(in_size, out_size, align_corners, scale=None): 3685 if align_corners: 3686 return (in_size - 1.0) / (out_size - 1.0) if out_size > 1 else 0 3687 else: 3688 return 1.0 / scale if scale is not None and scale > 0 else in_size / out_size 3689 3690 3691def _compute_source_index(scale, dst_index, align_corners): 3692 if align_corners: 3693 return scale * dst_index 3694 else: 3695 return scale * (dst_index + 0.5) - 0.5 3696 3697 3698def _sum_tensors_uint8( 3699 src: Iterable[Tensor], weights: Iterable[Tensor], weights_precision: Tensor 3700) -> Tensor: 3701 output = _sum_tensors( 3702 s.to(torch.int32) * c.to(torch.int32) for s, c in zip(src, weights) 3703 ) + (1 << (weights_precision - 1)) 3704 output = output >> weights_precision 3705 return torch.clamp(output, 0, 255).to(torch.uint8) 3706 3707 3708def _compute_weight_precision(weights: TensorSequenceType) -> Tensor: 3709 max_weight = torch.stack(weights).max() 3710 max_weight_precision = 22 3711 precisions = torch.arange(max_weight_precision, device=max_weight.device) 3712 values = 0.5 + max_weight * (1 << (precisions + 1)) 3713 mask = values >= (1 << 15) 3714 return max_weight_precision - mask.sum() 3715 3716 3717@pw_cast_for_opmath 3718def _upsample_linear( 3719 input: Tensor, 3720 output_size: List[int], 3721 align_corners: bool, 3722 scales: List[Optional[float]], 3723) -> Tensor: 3724 # get dimensions of original image 3725 n_batch, n_channels = input.shape[:2] 3726 inp_sizes = input.shape[2:] 3727 n_dims = len(inp_sizes) 3728 3729 _, dtype = utils.elementwise_dtypes( 3730 input, 3731 type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 3732 ) 3733 3734 def get_values(inp_size, out_size, scales, nsqueeze): 3735 # First Calculate scaling factor 3736 scale_factor = _compute_scale(inp_size, out_size, align_corners, scales) 3737 # We have to create arange with int64 dtype and use .to in order to avoid 3738 # additional kernels creation in inductor and get a perf slowdown 3739 i = torch.arange(out_size, device=input.device).to(dtype=dtype) 3740 3741 x_f32 = _compute_source_index(scale_factor, i, align_corners).clamp(min=0.0) 3742 x_f32 = x_f32.reshape(x_f32.shape[0], *[1] * (nsqueeze)) 3743 x = x_f32.to(torch.int64) 3744 xp1 = (x + 1).clamp(max=inp_size - 1) 3745 return x_f32, x, xp1 3746 3747 values = [ 3748 get_values(inp_size, out_size, scales, n_dims - 1 - i) 3749 for i, (inp_size, out_size, scales) in enumerate( 3750 zip(inp_sizes, output_size, scales) 3751 ) 3752 ] 3753 xs_f32, xs, xp1s = list(zip(*values)) 3754 3755 vs = [] 3756 for a in product(*[[0, 1]] * n_dims): 3757 idx = [None, None] + [xs[k] if a[k] == 0 else xp1s[k] for k in range(n_dims)] 3758 v = aten._unsafe_index(input, idx) 3759 v = _maybe_convert_to_dtype(v, dtype) 3760 vs.append(v) 3761 3762 for i in reversed(range(n_dims)): 3763 xscale = (xs_f32[i] - xs[i]).clamp(0.0, 1.0).to(dtype) 3764 vs = [ 3765 # x1 * (1 - alpha) + x2 * alpha == x1 + (x2 - x1) * alpha 3766 v1 + torch.mul(v2 - v1, xscale) 3767 for v1, v2 in zip(vs[::2], vs[1::2]) 3768 ] 3769 3770 assert len(vs) == 1 3771 result = vs[0] 3772 3773 # convert output to correct memory format, if necessary 3774 memory_format = utils.suggest_memory_format(input) 3775 3776 # following "heuristic: only use channels_last path when it's faster than the contiguous path" 3777 if input.device.type == "cuda" and n_channels < 16: 3778 memory_format = torch.contiguous_format 3779 3780 assert isinstance(result, torch.Tensor) 3781 3782 result = result.contiguous(memory_format=memory_format) 3783 3784 if not input.is_floating_point(): 3785 result = result.round() 3786 3787 return result 3788 3789 3790# We should be applying decompositions after all transformations 3791@register_decomposition(aten.is_same_size.default) 3792def is_same_size(a: Tensor, b: Tensor) -> bool: 3793 return a.shape == b.shape 3794 3795 3796@register_decomposition([aten._reshape_alias, aten._unsafe_view]) 3797@out_wrapper() 3798def _reshape_alias(x, shape, *args): 3799 return aten.view(x, shape) 3800 3801 3802@register_decomposition([aten._unsafe_index]) 3803def _unsafe_index(x, indices): 3804 return aten.index(x, indices) 3805 3806 3807@register_decomposition([aten._unsafe_index_put]) 3808def _unsafe_index_put(x, indices, value, accumulate=False): 3809 return aten.index_put(x, indices, value, accumulate) 3810 3811 3812@register_decomposition([aten._unsafe_masked_index]) 3813def _unsafe_masked_index(x, mask, indices, fill): 3814 for index in indices: 3815 if index is not None: 3816 torch._check( 3817 index.dtype in [torch.long, torch.int], 3818 lambda: "tensors used as indices must be long or int tensors", 3819 ) 3820 3821 torch._check( 3822 mask.dtype == torch.bool, 3823 lambda: "tensors used as masks must be bool tensors", 3824 ) 3825 3826 if x.numel() == 0: 3827 meta_result = torch._meta_registrations.meta_index_Tensor(x, indices) 3828 return x.new_full(meta_result.shape, fill) 3829 3830 for i in range(len(indices)): 3831 index = indices[i] 3832 if index is not None: 3833 indices[i] = index.clamp(min=0, max=x.size(i) - 1) 3834 3835 return aten._unsafe_index(x, indices).masked_fill(~mask, fill) 3836 3837 3838@register_decomposition([aten._unsafe_masked_index_put_accumulate]) 3839def _unsafe_masked_index_put_accumulate(x, mask, indices, values): 3840 for index in indices: 3841 if index is not None: 3842 torch._check( 3843 index.dtype in [torch.long, torch.int], 3844 lambda: "tensors used as indices must be long or int tensors", 3845 ) 3846 3847 torch._check( 3848 mask.dtype == torch.bool, 3849 lambda: "tensors used as masks must be bool tensors", 3850 ) 3851 3852 if x.numel() == 0: 3853 return x.clone() 3854 3855 for i in range(len(indices)): 3856 index = indices[i] 3857 if index is not None: 3858 indices[i] = index.clamp(min=-x.size(i), max=x.size(i) - 1) 3859 3860 masked_value = values.masked_fill(~mask, 0) 3861 return aten._unsafe_index_put(x, indices, masked_value, accumulate=True) 3862 3863 3864def _nll_loss_forward( 3865 self: Tensor, 3866 target: Tensor, 3867 weight: Optional[Tensor], 3868 reduction: int, 3869 ignore_index: int, 3870) -> Tuple[Tensor, Tensor]: 3871 # self can be [N, C] or [C] 3872 # target can be [N] or [] 3873 3874 n_dims = self.dim() 3875 channel_dim = 1 3876 if n_dims < 2: 3877 channel_dim = 0 3878 3879 if weight is not None: 3880 if n_dims > 1: 3881 shape = [ 3882 1, 3883 ] * n_dims 3884 shape[channel_dim] = weight.shape[0] 3885 w = weight.view(shape) 3886 else: 3887 w = weight 3888 self = self * w 3889 safe_target = torch.where(target != ignore_index, target, 0) 3890 safe_target_ = safe_target.unsqueeze(channel_dim) 3891 # target can be [N, 1] or [1] 3892 3893 result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim) 3894 3895 result = torch.where(target != ignore_index, result, 0) 3896 3897 if reduction == Reduction.NONE.value and n_dims > 1: 3898 total_weight = self.new_full((), 0.0) 3899 return result, total_weight 3900 3901 if weight is not None: 3902 w = w.expand(self.shape) 3903 wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim) 3904 wsum = torch.where(target != ignore_index, wsum, 0) 3905 total_weight = wsum.sum() 3906 else: 3907 total_weight = (target != ignore_index).sum().to(self) 3908 3909 if reduction == Reduction.SUM.value: 3910 result = result.sum() 3911 elif reduction == Reduction.MEAN.value: 3912 result = result.sum() / total_weight 3913 3914 return result, total_weight 3915 3916 3917@register_decomposition(aten.nll_loss_forward) 3918@out_wrapper("output", "total_weight") 3919def nll_loss_forward( 3920 self: Tensor, 3921 target: Tensor, 3922 weight: Optional[Tensor], 3923 reduction: int, 3924 ignore_index: int, 3925) -> Tuple[Tensor, Tensor]: 3926 assert self.dim() > 0 and self.dim() <= 2, "input tensor should be 1D or 2D" 3927 assert ( 3928 target.dim() <= 1 3929 ), "0D or 1D target tensor expected, multi-target not supported" 3930 3931 no_batch_dim = self.dim() == 1 and target.dim() == 0 3932 assert no_batch_dim or ( 3933 self.shape[0] == target.shape[0] 3934 ), f"size mismatch (got input: {self.shape}, target: {target.shape})" 3935 3936 n_classes = self.shape[-1] 3937 3938 assert weight is None or ( 3939 weight.dim() == 1 and weight.numel() == n_classes 3940 ), f"weight tensor should be defined either for all {n_classes} classes or no classes but got weight tensor of shape: {weight.shape}" # noqa: B950 3941 3942 return _nll_loss_forward(self, target, weight, reduction, ignore_index) 3943 3944 3945@register_decomposition(aten.nll_loss2d_forward) 3946@out_wrapper("output", "total_weight") 3947def nll_loss2d_forward( 3948 self: Tensor, 3949 target: Tensor, 3950 weight: Optional[Tensor], 3951 reduction: int, 3952 ignore_index: int, 3953) -> Tuple[Tensor, Tensor]: 3954 return _nll_loss_forward(self, target, weight, reduction, ignore_index) 3955 3956 3957# These are adapted from aten/src/ATen/native/UpSample.h, wich is based on 3958# https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm 3959def _upsample_cubic_convolution1(x: Tensor, A: float) -> Tensor: 3960 return ((A + 2) * x - (A + 3)) * x * x + 1 3961 3962 3963def _upsample_cubic_convolution2(x: Tensor, A: float) -> Tensor: 3964 return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A 3965 3966 3967def _upsample_get_cubic_coefficients(t: Tensor) -> TensorSequenceType: 3968 A = -0.75 3969 3970 if t.device == torch.device("cpu"): 3971 tt1 = torch.stack([t, 1.0 - t], dim=0) 3972 tt2 = torch.stack([t + 1.0, 2.0 - t], dim=0) 3973 w03 = _upsample_cubic_convolution2(tt2, A) 3974 w12 = _upsample_cubic_convolution1(tt1, A) 3975 w0, w3 = torch.unbind(w03, dim=0) 3976 w1, w2 = torch.unbind(w12, dim=0) 3977 return w0, w1, w2, w3 3978 else: 3979 return ( 3980 _upsample_cubic_convolution2(t + 1.0, A), 3981 _upsample_cubic_convolution1(t, A), 3982 _upsample_cubic_convolution1(1.0 - t, A), 3983 _upsample_cubic_convolution2(2.0 - t, A), 3984 ) 3985 3986 3987def _upsample_cubic_interp1d(coeffs: TensorSequenceType, ts: Tensor) -> Tensor: 3988 coeffs2 = _upsample_get_cubic_coefficients(ts) 3989 return _sum_tensors(c1 * c2 for (c1, c2) in zip(coeffs, coeffs2)) 3990 3991 3992# Need this instead of just sum() to keep mypy happy 3993def _sum_tensors(ts: Iterable[Tensor]) -> Tensor: 3994 return reduce(torch.add, ts) 3995 3996 3997def _linspace_from_neg_one( 3998 num_steps: int, align_corners: bool, dtype: torch.dtype, device: torch.device 3999): 4000 if num_steps <= 1: 4001 return torch.tensor(0, device=device, dtype=dtype) 4002 4003 a = ((num_steps - 1) / num_steps) if not align_corners else 1 4004 return torch.linspace(-a, a, steps=num_steps, device=device, dtype=dtype) 4005 4006 4007def _make_base_grid_4d(theta: Tensor, h: int, w: int, align_corners: bool): 4008 dtype = theta.dtype 4009 device = theta.device 4010 4011 # Using padding and summation generates a single kernel vs using torch.stack where 3 kernels generated 4012 # corresponding to each individual tensor: grid_x, grid_y, grid_one 4013 grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, w, 1) 4014 grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(h, 1, 1) 4015 grid_one = torch.ones((1, 1, 1), dtype=dtype, device=device) 4016 4017 # this is just a temporary hack and we should use torch.stack here once #104480 is merged 4018 grid_x = torch.nn.functional.pad(grid_x, pad=(0, 2), mode="constant", value=0) 4019 grid_y = torch.nn.functional.pad(grid_y, pad=(1, 1), mode="constant", value=0) 4020 grid_one = torch.nn.functional.pad(grid_one, pad=(2, 0), mode="constant", value=0) 4021 return grid_x + grid_y + grid_one 4022 4023 4024def _make_base_grid_5d(theta: Tensor, d: int, h: int, w: int, align_corners: bool): 4025 dtype = theta.dtype 4026 device = theta.device 4027 4028 grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, 1, w, 1) 4029 grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(1, h, 1, 1) 4030 grid_z = _linspace_from_neg_one(d, align_corners, dtype, device).view(d, 1, 1, 1) 4031 grid_one = torch.ones((1, 1, 1, 1), dtype=dtype, device=device) 4032 4033 # this is just a temporary hack and we should use torch.stack here once #104480 is merged 4034 grid_x = torch.nn.functional.pad(grid_x, pad=(0, 3), mode="constant", value=0) 4035 grid_y = torch.nn.functional.pad(grid_y, pad=(1, 2), mode="constant", value=0) 4036 grid_z = torch.nn.functional.pad(grid_z, pad=(2, 1), mode="constant", value=0) 4037 grid_one = torch.nn.functional.pad(grid_one, pad=(3, 0), mode="constant", value=0) 4038 return grid_x + grid_y + grid_z + grid_one 4039 4040 4041def _affine_grid_generator_4d(theta: Tensor, size: List[int], align_corners: bool): 4042 n, _, h, w = size 4043 base_grid = _make_base_grid_4d(theta, h, w, align_corners=align_corners) 4044 # base_grid shape is (h, w, 3) and theta shape is (n, 2, 3) 4045 # We do manually a matrix multiplication which is faster than mm() 4046 # (h * w, 3, 1) * (n, 1, 3, 2) -> (n, h * w, 2) 4047 grid = (base_grid.view(-1, 3, 1) * theta.mT.unsqueeze(1)).sum(-2) 4048 return grid.view(n, h, w, 2) 4049 4050 4051def _affine_grid_generator_5d(theta: Tensor, size: List[int], align_corners: bool): 4052 n, _, d, h, w = size 4053 base_grid = _make_base_grid_5d(theta, d, h, w, align_corners=align_corners) 4054 # base_grid shape is (d, h, w, 4) and theta shape is (n, 3, 4) 4055 # We do manually a matrix multiplication which is faster than mm() 4056 # (d * h * w, 4, 1) * (n, 1, 4, 3) -> (n, h * w, 3) 4057 grid = (base_grid.view(-1, 4, 1) * theta.mT.unsqueeze(1)).sum(-2) 4058 return grid.view(n, d, h, w, 3) 4059 4060 4061@register_decomposition(aten.affine_grid_generator) 4062@out_wrapper() 4063@pw_cast_for_opmath 4064def affine_grid_generator(theta: Tensor, size: List[int], align_corners: bool): 4065 torch._check( 4066 len(size) in (4, 5), 4067 lambda: "affine_grid_generator needs 4d (spatial) or 5d (volumetric) inputs.", 4068 ) 4069 if len(size) == 4: 4070 return _affine_grid_generator_4d(theta, size, align_corners=align_corners) 4071 else: 4072 return _affine_grid_generator_5d(theta, size, align_corners=align_corners) 4073 4074 4075def _grid_sampler_2d( 4076 a: Tensor, 4077 grid: Tensor, 4078 interpolation_mode: int = 0, 4079 padding_mode: int = 0, 4080 align_corners: bool = False, 4081 _expand_grid: bool = True, 4082) -> Tensor: 4083 # This method is a copy of grid_sampler_2d implementation and introduced with additional arg _expand_grid to 4084 # optionally expand the input grid for performance reasons. 4085 # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x 4086 # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2) 4087 # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first. 4088 # Thus we apply this hack to not expand the grid for this case. 4089 4090 torch._check( 4091 interpolation_mode in (0, 1, 2), 4092 lambda: f"Invalid interpolation mode {interpolation_mode}", 4093 ) 4094 torch._check( 4095 padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}" 4096 ) 4097 4098 def unnormalize(coords: Tensor, size: int) -> Tensor: 4099 # Rescale coordinates from [-1, 1] to: 4100 # [0, size - 1] if align_corners is True 4101 # [-.5, size -.5] if align_corners is False 4102 mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5) 4103 ofs = size * 0.5 - 0.5 4104 return coords * mul + ofs 4105 4106 # Reflects coordinates until they fall between low and high (inclusive). 4107 # The bounds are passed as twice their value so that half-integer values 4108 # can be represented as ints. 4109 def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor: 4110 if twice_low == twice_high: 4111 return torch.zeros_like(coords) 4112 coords_min = twice_low / 2 4113 coords_span = (twice_high - twice_low) / 2 4114 coords2 = (coords - coords_min).abs() 4115 extra = torch.fmod(coords2, coords_span) 4116 flips = (coords2 / coords_span).floor().to(dtype=torch.int8) 4117 return torch.where( 4118 flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra 4119 ) 4120 4121 def compute_coordinates(coords: Tensor, size: int) -> Tensor: 4122 if padding_mode == 0: # Zero 4123 return coords 4124 elif padding_mode == 1: # Borders 4125 return torch.clamp(coords, 0, size - 1) 4126 else: # padding_mode == 2, Reflection 4127 if align_corners: 4128 coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1)) 4129 else: 4130 coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1) 4131 return torch.clamp(coords_reflected, 0, size - 1) 4132 4133 def compute_source_index(coords: Tensor, size: int) -> Tensor: 4134 coords_un = unnormalize(coords, size) 4135 return compute_coordinates(coords_un, size) 4136 4137 N, C, iH, iW = a.shape 4138 _, oH, oW, two = grid.shape 4139 assert two == 2 4140 4141 if _expand_grid: 4142 # Let's expand grid to [N, C, oH, oW, 2] 4143 # This allows to generate a single triton cuda kernel instead of two kernels. 4144 # Two kernels are due source indices, weights have shape (N, 1, oH, oW), xnumel=N*oH*oW 4145 # and output has shape (N, C, oH, oW), xnumel=N*C*oH*oW 4146 # Expanding grid to (N, C, oH, oW, two) unifies xnumel to N*C*oH*oW 4147 grid = grid.view(N, 1, oH, oW, two).expand(N, C, oH, oW, 2) 4148 4149 def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor: 4150 return torch.logical_and( 4151 0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys < iH)) 4152 ) 4153 4154 N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1) 4155 C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1) 4156 4157 def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType: 4158 cond = in_bounds_cond(xs, ys) 4159 # To clip to inside valid coordinates, we map the coordinates 4160 # to (x, y) = (0, 0) and also set the weight to 0 4161 # We also change the shape of the tensor to the appropriate one for 4162 # broadcasting with N_idx, C_idx for the purposes of advanced indexing 4163 c = C if _expand_grid else 1 4164 return tuple( 4165 torch.where(cond, t, 0).view(N, c, oH, oW) 4166 for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws) 4167 ) 4168 4169 def get_summand(ix: Tensor, iy: Tensor, w) -> Tensor: 4170 # Perform clipping, index into input tensor and multiply by weight 4171 idx_x, idx_y, w_ = clip(ix, iy, w) 4172 return a[N_idx, C_idx, idx_y, idx_x] * w_ 4173 4174 x = grid[..., 0] 4175 y = grid[..., 1] 4176 4177 if interpolation_mode == 0: # Bilinear 4178 ix = compute_source_index(x, iW) 4179 iy = compute_source_index(y, iH) 4180 4181 ix_nw, iy_nw = ix.floor(), iy.floor() 4182 ix_ne, iy_ne = ix_nw + 1, iy_nw 4183 ix_sw, iy_sw = ix_nw, iy_nw + 1 4184 ix_se, iy_se = ix_ne, iy_sw 4185 4186 w_nw = (ix_se - ix) * (iy_se - iy) 4187 w_ne = (ix - ix_sw) * (iy_sw - iy) 4188 w_sw = (ix_ne - ix) * (iy - iy_ne) 4189 w_se = (ix - ix_nw) * (iy - iy_nw) 4190 4191 return _sum_tensors( 4192 get_summand(ix, iy, w) 4193 for (ix, iy, w) in ( 4194 (ix_nw, iy_nw, w_nw), 4195 (ix_ne, iy_ne, w_ne), 4196 (ix_sw, iy_sw, w_sw), 4197 (ix_se, iy_se, w_se), 4198 ) 4199 ) 4200 elif interpolation_mode == 1: # Nearest 4201 ix = compute_source_index(x, iW) 4202 iy = compute_source_index(y, iH) 4203 4204 ix_nearest = ix.round() 4205 iy_nearest = iy.round() 4206 4207 return get_summand(ix_nearest, iy_nearest, 1) 4208 else: # interpolation_mode == 2, Bicubic 4209 ix = unnormalize(x, iW) 4210 iy = unnormalize(y, iH) 4211 4212 ix_nw = ix.floor() 4213 iy_nw = iy.floor() 4214 4215 tx = ix - ix_nw 4216 ty = iy - iy_nw 4217 4218 if not _expand_grid: 4219 tx = tx.unsqueeze(1) 4220 ty = ty.unsqueeze(1) 4221 4222 def get_value_bounded(ix: Tensor, iy: Tensor) -> Tensor: 4223 x = compute_coordinates(ix, iW) 4224 y = compute_coordinates(iy, iH) 4225 return get_summand(x, y, 1) 4226 4227 def get_coeff(ofs: int) -> Tensor: 4228 iy_ofs = iy_nw + (ofs - 1) 4229 cs = ( 4230 get_value_bounded(ix_nw - 1, iy_ofs), 4231 get_value_bounded(ix_nw, iy_ofs), 4232 get_value_bounded(ix_nw + 1, iy_ofs), 4233 get_value_bounded(ix_nw + 2, iy_ofs), 4234 ) 4235 return _upsample_cubic_interp1d(cs, tx) 4236 4237 coeffs = tuple(get_coeff(ofs) for ofs in range(4)) 4238 return _upsample_cubic_interp1d(coeffs, ty) 4239 4240 4241@register_decomposition(aten.grid_sampler_2d) 4242@out_wrapper() 4243@pw_cast_for_opmath 4244def grid_sampler_2d( 4245 a: Tensor, 4246 grid: Tensor, 4247 interpolation_mode: int = 0, 4248 padding_mode: int = 0, 4249 align_corners: bool = False, 4250) -> Tensor: 4251 return _grid_sampler_2d( 4252 a, 4253 grid=grid, 4254 interpolation_mode=interpolation_mode, 4255 padding_mode=padding_mode, 4256 align_corners=align_corners, 4257 ) 4258 4259 4260@register_decomposition(aten.mv) 4261@out_wrapper() 4262@pw_cast_for_opmath 4263def mv(self, vec): 4264 torch._check( 4265 self.dim() == 2 and vec.dim() == 1, 4266 lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}", 4267 ) 4268 torch._check( 4269 self.size(1) == vec.size(0), 4270 lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})", 4271 ) 4272 return (self * vec).sum(dim=1) 4273 4274 4275@register_decomposition(aten.binary_cross_entropy_with_logits) 4276@out_wrapper() 4277def binary_cross_entropy_with_logits( 4278 self, target, weight=None, pos_weight=None, reduction=Reduction.MEAN.value 4279): 4280 if pos_weight is not None: 4281 log_weight = (pos_weight - 1) * target + 1 4282 loss = (1 - target) * self - (log_weight * F.logsigmoid(self)) 4283 else: 4284 loss = (1 - target) * self - F.logsigmoid(self) 4285 4286 if weight is not None: 4287 loss = loss * weight 4288 4289 return apply_loss_reduction(loss, reduction) 4290 4291 4292def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> bool: 4293 # For comments of the logic of this function see eager in /native/LinearAlgebra.cpp 4294 4295 t1, t2 = (tensor1, tensor2) if tensor1.ndim >= tensor2.ndim else (tensor2, tensor1) 4296 4297 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 4298 4299 if not (t1.ndim >= 3 and t2.ndim <= 2): 4300 return False 4301 if t2.requires_grad and not is_out: 4302 return True 4303 if tensor1.ndim == 2: 4304 return False 4305 if guard_size_oblivious(t1.numel() == 0): 4306 return True 4307 4308 t1_shape = t1.shape 4309 t1_stride = t1.stride() 4310 return all( 4311 st1 == st2 * s2 4312 for (st1, st2, s2) in zip(t1_stride[:-2], t1_stride[1:-1], t1_shape[1:-1]) 4313 ) 4314 4315 4316@aten.matmul.default.py_impl(DispatchKey.CompositeImplicitAutograd) 4317@aten.matmul.out.py_impl(DispatchKey.CompositeImplicitAutograd) 4318@out_wrapper(pass_is_out=True) 4319def matmul(tensor1, tensor2, *, is_out=False): 4320 dim_tensor1 = tensor1.dim() 4321 dim_tensor2 = tensor2.dim() 4322 assert dim_tensor1 != 0 and dim_tensor2 != 0 4323 if dim_tensor1 == 1 and dim_tensor2 == 1: 4324 return torch.dot(tensor1, tensor2) 4325 elif dim_tensor1 == 2 and dim_tensor2 == 1: 4326 return torch.mv(tensor1, tensor2) 4327 elif dim_tensor1 == 1 and dim_tensor2 == 2: 4328 return torch.squeeze(torch.mm(torch.unsqueeze(tensor1, 0), tensor2), 0) 4329 elif dim_tensor1 == 2 and dim_tensor2 == 2: 4330 return torch.mm(tensor1, tensor2) 4331 elif should_fold(tensor1, tensor2, is_out): 4332 # dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) || 4333 # dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2) 4334 # and some condition on the strides is fulfilled 4335 4336 # optimization: use mm instead of bmm by folding the batch of the larger tensor 4337 # into its leading matrix dimension 4338 transpose = dim_tensor2 > dim_tensor1 4339 t1 = tensor2.mT if transpose else tensor1 4340 t2 = ( 4341 tensor2 if not transpose else (tensor1.t() if dim_tensor1 == 2 else tensor1) 4342 ) 4343 # Invariant: t1.dim() >= 3 && (t2.dim() == 1 || t2.dim() == 2) 4344 # and t1 and t2 are matmul-compatible 4345 4346 # Why not t1.view(-1, sizes_1[-1])? 4347 # If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous. 4348 # This can happen in e.g. [3, 5, 0] @ [0, 0]. 4349 sizes_1 = t1.shape 4350 output_shape = list(sizes_1[:-1]) 4351 folded_dim1 = reduce(operator.mul, output_shape) 4352 4353 # Readjust output_shape if we are multiplying by a matrix 4354 t2_is_matrix = t2.dim() == 2 4355 if t2_is_matrix: 4356 output_shape.append(t2.shape[1]) 4357 4358 # This will almost always be a view. 4359 # It may not be a view if t2->requires_grad(). See should_fold in aten/ for an explanation 4360 t1_folded = t1.reshape(folded_dim1, sizes_1[-1]) 4361 if t2_is_matrix: 4362 # This copies if we perform a 2D @ 3D and the first tensor requires_grad 4363 # See should_fold native/LinearAlgebra.cpp for why. 4364 output = t1_folded.mm(t2).view(output_shape) 4365 return output.mT.contiguous() if transpose else output 4366 else: 4367 return t1_folded.mv(t2).view(output_shape) 4368 4369 elif dim_tensor1 >= 1 and dim_tensor2 >= 1: 4370 # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list); 4371 # we track m1 vs m2 separately even though they must match for nicer error messages 4372 n = tensor1.size(-2) if dim_tensor1 > 1 else 1 4373 m1 = tensor1.size(-1) 4374 batch_tensor1 = tensor1.shape[:-2] 4375 m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1) 4376 p = tensor2.size(-1) if dim_tensor2 > 1 else 1 4377 4378 batch_tensor2: List[int] = [] 4379 # TODO: handling of slice 4380 for i in range(dim_tensor2 - 2): 4381 batch_tensor2.append(tensor2.size(i)) 4382 4383 # Same optimization for the gradients as that in should_fold 4384 # If we're going to broadcast, we force it to go through the should_fold branch 4385 if ( 4386 dim_tensor1 == 3 4387 and dim_tensor2 == 3 4388 and batch_tensor1[0] != batch_tensor2[0] 4389 ): 4390 if batch_tensor1[0] == 1 and tensor1.requires_grad: 4391 return matmul(tensor1.squeeze(0), tensor2) 4392 if batch_tensor2[0] == 1 and tensor2.requires_grad: 4393 return matmul(tensor1, tensor2.squeeze(0)) 4394 4395 # expand the batch portion (i.e. cut off matrix dimensions and expand rest) 4396 expand_batch_portion = list( 4397 torch.broadcast_shapes(batch_tensor1, batch_tensor2) 4398 ) 4399 4400 tensor1_expand_size = expand_batch_portion + [n, m1] 4401 4402 expand_batch_product = prod(expand_batch_portion) 4403 4404 # HACK: We need reshape with symint support 4405 tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape( 4406 expand_batch_product, n, m1 4407 ) 4408 4409 vector_rhs = dim_tensor2 == 1 4410 if vector_rhs: 4411 tensor2_expand_size = expand_batch_portion + [m2] 4412 tensor2_expanded = ( 4413 tensor2.expand(tensor2_expand_size) 4414 .reshape(expand_batch_product, m2) 4415 .unsqueeze(2) 4416 ) 4417 else: 4418 tensor2_expand_size = expand_batch_portion + [m2, p] 4419 tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape( 4420 expand_batch_product, m2, p 4421 ) 4422 4423 output_shape = expand_batch_portion 4424 if dim_tensor1 > 1: 4425 output_shape.append(n) 4426 4427 if dim_tensor2 > 1: 4428 output_shape.append(p) 4429 4430 if vector_rhs: 4431 return tensor1_expanded.bmm(tensor2_expanded).squeeze(-1).view(output_shape) 4432 else: 4433 return tensor1_expanded.bmm(tensor2_expanded).view(output_shape) 4434 else: 4435 torch._check(False, lambda: "both arguments to matmul need to be at least 1D") 4436 4437 4438@register_decomposition([aten.upsample_bicubic2d.default, aten.upsample_bicubic2d.out]) 4439@aten.upsample_bicubic2d.default.py_impl(DispatchKey.Autograd) 4440@out_wrapper() 4441@pw_cast_for_opmath 4442def upsample_bicubic2d_default( 4443 input: Tensor, 4444 output_size: Tuple[int, int], 4445 align_corners: bool, 4446 scale_h: Optional[float] = None, 4447 scale_w: Optional[float] = None, 4448) -> Tensor: 4449 # get dimensions of original image 4450 _, _, in_h, in_w = input.shape 4451 4452 # Calculate horizontal and vertical scaling factor 4453 h_scale_factor = _compute_scale(in_h, output_size[0], align_corners, scale_h) 4454 w_scale_factor = _compute_scale(in_w, output_size[1], align_corners, scale_w) 4455 4456 _, dtype = utils.elementwise_dtypes( 4457 input, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 4458 ) 4459 4460 # We have to create arange with int64 dtype and use .to in order to avoid 4461 # additional kernels creation in inductor and get a perf slowdown 4462 i = torch.arange(output_size[0], device=input.device).to(dtype=dtype) 4463 j = torch.arange(output_size[1], device=input.device).to(dtype=dtype) 4464 4465 x_float = _compute_source_index(w_scale_factor, j, align_corners) 4466 y_float = _compute_source_index(h_scale_factor, i, align_corners) 4467 y_float = y_float.unsqueeze(-1) 4468 4469 x = x_float.floor() 4470 y = y_float.floor() 4471 4472 # We should also clamp xscale/yscale 4473 # See guard_index_and_lambda in UpSample.h 4474 yscale = (y_float - y).clamp(0.0, 1.0) 4475 xscale = (x_float - x).clamp(0.0, 1.0) 4476 x = x.to(torch.int64) 4477 y = y.to(torch.int64) 4478 4479 iys_ofs = (y - 1, y, y + 1, y + 2) 4480 ixs_ofs = (x - 1, x, x + 1, x + 2) 4481 4482 weights_x = _upsample_get_cubic_coefficients(xscale) 4483 weights_y = _upsample_get_cubic_coefficients(yscale) 4484 4485 weights_precision_x, weights_precision_y = None, None 4486 if input.dtype == torch.uint8: 4487 weights_precision_x = _compute_weight_precision(weights_x) 4488 weights_precision_y = _compute_weight_precision(weights_y) 4489 4490 weights_x = [ 4491 (w * (1 << weights_precision_x) + torch.sign(w) * 0.5).to(torch.int16) 4492 for w in weights_x 4493 ] 4494 weights_y = [ 4495 (w * (1 << weights_precision_y) + torch.sign(w) * 0.5).to(torch.int16) 4496 for w in weights_y 4497 ] 4498 4499 def load_bounded(ys, xs): 4500 y_idx = torch.clamp(ys, 0, in_h - 1) 4501 x_idx = torch.clamp(xs, 0, in_w - 1) 4502 v = aten._unsafe_index(input, [None, None, y_idx, x_idx]) 4503 return v 4504 4505 def get_x_interp(y): 4506 src_x = tuple(load_bounded(y, x_ofs) for x_ofs in ixs_ofs) 4507 if input.dtype == torch.uint8: 4508 assert weights_precision_x is not None 4509 return _sum_tensors_uint8(src_x, weights_x, weights_precision_x) 4510 return _sum_tensors(c1 * c2 for (c1, c2) in zip(src_x, weights_x)) 4511 4512 src_y = tuple(get_x_interp(y_ofs) for y_ofs in iys_ofs) 4513 if input.dtype == torch.uint8: 4514 assert weights_precision_y is not None 4515 result = _sum_tensors_uint8(src_y, weights_y, weights_precision_y) 4516 else: 4517 result = _sum_tensors(c1 * c2 for (c1, c2) in zip(src_y, weights_y)) 4518 4519 # convert output to correct memory format, if necessary 4520 memory_format = utils.suggest_memory_format(input) 4521 result = result.contiguous(memory_format=memory_format) 4522 return result 4523 4524 4525@register_decomposition(aten.upsample_bicubic2d.vec) 4526@aten.upsample_bicubic2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) 4527@aten.upsample_bicubic2d.vec.py_impl(DispatchKey.Autograd) 4528@out_wrapper() 4529@pw_cast_for_opmath 4530def upsample_bicubic2d_vec( 4531 a: Tensor, 4532 output_size: Optional[Tuple[int, int]], 4533 align_corners: bool, 4534 scale_factors: Optional[Tuple[float, float]] = None, 4535) -> Tensor: 4536 torch._check( 4537 bool(output_size) + bool(scale_factors) == 1, 4538 lambda: "Must specify exactly one of output_size and scale_factors.", 4539 ) 4540 if output_size is None: 4541 assert scale_factors is not None 4542 output_size = cast( 4543 Tuple[int, int], 4544 tuple( 4545 sym_int(sym_float(w) * scale) 4546 for w, scale in zip(a.shape[2:], scale_factors) 4547 ), 4548 ) 4549 scale_h, scale_w = scale_factors if scale_factors else (None, None) 4550 return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w) 4551 4552 4553@register_decomposition(aten.reflection_pad1d) 4554@register_decomposition(aten.reflection_pad2d) 4555@register_decomposition(aten.reflection_pad3d) 4556@pw_cast_for_opmath 4557@out_wrapper() 4558def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: 4559 def idx(left, middle, right): 4560 dim_idx = torch.arange(-left, middle + right, device=a.device) 4561 return middle - 1 - (middle - 1 - dim_idx.abs()).abs() 4562 4563 return _reflection_or_replication_pad( 4564 a, 4565 padding, 4566 idx, 4567 ) 4568 4569 4570@register_decomposition(aten.replication_pad1d) 4571@register_decomposition(aten.replication_pad2d) 4572@register_decomposition(aten.replication_pad3d) 4573@pw_cast_for_opmath 4574@out_wrapper() 4575def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: 4576 def idx(left, middle, right): 4577 dim_idx = torch.arange(-left, middle + right, device=a.device) 4578 return torch.clamp(dim_idx, 0, middle - 1) 4579 4580 return _reflection_or_replication_pad( 4581 a, 4582 padding, 4583 idx, 4584 ) 4585 4586 4587def _reflection_or_replication_pad( 4588 a: Tensor, 4589 padding: Tuple[int, ...], 4590 idx_fn: Callable[[int, int, int], Tensor], 4591) -> Tensor: 4592 dim = len(padding) // 2 4593 torch._check( 4594 a.dim() in (dim + 1, dim + 2), 4595 lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input", 4596 ) 4597 inp_shape = a.shape[-dim:] 4598 nc_dim = a.dim() - dim 4599 4600 padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] 4601 padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] 4602 4603 result = a 4604 for i in range(dim): 4605 idx: List[Any] = [None] * result.dim() 4606 idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i]) 4607 result = aten._unsafe_index(result, idx) 4608 4609 # convert output to correct memory format, if necessary 4610 memory_format = utils.suggest_memory_format(result) 4611 result = result.contiguous(memory_format=memory_format) 4612 return result 4613 4614 4615@register_decomposition(aten.reflection_pad1d_backward) 4616@register_decomposition(aten.reflection_pad2d_backward) 4617@register_decomposition(aten.reflection_pad3d_backward) 4618@out_wrapper("grad_input") 4619def _reflection_pad_backward(grad_output, x, padding): 4620 dim = len(padding) // 2 4621 4622 dhw = [h - 1 for h in x.shape[-dim:]] 4623 4624 padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] 4625 padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] 4626 4627 indices = [] 4628 for i in range(x.ndim): 4629 view_shape = [1] * x.ndim 4630 view_shape[i] = -1 4631 indices.append(torch.arange(x.shape[i], device=x.device).view(view_shape)) 4632 4633 b = indices[:-dim] 4634 xyz = indices[-dim:] 4635 4636 def index_range_condition(index_range): 4637 i, lb, ub = index_range 4638 return torch.logical_and(i >= lb, i <= ub) 4639 4640 # Areas after reflection: 4641 # 4642 # top-left | top | top-right 4643 # ----------------------------------------- 4644 # left | center | right 4645 # ----------------------------------------- 4646 # bottom-left | bottom | bottom-right 4647 # 4648 # The center area is the original matrix. Other areas are reflections. 4649 4650 center = [xyz[i] + padding_left[i] for i in range(dim)] 4651 left_reflect = [padding_left[i] - xyz[i] for i in range(dim)] 4652 right_reflect = [2 * dhw[i] + padding_left[i] - xyz[i] for i in range(dim)] 4653 4654 # Accumulate gradients from different areas 4655 # If some of the padding is negative, center load is not always valid 4656 range_c = [ 4657 (center[i], 0, dhw[i] + padding_left[i] + padding_right[i]) for i in range(dim) 4658 ] 4659 cond = functools.reduce( 4660 aten.logical_and, [index_range_condition(range_c[i]) for i in range(dim)] 4661 ) 4662 grad = aten._unsafe_masked_index(grad_output, cond, b + center, 0.0) 4663 4664 def accumulate(grad, out, index_ranges): 4665 # If the upper bound is less than the lower bound, we can get rid of one accumulation. 4666 # This happens when the padding size is zero. 4667 for i in range(dim): 4668 upper_less_than_lower = index_ranges[i][2] < index_ranges[i][1] 4669 if isinstance(upper_less_than_lower, bool) and upper_less_than_lower: 4670 return grad 4671 4672 cond = functools.reduce( 4673 aten.logical_and, 4674 [index_range_condition(index_range) for index_range in index_ranges], 4675 ) 4676 g = aten._unsafe_masked_index(grad_output, cond, b + out, 0.0) 4677 return grad + g 4678 4679 for area in itertools.product(*[[-1, 0, 1] for _ in range(dim)]): 4680 if area == tuple([0] * dim): 4681 # center, this is already done. 4682 continue 4683 4684 outs = [] 4685 index_ranges = [] 4686 4687 for i in range(dim): 4688 if area[i] == 0: 4689 out = center[i] 4690 index_range = range_c[i] 4691 elif area[i] == -1: 4692 out = left_reflect[i] 4693 index_range = (xyz[i], 1, padding_left[i]) 4694 elif area[i] == 1: 4695 out = right_reflect[i] 4696 index_range = (xyz[i], dhw[i] - padding_right[i], dhw[i] - 1) 4697 4698 outs.append(out) # type: ignore[possibly-undefined] 4699 index_ranges.append(index_range) # type: ignore[possibly-undefined] 4700 4701 grad = accumulate(grad, outs, index_ranges) 4702 4703 return grad 4704 4705 4706@register_decomposition(aten.aminmax) 4707@out_wrapper("min", "max") 4708def aminmax(self, *, dim=None, keepdim=False): 4709 amin = torch.amin(self, dim=dim, keepdim=keepdim) 4710 amax = torch.amax(self, dim=dim, keepdim=keepdim) 4711 return amin, amax 4712 4713 4714@register_decomposition(aten.nansum) 4715@out_wrapper() 4716def nansum(self, dim=None, keepdim=False, *, dtype=None): 4717 return aten.sum(torch.where(torch.isnan(self), 0, self), dim, keepdim, dtype=dtype) 4718 4719 4720@register_decomposition([aten.arange.default, aten.arange.out]) 4721@out_wrapper() 4722def arange_default( 4723 end: NumberType, 4724 *, 4725 dtype: Optional[torch.dtype] = None, 4726 layout: torch.layout = torch.strided, 4727 device: Optional[torch.device] = None, 4728 pin_memory: bool = False, 4729): 4730 return aten.arange.start_step( 4731 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 4732 ) 4733 4734 4735@register_decomposition([aten.arange.start]) 4736def arange_start( 4737 start: NumberType, 4738 end: NumberType, 4739 *, 4740 dtype: Optional[torch.dtype] = None, 4741 layout: torch.layout = torch.strided, 4742 device: Optional[torch.device] = None, 4743 pin_memory: bool = False, 4744): 4745 return aten.arange.start_step( 4746 start, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 4747 ) 4748 4749 4750@register_decomposition(out_dtype) 4751def out_dtype_decomp(*args, **kwargs): 4752 from torch._higher_order_ops.out_dtype import out_dtype_dense 4753 4754 return out_dtype_dense(*args, **kwargs) 4755 4756 4757@register_decomposition(aten.multi_margin_loss) 4758@aten.multi_margin_loss.default.py_impl(DispatchKey.Autograd) 4759@out_wrapper() 4760def multi_margin_loss( 4761 input: Tensor, 4762 target: Tensor, 4763 p: NumberType = 1, 4764 margin: NumberType = 1, 4765 weight: Optional[Tensor] = None, 4766 reduction: int = Reduction.MEAN.value, 4767) -> Tensor: 4768 input = torch.atleast_2d(input) 4769 target = torch.atleast_1d(target) 4770 nframe = input.shape[0] 4771 dim = input.shape[1] 4772 torch._check(p == 1 or p == 2, lambda: "only p == 1 and p == 2 supported") 4773 torch._check( 4774 input.ndim == 2 and dim != 0, 4775 lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {input.shape}", 4776 ) 4777 torch._check( 4778 target.ndim == 1 and target.numel() == nframe, 4779 lambda: f"inconsistent target size, expected {nframe} but got {target.shape}", 4780 ) 4781 if weight is not None: 4782 weight = torch.atleast_1d(weight) 4783 torch._check( 4784 weight.ndim == 1 and weight.numel() == dim, # type: ignore[union-attr] 4785 lambda: f"inconsistent weight size, expected {dim} but got {weight.shape}", # type: ignore[union-attr] 4786 ) 4787 target = target.unsqueeze(1) 4788 u = torch.gather(input, dim=1, index=target) 4789 z = margin - u + input 4790 z = z.clamp_min(0) 4791 z = z if p == 1 else z * z 4792 if weight is not None: 4793 z = z * weight[target] 4794 idx = torch.arange(dim, device=input.device) 4795 z = torch.where(idx != target, z, 0) 4796 if reduction == Reduction.MEAN.value: 4797 return z.mean() 4798 elif reduction == Reduction.SUM.value: 4799 return z.sum() / z.shape[1] 4800 else: 4801 return z.mean(dim=1) 4802 4803 4804@register_decomposition(aten.multilabel_margin_loss_forward) 4805@aten.multilabel_margin_loss_forward.default.py_impl(DispatchKey.Autograd) 4806@out_wrapper("output", "is_target") 4807def multilabel_margin_loss_forward( 4808 input: Tensor, 4809 target: Tensor, 4810 reduction: int, 4811) -> Tuple[Tensor, Tensor]: 4812 orig_input_shape = input.shape 4813 orig_target_shape = target.shape 4814 input = torch.atleast_2d(input) 4815 target = torch.atleast_2d(target) 4816 dim = input.shape[1] 4817 torch._check( 4818 len(orig_input_shape) <= 2 and dim != 0, 4819 lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {orig_input_shape}", 4820 ) 4821 torch._check( 4822 len(orig_target_shape) <= 2 and orig_target_shape == orig_input_shape, 4823 lambda: f"inconsistent target size: {orig_target_shape} for input of size: {orig_input_shape}", 4824 ) 4825 # ignores labels after the first -1, detects when -1 is not present 4826 idx = torch.arange(dim, device=target.device) 4827 is_end = target == -1 4828 end_idx = torch.amin(torch.where(is_end, idx, dim), dim=-1, keepdim=True) 4829 # target indices 4830 target_mask = idx < end_idx 4831 # masks target to be able to use gather, which doesn't allow -1 4832 tidx0 = torch.where(target_mask, target, 0) 4833 u = torch.gather(input, dim=-1, index=tidx0) 4834 # is_target 4835 tidx1 = torch.where(target_mask, target, -1) 4836 is_target = torch.any(idx == tidx1.unsqueeze(dim=-1), dim=1) 4837 # loss 4838 z = 1.0 - u.T.unsqueeze(dim=-1) + input 4839 z = z.clamp_min(0) 4840 z = z / dim 4841 # masks loss 4842 z = torch.where(is_target, 0, z) 4843 # reduction 4844 if reduction == Reduction.MEAN.value: 4845 z = z.sum(dim=(0, -1)).mean() 4846 elif reduction == Reduction.SUM.value: 4847 z = z.sum() 4848 else: 4849 z = z.sum(dim=(0, -1)) 4850 # result 4851 is_target = is_target.to(input.dtype).reshape(orig_target_shape) 4852 return z, is_target 4853 4854 4855# scaled_dot_product_attention used to be decomposed in pre-autograd, given that 4856# it calls _scaled_dot_product_attention_math and 4857# _scaled_dot_product_attention_math only has a CompositeImplicitAutograd 4858# kernel. As a result it's decomposed into ops with finer granularity. 4859# However recent PRs (#103826 #105131 #115913) added new logic in 4860# scaled_dot_product_attention and now it calls 4861# _scaled_dot_product_flash_attention_for_cpu in export path. This results 4862# in _scaled_dot_product_flash_attention_for_cpu showing up in export result. 4863# This decomposition ensures scaled_dot_product_attention is still decomposed 4864# the same way as before, i.e., going through 4865# _scaled_dot_product_attention_math. Notice that this decomp rule should be 4866# excluded by inductor. 4867@register_decomposition(aten._scaled_dot_product_flash_attention_for_cpu.default) 4868def scaled_dot_product_flash_attention_for_cpu( 4869 query: Tensor, 4870 key: Tensor, 4871 value: Tensor, 4872 dropout_p: float = 0.0, 4873 is_causal: bool = False, 4874 *, 4875 attn_mask: Optional[Tensor] = None, 4876 scale: Optional[float] = None, 4877) -> Tuple[Tensor, Tensor]: 4878 dtype = query.dtype 4879 torch._check( 4880 torch.is_floating_point(query), 4881 lambda: f"query must be FP32, FP64, BF16, FP16 but got {query.dtype}", 4882 ) 4883 torch._check( 4884 query.dim() == 4 and key.dim() == 4 and value.dim() == 4, 4885 lambda: f"q, k, v must be a 4 dimensional tensor, got {query.dim()}, {key.dim()}, {value.dim()}", 4886 ) 4887 torch._check( 4888 dropout_p == 0.0, lambda: f"dropout probability must be zero, got {dropout_p}" 4889 ) 4890 torch._check( 4891 query.shape[3] == value.shape[3] and key.shape[3] == value.shape[3], 4892 lambda: "q, k, v should have the same head size", 4893 ) 4894 4895 output, attn = aten._scaled_dot_product_attention_math.default( 4896 query, 4897 key, 4898 value, 4899 attn_mask=attn_mask, 4900 dropout_p=dropout_p, 4901 is_causal=is_causal, 4902 dropout_mask=None, 4903 scale=scale, 4904 ) 4905 # Why this change? 4906 # In pre-dispatch export scaled_dot_product_attention is executed via 4907 # * flash_attention. 4908 # flash_attention allocates output tensor as (N, L, H, E) 4909 # it then transposes that to get (N, H, L, E) which is supposed to be the return 4910 # tensor dim for scaled_dot_product_attention 4911 # assume x: [N, H, L, E] is the output sdpa 4912 # In MHA code, this output is then permuted via (2, 0, 1, 3) to get 4913 # (L, N, H, E) dim tensor 4914 # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via 4915 # x = x.view(L * N, H * E) 4916 # During pre autograd dispatch call to contiguous is not traced because 4917 # flash_attention output after the x.permute is already contiguous 4918 # on which the view is valid 4919 # However, during 2nd stage export, post-dispatch, we run _match variant 4920 # instead of flash* to get the decomposition. _match variant returns 4921 # x: [N, H, L, E] applying x.permute(2, 0, 1, 3) returns 4922 # x: [L, N, H, E] and without converting this to contiguous tensor 4923 # subsequent view is not valid and the export fails 4924 # solution is to maintain the return tensor view from the decomp to be 4925 # exactly same as *flash* variant. 4926 # flash variants output is contiguous as [N, L, H, E] 4927 # _match variant out is contiguous as [N, H, L, E] 4928 # out = out.transpose(1, 2).contiguous gets output as contiguous 4929 # in [N, L, H, E]. 4930 # Subsrequent transpose(1, 2) then returns a view on which 4931 # aforementioned code snippet, as showm below, is valid 4932 # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via 4933 # x = x.view(L * N, H * E) 4934 4935 # Really the invariant you want to maintain is: 4936 # pre-dispatch op-output and its decomposed representation must 4937 # return tensor with same view and dims 4938 output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format) 4939 return (output.transpose(1, 2), attn) 4940 4941 4942def register_inplace(aten_op, outplace_op): 4943 @register_decomposition(aten_op) 4944 def inplace_op(*args, **kwargs): 4945 out = outplace_op(*args, **kwargs) 4946 return args[0].copy_(out) 4947 4948 return inplace_op 4949 4950 4951@register_decomposition([aten.baddbmm]) 4952@out_wrapper() 4953@pw_cast_for_opmath 4954def baddbmm(self, batch1, batch2, beta=1, alpha=1): 4955 if not self.is_floating_point() and not self.is_complex(): 4956 beta = int(beta) 4957 alpha = int(alpha) 4958 result = torch.bmm(batch1, batch2) 4959 if not isinstance(alpha, numbers.Number) or alpha != 1: 4960 result = result * alpha 4961 if beta == 0: 4962 return result 4963 if not isinstance(beta, numbers.Number) or beta != 1: 4964 self = self * beta 4965 return self + result 4966 4967 4968@register_decomposition(aten.floor_divide) 4969@out_wrapper() 4970def floor_divide(self, other): 4971 return torch.div(self, other, rounding_mode="floor") 4972 4973 4974@register_decomposition(aten.sym_numel) 4975def sym_numel(t): 4976 return functools.reduce(operator.mul, t.shape, 1) 4977 4978 4979@register_decomposition([aten.sum.default, aten.sum.out]) 4980def sum_default( 4981 self: Tensor, 4982 *, 4983 dtype: Optional[torch.dtype] = None, 4984 out: Optional[Tensor] = None, 4985) -> Tensor: 4986 if out is None: 4987 return aten.sum.dim_IntList(self, [], dtype=dtype) 4988 else: 4989 return aten.sum.IntList_out(self, [], dtype=dtype, out=out) 4990 4991 4992@register_decomposition([aten.squeeze.default, aten.squeeze.dim]) 4993def squeeze_default(self: Tensor, dim: Optional[int] = None): 4994 # handle a scalar directly 4995 if not isinstance(self, torch.Tensor): 4996 return self 4997 # perform squeeze 4998 if dim is None: 4999 return aten.squeeze.dims(self, list(range(self.dim()))) 5000 else: 5001 return aten.squeeze.dims(self, [dim]) 5002 5003 5004@register_decomposition(torch.ops.aten._weight_norm_interface) 5005def _weight_norm_interface(v, g, dim=0): 5006 # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58 5007 keep_dim = tuple(i for i in range(len(v.shape)) if i != dim) 5008 # align with cuda behavior, keep norm in 'float' when g is 'bfloat16' 5009 norm_dtype = torch.float if g.dtype == torch.bfloat16 else None 5010 norm = v.norm(2, keep_dim, keepdim=True, dtype=norm_dtype) 5011 return v * (g / norm.to(g.dtype)), norm 5012 5013 5014@register_decomposition(aten.isin) 5015@out_wrapper() 5016def isin(elements, test_elements, *, assume_unique=False, invert=False): 5017 # handle when either elements or test_elements are Scalars (they can't both be) 5018 if not isinstance(elements, torch.Tensor): 5019 elements = torch.tensor(elements, device=test_elements.device) 5020 if not isinstance(test_elements, torch.Tensor): 5021 test_elements = torch.tensor(test_elements, device=elements.device) 5022 5023 if test_elements.numel() < 10.0 * pow(elements.numel(), 0.145): 5024 return isin_default(elements, test_elements, invert=invert) 5025 else: 5026 return isin_sorting( 5027 elements, test_elements, assume_unique=assume_unique, invert=invert 5028 ) 5029 5030 5031def isin_default(elements, test_elements, *, invert=False): 5032 if elements.numel() == 0: 5033 return torch.empty_like(elements, dtype=torch.bool) 5034 5035 x = elements.view(*elements.shape, *((1,) * test_elements.ndim)) 5036 if not invert: 5037 cmp = x == test_elements 5038 else: 5039 cmp = x != test_elements 5040 dim = tuple(range(-1, -test_elements.ndim - 1, -1)) 5041 return cmp.any(dim=dim) 5042 5043 5044def isin_sorting(elements, test_elements, *, assume_unique=False, invert=False): 5045 elements_flat = elements.flatten() 5046 test_elements_flat = test_elements.flatten() 5047 if assume_unique: 5048 # This is the same as the aten implementation. For 5049 # assume_unique=False, we cannot use unique() here, so we use a 5050 # version with searchsorted instead. 5051 all_elements = torch.cat([elements_flat, test_elements_flat]) 5052 sorted_elements, sorted_order = torch.sort(all_elements, stable=True) 5053 5054 duplicate_mask = sorted_elements[1:] == sorted_elements[:-1] 5055 duplicate_mask = torch.constant_pad_nd(duplicate_mask, [0, 1], False) 5056 5057 if invert: 5058 duplicate_mask = duplicate_mask.logical_not() 5059 5060 mask = torch.empty_like(duplicate_mask) 5061 mask = mask.index_copy(0, sorted_order, duplicate_mask) 5062 5063 return mask[0 : elements.numel()] 5064 else: 5065 sorted_test_elements, _ = torch.sort(test_elements_flat) 5066 idx = torch.searchsorted(sorted_test_elements, elements_flat) 5067 test_idx = torch.where(idx < sorted_test_elements.numel(), idx, 0) 5068 cmp = sorted_test_elements[test_idx] == elements_flat 5069 cmp = cmp.logical_not() if invert else cmp 5070 return cmp.reshape(elements.shape) 5071 5072 5073@register_decomposition(aten.take) 5074@out_wrapper() 5075def take(self, index): 5076 flattened = self.reshape(-1) 5077 return flattened[index] 5078 5079 5080@register_decomposition(aten.resize_as) 5081def resize_as(self, other, memory_format=None): 5082 if memory_format is None: 5083 memory_format = torch.contiguous_format 5084 if memory_format == torch.preserve_format: 5085 memory_format = suggest_memory_format(other) 5086 return aten.resize(self, other.shape, memory_format=memory_format) 5087 5088 5089register_inplace(aten.addbmm_, aten.addbmm) 5090register_inplace(aten.addmm_, aten.addmm) 5091register_inplace(aten.addmv_, aten.addmv) 5092register_inplace(aten.baddbmm_, aten.baddbmm) 5093register_inplace(aten.fill_, aten.fill) 5094register_inplace(aten.gelu_, aten.gelu) 5095register_inplace(aten.hardswish_, aten.hardswish) 5096register_inplace(aten.hardtanh_, aten.hardtanh) 5097register_inplace(aten.hardsigmoid_, aten.hardsigmoid) 5098register_inplace(aten.__iand__, aten.__and__) 5099register_inplace(aten.__ilshift__, aten.__lshift__) 5100register_inplace(aten.index_put_, aten.index_put) 5101register_inplace(aten.index_reduce_, aten.index_reduce) 5102register_inplace(aten.__ior__, aten.__or__) 5103register_inplace(aten.__irshift__, aten.__rshift__) 5104register_inplace(aten.__ixor__, aten.__xor__) 5105register_inplace(aten.leaky_relu_, aten.leaky_relu) 5106register_inplace(aten.logit_, aten.logit) 5107register_inplace(aten.relu_, aten.relu) 5108register_inplace(aten.renorm_, aten.renorm) 5109register_inplace(aten.round_, aten.round) 5110register_inplace(aten.scatter_, aten.scatter) 5111register_inplace(aten.scatter_add_, aten.scatter_add) 5112register_inplace(aten.scatter_reduce_, aten.scatter_reduce) 5113register_inplace(aten.silu_, aten.silu) 5114