1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dispatch"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport itertools 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport re 6*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch._C as C 9*da0073e9SAndroid Build Coastguard Workerimport torch.utils.cpp_extension 10*da0073e9SAndroid Build Coastguard Workerfrom torch._python_dispatcher import PythonDispatcher 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker# TODO: Expand the dispatcher API to be a generic API for interfacing with 15*da0073e9SAndroid Build Coastguard Worker# the dispatcher from Python! 16*da0073e9SAndroid Build Coastguard Worker# 17*da0073e9SAndroid Build Coastguard Worker# These are exhaustive tests for commutativity of dispatch behavior. If you're 18*da0073e9SAndroid Build Coastguard Worker# looking for more usage-info style tests, check op_registration_test.cpp 19*da0073e9SAndroid Build Coastguard Worker# 20*da0073e9SAndroid Build Coastguard Worker# Things not tested here: 21*da0073e9SAndroid Build Coastguard Worker# - Listeners 22*da0073e9SAndroid Build Coastguard Worker# - Top level namespace registrations 23*da0073e9SAndroid Build Coastguard Worker# - Fallback 24*da0073e9SAndroid Build Coastguard Worker# - Exotic overloads of CppFunction/schema 25*da0073e9SAndroid Build Coastguard Worker# 26*da0073e9SAndroid Build Coastguard Worker# Things not directly tested here: 27*da0073e9SAndroid Build Coastguard Worker# - Internal state of Dispatcher makes sense. This is indirectly 28*da0073e9SAndroid Build Coastguard Worker# tested by the invariant testing 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard WorkerResult = namedtuple("Result", "state table provenance") 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Workerdispatch_keys_to_check = ( 33*da0073e9SAndroid Build Coastguard Worker "Undefined", 34*da0073e9SAndroid Build Coastguard Worker "CPU", 35*da0073e9SAndroid Build Coastguard Worker "CUDA", 36*da0073e9SAndroid Build Coastguard Worker "XLA", 37*da0073e9SAndroid Build Coastguard Worker "AutogradOther", 38*da0073e9SAndroid Build Coastguard Worker "AutogradCPU", 39*da0073e9SAndroid Build Coastguard Worker "AutogradCUDA", 40*da0073e9SAndroid Build Coastguard Worker "AutogradXLA", 41*da0073e9SAndroid Build Coastguard Worker) 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Workerdef extract_dispatch_table_with_keys(table, dispatch_keys): 45*da0073e9SAndroid Build Coastguard Worker extracted = "" 46*da0073e9SAndroid Build Coastguard Worker table_entries = table.split("\n") 47*da0073e9SAndroid Build Coastguard Worker regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)") 48*da0073e9SAndroid Build Coastguard Worker for k in dispatch_keys: 49*da0073e9SAndroid Build Coastguard Worker for t in table_entries: 50*da0073e9SAndroid Build Coastguard Worker if t.startswith(k): 51*da0073e9SAndroid Build Coastguard Worker # mask out file:line info for in-tree backend fallback 52*da0073e9SAndroid Build Coastguard Worker entry = regex.sub("registered in pytorch framework [", t) 53*da0073e9SAndroid Build Coastguard Worker extracted += entry + "\n" 54*da0073e9SAndroid Build Coastguard Worker return extracted 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Workerclass TestDispatch(TestCase): 58*da0073e9SAndroid Build Coastguard Worker namespace_index = 0 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker def test_all_invariants(self): 61*da0073e9SAndroid Build Coastguard Worker # Check that the regular stuff is OK! 62*da0073e9SAndroid Build Coastguard Worker C._dispatch_check_all_invariants() 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker # You probably don't want to call this directly; if your constructors 65*da0073e9SAndroid Build Coastguard Worker # don't commute, you can still run commute with a fixed ctor_order 66*da0073e9SAndroid Build Coastguard Worker # so that you can test that the destructors still commute 67*da0073e9SAndroid Build Coastguard Worker def run_ops( 68*da0073e9SAndroid Build Coastguard Worker self, 69*da0073e9SAndroid Build Coastguard Worker name, 70*da0073e9SAndroid Build Coastguard Worker ops, 71*da0073e9SAndroid Build Coastguard Worker ctor_order=None, 72*da0073e9SAndroid Build Coastguard Worker dtor_order=None, 73*da0073e9SAndroid Build Coastguard Worker results=None, 74*da0073e9SAndroid Build Coastguard Worker expect_raises=False, 75*da0073e9SAndroid Build Coastguard Worker ): 76*da0073e9SAndroid Build Coastguard Worker """ 77*da0073e9SAndroid Build Coastguard Worker Given a list of operator registrations, run the registrations in the 78*da0073e9SAndroid Build Coastguard Worker order specified by ctor_order, and then run the deregistrations in 79*da0073e9SAndroid Build Coastguard Worker dtor_order. 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker If results is specified, intermediate results are checked for consistency 82*da0073e9SAndroid Build Coastguard Worker with results stored in results (and stored in results if this is the 83*da0073e9SAndroid Build Coastguard Worker first time we've seen them). Results are expected to be equivalent 84*da0073e9SAndroid Build Coastguard Worker modulo commutativity and inverses (thus, results is keyed on a frozenset 85*da0073e9SAndroid Build Coastguard Worker of in effect registrations from ops). Results stores namedtuple 86*da0073e9SAndroid Build Coastguard Worker Result[state, table, provenance], where state is a string that contains 87*da0073e9SAndroid Build Coastguard Worker non-derived kernel registered or error message if it doesn't pass; 88*da0073e9SAndroid Build Coastguard Worker table is a string that contains computed dispatch table entries; 89*da0073e9SAndroid Build Coastguard Worker provenance is a string that describes how exactly we got this string. 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker If expect_raises is True, it is not an error to raise an exception. Instead, 92*da0073e9SAndroid Build Coastguard Worker we'll store the exception string (instead of the dispatcher state) 93*da0073e9SAndroid Build Coastguard Worker in results. In principle we should flag these differently, but it's 94*da0073e9SAndroid Build Coastguard Worker very obvious when you get an error in one case but not another. 95*da0073e9SAndroid Build Coastguard Worker """ 96*da0073e9SAndroid Build Coastguard Worker # By allocating every test into a fresh namespace, this makes it less 97*da0073e9SAndroid Build Coastguard Worker # likely that a bug in the testing framework will result in tests 98*da0073e9SAndroid Build Coastguard Worker # interfering with each other 99*da0073e9SAndroid Build Coastguard Worker self.__class__.namespace_index += 1 100*da0073e9SAndroid Build Coastguard Worker if results is None: 101*da0073e9SAndroid Build Coastguard Worker results = {} 102*da0073e9SAndroid Build Coastguard Worker if ctor_order is None: 103*da0073e9SAndroid Build Coastguard Worker ctor_order = list(range(len(ops))) 104*da0073e9SAndroid Build Coastguard Worker if dtor_order is None: 105*da0073e9SAndroid Build Coastguard Worker dtor_order = list(reversed(ctor_order)) 106*da0073e9SAndroid Build Coastguard Worker # Refs which retain the c10::Module object so we can explicitly control 107*da0073e9SAndroid Build Coastguard Worker # when each deregistration happens (deregistration occurs when the 108*da0073e9SAndroid Build Coastguard Worker # object gets deallocated). 109*da0073e9SAndroid Build Coastguard Worker refs = [None] * len(ops) 110*da0073e9SAndroid Build Coastguard Worker # Keep track of the set "in effect" registrations 111*da0073e9SAndroid Build Coastguard Worker active_ops = set() 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker # double underscore to make it less likely we conflict with something 114*da0073e9SAndroid Build Coastguard Worker # else 115*da0073e9SAndroid Build Coastguard Worker test_namespace = f"__test{self.namespace_index}__" 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker def check_invariants(actual_provenance): 118*da0073e9SAndroid Build Coastguard Worker C._dispatch_check_invariants(name) 119*da0073e9SAndroid Build Coastguard Worker # Normalize the test namespace so that expected outputs are stable 120*da0073e9SAndroid Build Coastguard Worker actual_state = C._dispatch_dump(f"{test_namespace}::{name}").replace( 121*da0073e9SAndroid Build Coastguard Worker test_namespace, "test" 122*da0073e9SAndroid Build Coastguard Worker ) 123*da0073e9SAndroid Build Coastguard Worker actual_table = C._dispatch_dump_table(f"{test_namespace}::{name}").replace( 124*da0073e9SAndroid Build Coastguard Worker test_namespace, "test" 125*da0073e9SAndroid Build Coastguard Worker ) 126*da0073e9SAndroid Build Coastguard Worker expected_state, expected_table, expected_provenance = results.setdefault( 127*da0073e9SAndroid Build Coastguard Worker frozenset(active_ops), 128*da0073e9SAndroid Build Coastguard Worker Result(actual_state, actual_table, actual_provenance), 129*da0073e9SAndroid Build Coastguard Worker ) 130*da0073e9SAndroid Build Coastguard Worker self.assertMultiLineEqual( 131*da0073e9SAndroid Build Coastguard Worker expected_state, 132*da0073e9SAndroid Build Coastguard Worker actual_state, 133*da0073e9SAndroid Build Coastguard Worker f"expected from {expected_provenance}; actual from {actual_provenance}", 134*da0073e9SAndroid Build Coastguard Worker ) 135*da0073e9SAndroid Build Coastguard Worker self.assertMultiLineEqual( 136*da0073e9SAndroid Build Coastguard Worker expected_table, 137*da0073e9SAndroid Build Coastguard Worker actual_table, 138*da0073e9SAndroid Build Coastguard Worker f"expected from {expected_provenance}; actual from {actual_provenance}", 139*da0073e9SAndroid Build Coastguard Worker ) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker results.setdefault(frozenset(), Result("", "", "hardcoded initial state")) 142*da0073e9SAndroid Build Coastguard Worker check_invariants("initial state") 143*da0073e9SAndroid Build Coastguard Worker # In the order specified by ctor_order, run registrations 144*da0073e9SAndroid Build Coastguard Worker set_to_report = frozenset(range(len(ops))) 145*da0073e9SAndroid Build Coastguard Worker for i, op_ix in enumerate(ctor_order): 146*da0073e9SAndroid Build Coastguard Worker # It would be better to DEF here, but because we manage 147*da0073e9SAndroid Build Coastguard Worker # lifetime of multiple registrations with multiple Library 148*da0073e9SAndroid Build Coastguard Worker # references (refs), we can't deal with the strict checking 149*da0073e9SAndroid Build Coastguard Worker # from DEF. 150*da0073e9SAndroid Build Coastguard Worker refs[op_ix] = C._dispatch_library("FRAGMENT", test_namespace, "") 151*da0073e9SAndroid Build Coastguard Worker active_ops.add(op_ix) 152*da0073e9SAndroid Build Coastguard Worker try: 153*da0073e9SAndroid Build Coastguard Worker ops[op_ix](refs[op_ix]) 154*da0073e9SAndroid Build Coastguard Worker check_invariants(f"running ctors {ctor_order[:i + 1]}") 155*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 156*da0073e9SAndroid Build Coastguard Worker if not expect_raises: 157*da0073e9SAndroid Build Coastguard Worker raise 158*da0073e9SAndroid Build Coastguard Worker actual = str(e).replace(test_namespace, "test") 159*da0073e9SAndroid Build Coastguard Worker actual = actual.split("\nException raised from ")[0] 160*da0073e9SAndroid Build Coastguard Worker expected, _, expected_provenance = results.setdefault( 161*da0073e9SAndroid Build Coastguard Worker frozenset(active_ops), 162*da0073e9SAndroid Build Coastguard Worker Result( 163*da0073e9SAndroid Build Coastguard Worker actual, "", f"error after running ctors {ctor_order[:i + 1]}" 164*da0073e9SAndroid Build Coastguard Worker ), 165*da0073e9SAndroid Build Coastguard Worker ) 166*da0073e9SAndroid Build Coastguard Worker self.assertMultiLineEqual(expected, actual, expected_provenance) 167*da0073e9SAndroid Build Coastguard Worker set_to_report = frozenset(active_ops) 168*da0073e9SAndroid Build Coastguard Worker active_ops.remove(op_ix) 169*da0073e9SAndroid Build Coastguard Worker # NB: this finally test asserts that if a registrations fails, 170*da0073e9SAndroid Build Coastguard Worker # the dispatcher is left in the same state *that it was before*! 171*da0073e9SAndroid Build Coastguard Worker check_invariants( 172*da0073e9SAndroid Build Coastguard Worker f"running ctors {ctor_order[:i]} and then failing to run ctor {op_ix} " 173*da0073e9SAndroid Build Coastguard Worker "(did this failure leave the dispatcher in a wedged state? " 174*da0073e9SAndroid Build Coastguard Worker "it shouldn't!)" 175*da0073e9SAndroid Build Coastguard Worker ) 176*da0073e9SAndroid Build Coastguard Worker break 177*da0073e9SAndroid Build Coastguard Worker last_ctor = i 178*da0073e9SAndroid Build Coastguard Worker if expect_raises and len(active_ops) == len(ops): 179*da0073e9SAndroid Build Coastguard Worker # Destroy references first, as some test frameworks (like pytest) 180*da0073e9SAndroid Build Coastguard Worker # will retain references in the exception raised by assertTrue! EW! 181*da0073e9SAndroid Build Coastguard Worker refs = None 182*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 183*da0073e9SAndroid Build Coastguard Worker False, 184*da0073e9SAndroid Build Coastguard Worker "expected exception to be raised, but nothing was raised " 185*da0073e9SAndroid Build Coastguard Worker f"(after running ctors {ctor_order})", 186*da0073e9SAndroid Build Coastguard Worker ) 187*da0073e9SAndroid Build Coastguard Worker # In the order specified by dtor_order, run deregistrations 188*da0073e9SAndroid Build Coastguard Worker for i, op_ix in enumerate(dtor_order): 189*da0073e9SAndroid Build Coastguard Worker # Trigger a destruction 190*da0073e9SAndroid Build Coastguard Worker refs[op_ix] = None 191*da0073e9SAndroid Build Coastguard Worker # discard not remove, since we may not have actually deregistered 192*da0073e9SAndroid Build Coastguard Worker # anything if there was an error raised 193*da0073e9SAndroid Build Coastguard Worker if expect_raises: 194*da0073e9SAndroid Build Coastguard Worker active_ops.discard(op_ix) 195*da0073e9SAndroid Build Coastguard Worker else: 196*da0073e9SAndroid Build Coastguard Worker active_ops.remove(op_ix) 197*da0073e9SAndroid Build Coastguard Worker check_invariants( 198*da0073e9SAndroid Build Coastguard Worker f"running ctors {ctor_order[:last_ctor + 1]}, then running dtors {dtor_order[:i + 1]}" 199*da0073e9SAndroid Build Coastguard Worker ) 200*da0073e9SAndroid Build Coastguard Worker return results[set_to_report][0] 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker # Operator registrations are commutative (as static initializers can 203*da0073e9SAndroid Build Coastguard Worker # run in any order) and invertible (by deregistration). (Subject 204*da0073e9SAndroid Build Coastguard Worker # to some caveats: some legacy behavior in the system are not commutative-- 205*da0073e9SAndroid Build Coastguard Worker # we want to get rid of these!) 206*da0073e9SAndroid Build Coastguard Worker # 207*da0073e9SAndroid Build Coastguard Worker # So while in principle we could simply test a set of operations 208*da0073e9SAndroid Build Coastguard Worker # by just running them one by one in the order specified by the user, 209*da0073e9SAndroid Build Coastguard Worker # we can get more assurance about these extra properties by doing 210*da0073e9SAndroid Build Coastguard Worker # more work: 211*da0073e9SAndroid Build Coastguard Worker # 212*da0073e9SAndroid Build Coastguard Worker # 1. Don't run the registrations once in a fixed order: run every possible 213*da0073e9SAndroid Build Coastguard Worker # permutation. Similarly, run every permutation of deregistration order. 214*da0073e9SAndroid Build Coastguard Worker # 215*da0073e9SAndroid Build Coastguard Worker # 2. Don't just check the end state of the dispatcher: for every 216*da0073e9SAndroid Build Coastguard Worker # subset of operator registrations, ensure that the computed 217*da0073e9SAndroid Build Coastguard Worker # intermediate state is path independent. One thing to note: 218*da0073e9SAndroid Build Coastguard Worker # in this function, we assume each operation is unique. In general, 219*da0073e9SAndroid Build Coastguard Worker # there may be duplicated registrations, but these are usually 220*da0073e9SAndroid Build Coastguard Worker # idempotent or legacy. We test for behavior here separately. 221*da0073e9SAndroid Build Coastguard Worker # 222*da0073e9SAndroid Build Coastguard Worker # NB: checking all permutations means this function is exponential in 223*da0073e9SAndroid Build Coastguard Worker # the length of ops! So don't pass too many ops to this function! 224*da0073e9SAndroid Build Coastguard Worker def commute(self, name, ops, ctor_order=None, expect_raises=False): 225*da0073e9SAndroid Build Coastguard Worker results = {} 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker def go(ctor_order): 228*da0073e9SAndroid Build Coastguard Worker for dtor_order in itertools.permutations(range(len(ops))): 229*da0073e9SAndroid Build Coastguard Worker self.run_ops( 230*da0073e9SAndroid Build Coastguard Worker name, 231*da0073e9SAndroid Build Coastguard Worker ops, 232*da0073e9SAndroid Build Coastguard Worker ctor_order, 233*da0073e9SAndroid Build Coastguard Worker dtor_order, 234*da0073e9SAndroid Build Coastguard Worker results=results, 235*da0073e9SAndroid Build Coastguard Worker expect_raises=expect_raises, 236*da0073e9SAndroid Build Coastguard Worker ) 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker if ctor_order is not None: 239*da0073e9SAndroid Build Coastguard Worker go(ctor_order) 240*da0073e9SAndroid Build Coastguard Worker else: 241*da0073e9SAndroid Build Coastguard Worker for ctor_order in itertools.permutations(range(len(ops))): 242*da0073e9SAndroid Build Coastguard Worker go(ctor_order) 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker # Return the "full" Result namedtuple after all operations are run. 245*da0073e9SAndroid Build Coastguard Worker # If this KeyErrors, that means that there did not exist any 246*da0073e9SAndroid Build Coastguard Worker # ordering of ctors which got us to the "end". That's an 247*da0073e9SAndroid Build Coastguard Worker # error in test construction: it means you could have 248*da0073e9SAndroid Build Coastguard Worker # factored the test into two smaller ones. 249*da0073e9SAndroid Build Coastguard Worker return results[frozenset(range(len(ops)))] 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker def test_def(self): 252*da0073e9SAndroid Build Coastguard Worker state = self.commute( 253*da0073e9SAndroid Build Coastguard Worker "foo", 254*da0073e9SAndroid Build Coastguard Worker [ 255*da0073e9SAndroid Build Coastguard Worker # m.def("foo(Tensor x) -> Tensor") 256*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x) -> Tensor"), 257*da0073e9SAndroid Build Coastguard Worker # m.impl("test_def", [](const Tensor& x) { return x }) 258*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo"), 259*da0073e9SAndroid Build Coastguard Worker # m.impl("test_def", kCPU, [](const Tensor& x) { return x }) 260*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", dispatch="CPU"), 261*da0073e9SAndroid Build Coastguard Worker # m.impl("test_def", kAutograd, [](const Tensor& x) { return x }) 262*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", dispatch="Autograd"), 263*da0073e9SAndroid Build Coastguard Worker # m.impl("test_def", kAutogradCPU, [](const Tensor& x) { return x }) 264*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", dispatch="AutogradCPU"), 265*da0073e9SAndroid Build Coastguard Worker ], 266*da0073e9SAndroid Build Coastguard Worker ).state 267*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 268*da0073e9SAndroid Build Coastguard Worker state, 269*da0073e9SAndroid Build Coastguard Worker """\ 270*da0073e9SAndroid Build Coastguard Workername: test::foo 271*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor 272*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0 273*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA 274*da0073e9SAndroid Build Coastguard WorkerCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 275*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 276*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 277*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 278*da0073e9SAndroid Build Coastguard Worker""", 279*da0073e9SAndroid Build Coastguard Worker ) 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker def test_def_impl_schema_mismatch(self): 282*da0073e9SAndroid Build Coastguard Worker # NB: an impl-impl mismatch is not reported eagerly; you'll find out 283*da0073e9SAndroid Build Coastguard Worker # about it because one of them won't match with def 284*da0073e9SAndroid Build Coastguard Worker state = self.commute( 285*da0073e9SAndroid Build Coastguard Worker "foo", 286*da0073e9SAndroid Build Coastguard Worker [ 287*da0073e9SAndroid Build Coastguard Worker # m.def("foo(Tensor x, Tensor y) -> Tensor") 288*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"), 289*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", [](const Tensor & x) { return x }) 290*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo"), 291*da0073e9SAndroid Build Coastguard Worker ], 292*da0073e9SAndroid Build Coastguard Worker expect_raises=True, 293*da0073e9SAndroid Build Coastguard Worker ).state 294*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 295*da0073e9SAndroid Build Coastguard Worker state, 296*da0073e9SAndroid Build Coastguard Worker """\ 297*da0073e9SAndroid Build Coastguard WorkerInferred operator schema for a C++ kernel function doesn't match the expected function schema. 298*da0073e9SAndroid Build Coastguard Worker operator: test::foo 299*da0073e9SAndroid Build Coastguard Worker expected schema: test::foo(Tensor x, Tensor y) -> Tensor 300*da0073e9SAndroid Build Coastguard Worker registered at /dev/null:0 301*da0073e9SAndroid Build Coastguard Worker inferred schema: (Tensor _0) -> Tensor _0 302*da0073e9SAndroid Build Coastguard Worker impl_t_t 303*da0073e9SAndroid Build Coastguard Worker reason: The number of arguments is different. 2 vs 1.""", 304*da0073e9SAndroid Build Coastguard Worker ) 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker def test_def_with_inference(self): 307*da0073e9SAndroid Build Coastguard Worker state = self.commute( 308*da0073e9SAndroid Build Coastguard Worker "foo", 309*da0073e9SAndroid Build Coastguard Worker [ 310*da0073e9SAndroid Build Coastguard Worker # m.def("foo", [](const Tensor & x) { return x }) 311*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_name_t_t("foo"), 312*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) 313*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "CPU"), 314*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) 315*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "Autograd"), 316*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kAutogradCPU, [](const Tensor & x) { return x }) 317*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "AutogradCPU"), 318*da0073e9SAndroid Build Coastguard Worker ], 319*da0073e9SAndroid Build Coastguard Worker ).state 320*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 321*da0073e9SAndroid Build Coastguard Worker state, 322*da0073e9SAndroid Build Coastguard Worker """\ 323*da0073e9SAndroid Build Coastguard Workername: test::foo 324*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor _0) -> Tensor _0 325*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0 326*da0073e9SAndroid Build Coastguard Workeralias analysis kind: CONSERVATIVE 327*da0073e9SAndroid Build Coastguard WorkerCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 328*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 329*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 330*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: default_def_name_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 331*da0073e9SAndroid Build Coastguard Worker""", 332*da0073e9SAndroid Build Coastguard Worker ) 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker def test_def_only(self): 335*da0073e9SAndroid Build Coastguard Worker state = self.commute( 336*da0073e9SAndroid Build Coastguard Worker "foo", 337*da0073e9SAndroid Build Coastguard Worker [ 338*da0073e9SAndroid Build Coastguard Worker # m.def("foo(Tensor x, Tensor y) -> Tensor") 339*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"), 340*da0073e9SAndroid Build Coastguard Worker ], 341*da0073e9SAndroid Build Coastguard Worker ).state 342*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 343*da0073e9SAndroid Build Coastguard Worker state, 344*da0073e9SAndroid Build Coastguard Worker """\ 345*da0073e9SAndroid Build Coastguard Workername: test::foo 346*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x, Tensor y) -> Tensor 347*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0 348*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA 349*da0073e9SAndroid Build Coastguard Worker""", 350*da0073e9SAndroid Build Coastguard Worker ) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker def test_impl_only(self): 353*da0073e9SAndroid Build Coastguard Worker state = self.commute( 354*da0073e9SAndroid Build Coastguard Worker "foo", 355*da0073e9SAndroid Build Coastguard Worker [ 356*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", [](const Tensor& x) { return x }) 357*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo"), 358*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCPU, [](const Tensor& x) { return x }) 359*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "CPU"), 360*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kAutograd, [](const Tensor& x) { return x }) 361*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "Autograd"), 362*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kAutogradCPU, [](const Tensor& x) { return x }) 363*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "AutogradCPU"), 364*da0073e9SAndroid Build Coastguard Worker ], 365*da0073e9SAndroid Build Coastguard Worker ).state 366*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 367*da0073e9SAndroid Build Coastguard Worker state, 368*da0073e9SAndroid Build Coastguard Worker """\ 369*da0073e9SAndroid Build Coastguard Workername: test::foo 370*da0073e9SAndroid Build Coastguard Workerschema: (none) 371*da0073e9SAndroid Build Coastguard WorkerCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 372*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 373*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 374*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 375*da0073e9SAndroid Build Coastguard Worker""", 376*da0073e9SAndroid Build Coastguard Worker ) 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker def test_computed_table(self): 379*da0073e9SAndroid Build Coastguard Worker result = self.commute( 380*da0073e9SAndroid Build Coastguard Worker "foo", 381*da0073e9SAndroid Build Coastguard Worker [ 382*da0073e9SAndroid Build Coastguard Worker # m.def("foo", [](const Tensor & x) { return x }) 383*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_name_t_t("foo"), 384*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) 385*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), 386*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCUDA, [](const Tensor & x) { return x }) 387*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "XLA", debug="fn_xla"), 388*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) 389*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"), 390*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kAutogradCPU, [](const Tensor & x) { return x }) 391*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "AutogradCPU", debug="fn_autogradcpu"), 392*da0073e9SAndroid Build Coastguard Worker ], 393*da0073e9SAndroid Build Coastguard Worker ) 394*da0073e9SAndroid Build Coastguard Worker state, table = result.state, result.table 395*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 396*da0073e9SAndroid Build Coastguard Worker state, 397*da0073e9SAndroid Build Coastguard Worker """\ 398*da0073e9SAndroid Build Coastguard Workername: test::foo 399*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor _0) -> Tensor _0 400*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0 401*da0073e9SAndroid Build Coastguard Workeralias analysis kind: CONSERVATIVE 402*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 403*da0073e9SAndroid Build Coastguard WorkerXLA: fn_xla :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 404*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: fn_autogradcpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 405*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: fn_autograd :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 406*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: default_def_name_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 407*da0073e9SAndroid Build Coastguard Worker""", 408*da0073e9SAndroid Build Coastguard Worker ) 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker # computed dispatch table is too big, so we only check on a few entries we're interested in. 411*da0073e9SAndroid Build Coastguard Worker extracted_table = extract_dispatch_table_with_keys( 412*da0073e9SAndroid Build Coastguard Worker table, dispatch_keys_to_check 413*da0073e9SAndroid Build Coastguard Worker ) 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 416*da0073e9SAndroid Build Coastguard Worker extracted_table, 417*da0073e9SAndroid Build Coastguard Worker """\ 418*da0073e9SAndroid Build Coastguard WorkerUndefined: default_def_name_t_t [math kernel] 419*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu [kernel] 420*da0073e9SAndroid Build Coastguard WorkerCUDA: default_def_name_t_t [math kernel] 421*da0073e9SAndroid Build Coastguard WorkerXLA: fn_xla [kernel] 422*da0073e9SAndroid Build Coastguard WorkerAutogradOther: default_def_name_t_t [math kernel] 423*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: fn_autogradcpu [kernel] 424*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: default_def_name_t_t [math kernel] 425*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: fn_autograd [autograd kernel] 426*da0073e9SAndroid Build Coastguard Worker""", 427*da0073e9SAndroid Build Coastguard Worker ) 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker def test_computed_table_with_cpu_math_autogradcpu_fallthrough(self): 430*da0073e9SAndroid Build Coastguard Worker global_m = C._dispatch_library("IMPL", "_", "AutogradCPU") 431*da0073e9SAndroid Build Coastguard Worker result = self.commute( 432*da0073e9SAndroid Build Coastguard Worker "foo", 433*da0073e9SAndroid Build Coastguard Worker [ 434*da0073e9SAndroid Build Coastguard Worker # m.def("foo", [](const Tensor & x) { return x }) 435*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_name_t_t("foo"), 436*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) 437*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "CPU"), 438*da0073e9SAndroid Build Coastguard Worker ], 439*da0073e9SAndroid Build Coastguard Worker ) 440*da0073e9SAndroid Build Coastguard Worker state, table = result.state, result.table 441*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 442*da0073e9SAndroid Build Coastguard Worker state, 443*da0073e9SAndroid Build Coastguard Worker """\ 444*da0073e9SAndroid Build Coastguard Workername: test::foo 445*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor _0) -> Tensor _0 446*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0 447*da0073e9SAndroid Build Coastguard Workeralias analysis kind: CONSERVATIVE 448*da0073e9SAndroid Build Coastguard WorkerCPU: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 449*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: default_def_name_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 450*da0073e9SAndroid Build Coastguard Worker""", 451*da0073e9SAndroid Build Coastguard Worker ) 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker # computed dispatch table is too big, so we only check on a few entries we're interested in. 454*da0073e9SAndroid Build Coastguard Worker extracted_table = extract_dispatch_table_with_keys( 455*da0073e9SAndroid Build Coastguard Worker table, dispatch_keys_to_check 456*da0073e9SAndroid Build Coastguard Worker ) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 459*da0073e9SAndroid Build Coastguard Worker extracted_table, 460*da0073e9SAndroid Build Coastguard Worker """\ 461*da0073e9SAndroid Build Coastguard WorkerUndefined: default_def_name_t_t [math kernel] 462*da0073e9SAndroid Build Coastguard WorkerCPU: impl_t_t [kernel] 463*da0073e9SAndroid Build Coastguard WorkerCUDA: default_def_name_t_t [math kernel] 464*da0073e9SAndroid Build Coastguard WorkerXLA: default_def_name_t_t [math kernel] 465*da0073e9SAndroid Build Coastguard WorkerAutogradOther: default_def_name_t_t [math kernel] 466*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: registered in pytorch framework [backend fallback] 467*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: default_def_name_t_t [math kernel] 468*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: default_def_name_t_t [math kernel] 469*da0073e9SAndroid Build Coastguard Worker""", 470*da0073e9SAndroid Build Coastguard Worker ) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker def test_computed_table_with_math(self): 473*da0073e9SAndroid Build Coastguard Worker global_m = C._dispatch_library("IMPL", "_", "AutogradCPU") 474*da0073e9SAndroid Build Coastguard Worker result = self.commute( 475*da0073e9SAndroid Build Coastguard Worker "foo", 476*da0073e9SAndroid Build Coastguard Worker [ 477*da0073e9SAndroid Build Coastguard Worker # m.def("foo(Tensor x) -> Tensor") 478*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x) -> Tensor"), 479*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x }) 480*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "CompositeImplicitAutograd"), 481*da0073e9SAndroid Build Coastguard Worker ], 482*da0073e9SAndroid Build Coastguard Worker ) 483*da0073e9SAndroid Build Coastguard Worker state, table = result.state, result.table 484*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 485*da0073e9SAndroid Build Coastguard Worker state, 486*da0073e9SAndroid Build Coastguard Worker """\ 487*da0073e9SAndroid Build Coastguard Workername: test::foo 488*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor 489*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0 490*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA 491*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 492*da0073e9SAndroid Build Coastguard Worker""", 493*da0073e9SAndroid Build Coastguard Worker ) 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker # computed dispatch table is too big, so we only check on a few entries we're interested in. 496*da0073e9SAndroid Build Coastguard Worker extracted_table = extract_dispatch_table_with_keys( 497*da0073e9SAndroid Build Coastguard Worker table, dispatch_keys_to_check 498*da0073e9SAndroid Build Coastguard Worker ) 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 501*da0073e9SAndroid Build Coastguard Worker extracted_table, 502*da0073e9SAndroid Build Coastguard Worker """\ 503*da0073e9SAndroid Build Coastguard WorkerUndefined: impl_t_t [math kernel] 504*da0073e9SAndroid Build Coastguard WorkerCPU: impl_t_t [math kernel] 505*da0073e9SAndroid Build Coastguard WorkerCUDA: impl_t_t [math kernel] 506*da0073e9SAndroid Build Coastguard WorkerXLA: impl_t_t [math kernel] 507*da0073e9SAndroid Build Coastguard WorkerAutogradOther: impl_t_t [math kernel] 508*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: impl_t_t [math kernel] 509*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: impl_t_t [math kernel] 510*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: impl_t_t [math kernel] 511*da0073e9SAndroid Build Coastguard Worker""", 512*da0073e9SAndroid Build Coastguard Worker ) 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker def test_computed_table_with_cpu_math(self): 515*da0073e9SAndroid Build Coastguard Worker global_m = C._dispatch_library("IMPL", "_", "AutogradCPU") 516*da0073e9SAndroid Build Coastguard Worker result = self.commute( 517*da0073e9SAndroid Build Coastguard Worker "foo", 518*da0073e9SAndroid Build Coastguard Worker [ 519*da0073e9SAndroid Build Coastguard Worker # m.def("foo(Tensor x) -> Tensor") 520*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x) -> Tensor"), 521*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) 522*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), 523*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x }) 524*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t( 525*da0073e9SAndroid Build Coastguard Worker "foo", "CompositeImplicitAutograd", debug="fn_math" 526*da0073e9SAndroid Build Coastguard Worker ), 527*da0073e9SAndroid Build Coastguard Worker ], 528*da0073e9SAndroid Build Coastguard Worker ) 529*da0073e9SAndroid Build Coastguard Worker state, table = result.state, result.table 530*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 531*da0073e9SAndroid Build Coastguard Worker state, 532*da0073e9SAndroid Build Coastguard Worker """\ 533*da0073e9SAndroid Build Coastguard Workername: test::foo 534*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor 535*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0 536*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA 537*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 538*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 539*da0073e9SAndroid Build Coastguard Worker""", 540*da0073e9SAndroid Build Coastguard Worker ) 541*da0073e9SAndroid Build Coastguard Worker 542*da0073e9SAndroid Build Coastguard Worker # computed dispatch table is too big, so we only check on a few entries we're interested in. 543*da0073e9SAndroid Build Coastguard Worker extracted_table = extract_dispatch_table_with_keys( 544*da0073e9SAndroid Build Coastguard Worker table, dispatch_keys_to_check 545*da0073e9SAndroid Build Coastguard Worker ) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 548*da0073e9SAndroid Build Coastguard Worker extracted_table, 549*da0073e9SAndroid Build Coastguard Worker """\ 550*da0073e9SAndroid Build Coastguard WorkerUndefined: fn_math [math kernel] 551*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu [kernel] 552*da0073e9SAndroid Build Coastguard WorkerCUDA: fn_math [math kernel] 553*da0073e9SAndroid Build Coastguard WorkerXLA: fn_math [math kernel] 554*da0073e9SAndroid Build Coastguard WorkerAutogradOther: fn_math [math kernel] 555*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: registered in pytorch framework [backend fallback] 556*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: fn_math [math kernel] 557*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: fn_math [math kernel] 558*da0073e9SAndroid Build Coastguard Worker""", 559*da0073e9SAndroid Build Coastguard Worker ) 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker def test_computed_table_with_autograd(self): 562*da0073e9SAndroid Build Coastguard Worker global_m = C._dispatch_library("IMPL", "_", "AutogradCPU") 563*da0073e9SAndroid Build Coastguard Worker result = self.commute( 564*da0073e9SAndroid Build Coastguard Worker "foo", 565*da0073e9SAndroid Build Coastguard Worker [ 566*da0073e9SAndroid Build Coastguard Worker # m.def("foo(Tensor x) -> Tensor") 567*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x) -> Tensor"), 568*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) 569*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "Autograd"), 570*da0073e9SAndroid Build Coastguard Worker ], 571*da0073e9SAndroid Build Coastguard Worker ) 572*da0073e9SAndroid Build Coastguard Worker state, table = result.state, result.table 573*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 574*da0073e9SAndroid Build Coastguard Worker state, 575*da0073e9SAndroid Build Coastguard Worker """\ 576*da0073e9SAndroid Build Coastguard Workername: test::foo 577*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor 578*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0 579*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA 580*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: impl_t_t :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 581*da0073e9SAndroid Build Coastguard Worker""", 582*da0073e9SAndroid Build Coastguard Worker ) 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker # computed dispatch table is too big, so we only check on a few entries we're interested in. 585*da0073e9SAndroid Build Coastguard Worker extracted_table = extract_dispatch_table_with_keys( 586*da0073e9SAndroid Build Coastguard Worker table, dispatch_keys_to_check 587*da0073e9SAndroid Build Coastguard Worker ) 588*da0073e9SAndroid Build Coastguard Worker 589*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 590*da0073e9SAndroid Build Coastguard Worker extracted_table, 591*da0073e9SAndroid Build Coastguard Worker """\ 592*da0073e9SAndroid Build Coastguard WorkerAutogradOther: impl_t_t [autograd kernel] 593*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: impl_t_t [autograd kernel] 594*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: impl_t_t [autograd kernel] 595*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: impl_t_t [autograd kernel] 596*da0073e9SAndroid Build Coastguard Worker""", 597*da0073e9SAndroid Build Coastguard Worker ) 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker # Now that catchAll maps to CompositeImplicitAutograd, registering to both 600*da0073e9SAndroid Build Coastguard Worker # catchAll and CompositeImplicitAutograd breaks commutativity. 601*da0073e9SAndroid Build Coastguard Worker def test_computed_table_with_cpu_autograd_math(self): 602*da0073e9SAndroid Build Coastguard Worker result = self.commute( 603*da0073e9SAndroid Build Coastguard Worker "foo", 604*da0073e9SAndroid Build Coastguard Worker [ 605*da0073e9SAndroid Build Coastguard Worker # m.def("foo(Tensor x) -> Tensor") 606*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x) -> Tensor"), 607*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) 608*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), 609*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) 610*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"), 611*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x }) 612*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t( 613*da0073e9SAndroid Build Coastguard Worker "foo", "CompositeImplicitAutograd", debug="fn_math" 614*da0073e9SAndroid Build Coastguard Worker ), 615*da0073e9SAndroid Build Coastguard Worker ], 616*da0073e9SAndroid Build Coastguard Worker ) 617*da0073e9SAndroid Build Coastguard Worker state, table = result.state, result.table 618*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 619*da0073e9SAndroid Build Coastguard Worker state, 620*da0073e9SAndroid Build Coastguard Worker """\ 621*da0073e9SAndroid Build Coastguard Workername: test::foo 622*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor 623*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0 624*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA 625*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 626*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: fn_autograd :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 627*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 628*da0073e9SAndroid Build Coastguard Worker""", 629*da0073e9SAndroid Build Coastguard Worker ) 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker # computed dispatch table is too big, so we only check on a few entries we're interested in. 632*da0073e9SAndroid Build Coastguard Worker extracted_table = extract_dispatch_table_with_keys( 633*da0073e9SAndroid Build Coastguard Worker table, dispatch_keys_to_check 634*da0073e9SAndroid Build Coastguard Worker ) 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 637*da0073e9SAndroid Build Coastguard Worker extracted_table, 638*da0073e9SAndroid Build Coastguard Worker """\ 639*da0073e9SAndroid Build Coastguard WorkerUndefined: fn_math [math kernel] 640*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu [kernel] 641*da0073e9SAndroid Build Coastguard WorkerCUDA: fn_math [math kernel] 642*da0073e9SAndroid Build Coastguard WorkerXLA: fn_math [math kernel] 643*da0073e9SAndroid Build Coastguard WorkerAutogradOther: fn_math [math kernel] 644*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: fn_autograd [autograd kernel] 645*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: fn_math [math kernel] 646*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: fn_math [math kernel] 647*da0073e9SAndroid Build Coastguard Worker""", 648*da0073e9SAndroid Build Coastguard Worker ) 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker def test_computed_table_with_ambiguous_autogradother(self): 651*da0073e9SAndroid Build Coastguard Worker result = self.commute( 652*da0073e9SAndroid Build Coastguard Worker "foo", 653*da0073e9SAndroid Build Coastguard Worker [ 654*da0073e9SAndroid Build Coastguard Worker # m.def("foo(Tensor x) -> Tensor") 655*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x) -> Tensor"), 656*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x }) 657*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t( 658*da0073e9SAndroid Build Coastguard Worker "foo", "CompositeImplicitAutograd", debug="fn_math" 659*da0073e9SAndroid Build Coastguard Worker ), 660*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kFPGA, [](const Tensor & x) { return x }) 661*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "FPGA", debug="fn_fpga"), 662*da0073e9SAndroid Build Coastguard Worker ], 663*da0073e9SAndroid Build Coastguard Worker ) 664*da0073e9SAndroid Build Coastguard Worker state, table = result.state, result.table 665*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 666*da0073e9SAndroid Build Coastguard Worker state, 667*da0073e9SAndroid Build Coastguard Worker """\ 668*da0073e9SAndroid Build Coastguard Workername: test::foo 669*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor 670*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0 671*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA 672*da0073e9SAndroid Build Coastguard WorkerFPGA: fn_fpga :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 673*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 674*da0073e9SAndroid Build Coastguard Worker""", 675*da0073e9SAndroid Build Coastguard Worker ) 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker # computed dispatch table is too big, so we only check on a few entries we're interested in. 678*da0073e9SAndroid Build Coastguard Worker extracted_table = extract_dispatch_table_with_keys( 679*da0073e9SAndroid Build Coastguard Worker table, dispatch_keys_to_check + ("FPGA",) 680*da0073e9SAndroid Build Coastguard Worker ) 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 683*da0073e9SAndroid Build Coastguard Worker extracted_table, 684*da0073e9SAndroid Build Coastguard Worker """\ 685*da0073e9SAndroid Build Coastguard WorkerUndefined: fn_math [math kernel] 686*da0073e9SAndroid Build Coastguard WorkerCPU: fn_math [math kernel] 687*da0073e9SAndroid Build Coastguard WorkerCUDA: fn_math [math kernel] 688*da0073e9SAndroid Build Coastguard WorkerXLA: fn_math [math kernel] 689*da0073e9SAndroid Build Coastguard WorkerAutogradOther: ambiguous_autogradother [ambiguous autogradother] 690*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: fn_math [math kernel] 691*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: fn_math [math kernel] 692*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: fn_math [math kernel] 693*da0073e9SAndroid Build Coastguard WorkerFPGA: fn_fpga [kernel] 694*da0073e9SAndroid Build Coastguard Worker""", 695*da0073e9SAndroid Build Coastguard Worker ) 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker def test_computed_table_with_cpu_defaultbackend(self): 698*da0073e9SAndroid Build Coastguard Worker result = self.commute( 699*da0073e9SAndroid Build Coastguard Worker "foo", 700*da0073e9SAndroid Build Coastguard Worker [ 701*da0073e9SAndroid Build Coastguard Worker # m.def("foo(Tensor x) -> Tensor") 702*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x) -> Tensor"), 703*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) 704*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), 705*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCompositeExplicitAutograd, [](const Tensor & x) { return x }) 706*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t( 707*da0073e9SAndroid Build Coastguard Worker "foo", "CompositeExplicitAutograd", debug="fn_defaultbackend" 708*da0073e9SAndroid Build Coastguard Worker ), 709*da0073e9SAndroid Build Coastguard Worker ], 710*da0073e9SAndroid Build Coastguard Worker ) 711*da0073e9SAndroid Build Coastguard Worker state, table = result.state, result.table 712*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 713*da0073e9SAndroid Build Coastguard Worker state, 714*da0073e9SAndroid Build Coastguard Worker """\ 715*da0073e9SAndroid Build Coastguard Workername: test::foo 716*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor 717*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0 718*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA 719*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 720*da0073e9SAndroid Build Coastguard WorkerCompositeExplicitAutograd[alias]: fn_defaultbackend :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 721*da0073e9SAndroid Build Coastguard Worker""", 722*da0073e9SAndroid Build Coastguard Worker ) 723*da0073e9SAndroid Build Coastguard Worker 724*da0073e9SAndroid Build Coastguard Worker # computed dispatch table is too big, so we only check on a few entries we're interested in. 725*da0073e9SAndroid Build Coastguard Worker extracted_table = extract_dispatch_table_with_keys( 726*da0073e9SAndroid Build Coastguard Worker table, dispatch_keys_to_check 727*da0073e9SAndroid Build Coastguard Worker ) 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 730*da0073e9SAndroid Build Coastguard Worker extracted_table, 731*da0073e9SAndroid Build Coastguard Worker """\ 732*da0073e9SAndroid Build Coastguard WorkerUndefined: fn_defaultbackend [default backend kernel] 733*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu [kernel] 734*da0073e9SAndroid Build Coastguard WorkerCUDA: fn_defaultbackend [default backend kernel] 735*da0073e9SAndroid Build Coastguard WorkerXLA: fn_defaultbackend [default backend kernel] 736*da0073e9SAndroid Build Coastguard WorkerAutogradOther: registered in pytorch framework [backend fallback] 737*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: registered in pytorch framework [backend fallback] 738*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: registered in pytorch framework [backend fallback] 739*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: registered in pytorch framework [backend fallback] 740*da0073e9SAndroid Build Coastguard Worker""", 741*da0073e9SAndroid Build Coastguard Worker ) 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker def test_computed_table_with_cpu_autograd_defaultbackend(self): 744*da0073e9SAndroid Build Coastguard Worker result = self.commute( 745*da0073e9SAndroid Build Coastguard Worker "foo", 746*da0073e9SAndroid Build Coastguard Worker [ 747*da0073e9SAndroid Build Coastguard Worker # m.def("foo(Tensor x) -> Tensor") 748*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x) -> Tensor"), 749*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) 750*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), 751*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) 752*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"), 753*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCompositeExplicitAutograd, [](const Tensor & x) { return x }) 754*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t( 755*da0073e9SAndroid Build Coastguard Worker "foo", "CompositeExplicitAutograd", debug="fn_defaultbackend" 756*da0073e9SAndroid Build Coastguard Worker ), 757*da0073e9SAndroid Build Coastguard Worker ], 758*da0073e9SAndroid Build Coastguard Worker ) 759*da0073e9SAndroid Build Coastguard Worker state, table = result.state, result.table 760*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 761*da0073e9SAndroid Build Coastguard Worker state, 762*da0073e9SAndroid Build Coastguard Worker """\ 763*da0073e9SAndroid Build Coastguard Workername: test::foo 764*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor 765*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0 766*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA 767*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 768*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: fn_autograd :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 769*da0073e9SAndroid Build Coastguard WorkerCompositeExplicitAutograd[alias]: fn_defaultbackend :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 770*da0073e9SAndroid Build Coastguard Worker""", 771*da0073e9SAndroid Build Coastguard Worker ) 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker # computed dispatch table is too big, so we only check on a few entries we're interested in. 774*da0073e9SAndroid Build Coastguard Worker extracted_table = extract_dispatch_table_with_keys( 775*da0073e9SAndroid Build Coastguard Worker table, dispatch_keys_to_check + ("FPGA",) 776*da0073e9SAndroid Build Coastguard Worker ) 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 779*da0073e9SAndroid Build Coastguard Worker extracted_table, 780*da0073e9SAndroid Build Coastguard Worker """\ 781*da0073e9SAndroid Build Coastguard WorkerUndefined: fn_defaultbackend [default backend kernel] 782*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu [kernel] 783*da0073e9SAndroid Build Coastguard WorkerCUDA: fn_defaultbackend [default backend kernel] 784*da0073e9SAndroid Build Coastguard WorkerXLA: fn_defaultbackend [default backend kernel] 785*da0073e9SAndroid Build Coastguard WorkerAutogradOther: fn_autograd [autograd kernel] 786*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: fn_autograd [autograd kernel] 787*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: fn_autograd [autograd kernel] 788*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: fn_autograd [autograd kernel] 789*da0073e9SAndroid Build Coastguard WorkerFPGA: fn_defaultbackend [default backend kernel] 790*da0073e9SAndroid Build Coastguard Worker""", 791*da0073e9SAndroid Build Coastguard Worker ) 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Worker def test_computed_table_with_cpu_autograd_math_defaultbackend(self): 794*da0073e9SAndroid Build Coastguard Worker result = self.commute( 795*da0073e9SAndroid Build Coastguard Worker "foo", 796*da0073e9SAndroid Build Coastguard Worker [ 797*da0073e9SAndroid Build Coastguard Worker # m.def("foo(Tensor x) -> Tensor") 798*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x) -> Tensor"), 799*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) 800*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), 801*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) 802*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"), 803*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x }) 804*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t( 805*da0073e9SAndroid Build Coastguard Worker "foo", "CompositeImplicitAutograd", debug="fn_math" 806*da0073e9SAndroid Build Coastguard Worker ), 807*da0073e9SAndroid Build Coastguard Worker # m.impl("foo", torch::kCompositeExplicitAutograd, [](const Tensor & x) { return x }) 808*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t( 809*da0073e9SAndroid Build Coastguard Worker "foo", "CompositeExplicitAutograd", debug="fn_defaultbackend" 810*da0073e9SAndroid Build Coastguard Worker ), 811*da0073e9SAndroid Build Coastguard Worker ], 812*da0073e9SAndroid Build Coastguard Worker ) 813*da0073e9SAndroid Build Coastguard Worker state, table = result.state, result.table 814*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 815*da0073e9SAndroid Build Coastguard Worker state, 816*da0073e9SAndroid Build Coastguard Worker """\ 817*da0073e9SAndroid Build Coastguard Workername: test::foo 818*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x) -> Tensor 819*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0 820*da0073e9SAndroid Build Coastguard Workeralias analysis kind: FROM_SCHEMA 821*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 822*da0073e9SAndroid Build Coastguard WorkerAutograd[alias]: fn_autograd :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 823*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 824*da0073e9SAndroid Build Coastguard WorkerCompositeExplicitAutograd[alias]: fn_defaultbackend :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 825*da0073e9SAndroid Build Coastguard Worker""", 826*da0073e9SAndroid Build Coastguard Worker ) 827*da0073e9SAndroid Build Coastguard Worker 828*da0073e9SAndroid Build Coastguard Worker # computed dispatch table is too big, so we only check on a few entries we're interested in. 829*da0073e9SAndroid Build Coastguard Worker extracted_table = extract_dispatch_table_with_keys( 830*da0073e9SAndroid Build Coastguard Worker table, dispatch_keys_to_check 831*da0073e9SAndroid Build Coastguard Worker ) 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 834*da0073e9SAndroid Build Coastguard Worker extracted_table, 835*da0073e9SAndroid Build Coastguard Worker """\ 836*da0073e9SAndroid Build Coastguard WorkerUndefined: fn_defaultbackend [default backend kernel] 837*da0073e9SAndroid Build Coastguard WorkerCPU: fn_cpu [kernel] 838*da0073e9SAndroid Build Coastguard WorkerCUDA: fn_defaultbackend [default backend kernel] 839*da0073e9SAndroid Build Coastguard WorkerXLA: fn_defaultbackend [default backend kernel] 840*da0073e9SAndroid Build Coastguard WorkerAutogradOther: fn_autograd [autograd kernel] 841*da0073e9SAndroid Build Coastguard WorkerAutogradCPU: fn_autograd [autograd kernel] 842*da0073e9SAndroid Build Coastguard WorkerAutogradCUDA: fn_autograd [autograd kernel] 843*da0073e9SAndroid Build Coastguard WorkerAutogradXLA: fn_autograd [autograd kernel] 844*da0073e9SAndroid Build Coastguard Worker""", 845*da0073e9SAndroid Build Coastguard Worker ) 846*da0073e9SAndroid Build Coastguard Worker 847*da0073e9SAndroid Build Coastguard Worker def test_multiple_def_error(self): 848*da0073e9SAndroid Build Coastguard Worker ops = [ 849*da0073e9SAndroid Build Coastguard Worker # m.def("foo(Tensor x, Tensor y) -> Tensor") 850*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"), 851*da0073e9SAndroid Build Coastguard Worker # m.def("foo(Tensor x, Tensor y) -> Tensor") 852*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"), 853*da0073e9SAndroid Build Coastguard Worker ] 854*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 855*da0073e9SAndroid Build Coastguard Worker self.commute("foo", ops, expect_raises=True).state, 856*da0073e9SAndroid Build Coastguard Worker """Tried to register an operator (test::foo(Tensor x, Tensor y) -> Tensor) with the same name and overload """ 857*da0073e9SAndroid Build Coastguard Worker """name multiple times. Each overload's schema should only be registered with a single call to def(). """ 858*da0073e9SAndroid Build Coastguard Worker """Duplicate registration: registered at /dev/null:0. Original registration: registered at /dev/null:0""", 859*da0073e9SAndroid Build Coastguard Worker ) 860*da0073e9SAndroid Build Coastguard Worker 861*da0073e9SAndroid Build Coastguard Worker def test_def_with_explicit_alias(self): 862*da0073e9SAndroid Build Coastguard Worker state = self.commute( 863*da0073e9SAndroid Build Coastguard Worker "foo", 864*da0073e9SAndroid Build Coastguard Worker [ 865*da0073e9SAndroid Build Coastguard Worker # m.def(torch::schema( 866*da0073e9SAndroid Build Coastguard Worker # "foo(Tensor x, Tensor y) -> Tensor", 867*da0073e9SAndroid Build Coastguard Worker # AliasAnalysisKind::PURE)) 868*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_( 869*da0073e9SAndroid Build Coastguard Worker "foo(Tensor x, Tensor y) -> Tensor", alias="PURE_FUNCTION" 870*da0073e9SAndroid Build Coastguard Worker ) 871*da0073e9SAndroid Build Coastguard Worker ], 872*da0073e9SAndroid Build Coastguard Worker ).state 873*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 874*da0073e9SAndroid Build Coastguard Worker state, 875*da0073e9SAndroid Build Coastguard Worker """\ 876*da0073e9SAndroid Build Coastguard Workername: test::foo 877*da0073e9SAndroid Build Coastguard Workerschema: test::foo(Tensor x, Tensor y) -> Tensor 878*da0073e9SAndroid Build Coastguard Workerdebug: registered at /dev/null:0 879*da0073e9SAndroid Build Coastguard Workeralias analysis kind: PURE_FUNCTION 880*da0073e9SAndroid Build Coastguard Worker""", 881*da0073e9SAndroid Build Coastguard Worker ) 882*da0073e9SAndroid Build Coastguard Worker 883*da0073e9SAndroid Build Coastguard Worker def test_multiple_def_alias_defaulting(self): 884*da0073e9SAndroid Build Coastguard Worker ops = [ 885*da0073e9SAndroid Build Coastguard Worker # m.def(torch::schema("foo(Tensor x) -> Tensor", 886*da0073e9SAndroid Build Coastguard Worker # c10::AliasAnalysisKind::PURE_FUNCTION)) 887*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x) -> Tensor", alias="PURE_FUNCTION"), 888*da0073e9SAndroid Build Coastguard Worker # RegisterOperators().op("foo(Tensor x) -> Tensor") 889*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_legacy("foo(Tensor x) -> Tensor"), 890*da0073e9SAndroid Build Coastguard Worker ] 891*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 892*da0073e9SAndroid Build Coastguard Worker self.commute("foo", ops, expect_raises=True).state, 893*da0073e9SAndroid Build Coastguard Worker """Tried to register an operator (test::foo(Tensor x) -> Tensor) with the same name and overload """ 894*da0073e9SAndroid Build Coastguard Worker """name multiple times. Each overload's schema should only be registered with a single call to def(). """ 895*da0073e9SAndroid Build Coastguard Worker """Duplicate registration: registered at /dev/null:0. Original registration: registered at /dev/null:0""", 896*da0073e9SAndroid Build Coastguard Worker ) 897*da0073e9SAndroid Build Coastguard Worker 898*da0073e9SAndroid Build Coastguard Worker def test_multiple_def_alias_mismatch(self): 899*da0073e9SAndroid Build Coastguard Worker ops = [ 900*da0073e9SAndroid Build Coastguard Worker # m.def(torch::schema("foo(Tensor x) -> Tensor", 901*da0073e9SAndroid Build Coastguard Worker # c10::AliasAnalysisKind::PURE_FUNCTION)) 902*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x) -> Tensor", alias="PURE_FUNCTION"), 903*da0073e9SAndroid Build Coastguard Worker # m.def(torch::schema("foo(Tensor x) -> Tensor", 904*da0073e9SAndroid Build Coastguard Worker # c10::AliasAnalysisKind::CONSERVATIVE)) 905*da0073e9SAndroid Build Coastguard Worker lambda m: m.def_("foo(Tensor x) -> Tensor", alias="CONSERVATIVE"), 906*da0073e9SAndroid Build Coastguard Worker ] 907*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 908*da0073e9SAndroid Build Coastguard Worker self.commute("foo", ops, expect_raises=True).state, 909*da0073e9SAndroid Build Coastguard Worker """Tried to register an operator (test::foo(Tensor x) -> Tensor) with the same name and overload """ 910*da0073e9SAndroid Build Coastguard Worker """name multiple times. Each overload's schema should only be registered with a single call to def(). """ 911*da0073e9SAndroid Build Coastguard Worker """Duplicate registration: registered at /dev/null:0. Original registration: registered at /dev/null:0""", 912*da0073e9SAndroid Build Coastguard Worker ) 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Worker def test_multiple_fallback(self): 915*da0073e9SAndroid Build Coastguard Worker global_m = C._dispatch_library("IMPL", "_", "XLA") 916*da0073e9SAndroid Build Coastguard Worker global_m.fallback_fallthrough() 917*da0073e9SAndroid Build Coastguard Worker try: 918*da0073e9SAndroid Build Coastguard Worker global_m.fallback_fallthrough() 919*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 920*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 921*da0073e9SAndroid Build Coastguard Worker str(e), 922*da0073e9SAndroid Build Coastguard Worker """Tried to register multiple backend fallbacks for the same dispatch key XLA; previous registration """ 923*da0073e9SAndroid Build Coastguard Worker """registered at /dev/null:0, new registration registered at /dev/null:0""", 924*da0073e9SAndroid Build Coastguard Worker ) 925*da0073e9SAndroid Build Coastguard Worker else: 926*da0073e9SAndroid Build Coastguard Worker self.assertTrue(False) 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker def test_overwrite_math(self): 929*da0073e9SAndroid Build Coastguard Worker ops = [ 930*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", debug="fn1"), 931*da0073e9SAndroid Build Coastguard Worker lambda m: m.impl_t_t("foo", debug="fn2"), 932*da0073e9SAndroid Build Coastguard Worker ] 933*da0073e9SAndroid Build Coastguard Worker # Not commutative 934*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 935*da0073e9SAndroid Build Coastguard Worker self.commute("foo", ops, ctor_order=(0, 1)).state, 936*da0073e9SAndroid Build Coastguard Worker """\ 937*da0073e9SAndroid Build Coastguard Workername: test::foo 938*da0073e9SAndroid Build Coastguard Workerschema: (none) 939*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias]: fn2 :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 940*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias] (inactive): fn1 :: (Tensor _0) -> Tensor _0 [ boxed unboxed ] 941*da0073e9SAndroid Build Coastguard Worker""", 942*da0073e9SAndroid Build Coastguard Worker ) 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker # Definition: a dangling impl happens when someone does an impl() on a 945*da0073e9SAndroid Build Coastguard Worker # function but not a def() for it. This is usually a bug, e.g. someone 946*da0073e9SAndroid Build Coastguard Worker # misspelled an operator name, or someone registered an impl for an op that 947*da0073e9SAndroid Build Coastguard Worker # no longer exists 948*da0073e9SAndroid Build Coastguard Worker def test_find_dangling_impls(self): 949*da0073e9SAndroid Build Coastguard Worker dangling_impls = C._dispatch_find_dangling_impls() 950*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 951*da0073e9SAndroid Build Coastguard Worker 0, 952*da0073e9SAndroid Build Coastguard Worker len(dangling_impls), 953*da0073e9SAndroid Build Coastguard Worker msg=f"Expect zero dangling impls, but found: {dangling_impls}", 954*da0073e9SAndroid Build Coastguard Worker ) 955*da0073e9SAndroid Build Coastguard Worker 956*da0073e9SAndroid Build Coastguard Worker def test_find_dangling_impls_ext(self): 957*da0073e9SAndroid Build Coastguard Worker extension_path = os.path.join( 958*da0073e9SAndroid Build Coastguard Worker os.path.dirname(os.path.abspath(__file__)), 959*da0073e9SAndroid Build Coastguard Worker "cpp_extensions", 960*da0073e9SAndroid Build Coastguard Worker "dangling_impl_extension.cpp", 961*da0073e9SAndroid Build Coastguard Worker ) 962*da0073e9SAndroid Build Coastguard Worker module = torch.utils.cpp_extension.load( 963*da0073e9SAndroid Build Coastguard Worker name="dangling_impl_extension", 964*da0073e9SAndroid Build Coastguard Worker sources=[extension_path], 965*da0073e9SAndroid Build Coastguard Worker extra_cflags=["-g"], 966*da0073e9SAndroid Build Coastguard Worker verbose=True, 967*da0073e9SAndroid Build Coastguard Worker ) 968*da0073e9SAndroid Build Coastguard Worker 969*da0073e9SAndroid Build Coastguard Worker impls = C._dispatch_find_dangling_impls() 970*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, len(impls)) 971*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 972*da0073e9SAndroid Build Coastguard Worker f"""\ 973*da0073e9SAndroid Build Coastguard Workername: __test::foo 974*da0073e9SAndroid Build Coastguard Workerschema: (none) 975*da0073e9SAndroid Build Coastguard WorkerCPU: registered at {extension_path}:5 :: () -> () [ boxed unboxed ] 976*da0073e9SAndroid Build Coastguard Worker""", 977*da0073e9SAndroid Build Coastguard Worker impls[0], 978*da0073e9SAndroid Build Coastguard Worker ) 979*da0073e9SAndroid Build Coastguard Worker 980*da0073e9SAndroid Build Coastguard Worker def test_dispatch_print_registrations_for_dispatch_key_invalid(self): 981*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 982*da0073e9SAndroid Build Coastguard Worker RuntimeError, "could not parse dispatch key: invalid_key" 983*da0073e9SAndroid Build Coastguard Worker ): 984*da0073e9SAndroid Build Coastguard Worker C._dispatch_print_registrations_for_dispatch_key("invalid_key") 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Worker 987*da0073e9SAndroid Build Coastguard Workerclass TestPythonDispatcher(TestCase): 988*da0073e9SAndroid Build Coastguard Worker def test_basic(self): 989*da0073e9SAndroid Build Coastguard Worker dispatcher = PythonDispatcher() 990*da0073e9SAndroid Build Coastguard Worker dispatcher.register(["CPU", "XLA", "Lazy", "CompositeImplicitAutograd"]) 991*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 992*da0073e9SAndroid Build Coastguard Worker dispatcher.dispatchTable(), 993*da0073e9SAndroid Build Coastguard Worker """\ 994*da0073e9SAndroid Build Coastguard Worker 995*da0073e9SAndroid Build Coastguard WorkerComputed Dispatch Table 996*da0073e9SAndroid Build Coastguard Workerkey kernel 997*da0073e9SAndroid Build Coastguard Worker--------------------------- 998*da0073e9SAndroid Build Coastguard WorkerCPU fn_CPU [kernel] 999*da0073e9SAndroid Build Coastguard WorkerXLA fn_XLA [kernel] 1000*da0073e9SAndroid Build Coastguard WorkerLazy fn_Lazy [kernel] 1001*da0073e9SAndroid Build Coastguard WorkerFPGA fn_CompositeImplicitAutograd [math kernel] 1002*da0073e9SAndroid Build Coastguard WorkerAutogradOther fn_CompositeImplicitAutograd [math kernel] 1003*da0073e9SAndroid Build Coastguard WorkerAutogradCPU [backend fallback] 1004*da0073e9SAndroid Build Coastguard WorkerAutogradXLA [backend fallback] 1005*da0073e9SAndroid Build Coastguard WorkerAutogradLazy [backend fallback] 1006*da0073e9SAndroid Build Coastguard Worker""", 1007*da0073e9SAndroid Build Coastguard Worker ) 1008*da0073e9SAndroid Build Coastguard Worker 1009*da0073e9SAndroid Build Coastguard Worker def test_math_autogradcpu(self): 1010*da0073e9SAndroid Build Coastguard Worker dispatcher = PythonDispatcher() 1011*da0073e9SAndroid Build Coastguard Worker dispatcher.register( 1012*da0073e9SAndroid Build Coastguard Worker ["CPU", "XLA", "Lazy", "CompositeImplicitAutograd", "AutogradCPU"] 1013*da0073e9SAndroid Build Coastguard Worker ) 1014*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1015*da0073e9SAndroid Build Coastguard Worker dispatcher.dispatchTable(), 1016*da0073e9SAndroid Build Coastguard Worker """\ 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard WorkerComputed Dispatch Table 1019*da0073e9SAndroid Build Coastguard Workerkey kernel 1020*da0073e9SAndroid Build Coastguard Worker--------------------------- 1021*da0073e9SAndroid Build Coastguard WorkerCPU fn_CPU [kernel] 1022*da0073e9SAndroid Build Coastguard WorkerXLA fn_XLA [kernel] 1023*da0073e9SAndroid Build Coastguard WorkerLazy fn_Lazy [kernel] 1024*da0073e9SAndroid Build Coastguard WorkerFPGA fn_CompositeImplicitAutograd [math kernel] 1025*da0073e9SAndroid Build Coastguard WorkerAutogradOther fn_CompositeImplicitAutograd [math kernel] 1026*da0073e9SAndroid Build Coastguard WorkerAutogradCPU fn_AutogradCPU [kernel] 1027*da0073e9SAndroid Build Coastguard WorkerAutogradXLA [backend fallback] 1028*da0073e9SAndroid Build Coastguard WorkerAutogradLazy [backend fallback] 1029*da0073e9SAndroid Build Coastguard Worker""", 1030*da0073e9SAndroid Build Coastguard Worker ) 1031*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1032*da0073e9SAndroid Build Coastguard Worker dispatcher.registrations(), 1033*da0073e9SAndroid Build Coastguard Worker """\ 1034*da0073e9SAndroid Build Coastguard Worker 1035*da0073e9SAndroid Build Coastguard WorkerRegistered Kernels 1036*da0073e9SAndroid Build Coastguard Workerkey kernel 1037*da0073e9SAndroid Build Coastguard Worker--------------------------- 1038*da0073e9SAndroid Build Coastguard WorkerCPU fn_CPU 1039*da0073e9SAndroid Build Coastguard WorkerXLA fn_XLA 1040*da0073e9SAndroid Build Coastguard WorkerLazy fn_Lazy 1041*da0073e9SAndroid Build Coastguard WorkerAutogradCPU fn_AutogradCPU 1042*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias] fn_CompositeImplicitAutograd 1043*da0073e9SAndroid Build Coastguard Worker""", 1044*da0073e9SAndroid Build Coastguard Worker ) 1045*da0073e9SAndroid Build Coastguard Worker 1046*da0073e9SAndroid Build Coastguard Worker def test_defaultbackend_autogradcpu(self): 1047*da0073e9SAndroid Build Coastguard Worker dispatcher = PythonDispatcher() 1048*da0073e9SAndroid Build Coastguard Worker dispatcher.register( 1049*da0073e9SAndroid Build Coastguard Worker ["CPU", "XLA", "Lazy", "CompositeExplicitAutograd", "AutogradCPU"] 1050*da0073e9SAndroid Build Coastguard Worker ) 1051*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1052*da0073e9SAndroid Build Coastguard Worker dispatcher.dispatchTable(), 1053*da0073e9SAndroid Build Coastguard Worker """\ 1054*da0073e9SAndroid Build Coastguard Worker 1055*da0073e9SAndroid Build Coastguard WorkerComputed Dispatch Table 1056*da0073e9SAndroid Build Coastguard Workerkey kernel 1057*da0073e9SAndroid Build Coastguard Worker--------------------------- 1058*da0073e9SAndroid Build Coastguard WorkerCPU fn_CPU [kernel] 1059*da0073e9SAndroid Build Coastguard WorkerXLA fn_XLA [kernel] 1060*da0073e9SAndroid Build Coastguard WorkerLazy fn_Lazy [kernel] 1061*da0073e9SAndroid Build Coastguard WorkerFPGA fn_CompositeExplicitAutograd [default backend kernel] 1062*da0073e9SAndroid Build Coastguard WorkerAutogradOther [backend fallback] 1063*da0073e9SAndroid Build Coastguard WorkerAutogradCPU fn_AutogradCPU [kernel] 1064*da0073e9SAndroid Build Coastguard WorkerAutogradXLA [backend fallback] 1065*da0073e9SAndroid Build Coastguard WorkerAutogradLazy [backend fallback] 1066*da0073e9SAndroid Build Coastguard Worker""", 1067*da0073e9SAndroid Build Coastguard Worker ) 1068*da0073e9SAndroid Build Coastguard Worker 1069*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1070*da0073e9SAndroid Build Coastguard Worker dispatcher.registrations(), 1071*da0073e9SAndroid Build Coastguard Worker """\ 1072*da0073e9SAndroid Build Coastguard Worker 1073*da0073e9SAndroid Build Coastguard WorkerRegistered Kernels 1074*da0073e9SAndroid Build Coastguard Workerkey kernel 1075*da0073e9SAndroid Build Coastguard Worker--------------------------- 1076*da0073e9SAndroid Build Coastguard WorkerCPU fn_CPU 1077*da0073e9SAndroid Build Coastguard WorkerXLA fn_XLA 1078*da0073e9SAndroid Build Coastguard WorkerLazy fn_Lazy 1079*da0073e9SAndroid Build Coastguard WorkerAutogradCPU fn_AutogradCPU 1080*da0073e9SAndroid Build Coastguard WorkerCompositeExplicitAutograd[alias] fn_CompositeExplicitAutograd 1081*da0073e9SAndroid Build Coastguard Worker""", 1082*da0073e9SAndroid Build Coastguard Worker ) 1083*da0073e9SAndroid Build Coastguard Worker 1084*da0073e9SAndroid Build Coastguard Worker def test_autogradother(self): 1085*da0073e9SAndroid Build Coastguard Worker dispatcher = PythonDispatcher() 1086*da0073e9SAndroid Build Coastguard Worker dispatcher.register(["CPU", "FPGA", "CompositeImplicitAutograd"]) 1087*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1088*da0073e9SAndroid Build Coastguard Worker dispatcher.dispatchTable(), 1089*da0073e9SAndroid Build Coastguard Worker """\ 1090*da0073e9SAndroid Build Coastguard Worker 1091*da0073e9SAndroid Build Coastguard WorkerComputed Dispatch Table 1092*da0073e9SAndroid Build Coastguard Workerkey kernel 1093*da0073e9SAndroid Build Coastguard Worker--------------------------- 1094*da0073e9SAndroid Build Coastguard WorkerCPU fn_CPU [kernel] 1095*da0073e9SAndroid Build Coastguard WorkerXLA fn_CompositeImplicitAutograd [math kernel] 1096*da0073e9SAndroid Build Coastguard WorkerLazy fn_CompositeImplicitAutograd [math kernel] 1097*da0073e9SAndroid Build Coastguard WorkerFPGA fn_FPGA [kernel] 1098*da0073e9SAndroid Build Coastguard WorkerAutogradOther ambiguous_autogradother [ambiguous autogradother] 1099*da0073e9SAndroid Build Coastguard WorkerAutogradCPU [backend fallback] 1100*da0073e9SAndroid Build Coastguard WorkerAutogradXLA fn_CompositeImplicitAutograd [math kernel] 1101*da0073e9SAndroid Build Coastguard WorkerAutogradLazy fn_CompositeImplicitAutograd [math kernel] 1102*da0073e9SAndroid Build Coastguard Worker""", 1103*da0073e9SAndroid Build Coastguard Worker ) 1104*da0073e9SAndroid Build Coastguard Worker 1105*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1106*da0073e9SAndroid Build Coastguard Worker dispatcher.registrations(), 1107*da0073e9SAndroid Build Coastguard Worker """\ 1108*da0073e9SAndroid Build Coastguard Worker 1109*da0073e9SAndroid Build Coastguard WorkerRegistered Kernels 1110*da0073e9SAndroid Build Coastguard Workerkey kernel 1111*da0073e9SAndroid Build Coastguard Worker--------------------------- 1112*da0073e9SAndroid Build Coastguard WorkerFPGA fn_FPGA 1113*da0073e9SAndroid Build Coastguard WorkerCPU fn_CPU 1114*da0073e9SAndroid Build Coastguard WorkerCompositeImplicitAutograd[alias] fn_CompositeImplicitAutograd 1115*da0073e9SAndroid Build Coastguard Worker""", 1116*da0073e9SAndroid Build Coastguard Worker ) 1117*da0073e9SAndroid Build Coastguard Worker 1118*da0073e9SAndroid Build Coastguard Worker def test_duplicate_registrations(self): 1119*da0073e9SAndroid Build Coastguard Worker dispatcher = PythonDispatcher() 1120*da0073e9SAndroid Build Coastguard Worker 1121*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Overriden is not allowed"): 1122*da0073e9SAndroid Build Coastguard Worker dispatcher.register(["CPU", "CPU"]) 1123*da0073e9SAndroid Build Coastguard Worker 1124*da0073e9SAndroid Build Coastguard Worker def test_defaultbackend_math(self): 1125*da0073e9SAndroid Build Coastguard Worker dispatcher = PythonDispatcher() 1126*da0073e9SAndroid Build Coastguard Worker 1127*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1128*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1129*da0073e9SAndroid Build Coastguard Worker r"Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed", 1130*da0073e9SAndroid Build Coastguard Worker ): 1131*da0073e9SAndroid Build Coastguard Worker dispatcher.register( 1132*da0073e9SAndroid Build Coastguard Worker ["CompositeExplicitAutograd", "CompositeImplicitAutograd"] 1133*da0073e9SAndroid Build Coastguard Worker ) 1134*da0073e9SAndroid Build Coastguard Worker 1135*da0073e9SAndroid Build Coastguard Worker def test_quantized_structured_not_implemented(self): 1136*da0073e9SAndroid Build Coastguard Worker x = torch.zeros([1, 1, 1]) 1137*da0073e9SAndroid Build Coastguard Worker y = torch.zeros([1, 1, 1]) 1138*da0073e9SAndroid Build Coastguard Worker scale, zero_point = 1.0, 0 1139*da0073e9SAndroid Build Coastguard Worker dtype = torch.qint8 1140*da0073e9SAndroid Build Coastguard Worker qx = torch.quantize_per_tensor(x, scale, zero_point, dtype) 1141*da0073e9SAndroid Build Coastguard Worker qy = torch.quantize_per_tensor(y, scale, zero_point, dtype) 1142*da0073e9SAndroid Build Coastguard Worker # If bmm gets quantized support you need to update this to something 1143*da0073e9SAndroid Build Coastguard Worker # else that is not implemented 1144*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1145*da0073e9SAndroid Build Coastguard Worker NotImplementedError, 1146*da0073e9SAndroid Build Coastguard Worker "Could not run 'aten::bmm.out' with arguments from the 'QuantizedCPU' backend.", 1147*da0073e9SAndroid Build Coastguard Worker lambda: torch.bmm(qx, qy), 1148*da0073e9SAndroid Build Coastguard Worker ) 1149*da0073e9SAndroid Build Coastguard Worker 1150*da0073e9SAndroid Build Coastguard Worker 1151*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1152*da0073e9SAndroid Build Coastguard Worker run_tests() 1153