xref: /aosp_15_r20/external/pytorch/test/test_dispatch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dispatch"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport itertools
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport re
6*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch._C as C
9*da0073e9SAndroid Build Coastguard Workerimport torch.utils.cpp_extension
10*da0073e9SAndroid Build Coastguard Workerfrom torch._python_dispatcher import PythonDispatcher
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker# TODO: Expand the dispatcher API to be a generic API for interfacing with
15*da0073e9SAndroid Build Coastguard Worker# the dispatcher from Python!
16*da0073e9SAndroid Build Coastguard Worker#
17*da0073e9SAndroid Build Coastguard Worker# These are exhaustive tests for commutativity of dispatch behavior.  If you're
18*da0073e9SAndroid Build Coastguard Worker# looking for more usage-info style tests, check op_registration_test.cpp
19*da0073e9SAndroid Build Coastguard Worker#
20*da0073e9SAndroid Build Coastguard Worker# Things not tested here:
21*da0073e9SAndroid Build Coastguard Worker#   - Listeners
22*da0073e9SAndroid Build Coastguard Worker#   - Top level namespace registrations
23*da0073e9SAndroid Build Coastguard Worker#   - Fallback
24*da0073e9SAndroid Build Coastguard Worker#   - Exotic overloads of CppFunction/schema
25*da0073e9SAndroid Build Coastguard Worker#
26*da0073e9SAndroid Build Coastguard Worker# Things not directly tested here:
27*da0073e9SAndroid Build Coastguard Worker#   - Internal state of Dispatcher makes sense.  This is indirectly
28*da0073e9SAndroid Build Coastguard Worker#     tested by the invariant testing
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard WorkerResult = namedtuple("Result", "state table provenance")
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workerdispatch_keys_to_check = (
33*da0073e9SAndroid Build Coastguard Worker    "Undefined",
34*da0073e9SAndroid Build Coastguard Worker    "CPU",
35*da0073e9SAndroid Build Coastguard Worker    "CUDA",
36*da0073e9SAndroid Build Coastguard Worker    "XLA",
37*da0073e9SAndroid Build Coastguard Worker    "AutogradOther",
38*da0073e9SAndroid Build Coastguard Worker    "AutogradCPU",
39*da0073e9SAndroid Build Coastguard Worker    "AutogradCUDA",
40*da0073e9SAndroid Build Coastguard Worker    "AutogradXLA",
41*da0073e9SAndroid Build Coastguard Worker)
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Workerdef extract_dispatch_table_with_keys(table, dispatch_keys):
45*da0073e9SAndroid Build Coastguard Worker    extracted = ""
46*da0073e9SAndroid Build Coastguard Worker    table_entries = table.split("\n")
47*da0073e9SAndroid Build Coastguard Worker    regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)")
48*da0073e9SAndroid Build Coastguard Worker    for k in dispatch_keys:
49*da0073e9SAndroid Build Coastguard Worker        for t in table_entries:
50*da0073e9SAndroid Build Coastguard Worker            if t.startswith(k):
51*da0073e9SAndroid Build Coastguard Worker                # mask out file:line info for in-tree backend fallback
52*da0073e9SAndroid Build Coastguard Worker                entry = regex.sub("registered in pytorch framework [", t)
53*da0073e9SAndroid Build Coastguard Worker                extracted += entry + "\n"
54*da0073e9SAndroid Build Coastguard Worker    return extracted
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Workerclass TestDispatch(TestCase):
58*da0073e9SAndroid Build Coastguard Worker    namespace_index = 0
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker    def test_all_invariants(self):
61*da0073e9SAndroid Build Coastguard Worker        # Check that the regular stuff is OK!
62*da0073e9SAndroid Build Coastguard Worker        C._dispatch_check_all_invariants()
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    # You probably don't want to call this directly; if your constructors
65*da0073e9SAndroid Build Coastguard Worker    # don't commute, you can still run commute with a fixed ctor_order
66*da0073e9SAndroid Build Coastguard Worker    # so that you can test that the destructors still commute
67*da0073e9SAndroid Build Coastguard Worker    def run_ops(
68*da0073e9SAndroid Build Coastguard Worker        self,
69*da0073e9SAndroid Build Coastguard Worker        name,
70*da0073e9SAndroid Build Coastguard Worker        ops,
71*da0073e9SAndroid Build Coastguard Worker        ctor_order=None,
72*da0073e9SAndroid Build Coastguard Worker        dtor_order=None,
73*da0073e9SAndroid Build Coastguard Worker        results=None,
74*da0073e9SAndroid Build Coastguard Worker        expect_raises=False,
75*da0073e9SAndroid Build Coastguard Worker    ):
76*da0073e9SAndroid Build Coastguard Worker        """
77*da0073e9SAndroid Build Coastguard Worker        Given a list of operator registrations, run the registrations in the
78*da0073e9SAndroid Build Coastguard Worker        order specified by ctor_order, and then run the deregistrations in
79*da0073e9SAndroid Build Coastguard Worker        dtor_order.
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker        If results is specified, intermediate results are checked for consistency
82*da0073e9SAndroid Build Coastguard Worker        with results stored in results (and stored in results if this is the
83*da0073e9SAndroid Build Coastguard Worker        first time we've seen them).  Results are expected to be equivalent
84*da0073e9SAndroid Build Coastguard Worker        modulo commutativity and inverses (thus, results is keyed on a frozenset
85*da0073e9SAndroid Build Coastguard Worker        of in effect registrations from ops).  Results stores namedtuple
86*da0073e9SAndroid Build Coastguard Worker        Result[state, table, provenance], where state is a string that contains
87*da0073e9SAndroid Build Coastguard Worker        non-derived kernel registered or error message if it doesn't pass;
88*da0073e9SAndroid Build Coastguard Worker        table is a string that contains computed dispatch table entries;
89*da0073e9SAndroid Build Coastguard Worker        provenance is a string that describes how exactly we got this string.
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker        If expect_raises is True, it is not an error to raise an exception.  Instead,
92*da0073e9SAndroid Build Coastguard Worker        we'll store the exception string (instead of the dispatcher state)
93*da0073e9SAndroid Build Coastguard Worker        in results.  In principle we should flag these differently, but it's
94*da0073e9SAndroid Build Coastguard Worker        very obvious when you get an error in one case but not another.
95*da0073e9SAndroid Build Coastguard Worker        """
96*da0073e9SAndroid Build Coastguard Worker        # By allocating every test into a fresh namespace, this makes it less
97*da0073e9SAndroid Build Coastguard Worker        # likely that a bug in the testing framework will result in tests
98*da0073e9SAndroid Build Coastguard Worker        # interfering with each other
99*da0073e9SAndroid Build Coastguard Worker        self.__class__.namespace_index += 1
100*da0073e9SAndroid Build Coastguard Worker        if results is None:
101*da0073e9SAndroid Build Coastguard Worker            results = {}
102*da0073e9SAndroid Build Coastguard Worker        if ctor_order is None:
103*da0073e9SAndroid Build Coastguard Worker            ctor_order = list(range(len(ops)))
104*da0073e9SAndroid Build Coastguard Worker        if dtor_order is None:
105*da0073e9SAndroid Build Coastguard Worker            dtor_order = list(reversed(ctor_order))
106*da0073e9SAndroid Build Coastguard Worker        # Refs which retain the c10::Module object so we can explicitly control
107*da0073e9SAndroid Build Coastguard Worker        # when each deregistration happens (deregistration occurs when the
108*da0073e9SAndroid Build Coastguard Worker        # object gets deallocated).
109*da0073e9SAndroid Build Coastguard Worker        refs = [None] * len(ops)
110*da0073e9SAndroid Build Coastguard Worker        # Keep track of the set "in effect" registrations
111*da0073e9SAndroid Build Coastguard Worker        active_ops = set()
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker        # double underscore to make it less likely we conflict with something
114*da0073e9SAndroid Build Coastguard Worker        # else
115*da0073e9SAndroid Build Coastguard Worker        test_namespace = f"__test{self.namespace_index}__"
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker        def check_invariants(actual_provenance):
118*da0073e9SAndroid Build Coastguard Worker            C._dispatch_check_invariants(name)
119*da0073e9SAndroid Build Coastguard Worker            # Normalize the test namespace so that expected outputs are stable
120*da0073e9SAndroid Build Coastguard Worker            actual_state = C._dispatch_dump(f"{test_namespace}::{name}").replace(
121*da0073e9SAndroid Build Coastguard Worker                test_namespace, "test"
122*da0073e9SAndroid Build Coastguard Worker            )
123*da0073e9SAndroid Build Coastguard Worker            actual_table = C._dispatch_dump_table(f"{test_namespace}::{name}").replace(
124*da0073e9SAndroid Build Coastguard Worker                test_namespace, "test"
125*da0073e9SAndroid Build Coastguard Worker            )
126*da0073e9SAndroid Build Coastguard Worker            expected_state, expected_table, expected_provenance = results.setdefault(
127*da0073e9SAndroid Build Coastguard Worker                frozenset(active_ops),
128*da0073e9SAndroid Build Coastguard Worker                Result(actual_state, actual_table, actual_provenance),
129*da0073e9SAndroid Build Coastguard Worker            )
130*da0073e9SAndroid Build Coastguard Worker            self.assertMultiLineEqual(
131*da0073e9SAndroid Build Coastguard Worker                expected_state,
132*da0073e9SAndroid Build Coastguard Worker                actual_state,
133*da0073e9SAndroid Build Coastguard Worker                f"expected from {expected_provenance}; actual from {actual_provenance}",
134*da0073e9SAndroid Build Coastguard Worker            )
135*da0073e9SAndroid Build Coastguard Worker            self.assertMultiLineEqual(
136*da0073e9SAndroid Build Coastguard Worker                expected_table,
137*da0073e9SAndroid Build Coastguard Worker                actual_table,
138*da0073e9SAndroid Build Coastguard Worker                f"expected from {expected_provenance}; actual from {actual_provenance}",
139*da0073e9SAndroid Build Coastguard Worker            )
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker        results.setdefault(frozenset(), Result("", "", "hardcoded initial state"))
142*da0073e9SAndroid Build Coastguard Worker        check_invariants("initial state")
143*da0073e9SAndroid Build Coastguard Worker        # In the order specified by ctor_order, run registrations
144*da0073e9SAndroid Build Coastguard Worker        set_to_report = frozenset(range(len(ops)))
145*da0073e9SAndroid Build Coastguard Worker        for i, op_ix in enumerate(ctor_order):
146*da0073e9SAndroid Build Coastguard Worker            # It would be better to DEF here, but because we manage
147*da0073e9SAndroid Build Coastguard Worker            # lifetime of multiple registrations with multiple Library
148*da0073e9SAndroid Build Coastguard Worker            # references (refs), we can't deal with the strict checking
149*da0073e9SAndroid Build Coastguard Worker            # from DEF.
150*da0073e9SAndroid Build Coastguard Worker            refs[op_ix] = C._dispatch_library("FRAGMENT", test_namespace, "")
151*da0073e9SAndroid Build Coastguard Worker            active_ops.add(op_ix)
152*da0073e9SAndroid Build Coastguard Worker            try:
153*da0073e9SAndroid Build Coastguard Worker                ops[op_ix](refs[op_ix])
154*da0073e9SAndroid Build Coastguard Worker                check_invariants(f"running ctors {ctor_order[:i + 1]}")
155*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as e:
156*da0073e9SAndroid Build Coastguard Worker                if not expect_raises:
157*da0073e9SAndroid Build Coastguard Worker                    raise
158*da0073e9SAndroid Build Coastguard Worker                actual = str(e).replace(test_namespace, "test")
159*da0073e9SAndroid Build Coastguard Worker                actual = actual.split("\nException raised from ")[0]
160*da0073e9SAndroid Build Coastguard Worker                expected, _, expected_provenance = results.setdefault(
161*da0073e9SAndroid Build Coastguard Worker                    frozenset(active_ops),
162*da0073e9SAndroid Build Coastguard Worker                    Result(
163*da0073e9SAndroid Build Coastguard Worker                        actual, "", f"error after running ctors {ctor_order[:i + 1]}"
164*da0073e9SAndroid Build Coastguard Worker                    ),
165*da0073e9SAndroid Build Coastguard Worker                )
166*da0073e9SAndroid Build Coastguard Worker                self.assertMultiLineEqual(expected, actual, expected_provenance)
167*da0073e9SAndroid Build Coastguard Worker                set_to_report = frozenset(active_ops)
168*da0073e9SAndroid Build Coastguard Worker                active_ops.remove(op_ix)
169*da0073e9SAndroid Build Coastguard Worker                # NB: this finally test asserts that if a registrations fails,
170*da0073e9SAndroid Build Coastguard Worker                # the dispatcher is left in the same state *that it was before*!
171*da0073e9SAndroid Build Coastguard Worker                check_invariants(
172*da0073e9SAndroid Build Coastguard Worker                    f"running ctors {ctor_order[:i]} and then failing to run ctor {op_ix} "
173*da0073e9SAndroid Build Coastguard Worker                    "(did this failure leave the dispatcher in a wedged state? "
174*da0073e9SAndroid Build Coastguard Worker                    "it shouldn't!)"
175*da0073e9SAndroid Build Coastguard Worker                )
176*da0073e9SAndroid Build Coastguard Worker                break
177*da0073e9SAndroid Build Coastguard Worker        last_ctor = i
178*da0073e9SAndroid Build Coastguard Worker        if expect_raises and len(active_ops) == len(ops):
179*da0073e9SAndroid Build Coastguard Worker            # Destroy references first, as some test frameworks (like pytest)
180*da0073e9SAndroid Build Coastguard Worker            # will retain references in the exception raised by assertTrue! EW!
181*da0073e9SAndroid Build Coastguard Worker            refs = None
182*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
183*da0073e9SAndroid Build Coastguard Worker                False,
184*da0073e9SAndroid Build Coastguard Worker                "expected exception to be raised, but nothing was raised "
185*da0073e9SAndroid Build Coastguard Worker                f"(after running ctors {ctor_order})",
186*da0073e9SAndroid Build Coastguard Worker            )
187*da0073e9SAndroid Build Coastguard Worker        # In the order specified by dtor_order, run deregistrations
188*da0073e9SAndroid Build Coastguard Worker        for i, op_ix in enumerate(dtor_order):
189*da0073e9SAndroid Build Coastguard Worker            # Trigger a destruction
190*da0073e9SAndroid Build Coastguard Worker            refs[op_ix] = None
191*da0073e9SAndroid Build Coastguard Worker            # discard not remove, since we may not have actually deregistered
192*da0073e9SAndroid Build Coastguard Worker            # anything if there was an error raised
193*da0073e9SAndroid Build Coastguard Worker            if expect_raises:
194*da0073e9SAndroid Build Coastguard Worker                active_ops.discard(op_ix)
195*da0073e9SAndroid Build Coastguard Worker            else:
196*da0073e9SAndroid Build Coastguard Worker                active_ops.remove(op_ix)
197*da0073e9SAndroid Build Coastguard Worker            check_invariants(
198*da0073e9SAndroid Build Coastguard Worker                f"running ctors {ctor_order[:last_ctor + 1]}, then running dtors {dtor_order[:i + 1]}"
199*da0073e9SAndroid Build Coastguard Worker            )
200*da0073e9SAndroid Build Coastguard Worker        return results[set_to_report][0]
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker    # Operator registrations are commutative (as static initializers can
203*da0073e9SAndroid Build Coastguard Worker    # run in any order) and invertible (by deregistration).  (Subject
204*da0073e9SAndroid Build Coastguard Worker    # to some caveats: some legacy behavior in the system are not commutative--
205*da0073e9SAndroid Build Coastguard Worker    # we want to get rid of these!)
206*da0073e9SAndroid Build Coastguard Worker    #
207*da0073e9SAndroid Build Coastguard Worker    # So while in principle we could simply test a set of operations
208*da0073e9SAndroid Build Coastguard Worker    # by just running them one by one in the order specified by the user,
209*da0073e9SAndroid Build Coastguard Worker    # we can get more assurance about these extra properties by doing
210*da0073e9SAndroid Build Coastguard Worker    # more work:
211*da0073e9SAndroid Build Coastguard Worker    #
212*da0073e9SAndroid Build Coastguard Worker    # 1. Don't run the registrations once in a fixed order: run every possible
213*da0073e9SAndroid Build Coastguard Worker    #    permutation.  Similarly, run every permutation of deregistration order.
214*da0073e9SAndroid Build Coastguard Worker    #
215*da0073e9SAndroid Build Coastguard Worker    # 2. Don't just check the end state of the dispatcher: for every
216*da0073e9SAndroid Build Coastguard Worker    #    subset of operator registrations, ensure that the computed
217*da0073e9SAndroid Build Coastguard Worker    #    intermediate state is path independent.  One thing to note:
218*da0073e9SAndroid Build Coastguard Worker    #    in this function, we assume each operation is unique.  In general,
219*da0073e9SAndroid Build Coastguard Worker    #    there may be duplicated registrations, but these are usually
220*da0073e9SAndroid Build Coastguard Worker    #    idempotent or legacy.  We test for behavior here separately.
221*da0073e9SAndroid Build Coastguard Worker    #
222*da0073e9SAndroid Build Coastguard Worker    # NB: checking all permutations means this function is exponential in
223*da0073e9SAndroid Build Coastguard Worker    # the length of ops!  So don't pass too many ops to this function!
224*da0073e9SAndroid Build Coastguard Worker    def commute(self, name, ops, ctor_order=None, expect_raises=False):
225*da0073e9SAndroid Build Coastguard Worker        results = {}
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker        def go(ctor_order):
228*da0073e9SAndroid Build Coastguard Worker            for dtor_order in itertools.permutations(range(len(ops))):
229*da0073e9SAndroid Build Coastguard Worker                self.run_ops(
230*da0073e9SAndroid Build Coastguard Worker                    name,
231*da0073e9SAndroid Build Coastguard Worker                    ops,
232*da0073e9SAndroid Build Coastguard Worker                    ctor_order,
233*da0073e9SAndroid Build Coastguard Worker                    dtor_order,
234*da0073e9SAndroid Build Coastguard Worker                    results=results,
235*da0073e9SAndroid Build Coastguard Worker                    expect_raises=expect_raises,
236*da0073e9SAndroid Build Coastguard Worker                )
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker        if ctor_order is not None:
239*da0073e9SAndroid Build Coastguard Worker            go(ctor_order)
240*da0073e9SAndroid Build Coastguard Worker        else:
241*da0073e9SAndroid Build Coastguard Worker            for ctor_order in itertools.permutations(range(len(ops))):
242*da0073e9SAndroid Build Coastguard Worker                go(ctor_order)
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker        # Return the "full" Result namedtuple after all operations are run.
245*da0073e9SAndroid Build Coastguard Worker        # If this KeyErrors, that means that there did not exist any
246*da0073e9SAndroid Build Coastguard Worker        # ordering of ctors which got us to the "end".  That's an
247*da0073e9SAndroid Build Coastguard Worker        # error in test construction: it means you could have
248*da0073e9SAndroid Build Coastguard Worker        # factored the test into two smaller ones.
249*da0073e9SAndroid Build Coastguard Worker        return results[frozenset(range(len(ops)))]
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker    def test_def(self):
252*da0073e9SAndroid Build Coastguard Worker        state = self.commute(
253*da0073e9SAndroid Build Coastguard Worker            "foo",
254*da0073e9SAndroid Build Coastguard Worker            [
255*da0073e9SAndroid Build Coastguard Worker                # m.def("foo(Tensor x) -> Tensor")
256*da0073e9SAndroid Build Coastguard Worker                lambda m: m.def_("foo(Tensor x) -> Tensor"),
257*da0073e9SAndroid Build Coastguard Worker                # m.impl("test_def", [](const Tensor& x) { return x })
258*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo"),
259*da0073e9SAndroid Build Coastguard Worker                # m.impl("test_def", kCPU, [](const Tensor& x) { return x })
260*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", dispatch="CPU"),
261*da0073e9SAndroid Build Coastguard Worker                # m.impl("test_def", kAutograd, [](const Tensor& x) { return x })
262*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", dispatch="Autograd"),
263*da0073e9SAndroid Build Coastguard Worker                # m.impl("test_def", kAutogradCPU, [](const Tensor& x) { return x })
264*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", dispatch="AutogradCPU"),
265*da0073e9SAndroid Build Coastguard Worker            ],
266*da0073e9SAndroid Build Coastguard Worker        ).state
267*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
268*da0073e9SAndroid Build Coastguard Worker            state,
269*da0073e9SAndroid Build Coastguard Worker            """\
270*da0073e9SAndroid Build Coastguard Workername: test::foo
271*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor
272*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0
273*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA
274*da0073e9SAndroid Build Coastguard WorkerCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
275*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
276*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
277*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
278*da0073e9SAndroid Build Coastguard Worker""",
279*da0073e9SAndroid Build Coastguard Worker        )
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker    def test_def_impl_schema_mismatch(self):
282*da0073e9SAndroid Build Coastguard Worker        # NB: an impl-impl mismatch is not reported eagerly; you'll find out
283*da0073e9SAndroid Build Coastguard Worker        # about it because one of them won't match with def
284*da0073e9SAndroid Build Coastguard Worker        state = self.commute(
285*da0073e9SAndroid Build Coastguard Worker            "foo",
286*da0073e9SAndroid Build Coastguard Worker            [
287*da0073e9SAndroid Build Coastguard Worker                # m.def("foo(Tensor x, Tensor y) -> Tensor")
288*da0073e9SAndroid Build Coastguard Worker                lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
289*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", [](const Tensor & x) { return x })
290*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo"),
291*da0073e9SAndroid Build Coastguard Worker            ],
292*da0073e9SAndroid Build Coastguard Worker            expect_raises=True,
293*da0073e9SAndroid Build Coastguard Worker        ).state
294*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
295*da0073e9SAndroid Build Coastguard Worker            state,
296*da0073e9SAndroid Build Coastguard Worker            """\
297*da0073e9SAndroid Build Coastguard WorkerInferred operator schema for a C++ kernel function doesn't match the expected function schema.
298*da0073e9SAndroid Build Coastguard Worker  operator: test::foo
299*da0073e9SAndroid Build Coastguard Worker  expected schema: test::foo(Tensor x, Tensor y) -> Tensor
300*da0073e9SAndroid Build Coastguard Worker    registered at /dev/null:0
301*da0073e9SAndroid Build Coastguard Worker  inferred schema: (Tensor _0) -> Tensor _0
302*da0073e9SAndroid Build Coastguard Worker    impl_t_t
303*da0073e9SAndroid Build Coastguard Worker  reason: The number of arguments is different. 2 vs 1.""",
304*da0073e9SAndroid Build Coastguard Worker        )
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker    def test_def_with_inference(self):
307*da0073e9SAndroid Build Coastguard Worker        state = self.commute(
308*da0073e9SAndroid Build Coastguard Worker            "foo",
309*da0073e9SAndroid Build Coastguard Worker            [
310*da0073e9SAndroid Build Coastguard Worker                # m.def("foo", [](const Tensor & x) { return x })
311*da0073e9SAndroid Build Coastguard Worker                lambda m: m.def_name_t_t("foo"),
312*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
313*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "CPU"),
314*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
315*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "Autograd"),
316*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kAutogradCPU, [](const Tensor & x) { return x })
317*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "AutogradCPU"),
318*da0073e9SAndroid Build Coastguard Worker            ],
319*da0073e9SAndroid Build Coastguard Worker        ).state
320*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
321*da0073e9SAndroid Build Coastguard Worker            state,
322*da0073e9SAndroid Build Coastguard Worker            """\
323*da0073e9SAndroid Build Coastguard Workername: test::foo
324*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor _0) -> Tensor _0
325*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0
326*da0073e9SAndroid Build Coastguard Workeralias analysis kind: CONSERVATIVE
327*da0073e9SAndroid Build Coastguard WorkerCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
328*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
329*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
330*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: default_def_name_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
331*da0073e9SAndroid Build Coastguard Worker""",
332*da0073e9SAndroid Build Coastguard Worker        )
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker    def test_def_only(self):
335*da0073e9SAndroid Build Coastguard Worker        state = self.commute(
336*da0073e9SAndroid Build Coastguard Worker            "foo",
337*da0073e9SAndroid Build Coastguard Worker            [
338*da0073e9SAndroid Build Coastguard Worker                # m.def("foo(Tensor x, Tensor y) -> Tensor")
339*da0073e9SAndroid Build Coastguard Worker                lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
340*da0073e9SAndroid Build Coastguard Worker            ],
341*da0073e9SAndroid Build Coastguard Worker        ).state
342*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
343*da0073e9SAndroid Build Coastguard Worker            state,
344*da0073e9SAndroid Build Coastguard Worker            """\
345*da0073e9SAndroid Build Coastguard Workername: test::foo
346*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x, Tensor y) -> Tensor
347*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0
348*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA
349*da0073e9SAndroid Build Coastguard Worker""",
350*da0073e9SAndroid Build Coastguard Worker        )
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker    def test_impl_only(self):
353*da0073e9SAndroid Build Coastguard Worker        state = self.commute(
354*da0073e9SAndroid Build Coastguard Worker            "foo",
355*da0073e9SAndroid Build Coastguard Worker            [
356*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", [](const Tensor& x) { return x })
357*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo"),
358*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCPU, [](const Tensor& x) { return x })
359*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "CPU"),
360*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kAutograd, [](const Tensor& x) { return x })
361*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "Autograd"),
362*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kAutogradCPU, [](const Tensor& x) { return x })
363*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "AutogradCPU"),
364*da0073e9SAndroid Build Coastguard Worker            ],
365*da0073e9SAndroid Build Coastguard Worker        ).state
366*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
367*da0073e9SAndroid Build Coastguard Worker            state,
368*da0073e9SAndroid Build Coastguard Worker            """\
369*da0073e9SAndroid Build Coastguard Workername: test::foo
370*da0073e9SAndroid Build Coastguard Workerschema: (none)
371*da0073e9SAndroid Build Coastguard WorkerCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
372*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
373*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
374*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
375*da0073e9SAndroid Build Coastguard Worker""",
376*da0073e9SAndroid Build Coastguard Worker        )
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker    def test_computed_table(self):
379*da0073e9SAndroid Build Coastguard Worker        result = self.commute(
380*da0073e9SAndroid Build Coastguard Worker            "foo",
381*da0073e9SAndroid Build Coastguard Worker            [
382*da0073e9SAndroid Build Coastguard Worker                # m.def("foo", [](const Tensor & x) { return x })
383*da0073e9SAndroid Build Coastguard Worker                lambda m: m.def_name_t_t("foo"),
384*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
385*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
386*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCUDA, [](const Tensor & x) { return x })
387*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "XLA", debug="fn_xla"),
388*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
389*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"),
390*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kAutogradCPU, [](const Tensor & x) { return x })
391*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "AutogradCPU", debug="fn_autogradcpu"),
392*da0073e9SAndroid Build Coastguard Worker            ],
393*da0073e9SAndroid Build Coastguard Worker        )
394*da0073e9SAndroid Build Coastguard Worker        state, table = result.state, result.table
395*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
396*da0073e9SAndroid Build Coastguard Worker            state,
397*da0073e9SAndroid Build Coastguard Worker            """\
398*da0073e9SAndroid Build Coastguard Workername: test::foo
399*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor _0) -> Tensor _0
400*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0
401*da0073e9SAndroid Build Coastguard Workeralias analysis kind: CONSERVATIVE
402*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
403*da0073e9SAndroid Build Coastguard WorkerXLA: fn_xla :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
404*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: fn_autogradcpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
405*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: fn_autograd :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
406*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: default_def_name_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
407*da0073e9SAndroid Build Coastguard Worker""",
408*da0073e9SAndroid Build Coastguard Worker        )
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker        # computed dispatch table is too big, so we only check on a few entries we're interested in.
411*da0073e9SAndroid Build Coastguard Worker        extracted_table = extract_dispatch_table_with_keys(
412*da0073e9SAndroid Build Coastguard Worker            table, dispatch_keys_to_check
413*da0073e9SAndroid Build Coastguard Worker        )
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
416*da0073e9SAndroid Build Coastguard Worker            extracted_table,
417*da0073e9SAndroid Build Coastguard Worker            """\
418*da0073e9SAndroid Build Coastguard WorkerUndefined: default_def_name_t_t [math kernel]
419*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu [kernel]
420*da0073e9SAndroid Build Coastguard WorkerCUDA: default_def_name_t_t [math kernel]
421*da0073e9SAndroid Build Coastguard WorkerXLA: fn_xla [kernel]
422*da0073e9SAndroid Build Coastguard WorkerAutogradOther: default_def_name_t_t [math kernel]
423*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: fn_autogradcpu [kernel]
424*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: default_def_name_t_t [math kernel]
425*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: fn_autograd [autograd kernel]
426*da0073e9SAndroid Build Coastguard Worker""",
427*da0073e9SAndroid Build Coastguard Worker        )
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker    def test_computed_table_with_cpu_math_autogradcpu_fallthrough(self):
430*da0073e9SAndroid Build Coastguard Worker        global_m = C._dispatch_library("IMPL", "_", "AutogradCPU")
431*da0073e9SAndroid Build Coastguard Worker        result = self.commute(
432*da0073e9SAndroid Build Coastguard Worker            "foo",
433*da0073e9SAndroid Build Coastguard Worker            [
434*da0073e9SAndroid Build Coastguard Worker                # m.def("foo", [](const Tensor & x) { return x })
435*da0073e9SAndroid Build Coastguard Worker                lambda m: m.def_name_t_t("foo"),
436*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
437*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "CPU"),
438*da0073e9SAndroid Build Coastguard Worker            ],
439*da0073e9SAndroid Build Coastguard Worker        )
440*da0073e9SAndroid Build Coastguard Worker        state, table = result.state, result.table
441*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
442*da0073e9SAndroid Build Coastguard Worker            state,
443*da0073e9SAndroid Build Coastguard Worker            """\
444*da0073e9SAndroid Build Coastguard Workername: test::foo
445*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor _0) -> Tensor _0
446*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0
447*da0073e9SAndroid Build Coastguard Workeralias analysis kind: CONSERVATIVE
448*da0073e9SAndroid Build Coastguard WorkerCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
449*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: default_def_name_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
450*da0073e9SAndroid Build Coastguard Worker""",
451*da0073e9SAndroid Build Coastguard Worker        )
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker        # computed dispatch table is too big, so we only check on a few entries we're interested in.
454*da0073e9SAndroid Build Coastguard Worker        extracted_table = extract_dispatch_table_with_keys(
455*da0073e9SAndroid Build Coastguard Worker            table, dispatch_keys_to_check
456*da0073e9SAndroid Build Coastguard Worker        )
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
459*da0073e9SAndroid Build Coastguard Worker            extracted_table,
460*da0073e9SAndroid Build Coastguard Worker            """\
461*da0073e9SAndroid Build Coastguard WorkerUndefined: default_def_name_t_t [math kernel]
462*da0073e9SAndroid Build Coastguard WorkerCPU: impl_t_t [kernel]
463*da0073e9SAndroid Build Coastguard WorkerCUDA: default_def_name_t_t [math kernel]
464*da0073e9SAndroid Build Coastguard WorkerXLA: default_def_name_t_t [math kernel]
465*da0073e9SAndroid Build Coastguard WorkerAutogradOther: default_def_name_t_t [math kernel]
466*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: registered in pytorch framework [backend fallback]
467*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: default_def_name_t_t [math kernel]
468*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: default_def_name_t_t [math kernel]
469*da0073e9SAndroid Build Coastguard Worker""",
470*da0073e9SAndroid Build Coastguard Worker        )
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker    def test_computed_table_with_math(self):
473*da0073e9SAndroid Build Coastguard Worker        global_m = C._dispatch_library("IMPL", "_", "AutogradCPU")
474*da0073e9SAndroid Build Coastguard Worker        result = self.commute(
475*da0073e9SAndroid Build Coastguard Worker            "foo",
476*da0073e9SAndroid Build Coastguard Worker            [
477*da0073e9SAndroid Build Coastguard Worker                # m.def("foo(Tensor x) -> Tensor")
478*da0073e9SAndroid Build Coastguard Worker                lambda m: m.def_("foo(Tensor x) -> Tensor"),
479*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x })
480*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "CompositeImplicitAutograd"),
481*da0073e9SAndroid Build Coastguard Worker            ],
482*da0073e9SAndroid Build Coastguard Worker        )
483*da0073e9SAndroid Build Coastguard Worker        state, table = result.state, result.table
484*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
485*da0073e9SAndroid Build Coastguard Worker            state,
486*da0073e9SAndroid Build Coastguard Worker            """\
487*da0073e9SAndroid Build Coastguard Workername: test::foo
488*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor
489*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0
490*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA
491*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
492*da0073e9SAndroid Build Coastguard Worker""",
493*da0073e9SAndroid Build Coastguard Worker        )
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker        # computed dispatch table is too big, so we only check on a few entries we're interested in.
496*da0073e9SAndroid Build Coastguard Worker        extracted_table = extract_dispatch_table_with_keys(
497*da0073e9SAndroid Build Coastguard Worker            table, dispatch_keys_to_check
498*da0073e9SAndroid Build Coastguard Worker        )
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
501*da0073e9SAndroid Build Coastguard Worker            extracted_table,
502*da0073e9SAndroid Build Coastguard Worker            """\
503*da0073e9SAndroid Build Coastguard WorkerUndefined: impl_t_t [math kernel]
504*da0073e9SAndroid Build Coastguard WorkerCPU: impl_t_t [math kernel]
505*da0073e9SAndroid Build Coastguard WorkerCUDA: impl_t_t [math kernel]
506*da0073e9SAndroid Build Coastguard WorkerXLA: impl_t_t [math kernel]
507*da0073e9SAndroid Build Coastguard WorkerAutogradOther: impl_t_t [math kernel]
508*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: impl_t_t [math kernel]
509*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: impl_t_t [math kernel]
510*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: impl_t_t [math kernel]
511*da0073e9SAndroid Build Coastguard Worker""",
512*da0073e9SAndroid Build Coastguard Worker        )
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker    def test_computed_table_with_cpu_math(self):
515*da0073e9SAndroid Build Coastguard Worker        global_m = C._dispatch_library("IMPL", "_", "AutogradCPU")
516*da0073e9SAndroid Build Coastguard Worker        result = self.commute(
517*da0073e9SAndroid Build Coastguard Worker            "foo",
518*da0073e9SAndroid Build Coastguard Worker            [
519*da0073e9SAndroid Build Coastguard Worker                # m.def("foo(Tensor x) -> Tensor")
520*da0073e9SAndroid Build Coastguard Worker                lambda m: m.def_("foo(Tensor x) -> Tensor"),
521*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
522*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
523*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x })
524*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t(
525*da0073e9SAndroid Build Coastguard Worker                    "foo", "CompositeImplicitAutograd", debug="fn_math"
526*da0073e9SAndroid Build Coastguard Worker                ),
527*da0073e9SAndroid Build Coastguard Worker            ],
528*da0073e9SAndroid Build Coastguard Worker        )
529*da0073e9SAndroid Build Coastguard Worker        state, table = result.state, result.table
530*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
531*da0073e9SAndroid Build Coastguard Worker            state,
532*da0073e9SAndroid Build Coastguard Worker            """\
533*da0073e9SAndroid Build Coastguard Workername: test::foo
534*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor
535*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0
536*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA
537*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
538*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
539*da0073e9SAndroid Build Coastguard Worker""",
540*da0073e9SAndroid Build Coastguard Worker        )
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Worker        # computed dispatch table is too big, so we only check on a few entries we're interested in.
543*da0073e9SAndroid Build Coastguard Worker        extracted_table = extract_dispatch_table_with_keys(
544*da0073e9SAndroid Build Coastguard Worker            table, dispatch_keys_to_check
545*da0073e9SAndroid Build Coastguard Worker        )
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
548*da0073e9SAndroid Build Coastguard Worker            extracted_table,
549*da0073e9SAndroid Build Coastguard Worker            """\
550*da0073e9SAndroid Build Coastguard WorkerUndefined: fn_math [math kernel]
551*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu [kernel]
552*da0073e9SAndroid Build Coastguard WorkerCUDA: fn_math [math kernel]
553*da0073e9SAndroid Build Coastguard WorkerXLA: fn_math [math kernel]
554*da0073e9SAndroid Build Coastguard WorkerAutogradOther: fn_math [math kernel]
555*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: registered in pytorch framework [backend fallback]
556*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: fn_math [math kernel]
557*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: fn_math [math kernel]
558*da0073e9SAndroid Build Coastguard Worker""",
559*da0073e9SAndroid Build Coastguard Worker        )
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker    def test_computed_table_with_autograd(self):
562*da0073e9SAndroid Build Coastguard Worker        global_m = C._dispatch_library("IMPL", "_", "AutogradCPU")
563*da0073e9SAndroid Build Coastguard Worker        result = self.commute(
564*da0073e9SAndroid Build Coastguard Worker            "foo",
565*da0073e9SAndroid Build Coastguard Worker            [
566*da0073e9SAndroid Build Coastguard Worker                # m.def("foo(Tensor x) -> Tensor")
567*da0073e9SAndroid Build Coastguard Worker                lambda m: m.def_("foo(Tensor x) -> Tensor"),
568*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
569*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "Autograd"),
570*da0073e9SAndroid Build Coastguard Worker            ],
571*da0073e9SAndroid Build Coastguard Worker        )
572*da0073e9SAndroid Build Coastguard Worker        state, table = result.state, result.table
573*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
574*da0073e9SAndroid Build Coastguard Worker            state,
575*da0073e9SAndroid Build Coastguard Worker            """\
576*da0073e9SAndroid Build Coastguard Workername: test::foo
577*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor
578*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0
579*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA
580*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
581*da0073e9SAndroid Build Coastguard Worker""",
582*da0073e9SAndroid Build Coastguard Worker        )
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker        # computed dispatch table is too big, so we only check on a few entries we're interested in.
585*da0073e9SAndroid Build Coastguard Worker        extracted_table = extract_dispatch_table_with_keys(
586*da0073e9SAndroid Build Coastguard Worker            table, dispatch_keys_to_check
587*da0073e9SAndroid Build Coastguard Worker        )
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
590*da0073e9SAndroid Build Coastguard Worker            extracted_table,
591*da0073e9SAndroid Build Coastguard Worker            """\
592*da0073e9SAndroid Build Coastguard WorkerAutogradOther: impl_t_t [autograd kernel]
593*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: impl_t_t [autograd kernel]
594*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: impl_t_t [autograd kernel]
595*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: impl_t_t [autograd kernel]
596*da0073e9SAndroid Build Coastguard Worker""",
597*da0073e9SAndroid Build Coastguard Worker        )
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker    # Now that catchAll maps to CompositeImplicitAutograd, registering to both
600*da0073e9SAndroid Build Coastguard Worker    # catchAll and CompositeImplicitAutograd breaks commutativity.
601*da0073e9SAndroid Build Coastguard Worker    def test_computed_table_with_cpu_autograd_math(self):
602*da0073e9SAndroid Build Coastguard Worker        result = self.commute(
603*da0073e9SAndroid Build Coastguard Worker            "foo",
604*da0073e9SAndroid Build Coastguard Worker            [
605*da0073e9SAndroid Build Coastguard Worker                # m.def("foo(Tensor x) -> Tensor")
606*da0073e9SAndroid Build Coastguard Worker                lambda m: m.def_("foo(Tensor x) -> Tensor"),
607*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
608*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
609*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
610*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"),
611*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x })
612*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t(
613*da0073e9SAndroid Build Coastguard Worker                    "foo", "CompositeImplicitAutograd", debug="fn_math"
614*da0073e9SAndroid Build Coastguard Worker                ),
615*da0073e9SAndroid Build Coastguard Worker            ],
616*da0073e9SAndroid Build Coastguard Worker        )
617*da0073e9SAndroid Build Coastguard Worker        state, table = result.state, result.table
618*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
619*da0073e9SAndroid Build Coastguard Worker            state,
620*da0073e9SAndroid Build Coastguard Worker            """\
621*da0073e9SAndroid Build Coastguard Workername: test::foo
622*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor
623*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0
624*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA
625*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
626*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: fn_autograd :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
627*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
628*da0073e9SAndroid Build Coastguard Worker""",
629*da0073e9SAndroid Build Coastguard Worker        )
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker        # computed dispatch table is too big, so we only check on a few entries we're interested in.
632*da0073e9SAndroid Build Coastguard Worker        extracted_table = extract_dispatch_table_with_keys(
633*da0073e9SAndroid Build Coastguard Worker            table, dispatch_keys_to_check
634*da0073e9SAndroid Build Coastguard Worker        )
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
637*da0073e9SAndroid Build Coastguard Worker            extracted_table,
638*da0073e9SAndroid Build Coastguard Worker            """\
639*da0073e9SAndroid Build Coastguard WorkerUndefined: fn_math [math kernel]
640*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu [kernel]
641*da0073e9SAndroid Build Coastguard WorkerCUDA: fn_math [math kernel]
642*da0073e9SAndroid Build Coastguard WorkerXLA: fn_math [math kernel]
643*da0073e9SAndroid Build Coastguard WorkerAutogradOther: fn_math [math kernel]
644*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: fn_autograd [autograd kernel]
645*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: fn_math [math kernel]
646*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: fn_math [math kernel]
647*da0073e9SAndroid Build Coastguard Worker""",
648*da0073e9SAndroid Build Coastguard Worker        )
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker    def test_computed_table_with_ambiguous_autogradother(self):
651*da0073e9SAndroid Build Coastguard Worker        result = self.commute(
652*da0073e9SAndroid Build Coastguard Worker            "foo",
653*da0073e9SAndroid Build Coastguard Worker            [
654*da0073e9SAndroid Build Coastguard Worker                # m.def("foo(Tensor x) -> Tensor")
655*da0073e9SAndroid Build Coastguard Worker                lambda m: m.def_("foo(Tensor x) -> Tensor"),
656*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x })
657*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t(
658*da0073e9SAndroid Build Coastguard Worker                    "foo", "CompositeImplicitAutograd", debug="fn_math"
659*da0073e9SAndroid Build Coastguard Worker                ),
660*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kFPGA, [](const Tensor & x) { return x })
661*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "FPGA", debug="fn_fpga"),
662*da0073e9SAndroid Build Coastguard Worker            ],
663*da0073e9SAndroid Build Coastguard Worker        )
664*da0073e9SAndroid Build Coastguard Worker        state, table = result.state, result.table
665*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
666*da0073e9SAndroid Build Coastguard Worker            state,
667*da0073e9SAndroid Build Coastguard Worker            """\
668*da0073e9SAndroid Build Coastguard Workername: test::foo
669*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor
670*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0
671*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA
672*da0073e9SAndroid Build Coastguard WorkerFPGA: fn_fpga :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
673*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
674*da0073e9SAndroid Build Coastguard Worker""",
675*da0073e9SAndroid Build Coastguard Worker        )
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker        # computed dispatch table is too big, so we only check on a few entries we're interested in.
678*da0073e9SAndroid Build Coastguard Worker        extracted_table = extract_dispatch_table_with_keys(
679*da0073e9SAndroid Build Coastguard Worker            table, dispatch_keys_to_check + ("FPGA",)
680*da0073e9SAndroid Build Coastguard Worker        )
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
683*da0073e9SAndroid Build Coastguard Worker            extracted_table,
684*da0073e9SAndroid Build Coastguard Worker            """\
685*da0073e9SAndroid Build Coastguard WorkerUndefined: fn_math [math kernel]
686*da0073e9SAndroid Build Coastguard WorkerCPU: fn_math [math kernel]
687*da0073e9SAndroid Build Coastguard WorkerCUDA: fn_math [math kernel]
688*da0073e9SAndroid Build Coastguard WorkerXLA: fn_math [math kernel]
689*da0073e9SAndroid Build Coastguard WorkerAutogradOther: ambiguous_autogradother [ambiguous autogradother]
690*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: fn_math [math kernel]
691*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: fn_math [math kernel]
692*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: fn_math [math kernel]
693*da0073e9SAndroid Build Coastguard WorkerFPGA: fn_fpga [kernel]
694*da0073e9SAndroid Build Coastguard Worker""",
695*da0073e9SAndroid Build Coastguard Worker        )
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker    def test_computed_table_with_cpu_defaultbackend(self):
698*da0073e9SAndroid Build Coastguard Worker        result = self.commute(
699*da0073e9SAndroid Build Coastguard Worker            "foo",
700*da0073e9SAndroid Build Coastguard Worker            [
701*da0073e9SAndroid Build Coastguard Worker                # m.def("foo(Tensor x) -> Tensor")
702*da0073e9SAndroid Build Coastguard Worker                lambda m: m.def_("foo(Tensor x) -> Tensor"),
703*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
704*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
705*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCompositeExplicitAutograd, [](const Tensor & x) { return x })
706*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t(
707*da0073e9SAndroid Build Coastguard Worker                    "foo", "CompositeExplicitAutograd", debug="fn_defaultbackend"
708*da0073e9SAndroid Build Coastguard Worker                ),
709*da0073e9SAndroid Build Coastguard Worker            ],
710*da0073e9SAndroid Build Coastguard Worker        )
711*da0073e9SAndroid Build Coastguard Worker        state, table = result.state, result.table
712*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
713*da0073e9SAndroid Build Coastguard Worker            state,
714*da0073e9SAndroid Build Coastguard Worker            """\
715*da0073e9SAndroid Build Coastguard Workername: test::foo
716*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor
717*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0
718*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA
719*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
720*da0073e9SAndroid Build Coastguard WorkerCompositeExplicitAutograd[alias]: fn_defaultbackend :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
721*da0073e9SAndroid Build Coastguard Worker""",
722*da0073e9SAndroid Build Coastguard Worker        )
723*da0073e9SAndroid Build Coastguard Worker
724*da0073e9SAndroid Build Coastguard Worker        # computed dispatch table is too big, so we only check on a few entries we're interested in.
725*da0073e9SAndroid Build Coastguard Worker        extracted_table = extract_dispatch_table_with_keys(
726*da0073e9SAndroid Build Coastguard Worker            table, dispatch_keys_to_check
727*da0073e9SAndroid Build Coastguard Worker        )
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
730*da0073e9SAndroid Build Coastguard Worker            extracted_table,
731*da0073e9SAndroid Build Coastguard Worker            """\
732*da0073e9SAndroid Build Coastguard WorkerUndefined: fn_defaultbackend [default backend kernel]
733*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu [kernel]
734*da0073e9SAndroid Build Coastguard WorkerCUDA: fn_defaultbackend [default backend kernel]
735*da0073e9SAndroid Build Coastguard WorkerXLA: fn_defaultbackend [default backend kernel]
736*da0073e9SAndroid Build Coastguard WorkerAutogradOther: registered in pytorch framework [backend fallback]
737*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: registered in pytorch framework [backend fallback]
738*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: registered in pytorch framework [backend fallback]
739*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: registered in pytorch framework [backend fallback]
740*da0073e9SAndroid Build Coastguard Worker""",
741*da0073e9SAndroid Build Coastguard Worker        )
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker    def test_computed_table_with_cpu_autograd_defaultbackend(self):
744*da0073e9SAndroid Build Coastguard Worker        result = self.commute(
745*da0073e9SAndroid Build Coastguard Worker            "foo",
746*da0073e9SAndroid Build Coastguard Worker            [
747*da0073e9SAndroid Build Coastguard Worker                # m.def("foo(Tensor x) -> Tensor")
748*da0073e9SAndroid Build Coastguard Worker                lambda m: m.def_("foo(Tensor x) -> Tensor"),
749*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
750*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
751*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
752*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"),
753*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCompositeExplicitAutograd, [](const Tensor & x) { return x })
754*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t(
755*da0073e9SAndroid Build Coastguard Worker                    "foo", "CompositeExplicitAutograd", debug="fn_defaultbackend"
756*da0073e9SAndroid Build Coastguard Worker                ),
757*da0073e9SAndroid Build Coastguard Worker            ],
758*da0073e9SAndroid Build Coastguard Worker        )
759*da0073e9SAndroid Build Coastguard Worker        state, table = result.state, result.table
760*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
761*da0073e9SAndroid Build Coastguard Worker            state,
762*da0073e9SAndroid Build Coastguard Worker            """\
763*da0073e9SAndroid Build Coastguard Workername: test::foo
764*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor
765*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0
766*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA
767*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
768*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: fn_autograd :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
769*da0073e9SAndroid Build Coastguard WorkerCompositeExplicitAutograd[alias]: fn_defaultbackend :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
770*da0073e9SAndroid Build Coastguard Worker""",
771*da0073e9SAndroid Build Coastguard Worker        )
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker        # computed dispatch table is too big, so we only check on a few entries we're interested in.
774*da0073e9SAndroid Build Coastguard Worker        extracted_table = extract_dispatch_table_with_keys(
775*da0073e9SAndroid Build Coastguard Worker            table, dispatch_keys_to_check + ("FPGA",)
776*da0073e9SAndroid Build Coastguard Worker        )
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
779*da0073e9SAndroid Build Coastguard Worker            extracted_table,
780*da0073e9SAndroid Build Coastguard Worker            """\
781*da0073e9SAndroid Build Coastguard WorkerUndefined: fn_defaultbackend [default backend kernel]
782*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu [kernel]
783*da0073e9SAndroid Build Coastguard WorkerCUDA: fn_defaultbackend [default backend kernel]
784*da0073e9SAndroid Build Coastguard WorkerXLA: fn_defaultbackend [default backend kernel]
785*da0073e9SAndroid Build Coastguard WorkerAutogradOther: fn_autograd [autograd kernel]
786*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: fn_autograd [autograd kernel]
787*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: fn_autograd [autograd kernel]
788*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: fn_autograd [autograd kernel]
789*da0073e9SAndroid Build Coastguard WorkerFPGA: fn_defaultbackend [default backend kernel]
790*da0073e9SAndroid Build Coastguard Worker""",
791*da0073e9SAndroid Build Coastguard Worker        )
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker    def test_computed_table_with_cpu_autograd_math_defaultbackend(self):
794*da0073e9SAndroid Build Coastguard Worker        result = self.commute(
795*da0073e9SAndroid Build Coastguard Worker            "foo",
796*da0073e9SAndroid Build Coastguard Worker            [
797*da0073e9SAndroid Build Coastguard Worker                # m.def("foo(Tensor x) -> Tensor")
798*da0073e9SAndroid Build Coastguard Worker                lambda m: m.def_("foo(Tensor x) -> Tensor"),
799*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
800*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
801*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
802*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"),
803*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x })
804*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t(
805*da0073e9SAndroid Build Coastguard Worker                    "foo", "CompositeImplicitAutograd", debug="fn_math"
806*da0073e9SAndroid Build Coastguard Worker                ),
807*da0073e9SAndroid Build Coastguard Worker                # m.impl("foo", torch::kCompositeExplicitAutograd, [](const Tensor & x) { return x })
808*da0073e9SAndroid Build Coastguard Worker                lambda m: m.impl_t_t(
809*da0073e9SAndroid Build Coastguard Worker                    "foo", "CompositeExplicitAutograd", debug="fn_defaultbackend"
810*da0073e9SAndroid Build Coastguard Worker                ),
811*da0073e9SAndroid Build Coastguard Worker            ],
812*da0073e9SAndroid Build Coastguard Worker        )
813*da0073e9SAndroid Build Coastguard Worker        state, table = result.state, result.table
814*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
815*da0073e9SAndroid Build Coastguard Worker            state,
816*da0073e9SAndroid Build Coastguard Worker            """\
817*da0073e9SAndroid Build Coastguard Workername: test::foo
818*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor
819*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0
820*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA
821*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
822*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: fn_autograd :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
823*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
824*da0073e9SAndroid Build Coastguard WorkerCompositeExplicitAutograd[alias]: fn_defaultbackend :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
825*da0073e9SAndroid Build Coastguard Worker""",
826*da0073e9SAndroid Build Coastguard Worker        )
827*da0073e9SAndroid Build Coastguard Worker
828*da0073e9SAndroid Build Coastguard Worker        # computed dispatch table is too big, so we only check on a few entries we're interested in.
829*da0073e9SAndroid Build Coastguard Worker        extracted_table = extract_dispatch_table_with_keys(
830*da0073e9SAndroid Build Coastguard Worker            table, dispatch_keys_to_check
831*da0073e9SAndroid Build Coastguard Worker        )
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
834*da0073e9SAndroid Build Coastguard Worker            extracted_table,
835*da0073e9SAndroid Build Coastguard Worker            """\
836*da0073e9SAndroid Build Coastguard WorkerUndefined: fn_defaultbackend [default backend kernel]
837*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu [kernel]
838*da0073e9SAndroid Build Coastguard WorkerCUDA: fn_defaultbackend [default backend kernel]
839*da0073e9SAndroid Build Coastguard WorkerXLA: fn_defaultbackend [default backend kernel]
840*da0073e9SAndroid Build Coastguard WorkerAutogradOther: fn_autograd [autograd kernel]
841*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: fn_autograd [autograd kernel]
842*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: fn_autograd [autograd kernel]
843*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: fn_autograd [autograd kernel]
844*da0073e9SAndroid Build Coastguard Worker""",
845*da0073e9SAndroid Build Coastguard Worker        )
846*da0073e9SAndroid Build Coastguard Worker
847*da0073e9SAndroid Build Coastguard Worker    def test_multiple_def_error(self):
848*da0073e9SAndroid Build Coastguard Worker        ops = [
849*da0073e9SAndroid Build Coastguard Worker            # m.def("foo(Tensor x, Tensor y) -> Tensor")
850*da0073e9SAndroid Build Coastguard Worker            lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
851*da0073e9SAndroid Build Coastguard Worker            # m.def("foo(Tensor x, Tensor y) -> Tensor")
852*da0073e9SAndroid Build Coastguard Worker            lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
853*da0073e9SAndroid Build Coastguard Worker        ]
854*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
855*da0073e9SAndroid Build Coastguard Worker            self.commute("foo", ops, expect_raises=True).state,
856*da0073e9SAndroid Build Coastguard Worker            """Tried to register an operator (test::foo(Tensor x, Tensor y) -> Tensor) with the same name and overload """
857*da0073e9SAndroid Build Coastguard Worker            """name multiple times. Each overload's schema should only be registered with a single call to def(). """
858*da0073e9SAndroid Build Coastguard Worker            """Duplicate registration: registered at /dev/null:0. Original registration: registered at /dev/null:0""",
859*da0073e9SAndroid Build Coastguard Worker        )
860*da0073e9SAndroid Build Coastguard Worker
861*da0073e9SAndroid Build Coastguard Worker    def test_def_with_explicit_alias(self):
862*da0073e9SAndroid Build Coastguard Worker        state = self.commute(
863*da0073e9SAndroid Build Coastguard Worker            "foo",
864*da0073e9SAndroid Build Coastguard Worker            [
865*da0073e9SAndroid Build Coastguard Worker                # m.def(torch::schema(
866*da0073e9SAndroid Build Coastguard Worker                #   "foo(Tensor x, Tensor y) -> Tensor",
867*da0073e9SAndroid Build Coastguard Worker                #   AliasAnalysisKind::PURE))
868*da0073e9SAndroid Build Coastguard Worker                lambda m: m.def_(
869*da0073e9SAndroid Build Coastguard Worker                    "foo(Tensor x, Tensor y) -> Tensor", alias="PURE_FUNCTION"
870*da0073e9SAndroid Build Coastguard Worker                )
871*da0073e9SAndroid Build Coastguard Worker            ],
872*da0073e9SAndroid Build Coastguard Worker        ).state
873*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
874*da0073e9SAndroid Build Coastguard Worker            state,
875*da0073e9SAndroid Build Coastguard Worker            """\
876*da0073e9SAndroid Build Coastguard Workername: test::foo
877*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x, Tensor y) -> Tensor
878*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0
879*da0073e9SAndroid Build Coastguard Workeralias analysis kind: PURE_FUNCTION
880*da0073e9SAndroid Build Coastguard Worker""",
881*da0073e9SAndroid Build Coastguard Worker        )
882*da0073e9SAndroid Build Coastguard Worker
883*da0073e9SAndroid Build Coastguard Worker    def test_multiple_def_alias_defaulting(self):
884*da0073e9SAndroid Build Coastguard Worker        ops = [
885*da0073e9SAndroid Build Coastguard Worker            # m.def(torch::schema("foo(Tensor x) -> Tensor",
886*da0073e9SAndroid Build Coastguard Worker            #                     c10::AliasAnalysisKind::PURE_FUNCTION))
887*da0073e9SAndroid Build Coastguard Worker            lambda m: m.def_("foo(Tensor x) -> Tensor", alias="PURE_FUNCTION"),
888*da0073e9SAndroid Build Coastguard Worker            # RegisterOperators().op("foo(Tensor x) -> Tensor")
889*da0073e9SAndroid Build Coastguard Worker            lambda m: m.def_legacy("foo(Tensor x) -> Tensor"),
890*da0073e9SAndroid Build Coastguard Worker        ]
891*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
892*da0073e9SAndroid Build Coastguard Worker            self.commute("foo", ops, expect_raises=True).state,
893*da0073e9SAndroid Build Coastguard Worker            """Tried to register an operator (test::foo(Tensor x) -> Tensor) with the same name and overload """
894*da0073e9SAndroid Build Coastguard Worker            """name multiple times. Each overload's schema should only be registered with a single call to def(). """
895*da0073e9SAndroid Build Coastguard Worker            """Duplicate registration: registered at /dev/null:0. Original registration: registered at /dev/null:0""",
896*da0073e9SAndroid Build Coastguard Worker        )
897*da0073e9SAndroid Build Coastguard Worker
898*da0073e9SAndroid Build Coastguard Worker    def test_multiple_def_alias_mismatch(self):
899*da0073e9SAndroid Build Coastguard Worker        ops = [
900*da0073e9SAndroid Build Coastguard Worker            # m.def(torch::schema("foo(Tensor x) -> Tensor",
901*da0073e9SAndroid Build Coastguard Worker            #                     c10::AliasAnalysisKind::PURE_FUNCTION))
902*da0073e9SAndroid Build Coastguard Worker            lambda m: m.def_("foo(Tensor x) -> Tensor", alias="PURE_FUNCTION"),
903*da0073e9SAndroid Build Coastguard Worker            # m.def(torch::schema("foo(Tensor x) -> Tensor",
904*da0073e9SAndroid Build Coastguard Worker            #                     c10::AliasAnalysisKind::CONSERVATIVE))
905*da0073e9SAndroid Build Coastguard Worker            lambda m: m.def_("foo(Tensor x) -> Tensor", alias="CONSERVATIVE"),
906*da0073e9SAndroid Build Coastguard Worker        ]
907*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
908*da0073e9SAndroid Build Coastguard Worker            self.commute("foo", ops, expect_raises=True).state,
909*da0073e9SAndroid Build Coastguard Worker            """Tried to register an operator (test::foo(Tensor x) -> Tensor) with the same name and overload """
910*da0073e9SAndroid Build Coastguard Worker            """name multiple times. Each overload's schema should only be registered with a single call to def(). """
911*da0073e9SAndroid Build Coastguard Worker            """Duplicate registration: registered at /dev/null:0. Original registration: registered at /dev/null:0""",
912*da0073e9SAndroid Build Coastguard Worker        )
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Worker    def test_multiple_fallback(self):
915*da0073e9SAndroid Build Coastguard Worker        global_m = C._dispatch_library("IMPL", "_", "XLA")
916*da0073e9SAndroid Build Coastguard Worker        global_m.fallback_fallthrough()
917*da0073e9SAndroid Build Coastguard Worker        try:
918*da0073e9SAndroid Build Coastguard Worker            global_m.fallback_fallthrough()
919*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
920*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(
921*da0073e9SAndroid Build Coastguard Worker                str(e),
922*da0073e9SAndroid Build Coastguard Worker                """Tried to register multiple backend fallbacks for the same dispatch key XLA; previous registration """
923*da0073e9SAndroid Build Coastguard Worker                """registered at /dev/null:0, new registration registered at /dev/null:0""",
924*da0073e9SAndroid Build Coastguard Worker            )
925*da0073e9SAndroid Build Coastguard Worker        else:
926*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(False)
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker    def test_overwrite_math(self):
929*da0073e9SAndroid Build Coastguard Worker        ops = [
930*da0073e9SAndroid Build Coastguard Worker            lambda m: m.impl_t_t("foo", debug="fn1"),
931*da0073e9SAndroid Build Coastguard Worker            lambda m: m.impl_t_t("foo", debug="fn2"),
932*da0073e9SAndroid Build Coastguard Worker        ]
933*da0073e9SAndroid Build Coastguard Worker        # Not commutative
934*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
935*da0073e9SAndroid Build Coastguard Worker            self.commute("foo", ops, ctor_order=(0, 1)).state,
936*da0073e9SAndroid Build Coastguard Worker            """\
937*da0073e9SAndroid Build Coastguard Workername: test::foo
938*da0073e9SAndroid Build Coastguard Workerschema: (none)
939*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: fn2 :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
940*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias] (inactive): fn1 :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
941*da0073e9SAndroid Build Coastguard Worker""",
942*da0073e9SAndroid Build Coastguard Worker        )
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker    # Definition: a dangling impl happens when someone does an impl() on a
945*da0073e9SAndroid Build Coastguard Worker    # function but not a def() for it. This is usually a bug, e.g. someone
946*da0073e9SAndroid Build Coastguard Worker    # misspelled an operator name, or someone registered an impl for an op that
947*da0073e9SAndroid Build Coastguard Worker    # no longer exists
948*da0073e9SAndroid Build Coastguard Worker    def test_find_dangling_impls(self):
949*da0073e9SAndroid Build Coastguard Worker        dangling_impls = C._dispatch_find_dangling_impls()
950*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
951*da0073e9SAndroid Build Coastguard Worker            0,
952*da0073e9SAndroid Build Coastguard Worker            len(dangling_impls),
953*da0073e9SAndroid Build Coastguard Worker            msg=f"Expect zero dangling impls, but found: {dangling_impls}",
954*da0073e9SAndroid Build Coastguard Worker        )
955*da0073e9SAndroid Build Coastguard Worker
956*da0073e9SAndroid Build Coastguard Worker    def test_find_dangling_impls_ext(self):
957*da0073e9SAndroid Build Coastguard Worker        extension_path = os.path.join(
958*da0073e9SAndroid Build Coastguard Worker            os.path.dirname(os.path.abspath(__file__)),
959*da0073e9SAndroid Build Coastguard Worker            "cpp_extensions",
960*da0073e9SAndroid Build Coastguard Worker            "dangling_impl_extension.cpp",
961*da0073e9SAndroid Build Coastguard Worker        )
962*da0073e9SAndroid Build Coastguard Worker        module = torch.utils.cpp_extension.load(
963*da0073e9SAndroid Build Coastguard Worker            name="dangling_impl_extension",
964*da0073e9SAndroid Build Coastguard Worker            sources=[extension_path],
965*da0073e9SAndroid Build Coastguard Worker            extra_cflags=["-g"],
966*da0073e9SAndroid Build Coastguard Worker            verbose=True,
967*da0073e9SAndroid Build Coastguard Worker        )
968*da0073e9SAndroid Build Coastguard Worker
969*da0073e9SAndroid Build Coastguard Worker        impls = C._dispatch_find_dangling_impls()
970*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, len(impls))
971*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
972*da0073e9SAndroid Build Coastguard Worker            f"""\
973*da0073e9SAndroid Build Coastguard Workername: __test::foo
974*da0073e9SAndroid Build Coastguard Workerschema: (none)
975*da0073e9SAndroid Build Coastguard WorkerCPU: registered at {extension_path}:5 :: () -> () [ boxed unboxed ]
976*da0073e9SAndroid Build Coastguard Worker""",
977*da0073e9SAndroid Build Coastguard Worker            impls[0],
978*da0073e9SAndroid Build Coastguard Worker        )
979*da0073e9SAndroid Build Coastguard Worker
980*da0073e9SAndroid Build Coastguard Worker    def test_dispatch_print_registrations_for_dispatch_key_invalid(self):
981*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
982*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "could not parse dispatch key: invalid_key"
983*da0073e9SAndroid Build Coastguard Worker        ):
984*da0073e9SAndroid Build Coastguard Worker            C._dispatch_print_registrations_for_dispatch_key("invalid_key")
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker
987*da0073e9SAndroid Build Coastguard Workerclass TestPythonDispatcher(TestCase):
988*da0073e9SAndroid Build Coastguard Worker    def test_basic(self):
989*da0073e9SAndroid Build Coastguard Worker        dispatcher = PythonDispatcher()
990*da0073e9SAndroid Build Coastguard Worker        dispatcher.register(["CPU", "XLA", "Lazy", "CompositeImplicitAutograd"])
991*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
992*da0073e9SAndroid Build Coastguard Worker            dispatcher.dispatchTable(),
993*da0073e9SAndroid Build Coastguard Worker            """\
994*da0073e9SAndroid Build Coastguard Worker
995*da0073e9SAndroid Build Coastguard WorkerComputed Dispatch Table
996*da0073e9SAndroid Build Coastguard Workerkey             kernel
997*da0073e9SAndroid Build Coastguard Worker---------------------------
998*da0073e9SAndroid Build Coastguard WorkerCPU             fn_CPU [kernel]
999*da0073e9SAndroid Build Coastguard WorkerXLA             fn_XLA [kernel]
1000*da0073e9SAndroid Build Coastguard WorkerLazy            fn_Lazy [kernel]
1001*da0073e9SAndroid Build Coastguard WorkerFPGA            fn_CompositeImplicitAutograd [math kernel]
1002*da0073e9SAndroid Build Coastguard WorkerAutogradOther   fn_CompositeImplicitAutograd [math kernel]
1003*da0073e9SAndroid Build Coastguard WorkerAutogradCPU     [backend fallback]
1004*da0073e9SAndroid Build Coastguard WorkerAutogradXLA     [backend fallback]
1005*da0073e9SAndroid Build Coastguard WorkerAutogradLazy    [backend fallback]
1006*da0073e9SAndroid Build Coastguard Worker""",
1007*da0073e9SAndroid Build Coastguard Worker        )
1008*da0073e9SAndroid Build Coastguard Worker
1009*da0073e9SAndroid Build Coastguard Worker    def test_math_autogradcpu(self):
1010*da0073e9SAndroid Build Coastguard Worker        dispatcher = PythonDispatcher()
1011*da0073e9SAndroid Build Coastguard Worker        dispatcher.register(
1012*da0073e9SAndroid Build Coastguard Worker            ["CPU", "XLA", "Lazy", "CompositeImplicitAutograd", "AutogradCPU"]
1013*da0073e9SAndroid Build Coastguard Worker        )
1014*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1015*da0073e9SAndroid Build Coastguard Worker            dispatcher.dispatchTable(),
1016*da0073e9SAndroid Build Coastguard Worker            """\
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard WorkerComputed Dispatch Table
1019*da0073e9SAndroid Build Coastguard Workerkey             kernel
1020*da0073e9SAndroid Build Coastguard Worker---------------------------
1021*da0073e9SAndroid Build Coastguard WorkerCPU             fn_CPU [kernel]
1022*da0073e9SAndroid Build Coastguard WorkerXLA             fn_XLA [kernel]
1023*da0073e9SAndroid Build Coastguard WorkerLazy            fn_Lazy [kernel]
1024*da0073e9SAndroid Build Coastguard WorkerFPGA            fn_CompositeImplicitAutograd [math kernel]
1025*da0073e9SAndroid Build Coastguard WorkerAutogradOther   fn_CompositeImplicitAutograd [math kernel]
1026*da0073e9SAndroid Build Coastguard WorkerAutogradCPU     fn_AutogradCPU [kernel]
1027*da0073e9SAndroid Build Coastguard WorkerAutogradXLA     [backend fallback]
1028*da0073e9SAndroid Build Coastguard WorkerAutogradLazy    [backend fallback]
1029*da0073e9SAndroid Build Coastguard Worker""",
1030*da0073e9SAndroid Build Coastguard Worker        )
1031*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1032*da0073e9SAndroid Build Coastguard Worker            dispatcher.registrations(),
1033*da0073e9SAndroid Build Coastguard Worker            """\
1034*da0073e9SAndroid Build Coastguard Worker
1035*da0073e9SAndroid Build Coastguard WorkerRegistered Kernels
1036*da0073e9SAndroid Build Coastguard Workerkey             kernel
1037*da0073e9SAndroid Build Coastguard Worker---------------------------
1038*da0073e9SAndroid Build Coastguard WorkerCPU             fn_CPU
1039*da0073e9SAndroid Build Coastguard WorkerXLA             fn_XLA
1040*da0073e9SAndroid Build Coastguard WorkerLazy            fn_Lazy
1041*da0073e9SAndroid Build Coastguard WorkerAutogradCPU     fn_AutogradCPU
1042*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias] fn_CompositeImplicitAutograd
1043*da0073e9SAndroid Build Coastguard Worker""",
1044*da0073e9SAndroid Build Coastguard Worker        )
1045*da0073e9SAndroid Build Coastguard Worker
1046*da0073e9SAndroid Build Coastguard Worker    def test_defaultbackend_autogradcpu(self):
1047*da0073e9SAndroid Build Coastguard Worker        dispatcher = PythonDispatcher()
1048*da0073e9SAndroid Build Coastguard Worker        dispatcher.register(
1049*da0073e9SAndroid Build Coastguard Worker            ["CPU", "XLA", "Lazy", "CompositeExplicitAutograd", "AutogradCPU"]
1050*da0073e9SAndroid Build Coastguard Worker        )
1051*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1052*da0073e9SAndroid Build Coastguard Worker            dispatcher.dispatchTable(),
1053*da0073e9SAndroid Build Coastguard Worker            """\
1054*da0073e9SAndroid Build Coastguard Worker
1055*da0073e9SAndroid Build Coastguard WorkerComputed Dispatch Table
1056*da0073e9SAndroid Build Coastguard Workerkey             kernel
1057*da0073e9SAndroid Build Coastguard Worker---------------------------
1058*da0073e9SAndroid Build Coastguard WorkerCPU             fn_CPU [kernel]
1059*da0073e9SAndroid Build Coastguard WorkerXLA             fn_XLA [kernel]
1060*da0073e9SAndroid Build Coastguard WorkerLazy            fn_Lazy [kernel]
1061*da0073e9SAndroid Build Coastguard WorkerFPGA            fn_CompositeExplicitAutograd [default backend kernel]
1062*da0073e9SAndroid Build Coastguard WorkerAutogradOther   [backend fallback]
1063*da0073e9SAndroid Build Coastguard WorkerAutogradCPU     fn_AutogradCPU [kernel]
1064*da0073e9SAndroid Build Coastguard WorkerAutogradXLA     [backend fallback]
1065*da0073e9SAndroid Build Coastguard WorkerAutogradLazy    [backend fallback]
1066*da0073e9SAndroid Build Coastguard Worker""",
1067*da0073e9SAndroid Build Coastguard Worker        )
1068*da0073e9SAndroid Build Coastguard Worker
1069*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1070*da0073e9SAndroid Build Coastguard Worker            dispatcher.registrations(),
1071*da0073e9SAndroid Build Coastguard Worker            """\
1072*da0073e9SAndroid Build Coastguard Worker
1073*da0073e9SAndroid Build Coastguard WorkerRegistered Kernels
1074*da0073e9SAndroid Build Coastguard Workerkey             kernel
1075*da0073e9SAndroid Build Coastguard Worker---------------------------
1076*da0073e9SAndroid Build Coastguard WorkerCPU             fn_CPU
1077*da0073e9SAndroid Build Coastguard WorkerXLA             fn_XLA
1078*da0073e9SAndroid Build Coastguard WorkerLazy            fn_Lazy
1079*da0073e9SAndroid Build Coastguard WorkerAutogradCPU     fn_AutogradCPU
1080*da0073e9SAndroid Build Coastguard WorkerCompositeExplicitAutograd[alias] fn_CompositeExplicitAutograd
1081*da0073e9SAndroid Build Coastguard Worker""",
1082*da0073e9SAndroid Build Coastguard Worker        )
1083*da0073e9SAndroid Build Coastguard Worker
1084*da0073e9SAndroid Build Coastguard Worker    def test_autogradother(self):
1085*da0073e9SAndroid Build Coastguard Worker        dispatcher = PythonDispatcher()
1086*da0073e9SAndroid Build Coastguard Worker        dispatcher.register(["CPU", "FPGA", "CompositeImplicitAutograd"])
1087*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1088*da0073e9SAndroid Build Coastguard Worker            dispatcher.dispatchTable(),
1089*da0073e9SAndroid Build Coastguard Worker            """\
1090*da0073e9SAndroid Build Coastguard Worker
1091*da0073e9SAndroid Build Coastguard WorkerComputed Dispatch Table
1092*da0073e9SAndroid Build Coastguard Workerkey             kernel
1093*da0073e9SAndroid Build Coastguard Worker---------------------------
1094*da0073e9SAndroid Build Coastguard WorkerCPU             fn_CPU [kernel]
1095*da0073e9SAndroid Build Coastguard WorkerXLA             fn_CompositeImplicitAutograd [math kernel]
1096*da0073e9SAndroid Build Coastguard WorkerLazy            fn_CompositeImplicitAutograd [math kernel]
1097*da0073e9SAndroid Build Coastguard WorkerFPGA            fn_FPGA [kernel]
1098*da0073e9SAndroid Build Coastguard WorkerAutogradOther   ambiguous_autogradother [ambiguous autogradother]
1099*da0073e9SAndroid Build Coastguard WorkerAutogradCPU     [backend fallback]
1100*da0073e9SAndroid Build Coastguard WorkerAutogradXLA     fn_CompositeImplicitAutograd [math kernel]
1101*da0073e9SAndroid Build Coastguard WorkerAutogradLazy    fn_CompositeImplicitAutograd [math kernel]
1102*da0073e9SAndroid Build Coastguard Worker""",
1103*da0073e9SAndroid Build Coastguard Worker        )
1104*da0073e9SAndroid Build Coastguard Worker
1105*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1106*da0073e9SAndroid Build Coastguard Worker            dispatcher.registrations(),
1107*da0073e9SAndroid Build Coastguard Worker            """\
1108*da0073e9SAndroid Build Coastguard Worker
1109*da0073e9SAndroid Build Coastguard WorkerRegistered Kernels
1110*da0073e9SAndroid Build Coastguard Workerkey             kernel
1111*da0073e9SAndroid Build Coastguard Worker---------------------------
1112*da0073e9SAndroid Build Coastguard WorkerFPGA            fn_FPGA
1113*da0073e9SAndroid Build Coastguard WorkerCPU             fn_CPU
1114*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias] fn_CompositeImplicitAutograd
1115*da0073e9SAndroid Build Coastguard Worker""",
1116*da0073e9SAndroid Build Coastguard Worker        )
1117*da0073e9SAndroid Build Coastguard Worker
1118*da0073e9SAndroid Build Coastguard Worker    def test_duplicate_registrations(self):
1119*da0073e9SAndroid Build Coastguard Worker        dispatcher = PythonDispatcher()
1120*da0073e9SAndroid Build Coastguard Worker
1121*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Overriden is not allowed"):
1122*da0073e9SAndroid Build Coastguard Worker            dispatcher.register(["CPU", "CPU"])
1123*da0073e9SAndroid Build Coastguard Worker
1124*da0073e9SAndroid Build Coastguard Worker    def test_defaultbackend_math(self):
1125*da0073e9SAndroid Build Coastguard Worker        dispatcher = PythonDispatcher()
1126*da0073e9SAndroid Build Coastguard Worker
1127*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1128*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1129*da0073e9SAndroid Build Coastguard Worker            r"Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed",
1130*da0073e9SAndroid Build Coastguard Worker        ):
1131*da0073e9SAndroid Build Coastguard Worker            dispatcher.register(
1132*da0073e9SAndroid Build Coastguard Worker                ["CompositeExplicitAutograd", "CompositeImplicitAutograd"]
1133*da0073e9SAndroid Build Coastguard Worker            )
1134*da0073e9SAndroid Build Coastguard Worker
1135*da0073e9SAndroid Build Coastguard Worker    def test_quantized_structured_not_implemented(self):
1136*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros([1, 1, 1])
1137*da0073e9SAndroid Build Coastguard Worker        y = torch.zeros([1, 1, 1])
1138*da0073e9SAndroid Build Coastguard Worker        scale, zero_point = 1.0, 0
1139*da0073e9SAndroid Build Coastguard Worker        dtype = torch.qint8
1140*da0073e9SAndroid Build Coastguard Worker        qx = torch.quantize_per_tensor(x, scale, zero_point, dtype)
1141*da0073e9SAndroid Build Coastguard Worker        qy = torch.quantize_per_tensor(y, scale, zero_point, dtype)
1142*da0073e9SAndroid Build Coastguard Worker        # If bmm gets quantized support you need to update this to something
1143*da0073e9SAndroid Build Coastguard Worker        # else that is not implemented
1144*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
1145*da0073e9SAndroid Build Coastguard Worker            NotImplementedError,
1146*da0073e9SAndroid Build Coastguard Worker            "Could not run 'aten::bmm.out' with arguments from the 'QuantizedCPU' backend.",
1147*da0073e9SAndroid Build Coastguard Worker            lambda: torch.bmm(qx, qy),
1148*da0073e9SAndroid Build Coastguard Worker        )
1149*da0073e9SAndroid Build Coastguard Worker
1150*da0073e9SAndroid Build Coastguard Worker
1151*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1152*da0073e9SAndroid Build Coastguard Worker    run_tests()
1153