1# mypy: allow-untyped-decorators 2import functools 3import logging 4import math 5import sys 6import typing 7from typing import Any, Callable, Dict, List, Optional, Tuple, Union 8 9import torch 10import torch._decomp as decomp 11import torch._prims_common as utils 12import torch.ao.quantization.fx._decomposed 13from torch._decomp import ( 14 core_aten_decompositions, 15 get_decompositions, 16 remove_decompositions, 17) 18from torch._decomp.decompositions import ( 19 _grid_sampler_2d as decomp_grid_sampler_2d, 20 pw_cast_for_opmath, 21) 22from torch._decomp.decompositions_for_rng import extra_random_decomps 23from torch._dynamo.utils import counters 24from torch._higher_order_ops.out_dtype import out_dtype 25from torch._inductor.utils import pad_listlike 26from torch._prims_common import ( 27 elementwise_dtypes, 28 ELEMENTWISE_TYPE_PROMOTION_KIND, 29 type_to_dtype, 30) 31from torch.fx.experimental.symbolic_shapes import definitely_true, guard_size_oblivious 32 33from . import config, inductor_prims 34from .utils import ( 35 is_gpu, 36 needs_fallback_due_to_atomic_add_limitations, 37 use_scatter_fallback, 38) 39 40 41log = logging.getLogger(__name__) 42aten = torch.ops.aten 43prims = torch.ops.prims 44quantized = torch.ops.quantized 45_quantized = torch.ops._quantized 46quantized_decomposed = torch.ops.quantized_decomposed 47 48inductor_decompositions = get_decompositions( 49 [ 50 aten._adaptive_avg_pool2d_backward, 51 aten.addmv, 52 aten.arange, 53 aten.bitwise_and_, 54 aten.bitwise_or_, 55 aten.clamp_min_, 56 aten.dist, 57 aten.empty_like, 58 aten.flip, 59 aten.gelu, 60 aten.hardtanh, 61 aten.index_select, 62 aten.lcm, 63 aten.leaky_relu, 64 aten.linalg_vector_norm, 65 aten._log_softmax, 66 aten.max_pool2d_with_indices_backward, 67 aten._native_batch_norm_legit, 68 aten._native_batch_norm_legit_functional, 69 aten._native_batch_norm_legit_no_training, 70 aten._batch_norm_with_update, 71 aten._batch_norm_with_update_functional, 72 aten._batch_norm_no_update, 73 aten.batch_norm_backward, 74 aten.native_batch_norm, 75 aten.native_group_norm, 76 aten.native_layer_norm, 77 aten.nll_loss2d_backward, 78 aten._softmax, 79 aten.sin_, 80 aten.sqrt_, 81 out_dtype, 82 aten._to_copy, 83 aten.tril_indices, 84 aten.triu_indices, 85 aten.upsample_bilinear2d.vec, 86 quantized.linear_dynamic_fp16_unpacked_weight, 87 _quantized.wrapped_quantized_linear, 88 ] 89) 90decompositions = {**core_aten_decompositions(), **inductor_decompositions} 91 92# Remove unwanted decompositions included via the core ATen decompositions from 93# the Inductor decomp table. 94decomps_to_exclude = [ 95 aten._unsafe_index, 96 aten._unsafe_masked_index, 97 aten._unsafe_masked_index_put_accumulate, 98 aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py 99 aten._softmax_backward_data, 100 aten.clamp_max, 101 aten.clamp_min, 102 aten.glu, # inductor lowers this directly 103 aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass 104 aten.slice_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass 105 aten.split.Tensor, # inductor lowers this directly 106 aten.squeeze, # inductor lowers this directly 107 aten.sum, # inductor lowers this directly 108 aten.unbind, # inductor lowers this directly 109] 110 111remove_decompositions(decompositions, decomps_to_exclude) 112 113 114def register_decomposition( 115 ops: List[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]] 116) -> Callable[..., Any]: 117 for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined] 118 if op in decompositions: 119 log.warning("duplicate decomp: %s", ops) 120 return decomp.register_decomposition(ops, decompositions) 121 122 123# TODO: for now, inductor doesn't handle asserts 124# because the condition is symbol -> tensor in the graph. 125@register_decomposition([aten._assert_async.msg]) 126def assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None: 127 return 128 129 130# Following `assert_async_msg_decomp` and implement as non-op. 131@register_decomposition([aten._functional_assert_async.msg]) 132def functional_assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None: 133 return 134 135 136@register_decomposition([aten.sym_constrain_range_for_size.default]) 137def sym_constrain_range_for_size( 138 symbol: torch.SymInt, 139 *, 140 min: Optional[torch.types.Number] = None, 141 max: Optional[torch.types.Number] = None, 142) -> None: 143 return 144 145 146@register_decomposition([aten.clamp]) 147@pw_cast_for_opmath 148def clamp( 149 x: torch.Tensor, 150 min: Optional[torch.types.Number] = None, 151 max: Optional[torch.types.Number] = None, 152) -> torch.Tensor: 153 if min is not None: 154 x = x.clamp_min(min) 155 if max is not None: 156 x = x.clamp_max(max) 157 return x 158 159 160@register_decomposition([aten.full]) 161def full( 162 size: List[Union[int, torch.SymInt]], 163 fill_value: torch.types.Number, 164 **kwargs: Any, 165) -> torch.Tensor: 166 dtype = kwargs.get("dtype") 167 if dtype is None: 168 kwargs["dtype"] = type_to_dtype(type(fill_value)) 169 return torch.full(size, fill_value, **kwargs) 170 return NotImplemented 171 172 173# Not really sure how to put this into the main library. PrimTorch wants 174# empty_permuted to go to the prim, and typically users don't really want 175# to decompose to empty_strided (but inductor is OK with it, because we are 176# cool with strides and everything goes to empty_strided) 177@register_decomposition([aten.empty_permuted.default]) 178def empty_permuted( 179 size: List[Union[int, torch.SymInt]], 180 physical_layout: List[int], 181 **kwargs: Any, 182) -> torch.Tensor: 183 perm = [0] * len(size) 184 for p, l in enumerate(physical_layout): 185 perm[l] = p 186 return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm) 187 188 189@register_decomposition([aten.convolution_backward]) 190def convolution_backward( 191 grad_output: torch.Tensor, 192 input: torch.Tensor, 193 weight: torch.Tensor, 194 bias_sizes: List[int], 195 stride: Union[int, List[int]], 196 padding: Union[int, List[int]], 197 dilation: Union[int, List[int]], 198 transposed: bool, 199 output_padding: List[int], 200 groups: int, 201 output_mask: List[bool], 202) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 203 if not output_mask[2] or not is_gpu(grad_output.device.type): 204 return NotImplemented 205 grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim()))) 206 grad_inp, grad_weight, _ = aten.convolution_backward( 207 grad_output, 208 input, 209 weight, 210 bias_sizes, 211 stride, 212 padding, 213 dilation, 214 transposed, 215 output_padding, 216 groups, 217 [output_mask[0], output_mask[1], False], 218 ) 219 return (grad_inp, grad_weight, grad_bias) 220 221 222@register_decomposition([aten.round.decimals]) 223def round_dec(x: torch.Tensor, decimals: int = 0) -> torch.Tensor: 224 ten_pow_decimals = 10.0**decimals 225 return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals) 226 227 228@register_decomposition([aten.bmm]) 229@pw_cast_for_opmath 230def bmm( 231 self: torch.Tensor, 232 batch2: torch.Tensor, 233) -> torch.Tensor: 234 if config.coordinate_descent_tuning: 235 if guard_size_oblivious(self.shape[1] == 1) or guard_size_oblivious( 236 batch2.shape[2] == 1 237 ): 238 out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2) 239 return out 240 if self.device.type == "cpu": 241 if guard_size_oblivious(self.size(1) == 1) and guard_size_oblivious( 242 batch2.size(-1) == 1 243 ): 244 counters["inductor"]["decompose_bmm"] += 1 245 return torch.sum( 246 self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True 247 ).unsqueeze(1) 248 return NotImplemented 249 250 251@register_decomposition([aten.addmm]) 252@pw_cast_for_opmath 253def addmm( 254 self: torch.Tensor, 255 mat1: torch.Tensor, 256 mat2: torch.Tensor, 257 beta: torch.types.Number = 1, 258 alpha: torch.types.Number = 1, 259) -> torch.Tensor: 260 if self.device.type == "cpu": 261 if guard_size_oblivious(mat1.size(0) == 1) and guard_size_oblivious( 262 mat2.size(-1) == 1 263 ): 264 counters["inductor"]["decompose_addmm"] += 1 265 out = torch.sum( 266 mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True 267 ).unsqueeze(0) 268 return alpha * out + beta * self 269 if ( 270 guard_size_oblivious(mat1.size(0) == 1) 271 and definitely_true(mat2.size(0) <= 16) 272 and definitely_true(mat2.size(1) <= 16) 273 ): 274 counters["inductor"]["decompose_addmm"] += 1 275 out = (mat1.T * mat2).sum(dim=0, keepdim=True) 276 return alpha * out + beta * self 277 return NotImplemented 278 279 280@register_decomposition([aten.mm]) 281@pw_cast_for_opmath 282def mm( 283 self: torch.Tensor, 284 input2: torch.Tensor, 285) -> torch.Tensor: 286 # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning. 287 # todo: Look into why and fix it (hopefully) 288 if config.coordinate_descent_tuning: 289 if guard_size_oblivious(self.shape[0] == 1) or guard_size_oblivious( 290 input2.shape[1] == 1 291 ): 292 return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1) 293 if self.device.type == "cpu": 294 if ( 295 guard_size_oblivious(self.size(-1) == 1) 296 and guard_size_oblivious(self.size(0) > 0) 297 and guard_size_oblivious(input2.size(0) == 1) 298 and (self.dtype == input2.dtype) 299 and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32) 300 ): 301 counters["inductor"]["decompose_mm"] += 1 302 return torch.cat([self[i, :] * input2 for i in range(self.size(0))]) 303 if guard_size_oblivious(self.size(0) == 1) and guard_size_oblivious( 304 input2.size(-1) == 1 305 ): 306 counters["inductor"]["decompose_mm"] += 1 307 return torch.sum( 308 self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True 309 ).unsqueeze(0) 310 return NotImplemented 311 312 313# This pass does two things: 314# - Eliminate cat when there is only one tensor input 315# - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we 316# don't remove ALL empty tensors, only the naughty ones) 317@register_decomposition([aten.cat.default]) 318def cat( 319 tensors: List[torch.Tensor], 320 dim: int = 0, 321) -> torch.Tensor: 322 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 323 324 def non_empty_tensor(x: torch.Tensor) -> bool: 325 # For better or worse, this is a valid cat: 326 # 327 # torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)]) 328 # 329 # We'd like to eliminate naughtiness like this for downstream passes 330 # like split_cat. The easiest way is to just drop such inputs 331 # (guarding that they are non-zero). 332 # 333 # Is it permissible for this filtering to be size-oblivious? A case 334 # where this could matter is cat([(2, 2), (u0,)], dim=0); if u0 335 # happened to be zero, we would have liked to have filtered it out. 336 # But actually, the ONLY way this could have passed is if u0 == 0, 337 # so by the time we get here we have already installed a deferred 338 # runtime assert forcing u0 to be zero. So if this hasn't happened, 339 # we know that the unbacked SymInt has appropriate size and there are 340 # no problems. 341 if len(x.shape) == 1 and guard_size_oblivious(x.shape[0] == 0): 342 return False 343 344 if dim < len(x.shape) and guard_size_oblivious(x.shape[dim] == 0): 345 return False 346 347 return True 348 349 filtered_tensors = list(filter(non_empty_tensor, tensors)) 350 351 if len(filtered_tensors) == 1: 352 return filtered_tensors[0].clone() 353 elif 1 < len(filtered_tensors) < len(tensors): 354 # on the first call, when we remove empty tensors, we redispatch recursively 355 return aten.cat.default(filtered_tensors, dim) 356 357 # optimization, avoid concat for single, repeated input 358 if len(filtered_tensors) > 1 and all( 359 t is filtered_tensors[0] for t in filtered_tensors 360 ): 361 inp = filtered_tensors[0] 362 shape = list(inp.shape) 363 dim = dim + len(inp.shape) if dim < 0 else dim 364 shape.insert(dim, len(filtered_tensors)) 365 return inp.unsqueeze(dim).expand(*shape).flatten(dim, dim + 1).clone() 366 367 # when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed) 368 return NotImplemented 369 370 371@register_decomposition([aten.angle]) 372def angle(x: torch.Tensor) -> torch.Tensor: 373 if x.is_complex(): 374 return torch.where( 375 torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real) 376 ) 377 378 # when x is real number 379 # if x >= 0, return 0 380 # if x < 0, return pi 381 # if x is nan, return nan 382 _, dtype = elementwise_dtypes( 383 x, 384 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 385 ) 386 pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device) 387 ret = torch.where(x < 0, pi, 0.0) 388 return torch.where(torch.isnan(x), float("nan"), ret) 389 390 391@register_decomposition([aten.add]) 392def add( 393 x: torch.Tensor, 394 y: torch.Tensor, 395 *, 396 alpha: Optional[torch.types.Number] = None, 397) -> torch.Tensor: 398 # Require both x and y to be complex tensors. 399 x_is_complex_tensor = torch.is_tensor(x) and x.is_complex() 400 y_is_complex_tensor = torch.is_tensor(y) and y.is_complex() 401 if not x_is_complex_tensor or not y_is_complex_tensor: 402 return NotImplemented 403 z = y 404 if alpha is not None: 405 z = alpha * y 406 complex_type = torch.promote_types(x.dtype, y.dtype) 407 408 # For complex typed `x`, `x.view(x.real.dtype)` doubles the last dimension and can cause problem 409 # when broadcasting the add. 410 def reshape_tensor_complex(tensor: torch.Tensor) -> torch.Tensor: 411 """Reshape tensor from [*initial_dims, last_dim] to *initial_dims, last_dim/2, 2]""" 412 # Get the current shape of the tensor 413 *initial_dims, last_dim = tensor.shape 414 415 # Check if the last dimension is even. We should never reach here since `x.view(x.real.dtype)` 416 # doubles the last dimension for complex numbers. 417 if last_dim % 2 != 0: 418 raise AssertionError( 419 "The size of the last dimension must be even to reshape it to [..., last_dim/2, 2]" 420 ) 421 422 # Reshape the tensor 423 new_shape = (*initial_dims, last_dim // 2, 2) 424 reshaped_tensor = tensor.view(new_shape) 425 return reshaped_tensor 426 427 x_reshaped = reshape_tensor_complex(x.view(x.real.dtype)) 428 z_reshaped = reshape_tensor_complex(z.view(y.real.dtype)) 429 result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type) 430 return result 431 432 433@register_decomposition([aten.conj_physical]) 434def conj_physical(self: torch.Tensor) -> torch.Tensor: 435 assert not self.is_complex(), "TODO: implement this" 436 return self 437 438 439@register_decomposition([aten.lift, aten.detach_]) 440def lift(self: torch.Tensor) -> torch.Tensor: 441 return self 442 443 444@register_decomposition([aten.bernoulli.default]) 445def bernoulli( 446 self: torch.Tensor, 447 *, 448 generator: Optional[torch.Generator] = None, 449) -> torch.Tensor: 450 assert generator is None 451 return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype) 452 453 454@register_decomposition([aten.fmin, prims.fmin]) 455def fmin(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: 456 return torch.where(torch.isnan(other) | (other > self), self, other) 457 458 459@register_decomposition([aten.fmax, prims.fmax]) 460def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: 461 return torch.where(torch.isnan(other) | (other < self), self, other) 462 463 464@register_decomposition(aten.amax) 465def amax( 466 self: torch.Tensor, 467 dim: Optional[int] = None, 468 keepdim: bool = False, 469) -> torch.Tensor: 470 if self.dtype == torch.bool: 471 return torch.any(self, dim=dim, keepdim=keepdim) 472 return NotImplemented 473 474 475@register_decomposition(aten.amin) 476def amin( 477 self: torch.Tensor, 478 dim: Optional[int] = None, 479 keepdim: bool = False, 480) -> torch.Tensor: 481 if self.dtype == torch.bool: 482 return torch.all(self, dim=dim, keepdim=keepdim) 483 return NotImplemented 484 485 486@register_decomposition([aten.narrow_copy]) 487def narrow_copy( 488 self: torch.Tensor, 489 dim: int, 490 start: int, 491 length: int, 492) -> torch.Tensor: 493 return torch.narrow(self, dim, start, length).clone() 494 495 496@register_decomposition([aten.view_copy.default]) 497def view_copy_default( 498 self: torch.Tensor, 499 size: List[Union[int, torch.SymInt]], 500) -> torch.Tensor: 501 return aten.view(self, size).clone() 502 503 504@register_decomposition([aten.view_copy.dtype]) 505def view_copy_dtype( 506 self: torch.Tensor, 507 dtype: torch.dtype, 508) -> torch.Tensor: 509 return self.to(dtype).clone() 510 511 512def get_like_layout( 513 tensor: torch.Tensor, 514 memory_format: Optional[torch.memory_format] = None, 515) -> torch.memory_format: 516 # TODO: _to_copy tensor to stride permutation 517 if memory_format is torch.preserve_format or memory_format is None: 518 return utils.suggest_memory_format(tensor) 519 else: 520 return memory_format 521 522 523@register_decomposition(aten.rand_like) 524def rand_like( 525 self: torch.Tensor, 526 *, 527 dtype: Optional[torch.dtype] = None, 528 device: Optional[torch.device] = None, 529 memory_format: Optional[torch.memory_format] = None, 530 **kwargs: Any, 531) -> torch.Tensor: 532 return torch.rand( 533 [*self.size()], 534 dtype=dtype or self.dtype, 535 device=device or self.device, 536 **kwargs, 537 ).to(memory_format=get_like_layout(self, memory_format)) 538 539 540@register_decomposition(aten.randn_like) 541def randn_like( 542 self: torch.Tensor, 543 *, 544 dtype: Optional[torch.dtype] = None, 545 device: Optional[torch.device] = None, 546 memory_format: Optional[torch.memory_format] = None, 547 **kwargs: Any, 548) -> torch.Tensor: 549 return torch.randn( 550 [*self.size()], 551 dtype=dtype or self.dtype, 552 device=device or self.device, 553 **kwargs, 554 ).to(memory_format=get_like_layout(self, memory_format)) 555 556 557@register_decomposition(aten.full_like) 558def full_like( 559 self: torch.Tensor, 560 fill_value: Union[int, float], 561 *, 562 dtype: Optional[torch.dtype] = None, 563 layout: Optional[torch.layout] = None, 564 device: Optional[torch.device] = None, 565 pin_memory: bool = False, 566 requires_grad: bool = False, 567 memory_format: torch.memory_format = torch.preserve_format, 568) -> torch.Tensor: 569 return torch.full( 570 [*self.size()], 571 fill_value, 572 dtype=dtype or self.dtype, 573 layout=layout or self.layout, 574 device=device or self.device, 575 requires_grad=requires_grad, 576 ).to(memory_format=get_like_layout(self, memory_format)) 577 578 579@register_decomposition(aten.randint_like.default) 580def randint_like( 581 self: torch.Tensor, 582 high: int, 583 *, 584 dtype: Optional[torch.dtype] = None, 585 device: Optional[torch.device] = None, 586 memory_format: Optional[torch.memory_format] = None, 587 **kwargs: Any, 588) -> torch.Tensor: 589 return aten.randint.low( 590 0, 591 high, 592 [*self.size()], 593 dtype=dtype or self.dtype, 594 device=device or self.device, 595 **kwargs, 596 ).to(memory_format=get_like_layout(self, memory_format)) 597 598 599@register_decomposition(aten.randint_like.low_dtype) 600def randint_like_low( 601 self: torch.Tensor, 602 low: int, 603 high: int, 604 *, 605 dtype: Optional[torch.dtype] = None, 606 device: Optional[torch.device] = None, 607 memory_format: Optional[torch.memory_format] = None, 608 **kwargs: Any, 609) -> torch.Tensor: 610 return aten.randint.low( 611 low, 612 high, 613 [*self.size()], 614 dtype=dtype or self.dtype, 615 device=device or self.device, 616 **kwargs, 617 ).to(memory_format=get_like_layout(self, memory_format)) 618 619 620@register_decomposition(aten.randint.default) 621def randint( 622 high: int, 623 size: List[Union[int, torch.SymInt]], 624 **kwargs: Any, 625) -> torch.Tensor: 626 return aten.randint.low(0, high, size, **kwargs) 627 628 629@register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default) 630def linear_dynamic_fp16_unpacked_weight( 631 input: torch.Tensor, 632 weight: torch.Tensor, 633 bias: torch.Tensor, 634) -> torch.Tensor: 635 packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight) 636 return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight( 637 input, packed_weight, bias, weight.size()[0] 638 ) 639 640 641@register_decomposition(_quantized.wrapped_quantized_linear.default) 642def wrapped_quantized_linear( 643 input: torch.Tensor, 644 input_scale: torch.Tensor, 645 input_zero_point: torch.Tensor, 646 weight: torch.Tensor, 647 weight_scale: torch.Tensor, 648 weight_zero_point: torch.Tensor, 649 bias: torch.Tensor, 650 out_scale: torch.Tensor, 651 out_zero_point: torch.Tensor, 652 out_channel: int, 653) -> torch.Tensor: 654 packed_weight = torch.ops._quantized._wrapped_linear_prepack( 655 weight, weight_scale, weight_zero_point, bias 656 ) 657 return torch.ops._quantized._wrapped_quantized_linear_prepacked( 658 input, 659 input_scale, 660 input_zero_point, 661 packed_weight, 662 out_scale, 663 out_zero_point, 664 out_channel, 665 ) 666 667 668@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack) 669def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor: 670 def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor: 671 x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3)) 672 if sys.byteorder == "little": 673 return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None] 674 else: 675 return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None] 676 677 scales = bitcast_u8_to_f32(packed[..., -8:-4]) 678 offsets = bitcast_u8_to_f32(packed[..., -4:]) 679 return packed[..., :-8].to(torch.float32) * scales + offsets 680 681 682@register_decomposition([aten.grid_sampler_2d]) 683@pw_cast_for_opmath 684def grid_sampler_2d( 685 a: torch.Tensor, 686 grid: torch.Tensor, 687 interpolation_mode: int = 0, 688 padding_mode: int = 0, 689 align_corners: bool = False, 690) -> torch.Tensor: 691 # We do not expand the grid (_expand_grid=False) on cpu for performance reasons 692 # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x 693 # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2) 694 # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first. 695 # Thus we apply this hack to not expand the grid for this case. 696 _expand_grid = not ( 697 a.device == torch.device("cpu") 698 and interpolation_mode == 0 699 and a.is_contiguous(memory_format=torch.contiguous_format) 700 ) 701 702 output = decomp_grid_sampler_2d( 703 a, 704 grid=grid, 705 interpolation_mode=interpolation_mode, 706 padding_mode=padding_mode, 707 align_corners=align_corners, 708 _expand_grid=_expand_grid, 709 ) 710 return output 711 712 713@register_decomposition(aten._foreach_addcmul.Scalar) 714def _foreach_addcmul_scalar( 715 self: List[torch.Tensor], 716 left_tensors: List[torch.Tensor], 717 right_tensors: List[torch.Tensor], 718 scalar: float = 1, 719) -> List[torch.Tensor]: 720 return aten._foreach_add.List( 721 self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar 722 ) 723 724 725@register_decomposition(aten._foreach_addcdiv.Scalar) 726def _foreach_addcdiv_scalar( 727 self: List[torch.Tensor], 728 left_tensors: List[torch.Tensor], 729 right_tensors: List[torch.Tensor], 730 scalar: float = 1, 731) -> List[torch.Tensor]: 732 return aten._foreach_add.List( 733 self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar 734 ) 735 736 737@register_decomposition(aten._foreach_lerp.Scalar) 738def _foreach_lerp_scalar( 739 start_tensors: List[torch.Tensor], 740 end_tensors: List[torch.Tensor], 741 weight: torch.types.Number, 742) -> List[torch.Tensor]: 743 return aten._foreach_add.List( 744 start_tensors, 745 aten._foreach_mul.Scalar( 746 aten._foreach_sub.List(end_tensors, start_tensors), weight 747 ), 748 ) 749 750 751@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd) 752@register_decomposition(aten.miopen_batch_norm) 753def miopen_batch_norm( 754 input: torch.Tensor, 755 weight: torch.Tensor, 756 bias: typing.Optional[torch.Tensor], 757 running_mean: typing.Optional[torch.Tensor], 758 running_var: typing.Optional[torch.Tensor], 759 training: bool, 760 exponential_average_factor: float, 761 epsilon: float, 762) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 763 a, b, c = aten.native_batch_norm( 764 input, 765 weight, 766 bias, 767 running_mean, 768 running_var, 769 training, 770 exponential_average_factor, 771 epsilon, 772 ) 773 774 if training: 775 return (a, b, c) 776 return ( 777 a, 778 weight.new_zeros((0,)), 779 weight.new_zeros((0,)), 780 ) 781 782 783@functools.lru_cache(None) 784def fast_random_decomps() -> Dict[Any, Callable[..., Any]]: 785 return {**decompositions, **extra_random_decomps} 786 787 788# TODO(aakhundov): replace this (and the above) Any by more 789# specific type and fix all the cascading mypy errors 790def select_decomp_table() -> Dict[Any, Callable[..., Any]]: 791 """decomps can change based on config""" 792 if config.fallback_random: 793 return decompositions 794 return fast_random_decomps() 795 796 797@register_decomposition(aten.masked_scatter) 798def masked_scatter( 799 self: torch.Tensor, 800 mask: torch.Tensor, 801 source: torch.Tensor, 802) -> torch.Tensor: 803 from .codegen.common import BackendFeature, has_backend_feature 804 805 if has_backend_feature(self.device, BackendFeature.MASKED_SCATTER_WITH_INDEX): 806 # This two-step algorithm is the same as eager CUDA, for eager CPU we 807 # use a 1-shot serial iteration. 808 self, mask = aten.broadcast_tensors([self, mask]) 809 source_idx = mask.reshape(-1).cumsum(0) - 1 810 self_flat, mask_flat, source_flat = (x.flatten() for x in (self, mask, source)) 811 result = aten._unsafe_masked_index(source_flat, mask_flat, [source_idx], 0) 812 return torch.where(mask_flat, result, self_flat).view(self.shape) 813 return NotImplemented 814 815 816@register_decomposition(quantized_decomposed.choose_qparams.tensor) 817def choose_qparams_tensor( 818 input: torch.Tensor, 819 quant_min: int, 820 quant_max: int, 821 eps: float, 822 dtype: torch.dtype, 823) -> Tuple[torch.Tensor, torch.Tensor]: 824 min_val, max_val = torch.aminmax(input) 825 scale = (max_val - min_val) / float(quant_max - quant_min) 826 scale = torch.max(scale, torch.Tensor([eps])) 827 zero_point = quant_min - torch.round(min_val / scale).to(torch.int) 828 zero_point = torch.clamp(zero_point, quant_min, quant_max) 829 return scale.to(torch.float64), zero_point.to(torch.int64) 830 831 832@register_decomposition(aten.put) 833def put( 834 self: torch.Tensor, 835 index: torch.Tensor, 836 source: torch.Tensor, 837 accumulate: bool = False, 838) -> torch.Tensor: 839 flattened = self.flatten() 840 flattened = torch.index_put( 841 flattened, [index], source.reshape(index.shape), accumulate 842 ) 843 return flattened.reshape(self.shape) 844 845 846@register_decomposition(aten.put_) 847def put_( 848 self: torch.Tensor, 849 index: torch.Tensor, 850 source: torch.Tensor, 851 accumulate: bool = False, 852) -> torch.Tensor: 853 out = aten.put(self, index, source, accumulate=accumulate) 854 return self.copy_(out) 855 856 857@register_decomposition(aten._softmax_backward_data.default) 858@pw_cast_for_opmath 859def _softmax_backward_data( 860 grad_output: torch.Tensor, 861 output: torch.Tensor, 862 dim: int, 863 input_dtype: torch.dtype, 864) -> torch.Tensor: 865 new_grad_output = grad_output * output 866 sum_new_grad = torch.sum(new_grad_output, dim=dim, keepdim=True) 867 # grad_input = new_grad_output - output * sum_new_grad 868 grad_input = inductor_prims.fma(-output, sum_new_grad, new_grad_output) 869 870 # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor 871 # if grad_output.device == torch.device("cpu"): 872 # return grad_input.contiguous() 873 874 if grad_output.dtype != input_dtype: 875 grad_input = grad_input.to(input_dtype) 876 return grad_input.contiguous() 877 878 879@register_decomposition(aten.index_reduce) 880def index_reduce( 881 self: torch.Tensor, 882 dim: int, 883 index: torch.Tensor, 884 src: torch.Tensor, 885 reduction_type: str, 886 *, 887 include_self: bool = True, 888) -> torch.Tensor: 889 if reduction_type == "mean" and not needs_fallback_due_to_atomic_add_limitations( 890 self.dtype 891 ): 892 true_division = self.dtype.is_floating_point or self.dtype.is_complex 893 ones = torch.ones_like(src) 894 if include_self: 895 out = self 896 counts = torch.ones_like(self).index_add(dim, index, ones) 897 else: 898 out = self.index_fill(dim, index, 0) 899 counts = torch.zeros_like(self).index_add(dim, index, ones) 900 counts = counts.masked_fill(counts < 1, 1) 901 out = out.index_add(dim, index, src) 902 return out / counts if true_division else out // counts 903 904 if use_scatter_fallback( 905 aten.scatter_reduce_.two, 906 reduction_type, 907 self.dtype, 908 src.dtype, 909 src.device.type, 910 True, 911 ): 912 return NotImplemented 913 914 repeats = self.shape[dim + 1 :].numel() * self.shape[:dim].numel() 915 index_shape = (index.numel(), *self.shape[dim + 1 :], *self.shape[:dim]) 916 perm = (*range(self.ndim - dim, self.ndim), 0, *range(1, self.ndim - dim)) 917 scatter_index = ( 918 index.to(torch.int64) 919 .repeat_interleave(repeats) 920 .reshape(index_shape) 921 .permute(perm) 922 ) 923 return self.scatter_reduce( 924 dim, 925 scatter_index, 926 src, 927 reduction_type, 928 include_self=include_self, 929 ) 930 931 932@register_decomposition(aten.max_pool2d_with_indices) 933def max_pool2d_with_indices( 934 x: torch.Tensor, 935 kernel_size: List[int], 936 stride: Optional[Union[int, List[int]]] = None, 937 padding: Union[int, List[int]] = 0, 938 dilation: Union[int, List[int]] = 1, 939 ceil_mode: bool = False, 940) -> Tuple[torch.Tensor, torch.Tensor]: 941 if dilation == 1: 942 dilation = [1, 1] 943 944 if padding == 0: 945 padding = [0, 0] 946 947 if not stride: 948 stride = kernel_size 949 950 kernel_size = pad_listlike(kernel_size, 2) 951 dilation = pad_listlike(dilation, 2) 952 padding = pad_listlike(padding, 2) 953 stride = pad_listlike(stride, 2) 954 955 window_size = kernel_size[0] * kernel_size[1] 956 # We fallback when using non-default dilation or when the window size is too large 957 if ( 958 torch._inductor.lowering.should_fallback_max_pool2d_with_indices( 959 kernel_size, dilation 960 ) 961 or window_size > torch.iinfo(torch.int8).max 962 ): 963 return NotImplemented 964 965 vals, offsets = prims._low_memory_max_pool2d_with_offsets( 966 x, 967 kernel_size, 968 stride, 969 padding, 970 dilation, 971 ceil_mode, 972 ) 973 indices = prims._low_memory_max_pool2d_offsets_to_indices( 974 offsets, 975 kernel_size[1], 976 x.size(-1), 977 stride, 978 padding, 979 ) 980 return vals, indices 981