xref: /aosp_15_r20/external/pytorch/test/test_dispatch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dispatch"]
2
3import itertools
4import os
5import re
6from collections import namedtuple
7
8import torch._C as C
9import torch.utils.cpp_extension
10from torch._python_dispatcher import PythonDispatcher
11from torch.testing._internal.common_utils import run_tests, TestCase
12
13
14# TODO: Expand the dispatcher API to be a generic API for interfacing with
15# the dispatcher from Python!
16#
17# These are exhaustive tests for commutativity of dispatch behavior.  If you're
18# looking for more usage-info style tests, check op_registration_test.cpp
19#
20# Things not tested here:
21#   - Listeners
22#   - Top level namespace registrations
23#   - Fallback
24#   - Exotic overloads of CppFunction/schema
25#
26# Things not directly tested here:
27#   - Internal state of Dispatcher makes sense.  This is indirectly
28#     tested by the invariant testing
29
30Result = namedtuple("Result", "state table provenance")
31
32dispatch_keys_to_check = (
33    "Undefined",
34    "CPU",
35    "CUDA",
36    "XLA",
37    "AutogradOther",
38    "AutogradCPU",
39    "AutogradCUDA",
40    "AutogradXLA",
41)
42
43
44def extract_dispatch_table_with_keys(table, dispatch_keys):
45    extracted = ""
46    table_entries = table.split("\n")
47    regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)")
48    for k in dispatch_keys:
49        for t in table_entries:
50            if t.startswith(k):
51                # mask out file:line info for in-tree backend fallback
52                entry = regex.sub("registered in pytorch framework [", t)
53                extracted += entry + "\n"
54    return extracted
55
56
57class TestDispatch(TestCase):
58    namespace_index = 0
59
60    def test_all_invariants(self):
61        # Check that the regular stuff is OK!
62        C._dispatch_check_all_invariants()
63
64    # You probably don't want to call this directly; if your constructors
65    # don't commute, you can still run commute with a fixed ctor_order
66    # so that you can test that the destructors still commute
67    def run_ops(
68        self,
69        name,
70        ops,
71        ctor_order=None,
72        dtor_order=None,
73        results=None,
74        expect_raises=False,
75    ):
76        """
77        Given a list of operator registrations, run the registrations in the
78        order specified by ctor_order, and then run the deregistrations in
79        dtor_order.
80
81        If results is specified, intermediate results are checked for consistency
82        with results stored in results (and stored in results if this is the
83        first time we've seen them).  Results are expected to be equivalent
84        modulo commutativity and inverses (thus, results is keyed on a frozenset
85        of in effect registrations from ops).  Results stores namedtuple
86        Result[state, table, provenance], where state is a string that contains
87        non-derived kernel registered or error message if it doesn't pass;
88        table is a string that contains computed dispatch table entries;
89        provenance is a string that describes how exactly we got this string.
90
91        If expect_raises is True, it is not an error to raise an exception.  Instead,
92        we'll store the exception string (instead of the dispatcher state)
93        in results.  In principle we should flag these differently, but it's
94        very obvious when you get an error in one case but not another.
95        """
96        # By allocating every test into a fresh namespace, this makes it less
97        # likely that a bug in the testing framework will result in tests
98        # interfering with each other
99        self.__class__.namespace_index += 1
100        if results is None:
101            results = {}
102        if ctor_order is None:
103            ctor_order = list(range(len(ops)))
104        if dtor_order is None:
105            dtor_order = list(reversed(ctor_order))
106        # Refs which retain the c10::Module object so we can explicitly control
107        # when each deregistration happens (deregistration occurs when the
108        # object gets deallocated).
109        refs = [None] * len(ops)
110        # Keep track of the set "in effect" registrations
111        active_ops = set()
112
113        # double underscore to make it less likely we conflict with something
114        # else
115        test_namespace = f"__test{self.namespace_index}__"
116
117        def check_invariants(actual_provenance):
118            C._dispatch_check_invariants(name)
119            # Normalize the test namespace so that expected outputs are stable
120            actual_state = C._dispatch_dump(f"{test_namespace}::{name}").replace(
121                test_namespace, "test"
122            )
123            actual_table = C._dispatch_dump_table(f"{test_namespace}::{name}").replace(
124                test_namespace, "test"
125            )
126            expected_state, expected_table, expected_provenance = results.setdefault(
127                frozenset(active_ops),
128                Result(actual_state, actual_table, actual_provenance),
129            )
130            self.assertMultiLineEqual(
131                expected_state,
132                actual_state,
133                f"expected from {expected_provenance}; actual from {actual_provenance}",
134            )
135            self.assertMultiLineEqual(
136                expected_table,
137                actual_table,
138                f"expected from {expected_provenance}; actual from {actual_provenance}",
139            )
140
141        results.setdefault(frozenset(), Result("", "", "hardcoded initial state"))
142        check_invariants("initial state")
143        # In the order specified by ctor_order, run registrations
144        set_to_report = frozenset(range(len(ops)))
145        for i, op_ix in enumerate(ctor_order):
146            # It would be better to DEF here, but because we manage
147            # lifetime of multiple registrations with multiple Library
148            # references (refs), we can't deal with the strict checking
149            # from DEF.
150            refs[op_ix] = C._dispatch_library("FRAGMENT", test_namespace, "")
151            active_ops.add(op_ix)
152            try:
153                ops[op_ix](refs[op_ix])
154                check_invariants(f"running ctors {ctor_order[:i + 1]}")
155            except RuntimeError as e:
156                if not expect_raises:
157                    raise
158                actual = str(e).replace(test_namespace, "test")
159                actual = actual.split("\nException raised from ")[0]
160                expected, _, expected_provenance = results.setdefault(
161                    frozenset(active_ops),
162                    Result(
163                        actual, "", f"error after running ctors {ctor_order[:i + 1]}"
164                    ),
165                )
166                self.assertMultiLineEqual(expected, actual, expected_provenance)
167                set_to_report = frozenset(active_ops)
168                active_ops.remove(op_ix)
169                # NB: this finally test asserts that if a registrations fails,
170                # the dispatcher is left in the same state *that it was before*!
171                check_invariants(
172                    f"running ctors {ctor_order[:i]} and then failing to run ctor {op_ix} "
173                    "(did this failure leave the dispatcher in a wedged state? "
174                    "it shouldn't!)"
175                )
176                break
177        last_ctor = i
178        if expect_raises and len(active_ops) == len(ops):
179            # Destroy references first, as some test frameworks (like pytest)
180            # will retain references in the exception raised by assertTrue! EW!
181            refs = None
182            self.assertTrue(
183                False,
184                "expected exception to be raised, but nothing was raised "
185                f"(after running ctors {ctor_order})",
186            )
187        # In the order specified by dtor_order, run deregistrations
188        for i, op_ix in enumerate(dtor_order):
189            # Trigger a destruction
190            refs[op_ix] = None
191            # discard not remove, since we may not have actually deregistered
192            # anything if there was an error raised
193            if expect_raises:
194                active_ops.discard(op_ix)
195            else:
196                active_ops.remove(op_ix)
197            check_invariants(
198                f"running ctors {ctor_order[:last_ctor + 1]}, then running dtors {dtor_order[:i + 1]}"
199            )
200        return results[set_to_report][0]
201
202    # Operator registrations are commutative (as static initializers can
203    # run in any order) and invertible (by deregistration).  (Subject
204    # to some caveats: some legacy behavior in the system are not commutative--
205    # we want to get rid of these!)
206    #
207    # So while in principle we could simply test a set of operations
208    # by just running them one by one in the order specified by the user,
209    # we can get more assurance about these extra properties by doing
210    # more work:
211    #
212    # 1. Don't run the registrations once in a fixed order: run every possible
213    #    permutation.  Similarly, run every permutation of deregistration order.
214    #
215    # 2. Don't just check the end state of the dispatcher: for every
216    #    subset of operator registrations, ensure that the computed
217    #    intermediate state is path independent.  One thing to note:
218    #    in this function, we assume each operation is unique.  In general,
219    #    there may be duplicated registrations, but these are usually
220    #    idempotent or legacy.  We test for behavior here separately.
221    #
222    # NB: checking all permutations means this function is exponential in
223    # the length of ops!  So don't pass too many ops to this function!
224    def commute(self, name, ops, ctor_order=None, expect_raises=False):
225        results = {}
226
227        def go(ctor_order):
228            for dtor_order in itertools.permutations(range(len(ops))):
229                self.run_ops(
230                    name,
231                    ops,
232                    ctor_order,
233                    dtor_order,
234                    results=results,
235                    expect_raises=expect_raises,
236                )
237
238        if ctor_order is not None:
239            go(ctor_order)
240        else:
241            for ctor_order in itertools.permutations(range(len(ops))):
242                go(ctor_order)
243
244        # Return the "full" Result namedtuple after all operations are run.
245        # If this KeyErrors, that means that there did not exist any
246        # ordering of ctors which got us to the "end".  That's an
247        # error in test construction: it means you could have
248        # factored the test into two smaller ones.
249        return results[frozenset(range(len(ops)))]
250
251    def test_def(self):
252        state = self.commute(
253            "foo",
254            [
255                # m.def("foo(Tensor x) -> Tensor")
256                lambda m: m.def_("foo(Tensor x) -> Tensor"),
257                # m.impl("test_def", [](const Tensor& x) { return x })
258                lambda m: m.impl_t_t("foo"),
259                # m.impl("test_def", kCPU, [](const Tensor& x) { return x })
260                lambda m: m.impl_t_t("foo", dispatch="CPU"),
261                # m.impl("test_def", kAutograd, [](const Tensor& x) { return x })
262                lambda m: m.impl_t_t("foo", dispatch="Autograd"),
263                # m.impl("test_def", kAutogradCPU, [](const Tensor& x) { return x })
264                lambda m: m.impl_t_t("foo", dispatch="AutogradCPU"),
265            ],
266        ).state
267        self.assertExpectedInline(
268            state,
269            """\
270name: test::foo
271schema: test::foo(Tensor x) -> Tensor
272debug: registered at /dev/null:0
273alias analysis kind: FROM_SCHEMA
274CPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
275AutogradCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
276Autograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
277CompositeImplicitAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
278""",
279        )
280
281    def test_def_impl_schema_mismatch(self):
282        # NB: an impl-impl mismatch is not reported eagerly; you'll find out
283        # about it because one of them won't match with def
284        state = self.commute(
285            "foo",
286            [
287                # m.def("foo(Tensor x, Tensor y) -> Tensor")
288                lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
289                # m.impl("foo", [](const Tensor & x) { return x })
290                lambda m: m.impl_t_t("foo"),
291            ],
292            expect_raises=True,
293        ).state
294        self.assertExpectedInline(
295            state,
296            """\
297Inferred operator schema for a C++ kernel function doesn't match the expected function schema.
298  operator: test::foo
299  expected schema: test::foo(Tensor x, Tensor y) -> Tensor
300    registered at /dev/null:0
301  inferred schema: (Tensor _0) -> Tensor _0
302    impl_t_t
303  reason: The number of arguments is different. 2 vs 1.""",
304        )
305
306    def test_def_with_inference(self):
307        state = self.commute(
308            "foo",
309            [
310                # m.def("foo", [](const Tensor & x) { return x })
311                lambda m: m.def_name_t_t("foo"),
312                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
313                lambda m: m.impl_t_t("foo", "CPU"),
314                # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
315                lambda m: m.impl_t_t("foo", "Autograd"),
316                # m.impl("foo", torch::kAutogradCPU, [](const Tensor & x) { return x })
317                lambda m: m.impl_t_t("foo", "AutogradCPU"),
318            ],
319        ).state
320        self.assertExpectedInline(
321            state,
322            """\
323name: test::foo
324schema: test::foo(Tensor _0) -> Tensor _0
325debug: registered at /dev/null:0
326alias analysis kind: CONSERVATIVE
327CPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
328AutogradCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
329Autograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
330CompositeImplicitAutograd[alias]: default_def_name_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
331""",
332        )
333
334    def test_def_only(self):
335        state = self.commute(
336            "foo",
337            [
338                # m.def("foo(Tensor x, Tensor y) -> Tensor")
339                lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
340            ],
341        ).state
342        self.assertExpectedInline(
343            state,
344            """\
345name: test::foo
346schema: test::foo(Tensor x, Tensor y) -> Tensor
347debug: registered at /dev/null:0
348alias analysis kind: FROM_SCHEMA
349""",
350        )
351
352    def test_impl_only(self):
353        state = self.commute(
354            "foo",
355            [
356                # m.impl("foo", [](const Tensor& x) { return x })
357                lambda m: m.impl_t_t("foo"),
358                # m.impl("foo", torch::kCPU, [](const Tensor& x) { return x })
359                lambda m: m.impl_t_t("foo", "CPU"),
360                # m.impl("foo", torch::kAutograd, [](const Tensor& x) { return x })
361                lambda m: m.impl_t_t("foo", "Autograd"),
362                # m.impl("foo", torch::kAutogradCPU, [](const Tensor& x) { return x })
363                lambda m: m.impl_t_t("foo", "AutogradCPU"),
364            ],
365        ).state
366        self.assertExpectedInline(
367            state,
368            """\
369name: test::foo
370schema: (none)
371CPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
372AutogradCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
373Autograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
374CompositeImplicitAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
375""",
376        )
377
378    def test_computed_table(self):
379        result = self.commute(
380            "foo",
381            [
382                # m.def("foo", [](const Tensor & x) { return x })
383                lambda m: m.def_name_t_t("foo"),
384                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
385                lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
386                # m.impl("foo", torch::kCUDA, [](const Tensor & x) { return x })
387                lambda m: m.impl_t_t("foo", "XLA", debug="fn_xla"),
388                # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
389                lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"),
390                # m.impl("foo", torch::kAutogradCPU, [](const Tensor & x) { return x })
391                lambda m: m.impl_t_t("foo", "AutogradCPU", debug="fn_autogradcpu"),
392            ],
393        )
394        state, table = result.state, result.table
395        self.assertExpectedInline(
396            state,
397            """\
398name: test::foo
399schema: test::foo(Tensor _0) -> Tensor _0
400debug: registered at /dev/null:0
401alias analysis kind: CONSERVATIVE
402CPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
403XLA: fn_xla :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
404AutogradCPU: fn_autogradcpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
405Autograd[alias]: fn_autograd :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
406CompositeImplicitAutograd[alias]: default_def_name_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
407""",
408        )
409
410        # computed dispatch table is too big, so we only check on a few entries we're interested in.
411        extracted_table = extract_dispatch_table_with_keys(
412            table, dispatch_keys_to_check
413        )
414
415        self.assertExpectedInline(
416            extracted_table,
417            """\
418Undefined: default_def_name_t_t [math kernel]
419CPU: fn_cpu [kernel]
420CUDA: default_def_name_t_t [math kernel]
421XLA: fn_xla [kernel]
422AutogradOther: default_def_name_t_t [math kernel]
423AutogradCPU: fn_autogradcpu [kernel]
424AutogradCUDA: default_def_name_t_t [math kernel]
425AutogradXLA: fn_autograd [autograd kernel]
426""",
427        )
428
429    def test_computed_table_with_cpu_math_autogradcpu_fallthrough(self):
430        global_m = C._dispatch_library("IMPL", "_", "AutogradCPU")
431        result = self.commute(
432            "foo",
433            [
434                # m.def("foo", [](const Tensor & x) { return x })
435                lambda m: m.def_name_t_t("foo"),
436                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
437                lambda m: m.impl_t_t("foo", "CPU"),
438            ],
439        )
440        state, table = result.state, result.table
441        self.assertExpectedInline(
442            state,
443            """\
444name: test::foo
445schema: test::foo(Tensor _0) -> Tensor _0
446debug: registered at /dev/null:0
447alias analysis kind: CONSERVATIVE
448CPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
449CompositeImplicitAutograd[alias]: default_def_name_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
450""",
451        )
452
453        # computed dispatch table is too big, so we only check on a few entries we're interested in.
454        extracted_table = extract_dispatch_table_with_keys(
455            table, dispatch_keys_to_check
456        )
457
458        self.assertExpectedInline(
459            extracted_table,
460            """\
461Undefined: default_def_name_t_t [math kernel]
462CPU: impl_t_t [kernel]
463CUDA: default_def_name_t_t [math kernel]
464XLA: default_def_name_t_t [math kernel]
465AutogradOther: default_def_name_t_t [math kernel]
466AutogradCPU: registered in pytorch framework [backend fallback]
467AutogradCUDA: default_def_name_t_t [math kernel]
468AutogradXLA: default_def_name_t_t [math kernel]
469""",
470        )
471
472    def test_computed_table_with_math(self):
473        global_m = C._dispatch_library("IMPL", "_", "AutogradCPU")
474        result = self.commute(
475            "foo",
476            [
477                # m.def("foo(Tensor x) -> Tensor")
478                lambda m: m.def_("foo(Tensor x) -> Tensor"),
479                # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x })
480                lambda m: m.impl_t_t("foo", "CompositeImplicitAutograd"),
481            ],
482        )
483        state, table = result.state, result.table
484        self.assertExpectedInline(
485            state,
486            """\
487name: test::foo
488schema: test::foo(Tensor x) -> Tensor
489debug: registered at /dev/null:0
490alias analysis kind: FROM_SCHEMA
491CompositeImplicitAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
492""",
493        )
494
495        # computed dispatch table is too big, so we only check on a few entries we're interested in.
496        extracted_table = extract_dispatch_table_with_keys(
497            table, dispatch_keys_to_check
498        )
499
500        self.assertExpectedInline(
501            extracted_table,
502            """\
503Undefined: impl_t_t [math kernel]
504CPU: impl_t_t [math kernel]
505CUDA: impl_t_t [math kernel]
506XLA: impl_t_t [math kernel]
507AutogradOther: impl_t_t [math kernel]
508AutogradCPU: impl_t_t [math kernel]
509AutogradCUDA: impl_t_t [math kernel]
510AutogradXLA: impl_t_t [math kernel]
511""",
512        )
513
514    def test_computed_table_with_cpu_math(self):
515        global_m = C._dispatch_library("IMPL", "_", "AutogradCPU")
516        result = self.commute(
517            "foo",
518            [
519                # m.def("foo(Tensor x) -> Tensor")
520                lambda m: m.def_("foo(Tensor x) -> Tensor"),
521                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
522                lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
523                # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x })
524                lambda m: m.impl_t_t(
525                    "foo", "CompositeImplicitAutograd", debug="fn_math"
526                ),
527            ],
528        )
529        state, table = result.state, result.table
530        self.assertExpectedInline(
531            state,
532            """\
533name: test::foo
534schema: test::foo(Tensor x) -> Tensor
535debug: registered at /dev/null:0
536alias analysis kind: FROM_SCHEMA
537CPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
538CompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
539""",
540        )
541
542        # computed dispatch table is too big, so we only check on a few entries we're interested in.
543        extracted_table = extract_dispatch_table_with_keys(
544            table, dispatch_keys_to_check
545        )
546
547        self.assertExpectedInline(
548            extracted_table,
549            """\
550Undefined: fn_math [math kernel]
551CPU: fn_cpu [kernel]
552CUDA: fn_math [math kernel]
553XLA: fn_math [math kernel]
554AutogradOther: fn_math [math kernel]
555AutogradCPU: registered in pytorch framework [backend fallback]
556AutogradCUDA: fn_math [math kernel]
557AutogradXLA: fn_math [math kernel]
558""",
559        )
560
561    def test_computed_table_with_autograd(self):
562        global_m = C._dispatch_library("IMPL", "_", "AutogradCPU")
563        result = self.commute(
564            "foo",
565            [
566                # m.def("foo(Tensor x) -> Tensor")
567                lambda m: m.def_("foo(Tensor x) -> Tensor"),
568                # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
569                lambda m: m.impl_t_t("foo", "Autograd"),
570            ],
571        )
572        state, table = result.state, result.table
573        self.assertExpectedInline(
574            state,
575            """\
576name: test::foo
577schema: test::foo(Tensor x) -> Tensor
578debug: registered at /dev/null:0
579alias analysis kind: FROM_SCHEMA
580Autograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
581""",
582        )
583
584        # computed dispatch table is too big, so we only check on a few entries we're interested in.
585        extracted_table = extract_dispatch_table_with_keys(
586            table, dispatch_keys_to_check
587        )
588
589        self.assertExpectedInline(
590            extracted_table,
591            """\
592AutogradOther: impl_t_t [autograd kernel]
593AutogradCPU: impl_t_t [autograd kernel]
594AutogradCUDA: impl_t_t [autograd kernel]
595AutogradXLA: impl_t_t [autograd kernel]
596""",
597        )
598
599    # Now that catchAll maps to CompositeImplicitAutograd, registering to both
600    # catchAll and CompositeImplicitAutograd breaks commutativity.
601    def test_computed_table_with_cpu_autograd_math(self):
602        result = self.commute(
603            "foo",
604            [
605                # m.def("foo(Tensor x) -> Tensor")
606                lambda m: m.def_("foo(Tensor x) -> Tensor"),
607                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
608                lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
609                # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
610                lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"),
611                # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x })
612                lambda m: m.impl_t_t(
613                    "foo", "CompositeImplicitAutograd", debug="fn_math"
614                ),
615            ],
616        )
617        state, table = result.state, result.table
618        self.assertExpectedInline(
619            state,
620            """\
621name: test::foo
622schema: test::foo(Tensor x) -> Tensor
623debug: registered at /dev/null:0
624alias analysis kind: FROM_SCHEMA
625CPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
626Autograd[alias]: fn_autograd :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
627CompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
628""",
629        )
630
631        # computed dispatch table is too big, so we only check on a few entries we're interested in.
632        extracted_table = extract_dispatch_table_with_keys(
633            table, dispatch_keys_to_check
634        )
635
636        self.assertExpectedInline(
637            extracted_table,
638            """\
639Undefined: fn_math [math kernel]
640CPU: fn_cpu [kernel]
641CUDA: fn_math [math kernel]
642XLA: fn_math [math kernel]
643AutogradOther: fn_math [math kernel]
644AutogradCPU: fn_autograd [autograd kernel]
645AutogradCUDA: fn_math [math kernel]
646AutogradXLA: fn_math [math kernel]
647""",
648        )
649
650    def test_computed_table_with_ambiguous_autogradother(self):
651        result = self.commute(
652            "foo",
653            [
654                # m.def("foo(Tensor x) -> Tensor")
655                lambda m: m.def_("foo(Tensor x) -> Tensor"),
656                # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x })
657                lambda m: m.impl_t_t(
658                    "foo", "CompositeImplicitAutograd", debug="fn_math"
659                ),
660                # m.impl("foo", torch::kFPGA, [](const Tensor & x) { return x })
661                lambda m: m.impl_t_t("foo", "FPGA", debug="fn_fpga"),
662            ],
663        )
664        state, table = result.state, result.table
665        self.assertExpectedInline(
666            state,
667            """\
668name: test::foo
669schema: test::foo(Tensor x) -> Tensor
670debug: registered at /dev/null:0
671alias analysis kind: FROM_SCHEMA
672FPGA: fn_fpga :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
673CompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
674""",
675        )
676
677        # computed dispatch table is too big, so we only check on a few entries we're interested in.
678        extracted_table = extract_dispatch_table_with_keys(
679            table, dispatch_keys_to_check + ("FPGA",)
680        )
681
682        self.assertExpectedInline(
683            extracted_table,
684            """\
685Undefined: fn_math [math kernel]
686CPU: fn_math [math kernel]
687CUDA: fn_math [math kernel]
688XLA: fn_math [math kernel]
689AutogradOther: ambiguous_autogradother [ambiguous autogradother]
690AutogradCPU: fn_math [math kernel]
691AutogradCUDA: fn_math [math kernel]
692AutogradXLA: fn_math [math kernel]
693FPGA: fn_fpga [kernel]
694""",
695        )
696
697    def test_computed_table_with_cpu_defaultbackend(self):
698        result = self.commute(
699            "foo",
700            [
701                # m.def("foo(Tensor x) -> Tensor")
702                lambda m: m.def_("foo(Tensor x) -> Tensor"),
703                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
704                lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
705                # m.impl("foo", torch::kCompositeExplicitAutograd, [](const Tensor & x) { return x })
706                lambda m: m.impl_t_t(
707                    "foo", "CompositeExplicitAutograd", debug="fn_defaultbackend"
708                ),
709            ],
710        )
711        state, table = result.state, result.table
712        self.assertExpectedInline(
713            state,
714            """\
715name: test::foo
716schema: test::foo(Tensor x) -> Tensor
717debug: registered at /dev/null:0
718alias analysis kind: FROM_SCHEMA
719CPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
720CompositeExplicitAutograd[alias]: fn_defaultbackend :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
721""",
722        )
723
724        # computed dispatch table is too big, so we only check on a few entries we're interested in.
725        extracted_table = extract_dispatch_table_with_keys(
726            table, dispatch_keys_to_check
727        )
728
729        self.assertExpectedInline(
730            extracted_table,
731            """\
732Undefined: fn_defaultbackend [default backend kernel]
733CPU: fn_cpu [kernel]
734CUDA: fn_defaultbackend [default backend kernel]
735XLA: fn_defaultbackend [default backend kernel]
736AutogradOther: registered in pytorch framework [backend fallback]
737AutogradCPU: registered in pytorch framework [backend fallback]
738AutogradCUDA: registered in pytorch framework [backend fallback]
739AutogradXLA: registered in pytorch framework [backend fallback]
740""",
741        )
742
743    def test_computed_table_with_cpu_autograd_defaultbackend(self):
744        result = self.commute(
745            "foo",
746            [
747                # m.def("foo(Tensor x) -> Tensor")
748                lambda m: m.def_("foo(Tensor x) -> Tensor"),
749                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
750                lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
751                # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
752                lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"),
753                # m.impl("foo", torch::kCompositeExplicitAutograd, [](const Tensor & x) { return x })
754                lambda m: m.impl_t_t(
755                    "foo", "CompositeExplicitAutograd", debug="fn_defaultbackend"
756                ),
757            ],
758        )
759        state, table = result.state, result.table
760        self.assertExpectedInline(
761            state,
762            """\
763name: test::foo
764schema: test::foo(Tensor x) -> Tensor
765debug: registered at /dev/null:0
766alias analysis kind: FROM_SCHEMA
767CPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
768Autograd[alias]: fn_autograd :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
769CompositeExplicitAutograd[alias]: fn_defaultbackend :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
770""",
771        )
772
773        # computed dispatch table is too big, so we only check on a few entries we're interested in.
774        extracted_table = extract_dispatch_table_with_keys(
775            table, dispatch_keys_to_check + ("FPGA",)
776        )
777
778        self.assertExpectedInline(
779            extracted_table,
780            """\
781Undefined: fn_defaultbackend [default backend kernel]
782CPU: fn_cpu [kernel]
783CUDA: fn_defaultbackend [default backend kernel]
784XLA: fn_defaultbackend [default backend kernel]
785AutogradOther: fn_autograd [autograd kernel]
786AutogradCPU: fn_autograd [autograd kernel]
787AutogradCUDA: fn_autograd [autograd kernel]
788AutogradXLA: fn_autograd [autograd kernel]
789FPGA: fn_defaultbackend [default backend kernel]
790""",
791        )
792
793    def test_computed_table_with_cpu_autograd_math_defaultbackend(self):
794        result = self.commute(
795            "foo",
796            [
797                # m.def("foo(Tensor x) -> Tensor")
798                lambda m: m.def_("foo(Tensor x) -> Tensor"),
799                # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x })
800                lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"),
801                # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
802                lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"),
803                # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x })
804                lambda m: m.impl_t_t(
805                    "foo", "CompositeImplicitAutograd", debug="fn_math"
806                ),
807                # m.impl("foo", torch::kCompositeExplicitAutograd, [](const Tensor & x) { return x })
808                lambda m: m.impl_t_t(
809                    "foo", "CompositeExplicitAutograd", debug="fn_defaultbackend"
810                ),
811            ],
812        )
813        state, table = result.state, result.table
814        self.assertExpectedInline(
815            state,
816            """\
817name: test::foo
818schema: test::foo(Tensor x) -> Tensor
819debug: registered at /dev/null:0
820alias analysis kind: FROM_SCHEMA
821CPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
822Autograd[alias]: fn_autograd :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
823CompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
824CompositeExplicitAutograd[alias]: fn_defaultbackend :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
825""",
826        )
827
828        # computed dispatch table is too big, so we only check on a few entries we're interested in.
829        extracted_table = extract_dispatch_table_with_keys(
830            table, dispatch_keys_to_check
831        )
832
833        self.assertExpectedInline(
834            extracted_table,
835            """\
836Undefined: fn_defaultbackend [default backend kernel]
837CPU: fn_cpu [kernel]
838CUDA: fn_defaultbackend [default backend kernel]
839XLA: fn_defaultbackend [default backend kernel]
840AutogradOther: fn_autograd [autograd kernel]
841AutogradCPU: fn_autograd [autograd kernel]
842AutogradCUDA: fn_autograd [autograd kernel]
843AutogradXLA: fn_autograd [autograd kernel]
844""",
845        )
846
847    def test_multiple_def_error(self):
848        ops = [
849            # m.def("foo(Tensor x, Tensor y) -> Tensor")
850            lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
851            # m.def("foo(Tensor x, Tensor y) -> Tensor")
852            lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
853        ]
854        self.assertExpectedInline(
855            self.commute("foo", ops, expect_raises=True).state,
856            """Tried to register an operator (test::foo(Tensor x, Tensor y) -> Tensor) with the same name and overload """
857            """name multiple times. Each overload's schema should only be registered with a single call to def(). """
858            """Duplicate registration: registered at /dev/null:0. Original registration: registered at /dev/null:0""",
859        )
860
861    def test_def_with_explicit_alias(self):
862        state = self.commute(
863            "foo",
864            [
865                # m.def(torch::schema(
866                #   "foo(Tensor x, Tensor y) -> Tensor",
867                #   AliasAnalysisKind::PURE))
868                lambda m: m.def_(
869                    "foo(Tensor x, Tensor y) -> Tensor", alias="PURE_FUNCTION"
870                )
871            ],
872        ).state
873        self.assertExpectedInline(
874            state,
875            """\
876name: test::foo
877schema: test::foo(Tensor x, Tensor y) -> Tensor
878debug: registered at /dev/null:0
879alias analysis kind: PURE_FUNCTION
880""",
881        )
882
883    def test_multiple_def_alias_defaulting(self):
884        ops = [
885            # m.def(torch::schema("foo(Tensor x) -> Tensor",
886            #                     c10::AliasAnalysisKind::PURE_FUNCTION))
887            lambda m: m.def_("foo(Tensor x) -> Tensor", alias="PURE_FUNCTION"),
888            # RegisterOperators().op("foo(Tensor x) -> Tensor")
889            lambda m: m.def_legacy("foo(Tensor x) -> Tensor"),
890        ]
891        self.assertExpectedInline(
892            self.commute("foo", ops, expect_raises=True).state,
893            """Tried to register an operator (test::foo(Tensor x) -> Tensor) with the same name and overload """
894            """name multiple times. Each overload's schema should only be registered with a single call to def(). """
895            """Duplicate registration: registered at /dev/null:0. Original registration: registered at /dev/null:0""",
896        )
897
898    def test_multiple_def_alias_mismatch(self):
899        ops = [
900            # m.def(torch::schema("foo(Tensor x) -> Tensor",
901            #                     c10::AliasAnalysisKind::PURE_FUNCTION))
902            lambda m: m.def_("foo(Tensor x) -> Tensor", alias="PURE_FUNCTION"),
903            # m.def(torch::schema("foo(Tensor x) -> Tensor",
904            #                     c10::AliasAnalysisKind::CONSERVATIVE))
905            lambda m: m.def_("foo(Tensor x) -> Tensor", alias="CONSERVATIVE"),
906        ]
907        self.assertExpectedInline(
908            self.commute("foo", ops, expect_raises=True).state,
909            """Tried to register an operator (test::foo(Tensor x) -> Tensor) with the same name and overload """
910            """name multiple times. Each overload's schema should only be registered with a single call to def(). """
911            """Duplicate registration: registered at /dev/null:0. Original registration: registered at /dev/null:0""",
912        )
913
914    def test_multiple_fallback(self):
915        global_m = C._dispatch_library("IMPL", "_", "XLA")
916        global_m.fallback_fallthrough()
917        try:
918            global_m.fallback_fallthrough()
919        except RuntimeError as e:
920            self.assertExpectedInline(
921                str(e),
922                """Tried to register multiple backend fallbacks for the same dispatch key XLA; previous registration """
923                """registered at /dev/null:0, new registration registered at /dev/null:0""",
924            )
925        else:
926            self.assertTrue(False)
927
928    def test_overwrite_math(self):
929        ops = [
930            lambda m: m.impl_t_t("foo", debug="fn1"),
931            lambda m: m.impl_t_t("foo", debug="fn2"),
932        ]
933        # Not commutative
934        self.assertExpectedInline(
935            self.commute("foo", ops, ctor_order=(0, 1)).state,
936            """\
937name: test::foo
938schema: (none)
939CompositeImplicitAutograd[alias]: fn2 :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
940CompositeImplicitAutograd[alias] (inactive): fn1 :: (Tensor _0) -> Tensor _0 [ boxed unboxed ]
941""",
942        )
943
944    # Definition: a dangling impl happens when someone does an impl() on a
945    # function but not a def() for it. This is usually a bug, e.g. someone
946    # misspelled an operator name, or someone registered an impl for an op that
947    # no longer exists
948    def test_find_dangling_impls(self):
949        dangling_impls = C._dispatch_find_dangling_impls()
950        self.assertEqual(
951            0,
952            len(dangling_impls),
953            msg=f"Expect zero dangling impls, but found: {dangling_impls}",
954        )
955
956    def test_find_dangling_impls_ext(self):
957        extension_path = os.path.join(
958            os.path.dirname(os.path.abspath(__file__)),
959            "cpp_extensions",
960            "dangling_impl_extension.cpp",
961        )
962        module = torch.utils.cpp_extension.load(
963            name="dangling_impl_extension",
964            sources=[extension_path],
965            extra_cflags=["-g"],
966            verbose=True,
967        )
968
969        impls = C._dispatch_find_dangling_impls()
970        self.assertEqual(1, len(impls))
971        self.assertEqual(
972            f"""\
973name: __test::foo
974schema: (none)
975CPU: registered at {extension_path}:5 :: () -> () [ boxed unboxed ]
976""",
977            impls[0],
978        )
979
980    def test_dispatch_print_registrations_for_dispatch_key_invalid(self):
981        with self.assertRaisesRegex(
982            RuntimeError, "could not parse dispatch key: invalid_key"
983        ):
984            C._dispatch_print_registrations_for_dispatch_key("invalid_key")
985
986
987class TestPythonDispatcher(TestCase):
988    def test_basic(self):
989        dispatcher = PythonDispatcher()
990        dispatcher.register(["CPU", "XLA", "Lazy", "CompositeImplicitAutograd"])
991        self.assertExpectedInline(
992            dispatcher.dispatchTable(),
993            """\
994
995Computed Dispatch Table
996key             kernel
997---------------------------
998CPU             fn_CPU [kernel]
999XLA             fn_XLA [kernel]
1000Lazy            fn_Lazy [kernel]
1001FPGA            fn_CompositeImplicitAutograd [math kernel]
1002AutogradOther   fn_CompositeImplicitAutograd [math kernel]
1003AutogradCPU     [backend fallback]
1004AutogradXLA     [backend fallback]
1005AutogradLazy    [backend fallback]
1006""",
1007        )
1008
1009    def test_math_autogradcpu(self):
1010        dispatcher = PythonDispatcher()
1011        dispatcher.register(
1012            ["CPU", "XLA", "Lazy", "CompositeImplicitAutograd", "AutogradCPU"]
1013        )
1014        self.assertExpectedInline(
1015            dispatcher.dispatchTable(),
1016            """\
1017
1018Computed Dispatch Table
1019key             kernel
1020---------------------------
1021CPU             fn_CPU [kernel]
1022XLA             fn_XLA [kernel]
1023Lazy            fn_Lazy [kernel]
1024FPGA            fn_CompositeImplicitAutograd [math kernel]
1025AutogradOther   fn_CompositeImplicitAutograd [math kernel]
1026AutogradCPU     fn_AutogradCPU [kernel]
1027AutogradXLA     [backend fallback]
1028AutogradLazy    [backend fallback]
1029""",
1030        )
1031        self.assertExpectedInline(
1032            dispatcher.registrations(),
1033            """\
1034
1035Registered Kernels
1036key             kernel
1037---------------------------
1038CPU             fn_CPU
1039XLA             fn_XLA
1040Lazy            fn_Lazy
1041AutogradCPU     fn_AutogradCPU
1042CompositeImplicitAutograd[alias] fn_CompositeImplicitAutograd
1043""",
1044        )
1045
1046    def test_defaultbackend_autogradcpu(self):
1047        dispatcher = PythonDispatcher()
1048        dispatcher.register(
1049            ["CPU", "XLA", "Lazy", "CompositeExplicitAutograd", "AutogradCPU"]
1050        )
1051        self.assertExpectedInline(
1052            dispatcher.dispatchTable(),
1053            """\
1054
1055Computed Dispatch Table
1056key             kernel
1057---------------------------
1058CPU             fn_CPU [kernel]
1059XLA             fn_XLA [kernel]
1060Lazy            fn_Lazy [kernel]
1061FPGA            fn_CompositeExplicitAutograd [default backend kernel]
1062AutogradOther   [backend fallback]
1063AutogradCPU     fn_AutogradCPU [kernel]
1064AutogradXLA     [backend fallback]
1065AutogradLazy    [backend fallback]
1066""",
1067        )
1068
1069        self.assertExpectedInline(
1070            dispatcher.registrations(),
1071            """\
1072
1073Registered Kernels
1074key             kernel
1075---------------------------
1076CPU             fn_CPU
1077XLA             fn_XLA
1078Lazy            fn_Lazy
1079AutogradCPU     fn_AutogradCPU
1080CompositeExplicitAutograd[alias] fn_CompositeExplicitAutograd
1081""",
1082        )
1083
1084    def test_autogradother(self):
1085        dispatcher = PythonDispatcher()
1086        dispatcher.register(["CPU", "FPGA", "CompositeImplicitAutograd"])
1087        self.assertExpectedInline(
1088            dispatcher.dispatchTable(),
1089            """\
1090
1091Computed Dispatch Table
1092key             kernel
1093---------------------------
1094CPU             fn_CPU [kernel]
1095XLA             fn_CompositeImplicitAutograd [math kernel]
1096Lazy            fn_CompositeImplicitAutograd [math kernel]
1097FPGA            fn_FPGA [kernel]
1098AutogradOther   ambiguous_autogradother [ambiguous autogradother]
1099AutogradCPU     [backend fallback]
1100AutogradXLA     fn_CompositeImplicitAutograd [math kernel]
1101AutogradLazy    fn_CompositeImplicitAutograd [math kernel]
1102""",
1103        )
1104
1105        self.assertExpectedInline(
1106            dispatcher.registrations(),
1107            """\
1108
1109Registered Kernels
1110key             kernel
1111---------------------------
1112FPGA            fn_FPGA
1113CPU             fn_CPU
1114CompositeImplicitAutograd[alias] fn_CompositeImplicitAutograd
1115""",
1116        )
1117
1118    def test_duplicate_registrations(self):
1119        dispatcher = PythonDispatcher()
1120
1121        with self.assertRaisesRegex(RuntimeError, r"Overriden is not allowed"):
1122            dispatcher.register(["CPU", "CPU"])
1123
1124    def test_defaultbackend_math(self):
1125        dispatcher = PythonDispatcher()
1126
1127        with self.assertRaisesRegex(
1128            RuntimeError,
1129            r"Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed",
1130        ):
1131            dispatcher.register(
1132                ["CompositeExplicitAutograd", "CompositeImplicitAutograd"]
1133            )
1134
1135    def test_quantized_structured_not_implemented(self):
1136        x = torch.zeros([1, 1, 1])
1137        y = torch.zeros([1, 1, 1])
1138        scale, zero_point = 1.0, 0
1139        dtype = torch.qint8
1140        qx = torch.quantize_per_tensor(x, scale, zero_point, dtype)
1141        qy = torch.quantize_per_tensor(y, scale, zero_point, dtype)
1142        # If bmm gets quantized support you need to update this to something
1143        # else that is not implemented
1144        self.assertRaisesRegex(
1145            NotImplementedError,
1146            "Could not run 'aten::bmm.out' with arguments from the 'QuantizedCPU' backend.",
1147            lambda: torch.bmm(qx, qy),
1148        )
1149
1150
1151if __name__ == "__main__":
1152    run_tests()
1153