1# mypy: ignore-errors 2 3import functools 4import itertools 5import math 6import sys 7from typing import Callable, Union 8 9import torch 10import torch._custom_op 11import torch._logging 12from torch._dispatch.python import no_python_dispatcher 13from torch._ops import OpOverload 14from torch._prims_common import ( 15 elementwise_dtypes, 16 ELEMENTWISE_TYPE_PROMOTION_KIND, 17 is_boolean_dtype, 18 is_float_dtype, 19 is_integer_dtype, 20) 21from torch._subclasses.fake_tensor import ( 22 DataDependentOutputException, 23 DynamicOutputShapeException, 24 FakeTensor, 25 in_kernel_invocation_manager, 26 run_fallback_kernel, 27 UnsupportedOperatorException, 28) 29from torch.fx.operator_schemas import normalize_function 30from torch.utils._stats import count_label 31 32 33pytree = torch.utils._pytree 34 35__all__ = [ 36 "op_implementations_checks", 37 "get_fast_op_impls", 38 "stride_incorrect_op", 39 "has_meta", 40] 41 42op_implementations_dict = {} 43op_implementations_checks = [] 44 45 46aten = torch._ops.ops.aten 47 48 49def ordered_set(*items): 50 return dict.fromkeys(items, True) 51 52 53# This function indicates if the backend device 54# supports non-contiguous tensors 55def is_noncontiguous_supported(device): 56 return device.type != "hpu" 57 58 59_like_tensor_constructors = ordered_set( 60 aten.empty_like.default, 61 aten.empty_like.out, 62 aten.full_like.default, 63 aten.full_like.out, 64 aten.ones_like.default, 65 aten.ones_like.out, 66 aten.rand_like.default, 67 aten.rand_like.out, 68 aten.randn_like.default, 69 aten.randn_like.out, 70 aten.randint_like.default, 71 aten.randint_like.out, 72 aten.randint_like.low_dtype, 73 aten.randint_like.low_dtype_out, 74 aten.zeros_like.default, 75 aten.zeros_like.out, 76 aten.new_empty.default, 77 aten.new_empty.out, 78 aten.new_empty_strided.default, 79 aten.new_empty_strided.out, 80 aten.new_full.default, 81 aten.new_full.out, 82 aten.new_zeros.default, 83 aten.new_zeros.out, 84 aten.new_ones.default, 85 aten.new_ones.out, 86) 87 88 89_device_not_kwarg_ops = ordered_set( 90 aten._resize_output_.default, 91 aten._nested_tensor_from_tensor_list.default, 92 aten._nested_tensor_from_tensor_list.out, 93 aten.pin_memory.default, 94 aten.to.device, 95 aten.to.prim_Device, 96 aten.is_pinned.default, 97 aten._pin_memory.default, 98 aten._pin_memory.out, 99 aten._resize_output.default, 100 aten._resize_output.out, 101) 102 103# this op is never actually used 104_non_kwarg_device_constructors = (aten._list_to_tensor,) 105 106 107def contains_tensor_types(type): 108 tensor_type = torch._C.TensorType.get() 109 return type.isSubtypeOf(tensor_type) or any( 110 contains_tensor_types(e) for e in type.containedTypes() 111 ) 112 113 114@functools.lru_cache(None) 115def _is_tensor_constructor(func: OpOverload): 116 assert isinstance(func, OpOverload) 117 schema = func._schema 118 if any(contains_tensor_types(arg.type) for arg in schema.arguments): 119 return False 120 # TODO: no real reason to restrict multiple outputs 121 return ( 122 len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get() 123 ) 124 125 126def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]): 127 def impl_decorator(op_impl): 128 if isinstance(run_impl_check, OpOverload): 129 assert ( 130 run_impl_check not in op_implementations_dict 131 ), f"duplicate registration: {run_impl_check}" 132 op_implementations_dict[run_impl_check] = op_impl 133 elif isinstance(run_impl_check, (list, tuple)): 134 for op in run_impl_check: 135 register_op_impl(op)(op_impl) 136 else: 137 assert callable(run_impl_check) 138 op_implementations_checks.append((run_impl_check, op_impl)) 139 140 return op_impl 141 142 return impl_decorator 143 144 145@register_op_impl(op_implementations_dict.__contains__) 146def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs): 147 return op_implementations_dict[func](fake_mode, func, *args, **kwargs) 148 149 150@register_op_impl(_is_tensor_constructor) 151@register_op_impl([*_like_tensor_constructors]) 152def constructors(fake_mode, func, *args, **kwargs): 153 assert func not in _non_kwarg_device_constructors 154 _, new_kwargs = normalize_function( 155 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 156 ) 157 if "names" in kwargs: 158 raise UnsupportedOperatorException( 159 "torch.compile doesn't support named tensors" 160 ) 161 162 if func in _like_tensor_constructors: 163 default_device = new_kwargs["input"].device 164 # TODO: file issue 165 args = (new_kwargs.pop("input"),) 166 else: 167 # cpu is default device if none is specified 168 default_device = torch.device("cpu") 169 args = () 170 out_device = new_kwargs.pop("device", None) 171 out_device = out_device if out_device is not None else default_device 172 new_kwargs["device"] = torch.device("meta") 173 # _like constructors have fake tensor inputs (maybe this causes the non-like 174 # to fail? hmmm) 175 with in_kernel_invocation_manager(fake_mode): 176 r = func(*args, **new_kwargs) 177 return FakeTensor(fake_mode, r, out_device) 178 179 180@register_op_impl(aten.is_pinned.default) 181def non_kwarg_is_pinned(fake_mode, func, *args, **kwargs): 182 _, new_kwargs = normalize_function( 183 func, args, kwargs, normalize_to_only_use_kwargs=True 184 ) 185 inp = new_kwargs.pop("input") 186 # we'll ignore device argument because it is deprecated and not 187 # actually used by is_pinned. 188 with in_kernel_invocation_manager(fake_mode): 189 r = func(inp) 190 return r 191 192 193@register_op_impl(aten.to.prim_Device) 194@register_op_impl(aten.to.device) 195def non_kwarg_to(fake_mode, func, *args, **kwargs): 196 _, new_kwargs = normalize_function( 197 func, args, kwargs, normalize_to_only_use_kwargs=True 198 ) 199 input_device = new_kwargs["device"] 200 out_device = input_device if input_device else new_kwargs["input"].device 201 new_kwargs["device"] = torch.device("meta") 202 inp = new_kwargs.pop("input") 203 with in_kernel_invocation_manager(fake_mode): 204 r = func(inp, **new_kwargs) 205 # TODO: I think this does the wrong thing if r is inp 206 return fake_mode.fake_tensor_converter.from_meta_and_device( 207 fake_mode, r, out_device 208 ) 209 210 211def stride_incorrect_op(op): 212 if op.namespace not in ("aten", "prims"): 213 return False 214 if op is aten._fft_c2c.default: 215 return False 216 217 op_name = op.name() 218 if "fft" in op_name: 219 return True 220 return False 221 222 223# These operators have meta implementations with incorrect strides 224@register_op_impl(stride_incorrect_op) 225def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs): 226 # This is a workaround for meta implmentations with incorrect strides 227 228 def is_symbolic(x): 229 if isinstance(x, FakeTensor): 230 return x._has_symbolic_sizes_strides 231 if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)): 232 return True 233 return False 234 235 # For static shapes, we can fall back to eager for the real strides 236 if fake_mode.allow_fallback_kernels: 237 require_dynamic = any( 238 is_symbolic(x) for x in itertools.chain(args, kwargs.values()) 239 ) 240 if not require_dynamic: 241 flat_args, args_spec = pytree.tree_flatten((args, kwargs)) 242 return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None) 243 244 raise UnsupportedOperatorException(func) 245 246 247# Dont default to default device handling, 248# since the device of `the_template` is ignored 249@register_op_impl(aten.resize_as_.default) 250def resize_as_(fake_mode, func, *args, **kwargs): 251 with in_kernel_invocation_manager(fake_mode): 252 return func(*args, **kwargs) 253 254 255@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default) 256def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs): 257 # TODO: remove me 258 return constructors(fake_mode, func, *args, **kwargs) 259 260 261# index.Tensor data-dependent in only some conditions 262@register_op_impl( 263 lambda func: torch.Tag.dynamic_output_shape in func.tags 264 and func 265 not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor] 266) 267def dyn_shape(fake_mode, func, *args, **kwargs): 268 raise DynamicOutputShapeException(func) 269 270 271def _unique( 272 fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False 273): 274 if ( 275 fake_mode.shape_env is None 276 or not fake_mode.shape_env.allow_dynamic_output_shape_ops 277 ): 278 # Without symints/symfloats, cannot handle this 279 raise DynamicOutputShapeException(func) 280 281 # Do not use a memo for unique_dim 282 if dim is not None or (nnz := arg.unique_memo) is None: 283 # Avoid importing sympy at a module level 284 from torch.fx.experimental.symbolic_shapes import ( 285 _constrain_range_for_size, 286 has_free_symbols, 287 ) 288 289 if not has_free_symbols(arg.numel()) and arg.numel() == 0: 290 # If numel is zero, then the output size must be zero. 291 # In this case, we must not allocate an unbacked SymInt, 292 # because if we do, it will immediately get refined to 293 # zero, but this will be inconsistent with size oblivious 294 # tests (which will continue to claim that the unbacked 295 # symint cannot equal zero). We could also unconditionally 296 # allocate an unbacked SymInt and not refine its range, 297 # but this seems more precise. 298 nnz = 0 299 else: 300 nnz = fake_mode.shape_env.create_unbacked_symint() 301 302 maxval = sys.maxsize - 1 303 304 numel = arg.numel() if dim is None else arg.size(dim) 305 if not has_free_symbols(numel): 306 maxval = int(numel) 307 308 _constrain_range_for_size(nnz, max=maxval) 309 310 if dim is None: 311 arg.unique_memo = nnz 312 313 if dim is None: 314 ret = [arg.new_empty((nnz,))] 315 else: 316 ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])] 317 318 return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu") 319 if return_inverse or return_if_dim_and_cpu: 320 inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],)) 321 else: 322 inverse = arg.new_empty(0) 323 ret.append(inverse) 324 325 if return_counts or return_if_dim_and_cpu: 326 counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],)) 327 else: 328 counts = arg.new_empty(0) 329 ret.append(counts) 330 331 return tuple(ret) 332 333 334@register_op_impl(aten._unique2.default) 335def unique2( 336 fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False 337): 338 return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts) 339 340 341@register_op_impl(aten.unique_dim.default) 342def unique_dim( 343 fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False 344): 345 return _unique( 346 fake_mode, 347 func, 348 arg, 349 # normalize dim to be non-negative 350 dim if dim >= 0 else dim % max(arg.ndim, 1), 351 sorted, 352 return_inverse, 353 return_counts, 354 ) 355 356 357@register_op_impl(aten.repeat_interleave.Tensor) 358def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): 359 if output_size is None: 360 if ( 361 fake_mode.shape_env is None 362 or not fake_mode.shape_env.allow_dynamic_output_shape_ops 363 ): 364 raise DynamicOutputShapeException(func) 365 366 output_size = fake_mode.shape_env.create_unbacked_symint() 367 368 # Avoid importing sympy at a module level 369 from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size 370 371 _constrain_range_for_size(output_size) 372 # TODO: consider a memo 373 return repeats.new_empty(output_size) 374 375 376@register_op_impl(torch.ops.aten.item.default) 377@register_op_impl(torch.ops.aten._local_scalar_dense.default) 378def local_scalar_dense(fake_mode, func, arg): 379 if (r := arg.item_memo) is not None: 380 return r 381 if fake_mode.shape_env is None or ( 382 not fake_mode.shape_env.allow_scalar_outputs 383 and not fake_mode.allow_scalar_outputs 384 ): 385 # Without symints/symfloats, cannot handle this 386 raise DataDependentOutputException(func) 387 if is_float_dtype(arg.dtype): 388 r = fake_mode.shape_env.create_unbacked_symfloat() 389 elif is_integer_dtype(arg.dtype): 390 r = fake_mode.shape_env.create_unbacked_symint() 391 elif is_boolean_dtype(arg.dtype): 392 r = fake_mode.shape_env.create_unbacked_symbool() 393 else: 394 raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}") 395 arg.item_memo = r 396 return r 397 398 399@register_op_impl(torch.ops.aten.nonzero.default) 400def nonzero(fake_mode, func, arg): 401 if ( 402 fake_mode.shape_env is None 403 or not fake_mode.shape_env.allow_dynamic_output_shape_ops 404 ): 405 # Without symints/symfloats, cannot handle this 406 raise DynamicOutputShapeException(func) 407 408 if (nnz := arg.nonzero_memo) is None: 409 # Avoid importing sympy at a module level 410 from torch.fx.experimental.symbolic_shapes import ( 411 _constrain_range_for_size, 412 has_free_symbols, 413 ) 414 415 if not has_free_symbols(arg.numel()) and arg.numel() == 0: 416 # If numel is zero, then the output size must be zero. 417 # In this case, we must not allocate an unbacked SymInt, 418 # because if we do, it will immediately get refined to 419 # zero, but this will be inconsistent with size oblivious 420 # tests (which will continue to claim that the unbacked 421 # symint cannot equal zero). We could also unconditionally 422 # allocate an unbacked SymInt and not refine its range, 423 # but this seems more precise. 424 nnz = 0 425 else: 426 nnz = fake_mode.shape_env.create_unbacked_symint() 427 428 maxval = sys.maxsize - 1 429 430 if not has_free_symbols(arg.numel()): 431 maxval = int(arg.numel()) 432 433 _constrain_range_for_size(nnz, max=maxval) 434 435 arg.nonzero_memo = nnz 436 437 return arg.new_empty((nnz, arg.dim()), dtype=torch.int64) 438 439 440@register_op_impl(torch.ops.aten.masked_select.default) 441def masked_select(fake_mode, func, self, mask): 442 if ( 443 fake_mode.shape_env is None 444 or not fake_mode.shape_env.allow_dynamic_output_shape_ops 445 ): 446 # Without symints/symfloats, cannot handle this 447 raise DynamicOutputShapeException(func) 448 449 nnz = fake_mode.shape_env.create_unbacked_symint() 450 451 # see nonzero for commentary 452 maxval = sys.maxsize - 1 453 454 # Avoid importing sympy at a module level 455 from torch.fx.experimental.symbolic_shapes import ( 456 _constrain_range_for_size, 457 has_free_symbols, 458 ) 459 from torch.utils._sympy.numbers import IntInfinity 460 from torch.utils._sympy.value_ranges import bound_sympy 461 462 # If num elements is expressed symbolically, calculate 463 # the concrete value based on upper bounds. Otherwise, 464 # we can set max val directly. 465 if not has_free_symbols(self.numel()): 466 num_elements = int(self.numel()) 467 else: 468 prod_node = math.prod(self.shape).node 469 prod_range = bound_sympy(prod_node.expr, prod_node.shape_env.var_to_range) 470 if isinstance(prod_range.upper, IntInfinity): 471 num_elements = sys.maxsize - 1 472 else: 473 num_elements = prod_range.upper 474 if num_elements > 2: 475 maxval = num_elements 476 477 _constrain_range_for_size(nnz, max=maxval) 478 479 return self.new_empty((nnz,)) 480 481 482# NB: this must be ordered after local_scalar_dense 483@register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags) 484def data_dep(fake_mode, func, *args, **kwargs): 485 raise DataDependentOutputException(func) 486 487 488# Bool Indices get Expanded as Masks 489# See: IndexingUtils.h:expandTensors 490def check_no_bool_index_tensors(func, self, indices): 491 for index in indices: 492 if index is not None and index.dtype in (torch.bool, torch.uint8): 493 raise DynamicOutputShapeException(func) 494 495 496def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs): 497 _, new_kwargs = normalize_function( 498 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 499 ) 500 501 out_device = new_kwargs["input"].device 502 with in_kernel_invocation_manager(fake_mode): 503 out = func(*args, **kwargs) 504 if not is_noncontiguous_supported(out_device): 505 out = out.new_empty(out.shape) 506 507 if out is new_kwargs["input"]: 508 return out # copy_ 509 return FakeTensor(fake_mode, out, out_device) 510 511 512_is_builtin_namespaces = ordered_set("aten", "prims", "prim") 513 514 515def is_builtin(op): 516 return op.namespace in _is_builtin_namespaces 517 518 519def has_meta(func): 520 return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta") 521 522 523@register_op_impl( 524 lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func) 525) 526def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs): 527 tensor_lists = [] 528 for arg in itertools.chain(args, kwargs.values()): 529 if ( 530 isinstance(arg, (list, tuple)) 531 and len(arg) 532 and isinstance(arg[0], torch.Tensor) 533 ): 534 tensor_lists.append(arg) 535 536 try: 537 with in_kernel_invocation_manager(fake_mode): 538 out_meta = func(*args, **kwargs) 539 except NotImplementedError as not_implemented_error: 540 return NotImplemented 541 542 if not out_meta: 543 return out_meta 544 545 assert tensor_lists 546 out_fake = [] 547 548 for i, meta_t in enumerate(out_meta): 549 device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists]) 550 out_fake.append( 551 fake_mode.fake_tensor_converter.from_meta_and_device( 552 fake_mode, meta_t, device 553 ) 554 ) 555 556 return out_fake 557 558 559# Dont default to default device handling, 560# Since op can take in non-zero sized cpu 561# index tensors with cuda self 562@register_op_impl(aten.index.Tensor) 563def index_tensor(fake_mode, func, *args, **kwargs): 564 from torch._meta_registrations import meta_index_Tensor 565 566 _, new_kwargs = normalize_function( 567 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 568 ) 569 570 out_device = new_kwargs["input"].device 571 # ensure nonzero call goes to fake tensor 572 with fake_mode: 573 out = meta_index_Tensor(*args, **kwargs) 574 return out.to(out_device) 575 576 577# Can take mixed meta/non-meta arguments; the meta registration 578# will roughly do the right thing even when given real devices 579@register_op_impl(aten._embedding_bag.default) 580def embedding_bag(fake_mode, func, *args, **kwargs): 581 from torch._meta_registrations import meta_embedding_bag 582 583 with fake_mode: 584 return meta_embedding_bag(*args, **kwargs) 585 586 587# takes in multiple-devices, dont default to default device handling 588@register_op_impl(aten._unsafe_index_put.default) 589@register_op_impl(aten.copy.default) 590@register_op_impl(aten.copy_.default) 591@register_op_impl(aten.slice_scatter.default) 592def multi_device_op_default(fake_mode, func, *args, **kwargs): 593 return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) 594 595 596# same with multi_device_op_default, but return the input 597@register_op_impl(aten.copy.out) 598@register_op_impl(aten.slice_scatter.out) 599def multi_device_op_out(fake_mode, func, *args, **kwargs): 600 with in_kernel_invocation_manager(fake_mode): 601 out = func(*args, **kwargs) 602 603 _, new_kwargs = normalize_function( 604 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 605 ) 606 607 return new_kwargs["input"] 608 609 610@register_op_impl(aten.index_put.default) 611@register_op_impl(aten.index_put_.default) 612def index_put_impl(fake_mode, func, *args, **kwargs): 613 _, new_kwargs = normalize_function( 614 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 615 ) 616 617 values = new_kwargs["values"] 618 self_device = new_kwargs["input"].fake_device 619 torch._check( 620 self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1), 621 lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})", 622 ) 623 624 out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) 625 if func is aten.index_put_.default: 626 return new_kwargs["input"] 627 else: 628 return out 629 630 631@register_op_impl(aten._nested_tensor_from_tensor_list.default) 632@register_op_impl(aten._nested_tensor_from_tensor_list.out) 633@register_op_impl(aten._nested_view_from_buffer.default) 634@register_op_impl(aten._nested_view_from_buffer_copy.default) 635def nested_tensors_unsupported(fake_mode, func, *args, **kwargs): 636 raise UnsupportedOperatorException( 637 "torch.compile does not support strided NestedTensor" 638 ) 639 640 641@register_op_impl( 642 [ 643 x 644 for x in _device_not_kwarg_ops 645 if x 646 not in ( 647 # these are already registered elsewhere 648 aten.is_pinned.default, 649 aten.to.device, 650 aten.to.prim_Device, 651 aten._nested_tensor_from_tensor_list.default, 652 aten._nested_tensor_from_tensor_list.out, 653 ) 654 ] 655) 656def nyi(fake_mode, func, *args, **kwargs): 657 assert func not in _device_not_kwarg_ops, f"NYI: {func}" 658 659 660@register_op_impl([aten.convolution.default, aten.convolution_backward.default]) 661def conv(fake_mode, func, *args, **kwargs): 662 _, kwargs = normalize_function( 663 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True 664 ) 665 device = kwargs["input"].fake_device 666 # need to re-enable mode so the tensors report fake device 667 with fake_mode: 668 # if the input is unsqueezed is done in Convolution.cpp we get segfault 669 k = kwargs["weight"].ndim 670 batch = kwargs["input"].shape[0] 671 672 # Avoid importing sympy at a module level 673 from torch.fx.experimental.symbolic_shapes import has_hint 674 675 if not has_hint(batch): 676 # TODO: We can make this a little more faithful with best effort 677 # channels last detection (but only if it's statically obvious!) 678 mem_fmt = None 679 elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu: 680 mem_fmt = None 681 else: 682 if func is aten.convolution.default: 683 conv_backend = torch._C._select_conv_backend(**kwargs) 684 else: 685 conv_backend = torch._C._select_conv_backend( 686 kwargs["input"], 687 kwargs["weight"], 688 bias=None, 689 stride=kwargs["stride"], 690 padding=kwargs["padding"], 691 dilation=kwargs["dilation"], 692 transposed=kwargs["transposed"], 693 output_padding=kwargs["output_padding"], 694 groups=kwargs["groups"], 695 bias_sizes=kwargs["bias_sizes"], 696 ) 697 mem_fmt = torch._C._conv_determine_backend_memory_format( 698 kwargs["input"], kwargs["weight"], conv_backend 699 ) 700 701 def convert(t, mem_fmt): 702 if t is None: 703 return t 704 if mem_fmt is not None: 705 t = t.to(memory_format=mem_fmt) 706 return FakeTensor(fake_mode, t, device) 707 708 with in_kernel_invocation_manager(fake_mode): 709 out = func(**kwargs) 710 711 if func is aten.convolution.default: 712 return convert(out, mem_fmt) 713 else: 714 return ( 715 convert(out[0], mem_fmt), 716 convert(out[1], mem_fmt), 717 convert(out[2], None), 718 ) 719 720 721@register_op_impl(torch.ops.aten._pack_padded_sequence.default) 722def _pack_padded_sequence(fake_mode, func, inputs, lengths, batch_first): 723 if ( 724 fake_mode.shape_env is None 725 or not fake_mode.shape_env.allow_dynamic_output_shape_ops 726 ): 727 # Without symints/symfloats, cannot handle this 728 raise DynamicOutputShapeException(func) 729 730 new_batch_size = fake_mode.shape_env.create_unbacked_symint() 731 732 from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size 733 734 _constrain_range_for_size(new_batch_size) 735 736 if not batch_first: 737 # Inputs should have shape (batch_size, seq_len, *) 738 inputs = inputs.transpose(0, 1) 739 740 res_size = inputs.shape[1:] 741 packed_data = inputs.new_empty(res_size) 742 batch_size = inputs.new_empty((new_batch_size,)) 743 return (packed_data, batch_size) 744 745 746FAST_OP_IMPLEMENTATIONS = {} 747 748 749# Unlike register_op_impl, these don't do the slow iteration for 750# run_impl_check, and these run BEFORE decompositions 751def register_fast_op_impl(func: OpOverload): 752 def impl_decorator(op_impl): 753 FAST_OP_IMPLEMENTATIONS[func] = op_impl 754 return op_impl 755 756 return impl_decorator 757 758 759# infer_size_impl in ExpandUtils 760def infer_size(a, b): 761 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 762 763 dimsA = len(a) 764 dimsB = len(b) 765 ndim = max(dimsA, dimsB) 766 expandedSizes = [0] * ndim 767 for i in range(ndim - 1, -1, -1): 768 offset = ndim - 1 - i 769 dimA = dimsA - 1 - offset 770 dimB = dimsB - 1 - offset 771 sizeA = a[dimA] if dimA >= 0 else 1 772 sizeB = b[dimB] if dimB >= 0 else 1 773 774 # NB: It is very important to test for broadcasting, before testing 775 # sizeA == sizeB. This is because the broadcasting tests are likely 776 # to be statically known (in particular, if sizeA/sizeB is unbacked 777 # but size-like, we will unsoundly assume they never equal 1), but 778 # the sizeA == sizeB test may not be statically known. However, once 779 # we have established that no broadcasting is happening, the 780 # sizeA == sizeB is now expect_true and we can defer it as a runtime 781 # assert (this works because Python will return the terminal 782 # expression of an or statement as-is, without bool()'ing it; if this 783 # were not the case, we'd need to write this using torch.sym_or() or 784 # something like that). 785 torch._check( 786 guard_size_oblivious(sizeA == 1) 787 or guard_size_oblivious(sizeB == 1) 788 or sizeA == sizeB, 789 lambda: f"The size of tensor a ({sizeA}) " 790 f"must match the size of tensor b ({sizeB}) " 791 f"at non-singleton dimension {i})", 792 ) 793 expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA 794 return tuple(expandedSizes) 795 796 797def make_fast_binary_impl(slow_ref): 798 def fast_binary_impl(mode, *args, **kwargs): 799 def slow(msg): 800 count_label(f"slow {msg}") 801 with mode: 802 return slow_ref(*args, **kwargs) 803 804 count_label("attempt fast") 805 806 # Fast path (based off of TensorIterator fast path). 807 # Unfortunately, there is no way to easily deduplicate 808 # this with either the TensorIterator C++ implementation 809 # (which we don't want to SymIntify, and also the algorithm 810 # here is slightly different from TensorIterator to allow 811 # for broadcasting), nor the PrimTorch implementation 812 # (which does not actually implement a fast path.) 813 814 operands = args 815 816 # compute_shape 817 has_scalars = False 818 has_tensors = False 819 final_shape = None 820 for op in operands: 821 shape = op.shape if isinstance(op, torch.Tensor) else () 822 if len(shape) == 0: 823 has_scalars = True 824 else: 825 has_tensors = True 826 if final_shape is None: 827 final_shape = shape 828 # TODO: Minor optimization: track if the shapes 829 # were equal so you can skip the equality check 830 # below if unnecessary 831 final_shape = infer_size(final_shape, shape) 832 assert final_shape is not None 833 834 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq 835 836 # Do some extra safety checks to see if the output 837 # stride is obvious 838 for op in operands: 839 if ( 840 isinstance(op, torch.Tensor) 841 and len(op.shape) == len(final_shape) 842 and guard_size_oblivious(sym_eq(op.shape, final_shape)) 843 ): 844 break 845 else: 846 return slow("both tensors nontrivially broadcast") 847 848 # compute_types 849 cpu = torch.device("cpu") 850 common_device = cpu 851 common_dtype = None 852 output_dtype = None 853 has_different_input_dtypes = False 854 for op in operands: 855 if not isinstance(op, torch.Tensor): 856 # Use elementwise_dtypes for the tricky case 857 has_different_input_dtypes = True 858 continue 859 if common_device == cpu and not op.device.type == "cpu": 860 common_device = op.device 861 # Slightly simplified here as target_dtype cannot vary 862 if common_dtype is None: 863 common_dtype = op.dtype 864 elif common_dtype != op.dtype: 865 has_different_input_dtypes = True 866 867 if has_different_input_dtypes: 868 # compute promotion 869 # TODO: we don't need the compute type 870 _, common_dtype = elementwise_dtypes( 871 *operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 872 ) 873 874 # check all tensors on same device 875 # cpu scalars are assumed allow 876 current_cpu_scalars_on_non_cpu = 0 877 max_cpu_scalars_on_non_cpu = 1 # hard coded atm 878 for op in operands: 879 if not isinstance(op, torch.Tensor): 880 continue 881 if common_device != cpu and op.dim() == 0 and op.device == cpu: 882 if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu: 883 return slow("error") 884 current_cpu_scalars_on_non_cpu += 1 885 elif op.device != common_device: 886 return slow("error") 887 888 # compute_fast_setup_type 889 is_contiguous = True 890 is_channels_last = True 891 # TODO: is_non-overlapping_and_dense (not bound from Python 892 # no inplace, no out, everything defined 893 894 if is_noncontiguous_supported(common_device): 895 for op in operands: 896 if not isinstance(op, torch.Tensor): 897 continue 898 is_contiguous = is_contiguous and op.is_contiguous( 899 memory_format=torch.contiguous_format 900 ) 901 is_channels_last = is_channels_last and op.is_contiguous( 902 memory_format=torch.channels_last 903 ) 904 if is_contiguous: 905 # do contiguous 906 count_label("fast is_contiguous") 907 return FakeTensor( 908 mode, 909 torch.empty( 910 final_shape, 911 dtype=common_dtype, 912 device="meta", 913 memory_format=torch.contiguous_format, 914 ), 915 device=common_device, 916 ) 917 if is_channels_last: 918 count_label("fast channels_last") 919 # do channels last 920 return FakeTensor( 921 mode, 922 torch.empty( 923 final_shape, 924 dtype=common_dtype, 925 device="meta", 926 memory_format=torch.channels_last, 927 ), 928 device=common_device, 929 ) 930 931 return slow("no contiguity match") 932 933 return fast_binary_impl 934 935 936# disable the python dispatcher to avoid decomposing detach() further 937# (proxy_mode should still decompose detach() though) 938def fast_detach(fake_mode, x): 939 with no_python_dispatcher(), in_kernel_invocation_manager(fake_mode): 940 out = torch.ops.aten.detach.default(x) 941 return FakeTensor(fake_mode, out, x.device) 942 943 944@functools.lru_cache(None) 945def get_fast_op_impls(): 946 import torch._refs 947 948 register_fast_op_impl(torch.ops.aten.add.Tensor)( 949 make_fast_binary_impl(torch._refs.add) 950 ) 951 register_fast_op_impl(torch.ops.aten.sub.Tensor)( 952 make_fast_binary_impl(torch._refs.sub) 953 ) 954 register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type] 955 register_fast_op_impl(torch.ops.aten.div.Tensor)( 956 make_fast_binary_impl(torch._refs.div) 957 ) 958 register_fast_op_impl(torch.ops.aten.detach.default)(fast_detach) 959 return FAST_OP_IMPLEMENTATIONS 960