1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import functools 4import itertools 5import logging 6import math 7import operator 8import os 9import warnings 10from collections import defaultdict 11from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union 12from unittest.mock import patch 13 14import sympy 15 16import torch 17import torch.ao.quantization.fx._decomposed 18import torch.fx 19import torch.utils._pytree as pytree 20from torch._higher_order_ops.associative_scan import associative_scan_op 21from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation 22from torch._prims_common import ( 23 canonicalize_dim, 24 canonicalize_dims, 25 check, 26 dtype_to_type, 27 elementwise_dtypes, 28 ELEMENTWISE_TYPE_PROMOTION_KIND, 29 get_computation_dtype, 30 is_boolean_dtype, 31 is_float_dtype, 32 is_integer_dtype, 33 Number, 34) 35from torch.fx.experimental.sym_node import magic_methods, method_to_operator 36from torch.utils._sympy.functions import ( 37 CeilDiv, 38 FloorDiv, 39 Identity, 40 IntTrueDiv, 41 ModularIndexing, 42) 43 44from .._dynamo.utils import import_submodule 45from . import config, inductor_prims, ir, test_operators # NOQA: F401 46from .decomposition import decompositions, get_decompositions 47from .ir import ( 48 DtypeView, 49 ExpandView, 50 IndexingConstant, 51 is_triton, 52 ops_wrapper, 53 PermuteView, 54 Pointwise, 55 Reduction, 56 SqueezeView, 57 TensorBox, 58 validate_ir, 59 View, 60) 61from .utils import ( 62 ceildiv, 63 decode_device, 64 is_dynamic, 65 is_gpu, 66 is_pointwise_use, 67 needs_fallback_due_to_atomic_add_limitations, 68 pad_listlike, 69 sympy_product, 70 use_scatter_fallback, 71) 72from .virtualized import ops, V 73 74 75log = logging.getLogger(__name__) 76lowerings: Dict[torch._ops.OpOverload, Callable[..., Any]] = {} 77# Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints 78_maybe_layout_constraints: Dict[ 79 torch._ops.OpOverload, Optional[Callable[..., Any]] 80] = {} 81fallbacks: Set[torch._ops.OpOverload] = set() 82aten = torch.ops.aten 83tr_c10d = torch.ops.tr_c10d 84prims = torch.ops.prims 85needs_realized_inputs: Set[torch._ops.OpOverload] = set() 86foreach_ops: Set[torch._ops.OpOverload] = set() 87inplace_foreach_ops: Set[torch._ops.OpOverload] = set() 88inplaceable_foreach_ops: Dict[torch._ops.OpOverload, torch._ops.OpOverload] = {} 89quantized_decomposed = torch.ops.quantized_decomposed 90 91 92def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., Any]]: 93 """Get layout constraints. Returns None if there are no layout constraints.""" 94 if not isinstance(fn, torch._ops.OpOverload): 95 # Only OpOverloads have layout constraints. 96 return None 97 if fn in _maybe_layout_constraints: 98 return _maybe_layout_constraints[fn] 99 # OpOverload with custom lowerings override tag-based layout constraints 100 if fn in lowerings: 101 _maybe_layout_constraints[fn] = None 102 return None 103 # We lazily register tag-based layout constraints. 104 105 def handle_layout_constraint_tag(tag): 106 if tag is torch._C.Tag.needs_fixed_stride_order: 107 _maybe_layout_constraints[fn] = constrain_to_fx_strides 108 return _maybe_layout_constraints[fn] 109 elif tag is torch._C.Tag.flexible_layout: 110 _maybe_layout_constraints[fn] = None 111 return None 112 else: 113 raise AssertionError(f"Unknown layout constraint tag: {tag}") 114 115 tag = get_layout_constraint_tag(fn) 116 return handle_layout_constraint_tag(tag) 117 118 119def get_layout_constraint_tag(fn): 120 tags_by_priority = [ 121 torch._C.Tag.needs_fixed_stride_order, 122 torch._C.Tag.flexible_layout, 123 ] 124 for tag in tags_by_priority: 125 if tag in fn.tags: 126 return tag 127 return getattr(torch._C.Tag, config.custom_op_default_layout_constraint) 128 129 130def assert_nyi(cond, msg): 131 if not cond: 132 raise NotImplementedError(f"inductor does not support {msg}") 133 134 135def add_needs_realized_inputs(fn): 136 if isinstance(fn, (list, tuple, set)): 137 return [add_needs_realized_inputs(x) for x in fn] 138 needs_realized_inputs.add(fn) 139 if isinstance(fn, torch._ops.OpOverloadPacket): 140 needs_realized_inputs.update( 141 getattr(fn, overload) for overload in fn.overloads() 142 ) 143 144 145def add_layout_constraint(fn, constraint): 146 if isinstance(fn, torch._ops.OpOverloadPacket): 147 for overload in fn.overloads(): 148 _maybe_layout_constraints[getattr(fn, overload)] = constraint 149 else: 150 _maybe_layout_constraints[fn] = constraint 151 152 153add_needs_realized_inputs( 154 [ 155 aten.as_strided, 156 aten.as_strided_copy, 157 aten.avg_pool2d, 158 aten.avg_pool2d_backward, 159 aten.bmm, 160 aten.convolution, 161 aten.convolution_backward, 162 aten.max_pool2d_with_indices, 163 aten.max_pool2d_with_indices_backward, 164 aten.mm, 165 aten.upsample_nearest2d, 166 aten._upsample_nearest_exact2d, 167 aten._int_mm, 168 ] 169) 170 171# TODO(jansel): ezyang says we won't need this in the future, try removing it 172# based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28 173DTYPE_ID_LOOKUP = { 174 0: torch.uint8, 175 1: torch.int8, 176 2: torch.int16, 177 3: torch.int32, 178 4: torch.int64, 179 5: torch.float16, 180 6: torch.float32, 181 7: torch.float64, 182 8: torch.complex32, 183 9: torch.complex64, 184 10: torch.complex32, 185 11: torch.bool, 186 15: torch.bfloat16, 187 # TODO(jansel): add quantized types? 188 # _(c10::qint8, QInt8) /* 12 */ 189 # _(c10::quint8, QUInt8) /* 13 */ 190 # _(c10::qint32, QInt32) /* 14 */ 191 # _(c10::quint4x2, QUInt4x2) /* 16 */ 192 # _(c10::quint2x4, QUInt2x4) /* 17 */ 193} 194 195 196def decode_dtype(dtype: int): 197 if not isinstance(dtype, int): 198 return dtype 199 assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP" 200 dtype = DTYPE_ID_LOOKUP[dtype] 201 return dtype 202 203 204def is_integer_type(x): 205 if isinstance(x, TensorBox): 206 return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) 207 elif isinstance(x, sympy.Expr): 208 return x.is_integer is True # type: ignore[attr-defined] 209 else: 210 return isinstance(x, int) 211 212 213def is_boolean_type(x): 214 if isinstance(x, TensorBox): 215 return is_boolean_dtype(x.get_dtype()) 216 else: 217 return isinstance(x, bool) 218 219 220def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND): 221 def construct_input(inp): 222 if isinstance(inp, (Number, sympy.Basic)): 223 return inp 224 else: 225 assert hasattr(inp, "get_dtype") 226 dim = len(inp.get_size()) 227 # construct a tmp tensor to feed into torch.result_type 228 return torch.zeros([1] * dim, dtype=inp.get_dtype()) 229 230 inps = [construct_input(arg) for arg in args] 231 _, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind) 232 return dtype 233 234 235def get_overloads(aten_fn): 236 if not isinstance(aten_fn, (list, tuple)): 237 aten_fn = [aten_fn] 238 else: 239 aten_fn = list(aten_fn) 240 241 for fn in list(aten_fn): 242 if isinstance(fn, torch._ops.OpOverloadPacket): 243 for overload in fn.overloads(): 244 other_fn = getattr(fn, overload) 245 if other_fn not in lowerings: 246 aten_fn.append(other_fn) 247 248 return aten_fn 249 250 251def in_namespace(op, namespace): 252 if isinstance(op, torch._ops.OpOverloadPacket): 253 return namespace in op._qualified_op_name 254 elif isinstance(op, torch._ops.OpOverload): 255 return namespace in op.name() 256 return False 257 258 259def transform_args(args, broadcast, type_promotion_kind, convert_input_to_bool): 260 indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] 261 if (type_promotion_kind or convert_input_to_bool) and indices: 262 if convert_input_to_bool: 263 dtype = torch.bool 264 else: 265 # FIXME that's a crude approximation for promoting args 266 promoting_args = [ 267 a 268 for a in args 269 if isinstance(a, (Number, sympy.Basic)) 270 or getattr(a, "dtype", None) is not None 271 ] 272 dtype = get_promoted_dtype( 273 *promoting_args, type_promotion_kind=type_promotion_kind 274 ) 275 276 # sometimes args are an immutable list so we can't mutate them 277 def promote(arg): 278 if isinstance(arg, TensorBox): 279 return to_dtype(arg, dtype) 280 elif isinstance(arg, ir.Constant): 281 return ir.Constant(arg.value, dtype, args[indices[0]].get_device()) 282 else: 283 return arg 284 285 args = [promote(a) for a in args] 286 if broadcast and indices: 287 for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): 288 args[i] = x 289 for i in range(len(args)): 290 if isinstance(args[i], ir.Constant): 291 args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) 292 293 return args 294 295 296def _register_foreach_lowering(aten_fn, decomp_fn): 297 """ 298 Add a foreach lowering to lowerings dict. 299 300 Arguments: 301 aten_fn: torch.ops.aten.* fn we are lowering 302 decomp_fn: alternate implementation on our IR 303 broadcast: True to apply broadcasting to tensor inputs 304 type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion 305 convert_input_to_bool: some logical ops require inputs are converted to bool 306 """ 307 308 @functools.wraps(decomp_fn) 309 def wrapped(*args, **kwargs): 310 assert len(args) <= 2 311 out = decomp_fn(*args, **kwargs) 312 validate_ir(out) 313 return out 314 315 aten_fns = get_overloads(aten_fn) 316 foreach_ops.update(aten_fns) 317 lowerings.update(dict.fromkeys(aten_fns, wrapped)) 318 return wrapped 319 320 321def _register_lowering( 322 aten_fn, decomp_fn, broadcast, type_promotion_kind, convert_input_to_bool 323): 324 """ 325 Add a lowering to lowerings dict 326 327 Arguments: 328 aten_fn: torch.ops.aten.* fn we are lowering 329 decomp_fn: alternate implementation on our IR 330 broadcast: True to apply broadcasting to tensor inputs 331 type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion 332 convert_input_to_bool: some logical ops require inputs are converted to bool 333 """ 334 335 @functools.wraps(decomp_fn) 336 def wrapped(*args, **kwargs): 337 args: Union[List[Any], Tuple[Any, ...], Dict[Any, Any]] = list(args) 338 unpacked = False 339 # TODO maybe we need to use pytrees here 340 if len(args) == 1 and isinstance(args[0], (list, tuple)): 341 unpacked = True 342 args = args[0] 343 344 # kwargs tensors not supported yet unless it's a fallback op 345 if not all( 346 (fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn 347 ): 348 assert not any(isinstance(x, TensorBox) for x in kwargs.values()) 349 # explicitly assert for "out=" ops for better error messages 350 assert not any( 351 x == "out" for x in kwargs.keys() 352 ), "out= ops aren't yet supported" 353 354 args = transform_args( 355 args, broadcast, type_promotion_kind, convert_input_to_bool 356 ) 357 358 if unpacked: 359 args = [args] 360 361 out = decomp_fn(*args, **kwargs) 362 validate_ir(out) 363 364 return out 365 366 aten_fn = get_overloads(aten_fn) 367 368 lowerings.update(dict.fromkeys(aten_fn, wrapped)) 369 return wrapped 370 371 372def register_lowering( 373 aten_fn, 374 broadcast=False, 375 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 376 convert_input_to_bool=False, 377): 378 """ 379 Shim to support decorator syntax. 380 """ 381 return functools.partial( 382 _register_lowering, 383 aten_fn, 384 broadcast=broadcast, 385 type_promotion_kind=type_promotion_kind, 386 convert_input_to_bool=convert_input_to_bool, 387 ) 388 389 390def broadcast_symbolic_shapes(a, b): 391 """ 392 Broadcasting logic based on symbolic shapes. 393 394 We give the shapes 0 and 1 concrete values, while all other shapes 395 are symbolic sympy formulas. 396 """ 397 output = [] 398 for x, y in itertools.zip_longest( 399 reversed(a), reversed(b), fillvalue=sympy.Integer(1) 400 ): 401 if y == 1: 402 output.append(x) 403 elif x == 1: 404 output.append(y) 405 else: 406 V.graph.sizevars.guard_equals(x, y) 407 if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols): 408 output.append(y) # prefer shorter formula 409 else: 410 output.append(x) 411 return tuple(reversed(output)) 412 413 414def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None): 415 assert ( 416 override_return_dtype is None or type_promotion_kind is None 417 ), "only one of override_return_dtype or type_promotion_kind may be given" 418 419 if override_return_dtype is None and type_promotion_kind is None: 420 type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 421 422 if not any(isinstance(x, (sympy.Basic, int, float)) for x in inputs): 423 return inputs 424 if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs): 425 dtype = override_return_dtype or get_promoted_dtype( 426 *inputs, type_promotion_kind=type_promotion_kind 427 ) 428 429 def const_func(x): 430 if isinstance(x, sympy.Basic): 431 return ir.IndexingConstant(x, dtype, decode_device(None)) 432 else: 433 return ir.Constant(x, dtype, decode_device(None)) 434 435 return [const_func(x) for x in inputs] 436 ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView, ir.Constant))) 437 out = [] 438 for x in inputs: 439 if isinstance(x, (int, float)): 440 out.append( 441 ExpandView.create( 442 ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size()) 443 ) 444 ) 445 elif isinstance(x, sympy.Basic): 446 out.append( 447 ExpandView.create( 448 IndexingConstant(x, ex.get_dtype(), ex.get_device()), 449 list(ex.get_size()), 450 ) 451 ) 452 else: 453 out.append(x) 454 455 return out 456 457 458def make_pointwise( 459 fn, 460 override_return_dtype=None, 461 override_device=None, 462 override_fn_when_input_bool=None, 463 override_fn_when_gpu_float64=None, 464 allow_alpha=False, 465 triton_fallback=None, 466): 467 def inner(*inputs: List[TensorBox], alpha=None): 468 if triton_fallback is not None and any(map(is_triton, inputs)): 469 assert not allow_alpha # not implemented 470 return triton_fallback(*inputs) 471 472 inputs = promote_constants(inputs, override_return_dtype) 473 if allow_alpha: 474 if alpha is not None and alpha != 1: 475 inputs = list(inputs) 476 inputs[-1] = mul(inputs[-1], alpha) 477 else: 478 assert alpha is None 479 loaders = [x.make_loader() for x in inputs] 480 ranges = inputs[0].get_size() 481 dtype = override_return_dtype or inputs[0].get_dtype() 482 is_gpu_device = is_gpu(decode_device(inputs[0].get_device()).type) 483 484 for other in inputs[1:]: 485 assert isinstance(other, ir.BaseConstant) or len(ranges) == len( 486 other.get_size() 487 ), f"ndim mismatch {fn} {ranges} {other.get_size()}" 488 489 # in tracing, we will annotate pointwise nodes that correspond to the output of 490 # a pointwise node that would have been run in eager. intermediary pointwise nodes 491 # during decompositions are not annotated. 492 emulate_precision_casts = ( 493 V.graph is not None 494 and getattr(V.graph, "current_node", None) is not None 495 and V.graph.current_node.meta is not None 496 and V.graph.current_node.meta.get("low_precision_pointwise_barrier", False) 497 and dtype in (torch.bfloat16, torch.float16) 498 ) 499 500 def inner_fn(index): 501 assert len(index) == len(ranges), f"wrong ndim {index} {ranges}" 502 if dtype == torch.bool and override_fn_when_input_bool is not None: 503 return override_fn_when_input_bool(*[load(index) for load in loaders]) 504 elif ( 505 override_fn_when_gpu_float64 506 and is_gpu_device 507 and dtype == torch.float64 508 ): 509 return override_fn_when_gpu_float64(*[load(index) for load in loaders]) 510 else: 511 inputs_loaded = [] 512 for load in loaders: 513 out = load(index) 514 if emulate_precision_casts: 515 downcast = ops.to_dtype(out, dtype, use_compute_types=False) 516 out = ops.to_dtype(downcast, dtype) 517 inputs_loaded.append(out) 518 519 out = fn(*inputs_loaded) 520 if emulate_precision_casts: 521 # fp16/bf16 kernels are computed in fp32. Casting down to fp16/bf16 here, 522 # then upcasting again, to emulate casts that eager would do. 523 downcast = ops.to_dtype(out, dtype, use_compute_types=False) 524 return ops.to_dtype(downcast, dtype) 525 return out 526 527 if not override_device: 528 device = None 529 for i in inputs: 530 if is_gpu(i.get_device().type): 531 device = i.get_device() 532 break 533 if not device: 534 device = inputs[0].get_device() 535 536 device = override_device or device 537 538 return Pointwise.create( 539 device=device, 540 dtype=dtype, 541 inner_fn=inner_fn, 542 ranges=ranges, 543 ) 544 545 return inner 546 547 548def make_foreach_pointwise(pw_fn, allow_alpha=False): 549 def inner(*inputs: List[List[TensorBox]], alpha=1): 550 # group by device, whether any of the inputs are dynamic, and whether their types match 551 # (proxy for type promotion) 552 def group_args(arg_pairs): 553 out = defaultdict(list) 554 for i, args in enumerate(arg_pairs): 555 use_foreach = ( 556 not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes 557 ) 558 device = None 559 for t in args: 560 if isinstance(t, TensorBox): 561 device = t.data.get_device() 562 break 563 assert ( 564 device is not None 565 ), "foreach op should have at least one tensor arg" 566 out[(device, use_foreach)].append((i, args)) 567 return out 568 569 realize_outputs = ( 570 len(V.graph.current_node.users) == 0 571 or V.graph.current_node.target in inplace_foreach_ops 572 ) 573 for node in V.graph.current_node.users: 574 for user in node.users: 575 if not (user.op == "call_function" and (user.target in foreach_ops)): 576 realize_outputs = True 577 578 a_list_input = None 579 for input in inputs: 580 if isinstance(input, (list, tuple)): 581 a_list_input = input 582 break 583 assert ( 584 a_list_input is not None 585 ), "at least one input must be a list to a foreach op" 586 587 # broadcast scalar inputs to match length of list inputs 588 broadcast_inputs = [] 589 for input in inputs: 590 if not isinstance(input, (list, tuple)): 591 broadcast_inputs.append([input] * len(a_list_input)) 592 else: 593 broadcast_inputs.append(input) 594 595 groups = group_args(zip(*broadcast_inputs)) 596 597 outputs = [None] * len(a_list_input) 598 for (device, use_foreach), group in groups.items(): 599 operation_list: List[str] = [] 600 for ( 601 output_ind, 602 args, 603 ) in group: 604 if allow_alpha: 605 output = pw_fn(*args, alpha=alpha) 606 else: 607 output = pw_fn(*args) 608 609 outputs[output_ind] = output 610 611 if ( 612 V.graph.has_feature(device, BackendFeature.FOREACH) 613 and use_foreach 614 and realize_outputs 615 ): 616 output.realize() 617 operation_list.append(output.get_operation_name()) 618 619 if operation_list: 620 V.graph.register_operation_list(operation_list) 621 622 assert all(x is not None for x in outputs) 623 return outputs 624 625 return inner 626 627 628def to_dtype(x: TensorBox, dtype: torch.dtype, copy=False): 629 src_dtype = x.get_dtype() 630 if src_dtype == dtype: 631 return clone(x) if copy else x 632 633 def _to_dtype(x): 634 return ops.to_dtype(x, dtype, src_dtype=src_dtype) 635 636 return make_pointwise(_to_dtype, override_return_dtype=dtype)(x) 637 638 639@register_lowering(prims.convert_element_type, type_promotion_kind=None) 640def _convert_element_type(x: TensorBox, dtype: torch.dtype): 641 if dtype.is_complex or x.get_dtype().is_complex: 642 if x.get_size(): 643 # Decompose since aa aten fallback is more friendly for c++ codegen. 644 # This decomposition doesn't work for empty tensor, which needs more investigation. 645 dst = empty_like(x, dtype=dtype) 646 ir.InplaceCopyFallback.create(dst, x) 647 return dst 648 else: 649 return fallback_handler( 650 prims.convert_element_type.default, add_to_fallback_set=False 651 )(x, dtype) 652 return to_dtype(x, dtype, copy=True) 653 654 655def to_dtype_bitcast(x: TensorBox, dtype: torch.dtype, *, copy=False): 656 x_dtype = x.get_dtype() 657 if x_dtype == dtype: 658 return clone(x) if copy else x 659 660 def _get_primitive_bitwidth(dtype): 661 if dtype.is_floating_point: 662 return torch.finfo(dtype).bits 663 else: 664 return torch.iinfo(dtype).bits 665 666 src_bits = _get_primitive_bitwidth(x_dtype) 667 dst_bits = _get_primitive_bitwidth(dtype) 668 if src_bits != dst_bits: 669 # fallback to aten eager implementation for differing bitwidths 670 return fallback_handler(aten.view.dtype)(x, dtype) 671 else: 672 return TensorBox(DtypeView.create(x, dtype)) 673 674 675@register_lowering(aten.view.dtype, type_promotion_kind=None) 676def _view_dtype(x: TensorBox, dtype: torch.dtype): 677 if dtype.is_complex or x.get_dtype().is_complex: 678 return TensorBox.create( 679 ir.ComplexView.create(torch.ops.aten.view.dtype, x, dtype) 680 ) 681 return to_dtype_bitcast(x, dtype) 682 683 684def to_device(x: TensorBox, device: torch.device, *, copy=False): 685 device = decode_device(device) 686 if x.get_device() == device: 687 return clone(x) if copy else x 688 return TensorBox.create(ir.DeviceCopy.create(x, device)) 689 690 691@register_lowering(prims.device_put, type_promotion_kind=None) 692def _device_put(x: TensorBox, device: torch.device): 693 return to_device(x, device, copy=True) 694 695 696def register_pointwise( 697 aten_fn, 698 name=None, 699 broadcast=True, 700 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 701 convert_input_to_bool=False, 702 override_return_dtype=None, 703 override_fn_when_input_bool=None, 704 allow_alpha=False, 705 use_libdevice_for_f64=False, 706 triton_fallback=None, 707): 708 """A pointwise function that maps ops.{name} to inputs""" 709 name = name or aten_fn.__name__ 710 fn = ops_wrapper(name) 711 if use_libdevice_for_f64: 712 fn_libdevice = ops_wrapper("libdevice_" + name) 713 if override_fn_when_input_bool is not None: 714 override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool) 715 716 fn = make_pointwise( 717 fn, 718 override_return_dtype=override_return_dtype, 719 override_fn_when_input_bool=override_fn_when_input_bool, 720 override_fn_when_gpu_float64=fn_libdevice if use_libdevice_for_f64 else None, # type: ignore[possibly-undefined] 721 allow_alpha=allow_alpha, 722 triton_fallback=triton_fallback, 723 ) 724 fn = register_lowering( 725 aten_fn, 726 broadcast=broadcast, 727 type_promotion_kind=type_promotion_kind, 728 convert_input_to_bool=convert_input_to_bool, 729 )(fn) 730 731 if hasattr(prims, name): 732 register_lowering( 733 getattr(prims, name), 734 type_promotion_kind=None, 735 convert_input_to_bool=convert_input_to_bool, 736 )(fn) 737 return fn 738 739 740def register_frexp(): 741 """A pointwise function that maps ops.frexp to inputs""" 742 name = "frexp" 743 frexp = ops_wrapper("frexp") 744 745 def frexp0(*args, **kwargs): 746 return frexp(*args, **kwargs)[0] # type: ignore[index] # next PR 747 748 def frexp1(*args, **kwargs): 749 return frexp(*args, **kwargs)[1] # type: ignore[index] # next PR 750 751 pw_fns = [ 752 make_pointwise(frexp0), 753 make_pointwise(frexp1, override_return_dtype=torch.int32), 754 ] 755 756 def fn(*args, **kwargs): 757 return pw_fns[0](*args, **kwargs), pw_fns[1](*args, **kwargs) 758 759 fn = register_lowering( 760 aten.frexp, 761 )(fn) 762 763 if hasattr(prims, name): 764 register_lowering( 765 getattr(prims, name), 766 type_promotion_kind=None, 767 )(fn) 768 return fn 769 770 771register_frexp() 772 773 774def register_foreach_pointwise( 775 aten_fn, 776 pointwise_lowering_fn, 777 allow_alpha=False, 778): 779 fn = make_foreach_pointwise(pointwise_lowering_fn, allow_alpha=allow_alpha) 780 fn = _register_foreach_lowering(aten_fn, fn) 781 return fn 782 783 784@register_lowering(aten.where, broadcast=False, type_promotion_kind=None) 785def where(cond, a, b): 786 def fn(*args): 787 return ops.where(*args) 788 789 if isinstance(a, (float, int)): 790 a = constant_like(a)(b) 791 if isinstance(b, (float, int)): 792 b = constant_like(b)(a) 793 794 args = [cond, a, b] 795 dtype = get_promoted_dtype( 796 args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 797 ) 798 indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] 799 for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): 800 args[i] = x 801 for i in range(len(args)): 802 if isinstance(args[i], ir.Constant): 803 args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) 804 return make_pointwise(fn, override_return_dtype=dtype)( 805 args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype) 806 ) 807 808 809@register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None) 810def broadcast_tensors(*inputs): 811 if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)): 812 return broadcast_tensors(*inputs[0]) 813 target: List[sympy.Expr] = functools.reduce( 814 broadcast_symbolic_shapes, [x.get_size() for x in inputs], [] 815 ) 816 outputs = [] 817 for x in inputs: 818 sizes = x.get_size() 819 if len(sizes) != len(target) or any( 820 ((a == 1 and b != 1) or (a != 1 and b == 1)) for a, b in zip(sizes, target) 821 ): 822 x = expand(x, target) 823 outputs.append(x) 824 return outputs 825 826 827@register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of]) 828def nop(x): 829 return x # AOT autograd handles this for us 830 831 832if hasattr(aten, "lift_fresh"): 833 register_lowering(aten.lift_fresh)(nop) 834 835 836@register_lowering(aten.squeeze, type_promotion_kind=None) 837def squeeze(x, dim=None): 838 assert isinstance(x, TensorBox) 839 if dim is None: 840 return TensorBox(SqueezeView.create(x.data)) 841 842 dim = ( 843 V.graph.sizevars.evaluate_static_shape(dim) 844 if isinstance(dim, (int, sympy.Expr)) 845 else tuple(V.graph.sizevars.evaluate_static_shape(d) for d in dim) 846 ) 847 dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload] 848 dims = set((dim,) if not isinstance(dim, tuple) else dim) 849 850 new_shape = [] 851 for d, s in enumerate(x.get_size()): 852 if not (d in dims and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1))): 853 new_shape.append(s) 854 855 # squeeze does nothing if the size isn't 1 856 return view(x, new_shape) if new_shape != x.get_size() else x 857 858 859@register_lowering(aten.squeeze_copy, type_promotion_kind=None) 860def squeeze_copy(x, dim=None): 861 return clone(squeeze(x, dim)) 862 863 864@register_lowering([aten.squeeze_]) 865def squeeze_(x, dim=None): 866 val = squeeze(x, dim) 867 assert isinstance(x, TensorBox) 868 assert isinstance(val, TensorBox) 869 x.data = val.data 870 return x 871 872 873@register_lowering(aten.isinf) 874def isinf(x): 875 if is_integer_type(x): 876 return full_like(x, False, dtype=torch.bool) 877 fn = ops_wrapper("isinf") 878 return make_pointwise(fn, override_return_dtype=torch.bool)(x) 879 880 881@register_lowering(aten.isnan) 882def isnan(x): 883 if is_integer_type(x): 884 return full_like(x, False, dtype=torch.bool) 885 fn = ops_wrapper("isnan") 886 return make_pointwise(fn, override_return_dtype=torch.bool)(x) 887 888 889@register_lowering(aten.ceil) 890def ceil(x): 891 if is_integer_type(x): 892 return clone(x) 893 fn = ops_wrapper("ceil") 894 return make_pointwise(fn)(x) 895 896 897@register_lowering(aten.floor) 898def floor(x): 899 if is_integer_type(x): 900 return clone(x) 901 fn = ops_wrapper("floor") 902 return make_pointwise(fn)(x) 903 904 905@register_lowering(aten.round.default) 906def round(x): 907 if is_integer_type(x): 908 return clone(x) 909 else: 910 fn = ops_wrapper("round") 911 return make_pointwise(fn)(x) 912 913 914@register_lowering(aten.trunc) 915def trunc(x): 916 if is_integer_type(x): 917 return clone(x) 918 fn = ops_wrapper("trunc") 919 return make_pointwise(fn)(x) 920 921 922@register_lowering(aten.expand, type_promotion_kind=None) 923def expand(x, sizes): 924 from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols 925 926 (x,) = promote_constants([x]) 927 if isinstance(x, ir.BaseConstant): 928 return ExpandView.create(x, tuple(sizes)) 929 assert isinstance(x, TensorBox) 930 assert isinstance(sizes, (list, tuple)) 931 if tuple(x.get_size()) == tuple(sizes): 932 return x 933 934 if not free_unbacked_symbols(x.get_size()): 935 x_size_product = V.graph.sizevars.size_hint(sympy_product(x.get_size())) 936 # TODO: It would be better to realize the input if any of its sizes 937 # are unbacked, because typically the size will be non-zero. However, 938 # this cannot be done directly as below as we'll choke on the size_hint 939 # here 940 if x_size_product > 0 and not free_unbacked_symbols(sizes): 941 # maybe realize input before broadcasting it 942 x.mark_reuse( 943 V.graph.sizevars.size_hint(sympy_product(sizes)) // x_size_product 944 ) 945 return TensorBox(ExpandView.create(x.data, tuple(sizes))) 946 947 948@register_lowering(prims.broadcast_in_dim, type_promotion_kind=None) 949def broadcast_in_dim(a, shape, broadcast_dimensions): 950 s = list(shape) 951 for broadcast_dimension in broadcast_dimensions: 952 s[broadcast_dimension] = -1 953 954 v = a 955 for idx, x in enumerate(s): 956 if x != -1: 957 v = unsqueeze(v, idx) 958 959 return expand(v, shape) 960 961 962@register_lowering(aten.expand_as, type_promotion_kind=None) 963def expand_as(x, y): 964 return expand(x, y.get_size()) 965 966 967@register_lowering(aten.repeat) 968def repeat(x, repeats): 969 old_size = list(x.get_size()) 970 if len(repeats) > len(old_size): 971 old_size = [sympy.Integer(1)] * (len(repeats) - len(old_size)) + old_size 972 x = view(x, list(old_size)) 973 assert len(repeats) == len(x.get_size()) 974 975 new_size = list(x.get_size()) 976 977 zero_tensor = False 978 for i in range(len(repeats)): 979 if repeats[i] == 0: 980 zero_tensor = True 981 new_size[i] = new_size[i] * repeats[i] 982 983 if zero_tensor: 984 return empty(new_size, dtype=x.get_dtype(), device=x.get_device()) 985 if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)): 986 return clone(expand(x, new_size)) 987 988 x_loader: Callable[[Any], Any] 989 990 def inner_fn(index): 991 assert len(index) == len(repeats) 992 index = list(index) 993 for i in range(len(repeats)): 994 if repeats[i] != 1: 995 if old_size[i] == 1: 996 index[i] = sympy.Integer(0) 997 else: 998 index[i] = ModularIndexing(index[i], 1, old_size[i]) 999 return x_loader(index) 1000 1001 old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size)) 1002 if old_size_product > 0: 1003 # maybe realize the input 1004 x.mark_reuse( 1005 V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product 1006 ) 1007 1008 x_loader = x.make_loader() 1009 return Pointwise.create( 1010 device=x.get_device(), 1011 dtype=x.get_dtype(), 1012 inner_fn=inner_fn, 1013 ranges=list(new_size), 1014 ) 1015 1016 1017@register_lowering(aten._unsafe_view, type_promotion_kind=None) 1018@register_lowering(aten.view, type_promotion_kind=None) 1019@register_lowering(aten.reshape, type_promotion_kind=None) 1020def view(x, sizes): 1021 assert isinstance(x, TensorBox) 1022 assert isinstance(sizes, (list, tuple)) 1023 return TensorBox(View.create(x.data, sizes)) 1024 1025 1026@register_lowering(aten.permute, type_promotion_kind=None) 1027def permute(x, dims): 1028 assert isinstance(x, TensorBox) 1029 assert isinstance(dims, (list, tuple)) 1030 return TensorBox(PermuteView.create(x.data, tuple(dims))) 1031 1032 1033@register_lowering(aten.slice, type_promotion_kind=None) 1034def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True): 1035 assert isinstance(x, TensorBox) 1036 dim = _validate_dim(x, dim, 0) 1037 return TensorBox(ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp)) 1038 1039 1040@register_lowering(aten.as_strided, type_promotion_kind=None) 1041def as_strided(x, size, stride, storage_offset=None): 1042 if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView): 1043 # as_strided ignores views 1044 x = x.data.unwrap_view() 1045 x.realize() 1046 if not ir.is_storage_and_layout(x): 1047 raise NotImplementedError(f"unrealized as_strided({x}, ...)") 1048 storage, old_layout = ir.as_storage_and_layout(x) 1049 new_layout = ir.FixedLayout( 1050 old_layout.device, 1051 old_layout.dtype, 1052 [sympy.expand(s) for s in size], 1053 [sympy.expand(s) for s in stride], 1054 sympy.expand(storage_offset or 0), 1055 ) 1056 return TensorBox(ir.ReinterpretView(storage, new_layout)) 1057 1058 1059@register_lowering(aten.as_strided_, type_promotion_kind=None) 1060def as_strided_(x, size, stride, storage_offset=None): 1061 assert isinstance(x, TensorBox) 1062 x.data = as_strided(x, size, stride, storage_offset).data 1063 return x 1064 1065 1066@register_lowering(aten.as_strided_copy, type_promotion_kind=None) 1067def as_strided_copy(x, size, stride, storage_offset=None): 1068 result = as_strided(x, size, stride, storage_offset) 1069 return clone(result) 1070 1071 1072def pointwise_cat(inputs, dim=0): 1073 # (inclusive, exclusive) 1074 inputs_ranges: List[Tuple[sympy.Expr, sympy.Expr]] = [] 1075 prev_end = 0 1076 for inp in inputs: 1077 inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type] 1078 prev_end = inputs_ranges[-1][-1] # type: ignore[assignment] 1079 1080 inputs_loaders = [inp.make_loader() for inp in inputs] 1081 1082 def inner_fn(idx): 1083 idx_dim = ops.index_expr(idx[dim], torch.int64) 1084 1085 masks = [] 1086 masked_loads = [] 1087 for i in range(len(inputs)): 1088 start = ( 1089 ops.constant(0, torch.int64) 1090 if i == 0 1091 else ops.index_expr(inputs_ranges[i][0], torch.int64) 1092 ) 1093 end = ops.index_expr(inputs_ranges[i][1], torch.int64) 1094 1095 start_cond = ops.ge(idx_dim, start) 1096 end_cond = ops.lt(idx_dim, end) 1097 if i == 0: 1098 mask = end_cond 1099 elif i == len(inputs) - 1: 1100 mask = start_cond 1101 else: 1102 mask = ops.and_(start_cond, end_cond) 1103 1104 masks.append(mask) 1105 idx_load = list(idx) 1106 1107 # if we're concatting [4], [2] 1108 # when we index the second tensor for 5 we want to index 5 - 4 1109 # Use Identity to prevent expansion of index * stride to keep expression 1110 # in same int bitwidth as shape 1111 idx_load[dim] = Identity(idx_load[dim] - inputs_ranges[i][0]) 1112 1113 masked_loads.append( 1114 ops.masked( 1115 mask, 1116 lambda: inputs_loaders[i](idx_load), 1117 0.0, # this value should be unused 1118 ), 1119 ) 1120 1121 next_val = masked_loads[-1] 1122 for i in range((len(inputs)) - 2, -1, -1): 1123 next_val = ops.where( 1124 masks[i], 1125 masked_loads[i], 1126 next_val, 1127 ) 1128 return next_val 1129 1130 new_size = list(inputs[0].get_size()) 1131 new_size[dim] = inputs_ranges[-1][-1] 1132 1133 return Pointwise.create( 1134 device=inputs[0].get_device(), 1135 dtype=inputs[0].get_dtype(), 1136 inner_fn=inner_fn, 1137 ranges=new_size, 1138 ) 1139 1140 1141@register_lowering(quantized_decomposed.quantize_per_channel, type_promotion_kind=None) 1142def quantized_decomposed_quantize_per_channel( 1143 input: TensorBox, 1144 scales: TensorBox, 1145 zero_points: TensorBox, 1146 axis: int, 1147 quant_min: int, 1148 quant_max: int, 1149 dtype: torch.dtype, 1150) -> TensorBox: 1151 assert len(scales.get_size()) == 1, "expect scales 1 dim" 1152 assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" 1153 1154 if input.get_dtype() == torch.bfloat16: 1155 input = to_dtype(input, torch.float32) 1156 assert ( 1157 input.get_dtype() == torch.float32 1158 ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" 1159 assert axis < len( 1160 input.get_size() 1161 ), f"Expecting axis to be < {len(input.get_size())}" 1162 1163 input_loader = input.make_loader() 1164 scales_loader = scales.make_loader() 1165 zero_points_loader = zero_points.make_loader() 1166 1167 def inner_fn(idx): 1168 channel_idx = (idx[axis],) 1169 1170 input = input_loader(idx) 1171 scale = scales_loader(channel_idx) 1172 zero_point = zero_points_loader(channel_idx) 1173 qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) 1174 1175 if scales.dtype != torch.float32: 1176 scale = ops.to_dtype(scale, torch.float32) 1177 if zero_points.dtype != torch.int32: 1178 zero_point = ops.to_dtype(zero_point, torch.int32) 1179 inv_scale = ops.reciprocal(scale) 1180 val = ops.round(input * inv_scale) + zero_point 1181 clamped = ops.maximum(qmin, ops.minimum(qmax, val)) 1182 return ops.to_dtype(clamped, dtype) 1183 1184 return Pointwise.create( 1185 device=input.get_device(), 1186 dtype=dtype, 1187 inner_fn=inner_fn, 1188 ranges=input.get_size(), 1189 ) 1190 1191 1192@register_lowering( 1193 quantized_decomposed.dequantize_per_channel, type_promotion_kind=None 1194) 1195def quantized_decomposed_dequantize_per_channel( 1196 input: TensorBox, 1197 scales: TensorBox, 1198 zero_points: TensorBox, 1199 axis: int, 1200 quant_min: int, 1201 quant_max: int, 1202 dtype: torch.dtype, 1203) -> TensorBox: 1204 assert len(scales.get_size()) == 1, "expect scales 1 dim" 1205 assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" 1206 assert ( 1207 input.get_dtype() == dtype 1208 ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" 1209 assert axis < len( 1210 input.get_size() 1211 ), f"Expecting axis to be < {len(input.get_size())}" 1212 1213 input_loader = input.make_loader() 1214 scales_loader = scales.make_loader() 1215 zero_points_loader = zero_points.make_loader() 1216 1217 def inner_fn(idx): 1218 channel_idx = (idx[axis],) 1219 1220 input = input_loader(idx) 1221 scale = scales_loader(channel_idx) 1222 zero_point = zero_points_loader(channel_idx) 1223 1224 if scales.dtype != torch.float32: 1225 scale = ops.to_dtype(scale, torch.float32) 1226 if zero_points.dtype != torch.float32: 1227 zero_point = ops.to_dtype(zero_point, torch.float32) 1228 val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale 1229 return val 1230 1231 return Pointwise.create( 1232 device=input.get_device(), 1233 dtype=torch.float32, 1234 inner_fn=inner_fn, 1235 ranges=input.get_size(), 1236 ) 1237 1238 1239@register_lowering( 1240 quantized_decomposed.quantize_per_tensor.default, type_promotion_kind=None 1241) 1242def quantized_decomposed_quantize_per_tensor_default( 1243 input: TensorBox, 1244 scale: float, 1245 zero_point: int, 1246 quant_min: int, 1247 quant_max: int, 1248 dtype: torch.dtype, 1249) -> TensorBox: 1250 if input.get_dtype() == torch.bfloat16: 1251 input = to_dtype(input, torch.float32) 1252 assert ( 1253 input.get_dtype() == torch.float32 1254 ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" 1255 1256 input_loader = input.make_loader() 1257 1258 def inner_fn(idx, scale, zero_point): 1259 input = input_loader(idx) 1260 inv_scale, zero_point = _create_constants( 1261 1.0 / scale, zero_point, dtype=torch.float32 1262 ) 1263 val = ops.round(input * inv_scale) + zero_point 1264 qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) 1265 clamped = ops.minimum(ops.maximum(val, qmin), qmax) 1266 return ops.to_dtype(clamped, dtype) 1267 1268 return Pointwise.create( 1269 device=input.get_device(), 1270 dtype=dtype, 1271 inner_fn=functools.partial( 1272 inner_fn, scale=float(scale), zero_point=int(zero_point) 1273 ), 1274 ranges=input.get_size(), 1275 ) 1276 1277 1278@register_lowering( 1279 quantized_decomposed.dequantize_per_tensor.default, type_promotion_kind=None 1280) 1281def quantized_decomposed_dequantize_per_tensor_default( 1282 input: TensorBox, 1283 scale: float, 1284 zero_point: int, 1285 quant_min: int, 1286 quant_max: int, 1287 dtype: torch.dtype, 1288) -> TensorBox: 1289 assert ( 1290 input.get_dtype() == dtype 1291 ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" 1292 1293 input_loader = input.make_loader() 1294 1295 def inner_fn(idx, scale, zero_point): 1296 input = input_loader(idx) 1297 scale, zero_point = _create_constants(scale, zero_point, dtype=torch.float32) 1298 val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale 1299 return val 1300 1301 return Pointwise.create( 1302 device=input.get_device(), 1303 dtype=torch.float32, 1304 inner_fn=functools.partial( 1305 inner_fn, scale=float(scale), zero_point=int(zero_point) 1306 ), 1307 ranges=input.get_size(), 1308 ) 1309 1310 1311@register_lowering( 1312 quantized_decomposed.quantize_per_tensor.tensor, type_promotion_kind=None 1313) 1314def quantized_decomposed_quantize_per_tensor_tensor( 1315 input: TensorBox, 1316 scale: TensorBox, 1317 zero_point: TensorBox, 1318 quant_min: int, 1319 quant_max: int, 1320 dtype: torch.dtype, 1321) -> TensorBox: 1322 if input.get_dtype() == torch.bfloat16: 1323 input = to_dtype(input, torch.float32) 1324 assert ( 1325 input.get_dtype() == torch.float32 1326 ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" 1327 assert len(scale.get_size()) == 0 or ( 1328 len(scale.get_size()) == 1 and scale.get_size()[0] == 1 1329 ), "expect scale as scalar tensor" 1330 assert len(zero_point.get_size()) == 0 or ( 1331 len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 1332 ), "expect zero_point as scalar tensor" 1333 1334 input_loader = input.make_loader() 1335 scale_loader = scale.make_loader() 1336 zero_point_loader = zero_point.make_loader() 1337 1338 def inner_fn(idx): 1339 input = input_loader(idx) 1340 _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ()) 1341 _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ()) 1342 if scale.dtype != torch.float32: 1343 _scale = ops.to_dtype(_scale, torch.float32) 1344 if zero_point.dtype != torch.float32: 1345 _zero_point = ops.to_dtype(_zero_point, torch.float32) 1346 val = ops.round(input * ops.reciprocal(_scale)) + _zero_point 1347 qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) 1348 clamped = ops.minimum(ops.maximum(val, qmin), qmax) 1349 return ops.to_dtype(clamped, dtype) 1350 1351 return Pointwise.create( 1352 device=input.get_device(), 1353 dtype=dtype, 1354 inner_fn=inner_fn, 1355 ranges=input.get_size(), 1356 ) 1357 1358 1359@register_lowering( 1360 quantized_decomposed.dequantize_per_tensor.tensor, type_promotion_kind=None 1361) 1362def quantized_decomposed_dequantize_per_tensor_tensor( 1363 input: TensorBox, 1364 scale: TensorBox, 1365 zero_point: TensorBox, 1366 quant_min: int, 1367 quant_max: int, 1368 dtype: torch.dtype, 1369) -> TensorBox: 1370 assert len(scale.get_size()) == 0 or ( 1371 len(scale.get_size()) == 1 and scale.get_size()[0] == 1 1372 ), "expect scale as scalar tensor" 1373 assert len(zero_point.get_size()) == 0 or ( 1374 len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 1375 ), "expect zero_point as scalar tensor" 1376 assert ( 1377 input.get_dtype() == dtype 1378 ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" 1379 1380 input_loader = input.make_loader() 1381 scale_loader = scale.make_loader() 1382 zero_point_loader = zero_point.make_loader() 1383 1384 def inner_fn(idx): 1385 input = input_loader(idx) 1386 _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ()) 1387 _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ()) 1388 if scale.dtype != torch.float32: 1389 _scale = ops.to_dtype(_scale, torch.float32) 1390 if zero_point.dtype != torch.float32: 1391 _zero_point = ops.to_dtype(_zero_point, torch.float32) 1392 val = ops.sub(ops.to_dtype(input, torch.float32), _zero_point) * _scale 1393 return val 1394 1395 return Pointwise.create( 1396 device=input.get_device(), 1397 dtype=torch.float32, 1398 inner_fn=inner_fn, 1399 ranges=input.get_size(), 1400 ) 1401 1402 1403@register_lowering(aten.cat) 1404def cat(inputs, dim=0): 1405 cpu_device = inputs[0].get_device().type == "cpu" 1406 if cpu_device and all( 1407 input.get_dtype() in [torch.int8, torch.uint8] for input in inputs 1408 ): 1409 # TODO <leslie> Remove this fallback when we support vectorization 1410 # code gen with uint8 data type directly. 1411 for input in inputs: 1412 input.realize() 1413 if all(len(input.get_size()) == 4 for input in inputs): 1414 inputs, _ = require_channels_last(aten.cat, *inputs) 1415 return fallback_handler(aten.cat.default)(inputs, dim) 1416 1417 if len(inputs) == 1: 1418 return clone(inputs[0]) 1419 1420 dim = _validate_dim(inputs[0], dim, 0) 1421 dtype = get_promoted_dtype( 1422 *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 1423 ) 1424 inputs = [to_dtype(inp, dtype) for inp in inputs] 1425 1426 def unwrap_tensor(x: Union[TensorBox, ir.StorageBox]) -> ir.IRNode: 1427 if isinstance(x, TensorBox): 1428 if isinstance(x.data, ir.BaseView): 1429 return x.data.unwrap_view() 1430 else: 1431 return x.data 1432 1433 if isinstance(x, ir.StorageBox): 1434 return x.data 1435 1436 return x 1437 1438 def is_reduction(t): 1439 return isinstance(t, ir.ComputedBuffer) and isinstance(t.data, ir.Reduction) 1440 1441 def can_fuse_reduction(t): 1442 if isinstance(t, (TensorBox, ir.StorageBox)): 1443 return can_fuse_reduction(unwrap_tensor(t)) 1444 return ( 1445 is_reduction(t) 1446 or isinstance(t, ir.Pointwise) 1447 and any( 1448 can_fuse_reduction(V.graph.get_buffer(read)) 1449 for read in t.get_read_names() 1450 ) 1451 ) 1452 1453 # fusing reducutions into computed concat buffer can cause regressions. 1454 fusable_reduction = any(can_fuse_reduction(t) for t in inputs) 1455 1456 def should_lower_cat_input(x) -> bool: 1457 # Unrealized inputs will not be storage and layouts, and we dont want to realize 1458 # them in case we want to fuse 1459 if ir.is_storage_and_layout(x): 1460 storage, _ = ir.as_storage_and_layout(x, freeze=False) 1461 return not ir.ConcatKernel.can_realize_into_without_copy(storage) 1462 1463 if isinstance(x, (TensorBox, ir.StorageBox)): 1464 return should_lower_cat_input(unwrap_tensor(x)) 1465 1466 if isinstance(x, ir.Pointwise): 1467 return True 1468 1469 return False 1470 1471 # TODO: We observed negative performance impact of pointwise_cat optimization on CPU so disabled it. 1472 # We will revisit this later after enabling vectorization on index_expr. 1473 if cpu_device: 1474 return TensorBox(ir.ConcatKernel.create(inputs, dim)) 1475 1476 def op_count(x): 1477 if isinstance(x, (TensorBox, ir.StorageBox)): 1478 return op_count(unwrap_tensor(x)) 1479 1480 # this will correspond to a direct memory read 1481 if not isinstance(x, ir.Pointwise): 1482 return 0 1483 1484 count = x.inner_fn_opcount().num_ops 1485 for read in x.get_read_names(): 1486 count += op_count(V.graph.get_buffer(read)) 1487 1488 return count 1489 1490 # as of inputs increase, possibility for register spilling also increases 1491 # past a certain threshold of inputs we only fuse if the if the input kernels 1492 # are simple 1493 # not sure if we want to expose to users via config since logic may change in future 1494 MAX_COMPLEX_POINTWISE_CAT = 8 1495 MAX_SIMPLE_OP_COUNT = 2 1496 1497 def additional_pointwise_ops(op: torch._ops.OpOverload): 1498 return op in (aten.cat.default, aten.constant_pad_nd.default) 1499 1500 if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or ( 1501 (len(inputs) <= config.max_pointwise_cat_inputs) 1502 and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs) 1503 ): 1504 pointwise_uses = all( 1505 is_pointwise_use(use, additional_pointwise_ops) 1506 for use in V.current_node.users 1507 ) 1508 # fuse in case we will be used in a pointwise node, and there are any inputs we 1509 # we can prevent materialization of. 1510 fuse_pointwise_use = ( 1511 any(should_lower_cat_input(inp) for inp in inputs) and pointwise_uses 1512 ) 1513 1514 # horizontal fuse in case all inputs will require a copy kernel anyway. 1515 # only horizontally fuse pointwise kernels 1516 horizontal_fuse_cat = all( 1517 should_lower_cat_input(inp) for inp in inputs 1518 ) and not any(can_fuse_reduction(t) for t in inputs) 1519 if fuse_pointwise_use or (horizontal_fuse_cat and not fusable_reduction): 1520 return pointwise_cat(inputs, dim) 1521 1522 return TensorBox(ir.ConcatKernel.create(inputs, dim)) 1523 1524 1525@register_lowering(aten.diagonal, type_promotion_kind=None) 1526def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1): 1527 original_shape = input.get_size() 1528 num_dims = len(original_shape) 1529 dim1 = canonicalize_dim(idx=dim1, rank=num_dims) 1530 dim2 = canonicalize_dim(idx=dim2, rank=num_dims) 1531 1532 check( 1533 dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" 1534 ) 1535 1536 offset_negative = V.graph.sizevars.evaluate_expr(sympy.Lt(offset, 0)) 1537 if offset_negative: 1538 diag_size = V.graph.sizevars.evaluate_max( 1539 V.graph.sizevars.evaluate_min( 1540 original_shape[dim1] + offset, original_shape[dim2] 1541 ), 1542 0, # type: ignore[arg-type] 1543 ) 1544 else: 1545 diag_size = V.graph.sizevars.evaluate_max( 1546 V.graph.sizevars.evaluate_min( 1547 original_shape[dim1], original_shape[dim2] - offset 1548 ), 1549 0, # type: ignore[arg-type] 1550 ) 1551 1552 base_idx = (0, 0) 1553 if offset_negative: 1554 base_idx = (-offset, 0) 1555 else: 1556 base_idx = (0, offset) 1557 1558 sizes = [s for i, s in enumerate(original_shape) if i not in (dim1, dim2)] 1559 sizes.append(diag_size) 1560 1561 def reindexer(idx): 1562 diag_idx = idx[-1] 1563 original_idx = [0] * len(original_shape) 1564 cur_dim = 0 1565 for d in range(num_dims): 1566 if d == dim1: 1567 original_idx[d] = diag_idx + base_idx[0] 1568 elif d == dim2: 1569 original_idx[d] = diag_idx + base_idx[1] 1570 else: 1571 original_idx[d] = idx[cur_dim] 1572 cur_dim += 1 1573 1574 assert cur_dim == len(original_shape) - 2 1575 return original_idx 1576 1577 return TensorBox(ir.GenericView.create(input, sizes, reindexer)) 1578 1579 1580@register_lowering(aten.diagonal_copy, type_promotion_kind=None) 1581def diagonal_copy(input, offset: int = 0, dim1: int = 0, dim2: int = 1): 1582 return clone(diagonal(input, offset, dim1, dim2)) 1583 1584 1585@register_lowering(aten.diagonal_scatter, type_promotion_kind=None) 1586def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1): 1587 output = clone(input) 1588 target = diagonal(output, offset, dim1, dim2) 1589 mutate_to(target, src) 1590 return output 1591 1592 1593@register_lowering(aten.select, type_promotion_kind=None) 1594def select(x, dim, idx): 1595 idx = View.handle_negative_index(idx, x.get_size()[dim]) 1596 return squeeze(slice_(x, dim, idx, idx + 1), dim) 1597 1598 1599@register_lowering(aten.split, type_promotion_kind=None) 1600def split(x, sizes, dim=0, clamp=True): 1601 dim = _validate_dim(x, dim, 0) 1602 if isinstance(sizes, sympy.Expr): 1603 # TODO: We don't have to guard on sizes per se, but the number 1604 # of splits must stay constant 1605 sizes = V.graph.sizevars.evaluate_static_shape(sizes) 1606 if isinstance(sizes, (int, sympy.Integer)): 1607 x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) 1608 sizes = [sizes] * ((x_size + sizes - 1) // sizes) 1609 result = [] 1610 start = 0 1611 for size in sizes: 1612 end = start + size 1613 result.append(slice_(x, dim, start, end, clamp=clamp)) 1614 start = end 1615 return result 1616 1617 1618@register_lowering(aten.split_with_sizes, type_promotion_kind=None) 1619def split_with_sizes(x, sizes, dim=0): 1620 return split(x, sizes, dim, clamp=False) 1621 1622 1623@register_lowering(aten.unbind, type_promotion_kind=None) 1624def unbind(x, dim=0): 1625 dim = _validate_dim(x, dim, 0) 1626 x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) 1627 result = [] 1628 for i in range(x_size): 1629 result.append(select(x, dim, i)) 1630 return result 1631 1632 1633@register_lowering(aten.unfold, type_promotion_kind=None) 1634def unfold(x, dimension, size, step): 1635 sizes = x.get_size() 1636 ndim = len(sizes) 1637 dim = canonicalize_dim(ndim, dimension) 1638 1639 if ndim == 0: 1640 return slice_(unsqueeze(x, 0), end=size) 1641 1642 dim_size = sizes[dim] 1643 sizevars = V.graph.sizevars 1644 sizevars.guard_leq(size, dim_size) 1645 sizevars.guard_lt(0, step) # type: ignore[arg-type] 1646 1647 new_dim_size = FloorDiv(dim_size - size, step) + 1 1648 if sizevars.size_hint(dim_size) > 0: 1649 x.mark_reuse(sizevars.size_hint(CeilDiv(new_dim_size * size, dim_size))) 1650 1651 out_size = [*sizes[:dim], new_dim_size, *sizes[dim + 1 :], size] 1652 1653 def reindexer(idx): 1654 dim_idx = idx[-1] + idx[dim] * step 1655 return (*idx[:dim], dim_idx, *idx[dim + 1 : -1]) 1656 1657 return TensorBox(ir.GenericView.create(x, out_size, reindexer)) 1658 1659 1660@register_lowering(aten.unsqueeze, type_promotion_kind=None) 1661def unsqueeze(x, dim): 1662 dim = _validate_dim(x, dim, 1) 1663 new_shape = list(x.get_size()) 1664 new_shape.insert(dim, sympy.Integer(1)) 1665 return view(x, new_shape) 1666 1667 1668@register_lowering(aten.unsqueeze_, type_promotion_kind=None) 1669def unsqueeze_(x, dim): 1670 val = unsqueeze(x, dim) 1671 assert isinstance(x, TensorBox) 1672 assert isinstance(val, TensorBox) 1673 x.data = val.data 1674 return x 1675 1676 1677def _validate_dim(x, dim, offset=0): 1678 dim = V.graph.sizevars.shape_env.evaluate_expr(sympy.sympify(dim)) 1679 ndim = len(x.get_size()) 1680 if dim < 0: 1681 dim += ndim + offset 1682 assert 0 <= dim < ndim + offset 1683 return dim 1684 1685 1686@register_lowering(aten.glu) 1687def glu(x, dim=-1): 1688 dim = _validate_dim(x, dim, 0) 1689 # TODO: don't guard on static shape here 1690 new_len = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) // 2 1691 a = slice_(x, dim, 0, new_len) 1692 b = slice_(x, dim, new_len, new_len * 2) 1693 return mul(a, sigmoid(b)) 1694 1695 1696def fallback_handler(kernel, add_to_fallback_set=True): 1697 if add_to_fallback_set: 1698 fallbacks.add(kernel) 1699 1700 def handler(*args, **kwargs): 1701 def wrap_tensors(x): 1702 return TensorBox.create(x) if isinstance(x, ir.IRNode) else x 1703 1704 return pytree.tree_map( 1705 wrap_tensors, ir.FallbackKernel.create(kernel, *args, **kwargs) 1706 ) 1707 1708 return handler 1709 1710 1711@functools.lru_cache(None) 1712def _warn_complex_not_supported(): 1713 warnings.warn( 1714 "Torchinductor does not support code generation for complex operators. Performance may be worse than eager." 1715 ) 1716 1717 1718# There are some types (CPU) which we accept as input but not as 1719# output. 1720def unsupported_input_tensor(t: torch.Tensor, parent=None): 1721 "Do not support reading or writing to this tensor" 1722 if t.is_complex(): 1723 # Complex views are supported with IR ComplexView 1724 if parent and parent.target in ( 1725 torch.ops.aten.view.dtype, 1726 torch.ops.prims.convert_element_type.default, 1727 ): 1728 return False 1729 _warn_complex_not_supported() 1730 return True 1731 return False 1732 1733 1734def unsupported_output_tensor(t: torch.Tensor, parent=None): 1735 "Do not support writing tensor but can read from it" 1736 if unsupported_input_tensor(t, parent): 1737 return True 1738 return t.is_cpu and config.disable_cpp_codegen 1739 1740 1741def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=True): 1742 # Custom fallback lowering 1743 if node.target is aten.view_as_complex.default: 1744 return False 1745 1746 # We should be able to remove this special case once `disable_cpp_codegen` is killed. 1747 if node.target is aten.lift_fresh_copy.default: 1748 return False 1749 1750 def check_skip_condition(node, parent, is_output): 1751 if not isinstance(node, torch.fx.Node): 1752 return False 1753 1754 if "val" not in node.meta: 1755 return False 1756 1757 for meta in pytree.tree_leaves(node.meta["val"]): 1758 if not isinstance(meta, torch._subclasses.FakeTensor): 1759 continue 1760 1761 if is_output: 1762 if unsupported_output_tensor(meta, parent): 1763 return True 1764 else: 1765 if unsupported_input_tensor(meta, parent): 1766 return True 1767 1768 return False 1769 1770 # only skip codegen if there is a cpu output, not input 1771 for arg in pytree.arg_tree_leaves(*node.args, **node.kwargs): 1772 if check_skip_condition(arg, node, is_output=False): 1773 return True 1774 1775 return check_skip_condition(node, node, is_output=True) 1776 1777 1778def make_fallback(op, layout_constraint=None, warn=True): 1779 assert op not in decompositions, f"both a fallback and a decomp for same op: {op}" 1780 if ( 1781 warn 1782 and bool(os.getenv("CI")) 1783 and get_decompositions([op]) 1784 # if fallback_random, we allow not decomposing random 1785 and not ( 1786 config.fallback_random 1787 and op in torch._decomp.decompositions_for_rng.extra_random_decomps 1788 ) 1789 ): 1790 # Note: 'warn' is holdover from when this was a warning, but for ops that previously 1791 # set warn=False we do not want a CI error. 1792 # Ignore the 'suppress errors' configs in CI, as this particular warning happens on startup anyway and is not 1793 # likely to be triggered preferentially on one CI config over another. 1794 if torch._dynamo.config.suppress_errors: 1795 torch._dynamo.config.suppress_errors = False 1796 log.warning( 1797 "A make_fallback error occurred in suppress_errors config," 1798 " and suppress_errors is being disabled to surface it." 1799 ) 1800 raise AssertionError( 1801 f"make_fallback({op}): a decomposition exists, we should switch to it." 1802 " To fix this error, either add a decomposition to core_aten_decompositions (preferred)" 1803 " or inductor_decompositions, and delete the corresponding `make_fallback` line." 1804 " Get help from the inductor team if unsure, don't pick arbitrarily to unblock yourself.", 1805 ) 1806 1807 def register_fallback(op_overload): 1808 add_needs_realized_inputs(op_overload) 1809 if layout_constraint is not None: 1810 add_layout_constraint(op_overload, layout_constraint) 1811 return register_lowering(op_overload, type_promotion_kind=None)( 1812 fallback_handler(op_overload) 1813 ) 1814 1815 if isinstance(op, torch._ops.OpOverloadPacket): 1816 for ol in op.overloads(): 1817 op_overload = getattr(op, ol) 1818 register_fallback(op_overload) 1819 elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): 1820 register_fallback(op) 1821 else: 1822 raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}") 1823 1824 1825def philox_rand_offset(shape): 1826 """ 1827 TorchInductor offset calculation differs from PyTorch eager offset 1828 calculation for random ops (tl.rand vs torch.rand). In future, we should 1829 strive for same impl for tl.rand and torch.rand. 1830 """ 1831 numel = 1 1832 for s in shape: 1833 numel = numel * s 1834 return tensor(numel, dtype=torch.int64) 1835 1836 1837@register_lowering(torch.ops.rngprims.philox_rand, type_promotion_kind=None) 1838def philox_rand(size, seed, offset, stride, device, dtype): 1839 # stride arg is optional and will be used in future for distributed random 1840 # ops. Currently, its unused. 1841 random_pos = ir.FixedLayout( 1842 device, 1843 dtype, 1844 size, 1845 ir.FlexibleLayout.contiguous_strides(size), 1846 ).make_indexer() 1847 seed_loader = seed.make_loader() 1848 offset_loader = offset.make_loader() 1849 1850 def inner_fn(index): 1851 # Both seed and offset in the philox_rand op are tensors. 1852 # torch seed and offsets are of type int64, but tl.rand accepts int32 1853 seed_index_expr = ops.to_dtype(seed_loader([]), torch.int32) 1854 offset_index_expr = ops.to_dtype(offset_loader([]), torch.int32) 1855 # Get the offset'd position 1856 rand_index_expr = ops.add( 1857 ops.index_expr(random_pos(index), torch.int32), offset_index_expr 1858 ) 1859 result = ops.rand( 1860 seed_index_expr, 1861 rand_index_expr, 1862 ) 1863 return ops.to_dtype(result, dtype) 1864 1865 random_values_node = Pointwise.create( 1866 device=device, 1867 dtype=dtype, 1868 inner_fn=inner_fn, 1869 ranges=list(size), 1870 ) 1871 1872 offset_node = philox_rand_offset(size) 1873 return random_values_node, offset_node 1874 1875 1876@register_lowering(aten.native_dropout, type_promotion_kind=None) 1877def native_dropout(x, p, train): 1878 if config.fallback_random: 1879 return pytree.tree_map( 1880 TensorBox.create, 1881 ir.FallbackKernel.create(aten.native_dropout.default, x, p, train), 1882 ) 1883 else: 1884 raise AssertionError("should be handled in replace_random.py") 1885 1886 1887@register_lowering(aten.bernoulli_, type_promotion_kind=None) 1888def bernoulli_(x, *args): 1889 assert config.fallback_random or x.get_device() == torch.device( 1890 "cpu" 1891 ), "this should be handled in decomps unless config.fallback_random or the device is CPU" 1892 x.realize() 1893 op_overload = ( 1894 aten.bernoulli_.float 1895 if len(args) == 0 or isinstance(args[0], float) 1896 else aten.bernoulli_.Tensor 1897 ) 1898 ir.InplaceBernoulliFallback(op_overload, x, *args) 1899 return x 1900 1901 1902@register_lowering(aten.bernoulli.p, type_promotion_kind=None) 1903def bernoulli_p(x, *args): 1904 assert config.fallback_random or x.get_device() == torch.device( 1905 "cpu" 1906 ), "this should be handled in decomps unless config.fallback_random or the device is CPU" 1907 return bernoulli_(clone(x), *args) 1908 1909 1910# This shouldn't be called in general 1911@register_lowering(aten._foobar) 1912def _foobar(_): 1913 raise AssertionError 1914 1915 1916@functools.lru_cache(1) 1917def _warn_triton_random(salt): 1918 log.info("using triton random, expect difference from eager") 1919 1920 1921def warn_triton_random(): 1922 # only warn once per graph 1923 _warn_triton_random(V.graph.creation_time) 1924 1925 1926fallback_rand_default = fallback_handler(aten.rand.default) 1927fallback_rand_generator = fallback_handler(aten.rand.generator) 1928fallback_randn_default = fallback_handler(aten.randn.default) 1929fallback_randn_generator = fallback_handler(aten.randn.generator) 1930make_fallback(aten.randint) 1931 1932 1933@register_lowering(aten.rand) 1934def rand(*args, **kwargs): 1935 if kwargs.get("generator", None) is not None: 1936 return fallback_rand_generator(*args, **kwargs) 1937 elif config.fallback_random: 1938 kwargs.pop("generator", None) 1939 return fallback_rand_default(*args, **kwargs) 1940 raise AssertionError("should have been handled in replace_random.py") 1941 1942 1943@register_lowering(aten.randn) 1944def randn(*args, **kwargs): 1945 if kwargs.get("generator", None) is not None: 1946 return fallback_randn_generator(*args, **kwargs) 1947 elif config.fallback_random: 1948 kwargs.pop("generator", None) 1949 return fallback_randn_default(*args, **kwargs) 1950 raise AssertionError("should have been handled in replace_random.py") 1951 1952 1953@register_lowering(inductor_prims.force_stride_order, type_promotion_kind=None) 1954def inductor_force_stride_order(input_tensor, stride): 1955 stride_order = ir.get_stride_order(stride) 1956 return ir.ExternKernel.require_stride_order(input_tensor, stride_order) 1957 1958 1959@register_lowering(inductor_prims.seed, type_promotion_kind=None) 1960def inductor_seed(device: torch.device): 1961 raise AssertionError("should be handled in fuse_seed_creation_pass()") 1962 1963 1964@register_lowering(inductor_prims.seeds, type_promotion_kind=None) 1965def inductor_seeds(count, device): 1966 warn_triton_random() 1967 return TensorBox.create(ir.RandomSeeds(count, decode_device(device))) 1968 1969 1970@register_lowering(inductor_prims.lookup_seed, type_promotion_kind=None) 1971def inductor_lookup_seed(seeds, index): 1972 def inner_fn(_): 1973 return ops.load_seed(seeds.get_name(), index) 1974 1975 return Pointwise.create( 1976 device=seeds.get_device(), 1977 dtype=seeds.get_dtype(), 1978 inner_fn=inner_fn, 1979 ranges=[], 1980 ) 1981 1982 1983@register_lowering(inductor_prims.random, type_promotion_kind=None) 1984def inductor_random(size: List[int], seed: TensorBox, mode: str, *, offset: int = 0): 1985 assert not config.fallback_random 1986 assert mode in ("rand", "randn") 1987 size = [*size] 1988 dtype = torch.float32 1989 device = seed.get_device() 1990 random_pos = ir.FixedLayout( 1991 device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset 1992 ).make_indexer() 1993 seed_loader = seed.make_loader() 1994 1995 def inner_fn(index): 1996 return getattr(ops, mode)( 1997 seed_loader([]), 1998 ops.index_expr(random_pos(index), torch.int32), 1999 ) 2000 2001 result = Pointwise.create( 2002 device=device, 2003 dtype=dtype, 2004 inner_fn=inner_fn, 2005 ranges=[*size], 2006 ) 2007 result.realize() 2008 return result 2009 2010 2011@register_lowering(inductor_prims.randint, type_promotion_kind=None) 2012def inductor_randint( 2013 low: int, high: int, size: List[int], seed: TensorBox, *, offset: int = 0 2014): 2015 assert not config.fallback_random 2016 size = [*size] 2017 dtype = torch.int64 2018 device = seed.get_device() 2019 random_pos = ir.FixedLayout( 2020 device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset 2021 ).make_indexer() 2022 seed_loader = seed.make_loader() 2023 2024 def inner_fn(index): 2025 return ops.randint64( 2026 seed_loader([]), 2027 ops.index_expr(random_pos(index), torch.int32), 2028 ops.index_expr(low, torch.int64), 2029 ops.index_expr(high, torch.int64), 2030 ) 2031 2032 return Pointwise.create( 2033 device=device, 2034 dtype=dtype, 2035 inner_fn=inner_fn, 2036 ranges=[*size], 2037 ) 2038 2039 2040@register_lowering(aten.bucketize, type_promotion_kind=None) 2041def bucketize( 2042 input: TensorBox, 2043 boundaries: TensorBox, 2044 *, 2045 out_int32: bool = False, 2046 right: bool = False, 2047): 2048 assert len(boundaries.get_size()) == 1 2049 2050 if not ( 2051 V.graph.has_feature(input, BackendFeature.BUCKETIZE) 2052 and V.graph.has_feature(boundaries, BackendFeature.BUCKETIZE) 2053 ): 2054 return fallback_handler(aten.bucketize.Tensor, add_to_fallback_set=False)( 2055 input, boundaries, out_int32=out_int32, right=right 2056 ) 2057 2058 # The entire boundaries tensor needs to be used by ops.bucketize, so we 2059 # need to realize it into global memory; or in other words, we can't 2060 # guarantee that boundaries.get_name() (used below) will exist unless 2061 # we call boundaries.realize(). 2062 boundaries.realize() 2063 boundaries_size = boundaries.get_size()[0] 2064 device = input.get_device() 2065 input_loader = input.make_loader() 2066 2067 index_dtype = torch.int32 if out_int32 else torch.int64 2068 2069 def inner_fn(index): 2070 val = input_loader(index) 2071 indices = ops.bucketize( 2072 val, 2073 boundaries.get_name(), 2074 boundaries_size, 2075 index_dtype, 2076 right, 2077 ) 2078 2079 return indices 2080 2081 return Pointwise.create( 2082 device=device, 2083 dtype=index_dtype, 2084 inner_fn=inner_fn, 2085 ranges=input.get_size(), 2086 ) 2087 2088 2089def require_dense(_, *args, **kwargs): 2090 args, kwargs = pytree.tree_map_only( 2091 ir.IRNode, ir.ExternKernel.require_stride1, (args, kwargs) 2092 ) 2093 return args, kwargs 2094 2095 2096def require_contiguous(_, *args, **kwargs): 2097 args, kwargs = pytree.tree_map_only( 2098 ir.IRNode, ir.ExternKernel.require_contiguous, (args, kwargs) 2099 ) 2100 return args, kwargs 2101 2102 2103def require_channels_last(_, *args, **kwargs): 2104 args, kwargs = pytree.tree_map_only( 2105 ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs) 2106 ) 2107 return args, kwargs 2108 2109 2110def constrain_to_fx_strides(fx_node, *args, ignore_mutated_args_FIXME=False, **kwargs): 2111 def apply_constraint(arg, fx_arg): 2112 if isinstance(arg, ir.IRNode): 2113 stride_order = ir.get_stride_order(fx_arg.meta["val"].stride()) 2114 return ir.ExternKernel.require_stride_order(arg, stride_order) 2115 return arg 2116 2117 # There's a silent incorrectness bug where we if we constrain a mutated arg, 2118 # we may end up cloning it, writing in-place to the clone, and then using 2119 # the original value (instead of the cloned value). Our short-term fix for this 2120 # is to never constrain mutated args; longer term we do want to fix this. 2121 # https://github.com/pytorch/pytorch/issues/128084 2122 if ignore_mutated_args_FIXME: 2123 assert isinstance(fx_node.target, torch._ops.OpOverload) 2124 schema = fx_node.target._schema 2125 2126 def maybe_apply_constraint(schema_arg, arg, fx_arg): 2127 if schema_arg.alias_info is not None and schema_arg.alias_info.is_write: 2128 return arg 2129 return apply_constraint(arg, fx_arg) 2130 2131 new_args = [] 2132 new_kwargs = {} 2133 2134 for idx, (arg, fx_arg) in enumerate(zip(args, fx_node.args)): 2135 schema_arg = schema.arguments[idx] 2136 new_args.append(maybe_apply_constraint(schema_arg, arg, fx_arg)) 2137 2138 schema_kwargs = {arg.name: arg for arg in schema.arguments} 2139 2140 for key in kwargs.keys(): 2141 arg = kwargs[key] 2142 fx_arg = fx_node.kwargs[key] 2143 schema_arg = schema_kwargs[key] 2144 new_kwargs[key] = maybe_apply_constraint(schema_arg, arg, fx_arg) 2145 2146 return tuple(new_args), new_kwargs 2147 2148 args = tuple( 2149 apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) 2150 ) 2151 kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} 2152 return args, kwargs 2153 2154 2155# TODO(jansel): we should implement decomps or lowerings for these 2156# https://github.com/pytorch/torchdynamo/issues/327 2157FALLBACK_ALLOW_LIST = { 2158 "torchvision::roi_align", 2159} 2160 2161 2162def sdpa_constraint(fx_node, *args, **kwargs): 2163 # sdpa requires dense last dimension] 2164 2165 def apply_constraint(arg, fx_arg): 2166 if not isinstance(arg, ir.IRNode): 2167 return arg 2168 2169 meta_val = fx_arg.meta["val"] 2170 meta_stride = meta_val.stride() 2171 2172 stride_order = ir.get_stride_order(meta_stride) 2173 if stride_order and stride_order[-1] != 0: 2174 # contiguous stride order 2175 stride_order = list(reversed(range(len(arg.get_size())))) 2176 2177 if not meta_val.is_cuda: 2178 return ir.ExternKernel.require_stride_order(arg, stride_order) 2179 2180 # This is the minimum alignment required by SDPA kernels for attention_bias. 2181 # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask 2182 ALIGNMENT = 8 2183 2184 assert isinstance(arg, TensorBox) 2185 if len(arg.get_size()) not in (3, 4): 2186 return arg 2187 2188 def is_aligned_realized_tensor(x): 2189 aligned_strides = all( 2190 (V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0 2191 for i in range(len(x.get_stride()) - 1) 2192 ) 2193 return ( 2194 V.graph.sizevars.size_hint(x.get_stride()[-1]) 2195 ) == 1 and aligned_strides 2196 2197 try: 2198 arg.get_stride() 2199 if is_aligned_realized_tensor(arg): 2200 return V.graph.try_match_insignificant_strides( 2201 ir.ExternKernel.realize_input(arg), meta_stride 2202 ) 2203 except AttributeError: 2204 pass 2205 2206 def is_aligned(x): 2207 return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0 2208 2209 if isinstance(arg.data, ir.BaseView): 2210 if not is_aligned(arg): 2211 if is_aligned(arg.unwrap_view()): 2212 return V.graph.try_match_insignificant_strides( 2213 ir.ExternKernel.realize_input(arg), meta_stride 2214 ) 2215 2216 return ir.ExternKernel.require_stride_order(arg, stride_order) 2217 2218 args = tuple( 2219 apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) 2220 ) 2221 kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} 2222 return args, kwargs 2223 2224 2225# WIP 2226make_fallback(aten._adaptive_avg_pool3d) # @isuruf 2227make_fallback(aten.adaptive_max_pool3d) # @isuruf 2228make_fallback(aten.fractional_max_pool3d) # @isuruf 2229make_fallback(aten.max_pool3d_with_indices) # @isuruf (can this one be implemented?) 2230 2231 2232# 1) Easy 2233make_fallback(aten.uniform, warn=False) 2234make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py) 2235make_fallback(aten._pdist_forward) # Has decomp. Needs benchmarks 2236make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl? 2237make_fallback(aten.searchsorted) # bucketized is implemented (see eager impl) 2238 2239 2240# 1.5) Easy or Impossible 2241make_fallback(aten._cdist_forward) # p=2 should be feasible 2242make_fallback(aten._cdist_backward) 2243 2244# 2) Medium 2245make_fallback(aten.max_unpool2d) 2246make_fallback(aten.max_unpool3d) 2247make_fallback(aten._trilinear) 2248 2249 2250# 3) Difficult 2251# Scans 2252# See the discussion at 2253# https://dev-discuss.pytorch.org/t/pytorch-sparse-gnn-compiler-rfc/1644/19 2254make_fallback(aten.segment_reduce.default) 2255make_fallback(aten._segment_reduce_backward.default) 2256 2257# Histogram (need to implement Histogram IR) 2258make_fallback(aten.histc) 2259make_fallback(aten.histogram.bin_ct) 2260make_fallback(aten._histogramdd_bin_edges.default) 2261make_fallback(aten._histogramdd_from_bin_cts.default) 2262 2263# Need templated kernel 2264make_fallback(aten.addbmm) 2265make_fallback(aten._addmm_activation, warn=False) 2266 2267# Need templated kernel. Probably impossible to write efficiently 2268make_fallback(aten.convolution_backward, constrain_to_fx_strides) 2269make_fallback(aten._cudnn_rnn, require_dense) 2270make_fallback(aten._cudnn_rnn_backward, require_contiguous) 2271 2272# Haven't checked but sound difficult / impossible 2273make_fallback(aten._embedding_bag, require_contiguous) 2274make_fallback(aten._embedding_bag_forward_only, require_contiguous) 2275make_fallback(aten._embedding_bag_backward) 2276make_fallback(aten._embedding_bag_per_sample_weights_backward) 2277make_fallback(aten._embedding_bag_per_sample_weights_backward) 2278make_fallback(aten._fused_moving_avg_obs_fq_helper) 2279make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) 2280 2281 2282# 4) Backwards (try py_impl'ing them) when fwd is written as a decomp 2283make_fallback(aten.max_pool3d_with_indices_backward) 2284make_fallback(aten._adaptive_avg_pool2d_backward, require_dense) 2285make_fallback(aten._adaptive_avg_pool3d_backward) 2286make_fallback(aten.adaptive_max_pool2d_backward) 2287make_fallback(aten.adaptive_max_pool3d_backward) 2288make_fallback(aten.fractional_max_pool2d_backward) 2289make_fallback(aten.fractional_max_pool3d_backward) 2290make_fallback(aten.replication_pad1d_backward) 2291make_fallback(aten.replication_pad2d_backward) 2292make_fallback(aten.upsample_linear1d_backward) 2293make_fallback(aten.upsample_bicubic2d_backward, require_contiguous) 2294make_fallback(aten.upsample_trilinear3d_backward) 2295make_fallback(aten.grid_sampler_2d_backward, require_dense) 2296make_fallback(aten._pdist_backward) 2297 2298 2299# 5) Impossible (missing triton/CPU features) 2300 2301# Sorting / Sorting-like 2302make_fallback(aten.sort) 2303make_fallback(aten.sort.stable) 2304make_fallback(aten.kthvalue) 2305make_fallback(aten.topk) 2306make_fallback(aten.mode) 2307make_fallback(aten.median) 2308make_fallback(aten.nanmedian) 2309make_fallback(aten.randperm) 2310# see: https://github.com/pytorch/pytorch/pull/121354 2311make_fallback(aten.resize_) 2312make_fallback(aten.resize_as_) 2313 2314# Linalg 2315make_fallback(aten._linalg_det) 2316make_fallback(aten.linalg_householder_product) 2317make_fallback(aten.linalg_inv_ex) 2318make_fallback(aten.linalg_ldl_factor_ex) 2319make_fallback(aten.linalg_ldl_solve) 2320make_fallback(aten.linalg_lu) 2321make_fallback(aten.linalg_lu_factor_ex) 2322make_fallback(aten.linalg_lu_solve) 2323make_fallback(aten.linalg_matrix_exp) 2324make_fallback(aten.linalg_qr) 2325make_fallback(aten._linalg_slogdet) 2326make_fallback(aten._linalg_solve_ex) 2327make_fallback(aten.linalg_solve_triangular) 2328make_fallback(aten._linalg_svd) 2329make_fallback(aten.lu_unpack) 2330make_fallback(aten.ormqr) 2331make_fallback(aten._linalg_check_errors) 2332make_fallback(aten.linalg_pinv.atol_rtol_tensor) 2333make_fallback(aten._linalg_eigh) 2334make_fallback(aten.triangular_solve) 2335make_fallback(aten.linalg_cholesky_ex) 2336make_fallback(aten.cholesky_inverse) 2337make_fallback(aten.cholesky_solve) 2338make_fallback(aten.geqrf) 2339make_fallback(aten._fft_r2c) # needs complex as well 2340 2341# Data dependent (are these necessary?) 2342make_fallback(aten.nonzero.default) 2343 2344# Misc 2345make_fallback(aten.gcd.default, warn=False) 2346make_fallback(aten._thnn_fused_lstm_cell, require_dense) 2347make_fallback(torch._prims.rng_prims.run_and_save_rng_state) 2348make_fallback(torch._prims.rng_prims.run_with_rng_state) 2349 2350# Implmented / Half implemented 2351# Scans. Implemented for CUDA, missing CPU 2352make_fallback(aten.masked_scatter) 2353make_fallback(aten.masked_scatter_backward) 2354 2355# Complex number support 2356make_fallback(aten.view_as_complex, require_contiguous) 2357make_fallback(aten.angle) # needs complex 2358 2359# Needs efficentzerotensor 2360make_fallback(aten._efficientzerotensor) 2361 2362# Needs Sparse 2363make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) 2364make_fallback(aten.to_sparse) 2365make_fallback(aten._to_sparse) 2366 2367# Needs dimname support 2368make_fallback(aten.zeros.names) 2369 2370# 6) Pattern-matched 2371make_fallback( 2372 aten._scaled_dot_product_efficient_attention.default, 2373 sdpa_constraint, 2374 warn=False, 2375) 2376make_fallback( 2377 aten._scaled_dot_product_efficient_attention_backward.default, 2378 sdpa_constraint, 2379 warn=False, 2380) 2381make_fallback( 2382 aten._scaled_dot_product_flash_attention.default, 2383 sdpa_constraint, 2384 warn=False, 2385) 2386make_fallback( 2387 aten._scaled_dot_product_flash_attention_backward.default, 2388 sdpa_constraint, 2389 warn=False, 2390) 2391make_fallback( 2392 aten._scaled_dot_product_cudnn_attention.default, 2393 sdpa_constraint, 2394 warn=False, 2395) 2396make_fallback( 2397 aten._scaled_dot_product_cudnn_attention_backward.default, 2398 sdpa_constraint, 2399 warn=False, 2400) 2401make_fallback( 2402 aten._scaled_dot_product_flash_attention_for_cpu.default, 2403 sdpa_constraint, 2404 warn=False, 2405) 2406make_fallback( 2407 aten._scaled_dot_product_flash_attention_for_cpu_backward.default, 2408 sdpa_constraint, 2409 warn=False, 2410) 2411make_fallback(aten._flash_attention_forward.default, sdpa_constraint) 2412make_fallback(aten._flash_attention_backward.default, sdpa_constraint) 2413make_fallback(aten._efficient_attention_forward.default, sdpa_constraint) 2414make_fallback(aten._efficient_attention_backward.default, sdpa_constraint) 2415 2416# index_reduce requires fallback when use_scatter_fallback(...) returns True 2417make_fallback(aten.index_reduce) 2418 2419 2420# Register with type_promotion_kind None. 2421# For example, fp16.copy_(fp32) should **not** promote the first input's dtype. 2422@register_lowering(aten.copy, type_promotion_kind=None) 2423def copy(self, src, non_blocking=False): 2424 x = src 2425 if self.get_device() != src.get_device(): 2426 x = to_device(x, self.get_device()) 2427 if self.get_dtype() != src.get_dtype(): 2428 x = to_dtype(x, self.get_dtype()) 2429 2430 if self.get_size() != src.get_size(): 2431 out = expand(x, self.get_size()) 2432 return clone(out) 2433 return clone(x) 2434 2435 2436@register_lowering(aten.clone) 2437def clone(x, *, memory_format=None): 2438 # TODO(jansel): memory format 2439 return Pointwise.create( 2440 device=x.get_device(), 2441 dtype=x.get_dtype(), 2442 inner_fn=x.make_loader(), 2443 ranges=list(x.get_size()), 2444 ) 2445 2446 2447def clone_preserve_reinterpret_view(x): 2448 reinterpret_view_layouts = [] 2449 if isinstance(x, TensorBox) and isinstance(x.data, ir.ReinterpretView): 2450 x = x.data # unwrap TensorBox 2451 while isinstance(x, ir.ReinterpretView): 2452 reinterpret_view_layouts.append(x.get_layout()) 2453 x = x.data 2454 x = TensorBox(x) 2455 2456 x = clone(x) 2457 2458 if reinterpret_view_layouts: 2459 x = x.data # unwrap TensorBox 2460 for layout in reinterpret_view_layouts[::-1]: 2461 x = ir.ReinterpretView(x, layout) 2462 x = TensorBox(x) 2463 2464 return x 2465 2466 2467if hasattr(aten, "lift_fresh_copy"): 2468 register_lowering(aten.lift_fresh_copy)(clone) 2469 2470 2471@register_lowering(prims.iota) 2472def iota( 2473 length, 2474 *, 2475 start, 2476 step, 2477 dtype, 2478 device, 2479 requires_grad, 2480): 2481 def fn(index): 2482 return ops.index_expr(step * index[0] + start, dtype=dtype) 2483 2484 return Pointwise.create( 2485 device=decode_device(device), 2486 dtype=dtype, 2487 inner_fn=fn, 2488 ranges=[length], 2489 ) 2490 2491 2492@register_lowering(aten.select_scatter, type_promotion_kind=None) 2493def select_scatter(x, src, dim: int, index: int): 2494 assert x.get_dtype() == src.get_dtype() 2495 x_loader = x.make_loader() 2496 dim = _validate_dim(x, dim, 0) 2497 if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)): 2498 index = index + x.get_size()[dim] 2499 V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type] 2500 V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type] 2501 src = expand(unsqueeze(src, dim), x.get_size()) 2502 src_loader = src.make_loader() 2503 2504 def inner_fn(idx): 2505 return ops.where( 2506 ops.eq( 2507 ops.index_expr(idx[dim], torch.int32), 2508 ops.index_expr(index, torch.int32), 2509 ), 2510 src_loader(idx), 2511 x_loader(idx), 2512 ) 2513 2514 return Pointwise.create( 2515 device=x.get_device(), 2516 dtype=x.get_dtype(), 2517 inner_fn=inner_fn, 2518 ranges=list(x.get_size()), 2519 ) 2520 2521 2522@register_lowering(aten.slice_scatter, type_promotion_kind=None) 2523def slice_scatter(x, src, dim=0, start=None, end=None, step=1): 2524 assert x.get_dtype() == src.get_dtype() 2525 x_loader = x.make_loader() 2526 dim = _validate_dim(x, dim, 0) 2527 dim_size = x.get_size()[dim] 2528 2529 start, end = ir.SliceView.normalize_start_end(x, dim, start, end) 2530 2531 src_size = list(x.get_size()) 2532 src_size[dim] = FloorDiv(end - start + (step - 1), step) 2533 src = expand(src, src_size) 2534 src_loader = src.make_loader() 2535 2536 def inner_fn(idx): 2537 if start == 0 and end == dim_size and step == 1: 2538 # selecting every element is the same as just src.clone() 2539 return src_loader(idx) 2540 2541 idx_dim = ops.index_expr(idx[dim], torch.int64) 2542 src_idx = list(idx) 2543 src_idx[dim] = FloorDiv(idx[dim] - start, step) 2544 2545 mask = [] 2546 if start != 0: 2547 mask.append( 2548 ops.ge( 2549 idx_dim, 2550 ops.index_expr(sympy.expand(start), torch.int64), 2551 ) 2552 ) 2553 if end != dim_size: 2554 mask.append( 2555 ops.lt( 2556 idx_dim, 2557 ops.index_expr(sympy.expand(end), torch.int64), 2558 ) 2559 ) 2560 if step != 1: 2561 mask.append( 2562 ops.eq( 2563 ops.index_expr( 2564 ModularIndexing(idx[dim] - start, 1, step), torch.int64 2565 ), 2566 ops.constant(0, torch.int64), 2567 ) 2568 ) 2569 assert mask 2570 mask = functools.reduce(ops.and_, mask) 2571 src_val = ops.masked( 2572 mask, 2573 lambda: src_loader(src_idx), 2574 0 if is_integer_type(x) else 0.0, 2575 ) 2576 return ops.where( 2577 mask, 2578 src_val, 2579 x_loader(idx), 2580 ) 2581 2582 return Pointwise.create( 2583 device=x.get_device(), 2584 dtype=x.get_dtype(), 2585 inner_fn=inner_fn, 2586 ranges=list(x.get_size()), 2587 ) 2588 2589 2590def _unwrap(x): 2591 if isinstance(x, (list, tuple)) and len(x) > 0: 2592 return _unwrap(x[0]) 2593 return x 2594 2595 2596@register_lowering([torch.tensor, aten.scalar_tensor]) 2597def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False): 2598 assert_nyi(layout in (None, torch.strided), f"layout={layout}") 2599 assert_nyi(not pin_memory, "pin_memory") 2600 if isinstance(_unwrap(data), int): 2601 dtype = dtype or torch.int64 2602 else: 2603 dtype = dtype or torch.get_default_dtype() 2604 2605 ranges: List[sympy.Expr] = [] 2606 2607 if isinstance(data, sympy.Basic): 2608 2609 def inner_fn(index): 2610 return ops.index_expr(data, dtype) 2611 2612 elif isinstance(data, (float, int)): 2613 2614 def inner_fn(index): 2615 return ops.constant(data, dtype) 2616 2617 elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8: 2618 # inline small tensors 2619 ranges.append(sympy.Integer(len(data))) 2620 2621 def inner_fn(index): 2622 def binary_search(start, end): 2623 assert start < end 2624 if end - start == 1: 2625 return ops.constant(data[start], dtype) 2626 mid = (end - start) // 2 + start 2627 return ops.where( 2628 ops.lt( 2629 ops.index_expr(index[0], torch.int64), 2630 ops.constant(mid, torch.int64), 2631 ), 2632 binary_search(start, mid), 2633 binary_search(mid, end), 2634 ) 2635 2636 if len(data) == 0: 2637 return ops.constant(0, dtype) 2638 return binary_search(0, len(data)) 2639 2640 else: 2641 return V.graph.add_tensor_constant( 2642 torch.tensor(data, dtype=dtype, device=device) 2643 ) 2644 2645 return Pointwise.create( 2646 device=decode_device(device), 2647 dtype=dtype, 2648 inner_fn=inner_fn, 2649 ranges=ranges, 2650 ) 2651 2652 2653@register_lowering(torch.as_tensor) 2654def as_tensor(data, dtype=None, device=None): 2655 if isinstance(data, TensorBox): 2656 if dtype is not None: 2657 data = to_dtype(data, dtype) 2658 if device is not None: 2659 data = to_device(data, device) 2660 return data 2661 return tensor(data, dtype=dtype, device=device) 2662 2663 2664@register_lowering(torch.LongTensor) 2665def long_tensor(data): 2666 return tensor(data, dtype=torch.int64) 2667 2668 2669@register_lowering(aten._local_scalar_dense) 2670def _local_scalar_dense(data): 2671 from torch.fx.experimental.symbolic_shapes import resolve_unbacked_bindings 2672 2673 # This is interesting! Most lowerings return tensors, so you can just 2674 # return the buffer you allocated and it will get used (or not used, if 2675 # it's dead.) But _local_scalar_dense (aka item) returns an int, 2676 # not a Tensor, so you would have a type mismatch if you return a buffer; 2677 # we are obligated to return a sympy expression instead. However, 2678 # we need to actually codegen the .item() call somehow. We do this 2679 # by registering a faux buffer for the DynamicScalar IR node, which is 2680 # solely responsible for generating this .item(). The buffer is 2681 # not used for anything (notice we discard it); at codegen time, 2682 # the "buffer" just gets assigned None. 2683 unbacked_bindings = resolve_unbacked_bindings( 2684 V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"] 2685 ) 2686 assert len(unbacked_bindings) == 1, unbacked_bindings 2687 # NB: Have to be very careful here. V.graph.current_node.meta["val"] 2688 # seemingly also contains a symbol which you want to do binding for, 2689 # but it actually isn't. In particular, if we have later performed 2690 # a deferred runtime assert saying that u0 == s0, you will actually 2691 # see s0 from expr! This is bad because we need to actually generate 2692 # the assert that says u0 == s0, so we need to know where to get u0 2693 # from (this call). In particular, we must use unbacked_bindings, which 2694 # is guaranteed to have the original, unreplaced symbol in question. 2695 # 2696 # NB2: Another thing we have to be very careful about are symbol bindings 2697 # that require nontrivial refinement, e.g., when you have a binding site 2698 # x: Sym(u0 * 4) = y.item(). Here, the code generation must do a division 2699 # in order to appropriately bind u0. This is communicated via the keypath 2700 # in unbacked_bindings, and we need to hold onto it in order to generate 2701 # code appropriately for this case. 2702 binding_sym, keypath = next(iter(unbacked_bindings.items())) 2703 buffer = ir.DynamicScalar(binding_sym, keypath, data) 2704 buffer.name = V.graph.register_buffer(buffer) 2705 V.graph.register_operation(buffer) 2706 # NB: the replaced expr is OK to use directly downstream, we want 2707 # simplifications in this case! 2708 val = V.graph.current_node.meta["val"] 2709 if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): 2710 return val.node.expr 2711 else: 2712 return sympy.sympify(val) 2713 2714 2715@register_lowering(aten._assert_scalar) 2716def _assert_scalar(data, msg): 2717 # NB: These will be handled at codegen time 2718 # Not sure if we are guaranteed to be able to serve out truth from the 2719 # deferred_runtime_asserts, TODO: try this assert out 2720 # assert bool(data.scalar), data 2721 return None 2722 2723 2724def _full(fill_value, device, dtype, size): 2725 value = fill_value 2726 if not isinstance(fill_value, (int, float)) and hasattr(value, "value"): 2727 value = value.value 2728 2729 if isinstance(value, (int, float)): 2730 2731 def inner_fn(index): 2732 return ops.constant(value, dtype) 2733 2734 elif isinstance(value, sympy.Basic): 2735 2736 def inner_fn(index): 2737 return ops.index_expr(value, dtype) 2738 2739 else: 2740 assert len(value.get_size()) == 0 2741 value_loader = value.make_loader() 2742 2743 def inner_fn(index): 2744 return value_loader([]) 2745 2746 return Pointwise.create( 2747 device=device, 2748 dtype=dtype, 2749 inner_fn=inner_fn, 2750 ranges=list(size), 2751 ) 2752 2753 2754@register_lowering(aten.full_like, type_promotion_kind=None) 2755def full_like(x, fill_value, **kwargs): 2756 return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs) 2757 2758 2759def tensor_constructor(fill_value): 2760 # torch.zeros, torch.ones, etc 2761 def inner( 2762 *size, 2763 names=None, 2764 dtype=None, 2765 device=None, 2766 layout=None, 2767 pin_memory=False, 2768 memory_format=None, 2769 ): 2770 assert_nyi(names is None, "named tensors") 2771 assert_nyi(layout in (None, torch.strided), f"layout={layout}") 2772 assert_nyi(not pin_memory, "pin_memory") 2773 device = decode_device(device) 2774 dtype = dtype or torch.get_default_dtype() 2775 if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): 2776 size = tuple(size[0]) 2777 # See https://github.com/pytorch/pytorch/issues/118102 2778 # All sizes at lowering time should be sympy.Symbol, not SymInt! 2779 for s in size: 2780 assert not isinstance(s, torch.SymInt) 2781 size = [sympy.expand(s) for s in size] 2782 return _full(fill_value, device, dtype, size) 2783 2784 return inner 2785 2786 2787@register_lowering([torch.empty, aten.empty]) 2788def empty( 2789 *size, 2790 names=None, 2791 dtype=None, 2792 layout=None, 2793 device=None, 2794 pin_memory=None, 2795 memory_format=None, 2796): 2797 assert_nyi(names is None, "named tensors") 2798 device = decode_device(device) 2799 if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): 2800 size = tuple(size[0]) 2801 return empty_strided( 2802 size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 2803 ) 2804 2805 2806def create_tensor_like(creation_fn): 2807 """ 2808 Shim to convert X_like(...) into X(...). For example zeros_like() into zeros(). 2809 """ 2810 2811 def _constant_like( 2812 x, *, dtype=None, device=None, layout=None, pin_memory=False, memory_format=None 2813 ): 2814 assert_nyi(not pin_memory, "pin_memory") 2815 assert_nyi(layout in (None, torch.strided), f"layout={layout}") 2816 if dtype is None: 2817 dtype = x.get_dtype() 2818 else: 2819 dtype = decode_dtype(dtype) 2820 device = device or x.get_device() 2821 size = list(x.get_size()) 2822 return creation_fn( 2823 size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory 2824 ) 2825 2826 return _constant_like 2827 2828 2829def constant_like(fill_value): 2830 return create_tensor_like(tensor_constructor(fill_value)) 2831 2832 2833empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty)) 2834ones_like = create_tensor_like(tensor_constructor(1)) 2835zeros_like = create_tensor_like(tensor_constructor(0)) 2836 2837 2838def new_constant(fill_value): 2839 def _new_constant( 2840 x, size, *, dtype=None, layout=None, device=None, pin_memory=None 2841 ): 2842 assert isinstance(size, (list, tuple)) 2843 assert_nyi(not pin_memory, "pin_memory") 2844 assert_nyi(layout in (None, torch.strided), f"layout={layout}") 2845 dtype = decode_dtype(dtype) or x.get_dtype() 2846 device = device or x.get_device() 2847 size = [sympy.Integer(s) for s in size] 2848 return _full(fill_value, device, dtype, size) 2849 2850 return _new_constant 2851 2852 2853@register_lowering(aten.new_empty) 2854def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None): 2855 if dtype is None: 2856 dtype = x.get_dtype() 2857 if device is None: 2858 device = x.get_device() 2859 return empty_strided( 2860 size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 2861 ) 2862 2863 2864@register_lowering(aten.empty_strided) 2865def empty_strided( 2866 size, stride, *, dtype=None, layout=None, device=None, pin_memory=None 2867): 2868 assert isinstance(size, (list, tuple)) 2869 assert isinstance(stride, (list, tuple, type(None))) 2870 assert_nyi(not pin_memory, "pin_memory") 2871 assert_nyi(layout in (None, torch.strided), f"layout={layout}") 2872 dtype = decode_dtype(dtype) or torch.get_default_dtype() 2873 device = device or torch.tensor(0.0).device 2874 pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size) 2875 pointwise.realize() 2876 buffer = pointwise.data.data 2877 # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode 2878 buffer.data.ranges = [0] * len(size) 2879 assert isinstance(buffer, ir.ComputedBuffer) 2880 size = [sympy.expand(s) for s in size] 2881 stride = ( 2882 [sympy.expand(s) for s in stride] 2883 if stride 2884 else ir.FlexibleLayout.contiguous_strides(size) 2885 ) 2886 buffer.layout = ir.FixedLayout( 2887 device=device, 2888 dtype=dtype, 2889 size=size, 2890 stride=stride, 2891 ) 2892 return pointwise 2893 2894 2895@register_lowering(aten.new_empty_strided) 2896def new_empty_strided( 2897 x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None 2898): 2899 if dtype is None: 2900 dtype = x.get_dtype() 2901 if device is None: 2902 device = x.get_device() 2903 return empty_strided( 2904 size, stride, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 2905 ) 2906 2907 2908@register_lowering(prims.copy_strided.default) 2909def copy_strided(x, stride): 2910 stride = [V.graph.sizevars.size_hint(s) for s in stride] 2911 stride_order = sorted(range(len(stride)), key=stride.__getitem__) 2912 return ir.ExternKernel.require_stride_order(x, stride_order) 2913 2914 2915@register_lowering([torch.full, aten.full]) 2916def full(size, fill_value, **kwargs): 2917 assert kwargs.get("dtype") is not None, "dtype should be handled by decomposition" 2918 return tensor_constructor(fill_value)(size, **kwargs) 2919 2920 2921@register_lowering(aten.gather, type_promotion_kind=None) 2922def gather(x, dim, index, sparse_grad=False): 2923 # sparse_grad doesn't affect forward computation, 2924 # and backward tracing is taken care of by AOT Autograd 2925 assert isinstance(x, TensorBox) 2926 if index.get_numel() == 0: 2927 # Empty index case. Return an empty array with the same shape 2928 return new_empty(x, index.get_size()) 2929 2930 assert index.get_dtype() == torch.int64 2931 size = x.get_size() 2932 offset = len(size) == 0 2933 dim = _validate_dim(x, dim, offset) 2934 2935 if offset: 2936 x = expand(x, [1]) 2937 size = [1] 2938 2939 x_loader = x.make_loader() 2940 index_loader = index.make_loader() 2941 2942 def fn(idx): 2943 idx = list(idx) 2944 gather_idx = ops.indirect_indexing(index_loader(idx), size[dim]) 2945 if len(idx) == 0: 2946 idx = [gather_idx] 2947 else: 2948 idx[dim] = gather_idx 2949 return x_loader(idx) 2950 2951 return Pointwise.create( 2952 device=x.get_device(), 2953 dtype=x.get_dtype(), 2954 inner_fn=fn, 2955 ranges=index.get_size(), 2956 ) 2957 2958 2959@register_lowering(aten.embedding, type_promotion_kind=None) 2960def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): 2961 assert not sparse 2962 assert isinstance(weight, TensorBox) 2963 assert isinstance(indices, TensorBox) 2964 assert "int" in str(indices.get_dtype()) 2965 2966 weight_loader = weight.make_loader() 2967 indices_loader = indices.make_loader() 2968 indices_ndim = len(indices.get_size()) 2969 weight_size = weight.get_size() 2970 new_size = [*indices.get_size(), *weight_size[1:]] 2971 2972 def fn(idx): 2973 assert len(idx) == len(new_size), f"{idx} != {new_size}" 2974 var_index = indices_loader(idx[:indices_ndim]) 2975 weight_idx = [ops.indirect_indexing(var_index, weight_size[0])] + [ 2976 *idx[indices_ndim:] 2977 ] 2978 return weight_loader(weight_idx) 2979 2980 return Pointwise.create( 2981 device=weight.get_device(), 2982 dtype=weight.get_dtype(), 2983 inner_fn=fn, 2984 ranges=new_size, 2985 ) 2986 2987 2988def check_and_broadcast_indices(indices, device): 2989 assert all( 2990 i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8) 2991 for i in indices 2992 if i is not None 2993 ), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}" 2994 if any( 2995 i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None 2996 ): 2997 raise NotImplementedError("Fallback for bool indices") 2998 2999 valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)] 3000 assert len(valid_idxs) > 0, "requires at least 1 non-None index" 3001 new_indices = [None] * len(indices) 3002 for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])): 3003 # Eager allows indices to be CPU tensor when running on CUDA 3004 # FIXME: Calling to_device(x, device) should work but 3005 # test_advancedindex_mixed_cpu_devices still fails 3006 if x.get_device() != device: 3007 raise NotImplementedError("Fallback when indices is on a different device") 3008 new_indices[i] = x 3009 return new_indices, valid_idxs 3010 3011 3012def index_output_size_and_inner_fn( 3013 x_size, 3014 indices, 3015 tensor_indices, 3016 tensor_size, 3017 indices_loaders, 3018 indexed_size, 3019 x_loader, 3020 check, 3021): 3022 # Note that behavior of indexing differs when there are non consecutive 3023 # tensors. In this case, the tensor index is pulled to the beginning. 3024 # 3025 # Suppose a = torch.arange(3 * 4 * 5 * 6 * 7).view(3, 4, 5, 6, 7) 3026 # x = torch.tensor[1,2] 3027 # Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will 3028 # be pulled to the front. 3029 non_consecutive_tensors = False 3030 for previous, current in zip(tensor_indices, tensor_indices[1:]): 3031 if current - previous != 1: 3032 non_consecutive_tensors = True 3033 3034 output_size = [x_size[i] for i, val in enumerate(indices) if val is None] 3035 output_size = [*output_size, *x_size[len(output_size) + len(tensor_indices) :]] 3036 3037 first_tensor_index = tensor_indices[0] 3038 if non_consecutive_tensors: 3039 output_size = tensor_size + output_size 3040 else: 3041 output_size = ( 3042 output_size[:first_tensor_index] 3043 + tensor_size 3044 + output_size[first_tensor_index:] 3045 ) 3046 3047 def fn(idx): 3048 assert len(idx) == len(output_size) 3049 assert len(indices_loaders) == len(indexed_size) 3050 3051 rank = len(tensor_size) 3052 new_index = [] 3053 first_tensor_index = tensor_indices[0] 3054 start_offset = 0 if non_consecutive_tensors else first_tensor_index 3055 next_idx = 0 3056 for i in range(tensor_indices[-1] + 1): 3057 if i == start_offset: 3058 next_idx += rank 3059 if indices[i] is None: 3060 assert next_idx < len(idx) 3061 new_index.append(idx[next_idx]) 3062 next_idx += 1 3063 else: 3064 loader = indices_loaders[i] 3065 assert loader is not None 3066 size = indexed_size[i] 3067 new_index.append( 3068 ops.indirect_indexing( 3069 loader(idx[start_offset : start_offset + rank]), 3070 size, 3071 check=check, 3072 ) 3073 ) 3074 new_index = [ 3075 *new_index, 3076 *idx[next_idx:], 3077 ] 3078 return new_index if x_loader is None else x_loader(new_index) 3079 3080 return output_size, fn 3081 3082 3083def index_impl(x, indices, check): 3084 output_size, inner_fn, _ = index_impl_helper(x, indices, check) 3085 3086 return Pointwise.create( 3087 device=x.get_device(), 3088 dtype=x.get_dtype(), 3089 inner_fn=inner_fn, 3090 ranges=output_size, 3091 ) 3092 3093 3094def index_impl_helper(x, indices, check): 3095 assert isinstance(indices, (list, tuple)) 3096 x_loader = x.make_loader() 3097 indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device()) 3098 assert len(tensor_indices) > 0, "Must have at least one valid idx" 3099 3100 indices_loaders = [i.make_loader() if i is not None else None for i in indices] 3101 # no guards on output size, all the guards are set in broadcast_tensors 3102 3103 # We can use the first one since they are all required to be the same size 3104 tensor_size = list(indices[tensor_indices[0]].get_size()) 3105 3106 x_size = x.get_size() 3107 3108 indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None] 3109 if check and 0 in indexed_size and 0 not in tensor_size: 3110 raise IndexError("index is out of bounds for dimension with size 0") 3111 3112 indexed_size = [x_size[i] for i in range(len(indices))] 3113 output_size, index_inner_fn = index_output_size_and_inner_fn( 3114 x_size, 3115 indices, 3116 tensor_indices, 3117 tensor_size, 3118 indices_loaders, 3119 indexed_size, 3120 None, 3121 check=check, 3122 ) 3123 3124 def inner_fn(idx): 3125 return x_loader(index_inner_fn(idx)) 3126 3127 return output_size, inner_fn, index_inner_fn 3128 3129 3130@register_lowering(aten.index, type_promotion_kind=None) 3131def index(x, indices): 3132 try: 3133 return index_impl(x, indices, check=True) 3134 except NotImplementedError: 3135 # Fallback to ATen for boolean indexing 3136 x.realize() 3137 return fallback_handler(aten.index.Tensor, add_to_fallback_set=False)( 3138 x, indices 3139 ) 3140 3141 3142@register_lowering(aten._unsafe_index, type_promotion_kind=None) 3143def _unsafe_index(x, indices): 3144 return index_impl(x, indices, check=False) 3145 3146 3147# All the indexing decompositions are written in terms of index, index_put, and index_put_ 3148# We cannot have this lowering as a decomposition as it introduces 3149# mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead 3150# code elimination and common subexpression elimination optimizations, which 3151# assume graphs to be side-effect free. More details at 3152# https://github.com/pytorch/torchdynamo/issues/1235 3153# and 3154# https://github.com/pytorch/torchdynamo/issues/1863 3155@register_lowering(aten.index_put) 3156def index_put(x, indices, values, accumulate=False): 3157 return index_put_(clone(x), indices, values, accumulate) 3158 3159 3160@register_lowering(aten._unsafe_index_put) 3161def _unsafe_index_put(x, indices, values, accumulate=False): 3162 return index_put_impl_(clone(x), indices, values, accumulate, check=False) 3163 3164 3165def index_put_as_masked_fill(self, indices, value, accumulate): 3166 if value.get_device() != self.get_device(): 3167 value = to_device(value, self.get_device()) 3168 if accumulate: 3169 value = add(self, value) 3170 return mutate_to(self, where(indices[0], value, self)) 3171 3172 3173def index_put_fallback(self, indices, values, accumulate): 3174 deterministic = torch.are_deterministic_algorithms_enabled() 3175 if is_triton(values) and (accumulate or deterministic): 3176 msg = ( 3177 "index put with accumulate." 3178 if not deterministic 3179 else "deterministic index put." 3180 ) 3181 if stack_trace := V.graph.current_node.meta.get("stack_trace", None): 3182 msg = f"{msg} Found from : \n {stack_trace}" 3183 V.graph.disable_cudagraphs_reason = msg 3184 3185 ir.IndexPutFallback(V.graph.current_node.target, self, indices, values, accumulate) 3186 return self 3187 3188 3189@register_lowering(aten.index_put_, type_promotion_kind=None) 3190def index_put_(self, indices, values, accumulate=False): 3191 return index_put_impl_(self, indices, values, accumulate, check=True) 3192 3193 3194@register_lowering(inductor_prims._unsafe_index_put_, type_promotion_kind=None) 3195def _unsafe_index_put_(self, indices, values, accumulate=False): 3196 return index_put_impl_(self, indices, values, accumulate, check=False) 3197 3198 3199def index_put_impl_(self, indices, values, accumulate, check): 3200 # Dispatch to masked fill for single boolean index with single value 3201 if ( 3202 values.get_numel() == 1 3203 and len(indices) == 1 3204 and indices[0].get_dtype() in {torch.bool, torch.uint8} 3205 ): 3206 mask = indices[0] 3207 for _ in range(len(mask.get_size()), len(self.get_size())): 3208 mask = unsqueeze(mask, -1) 3209 return index_put_as_masked_fill(self, [mask], values, accumulate) 3210 3211 # Fallback in torch deterministic mode 3212 if torch.are_deterministic_algorithms_enabled(): 3213 return index_put_fallback(self, indices, values, accumulate) 3214 3215 # Fallback if there is a boolean index 3216 for index in indices: 3217 if index is not None and index.get_dtype() in {torch.bool, torch.uint8}: 3218 return index_put_fallback(self, indices, values, accumulate) 3219 3220 x_size = self.get_size() 3221 x_ndim = len(x_size) 3222 3223 if accumulate and needs_fallback_due_to_atomic_add_limitations(self.get_dtype()): 3224 # self is an scalar Tensor 3225 if x_ndim == 0: 3226 self = view(self, [1]) 3227 self = index_put_fallback(self, indices, values, accumulate) 3228 if x_ndim == 0: 3229 self = view(self, []) 3230 return self 3231 3232 values = to_dtype(values, self.get_dtype()) 3233 3234 try: 3235 # Note that code will only get here when dtype is uint32 3236 indices, tensor_indices = check_and_broadcast_indices( 3237 indices, self.get_device() 3238 ) 3239 except NotImplementedError: 3240 return index_put_fallback(self, indices, values, accumulate) 3241 3242 indices_loaders = [i.make_loader() if i is not None else None for i in indices] 3243 3244 assert isinstance(self, TensorBox) 3245 self.realize() 3246 3247 # self is an scalar Tensor 3248 if x_ndim == 0: 3249 self = view(self, [1]) 3250 3251 # We can use the first one since they are all required to be the same size 3252 tensor_size = list(indices[tensor_indices[0]].get_size()) 3253 indexed_size = [x_size[i] for i in range(len(indices))] 3254 3255 expected_vals_size, inner_fn = index_output_size_and_inner_fn( 3256 x_size, 3257 indices, 3258 tensor_indices, 3259 tensor_size, 3260 indices_loaders, 3261 indexed_size, 3262 None, 3263 check=check, 3264 ) 3265 3266 values = expand(values, expected_vals_size) 3267 # all guards are set above during broadcast_tensors and expand 3268 3269 scatter = ir.Scatter( 3270 device=self.get_device(), 3271 dtype=self.get_dtype(), 3272 inner_fn=values.make_loader(), 3273 ranges=expected_vals_size, # iter_ranges, 3274 output_indexer=inner_fn, 3275 scatter_mode="atomic_add" if accumulate else None, 3276 ) 3277 buffer = ir.ComputedBuffer( 3278 None, 3279 ir.MutationLayoutSHOULDREMOVE(self), 3280 scatter, 3281 ) 3282 buffer.name = V.graph.register_buffer(buffer) 3283 V.graph.register_operation(buffer) 3284 3285 if x_ndim == 0: 3286 self = view(self, []) 3287 return self 3288 3289 3290fallback__unsafe_masked_index = fallback_handler( 3291 aten._unsafe_masked_index.default, add_to_fallback_set=False 3292) 3293 3294fallback__unsafe_masked_index_put_accumulate = fallback_handler( 3295 aten._unsafe_masked_index_put_accumulate.default, add_to_fallback_set=False 3296) 3297 3298 3299@register_lowering(aten._unsafe_masked_index, type_promotion_kind=None) 3300def _unsafe_masked_index(self, mask, indices, fill): 3301 ranges, _, _unsafe_index_fn = index_impl_helper(self, indices, check=False) 3302 mask_loader = mask.make_loader() 3303 self_loader = self.make_loader() 3304 3305 def inner_fn(idx): 3306 if mask.dtype != torch.bool: 3307 mask_val = ops.to_dtype(mask_loader(idx), torch.bool) 3308 else: 3309 mask_val = mask_loader(idx) 3310 return ops.masked(mask_val, lambda: self_loader(_unsafe_index_fn(idx)), fill) 3311 3312 return Pointwise.create( 3313 device=self.get_device(), 3314 dtype=self.get_dtype(), 3315 inner_fn=inner_fn, 3316 ranges=ranges, 3317 ) 3318 3319 3320@register_lowering(aten._unsafe_masked_index_put_accumulate, type_promotion_kind=None) 3321def _unsafe_masked_index_put_accumulate(x, mask, indices, values): 3322 masked_value = where(mask, values, 0) 3323 shape = x.get_size() 3324 clamped_indices = [ 3325 clamp(indices[i], -shape[i], shape[i] - 1) if indices[i] else None 3326 for i in range(len(indices)) 3327 ] 3328 # TODO: use a masked store for this. currently only triton 3329 # supports masked stores and cpp backend does not. 3330 return _unsafe_index_put(x, clamped_indices, masked_value, accumulate=True) 3331 3332 3333@make_pointwise 3334def clamp(a, min, max): 3335 return ops.maximum(min, ops.minimum(max, a)) 3336 3337 3338@register_lowering(aten.as_strided_scatter, type_promotion_kind=None) 3339def as_strided_scatter(self, src, size, stride, storage_offset=None): 3340 output = clone(self) 3341 output_view = as_strided(output, size, stride, storage_offset) 3342 copy_(output_view, src) 3343 return output 3344 3345 3346@register_lowering(aten.scatter, type_promotion_kind=None) 3347def scatter(x, dim: int, index, src, **kwargs): 3348 return scatter_(clone(x), dim, index, src, **kwargs) 3349 3350 3351def scatter_fallback( 3352 op_overload: torch._ops.OpOverload, 3353 self, 3354 dim: int, 3355 index, 3356 src, 3357 *, 3358 reduce: Optional[str] = None, 3359 include_self: bool = True, 3360): 3361 src_is_tensor = isinstance(src, TensorBox) 3362 if use_scatter_fallback( 3363 op_overload, 3364 reduce, 3365 self.get_dtype(), 3366 src.get_dtype() if src_is_tensor else type(src), 3367 src.get_device().type if src_is_tensor else "not impl", 3368 src_is_tensor, 3369 ): 3370 ir.ScatterFallback( 3371 op_overload, 3372 self, 3373 dim, 3374 index, 3375 src, 3376 reduce=reduce, 3377 include_self=include_self, 3378 ) 3379 return self 3380 3381 return None 3382 3383 3384@register_lowering(aten.scatter_, type_promotion_kind=None) 3385def scatter_(self, dim: int, index, src, *, reduce: Optional[str] = None): 3386 assert reduce in {None, "add", "multiply"} 3387 if reduce is None: 3388 op_overload = getattr(aten.scatter_, V.graph.current_node.target._overloadname) # type: ignore[union-attr] 3389 fallback_result = scatter_fallback( 3390 op_overload, self, dim, index, src, reduce=reduce 3391 ) 3392 if fallback_result is not None: 3393 return fallback_result 3394 3395 if reduce == "add": 3396 reduce = "sum" 3397 elif reduce == "multiply": 3398 reduce = "prod" 3399 return scatter_reduce_(self, dim, index, src, reduce) 3400 3401 3402@register_lowering(aten.scatter_add, type_promotion_kind=None) 3403def scatter_add(x, dim: int, index, src): 3404 return scatter_add_(clone(x), dim, index, src) 3405 3406 3407@register_lowering(aten.scatter_add_, type_promotion_kind=None) 3408def scatter_add_(x, dim: int, index, src): 3409 return scatter_reduce_(x, dim, index, src, "sum") 3410 3411 3412@register_lowering(aten.scatter_reduce, type_promotion_kind=None) 3413def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs): 3414 return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs) 3415 3416 3417@register_lowering(aten.scatter_reduce_, type_promotion_kind=None) 3418def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True): 3419 assert reduce in {None, "sum", "prod", "mean", "amax", "amin"} 3420 assert ( 3421 len(aten.scatter_reduce_.overloads()) == 1 3422 and "two" in aten.scatter_reduce_.overloads() 3423 ), "aten.scatter_reduce_.two is not the unique overload of aten.scatter_reduce_" 3424 3425 if isinstance(src, Number): 3426 src = full_like(self, src) 3427 3428 fallback_result = scatter_fallback( 3429 aten.scatter_reduce_.two, 3430 self, 3431 dim, 3432 index, 3433 src, 3434 reduce=reduce, 3435 include_self=include_self, 3436 ) 3437 3438 if fallback_result: 3439 return fallback_result 3440 3441 assert isinstance(self, TensorBox) 3442 assert "int" in str(index.get_dtype()) 3443 3444 ndim = len(self.get_size()) 3445 if ndim == 0: 3446 self = view(self, [1]) 3447 3448 if isinstance(src, TensorBox) and len(src.get_size()) == 0: 3449 src = view(src, [1]) 3450 3451 if isinstance(index, TensorBox) and len(index.get_size()) == 0: 3452 index = view(index, [1]) 3453 3454 if index.get_numel() == 0: 3455 return self 3456 3457 dim = _validate_dim(self, dim) 3458 3459 self.realize() 3460 index_loader = index.make_loader() 3461 src_loader = src.make_loader() if isinstance(src, TensorBox) else None 3462 3463 def output_indexer(idx): 3464 # self is captured from the end of the function, so it may have 0 dim 3465 shape = self.get_size() 3466 ndim = len(shape) 3467 indirect_idx = list(idx) 3468 indirect_idx[dim] = ops.indirect_indexing( 3469 index_loader(idx), 1 if ndim == 0 else shape[dim], wrap_neg=False 3470 ) 3471 return indirect_idx 3472 3473 def fn(idx): 3474 if src_loader: 3475 return src_loader(idx) 3476 else: 3477 # src is a scalar 3478 return ops.constant(src, self.get_dtype()) 3479 3480 def backend_reduce_str(reduce): 3481 if reduce == "sum": 3482 return "atomic_add" 3483 else: 3484 # TODO: Need to support more reduction type 3485 assert reduce is None 3486 return None 3487 3488 if not include_self: 3489 # zero out the corresponding elements first 3490 zero_out = ir.Scatter( 3491 device=self.get_device(), 3492 dtype=self.get_dtype(), 3493 inner_fn=lambda index: ops.constant(0, self.get_dtype()), 3494 ranges=index.get_size(), 3495 output_indexer=output_indexer, 3496 scatter_mode=None, 3497 ) 3498 buffer = ir.ComputedBuffer( 3499 None, 3500 ir.MutationLayoutSHOULDREMOVE(self), 3501 zero_out, 3502 ) 3503 buffer.name = V.graph.register_buffer(buffer) 3504 V.graph.register_operation(buffer) 3505 3506 # self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 3507 # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 3508 # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 3509 scatter = ir.Scatter( 3510 device=self.get_device(), 3511 dtype=self.get_dtype(), 3512 inner_fn=fn, 3513 ranges=index.get_size(), 3514 output_indexer=output_indexer, 3515 scatter_mode=backend_reduce_str(reduce), 3516 ) 3517 buffer = ir.ComputedBuffer( 3518 None, 3519 ir.MutationLayoutSHOULDREMOVE(self), 3520 scatter, 3521 ) 3522 buffer.name = V.graph.register_buffer(buffer) 3523 V.graph.register_operation(buffer) 3524 3525 if ndim == 0: 3526 self = view(self, []) 3527 return self 3528 3529 3530def upsample_nearestnd( 3531 x, 3532 output_size, 3533 scales_x: Tuple[Optional[float], ...], 3534 n: int = 2, 3535 exact: bool = False, 3536): 3537 x.realize_hint() # elements are reused 3538 x_loader = x.make_loader() 3539 i_sizes = x.get_size()[-n:] 3540 batch = x.get_size()[:-n] 3541 i_sizes = [V.graph.sizevars.evaluate_static_shape(i) for i in i_sizes] 3542 3543 assert len(scales_x) == n 3544 o_sizes = output_size 3545 3546 inv_scales = [i / o for i, o in zip(i_sizes, o_sizes)] 3547 for i, scale in enumerate(scales_x): 3548 if scale is not None: 3549 inv_scales[i] = 1.0 / scale 3550 3551 def scale_fn(x, scale, size): 3552 # Nearest Exact: input_index = round(scale * (output_index + 0.5) - 0.5) 3553 # = floor(scale * (output_index + 0.5)) 3554 # Nearest: input_index = floor(scale * output_index) 3555 x = ops.index_expr(x, torch.float32) 3556 if exact: 3557 x = ops.add(x, ops.constant(0.5, torch.float32)) 3558 x = ops.mul(x, ops.constant(scale, torch.float32)) 3559 x = ops.to_dtype(x, torch.int32) 3560 return ops.indirect_indexing(x, size, check=False) 3561 3562 def fn(idx): 3563 x = idx[-n:] 3564 b = idx[:-n] 3565 return x_loader( 3566 [*b, *[scale_fn(i, s, size) for i, s, size in zip(x, inv_scales, i_sizes)]] 3567 ) 3568 3569 return Pointwise.create( 3570 device=x.get_device(), 3571 dtype=x.get_dtype(), 3572 inner_fn=fn, 3573 ranges=[*batch, *o_sizes], 3574 ) 3575 3576 3577@register_lowering(aten.upsample_nearest1d.default) 3578def upsample_nearest1d(x, output_size, scales: Optional[float] = None): 3579 return upsample_nearestnd(x, output_size, (scales,), n=1) 3580 3581 3582@register_lowering(aten._upsample_nearest_exact1d.default) 3583def _upsample_nearest_exact1d(x, output_size, scales: Optional[float] = None): 3584 return upsample_nearestnd(x, output_size, (scales,), n=1, exact=True) 3585 3586 3587@register_lowering(aten.upsample_nearest2d.default) 3588def upsample_nearest2d( 3589 x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None 3590): 3591 return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2) 3592 3593 3594@register_lowering(aten._upsample_nearest_exact2d.default) 3595def _upsample_nearest_exact2d( 3596 x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None 3597): 3598 return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2, exact=True) 3599 3600 3601@register_lowering(aten.upsample_nearest3d.default) 3602def upsample_nearest3d( 3603 x, 3604 output_size, 3605 scales_d: Optional[float] = None, 3606 scales_h: Optional[float] = None, 3607 scales_w: Optional[float] = None, 3608): 3609 return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3) 3610 3611 3612@register_lowering(aten._upsample_nearest_exact3d.default) 3613def _upsample_nearest_exact3d( 3614 x, 3615 output_size, 3616 scales_d: Optional[float] = None, 3617 scales_h: Optional[float] = None, 3618 scales_w: Optional[float] = None, 3619): 3620 return upsample_nearestnd( 3621 x, output_size, (scales_d, scales_h, scales_w), n=3, exact=True 3622 ) 3623 3624 3625def _create_constants(*args, dtype): 3626 return tuple(ops.constant(a, dtype) for a in args) 3627 3628 3629@register_lowering(prims.rev.default) 3630def rev(x, dims): 3631 # note - dims pre-canonicalized 3632 x_loader = x.make_loader() 3633 sizes = x.get_size() 3634 3635 def loader(idx): 3636 idx = list(idx) 3637 assert len(idx) == len(sizes) 3638 for dim in dims: 3639 idx[dim] = (sizes[dim] - 1) - idx[dim] 3640 3641 return x_loader(idx) 3642 3643 return Pointwise.create( 3644 device=x.get_device(), 3645 dtype=x.get_dtype(), 3646 inner_fn=loader, 3647 ranges=sizes, 3648 ) 3649 3650 3651@register_lowering(aten.constant_pad_nd, type_promotion_kind=None) 3652def constant_pad_nd(x, padding, fill_value=0): 3653 assert (len(padding) % 2) == 0 3654 if all(p == 0 for p in padding): 3655 return clone(x) 3656 3657 sizes = x.get_size() 3658 3659 bounds = list(reversed(list(zip(padding[::2], padding[1::2])))) 3660 n = len(sizes) - len(bounds) 3661 3662 # if padding is a complicated expression, hoist it 3663 bounds_precomp: List[Tuple[sympy.Symbol, Any]] = [] 3664 for l, h in bounds: 3665 bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type] 3666 3667 output_size = list(sizes[:n]) 3668 mask_sizes = [] 3669 for (low, high), size in zip(bounds, sizes[n:]): 3670 mask_sizes.append(size) 3671 output_size.append(sympy.expand(size + low + high)) 3672 assert len(output_size) == len(sizes) 3673 fill_value = dtype_to_type(x.get_dtype())(fill_value) 3674 3675 def mask(index): 3676 mask = [] 3677 for idx, (low, high), length in zip(index[n:], bounds, mask_sizes): 3678 if low != 0: 3679 mask.append(range_mask_low(idx, 0)) 3680 if high != 0: 3681 mask.append(range_mask_high(idx, length)) 3682 mask = functools.reduce(ops.and_, mask) 3683 return ops.masked(mask, lambda: x_loader(index), fill_value) 3684 3685 def offset_fn(index): 3686 new_index = list(index[:n]) 3687 for idx, (low, high) in zip(index[n:], bounds_precomp): 3688 new_index.append(idx - low) 3689 assert len(new_index) == len(index) 3690 return mask(new_index) 3691 3692 x_loader = x.make_loader() 3693 return Pointwise.create( 3694 device=x.get_device(), 3695 dtype=x.get_dtype(), 3696 inner_fn=offset_fn, 3697 ranges=output_size, 3698 ) 3699 3700 3701def range_mask_low(i: sympy.Expr, low: Union[sympy.Expr, int]): 3702 return ops.ge( 3703 ops.index_expr(i, torch.int64), 3704 ops.index_expr(sympy.Integer(low), torch.int64), 3705 ) 3706 3707 3708def range_mask_high(i: sympy.Expr, high: sympy.Expr): 3709 return ops.lt( 3710 ops.index_expr(i, torch.int64), 3711 ops.index_expr(high, torch.int64), 3712 ) 3713 3714 3715def range_mask(i: sympy.Expr, high: sympy.Expr, low: sympy.Expr): 3716 return ops.and_( 3717 range_mask_low(i, low), 3718 range_mask_high(i, high), 3719 ) 3720 3721 3722def constant_boundary_condition( 3723 x, fill_value, padding=None, pad_fill_value=1.0, dim=None 3724): 3725 h = x.get_size()[-dim:] 3726 x_loader = x.make_loader() 3727 padding_h = padding or [0] * dim 3728 3729 def load(index): 3730 prefix = index[:-dim] 3731 ih = index[-dim:] 3732 3733 mask = functools.reduce( 3734 ops.and_, 3735 [range_mask(ih[i], h[i] + padding_h[i], -padding_h[i]) for i in range(dim)], 3736 ) 3737 return ( 3738 ops.masked( 3739 mask, 3740 lambda: constant_boundary_condition(x, pad_fill_value, dim=dim)( 3741 [*prefix, *ih] 3742 ), 3743 fill_value, 3744 ) 3745 if padding 3746 else ops.masked(mask, lambda: x_loader([*prefix, *ih]), fill_value) 3747 ) 3748 3749 return load 3750 3751 3752def pooling_size(x, i, kernel_size, stride, padding, ceil_mode): 3753 x_out = FloorDiv( 3754 x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i] 3755 ) 3756 3757 if ceil_mode: 3758 x_alt = FloorDiv( 3759 x + 2 * padding[i] - (kernel_size[i] - 1) + 2 * (stride[i] - 1), stride[i] 3760 ) 3761 if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0: 3762 # Sliding windows must start within the input or left padding 3763 x_alt -= 1 # type: ignore[assignment] 3764 V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type] 3765 if V.graph.sizevars.size_hint(x_out - x_alt) == 0: 3766 # ceil mode is actually a no-op, lets guard on that 3767 V.graph.sizevars.guard_equals(x_out, x_alt) 3768 ceil_mode = False 3769 else: 3770 x_out = x_alt 3771 return x_out, ceil_mode 3772 3773 3774def should_fallback_max_pool2d_with_indices(kernel_size, dilation): 3775 kernel_size = pad_listlike(kernel_size, 2) 3776 window_size = kernel_size[0] * kernel_size[1] 3777 return (window_size > 25) or any(d > 1 for d in dilation) 3778 3779 3780def max_pool2d_checks( 3781 x, kernel_size, stride, padding, dilation, *, assert_fallback=None 3782): 3783 if padding == 0: 3784 padding = [0, 0] 3785 if dilation == 1: 3786 dilation = [1, 1] 3787 if not stride: 3788 stride = kernel_size 3789 3790 kernel_size = pad_listlike(kernel_size, 2) 3791 stride = pad_listlike(stride, 2) 3792 padding = pad_listlike(padding, 2) 3793 dilation = pad_listlike(dilation, 2) 3794 3795 assert isinstance(x, TensorBox) 3796 assert len(kernel_size) == 2 3797 assert len(stride) == 2 3798 assert len(padding) == 2 3799 assert len(dilation) == 2 3800 assert len(x.get_size()) in (3, 4) 3801 3802 use_fallback = should_fallback_max_pool2d_with_indices(kernel_size, dilation) 3803 if assert_fallback is not None: 3804 assert use_fallback == assert_fallback 3805 3806 return kernel_size, stride, padding, dilation, use_fallback 3807 3808 3809@register_lowering(prims._low_memory_max_pool2d_with_offsets, type_promotion_kind=None) 3810def _low_memory_max_pool2d_with_offsets( 3811 x, 3812 kernel_size, 3813 stride, 3814 padding, 3815 dilation, 3816 ceil_mode=False, 3817): 3818 # assert we are not on a fallback path, the inductor decomp should have guaranteed this 3819 kernel_size, stride, padding, dilation, _ = max_pool2d_checks( 3820 x, kernel_size, stride, padding, dilation, assert_fallback=False 3821 ) 3822 3823 x.realize_hint() 3824 *batch, h, w = x.get_size() 3825 3826 h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode) 3827 w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode) 3828 3829 dtype = x.dtype 3830 min_value = ( 3831 False 3832 if dtype is torch.bool 3833 else (float("-inf") if dtype.is_floating_point else torch.iinfo(dtype).min) 3834 ) 3835 3836 new_size = list(batch) + [h_out, w_out] 3837 if padding[0] or padding[1] or ceil_mode1 or ceil_mode2: 3838 x_loader = constant_boundary_condition(x, min_value, dim=2) 3839 else: 3840 x_loader = x.make_loader() 3841 3842 def fn(idx, return_index): 3843 *prefix, bh, bw = idx 3844 maxval = None 3845 maxindex = None 3846 for h_inc, w_inc in itertools.product( 3847 range(kernel_size[0]), range(kernel_size[1]) 3848 ): 3849 ih = bh * stride[0] + h_inc - padding[0] 3850 iw = bw * stride[1] + w_inc - padding[1] 3851 val = x_loader([*prefix, ih, iw]) 3852 if return_index: 3853 index = ops.index_expr(h_inc * kernel_size[1] + w_inc, torch.int8) 3854 if maxindex is None: 3855 maxindex = index 3856 else: 3857 maxindex = ops.where(ops.gt(val, maxval), index, maxindex) 3858 if maxval is None: 3859 maxval = val 3860 else: 3861 maxval = ops.maximum(val, maxval) 3862 if return_index: 3863 return maxindex 3864 else: 3865 return maxval 3866 3867 out = Pointwise.create( 3868 device=x.get_device(), 3869 dtype=x.get_dtype(), 3870 inner_fn=functools.partial(fn, return_index=False), 3871 ranges=new_size, 3872 ) 3873 offsets = Pointwise.create( 3874 device=x.get_device(), 3875 dtype=torch.int8, 3876 inner_fn=functools.partial(fn, return_index=True), 3877 ranges=new_size, 3878 ) 3879 return out, offsets 3880 3881 3882@register_lowering( 3883 prims._low_memory_max_pool2d_offsets_to_indices, type_promotion_kind=None 3884) 3885def _low_memory_max_pool2d_offsets_to_indices( 3886 offsets, kernel_width, input_width, stride, padding 3887): 3888 # TODO: Generalize to other max pooling flavors, and arbitrary dim 3889 3890 offsets_loader = offsets.make_loader() 3891 3892 def increments_to_index(h_inc, w_inc, bh, bw): 3893 w_in = ops.index_expr(input_width, torch.int64) 3894 hbase = ops.index_expr(bh * stride[0] - padding[0], torch.int64) 3895 wbase = ops.index_expr(bw * stride[1] - padding[1], torch.int64) 3896 ih = hbase + h_inc 3897 iw = wbase + w_inc 3898 return ih * w_in + iw 3899 3900 def offsets_to_indices(idx): 3901 *prefix, bh, bw = idx 3902 offset = offsets_loader([*prefix, bh, bw]) 3903 kw_const = ops.constant(kernel_width, torch.int32) 3904 h_inc = offset // kw_const 3905 w_inc = offset - (h_inc * kw_const) 3906 return increments_to_index(h_inc, w_inc, bh, bw) 3907 3908 indices = Pointwise.create( 3909 device=offsets.get_device(), 3910 dtype=torch.int64, 3911 inner_fn=offsets_to_indices, 3912 ranges=offsets.get_size(), 3913 ) 3914 return indices 3915 3916 3917# Fallback selected when we do not decompose to the low-memory path. 3918make_fallback(aten.max_pool2d_with_indices) 3919 3920 3921fallback_max_pool2d_with_indices_backward = fallback_handler( 3922 aten.max_pool2d_with_indices_backward.default, 3923 add_to_fallback_set=False, 3924) 3925 3926 3927@register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None) 3928def max_pool2d_with_indices_backward( 3929 grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices 3930): 3931 if padding == 0: 3932 padding = [0, 0] 3933 if dilation == 1: 3934 dilation = [1, 1] 3935 if not stride: 3936 stride = kernel_size 3937 3938 assert isinstance(x, TensorBox) 3939 assert len(kernel_size) == 2 3940 assert len(stride) == 2 3941 assert len(padding) == 2 3942 assert len(dilation) == 2 3943 assert len(x.get_size()) in (3, 4) 3944 3945 # we will read this many times, so make sure it is computed 3946 grad_output.realize_hint() 3947 try: 3948 gO_stride = grad_output.get_stride() 3949 except AttributeError: 3950 # some classes don't have `get_stride` 3951 # TODO will need a better way of determining if inputs are channels-last 3952 gO_stride = None 3953 if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise): # type: ignore[attr-defined] 3954 data = x.data.data # type: ignore[attr-defined] 3955 x_buffer = ir.ComputedBuffer( 3956 name=None, 3957 layout=ir.FlexibleLayout( 3958 device=data.get_device(), 3959 dtype=data.get_dtype(), 3960 size=data.get_size(), 3961 ), 3962 data=data, 3963 ) 3964 x_buffer.decide_layout() 3965 x_stride = x_buffer.get_stride() 3966 else: 3967 try: 3968 x_stride = x.get_stride() 3969 except AttributeError: 3970 x_stride = None 3971 3972 is_channels_last = (x_stride is not None and x_stride[1] == 1) or ( 3973 gO_stride is not None and gO_stride[1] == 1 3974 ) 3975 if any(d != 1 for d in dilation): 3976 # dilation NYI 3977 return fallback_max_pool2d_with_indices_backward( 3978 grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices 3979 ) 3980 3981 *batch, height, width = x.get_size() 3982 *_, pooled_height, pooled_width = grad_output.get_size() 3983 3984 indices_loader = indices.make_loader() 3985 grad_loader = grad_output.make_loader() 3986 new_size = list(x.get_size()) 3987 3988 h_window_size = max( 3989 max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1) 3990 for h in range(kernel_size[0] * 2) 3991 ) 3992 w_window_size = max( 3993 max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1) 3994 for w in range(kernel_size[1] * 2) 3995 ) 3996 3997 window_size = h_window_size * w_window_size 3998 3999 if window_size > 25: 4000 # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. 4001 return fallback_max_pool2d_with_indices_backward( 4002 grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices 4003 ) 4004 4005 indices_size = indices.get_size() 4006 4007 def fn(idx): 4008 *prefix, h, w = idx 4009 index_test = ops.index_expr(h * width + w, torch.int32) 4010 h = h + padding[0] 4011 w = w + padding[1] 4012 phstart = ops.index_expr( 4013 FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 4014 ) 4015 pwstart = ops.index_expr( 4016 FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 4017 ) 4018 phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) 4019 pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) 4020 4021 phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) 4022 pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) 4023 phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) 4024 pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) 4025 4026 gradient = None 4027 for ph_ in range(h_window_size): 4028 for pw_ in range(w_window_size): 4029 ph = ops.add(phstart, ops.constant(ph_, torch.int32)) 4030 pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) 4031 grad_index = [ 4032 *prefix, 4033 ops.indirect_indexing( 4034 ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32))), 4035 indices_size[-2], 4036 check=False, 4037 ), 4038 ops.indirect_indexing( 4039 ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32))), 4040 indices_size[-1], 4041 check=False, 4042 ), 4043 ] 4044 4045 index_actual = indices_loader(grad_index) 4046 grad_part = grad_loader(grad_index) 4047 check = ops.eq(index_actual, index_test) 4048 4049 if gradient is None: 4050 # don't need mask for 0, 0 4051 gradient = ops.where( 4052 check, grad_part, ops.constant(0.0, torch.float32) 4053 ) 4054 else: 4055 mask = ops.and_( 4056 ops.and_( 4057 ops.lt(ph, phend), 4058 ops.lt(pw, pwend), 4059 ), 4060 check, 4061 ) 4062 gradient = ops.where(mask, ops.add(gradient, grad_part), gradient) 4063 assert gradient is not None 4064 return gradient 4065 4066 out = Pointwise.create( 4067 device=grad_output.get_device(), 4068 dtype=grad_output.get_dtype(), 4069 inner_fn=fn, 4070 ranges=new_size, 4071 ) 4072 if is_channels_last: 4073 return ir.ExternKernel.require_channels_last(out) 4074 else: 4075 return out 4076 4077 4078def pad_adaptive_loader(x, pad_val=0.0): 4079 *_, h, w = x.get_size() 4080 x_loader = x.make_loader() 4081 4082 def load(prefix, increments, start_indices, end_indices): 4083 ih, iw = increments 4084 h_start_index, w_start_index = start_indices 4085 h_end_index, w_end_index = end_indices 4086 4087 mask = ops.and_( 4088 ops.lt( 4089 ops.index_expr(h_start_index + ih, torch.int64), 4090 ops.index_expr(h_end_index, torch.int64), 4091 ), 4092 ops.lt( 4093 ops.index_expr(w_start_index + iw, torch.int64), 4094 ops.index_expr(w_end_index, torch.int64), 4095 ), 4096 ) 4097 4098 return ops.masked( 4099 mask, 4100 lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]), 4101 pad_val, 4102 ) 4103 4104 return load 4105 4106 4107def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out): 4108 h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) 4109 h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) 4110 4111 w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) 4112 w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) 4113 4114 return h_start_index, h_end_index, w_start_index, w_end_index 4115 4116 4117def _adaptive_pooling_fn( 4118 start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn 4119): 4120 h_in, w_in = in_sizes 4121 h_out, w_out = out_sizes 4122 4123 ( 4124 h_start_index_fn, 4125 h_end_index_fn, 4126 w_start_index_fn, 4127 w_end_index_fn, 4128 ) = compute_indices_adaptive_pooling( 4129 start_index, end_index, h_in, w_in, h_out, w_out 4130 ) 4131 4132 def fn(idx, loader): 4133 *prefix, bh, bw = idx 4134 4135 h_start_index = h_start_index_fn(bh) 4136 h_end_index = h_end_index_fn(bh) 4137 4138 w_start_index = w_start_index_fn(bw) 4139 w_end_index = w_end_index_fn(bw) 4140 4141 result = None 4142 for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): 4143 val = loader( 4144 prefix, 4145 [ih, iw], 4146 [h_start_index, w_start_index], 4147 [h_end_index, w_end_index], 4148 ) 4149 if result is None: 4150 result = val 4151 else: 4152 result = pooling_fn(val, result) 4153 return result 4154 4155 return fn 4156 4157 4158def _adaptive_pooling_fn_with_idx( 4159 start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn 4160): 4161 h_in, w_in = in_sizes 4162 h_out, w_out = out_sizes 4163 4164 ( 4165 h_start_index_fn, 4166 h_end_index_fn, 4167 w_start_index_fn, 4168 w_end_index_fn, 4169 ) = compute_indices_adaptive_pooling( 4170 start_index, end_index, h_in, w_in, h_out, w_out 4171 ) 4172 4173 def fn(idx, loader): 4174 *prefix, bh, bw = idx 4175 4176 h_start_index = h_start_index_fn(bh) 4177 h_end_index = h_end_index_fn(bh) 4178 4179 w_start_index = w_start_index_fn(bw) 4180 w_end_index = w_end_index_fn(bw) 4181 4182 maxval = None 4183 maxindex = None 4184 for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): 4185 val = loader( 4186 prefix, 4187 [ih, iw], 4188 [h_start_index, w_start_index], 4189 [h_end_index, w_end_index], 4190 ) 4191 4192 index = ops.index_expr( 4193 (h_start_index + ih) * w_in + w_start_index + iw, torch.int64 4194 ) 4195 4196 if maxindex is None: 4197 maxindex = index 4198 else: 4199 maxindex = ops.where(ops.gt(val, maxval), index, maxindex) 4200 4201 if maxval is None: 4202 maxval = val 4203 else: 4204 maxval = pooling_fn(val, maxval) 4205 4206 return maxindex 4207 4208 return fn 4209 4210 4211fallback_adaptive_avg_pool2d = fallback_handler( 4212 aten._adaptive_avg_pool2d.default, add_to_fallback_set=False 4213) 4214 4215 4216@register_lowering(aten._adaptive_avg_pool2d) 4217def _adaptive_avg_pool2d(x, output_size): 4218 assert isinstance(x, TensorBox) 4219 assert len(output_size) == 2 4220 x.realize_hint() 4221 4222 *batch, h_in, w_in = x.get_size() 4223 4224 h_in = V.graph.sizevars.evaluate_static_shape(h_in) 4225 w_in = V.graph.sizevars.evaluate_static_shape(w_in) 4226 4227 h_out, w_out = output_size 4228 4229 # no-op if the same input and output 4230 if h_in == h_out and w_in == w_out: 4231 return clone(x) 4232 4233 if h_out == 0 or w_out == 0: 4234 o_size = [*batch, h_out, w_out] 4235 return empty(o_size, dtype=x.get_dtype(), device=x.get_device()) 4236 if h_in % h_out == 0 and w_in % w_out == 0: 4237 kernel_size = [h_in // h_out, w_in // w_out] 4238 return avg_pool2d(x, kernel_size) 4239 4240 h_kernel_max = ceildiv((h_in + h_out - 1), h_out) 4241 w_kernel_max = ceildiv((w_in + w_out - 1), w_out) 4242 4243 new_size = list(batch) + [h_out, w_out] 4244 dtype = x.get_dtype() 4245 4246 window_size = h_kernel_max * w_kernel_max 4247 if window_size > 25: 4248 # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. 4249 return fallback_adaptive_avg_pool2d(x, output_size) 4250 4251 def start_index(index, out_dim, inp_dim): 4252 return FloorDiv((index * inp_dim), out_dim) 4253 4254 def end_index(index, out_dim, inp_dim): 4255 return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) 4256 4257 fn_sum = _adaptive_pooling_fn( 4258 start_index=start_index, 4259 end_index=end_index, 4260 kernel_maxes=[h_kernel_max, w_kernel_max], 4261 in_sizes=[h_in, w_in], 4262 out_sizes=[h_out, w_out], 4263 pooling_fn=ops.add, 4264 ) 4265 4266 ones_loader = pad_adaptive_loader(ones_like(x)) 4267 4268 def fn(idx): 4269 return ops.truediv( 4270 fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader) 4271 ) 4272 4273 rv = Pointwise.create( 4274 device=x.get_device(), 4275 dtype=dtype, 4276 inner_fn=fn, 4277 ranges=new_size, 4278 ) 4279 # TODO: should we force these to be realized? 4280 return rv 4281 4282 4283fallback_adaptive_max_pool2d = fallback_handler( 4284 aten.adaptive_max_pool2d.default, add_to_fallback_set=False 4285) 4286 4287 4288@register_lowering(aten.adaptive_max_pool2d) 4289def adaptive_max_pool2d(x, output_size): 4290 assert isinstance(x, TensorBox) 4291 assert len(output_size) == 2 4292 x.realize_hint() 4293 4294 *batch, h_in, w_in = x.get_size() 4295 4296 h_in = V.graph.sizevars.evaluate_static_shape(h_in) 4297 w_in = V.graph.sizevars.evaluate_static_shape(w_in) 4298 4299 h_out, w_out = output_size 4300 4301 if h_out == 0 or w_out == 0: 4302 o_size = [*batch, h_out, w_out] 4303 return empty(o_size, dtype=x.get_dtype(), device=x.get_device()), empty( 4304 o_size, dtype=torch.int64, device=x.get_device() 4305 ) 4306 if h_in % h_out == 0 and w_in % w_out == 0: 4307 kernel_size = [h_in // h_out, w_in // w_out] 4308 if should_fallback_max_pool2d_with_indices(kernel_size, dilation=[1, 1]): 4309 return max_pool2d_with_indices(x, kernel_size) # type: ignore[name-defined] # noqa: F821 4310 else: 4311 v, offsets = _low_memory_max_pool2d_with_offsets( 4312 x, 4313 kernel_size, 4314 stride=kernel_size, 4315 padding=[0, 0], 4316 dilation=[1, 1], 4317 ceil_mode=False, 4318 ) 4319 indices = _low_memory_max_pool2d_offsets_to_indices( 4320 offsets, kernel_size[1], w_in, kernel_size, padding=[0, 0] 4321 ) 4322 return v, indices 4323 4324 h_kernel_max = ceildiv((h_in + h_out - 1), h_out) 4325 w_kernel_max = ceildiv((w_in + w_out - 1), w_out) 4326 4327 new_size = list(batch) + [h_out, w_out] 4328 dtype = x.get_dtype() 4329 4330 window_size = h_kernel_max * w_kernel_max 4331 if window_size > 25: 4332 # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. 4333 return fallback_adaptive_max_pool2d(x, output_size) 4334 4335 def start_index(index, out_dim, inp_dim): 4336 return FloorDiv((index * inp_dim), out_dim) 4337 4338 def end_index(index, out_dim, inp_dim): 4339 return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) 4340 4341 inner_func_max_val = _adaptive_pooling_fn( 4342 start_index=start_index, 4343 end_index=end_index, 4344 kernel_maxes=[h_kernel_max, w_kernel_max], 4345 in_sizes=[h_in, w_in], 4346 out_sizes=[h_out, w_out], 4347 pooling_fn=ops.maximum, 4348 ) 4349 4350 inner_func_max_idx = _adaptive_pooling_fn_with_idx( 4351 start_index=start_index, 4352 end_index=end_index, 4353 kernel_maxes=[h_kernel_max, w_kernel_max], 4354 in_sizes=[h_in, w_in], 4355 out_sizes=[h_out, w_out], 4356 pooling_fn=ops.maximum, 4357 ) 4358 4359 def inner_fn_max_val(idx): 4360 return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf"))) 4361 4362 def inner_fn_max_idx(idx): 4363 return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf"))) 4364 4365 rv = Pointwise.create( 4366 device=x.get_device(), 4367 dtype=dtype, 4368 inner_fn=inner_fn_max_val, 4369 ranges=new_size, 4370 ) 4371 ri = Pointwise.create( 4372 device=x.get_device(), 4373 dtype=torch.int64, 4374 inner_fn=inner_fn_max_idx, 4375 ranges=new_size, 4376 ) 4377 return rv, ri 4378 4379 4380fallback_fractional_max_pool2d = fallback_handler( 4381 aten.fractional_max_pool2d.default, add_to_fallback_set=False 4382) 4383 4384 4385def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim): 4386 out_sz = out_sz[dim] 4387 in_sz = in_sz[dim] 4388 kernel_sz = kernel_sz[dim] 4389 alpha = IntTrueDiv(in_sz - kernel_sz, out_sz - 1) 4390 samples_loader = samples.make_loader() 4391 4392 def load(prefix, i): 4393 sample = samples_loader([*prefix, dim]) 4394 i_expr = ops.index_expr(i, samples.get_dtype()) 4395 alpha_expr = ops.index_expr(alpha, samples.get_dtype()) 4396 seq_i = ops.floor((i_expr + sample) * alpha_expr) - ops.floor( 4397 sample * alpha_expr 4398 ) 4399 seq_i = ops.to_dtype(seq_i, torch.int64) 4400 4401 mask = ops.lt( 4402 i_expr, 4403 ops.index_expr(out_sz - 1, torch.int64), 4404 ) 4405 return ops.where(mask, seq_i, ops.index_expr(in_sz - kernel_sz, torch.int64)) 4406 4407 return load 4408 4409 4410@register_lowering(aten.fractional_max_pool2d) 4411def fractional_max_pool2d(x, kernel_size, output_size, random_samples): 4412 x.realize_hint() 4413 *batch, inp_h, inp_w = x.get_size() 4414 kernel_h, kernel_w = kernel_size 4415 h_out, w_out = output_size 4416 4417 if kernel_h * kernel_w >= 25: 4418 return fallback_fractional_max_pool2d( 4419 x, kernel_size, output_size, random_samples 4420 ) 4421 4422 gen_offsets_for_dim = functools.partial( 4423 _fractional_pooling_offsets, 4424 samples=random_samples, 4425 in_sz=[inp_h, inp_w], 4426 out_sz=output_size, 4427 kernel_sz=kernel_size, 4428 ) 4429 4430 h_index_fn = gen_offsets_for_dim(dim=0) 4431 w_index_fn = gen_offsets_for_dim(dim=1) 4432 x_loader = x.make_loader() 4433 4434 def fn(idx, return_index): 4435 *prefix, bh, bw = idx 4436 4437 h_start_index = ops.indirect_indexing(h_index_fn(prefix, bh), inp_h) 4438 w_start_index = ops.indirect_indexing(w_index_fn(prefix, bw), inp_w) 4439 4440 maxval = None 4441 maxindex = None 4442 for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])): 4443 val = x_loader([*prefix, h_start_index + ih, w_start_index + iw]) 4444 if return_index: 4445 index = ops.index_expr( 4446 (h_start_index + ih) * inp_w + w_start_index + iw, torch.int64 4447 ) 4448 if maxindex is None: 4449 maxindex = index 4450 else: 4451 maxindex = ops.where( 4452 ops.or_(ops.gt(val, maxval), ops.isnan(val)), index, maxindex 4453 ) 4454 if maxval is None: 4455 maxval = val 4456 else: 4457 maxval = ops.maximum(val, maxval) 4458 if return_index: 4459 return maxindex 4460 else: 4461 return maxval 4462 4463 new_size = list(batch) + [h_out, w_out] 4464 rv = Pointwise.create( 4465 device=x.get_device(), 4466 dtype=x.get_dtype(), 4467 inner_fn=functools.partial(fn, return_index=False), 4468 ranges=new_size, 4469 ) 4470 4471 ri = Pointwise.create( 4472 device=x.get_device(), 4473 dtype=torch.int64, 4474 inner_fn=functools.partial(fn, return_index=True), 4475 ranges=new_size, 4476 ) 4477 return rv, ri 4478 4479 4480@register_lowering(aten.upsample_nearest2d_backward.default) 4481def upsample_nearest2d_backward( 4482 x, output_size=None, input_size=None, scales_h=None, scales_w=None 4483): 4484 x.realize_hint() 4485 4486 *batch, inp_h, inp_w = x.get_size() 4487 inp_h = V.graph.sizevars.evaluate_static_shape(inp_h) 4488 inp_w = V.graph.sizevars.evaluate_static_shape(inp_w) 4489 4490 *batch, out_h, out_w = input_size 4491 4492 if inp_h % out_h == 0 and inp_w % out_w == 0: 4493 return avg_pool2d(x, [inp_h // out_h, inp_w // out_w], divisor_override=1) 4494 4495 h_kernel_max = ceildiv(inp_h, out_h) 4496 w_kernel_max = ceildiv(inp_w, out_w) 4497 4498 def start_index(index, out_dim, inp_dim): 4499 return CeilDiv(index * inp_dim, sympy.sympify(out_dim)) 4500 4501 def end_index(index, out_dim, inp_dim): 4502 return start_index((index + 1), out_dim, inp_dim) 4503 4504 fn_sum = _adaptive_pooling_fn( 4505 start_index=start_index, 4506 end_index=end_index, 4507 kernel_maxes=[h_kernel_max, w_kernel_max], 4508 in_sizes=[inp_h, inp_w], 4509 out_sizes=[out_h, out_w], 4510 pooling_fn=ops.add, 4511 ) 4512 4513 def fn(idx): 4514 return fn_sum(idx, pad_adaptive_loader(x)) 4515 4516 rv = Pointwise.create( 4517 device=x.get_device(), 4518 dtype=x.get_dtype(), 4519 inner_fn=fn, 4520 ranges=list(input_size), 4521 ) 4522 4523 return rv 4524 4525 4526fallback_avg_pool2d = fallback_handler( 4527 aten.avg_pool2d.default, add_to_fallback_set=False 4528) 4529fallback_avg_pool3d = fallback_handler( 4530 aten.avg_pool3d.default, add_to_fallback_set=False 4531) 4532 4533 4534@register_lowering(aten.avg_pool2d, type_promotion_kind=None) 4535def avg_pool2d( 4536 x, 4537 kernel_size, 4538 stride=(), 4539 padding=0, 4540 ceil_mode=False, 4541 count_include_pad=True, 4542 divisor_override=None, 4543): 4544 return _avg_poolnd( 4545 x, 4546 kernel_size, 4547 stride, 4548 padding, 4549 ceil_mode, 4550 count_include_pad, 4551 divisor_override, 4552 dim=2, 4553 ) 4554 4555 4556@register_lowering(aten.avg_pool3d, type_promotion_kind=None) 4557def avg_pool3d( 4558 x, 4559 kernel_size, 4560 stride=(), 4561 padding=0, 4562 ceil_mode=False, 4563 count_include_pad=True, 4564 divisor_override=None, 4565): 4566 return _avg_poolnd( 4567 x, 4568 kernel_size, 4569 stride, 4570 padding, 4571 ceil_mode, 4572 count_include_pad, 4573 divisor_override, 4574 dim=3, 4575 ) 4576 4577 4578def _avg_poolnd( 4579 x, 4580 kernel_size, 4581 stride, 4582 padding, 4583 ceil_mode, 4584 count_include_pad, 4585 divisor_override, 4586 dim, 4587): 4588 if not stride: 4589 stride = kernel_size 4590 if not padding: 4591 padding = [0] * dim 4592 kernel_size = pad_listlike(kernel_size, dim) 4593 stride = pad_listlike(stride, dim) 4594 padding = pad_listlike(padding, dim) 4595 4596 assert isinstance(x, TensorBox) 4597 assert len(kernel_size) == dim 4598 assert len(stride) == dim 4599 assert len(padding) == dim 4600 assert len(x.get_size()) in (dim + 1, dim + 2) 4601 4602 x.realize_hint() 4603 batch = x.get_size()[:-dim] 4604 h = x.get_size()[-dim:] 4605 4606 h_out, ceil_modes = zip( 4607 *[ 4608 pooling_size(h[i], i, kernel_size, stride, padding, ceil_mode) 4609 for i in range(dim) 4610 ] 4611 ) 4612 4613 if any(padding) or any(ceil_modes): 4614 x_loader = constant_boundary_condition(x, 0.0, dim=dim) 4615 had_padding = True 4616 else: 4617 x_loader = x.make_loader() 4618 had_padding = False 4619 4620 new_size = list(batch) + list(h_out) 4621 dtype = x.get_dtype() 4622 4623 window_size = functools.reduce(operator.mul, kernel_size) 4624 if window_size > 25: 4625 # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. 4626 if dim == 2: 4627 fallback = fallback_avg_pool2d 4628 elif dim == 3: 4629 fallback = fallback_avg_pool3d 4630 else: 4631 raise ValueError(f"Unknown dim: {dim}") 4632 4633 return fallback( 4634 x, 4635 kernel_size, 4636 stride, 4637 padding, 4638 ceil_mode, 4639 count_include_pad, 4640 divisor_override, 4641 ) 4642 4643 def fn_sum(idx, loader): 4644 prefix = idx[:-dim] 4645 b = idx[-dim:] 4646 total = None 4647 for ih in itertools.product(*[range(kernel_size[i]) for i in range(dim)]): 4648 inp = [b[i] * stride[i] + ih[i] - padding[i] for i in range(dim)] 4649 val = loader([*prefix, *inp]) 4650 if total is None: 4651 total = val 4652 else: 4653 total = ops.add(val, total) 4654 return total 4655 4656 if not had_padding or divisor_override: 4657 if divisor_override: 4658 scale = 1 / divisor_override 4659 else: 4660 scale = 1.0 / window_size 4661 4662 def fn(idx): 4663 return ops.mul(fn_sum(idx, x_loader), ops.constant(scale, dtype)) 4664 4665 else: 4666 4667 def fn(idx): 4668 prefix = idx[:-dim] 4669 bh = idx[-dim:] 4670 4671 divide_factors = [] 4672 for i in range(dim): 4673 hstart = bh[i] * stride[i] - padding[i] 4674 hend = sympy.Min(hstart + kernel_size[i], h[i] + padding[i]) 4675 if not count_include_pad: 4676 hstart = sympy.Max(hstart, 0) 4677 hend = sympy.Min(hend, h[i]) 4678 factor = ops.index_expr(hend - hstart, torch.int32) 4679 divide_factors.append(factor) 4680 divide_factor = functools.reduce(ops.mul, divide_factors) 4681 return ops.truediv(fn_sum(idx, x_loader), divide_factor) 4682 4683 rv = Pointwise.create( 4684 device=x.get_device(), 4685 dtype=dtype, 4686 inner_fn=fn, 4687 ranges=new_size, 4688 ) 4689 # TODO(jansel): should we force these to be realized? 4690 return rv 4691 4692 4693fallback_avg_pool2d_backward = fallback_handler( 4694 aten.avg_pool2d_backward.default, add_to_fallback_set=False 4695) 4696 4697 4698@register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None) 4699def avg_pool2d_backward( 4700 grad_output, 4701 x, 4702 kernel_size, 4703 stride, 4704 padding, 4705 ceil_mode, 4706 count_include_pad, 4707 divisor_override=None, 4708): 4709 assert divisor_override is None or divisor_override != 0, "divisor must be not zero" 4710 if not stride: 4711 stride = kernel_size 4712 if not padding: 4713 padding = [0, 0] 4714 4715 assert isinstance(grad_output, TensorBox) 4716 assert isinstance(x, TensorBox) 4717 assert len(kernel_size) == 2 4718 assert len(stride) == 2 4719 assert len(padding) == 2 4720 assert len(x.get_size()) in (3, 4) 4721 4722 grad_output.realize_hint() # we will read this many times, so make sure it is computed 4723 4724 *batch, height, width = x.get_size() 4725 4726 h_out, ceil_mode1 = pooling_size(height, 0, kernel_size, stride, padding, ceil_mode) 4727 w_out, ceil_mode2 = pooling_size(width, 1, kernel_size, stride, padding, ceil_mode) 4728 4729 grad_loader = grad_output.make_loader() 4730 4731 had_padding = padding[0] or padding[1] or ceil_mode1 or ceil_mode2 4732 4733 *_, pooled_height, pooled_width = grad_output.get_size() 4734 new_size = list(x.get_size()) 4735 dtype = x.get_dtype() 4736 4737 h_window_size = max( 4738 max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1) 4739 for h in range(kernel_size[0] * 2) 4740 ) 4741 w_window_size = max( 4742 max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1) 4743 for w in range(kernel_size[1] * 2) 4744 ) 4745 4746 window_size = h_window_size * w_window_size 4747 if window_size > 25: 4748 # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. 4749 return fallback_avg_pool2d_backward( 4750 grad_output, 4751 x, 4752 kernel_size, 4753 stride, 4754 padding, 4755 ceil_mode, 4756 count_include_pad, 4757 divisor_override, 4758 ) 4759 4760 def compute_pool_size_without_padding(ph, pw): 4761 """ 4762 This computes the scaling factor that we will divide an element 4763 by when `count_include_pad=False` 4764 """ 4765 stride_h = ops.constant(stride[0], torch.int32) 4766 stride_w = ops.constant(stride[1], torch.int32) 4767 pad_h = ops.constant(padding[0], torch.int32) 4768 pad_w = ops.constant(padding[1], torch.int32) 4769 kernel_h = ops.constant(kernel_size[0], torch.int32) 4770 kernel_w = ops.constant(kernel_size[1], torch.int32) 4771 hstart = ops.sub(ops.mul(ph, stride_h), pad_h) 4772 wstart = ops.sub(ops.mul(pw, stride_w), pad_w) 4773 hend = ops.minimum( 4774 ops.add(hstart, kernel_h), 4775 ops.add(ops.index_expr(height, torch.int32), pad_h), 4776 ) 4777 wend = ops.minimum( 4778 ops.add(wstart, kernel_w), 4779 ops.add(ops.index_expr(width, torch.int32), pad_w), 4780 ) 4781 hstart = ops.maximum(hstart, ops.constant(0, torch.int32)) 4782 wstart = ops.maximum(wstart, ops.constant(0, torch.int32)) 4783 hend = ops.minimum(hend, ops.index_expr(height, torch.int32)) 4784 wend = ops.minimum(wend, ops.index_expr(width, torch.int32)) 4785 divide_factor = ops.mul(ops.sub(hend, hstart), ops.sub(wend, wstart)) 4786 return divide_factor 4787 4788 def fn(idx): 4789 *prefix, h, w = idx 4790 h = h + padding[0] 4791 w = w + padding[1] 4792 phstart = ops.index_expr( 4793 FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 4794 ) 4795 pwstart = ops.index_expr( 4796 FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 4797 ) 4798 phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) 4799 pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) 4800 4801 phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) 4802 pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) 4803 phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) 4804 pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) 4805 4806 gradient = None 4807 for ph_ in range(h_window_size): 4808 for pw_ in range(w_window_size): 4809 ph = ops.add(phstart, ops.constant(ph_, torch.int32)) 4810 pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) 4811 4812 if divisor_override is not None: 4813 scale = divisor_override 4814 elif count_include_pad or not had_padding: 4815 scale = kernel_size[0] * kernel_size[1] 4816 else: 4817 scale = compute_pool_size_without_padding(ph, pw) 4818 4819 part = ops.truediv( 4820 grad_loader( 4821 [ 4822 *prefix, 4823 ops.indirect_indexing( 4824 ops.minimum( 4825 ph, ops.sub(phend, ops.constant(1, torch.int32)) 4826 ), 4827 pooled_height, 4828 check=False, 4829 ), 4830 ops.indirect_indexing( 4831 ops.minimum( 4832 pw, ops.sub(pwend, ops.constant(1, torch.int32)) 4833 ), 4834 pooled_width, 4835 check=False, 4836 ), 4837 ] 4838 ), 4839 scale, 4840 ) 4841 4842 mask = ops.and_( 4843 ops.lt(ph, phend), 4844 ops.lt(pw, pwend), 4845 ) 4846 if gradient is None: 4847 gradient = ops.where(mask, part, ops.constant(0.0, torch.float32)) 4848 else: 4849 gradient = ops.where(mask, ops.add(gradient, part), gradient) 4850 assert gradient is not None 4851 return gradient 4852 4853 rv = Pointwise.create( 4854 device=grad_output.get_device(), 4855 dtype=dtype, 4856 inner_fn=fn, 4857 ranges=new_size, 4858 ) 4859 return rv 4860 4861 4862fallback_avg_pool3d_backward = fallback_handler( 4863 aten.avg_pool3d_backward.default, add_to_fallback_set=False 4864) 4865 4866 4867@register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None) 4868def avg_pool3d_backward( 4869 grad_output, 4870 x, 4871 kernel_size, 4872 stride, 4873 padding, 4874 ceil_mode, 4875 count_include_pad, 4876 divisor_override=None, 4877): 4878 assert divisor_override is None or divisor_override != 0, "divisor must be not zero" 4879 if not stride: 4880 stride = kernel_size 4881 if not padding: 4882 padding = [0, 0, 0] 4883 4884 assert isinstance(grad_output, TensorBox) 4885 assert isinstance(x, TensorBox) 4886 assert len(kernel_size) == 3 4887 assert len(stride) == 3 4888 assert len(padding) == 3 4889 assert len(x.get_size()) in (4, 5) 4890 4891 grad_output.realize_hint() 4892 4893 *batch, depth, height, width = x.get_size() 4894 4895 d_out, ceil_mode_d = pooling_size(depth, 0, kernel_size, stride, padding, ceil_mode) 4896 h_out, ceil_mode_h = pooling_size( 4897 height, 1, kernel_size, stride, padding, ceil_mode 4898 ) 4899 w_out, ceil_mode_w = pooling_size(width, 2, kernel_size, stride, padding, ceil_mode) 4900 4901 grad_loader = grad_output.make_loader() 4902 had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w 4903 4904 *_, pooled_depth, pooled_height, pooled_width = grad_output.get_size() 4905 new_size = list(x.get_size()) 4906 dtype = x.get_dtype() 4907 4908 d_window_size, h_window_size, w_window_size = ( 4909 max( 4910 max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1) 4911 for d in range(kernel_size[i] * 2) 4912 ) 4913 for i in range(3) 4914 ) 4915 4916 window_size = d_window_size * h_window_size * w_window_size 4917 if window_size > 125: 4918 # Kernel size too big. Results in hard-to-optimize Triton code. 4919 return fallback_avg_pool3d_backward( 4920 grad_output, 4921 x, 4922 kernel_size, 4923 stride, 4924 padding, 4925 ceil_mode, 4926 count_include_pad, 4927 divisor_override, 4928 ) 4929 4930 def compute_pool_size_without_padding(pd, ph, pw): 4931 stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride) 4932 pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding) 4933 kernel_d, kernel_h, kernel_w = ( 4934 ops.constant(k, torch.int32) for k in kernel_size 4935 ) 4936 4937 dstart, hstart, wstart = ( 4938 ops.sub(ops.mul(p, s), pad) 4939 for p, s, pad in zip( 4940 [pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w] 4941 ) 4942 ) 4943 dend, hend, wend = ( 4944 ops.minimum( 4945 ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad) 4946 ) 4947 for start, k, dim, pad in zip( 4948 [dstart, hstart, wstart], 4949 [kernel_d, kernel_h, kernel_w], 4950 [depth, height, width], 4951 [pad_d, pad_h, pad_w], 4952 ) 4953 ) 4954 dstart, hstart, wstart = ( 4955 ops.maximum(start, ops.constant(0, torch.int32)) 4956 for start in [dstart, hstart, wstart] 4957 ) 4958 dend, hend, wend = ( 4959 ops.minimum(end, ops.index_expr(dim, torch.int32)) 4960 for end, dim in zip([dend, hend, wend], [depth, height, width]) 4961 ) 4962 divide_factor = ops.mul( 4963 ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart) 4964 ) 4965 return divide_factor 4966 4967 def fn(idx): 4968 *prefix, d, h, w = idx 4969 d, h, w = (v + pad for v, pad in zip([d, h, w], padding)) 4970 4971 pdstart, phstart, pwstart = ( 4972 ops.index_expr(FloorDiv(v - k + s, s), torch.int32) 4973 for v, k, s in zip([d, h, w], kernel_size, stride) 4974 ) 4975 4976 pdend, phend, pwend = ( 4977 ops.index_expr(FloorDiv(v, s) + 1, torch.int32) 4978 for v, s in zip([d, h, w], stride) 4979 ) 4980 4981 pdstart, phstart, pwstart = ( 4982 ops.maximum(pstart, ops.constant(0, torch.int32)) 4983 for pstart in [pdstart, phstart, pwstart] 4984 ) 4985 pdend, phend, pwend = ( 4986 ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32)) 4987 for pend, pooled_dim in zip( 4988 [pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width] 4989 ) 4990 ) 4991 4992 gradient = None 4993 # Iterate over the 3D region to accumulate gradients 4994 for pd_ in range(d_window_size): 4995 for ph_ in range(h_window_size): 4996 for pw_ in range(w_window_size): 4997 pd, ph, pw = ( 4998 ops.add(pstart, ops.constant(p_, torch.int32)) 4999 for pstart, p_ in zip( 5000 [pdstart, phstart, pwstart], [pd_, ph_, pw_] 5001 ) 5002 ) 5003 5004 if divisor_override is not None: 5005 scale = divisor_override 5006 elif count_include_pad or not had_padding: 5007 scale = kernel_size[0] * kernel_size[1] * kernel_size[2] 5008 else: 5009 scale = compute_pool_size_without_padding(pd, ph, pw) 5010 5011 part = ops.truediv( 5012 grad_loader( 5013 [ 5014 *prefix, 5015 ops.indirect_indexing( 5016 ops.minimum( 5017 pd, ops.sub(pdend, ops.constant(1, torch.int32)) 5018 ), 5019 pooled_depth, 5020 check=False, 5021 ), 5022 ops.indirect_indexing( 5023 ops.minimum( 5024 ph, ops.sub(phend, ops.constant(1, torch.int32)) 5025 ), 5026 pooled_height, 5027 check=False, 5028 ), 5029 ops.indirect_indexing( 5030 ops.minimum( 5031 pw, ops.sub(pwend, ops.constant(1, torch.int32)) 5032 ), 5033 pooled_width, 5034 check=False, 5035 ), 5036 ] 5037 ), 5038 scale, 5039 ) 5040 5041 mask = ops.and_( 5042 ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)), 5043 ops.lt(pw, pwend), 5044 ) 5045 if gradient is None: 5046 gradient = ops.where( 5047 mask, part, ops.constant(0.0, torch.float32) 5048 ) 5049 else: 5050 gradient = ops.where(mask, ops.add(gradient, part), gradient) 5051 assert gradient is not None 5052 return gradient 5053 5054 rv = Pointwise.create( 5055 device=grad_output.get_device(), 5056 dtype=dtype, 5057 inner_fn=fn, 5058 ranges=new_size, 5059 ) 5060 return rv 5061 5062 5063def _validate_reduction_axis(x, axis): 5064 size = x.get_size() 5065 if isinstance(axis, int): 5066 axis = [axis] 5067 elif not axis: 5068 axis = range(len(size)) 5069 if len(size) == 0: 5070 assert tuple(axis) in [(), (0,), (-1,)], f"invalid axis: {axis}" 5071 return [] 5072 axis = list(axis) 5073 for i in range(len(axis)): 5074 if axis[i] < 0: 5075 axis[i] += len(size) if len(size) else 1 5076 assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0) 5077 assert len(set(axis)) == len(axis), "reduction axis not unique" 5078 return axis 5079 5080 5081def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype): 5082 if dtype is not None: 5083 x = to_dtype(x, dtype) 5084 size = x.get_size() 5085 axis = set(_validate_reduction_axis(x, axis)) 5086 5087 kept_sizes = [] 5088 kept_idx = [] 5089 reduced_sizes = [] 5090 reduced_idx = [] 5091 for i in range(len(size)): 5092 if i in axis: 5093 reduced_idx.append(i) 5094 reduced_sizes.append(size[i]) 5095 else: 5096 kept_idx.append(i) 5097 kept_sizes.append(size[i]) 5098 5099 def loader(index, reduction_index): 5100 assert len(reduction_index) == len(reduced_idx) 5101 if keepdims: 5102 assert len(index) == len(size) 5103 index = [index[i] for i in kept_idx] 5104 assert len(index) == len(kept_idx) 5105 new_index = [None] * (len(index) + len(reduction_index)) 5106 for idx, var in itertools.chain( 5107 zip(kept_idx, index), zip(reduced_idx, reduction_index) 5108 ): 5109 new_index[idx] = var 5110 return inner_loader(new_index) 5111 5112 if keepdims: 5113 new_size = list(size) 5114 for i in reduced_idx: 5115 new_size[i] = sympy.Integer(1) 5116 else: 5117 new_size = kept_sizes 5118 5119 inner_loader = x.make_loader() 5120 return dict( 5121 device=x.get_device(), 5122 dst_dtype=override_return_dtype or x.get_dtype(), 5123 src_dtype=x.get_dtype(), 5124 inner_fn=loader, 5125 ranges=new_size, 5126 reduction_ranges=reduced_sizes, 5127 ) 5128 5129 5130def make_reduction(reduction_type: str, override_return_dtype=None): 5131 def inner(x, axis=None, keepdims=False, *, dtype=None): 5132 kwargs = _make_reduction_inner( 5133 x, 5134 axis=axis, 5135 keepdims=keepdims, 5136 dtype=dtype, 5137 override_return_dtype=override_return_dtype, 5138 ) 5139 result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs) 5140 if isinstance( 5141 result.data.data, Reduction 5142 ): # Only realize if reduction isn't unrolled 5143 result.realize() 5144 return result 5145 5146 return inner 5147 5148 5149def _make_scan_inner(x, *, axis, dtype): 5150 if dtype is not None: 5151 x = to_dtype(x, dtype) 5152 size = x.get_size() 5153 axis = _validate_dim(x, axis) 5154 5155 return dict( 5156 device=x.get_device(), 5157 dtypes=(x.get_dtype(),), 5158 inner_fns=(x.make_loader(),), 5159 size=x.get_size(), 5160 axis=axis, 5161 ) 5162 5163 5164@register_lowering(aten.mean) 5165def mean(x, axis=None, keepdim=False, *, dtype=None): 5166 if dtype is not None: 5167 x = to_dtype(x, dtype) 5168 size = x.get_size() 5169 axis = _validate_reduction_axis(x, axis) 5170 # compute in higher-precision until end of mean lowering 5171 output_dtype = x.get_dtype() 5172 if output_dtype in (torch.float16, torch.bfloat16): 5173 x = to_dtype(x, torch.float) 5174 sum_result = sum_(x, axis, keepdim) 5175 denom = sympy_product(size[i] for i in axis) 5176 denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device()) 5177 denom = ExpandView.create(denom, list(sum_result.get_size())) 5178 return to_dtype(div(sum_result, denom), output_dtype) 5179 5180 5181def var_mean_sum_(x, axis, correction, keepdim, return_mean): 5182 if correction is None: 5183 correction = 1 5184 5185 size = x.get_size() 5186 axis = _validate_reduction_axis(x, axis) 5187 x_mean = mean(x, axis, keepdim=True) 5188 if return_mean: 5189 x_mean.realize() 5190 5191 diffs = square(sub(x, x_mean)) 5192 sum_result = sum_(diffs, axis, keepdim) 5193 5194 denom = sympy_product(size[i] for i in axis) 5195 if correction: 5196 denom = sympy.Max(denom - correction, 0) 5197 denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device()) 5198 denom = ExpandView.create(denom, list(sum_result.get_size())) 5199 x_var = div(sum_result, denom) 5200 if not return_mean: 5201 return (x_var,) 5202 5203 x_mean = x_mean if keepdim else squeeze(x_mean, axis) 5204 return x_var, x_mean 5205 5206 5207def use_two_step_variance(x, axis, keepdim): 5208 # Instead of unrolling welford, just unroll the simpler two-step var 5209 axis = _validate_reduction_axis(x, axis) 5210 kwargs = _make_reduction_inner( 5211 x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None 5212 ) 5213 5214 ranges = kwargs["ranges"] 5215 reduction_numel = sympy_product(kwargs["reduction_ranges"]) 5216 return ( 5217 isinstance(reduction_numel, sympy.Integer) 5218 and int(reduction_numel) < config.unroll_reductions_threshold 5219 and sympy_product(ranges) != 1 5220 ) 5221 5222 5223def var_mean_welford_(x, axis, *, correction, keepdim, return_mean): 5224 if correction is None: 5225 correction = 1 5226 5227 kwargs = _make_reduction_inner( 5228 x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None 5229 ) 5230 loader = kwargs.pop("inner_fn") 5231 kwargs.pop("dst_dtype") 5232 kwargs.pop("src_dtype") 5233 5234 mean, m2, _ = ir.WelfordReduction.create( 5235 inner_fns=(loader,), 5236 reduction_type="welford_reduce", 5237 dtype=x.get_dtype(), 5238 **kwargs, 5239 ) 5240 m2.realize() 5241 5242 dtype = x.get_dtype() 5243 size = x.get_size() 5244 axis = _validate_reduction_axis(x, axis) 5245 rnumel = sympy_product(size[i] for i in axis) 5246 5247 def get_constant_or_index_expr(x, dtype): 5248 if isinstance(x, sympy.Expr) and not x.is_number: 5249 return ops.to_dtype(ops.index_expr(x, torch.int64), dtype) 5250 return ops.constant(x, dtype) 5251 5252 def scale_fn(data): 5253 c = get_constant_or_index_expr(correction, dtype) 5254 N = get_constant_or_index_expr(rnumel, dtype) 5255 zero = ops.constant(0, dtype) 5256 return data / ops.maximum(zero, N - c) 5257 5258 var = make_pointwise(scale_fn)(m2) 5259 5260 if return_mean: 5261 mean.realize() 5262 return var, mean 5263 return (var,) 5264 5265 5266def var_mean_helper_(x, *, axis, correction, keepdim, return_mean): 5267 out_dtype = x.get_dtype() 5268 compute_dtype = get_computation_dtype(out_dtype) 5269 x = to_dtype(x, compute_dtype, copy=False) 5270 kwargs = dict( 5271 x=x, 5272 axis=axis, 5273 correction=correction, 5274 keepdim=keepdim, 5275 return_mean=return_mean, 5276 ) 5277 output = ( 5278 var_mean_sum_(**kwargs) 5279 if use_two_step_variance(x, axis=axis, keepdim=keepdim) 5280 else var_mean_welford_(**kwargs) 5281 ) 5282 output = tuple(to_dtype(x, out_dtype, copy=False) for x in output) 5283 return output[0] if not return_mean else output 5284 5285 5286@register_lowering([aten.var, prims.var]) 5287def var_(x, axis=None, *, correction=None, keepdim=False): 5288 return var_mean_helper_( 5289 x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False 5290 ) 5291 5292 5293@register_lowering(aten.var_mean) 5294def var_mean(x, axis=None, *, correction=None, keepdim=False): 5295 return var_mean_helper_( 5296 x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True 5297 ) 5298 5299 5300def pow_recursive(x, y, dtype): 5301 if y < 0: 5302 return pow_recursive(ops.reciprocal(x), -y, dtype) 5303 if y == 0: 5304 return ops.constant(1, dtype) 5305 if y == 1: 5306 return x 5307 5308 result = pow_recursive(x, y // 2, dtype) 5309 result = ops.mul(result, result) 5310 if (y % 2) == 1: 5311 result = ops.mul(result, x) 5312 return result 5313 5314 5315@make_pointwise 5316def pow_native(a, b): 5317 return ops.pow(a, b) 5318 5319 5320fallback_pow_tensor_tensor = fallback_handler( 5321 aten.pow.Tensor_Tensor, add_to_fallback_set=False 5322) 5323fallback_pow_scalar = fallback_handler(aten.pow.Scalar, add_to_fallback_set=False) 5324fallback_pow_tensor_scalar = fallback_handler( 5325 aten.pow.Tensor_Scalar, add_to_fallback_set=False 5326) 5327 5328 5329@register_lowering(aten.pow, broadcast=True) 5330def pow(a, b): 5331 if isinstance(b, float) and b == int(b): 5332 return pow(a, int(b)) 5333 elif isinstance(b, float) and b == 0.5: 5334 return sqrt(a) 5335 elif isinstance(b, int) and b == 1: 5336 return clone(a) 5337 5338 # Type promotion ensures all tensor arguments have the same type 5339 dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox)) 5340 is_integer_pow = is_integer_dtype(dtype) 5341 5342 # Optimize away small fixed powers, or for integers avoid falling back to ATen 5343 embed_exponent = isinstance(b, int) and ( 5344 -32 < b < 32 or (is_integer_pow and b >= 0) 5345 ) 5346 if embed_exponent: 5347 loader = a.make_loader() 5348 5349 def fn(idx): 5350 return pow_recursive(loader(idx), b, a.get_dtype()) 5351 5352 return Pointwise.create( 5353 device=a.get_device(), 5354 dtype=a.get_dtype(), 5355 inner_fn=fn, 5356 ranges=a.get_size(), 5357 ) 5358 5359 if isinstance(a, Number): 5360 if a == 1: 5361 return full_like(b, 1) 5362 if a == 2 and is_float_dtype(b.get_dtype()): 5363 return exp2(b) 5364 5365 if is_integer_pow: 5366 # ops.pow doesn't work for integers 5367 if isinstance(a, Number): 5368 return fallback_pow_scalar(a, b) 5369 elif isinstance(b, Number): 5370 return fallback_pow_tensor_scalar(a, b) 5371 else: 5372 return fallback_pow_tensor_tensor(a, b) 5373 5374 return pow_native(a, b) 5375 5376 5377def mutate_to(changed, val, unsafe_alias=False): 5378 if isinstance(changed, TensorBox): 5379 changed_data = changed.data 5380 else: 5381 changed_data = changed 5382 if isinstance(val, TensorBox): 5383 val = val.data 5384 5385 if not isinstance(val, ir.StorageBox): 5386 # introduce a copy to handle views 5387 val = Pointwise.create( 5388 device=changed.get_device(), 5389 dtype=changed.get_dtype(), 5390 inner_fn=val.make_loader(), 5391 ranges=changed.get_size(), 5392 ).data 5393 assert isinstance(val, ir.StorageBox) 5394 5395 if isinstance(changed_data, ir.StorageBox) and not ( 5396 changed_data.is_input_buffer() 5397 # In AOTI, module parameters and buffers are not lifted as graph inputs 5398 or changed_data.is_module_buffer() 5399 or isinstance(changed_data.data, ir.NopKernel) 5400 ): 5401 # Fast path, just swing the data pointer 5402 val.realize() 5403 changed_data.data = val.data 5404 return changed 5405 5406 ir.MutationLayoutSHOULDREMOVE.realize_into( 5407 val, changed_data, unsafe_alias=unsafe_alias 5408 ) 5409 return changed 5410 5411 5412@register_lowering(aten.fill_) 5413def fill_(x, fill_value): 5414 return mutate_to(x, full_like(x, fill_value)) 5415 5416 5417@register_lowering(aten.copy_, type_promotion_kind=None) 5418def copy_(dst, src, non_blocking=False): 5419 if dst is src: 5420 # dst.copy_(dst) can happen from the reinplacing pass 5421 return dst 5422 src = to_device(src, dst.get_device()) 5423 src = to_dtype(src, dst.get_dtype()) 5424 src = expand(src, dst.get_size()) 5425 return mutate_to(dst, src) 5426 5427 5428@make_pointwise 5429def floordiv(a, b): 5430 return ops.floordiv(a, b) 5431 5432 5433@make_pointwise 5434def truncdiv(a, b): 5435 return ops.truncdiv(a, b) 5436 5437 5438@register_lowering(aten.div, broadcast=True) 5439def div_mode(a, b, rounding_mode=None): 5440 both_integer = is_integer_type(a) and is_integer_type(b) 5441 both_boolean = is_boolean_type(a) and is_boolean_type(b) 5442 5443 # floordiv and truncdiv need special handling for integer tensors on Triton, 5444 # see the discussion at https://github.com/openai/triton/issues/605 5445 if rounding_mode == "floor": 5446 assert not both_boolean, "floordiv operands can not be boolean at the same time" 5447 return floordiv(a, b) if both_integer else floor(div(a, b)) 5448 if rounding_mode == "trunc": 5449 assert not both_boolean, "truncdiv operands can not be boolean at the same time" 5450 return truncdiv(a, b) if both_integer else trunc(div(a, b)) 5451 return div(a, b) 5452 5453 5454@register_lowering([aten.mul], broadcast=True) 5455def mul(a, b): 5456 both_bool = is_boolean_type(a) and is_boolean_type(b) 5457 if both_bool: 5458 return logical_and(a, b) 5459 else: 5460 fn = ops_wrapper(aten.mul.__name__) 5461 return make_pointwise(fn)(a, b) 5462 5463 5464def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]: 5465 """Try convert an arbitrary IR node into an ir.Constant value""" 5466 5467 # First try unwrapping the IRNode to see if it is already an ir.Constant 5468 # Optional step, but avoids unnecessary inner_fn evaluation. 5469 if isinstance(x, ir.MutableBox): 5470 return get_constant_value(x.data) 5471 if isinstance(x, ir.BaseView): 5472 return get_constant_value(x.unwrap_view()) 5473 if isinstance(x, ir.Constant): 5474 return x 5475 5476 # If the unwrapped node is not an ir.Constant, try evaluating inner_fn 5477 # to see if the returned value is from an `ops.constant` call 5478 if not isinstance(x, ir.Loops): 5479 return None 5480 5481 handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device()) 5482 with V.set_ops_handler(handler), patch.object( 5483 ir.FlexibleLayout, "allow_indexing", True 5484 ): 5485 out = x.inner_fn(*x.inner_fn_args()) 5486 5487 assert isinstance(out, torch._inductor.virtualized.OpsValue) 5488 if isinstance(out.value, ir.Constant): 5489 return out.value 5490 return None 5491 5492 5493# NOTE: prims.div maps to a / b in C, so performs truncation division on 5494# integer inputs and true division for floating and complex inputs. 5495@register_lowering([prims.div], broadcast=True) 5496def div_prim(a, b): 5497 is_integral = all(is_boolean_type(x) or is_integer_type(x) for x in [a, b]) 5498 5499 if is_integral: 5500 return truncdiv(a, b) 5501 5502 if (divisor := get_constant_value(b)) is not None: 5503 # Replace divide by constant with multiply by reciprocal 5504 if divisor.value == 0: 5505 reciprocal = math.copysign(float("inf"), divisor.value) 5506 else: 5507 reciprocal = 1.0 / divisor.value 5508 return mul(a, reciprocal) 5509 5510 def fn(*args): 5511 return ops.truediv(*args) 5512 5513 return make_pointwise(fn)(a, b) 5514 5515 5516@register_lowering( 5517 [aten.true_divide, aten.div.Tensor], 5518 broadcast=True, 5519 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 5520) 5521def div(a, b): 5522 a, b = promote_constants( 5523 (a, b), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 5524 ) 5525 return div_prim(a, b) 5526 5527 5528@register_lowering([aten.fmod, prims.fmod], broadcast=True) 5529def fmod(a, b): 5530 is_integral = is_boolean_type(a) or is_integer_type(a) 5531 5532 if is_integral: 5533 5534 def fn(a, b): 5535 return ops.mod(a, b) 5536 5537 else: 5538 5539 def fn(a, b): 5540 return ops.fmod(a, b) 5541 5542 return make_pointwise(fn)(a, b) 5543 5544 5545@register_lowering(aten.rsqrt) 5546def rsqrt(x): 5547 dtype = x.get_dtype() 5548 if is_integer_dtype(dtype) or is_boolean_dtype(dtype): 5549 x = to_dtype(x, torch.get_default_dtype()) 5550 5551 def _rsqrt(x): 5552 return ops.rsqrt(x) 5553 5554 return make_pointwise(_rsqrt)(x) 5555 5556 5557@register_lowering([aten.sum, prims.sum]) 5558def sum_(x, axis=None, keepdims=False, *, dtype=None): 5559 if ( 5560 is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) 5561 ) and dtype is None: 5562 dtype = torch.int64 5563 5564 fn = make_reduction("sum", override_return_dtype=dtype) 5565 return fn(x, axis, keepdims, dtype=dtype) 5566 5567 5568fallback_cumsum = fallback_handler(aten.cumsum.default) 5569fallback_cumprod = fallback_handler(aten.cumprod.default) 5570fallback_logcumsumexp = fallback_handler(aten.logcumsumexp.default) 5571fallback_cummax = fallback_handler(aten.cummax.default) 5572fallback_cummin = fallback_handler(aten.cummin.default) 5573 5574 5575@register_lowering(aten.cumsum) 5576def cumsum(x, axis=None, dtype=None): 5577 if ( 5578 is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) 5579 ) and dtype is None: 5580 dtype = torch.int64 5581 5582 if len(x.get_size()) == 0: 5583 assert axis in [0, -1] 5584 dtype = dtype or x.get_dtype() 5585 return to_dtype(x, dtype, copy=True) 5586 5587 def combine_fn(a_tuple, b_tuple): 5588 (a,) = a_tuple 5589 (b,) = b_tuple 5590 return (ops.add(a, b),) 5591 5592 kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) 5593 (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn) 5594 if result is None: 5595 return fallback_cumsum(x, dim=axis, dtype=dtype) 5596 return result 5597 5598 5599@register_lowering(aten.cumprod) 5600def cumprod(x, axis=None, dtype=None): 5601 if ( 5602 is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) 5603 ) and dtype is None: 5604 dtype = torch.int64 5605 5606 if len(x.get_size()) == 0: 5607 assert axis in [0, -1] 5608 dtype = dtype or x.get_dtype() 5609 return to_dtype(x, dtype, copy=True) 5610 5611 def combine_fn(a_tuple, b_tuple): 5612 (a,) = a_tuple 5613 (b,) = b_tuple 5614 return (ops.mul(a, b),) 5615 5616 kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) 5617 (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn) 5618 if result is None: 5619 return fallback_cumprod(x, dim=axis, dtype=dtype) 5620 return result 5621 5622 5623@register_lowering(aten.logcumsumexp) 5624def logcumsumexp(x, dim): 5625 def log_add_exp_helper(a_tuple, b_tuple): 5626 (a,) = a_tuple 5627 (b,) = b_tuple 5628 min_v = ops.minimum(a, b) 5629 max_v = ops.maximum(a, b) 5630 mask = (min_v != max_v) | (~ops.isinf(min_v)) 5631 return (ops.where(mask, ops.log1p(ops.exp(min_v - max_v)) + max_v, a),) 5632 5633 dtype = x.get_dtype() 5634 if len(x.get_size()) == 0: 5635 assert dim in [0, -1] 5636 return clone(x) 5637 5638 kwargs = _make_scan_inner(x, axis=dim, dtype=dtype) 5639 (result,) = ir.Scan.create(**kwargs, combine_fn=log_add_exp_helper) 5640 if result is None: 5641 return fallback_logcumsumexp(x, dim=dim) 5642 return result 5643 5644 5645@register_lowering(aten.cummax, type_promotion_kind=None) 5646def cummax(x, axis=None): 5647 if len(x.get_size()) == 0: 5648 assert axis in [0, -1] 5649 return clone(x), empty_like(x, dtype=torch.int64) 5650 5651 dtype = x.get_dtype() 5652 combine_fn = ir.get_reduction_combine_fn( 5653 "argmax", dtype=dtype, arg_break_ties_left=False 5654 ) 5655 5656 min_value = ( 5657 False 5658 if dtype is torch.bool 5659 else ( 5660 torch.finfo(dtype).min 5661 if dtype.is_floating_point 5662 else torch.iinfo(dtype).min 5663 ) 5664 ) 5665 5666 kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) 5667 kwargs["dtypes"] = (dtype, torch.int64) 5668 kwargs["inner_fns"] = (x.make_loader(), lambda _: "rindex") 5669 values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type] # next PR 5670 if values is None: 5671 return fallback_cummax(x, dim=axis) 5672 return values, indices 5673 5674 5675@register_lowering(aten.cummin, type_promotion_kind=None) 5676def cummin(x, axis=None): 5677 if len(x.get_size()) == 0: 5678 assert axis in [0, -1] 5679 return clone(x), empty_like(x, dtype=torch.int64) 5680 5681 dtype = x.get_dtype() 5682 combine_fn = ir.get_reduction_combine_fn( 5683 "argmin", dtype=dtype, arg_break_ties_left=False 5684 ) 5685 5686 max_value = ( 5687 True 5688 if dtype is torch.bool 5689 else ( 5690 torch.finfo(dtype).max 5691 if dtype.is_floating_point 5692 else torch.iinfo(dtype).max 5693 ) 5694 ) 5695 5696 kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) 5697 kwargs["dtypes"] = (dtype, torch.int64) 5698 kwargs["inner_fns"] = (x.make_loader(), lambda _: "rindex") 5699 values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type] # next PR 5700 if values is None: 5701 return fallback_cummin(x, dim=axis) 5702 return values, indices 5703 5704 5705@register_lowering(aten.prod) 5706def prod(x, axis=None, keepdims=False, *, dtype=None): 5707 if ( 5708 is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) 5709 ) and dtype is None: 5710 dtype = torch.int64 5711 5712 fn = make_reduction("prod", override_return_dtype=dtype) 5713 return fn(x, axis, keepdims, dtype=dtype) 5714 5715 5716@register_lowering(aten.any) 5717def reduce_any(x, dim=None, keepdim=False): 5718 x = to_dtype(x, torch.bool) 5719 return make_reduction("any")(x, axis=dim, keepdims=keepdim) 5720 5721 5722@register_lowering(aten.max, type_promotion_kind=None) 5723def reduce_max(x, dim=None, keepdim=False): 5724 if dim is not None: 5725 return ( 5726 reduce_amax(x, axis=dim, keepdims=keepdim), 5727 reduce_argmax(x, axis=dim, keepdims=keepdim), 5728 ) 5729 5730 return reduce_amax(x, axis=None, keepdims=keepdim) 5731 5732 5733@register_lowering(aten.min, type_promotion_kind=None) 5734def reduce_min(x, dim=None, keepdim=False): 5735 if dim is not None: 5736 return ( 5737 reduce_amin(x, axis=dim, keepdims=keepdim), 5738 reduce_argmin(x, axis=dim, keepdims=keepdim), 5739 ) 5740 5741 return reduce_amin(x, axis=None, keepdims=keepdim) 5742 5743 5744register_lowering(prims.xor_sum)(make_reduction("xor_sum")) 5745reduce_amax = register_lowering(aten.amax)(make_reduction("max")) 5746reduce_amin = register_lowering(aten.amin)(make_reduction("min")) 5747reduce_argmax = register_lowering(aten.argmax)( 5748 make_reduction("argmax", override_return_dtype=torch.int64) 5749) 5750reduce_argmin = register_lowering(aten.argmin)( 5751 make_reduction("argmin", override_return_dtype=torch.int64) 5752) 5753 5754add = register_pointwise( 5755 aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or" 5756) 5757 5758sort_fallback = fallback_handler(aten.sort.stable, add_to_fallback_set=False) 5759 5760 5761@register_lowering(aten.sort.stable, type_promotion_kind=None) 5762def sort_stable(x, *, stable=None, dim=-1, descending=False): 5763 if stable is None: 5764 stable = False 5765 5766 shape = x.get_size() 5767 device = x.get_device() 5768 dim = canonicalize_dim(len(shape), dim) 5769 if len(shape) == 0: 5770 return clone(x), _full(0, device, torch.int64, shape) 5771 5772 dim_size = shape[dim] if len(shape) else 1 5773 if not V.graph.sizevars.statically_known_lt(dim_size, torch.iinfo(torch.int16).max): 5774 return sort_fallback(x, stable=stable, dim=dim, descending=descending) 5775 5776 indices = iota( 5777 dim_size, start=0, step=1, dtype=torch.int16, device=device, requires_grad=False 5778 ) 5779 view_shape = [1] * len(shape) 5780 if len(shape): 5781 view_shape[dim] = dim_size 5782 indices = view(indices, view_shape) 5783 indices = expand(indices, shape) 5784 5785 values, indices = ir.Sort.create( 5786 device=device, 5787 dtypes=(x.dtype, indices.dtype), 5788 inner_fns=(x.make_loader(), indices.make_loader()), 5789 size=shape, 5790 axis=dim, 5791 stable=stable, 5792 descending=descending, 5793 ) 5794 if values is None: 5795 return sort_fallback(x, stable=stable, dim=dim, descending=descending) 5796 5797 assert indices is not None 5798 return values, to_dtype(indices, torch.int64) 5799 5800 5801@register_lowering(aten.sort.default, type_promotion_kind=None) 5802def sort(x, dim=-1, descending=False): 5803 return sort_stable(x, stable=False, dim=dim, descending=descending) 5804 5805 5806def register_pointwise_numeric(op, name=None, triton_fallback=None): 5807 return register_pointwise( 5808 op, 5809 name=name, 5810 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 5811 triton_fallback=triton_fallback, 5812 ) 5813 5814 5815def register_pointwise_numeric_ldf64(op): 5816 return register_pointwise( 5817 op, 5818 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 5819 use_libdevice_for_f64=True, 5820 ) 5821 5822 5823exp = register_pointwise_numeric_ldf64(aten.exp) 5824exp2 = register_pointwise_numeric(aten.exp2) 5825expm1 = register_pointwise_numeric(aten.expm1) 5826relu = register_pointwise(aten.relu) 5827sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid) 5828sqrt = register_pointwise_numeric_ldf64(aten.sqrt) 5829square = register_pointwise(aten.square) 5830sub = register_pointwise(aten.sub, allow_alpha=True) 5831register_pointwise_numeric_ldf64(aten.cos) 5832register_pointwise_numeric_ldf64(aten.sin) 5833abs = register_pointwise(aten.abs) 5834bitwise_and = register_pointwise(aten.bitwise_and) 5835bitwise_left_shift = register_pointwise(aten.bitwise_left_shift) 5836bitwise_not = register_pointwise( 5837 aten.bitwise_not, override_fn_when_input_bool="logical_not" 5838) 5839bitwise_or = register_pointwise(aten.bitwise_or) 5840bitwise_right_shift = register_pointwise(aten.bitwise_right_shift) 5841bitwise_xor = register_pointwise(aten.bitwise_xor) 5842register_pointwise_numeric(aten.lgamma) 5843erf = register_pointwise_numeric(aten.erf) 5844register_lowering( 5845 aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 5846)(erf) 5847 5848register_pointwise_numeric(aten.log1p) 5849register_pointwise_numeric(aten.tan) 5850register_pointwise_numeric(aten.tanh) 5851register_pointwise_numeric_ldf64(aten.log) 5852logical_and = register_pointwise( 5853 aten.logical_and, 5854 type_promotion_kind=None, 5855 convert_input_to_bool=True, 5856 override_return_dtype=torch.bool, 5857) 5858logical_not = register_pointwise( 5859 aten.logical_not, 5860 type_promotion_kind=None, 5861 convert_input_to_bool=True, 5862 override_return_dtype=torch.bool, 5863) 5864logical_or = register_pointwise( 5865 aten.logical_or, 5866 type_promotion_kind=None, 5867 convert_input_to_bool=True, 5868 override_return_dtype=torch.bool, 5869) 5870logical_xor = register_pointwise( 5871 aten.logical_xor, 5872 type_promotion_kind=None, 5873 convert_input_to_bool=True, 5874 override_return_dtype=torch.bool, 5875) 5876maximum = register_pointwise(aten.maximum) 5877minimum = register_pointwise(aten.minimum) 5878register_lowering(aten.clamp_min)(maximum) 5879register_lowering(aten.clamp_max)(minimum) 5880neg = register_pointwise(aten.neg) 5881abs = register_pointwise(aten.abs) 5882reciprocal = register_pointwise_numeric(aten.reciprocal) 5883register_pointwise(aten.remainder) 5884sign = register_pointwise(aten.sign, override_fn_when_input_bool="identity") 5885register_pointwise(aten.ceil) 5886register_pointwise(aten.signbit, override_return_dtype=torch.bool) 5887 5888register_lowering(aten._neg_view)(neg) 5889 5890register_pointwise(aten.le, override_return_dtype=torch.bool) 5891register_pointwise(aten.lt, override_return_dtype=torch.bool) 5892register_pointwise(aten.ge, override_return_dtype=torch.bool) 5893gt = register_pointwise(aten.gt, override_return_dtype=torch.bool) 5894register_pointwise(aten.eq, override_return_dtype=torch.bool) 5895register_pointwise(aten.ne, override_return_dtype=torch.bool) 5896 5897register_pointwise_numeric(aten.cosh) 5898register_pointwise_numeric(aten.sinh) 5899register_pointwise_numeric(aten.acos) 5900register_pointwise_numeric(aten.acosh) 5901register_pointwise_numeric(aten.asin) 5902register_pointwise_numeric(aten.asinh) 5903register_pointwise_numeric(aten.atan2) 5904register_pointwise_numeric(aten.atan) 5905register_pointwise_numeric(aten.atanh) 5906register_pointwise_numeric(aten.copysign) 5907register_pointwise_numeric(aten.erfc) 5908register_pointwise_numeric(aten.erfinv) 5909register_pointwise_numeric(aten.hypot) 5910register_pointwise_numeric(aten.log10) 5911register_pointwise_numeric(aten.log2) 5912register_pointwise_numeric(aten.nextafter) 5913 5914from .codegen.common import BackendFeature, pointwise_overrides_data 5915 5916 5917def _get_pointwise_overrides(ns, name): 5918 data = pointwise_overrides_data[name] 5919 op = getattr(ns, data.name, None) 5920 if op is None: 5921 return 5922 5923 def make_triton_fallback(op): 5924 if data.triton is None: 5925 return fallback_handler(op) 5926 5927 if isinstance(op, torch._ops.OpOverloadPacket): 5928 for olname in op.overloads(): 5929 ol = getattr(op, olname) 5930 yield ol, data.type_promotion_kind, make_triton_fallback(ol) 5931 else: 5932 yield op, data.type_promotion_kind, make_triton_fallback(op) 5933 5934 5935for name in pointwise_overrides_data: 5936 for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( 5937 aten, name 5938 ): 5939 register_pointwise( 5940 op, 5941 name=name, 5942 type_promotion_kind=type_promotion_kind, 5943 triton_fallback=triton_fallback, 5944 ) 5945 5946 for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( 5947 prims, name 5948 ): 5949 register_pointwise( 5950 op, 5951 name=name, 5952 type_promotion_kind=type_promotion_kind, 5953 triton_fallback=triton_fallback, 5954 ) 5955 5956 5957foreach_add_list = register_foreach_pointwise( 5958 aten._foreach_add.List, add, allow_alpha=True 5959) 5960foreach_add_scalar = register_foreach_pointwise( 5961 aten._foreach_add.Scalar, add, allow_alpha=True 5962) 5963register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True) 5964foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul) 5965foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul) 5966register_foreach_pointwise(aten._foreach_sub.List, sub) 5967register_foreach_pointwise(aten._foreach_sub.Scalar, sub) 5968register_foreach_pointwise(aten._foreach_neg.default, neg) 5969register_foreach_pointwise(aten._foreach_abs.default, abs) 5970register_foreach_pointwise(aten._foreach_pow.Scalar, pow) 5971register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow) 5972foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div) 5973foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div) 5974register_foreach_pointwise(aten._foreach_sqrt, sqrt) 5975register_foreach_pointwise(aten._foreach_maximum.List, maximum) 5976register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum) 5977register_foreach_pointwise(aten._foreach_minimum.List, minimum) 5978register_foreach_pointwise(aten._foreach_minimum.Scalar, minimum) 5979register_foreach_pointwise(aten._foreach_clamp_min.List, maximum) 5980register_foreach_pointwise(aten._foreach_clamp_min.Scalar, maximum) 5981register_foreach_pointwise(aten._foreach_clamp_max.List, minimum) 5982register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum) 5983register_foreach_pointwise(aten._foreach_reciprocal, reciprocal) 5984register_foreach_pointwise(aten._foreach_sign, sign) 5985register_foreach_pointwise(aten._foreach_copy, copy) 5986 5987 5988# these are only encountered as outputs of the graph 5989# reinplacing epilogue copies improves compile time 5990# by removing extra buffers sent to the scheduler. 5991def register_foreach_inplace(aten_op, outplace_aten_op, outplace_op): 5992 inplaceable_foreach_ops[outplace_aten_op] = aten_op 5993 inplace_foreach_ops.add(aten_op) 5994 5995 def fn(*args, **kwargs): 5996 results = outplace_op(*args, **kwargs) 5997 mut_results = [] 5998 for arg, result in zip(args[0], results): 5999 mut_results.append(mutate_to(arg, result, unsafe_alias=True)) 6000 6001 return mut_results 6002 6003 _register_foreach_lowering(aten_op, fn) 6004 6005 6006register_foreach_inplace( 6007 aten._foreach_add_.List, aten._foreach_add.List, foreach_add_list 6008) 6009register_foreach_inplace( 6010 aten._foreach_add_.Scalar, aten._foreach_add.Scalar, foreach_add_scalar 6011) 6012register_foreach_inplace( 6013 aten._foreach_mul_.List, aten._foreach_mul.List, foreach_mul_list 6014) 6015register_foreach_inplace( 6016 aten._foreach_mul_.Scalar, aten._foreach_mul.Scalar, foreach_mul_scalar 6017) 6018register_foreach_inplace( 6019 aten._foreach_div_.List, aten._foreach_div.List, foreach_div_list 6020) 6021register_foreach_inplace( 6022 aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar 6023) 6024 6025 6026def register_inplace(aten_op, outplace_op): 6027 @register_lowering(aten_op, type_promotion_kind=None) 6028 def fn(*args, **kwargs): 6029 result = outplace_op(*args, **kwargs) 6030 result = to_dtype(result, args[0].get_dtype()) 6031 return mutate_to(args[0], result) 6032 6033 return fn 6034 6035 6036register_inplace(aten.add_, add) 6037register_inplace(aten.bitwise_and_, bitwise_and) 6038register_inplace(aten.bitwise_left_shift_, bitwise_left_shift) 6039register_inplace(aten.bitwise_not_, bitwise_not) 6040register_inplace(aten.bitwise_or_, bitwise_or) 6041register_inplace(aten.bitwise_right_shift_, bitwise_right_shift) 6042register_inplace(aten.bitwise_xor_, bitwise_xor) 6043register_inplace(aten.mul_, mul) 6044register_inplace(aten.div_.Tensor, div) 6045register_inplace(aten.div_.Tensor_mode, div_mode) 6046register_inplace(aten.logical_and_, logical_and) 6047register_inplace(aten.logical_not_, logical_not) 6048register_inplace(aten.logical_or_, logical_or) 6049register_inplace(aten.logical_xor_, logical_xor) 6050register_inplace(aten.sub_, sub) 6051register_inplace(aten.relu_, relu) 6052register_inplace(aten.sigmoid_, sigmoid) 6053 6054 6055register_lowering(aten.__and__)(bitwise_and) 6056register_lowering(aten.__lshift__)(bitwise_left_shift) 6057register_lowering(aten.__or__)(bitwise_or) 6058register_lowering(aten.__rshift__)(bitwise_right_shift) 6059register_lowering(aten.__xor__)(bitwise_xor) 6060 6061register_inplace(aten.__iand__, aten.__and__) 6062register_inplace(aten.__ilshift__, aten.__lshift__) 6063register_inplace(aten.__ior__, aten.__or__) 6064register_inplace(aten.__irshift__, aten.__rshift__) 6065register_inplace(aten.__ixor__, aten.__xor__) 6066 6067 6068@register_lowering(aten.sym_constrain_range) 6069def sym_constrain_range(a, min=None, max=None): 6070 return None 6071 6072 6073@register_lowering(aten.sym_size.int) 6074def sym_size(a, dim): 6075 val = V.graph.current_node.meta["val"] 6076 # Note [Can val be an int?] 6077 # ~~~~~~~~~~~~~~~~~~~~~~~~~ 6078 # In principle, someone could construct an FX graph where 6079 # a call to size/stride has a val that is a plain int (not 6080 # SymInt). However, we will maintain the invariant that 6081 # this is not possible: if you are constructing an FX graph 6082 # where there is a call to size/stride that returns an 6083 # int, but you KNOW that int must always be a constant, 6084 # then you do not need trace that call at all (and just 6085 # constant propagate the integer as is.) 6086 assert isinstance(val, torch.SymInt) 6087 return val.node.expr 6088 6089 6090@register_lowering(aten.sym_stride.int) 6091def sym_stride(a, dim): 6092 val = V.graph.current_node.meta["val"] 6093 # See Note [Can val be an int?] 6094 assert isinstance(val, torch.SymInt) 6095 return val.node.expr 6096 6097 6098@register_lowering(aten.sym_numel) 6099def sym_numel(a): 6100 return a.get_numel() 6101 6102 6103for method, func in magic_methods.items(): 6104 register_lowering(method_to_operator(method))(func) 6105 6106 6107@register_lowering(aten._foobar) 6108def foobar(self, *args, **kwargs): 6109 raise NotImplementedError("Helpful for debugging") 6110 6111 6112@register_lowering(torch.ops._inductor_test.realize) 6113def _realize(x): 6114 x.realize() 6115 return clone(x) 6116 6117 6118@register_lowering(torch.ops.inductor.resize_storage_bytes_) 6119def resize_storage_bytes_(variable, new_size): 6120 variable.realize() 6121 ir.ResizeStorageBytes(variable, new_size) 6122 return variable 6123 6124 6125@register_lowering(torch.ops.aten.set_.source_Tensor) 6126def set__source_tensor(self, source_tensor): 6127 self.realize() 6128 source_tensor.realize() 6129 return TensorBox.create(ir.SetSourceTensorKernel(self, source_tensor)) 6130 6131 6132if hasattr(torch.ops.fsdp, "set_"): 6133 6134 @register_lowering(torch.ops.fsdp.set_.default) 6135 def fsdp_set_(self, source_tensor): 6136 self.realize() 6137 source_tensor.realize() 6138 ir.SetSourceTensorKernel(self, source_tensor) 6139 6140 6141@register_lowering(torch.ops.aten.resize) 6142def resize(x, size, *, memory_format=None): 6143 assert isinstance(x, TensorBox) 6144 assert isinstance(size, (list, tuple)) 6145 6146 if memory_format is None: 6147 memory_format = torch.contiguous_format 6148 if memory_format == torch.preserve_format: 6149 raise RuntimeError(f"unsupported memory format: {memory_format}") 6150 6151 if memory_format == torch.channels_last: 6152 assert len(size) == 4 6153 if memory_format == torch.channels_last_3d: 6154 assert len(size) == 5 6155 6156 old_numel = x.get_numel() 6157 dtype = x.get_dtype() 6158 device = x.get_device() 6159 6160 if isinstance(x.data, ir.BaseView): 6161 x.data = x.data.unwrap_view() 6162 6163 if ( 6164 torch.are_deterministic_algorithms_enabled() 6165 and torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined] 6166 ): 6167 if is_float_dtype(dtype): 6168 uninitalized_val = float("nan") 6169 elif is_integer_dtype(dtype): 6170 uninitalized_val = torch.iinfo(dtype).max 6171 else: 6172 uninitalized_val = True 6173 else: 6174 # using zero as that is what empty does 6175 uninitalized_val = 0.0 6176 6177 if V.graph.sizevars.statically_known_equals(old_numel, 0): # type: ignore[arg-type] 6178 return full(size, uninitalized_val, dtype=dtype, device=device) 6179 6180 x_flat = as_strided( 6181 x, 6182 [ 6183 old_numel, 6184 ], 6185 [ 6186 1, 6187 ], 6188 ) 6189 flat_loader = x_flat.make_loader() 6190 out_stride = ir.FlexibleLayout.stride_ordered_for_memory_format(size, memory_format) 6191 out_indexer = ir.FixedLayout(device, dtype, size, out_stride).make_indexer() 6192 6193 def inner_fn(idx): 6194 flat_index = out_indexer(idx) 6195 flat_index_expr = ops.index_expr(flat_index, torch.int64) 6196 limit = ops.index_expr(old_numel, torch.int64) 6197 mask = ops.lt(flat_index_expr, limit) 6198 return ops.masked(mask, lambda: flat_loader([flat_index]), uninitalized_val) 6199 6200 out = Pointwise.create( 6201 device=device, dtype=dtype, inner_fn=inner_fn, ranges=list(size) 6202 ) 6203 return out 6204 6205 6206from torch._higher_order_ops.auto_functionalize import auto_functionalized 6207 6208 6209make_fallback(auto_functionalized) 6210 6211 6212@register_lowering(triton_kernel_wrapper_mutation) 6213def triton_kernel_wrap_(*, kernel_idx, constant_args_idx, grid, kwargs): 6214 from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table 6215 6216 constant_args = kernel_side_table.get_constant_args(constant_args_idx) 6217 ir.UserDefinedTritonKernel( 6218 kernel_idx=kernel_idx, 6219 grid=grid, 6220 kernel_args={**kwargs, **constant_args}, 6221 ) 6222 return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)} 6223 6224 6225@register_lowering(torch.ops.higher_order.cond) 6226def cond(pred, true_fn, false_fn, operands): 6227 if is_triton(pred) or any(map(is_triton, operands)): 6228 msg = "control flow operator: torch.cond." 6229 if stack_trace := V.graph.current_node.meta.get("stack_trace", None): 6230 msg = f"{msg} Found from : \n {stack_trace}" 6231 V.graph.disable_cudagraphs_reason = msg 6232 6233 result = ir.Conditional.create(pred, true_fn, false_fn, operands) 6234 return list(map(TensorBox.create, result)) 6235 6236 6237@register_lowering(torch.ops.higher_order.while_loop) 6238def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): 6239 if any(map(is_triton, carried_inputs + additional_inputs)): 6240 msg = "control flow operator: torch.while_loop." 6241 if stack_trace := V.graph.current_node.meta.get("stack_trace", None): 6242 msg = f"{msg} Found from : \n {stack_trace}" 6243 V.graph.disable_cudagraphs_reason = msg 6244 6245 result = ir.WhileLoop.create(cond_fn, body_fn, carried_inputs, additional_inputs) 6246 return list(map(TensorBox.create, result)) 6247 6248 6249@register_lowering(associative_scan_op, type_promotion_kind=None) 6250def associative_scan(combine_fn: ir.Subgraph, input, dim: int): 6251 from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph 6252 6253 subgraph_inputs = [ 6254 InputDescriptor(dtype=x.get_dtype(), device=x.get_device()) 6255 for x in itertools.chain(input, input) 6256 ] 6257 lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs) # type: ignore[var-annotated] 6258 6259 def wrapped_combine_fn(lhs, rhs): 6260 return lowered_combine_fn( 6261 *pytree.tree_leaves(lhs), 6262 *pytree.tree_leaves(rhs), 6263 ) 6264 6265 kwargs = _make_scan_inner(input[0], axis=dim, dtype=None) 6266 kwargs["dtypes"] = tuple(x.get_dtype() for x in input) 6267 kwargs["inner_fns"] = tuple(x.make_loader() for x in input) 6268 result = ir.Scan.create( 6269 combine_fn=wrapped_combine_fn, 6270 can_fallback_to_aten=False, 6271 **kwargs, 6272 ) 6273 if result[0] is None: 6274 raise RuntimeError("Unable to generate code for associative_scan op") 6275 return result 6276 6277 6278@register_lowering(torch.ops.prims._sink_tokens.default) 6279def _sink_tokens(tokens): 6280 return None 6281 6282 6283@register_lowering(torch.ops.higher_order.with_effects) 6284def with_effects(token, op, *args, **kwargs): 6285 result = ir.EffectfulKernel.create(op, *args, **kwargs) 6286 6287 from torch._higher_order_ops.effects import get_effect_key 6288 6289 effect_type = get_effect_key(op, args, kwargs) 6290 assert effect_type is not None 6291 effectful_kernel = V.graph.effectful_ops[effect_type] 6292 6293 if result is None: 6294 return (effectful_kernel,) 6295 6296 result = pytree.tree_map_only(ir.MultiOutput, TensorBox.create, result) 6297 if not isinstance(result, (list, tuple)): 6298 return (effectful_kernel, result) 6299 else: 6300 return (effectful_kernel, *result) 6301 6302 6303try: 6304 import torch.distributed._functional_collectives 6305 6306 _c10d_functional = torch.ops._c10d_functional 6307 6308 @register_lowering(_c10d_functional.all_reduce) 6309 def _all_reduce(inp, reduce_op, group_name): 6310 inp = clone(inp) 6311 if config.reorder_for_compute_comm_overlap: 6312 # The horizontal fusion of this clone often severely delays the 6313 # scheduling of the all_reduce_ node. Horizontally fusing this 6314 # clone can almost never out-perform scheduling the all_reduce_ 6315 # earlier. Also in most cases, this clone is eliminated via 6316 # in-place reuse. Therefore, we tell the scheduler to not fuse it. 6317 inp.realize() 6318 V.graph.no_fuse_buffer_names.add(inp.get_name()) 6319 ir._CollectiveKernel.create_inplace( 6320 _c10d_functional.all_reduce_.default, inp, reduce_op, group_name 6321 ) 6322 return inp 6323 6324 @register_lowering(_c10d_functional.all_reduce_) 6325 def _all_reduce_(inp, reduce_op, group_name): 6326 ir._CollectiveKernel.create_inplace( 6327 _c10d_functional.all_reduce_.default, inp, reduce_op, group_name 6328 ) 6329 return inp 6330 6331 @register_lowering(_c10d_functional.all_reduce_coalesced) 6332 def _all_reduce_coalesced(inputs, reduce_op, group_name): 6333 inputs = [clone(inp) for inp in inputs] 6334 ir._CollectiveKernel.create_inplace( 6335 _c10d_functional.all_reduce_coalesced_.default, 6336 inputs, 6337 reduce_op, 6338 group_name, 6339 ) 6340 return inputs 6341 6342 @register_lowering(_c10d_functional.all_reduce_coalesced_) 6343 def _all_reduce_coalesced_(inputs, reduce_op, group_name): 6344 ir._CollectiveKernel.create_inplace( 6345 _c10d_functional.all_reduce_coalesced_.default, 6346 inputs, 6347 reduce_op, 6348 group_name, 6349 ) 6350 return inputs 6351 6352 @register_lowering(_c10d_functional.all_gather_into_tensor) 6353 def _all_gather_into_tensor(inp, group_size, group_name): 6354 return ir.TensorBox.create( 6355 ir._CollectiveKernel.create_out_of_place( 6356 _c10d_functional.all_gather_into_tensor.default, 6357 inp, 6358 group_size, 6359 group_name, 6360 ) 6361 ) 6362 6363 @register_lowering(_c10d_functional.all_gather_into_tensor_coalesced) 6364 def _all_gather_into_tensor_coalesced(inputs, group_size, group_name): 6365 return pytree.tree_map( 6366 ir.TensorBox.create, 6367 ir._CollectiveKernel.create_out_of_place( 6368 _c10d_functional.all_gather_into_tensor_coalesced.default, 6369 inputs, 6370 group_size, 6371 group_name, 6372 ), 6373 ) 6374 6375 @register_lowering(_c10d_functional.all_gather_into_tensor_out) 6376 def _all_gather_into_tensor_out(inp, group_size, group_name, *, out): 6377 ir._CollectiveKernel.create_inplace( 6378 _c10d_functional.all_gather_into_tensor_out.default, 6379 inp, 6380 group_size, 6381 group_name, 6382 out=out, 6383 ) 6384 return out 6385 6386 @register_lowering(_c10d_functional.reduce_scatter_tensor) 6387 def _reduce_scatter_tensor(inp, reduce_op, group_size, group_name): 6388 return ir.TensorBox.create( 6389 ir._CollectiveKernel.create_out_of_place( 6390 _c10d_functional.reduce_scatter_tensor.default, 6391 inp, 6392 reduce_op, 6393 group_size, 6394 group_name, 6395 ) 6396 ) 6397 6398 @register_lowering(_c10d_functional.reduce_scatter_tensor_coalesced) 6399 def _reduce_scatter_tensor_coalesced(inputs, reduce_op, group_size, group_name): 6400 return pytree.tree_map( 6401 ir.TensorBox.create, 6402 ir._CollectiveKernel.create_out_of_place( 6403 _c10d_functional.reduce_scatter_tensor_coalesced.default, 6404 inputs, 6405 reduce_op, 6406 group_size, 6407 group_name, 6408 ), 6409 ) 6410 6411 @register_lowering(_c10d_functional.all_to_all_single) 6412 def _all_to_all_single(inp, output_split_sizes, input_split_sizes, group_name): 6413 return ir.TensorBox.create( 6414 ir._CollectiveKernel.create_out_of_place( 6415 _c10d_functional.all_to_all_single.default, 6416 inp, 6417 output_split_sizes, 6418 input_split_sizes, 6419 group_name, 6420 ) 6421 ) 6422 6423 @register_lowering(_c10d_functional.broadcast) 6424 def _broadcast(inp, src, group_name): 6425 inp = clone(inp) 6426 ir._CollectiveKernel.create_inplace( 6427 _c10d_functional.broadcast_.default, inp, src, group_name 6428 ) 6429 return inp 6430 6431 @register_lowering(_c10d_functional.broadcast_) 6432 def _broadcast_(inp, src, group_name): 6433 ir._CollectiveKernel.create_inplace( 6434 _c10d_functional.broadcast_.default, inp, src, group_name 6435 ) 6436 return inp 6437 6438 @register_lowering(_c10d_functional.wait_tensor) 6439 def _wait_tensor(inp): 6440 ir._WaitKernel.create_wait(_c10d_functional.wait_tensor.default, inp) 6441 return inp 6442 6443 @register_lowering(torch.ops._dtensor.shard_dim_alltoall) 6444 def _shard_dim_alltoall(inp, gather_dim, shard_dim, group_name): 6445 return ir.TensorBox.create( 6446 ir._CollectiveKernel.create_out_of_place( 6447 torch.ops._dtensor.shard_dim_alltoall.default, 6448 inp, 6449 gather_dim, 6450 shard_dim, 6451 group_name, 6452 ) 6453 ) 6454 6455except (AttributeError, ImportError): 6456 log.info( 6457 "Inductor support for distributed collectives depends on building torch.distributed" 6458 ) 6459 6460# populate lowerings defined in kernel/* 6461from . import kernel 6462 6463 6464import_submodule(kernel) 6465 6466from . import quantized_lowerings 6467 6468 6469quantized_lowerings.register_quantized_ops() 6470quantized_lowerings.register_woq_mm_ops() 6471 6472from . import mkldnn_lowerings 6473 6474 6475mkldnn_lowerings.register_onednn_fusion_ops() 6476 6477from . import jagged_lowerings 6478 6479 6480jagged_lowerings.register_jagged_ops() 6481