1# mypy: allow-untyped-defs 2import itertools 3import unittest.mock 4from contextlib import contextmanager 5from typing import Iterator 6 7import torch 8import torch._C 9import torch._ops 10import torch.utils._python_dispatch 11import torch.utils._pytree as pytree 12 13 14__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"] 15 16no_python_dispatcher = torch._C._DisablePythonDispatcher 17enable_python_dispatcher = torch._C._EnablePythonDispatcher 18enable_pre_dispatch = torch._C._EnablePreDispatch 19 20CROSSREF_FUNCTIONALIZE = False 21 22 23def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]: 24 """ 25 Warning: the set of overloads this will report is very subtle. It is precisely 26 the set of torch.ops functions that have actually been accessed from Python 27 (e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT 28 from the set of registered operators, which will in general be a larger set, 29 as this would include all operators which we ran C++ static initializers or 30 Python operator registration on. This does not eagerly populate the list on 31 torch.ops.aten; this list is lazy! 32 33 In other words, this is good for traversing over everything that has an 34 OpOverload object allocated in Python. We use it for cache invalidation, but 35 don't rely on this list being complete. 36 37 Note that even if we did report all C++ registered overloads, this isn't guaranteed 38 to be complete either, as a subsequent lazy load of a library which triggers more 39 registrations could add more things to the set. 40 """ 41 for ns in torch.ops: 42 packets = getattr(torch.ops, ns) 43 for op_name in packets: 44 packet = getattr(packets, op_name) 45 for overload in packet: 46 yield getattr(packet, overload) 47 48 49@contextmanager 50def suspend_functionalization(): 51 f_tls = torch._C._dispatch_tls_is_dispatch_key_included( 52 torch._C.DispatchKey.Functionalize 53 ) 54 f_rv = torch._C._functionalization_reapply_views_tls() 55 if f_tls: 56 torch._disable_functionalization() 57 try: 58 yield 59 finally: 60 if f_tls: 61 torch._enable_functionalization(reapply_views=f_rv) 62 63 64def check_tensor_metadata_matches(nv, rv, desc): 65 assert callable(desc) 66 assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}" 67 assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}" 68 same_strides, idx = torch._prims_common.check_significant_strides( 69 nv, rv, only_cuda=False 70 ) 71 assert ( 72 same_strides 73 ), f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})" 74 75 76def check_metadata_matches(n, r, desc): 77 assert callable(desc) 78 n_vals, n_spec = pytree.tree_flatten(n) 79 r_vals, r_spec = pytree.tree_flatten(r) 80 # TODO: test the specs match; empirically sometimes we have a tuple 81 # on one side and a list on the other 82 assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" 83 for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): 84 if not isinstance(rv, torch.Tensor): 85 continue 86 check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}") 87 88 89class Lit: 90 def __init__(self, s): 91 self.s = s 92 93 def __repr__(self): 94 return self.s 95 96 97def _fmt(a: object) -> object: 98 if isinstance(a, torch.Tensor): 99 return Lit( 100 f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})" 101 ) 102 else: 103 return a 104 105 106def make_crossref_functionalize(op, final_key): 107 from torch._subclasses.fake_tensor import FakeTensorMode 108 109 # This case is pretty weird, suppress it for now 110 if op == torch.ops.aten.lift_fresh.default: 111 return final_key 112 113 def handler(*args, **kwargs): 114 fake_mode = FakeTensorMode() 115 116 def fakeify_defun(t): 117 if isinstance(t, torch.Tensor): 118 if torch._is_functional_tensor(t): 119 r = torch._from_functional_tensor(t) 120 # NB: This assumes that the inner tensor sizes/strides match 121 # the outer tensor sizes/strides. This doesn't necessarily have to 122 # be the case, see discussion at 123 # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456 124 assert t.size() == r.size() 125 assert t.stride() == r.stride() 126 else: 127 r = t 128 # TODO: suppress guards 129 return fake_mode.from_tensor(r) 130 return t 131 132 def maybe_detach(t): 133 if isinstance(t, torch.Tensor): 134 return t.detach() 135 else: 136 return t 137 138 # TODO: This probably does the wrong thing if you're running other 139 # substantive modes with the normal op outside here 140 with torch.utils._python_dispatch._disable_current_modes(), suspend_functionalization(): 141 f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs)) 142 orig_f_args, orig_f_kwargs = pytree.tree_map( 143 maybe_detach, (f_args, f_kwargs) 144 ) 145 with fake_mode: 146 f_r = op(*f_args, **f_kwargs) 147 r = op._op_dk(final_key, *args, **kwargs) 148 149 def desc(): 150 fmt_args = ", ".join( 151 itertools.chain( 152 (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args), 153 ( 154 f"{k}={pytree.tree_map(_fmt, v)}" 155 for k, v in orig_f_kwargs.items() 156 ), 157 ) 158 ) 159 return f"{op}({fmt_args})" 160 161 check_metadata_matches(f_r, r, desc) 162 return r 163 164 return handler 165 166 167# NB: enabling this is slow, don't do it in a hot loop. This is purely 168# for debugging purposes. 169@contextmanager 170def enable_crossref_functionalize(): 171 for op in all_py_loaded_overloads(): 172 op._uncache_dispatch(torch._C.DispatchKey.Functionalize) 173 try: 174 with enable_python_dispatcher(), unittest.mock.patch( 175 "torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True 176 ): 177 yield 178 finally: 179 for op in all_py_loaded_overloads(): 180 op._uncache_dispatch(torch._C.DispatchKey.Functionalize) 181