xref: /aosp_15_r20/external/pytorch/torch/_dispatch/python.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import itertools
3import unittest.mock
4from contextlib import contextmanager
5from typing import Iterator
6
7import torch
8import torch._C
9import torch._ops
10import torch.utils._python_dispatch
11import torch.utils._pytree as pytree
12
13
14__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
15
16no_python_dispatcher = torch._C._DisablePythonDispatcher
17enable_python_dispatcher = torch._C._EnablePythonDispatcher
18enable_pre_dispatch = torch._C._EnablePreDispatch
19
20CROSSREF_FUNCTIONALIZE = False
21
22
23def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
24    """
25    Warning: the set of overloads this will report is very subtle.  It is precisely
26    the set of torch.ops functions that have actually been accessed from Python
27    (e.g., we actually called torch.ops.aten.blah at some point.  This is DIFFERENT
28    from the set of registered operators, which will in general be a larger set,
29    as this would include all operators which we ran C++ static initializers or
30    Python operator registration on.  This does not eagerly populate the list on
31    torch.ops.aten; this list is lazy!
32
33    In other words, this is good for traversing over everything that has an
34    OpOverload object allocated in Python.  We use it for cache invalidation, but
35    don't rely on this list being complete.
36
37    Note that even if we did report all C++ registered overloads, this isn't guaranteed
38    to be complete either, as a subsequent lazy load of a library which triggers more
39    registrations could add more things to the set.
40    """
41    for ns in torch.ops:
42        packets = getattr(torch.ops, ns)
43        for op_name in packets:
44            packet = getattr(packets, op_name)
45            for overload in packet:
46                yield getattr(packet, overload)
47
48
49@contextmanager
50def suspend_functionalization():
51    f_tls = torch._C._dispatch_tls_is_dispatch_key_included(
52        torch._C.DispatchKey.Functionalize
53    )
54    f_rv = torch._C._functionalization_reapply_views_tls()
55    if f_tls:
56        torch._disable_functionalization()
57    try:
58        yield
59    finally:
60        if f_tls:
61            torch._enable_functionalization(reapply_views=f_rv)
62
63
64def check_tensor_metadata_matches(nv, rv, desc):
65    assert callable(desc)
66    assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
67    assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
68    same_strides, idx = torch._prims_common.check_significant_strides(
69        nv, rv, only_cuda=False
70    )
71    assert (
72        same_strides
73    ), f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
74
75
76def check_metadata_matches(n, r, desc):
77    assert callable(desc)
78    n_vals, n_spec = pytree.tree_flatten(n)
79    r_vals, r_spec = pytree.tree_flatten(r)
80    # TODO: test the specs match; empirically  sometimes we have a tuple
81    # on one side and a list on the other
82    assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
83    for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
84        if not isinstance(rv, torch.Tensor):
85            continue
86        check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
87
88
89class Lit:
90    def __init__(self, s):
91        self.s = s
92
93    def __repr__(self):
94        return self.s
95
96
97def _fmt(a: object) -> object:
98    if isinstance(a, torch.Tensor):
99        return Lit(
100            f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})"
101        )
102    else:
103        return a
104
105
106def make_crossref_functionalize(op, final_key):
107    from torch._subclasses.fake_tensor import FakeTensorMode
108
109    # This case is pretty weird, suppress it for now
110    if op == torch.ops.aten.lift_fresh.default:
111        return final_key
112
113    def handler(*args, **kwargs):
114        fake_mode = FakeTensorMode()
115
116        def fakeify_defun(t):
117            if isinstance(t, torch.Tensor):
118                if torch._is_functional_tensor(t):
119                    r = torch._from_functional_tensor(t)
120                    # NB: This assumes that the inner tensor sizes/strides match
121                    # the outer tensor sizes/strides.  This doesn't necessarily have to
122                    # be the case, see discussion at
123                    # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
124                    assert t.size() == r.size()
125                    assert t.stride() == r.stride()
126                else:
127                    r = t
128                # TODO: suppress guards
129                return fake_mode.from_tensor(r)
130            return t
131
132        def maybe_detach(t):
133            if isinstance(t, torch.Tensor):
134                return t.detach()
135            else:
136                return t
137
138        # TODO: This probably does the wrong thing if you're running other
139        # substantive modes with the normal op outside here
140        with torch.utils._python_dispatch._disable_current_modes(), suspend_functionalization():
141            f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
142            orig_f_args, orig_f_kwargs = pytree.tree_map(
143                maybe_detach, (f_args, f_kwargs)
144            )
145            with fake_mode:
146                f_r = op(*f_args, **f_kwargs)
147        r = op._op_dk(final_key, *args, **kwargs)
148
149        def desc():
150            fmt_args = ", ".join(
151                itertools.chain(
152                    (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
153                    (
154                        f"{k}={pytree.tree_map(_fmt, v)}"
155                        for k, v in orig_f_kwargs.items()
156                    ),
157                )
158            )
159            return f"{op}({fmt_args})"
160
161        check_metadata_matches(f_r, r, desc)
162        return r
163
164    return handler
165
166
167# NB: enabling this is slow, don't do it in a hot loop.  This is purely
168# for debugging purposes.
169@contextmanager
170def enable_crossref_functionalize():
171    for op in all_py_loaded_overloads():
172        op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
173    try:
174        with enable_python_dispatcher(), unittest.mock.patch(
175            "torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True
176        ):
177            yield
178    finally:
179        for op in all_py_loaded_overloads():
180            op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
181