1# mypy: allow-untyped-defs 2import contextlib 3 4import warnings 5from dataclasses import dataclass 6from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque, Type 7from typing_extensions import TypeGuard 8from collections import deque 9 10import torch 11import torchgen 12import torchgen.model 13from torch._C import ( 14 _get_dispatch_stack_at, 15 _len_torch_dispatch_stack, 16 _pop_torch_dispatch_stack, 17 _push_on_torch_dispatch_stack, 18 DispatchKey, 19) 20 21 22# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it: 23# - We need a better user-facing api for _DisableTorchDispatch that 24# is able to selectively disable __torch_dispatch__ of a particular class. 25# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor) 26# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694) 27 28_is_in_torch_dispatch_mode = False 29_is_in_non_infra_torch_dispatch_mode = False 30 31def is_in_torch_dispatch_mode(include_infra_modes=True) -> bool: 32 return _is_in_torch_dispatch_mode if include_infra_modes else _is_in_non_infra_torch_dispatch_mode 33 34 35class TorchDispatchMode: 36 """ 37 A ``TorchDispatchMode`` allows you to override the meaning of all 38 ``__torch_dispatch__`` overrideable functions within a dynamic scope, 39 without having to actually create a tensor subclass or manually 40 monkey-patch functions in the PyTorch API. Some common situations 41 where you should use a mode: 42 43 * You want to override the meaning of factory functions, or other 44 functions that do not otherwise take a tensor as an argument 45 (these cannot be overridden with tensor subclasses). 46 47 * You want to override the behavior of all functions without needing 48 to wrap your inputs in tensor subclasses; e.g., if you are just 49 interested in logging intermediate computations. 50 51 * You want to control the order of execution of various tensor 52 subclasses explicitly, rather than implicitly via the return of 53 ``NotImplemented``. 54 55 Independent subclasses of :class:`TorchDispatchMode` are compositional: 56 modes can be pushed onto a stack using ``with MyMode():``. 57 When you call functions in the PyTorch API inside your 58 ``__torch_dispatch__`` implementation, by default, they will forward on to 59 the next mode on the mode stack. If you want recursively call back into 60 your current ``__torch_dispatch__`` implementation, either explicitly 61 invoke ``self.__torch_dispatch__(...)``, or use the context manager 62 ``__torch_dispatch__(self)`` to make PyTorch 63 API self-referential (beware of infinite loops, in this case!) 64 """ 65 66 def __init__(self, _dispatch_key=None): 67 if _dispatch_key is not None: 68 assert isinstance(_dispatch_key, torch._C.DispatchKey) 69 self.__dict__["_dispatch_key"] = _dispatch_key 70 71 self.old_dispatch_mode_flags: Deque[bool] = deque() 72 self.old_non_infra_dispatch_mode_flags: Deque[bool] = deque() 73 74 def _lazy_init_old_dispatch_mode_flags(self): 75 if not hasattr(self, "old_dispatch_mode_flags"): 76 self.old_dispatch_mode_flags: Deque[bool] = deque() # type: ignore[no-redef] 77 78 if not hasattr(self, "old_non_infra_dispatch_mode_flags"): 79 self.old_non_infra_dispatch_mode_flags: Deque[bool] = deque() # type: ignore[no-redef] 80 81 82 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 83 raise NotImplementedError 84 85 def __enter__(self): 86 global _is_in_torch_dispatch_mode 87 global _is_in_non_infra_torch_dispatch_mode 88 # Previously, there wasn't any state in this class' constructor 89 # super calls were added to existing modes, but for any new modes 90 # this will replicate the previous behavior of not strictly needing 91 # to call super().__init__() 92 self._lazy_init_old_dispatch_mode_flags() 93 self.old_dispatch_mode_flags.append(_is_in_torch_dispatch_mode) 94 _is_in_torch_dispatch_mode = True 95 self.old_non_infra_dispatch_mode_flags.append(_is_in_non_infra_torch_dispatch_mode) 96 _is_in_non_infra_torch_dispatch_mode = _is_in_non_infra_torch_dispatch_mode or not self.is_infra_mode() 97 _push_mode(self) 98 return self 99 100 def __exit__(self, exc_type, exc_val, exc_tb): 101 mb_dk_or_mode_key = self.__dict__.get("_dispatch_key", None) 102 if mb_dk_or_mode_key is None: 103 # Today, mode keys are not used at all in the per-dispatch-key-mode logic (for pre-dispatch) 104 # We should probably revisit this. 105 mb_dk_or_mode_key = self.__dict__.get("_mode_key", None) 106 global _is_in_torch_dispatch_mode 107 _is_in_torch_dispatch_mode = self.old_dispatch_mode_flags.pop() 108 global _is_in_non_infra_torch_dispatch_mode 109 _is_in_non_infra_torch_dispatch_mode = self.old_non_infra_dispatch_mode_flags.pop() 110 _pop_mode(mb_dk_or_mode_key) 111 112 @classmethod 113 def push(cls, *args, **kwargs): 114 warnings.warn( 115 "`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`" 116 ) 117 instance = cls(*args, **kwargs) 118 return instance 119 120 @classmethod 121 def is_infra_mode(cls): 122 return False 123 124 125 126def _get_current_dispatch_mode(): 127 stack_len = _len_torch_dispatch_stack() 128 # Return a user mode on the stack if there are any 129 if stack_len > 0: 130 return _get_dispatch_stack_at(stack_len - 1) 131 return None 132 133 134def _detect_infra_mode(key): 135 assert key in [torch._C._TorchDispatchModeKey.FUNCTIONAL, torch._C._TorchDispatchModeKey.PROXY] 136 from torch._ops import _get_dispatch_mode_pre_dispatch 137 138 pre_dispatch_mode = _get_dispatch_mode_pre_dispatch( 139 key 140 ) 141 post_dispatch_mode = torch._C._get_dispatch_mode( 142 key 143 ) 144 145 assert (pre_dispatch_mode is None) or ( 146 post_dispatch_mode is None 147 ) 148 149 if pre_dispatch_mode is None: 150 return post_dispatch_mode 151 152 return pre_dispatch_mode 153 154 155def _unset_infra_mode(key): 156 from torch._ops import _get_dispatch_mode_pre_dispatch, unset_mode_pre_dispatch 157 158 pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(key) 159 post_dispatch_mode = torch._C._get_dispatch_mode(key) 160 if pre_dispatch_mode and post_dispatch_mode: 161 raise AssertionError( 162 "Can't have active infra mode on both pre and post dispatch mode stack" 163 ) 164 165 if pre_dispatch_mode: 166 mode = unset_mode_pre_dispatch(key) 167 return mode 168 if post_dispatch_mode: 169 return torch._C._unset_dispatch_mode(key) 170 171 172def _disable_infra_mode(key): 173 assert key in ( 174 torch._C._TorchDispatchModeKey.FUNCTIONAL, 175 torch._C._TorchDispatchModeKey.PROXY, 176 ) 177 mode_unset = _unset_infra_mode(key) 178 try: 179 yield mode_unset 180 finally: 181 if mode_unset is not None: 182 _push_mode(mode_unset) 183 184 185def _get_current_dispatch_mode_stack(): 186 stack_len = _len_torch_dispatch_stack() 187 return [_get_dispatch_stack_at(i) for i in range(stack_len)] 188 189 190def _push_mode(mode: TorchDispatchMode): 191 k = mode._dispatch_key if hasattr(mode, "_dispatch_key") else None 192 assert k is None or k == torch._C.DispatchKey.PreDispatch 193 if k is None: 194 _push_on_torch_dispatch_stack(mode) 195 return 196 197 from torch._ops import _set_mode_pre_dispatch, get_cached_ops 198 199 # See Note [Not Caching Per-Dispatch-Key Mode Handlers] 200 # Clear the cache of every op that has been used so far, for this particular key. 201 ks = torch._C._functionality_to_backend_keys(k) 202 for op in get_cached_ops(): 203 for key in ks: 204 op._uncache_dispatch(key) 205 _set_mode_pre_dispatch(mode) 206 207 208def _pop_mode(k: Optional[Union[DispatchKey, torch._C._TorchDispatchModeKey]] = None): 209 if k == torch._C.DispatchKey.PreDispatch: # type: ignore[attr-defined] 210 from torch._ops import _pop_mode_from_pre_dispatch 211 212 return _pop_mode_from_pre_dispatch() 213 214 if k is None or isinstance(k, torch._C._TorchDispatchModeKey): 215 return _pop_torch_dispatch_stack(k) 216 217 218@contextlib.contextmanager 219def _pop_mode_temporarily(k: Optional[DispatchKey] = None): 220 old = _pop_mode(k) 221 try: 222 yield old 223 finally: 224 _push_mode(old) 225 226 227@contextlib.contextmanager 228def _disable_current_modes(): 229 from torch._ops import ( 230 _len_torch_dispatch_stack_pre_dispatch, 231 _pop_mode_from_pre_dispatch, 232 ) 233 from torch._subclasses.functional_tensor import FunctionalTensorMode 234 from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode 235 from torch._subclasses.schema_check_mode import SchemaCheckMode 236 237 mode_len_pre_dispatch = _len_torch_dispatch_stack_pre_dispatch() 238 old_pre_dispatch_modes = [ 239 _pop_mode_from_pre_dispatch() for _ in range(mode_len_pre_dispatch) 240 ] 241 242 has_proxy_mode_in_pre_dispatch = False 243 has_functional_mode_in_pre_dispatch = False 244 has_schema_check_mode_in_pre_dispatch = False 245 246 for i in old_pre_dispatch_modes: 247 if isinstance(i, ProxyTorchDispatchMode): 248 has_proxy_mode_in_pre_dispatch = True 249 if isinstance(i, FunctionalTensorMode): 250 has_functional_mode_in_pre_dispatch = True 251 if isinstance(i, SchemaCheckMode): 252 has_schema_check_mode_in_pre_dispatch = True 253 254 mode_len = _len_torch_dispatch_stack() 255 old_modes = [_pop_mode() for _ in range(mode_len)] 256 257 for old in old_modes: 258 if ( 259 isinstance(old, FunctionalTensorMode) 260 and has_functional_mode_in_pre_dispatch 261 ): 262 raise AssertionError( 263 "Can't have FunctionalMode available both in PreDispatch and Python Key" 264 ) 265 if isinstance(old, ProxyTorchDispatchMode) and has_proxy_mode_in_pre_dispatch: 266 raise AssertionError( 267 "Can't have ProxyTorchDispatchMode available both in PreDispatch and Python Key" 268 ) 269 if ( 270 isinstance(old, SchemaCheckMode) 271 and has_schema_check_mode_in_pre_dispatch 272 ): 273 raise AssertionError( 274 "Can't have SchemaCheckMode available both in PreDispatch and Python Key" 275 ) 276 277 # Manually disable proxy and fake modes, if any are active 278 try: 279 yield old_pre_dispatch_modes + old_modes 280 finally: 281 for mode in reversed(old_modes): 282 _push_mode(mode) 283 for mode in reversed(old_pre_dispatch_modes): 284 _push_mode(mode) 285 286 287class BaseTorchDispatchMode(TorchDispatchMode): 288 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 289 if kwargs is None: 290 kwargs = {} 291 return func(*args, **kwargs) 292 293 294# Subtypes which have __tensor_flatten__ and __tensor_unflatten__. 295class TensorWithFlatten(Protocol): 296 def __tensor_flatten__(self) -> Tuple[Sequence[str], object]: 297 ... 298 299 @staticmethod 300 def __tensor_unflatten__(inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int) -> torch.Tensor: 301 ... 302 303 # It would be really nice to be able to say that the return of 304 # is_traceable_wrapper_subclass() is Intersection[torch.Tensor, 305 # TensorWithFlatten] - but that doesn't exist. 306 307 shape: torch._C.Size 308 309 @overload 310 def stride(self, dim: None = None) -> Tuple[int, ...]: 311 ... 312 313 @overload 314 def stride(self, dim: int) -> int: 315 ... 316 317 def dim(self) -> int: 318 ... 319 320 @overload 321 def to( 322 self, 323 dtype: torch.types._dtype, 324 non_blocking: bool = False, 325 copy: bool = False, 326 *, 327 memory_format: Optional[torch.memory_format] = None 328 ) -> torch.Tensor: 329 ... 330 331 @overload 332 def to( 333 self, 334 device: Optional["torch._prims_common.DeviceLikeType"] = None, 335 dtype: Optional[torch.types._dtype] = None, 336 non_blocking: bool = False, 337 copy: bool = False, 338 *, 339 memory_format: Optional[torch.memory_format] = None 340 ) -> torch.Tensor: 341 ... 342 343 @overload 344 def to( 345 self, 346 other: torch.Tensor, 347 non_blocking: bool = False, 348 copy: bool = False, 349 *, 350 memory_format: Optional[torch.memory_format] = None 351 ) -> torch.Tensor: 352 ... 353 354 355 356 357def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]: 358 """ 359 Returns whether or not a tensor subclass that implements __torch_dispatch__ 360 is 'traceable' with torch.compile. 361 In order for a tensor subclass to support TorchDispatchMode-style tracing in PT2, 362 It must implement two magic methods: __tensor_flatten__ and __tensor_unflatten__. 363 It is also expected to obey some restrictions around traceability and aliasing: 364 * The subclass's __torch_dispatch__() implementation should desugar into pytorch 365 dispatcher operations that can be traced into a graph. 366 * The subclass should use return_and_correct_aliasing(). This is needed today to make 367 sure that torch.compile does the right thing in a few cases around input mutation 368 and output aliasing. 369 370 Expected magic method signatures: 371 attrs, ctx = t.__tensor_flatten__() 372 attrs: list of attribute name strings for inner tensors 373 ctx: dict containing any other subclass-specific metadata needed for unflattening 374 375 t = MySubClass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride) 376 inner_tensors: dict mapping attribute name -> tensor for each inner tensor 377 ctx: dict with subclass metadata in the form that __tensor_flatten__() produces 378 outer_size: expected (possibly symbolic) size that the returned subclass 379 instance should have. Note that this arg is useful for certain subclasses 380 that require the shape info to be constructed. In most cases, this arg can be 381 safely ignored. 382 outer_stride: expected (possibly symbolic) stride that the returned subclass 383 instance should have. Note that this arg is useful for certain subclasses 384 that require the stride info to be constructed. In most cases, this arg can be 385 safely ignored. 386 """ 387 is_subclass = isinstance(t, torch.Tensor) and type(t) != torch.Tensor 388 return ( 389 is_subclass 390 and hasattr(t, "__tensor_flatten__") 391 and hasattr(t, "__tensor_unflatten__") 392 ) 393 394def is_traceable_wrapper_subclass_type(t: Type) -> TypeGuard[Type[TensorWithFlatten]]: 395 """Same as above, but takes a type argument instead of an instance.""" 396 return (issubclass(t, torch.Tensor) and t != torch.Tensor 397 and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__")) 398 399 400def transform_subclass(t, callback, outer_size=None, outer_stride=None): 401 """ 402 Given a traceable, wrapper tensor subclass ``t`` that implements 403 ``__torch_dispatch__`` and holds some inner tensors, 404 and a callback of type ``Callable[[str, torch.Tensor], torch.Tensor]``, 405 `transform_subclass` will construct a fresh instance of the wrapper tensor subclass. 406 It will do so by grabbing each inner tensor attribute from the wrapper, 407 passing them into ``callback`` to get a transformed tensor, 408 and putting each transformed tensor into the fresh tensor subclass instance. 409 410 Note: this function will not handle ensuring that the fresh subclass 411 gets the same (autograd, and aliasing) metadata as the original tensor. 412 This is generally handled in other subsystems like AOTAutograd. 413 """ 414 outer_size = outer_size if outer_size is not None else t.size() 415 outer_stride = outer_stride if outer_stride is not None else t.stride() 416 417 attrs, ctx = t.__tensor_flatten__() 418 transformed_tensors_dict = {} 419 for attr in attrs: 420 transformed_tensors_dict[attr] = callback(attr, getattr(t, attr)) 421 sub = type(t).__tensor_unflatten__( 422 transformed_tensors_dict, ctx, outer_size, outer_stride 423 ) 424 425 # NB: Purposefully guard here to simplify the inner / outer symbols. 426 # Using sym_eq() for symbolic comparison can result in an expression that's too 427 # difficult to guard on, so we use == here. 428 assert sub.shape == outer_size, ( 429 f"Expected return value from {type(t)}__tensor_unflatten__() to have " 430 f"shape equal to {outer_size}, but got: {sub.shape}" 431 ) 432 assert sub.stride() == outer_stride, ( 433 f"Expected return value from {type(t)}__tensor_unflatten__() to have " 434 f"stride equal to {outer_stride}, but got: {sub.stride()}" 435 ) 436 437 return sub 438 439 440def _correct_storage_aliasing(func, schema_info, args, outs): 441 """ 442 Given: an OpOverload, a SchemaInfo (cached information from torchgen about schema), 443 and the inputs/outputs to the OpOverload, 444 this function checks to see if func is a view operator 445 (by checking if any of the outputs in the op's schema 446 are immutable aliases of inputs). 447 If so, this function manually aliases the storage of the output tensor 448 with its corresponding input tensor alias. 449 It does this by unsafely overwriting the storage field of the output tensor 450 to be the same storage as the input. 451 """ 452 assert isinstance(func, torch._ops.OpOverload) 453 assert isinstance(args, tuple) 454 assert isinstance(outs, (list, tuple)) 455 flat_outs = torch.utils._pytree.tree_leaves(outs) 456 457 def alias_non_inplace_storage(arg, ret): 458 # This is hopefully a reasonable assert: 459 # subclasses that rely on this API for output aliasing 460 # should always return wrapper tensor subclasses for us to manually alias. 461 # in theory if a subclass that needs this API wants to sometimes return 462 # plain tensors, we could remove the assert and just not perform the aliasing, 463 # but it seems safer to learn more about this case first. 464 if is_traceable_wrapper_subclass(arg) or is_traceable_wrapper_subclass(ret): 465 ret_list = ret if isinstance(ret, list) else [ret] 466 for r in ret_list: 467 assert type(arg) == type( 468 r 469 ), f"""Called {str(func)} with input of type {type(arg)} 470and output of type {type(ret)}. But expected types to match.""" 471 # Need to call a non-dispatcher helper, because we explicitly do **not** 472 # want our subclass to intercept the set_() call. 473 # instead, our subclass should directly have its storage swapped out. 474 # we **explicitly** don't want to reset the sizes on ret, if the storage implies a size change. 475 # Why? 476 # The purpose of this API is *not* to change the size/strides of our output- we assume it's already correct. 477 # We just want to "fix up" the storage aliasing, without modifying or output's metadata. 478 # Example: out = inp.expand(inp.shape[0], inp.shape[0]) 479 # This requires swapping the storage of out to be the same as inp, 480 # but we do *not* want it to change the sizes/strides that were compute for out. 481 482 if isinstance(ret, list): 483 for r in ret: 484 torch._functionalize_unsafe_set(r, arg) 485 else: 486 assert isinstance(ret, torch.Tensor), f"type: {type(ret)}" 487 torch._functionalize_unsafe_set(ret, arg) 488 489 def is_read_only_alias_match(arg, ret): 490 shared_aliases = arg.alias_set & ret.alias_set 491 return len(shared_aliases) > 0 and not arg.is_write 492 493 num_args = len(func._schema.arguments) 494 num_returns = len(func._schema.returns) 495 for arg_idx in range(num_args): 496 for return_idx in range(num_returns): 497 if is_read_only_alias_match( 498 schema_info.args[arg_idx], schema_info.outs[return_idx] 499 ): 500 alias_non_inplace_storage(args[arg_idx], outs[return_idx]) 501 502 503# This abstracts over the fact that in return_and_correct_aliasing, 504# we sometimes use torchgen schema parsing (for aten ops, since torchscript's schema parsing is sometimes buggy), 505# and sometimes use torchscript schema parsing (for custom ops, for which torchgen parsing is untested). 506@dataclass 507class AliasInfo: 508 alias_set: Set[str] 509 is_write: bool 510 name: Optional[str] 511 512 513@dataclass 514class SchemaInfo: 515 args: List[AliasInfo] 516 outs: List[AliasInfo] 517 518 519# Can't import torch._ops.OpOverload due to circular reference 520parsed_schema_map: Dict[Any, SchemaInfo] = {} 521 522 523# Given an OpOverload, returns schema information on it. 524# This is cached for efficiency, since it can involve running torchgen 525def get_alias_info(func) -> SchemaInfo: 526 if func in parsed_schema_map: 527 return parsed_schema_map[func] 528 # For ATen ops: use torchgen (since torchscript parser doesn't handle alias annotations 529 # properly for some ops that output tensorlists) 530 if func.namespace == "aten": 531 torchgen_schema_str = str(func._schema) 532 assert torchgen_schema_str.startswith("aten::") 533 # remove the aten:: namespace, which is added by the torchscript parser, 534 # and torchgen doesn't know how to handle 535 torchgen_schema_str = torchgen_schema_str[6:] 536 import re 537 538 # the torchscript parser ends up converting int[2]=1 into int[2]=[1, 1], 539 # which torchgen chokes on. 540 torchgen_schema_str = re.sub(r"=\[[0, ]+\]", "=0", torchgen_schema_str) 541 torchgen_schema_str = re.sub(r"=\[[1, ]+\]", "=1", torchgen_schema_str) 542 # for aten::rot90 543 torchgen_schema_str = torchgen_schema_str.replace("=[0, 1]", "=[0,1]") 544 torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str) 545 arg_schemas = [ 546 AliasInfo( 547 alias_set=( 548 set() if a.annotation is None else set(a.annotation.alias_set) 549 ), 550 is_write=a.annotation is not None and a.annotation.is_write, 551 name=a.name, 552 ) 553 for a in torchgen_schema.arguments.flat_all 554 ] 555 out_schemas = [ 556 AliasInfo( 557 alias_set=( 558 set() if a.annotation is None else set(a.annotation.alias_set) 559 ), 560 is_write=a.annotation is not None and a.annotation.is_write, 561 name=a.name, 562 ) 563 for a in torchgen_schema.returns 564 ] 565 else: 566 # For non-aten ops, torchgen is untested so we rely on torchscript schema parsing 567 arg_schemas = [ 568 AliasInfo( 569 alias_set=( 570 set() if a.alias_info is None else set(a.alias_info.before_set) 571 ), 572 is_write=a.alias_info is not None and a.alias_info.is_write, 573 name=a.name, 574 ) 575 for a in func._schema.arguments 576 ] 577 out_schemas = [ 578 AliasInfo( 579 alias_set=( 580 set() if a.alias_info is None else set(a.alias_info.before_set) 581 ), 582 is_write=a.alias_info is not None and a.alias_info.is_write, 583 name=a.name, 584 ) 585 for a in func._schema.returns 586 ] 587 schema_info = SchemaInfo(args=arg_schemas, outs=out_schemas) 588 parsed_schema_map[func] = schema_info 589 return schema_info 590 591 592def return_and_correct_aliasing(func, args, kwargs, out): 593 """ 594 This function should be used by wrapper tensor ``__torch_dispatch__`` subclasses 595 that would like to work with torch.compile. It ensures that the subclass 596 properly implements the aliasing behavior of every op, 597 which is needed for correctness in AOTAutograd. 598 This function will handle: 599 600 * When we see a view op, we will alias the storages of any 601 input and output tensor subclasses 602 603 * When we see an inplace or out= op, we will directly 604 return the corresponding input tensor, instead of returning 605 a (potentially) fresh output tensor. 606 """ 607 608 # Caching here because torchgen parsing is definitely not fast, and this function is called 609 # once for every op in the graph during functionalization. 610 schema_info = get_alias_info(func) 611 612 def get_write_alias(x): 613 if len(x.alias_set) == 0: 614 return None 615 alias_set = list(x.alias_set) 616 # torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing 617 assert len(alias_set) == 1 618 if x.is_write: 619 return alias_set[0] 620 return None 621 622 def get_arg_from_alias(output_alias, schema_info, args, kwargs): 623 new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc] 624 func, args=args, kwargs=kwargs 625 ) 626 627 arg_indices = [ 628 i for i, a in enumerate(schema_info.args) if output_alias in a.alias_set 629 ] 630 # For any dispatcher op with an output alias, we expect it to map to exactly one alias in the schema's input arguments. 631 assert len(arg_indices) == 1 632 idx = arg_indices[0] 633 arg_info = schema_info.args[idx] 634 if arg_info.name is not None and arg_info.name in new_kwargs: 635 return new_kwargs[arg_info.name] 636 return new_args[idx] 637 638 # Fix up the storages of any outs so that they point to the same storage as the input, 639 # if func is a view op. 640 _correct_storage_aliasing( 641 func, schema_info, args, (out,) if not isinstance(out, tuple) else out 642 ) 643 644 # For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's 645 # metadata is set correctly. 646 if torch.Tag.inplace_view in func.tags: 647 # no_dispatch() to make sure that we secretly change the metadata on the wrapper, 648 # but don't end up dispatching the op anywhere else. 649 mutated_args = [ 650 x 651 for i, x in enumerate(args) 652 if get_write_alias(schema_info.args[i]) is not None 653 ] 654 # Assumption: we have a very small number of inplace_view ops that follow a strict schema: 655 # there is only a single argument that gets its metadata mutated. 656 assert len(mutated_args) == 1 657 # This check exists because we generally *do* want to update the metadata of any wrapper subclasses, 658 # but FunctionalTensor is special: it overrides all size/stride calls to plumb to the inner tensor. 659 # so we don't actually need to update the metadata (and attempting to do so causes errors) 660 from torch._subclasses.functional_tensor import FunctionalTensor 661 662 if not isinstance(mutated_args[0], FunctionalTensor): 663 with torch.utils._mode_utils.no_dispatch(): 664 # See Note: [Fake Tensor Dispatch Keys] 665 # we're borrowing the way it modifies dispatch key TLS. 666 meta_in_tls = torch._C._meta_in_tls_dispatch_include() 667 torch._C._set_meta_in_tls_dispatch_include(True) 668 try: 669 func(*args, **kwargs) 670 finally: 671 torch._C._set_meta_in_tls_dispatch_include(meta_in_tls) 672 673 # Next: we need to make sure to return inputs directly, if the output is a mutable alias (e.g. add_()). 674 675 # simple case: none of our outputs have mutable aliases, so we can return the output as-is 676 if not any(get_write_alias(r) is not None for r in schema_info.outs): 677 return out 678 679 # simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)" 680 if not all(get_write_alias(r) is not None for r in schema_info.outs): 681 raise RuntimeError("Unsupported schema: " + str(func._schema)) 682 683 if len(func._schema.returns) == 1: 684 return get_arg_from_alias( 685 get_write_alias(schema_info.outs[0]), schema_info, args, kwargs 686 ) 687 688 # In the multi-return case, all aten ops return a tuple / list, so cast accordingly. 689 outs_to_return = type(out)( 690 [ 691 ( 692 get_arg_from_alias( 693 get_write_alias(schema_info.outs[i]), schema_info, args, kwargs 694 ) 695 if get_write_alias(r) is not None 696 else o 697 ) 698 for ((i, r), o) in zip(enumerate(schema_info.outs), out) 699 ] 700 ) 701 return outs_to_return 702