# Owner(s): ["module: dynamo"] import functools import weakref import torch import torch._dynamo import torch._dynamo.test_case from torch._C._dynamo import guards from torch._dynamo.convert_frame import GlobalStateGuard from torch.testing._internal.common_utils import set_default_dtype RootGuardManager = guards.RootGuardManager DictGuardManager = guards.DictGuardManager DictSubclassGuardManager = guards.DictSubclassGuardManager GetAttrGuardAccessor = guards.GetAttrGuardAccessor GetItemGuardAccessor = guards.GetItemGuardAccessor TypeGuardAccessor = guards.TypeGuardAccessor OBJECT_ALIASING = guards.OBJECT_ALIASING install_object_aliasing_guard = guards.install_object_aliasing_guard NO_TENSOR_ALIASING = guards.NO_TENSOR_ALIASING install_no_tensor_aliasing_guard = guards.install_no_tensor_aliasing_guard x = torch.tensor(4) weakref_x = weakref.ref(x) default_mgr_enum = torch._dynamo.guards.GuardManagerType.GUARD_MANAGER class Pair: def __init__(self, x, y): self.x = x self.y = y global_pair = Pair(torch.randn(4), 1) def id_type(x): return id(type(x)) def equals_match(x, expected): return x == expected def equals_match_verbose_code_parts(expected): return [f"x == {expected}"] def ge_match(x, expected): return x >= expected def ge_match_verbose_code_parts(expected): return f"expected >= {expected}" def less_match(x, expected): return x < expected def less_match_verbose_code_parts(expected): return [f"expected < {expected}"] class GuardManagerTests(torch._dynamo.test_case.TestCase): def test_global_state_guard(self): guard = guards.GLOBAL_STATE(["global_state_check"]) self.assertTrue(guard(None)) with set_default_dtype(torch.double): self.assertFalse(guard(None)) self.assertExpectedInline( str(guard.check_verbose(None)), """\ GuardDebugInfo( result=0, verbose_code_parts=['GLOBAL_STATE changed: default_dtype '], num_guards_executed=0) """, ) self.assertTrue(guard(None)) self.assertTrue(guard.check_verbose(None).result) _orig = torch.are_deterministic_algorithms_enabled() try: torch.use_deterministic_algorithms(not _orig) self.assertFalse(guard(None)) self.assertExpectedInline( str(guard.check_verbose(None)), """\ GuardDebugInfo( result=0, verbose_code_parts=['GLOBAL_STATE changed: deterministic_algorithms '], num_guards_executed=0) """, ) finally: torch.use_deterministic_algorithms(_orig) self.assertTrue(guard(None)) self.assertTrue(guard.check_verbose(None).result) def test_global_state_reason(self): with torch.enable_grad(): guards = GlobalStateGuard() with torch.no_grad(): self.assertIs(guards.check(), False) self.assertEqual(guards.reason(), "grad_mode ") def test_python_lambda_leaf_guard(self): const_guard = guards.LAMBDA_GUARD( functools.partial(equals_match, expected=5), equals_match_verbose_code_parts(5), ) self.assertTrue(const_guard(5)) self.assertFalse(const_guard(4)) self.assertFalse(const_guard("foo")) def test_type_guard(self): foo = 4 guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == int"]) self.assertTrue(guard(5)) self.assertTrue(guard(4)) self.assertFalse(guard("foo")) foo = {"a": 1} guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == dict"]) self.assertTrue(guard(foo)) self.assertTrue(guard({})) self.assertFalse(guard(5)) self.assertFalse(guard("foo")) class Foo: def __init__(self, x, y): self.x = x self.y = y foo = Foo(1, 2) guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == Foo"]) self.assertTrue(guard(foo)) self.assertFalse(guard({})) self.assertFalse(guard(5)) self.assertFalse(guard("foo")) def test_id_guard(self): foo = 4 guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"]) self.assertTrue(guard(foo)) self.assertFalse(guard(5)) self.assertFalse(guard("foo")) foo = {"a": 1} guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"]) self.assertTrue(guard(foo)) self.assertFalse(guard({"a": 1})) self.assertFalse(guard({})) self.assertFalse(guard(5)) def test_equals_guard(self): foo = 4 guard = guards.EQUALS_MATCH(foo, ["x == 4"]) self.assertTrue(guard(4)) self.assertFalse(guard(5)) self.assertFalse(guard("foo")) # tuple foo = (1, 2, 3) guard = guards.EQUALS_MATCH(foo, ["x == foo"]) self.assertTrue(guard(foo)) self.assertTrue(guard((1, 2, 3))) self.assertFalse(guard((1, 2, 3, 4))) self.assertFalse(guard({})) # list foo = [1, 2, 3] guard = guards.EQUALS_MATCH(foo, ["x == foo"]) self.assertTrue(guard(foo)) self.assertTrue(guard([1, 2, 3])) self.assertFalse(guard([1, 2, 3, 4])) # type foo = int guard = guards.EQUALS_MATCH(foo, ["x == foo"]) self.assertTrue(guard(foo)) self.assertTrue(guard(int)) self.assertFalse(guard(float)) def test_default_device_guard(self): foo = 1 guard = guards.DEFAULT_DEVICE(["cpu device"]) self.assertTrue(guard(foo)) try: torch.set_default_device("cuda") self.assertFalse(guard(foo)) finally: torch.set_default_device(None) def test_data_ptr_match_guard(self): foo = torch.tensor([1, 2, 3]) guard = guards.DATA_PTR_MATCH(foo, ["x.data_ptr() == foo.data_ptr()"]) self.assertTrue(guard(foo)) self.assertFalse(guard(torch.tensor([1, 2, 3]))) def test_length_check_guard(self): foo = [1, 2, 3] guard = guards.LENGTH_CHECK(len(foo), ["len(x) == len(foo)"]) self.assertTrue(guard(foo)) self.assertFalse(guard([])) def test_no_hasattr_guard(self): class Bar: def __init__(self) -> None: self.bar = 2 bar = Bar() class Foo: def __init__(self) -> None: self.foo = 2 foo = Foo() guard = guards.NO_HASATTR("foo", ["hasattr(x, 'foo') == False"]) self.assertTrue(guard(bar)) self.assertFalse(guard(foo)) def test_tensor_aliasing_guard(self): guard_manager = RootGuardManager() a = torch.randn(3, 4) class Foo: def __init__(self, x, y): self.x = x self.y = y f_locals = Foo(a, a) x_guard_mgr = guard_manager.getattr_manager("x", "", a, default_mgr_enum) y_guard_mgr = guard_manager.getattr_manager("y", "", a, default_mgr_enum) install_object_aliasing_guard(x_guard_mgr, y_guard_mgr, ["x is y"]) # Check structure x_guards = x_guard_mgr.get_leaf_guards() y_guards = y_guard_mgr.get_leaf_guards() self.assertEqual(len(x_guards), 1) self.assertEqual(len(y_guards), 1) self.assertTrue(isinstance(x_guards[0], OBJECT_ALIASING)) self.assertTrue(isinstance(y_guards[0], OBJECT_ALIASING)) # Check that the two guards are the same object self.assertTrue(x_guards[0] is y_guards[0]) f_locals_unaliased = Foo(torch.randn(3, 4), torch.randn(3, 4)) self.assertEqual(len(x_guard_mgr.get_leaf_guards()), 1) self.assertEqual(len(y_guard_mgr.get_leaf_guards()), 1) self.assertTrue(guard_manager.check(f_locals)) self.assertFalse(guard_manager.check(f_locals_unaliased)) def test_dict_version_guard(self): foo = {"a": 1, "b": 2} guard = guards.DICT_VERSION(foo, ["x.version == foo.version"]) self.assertTrue(guard(foo)) self.assertFalse(guard(dict(foo))) foo["a"] = 2 self.assertFalse(guard(foo)) self.assertFalse(guard({"a": 1, "b": 2})) self.assertFalse(guard({})) def test_dynamic_indices_guard(self): guard1 = guards.DYNAMIC_INDICES(set(), ["x.size(0) == y.size(0)"]) guard2 = guards.DYNAMIC_INDICES(set({0, 1}), ["x.size(0) == y.size(0)"]) x = torch.randn(4) self.assertTrue(guard1(x)) self.assertTrue(guard2(x)) x._dynamo_dynamic_indices = set({0}) self.assertFalse(guard1(x)) self.assertTrue(guard2(x)) x._dynamo_dynamic_indices = set({2}) self.assertFalse(guard1(x)) self.assertFalse(guard2(x)) def test_tensor_match_guard(self): guard_manager = RootGuardManager() x = torch.randn(4, 4) size = list(x.size()) stride = list(x.stride()) guard_manager.add_tensor_match_guard(x, size, stride, "x", ["check_tensor(x)"]) self.assertTrue(guard_manager.check(x)) self.assertTrue(guard_manager.check_verbose(x).result) self.assertTrue(guard_manager.check(torch.randn(4, 4))) self.assertTrue(guard_manager.check_verbose(torch.randn(4, 4)).result) self.assertFalse(guard_manager.check(x.t_())) x = torch.randn(4, 4) x.t_() debug_info = guard_manager.check_verbose(x) print(debug_info.verbose_code_parts[0]) self.assertTrue( "tensor 'x' stride mismatch" in debug_info.verbose_code_parts[0] ) def test_no_tensor_aliasing_guard(self): guard_manager = RootGuardManager() a = torch.randn(3, 4) class Foo: def __init__(self, x, y, z): self.x = x self.y = y self.z = z f_locals = Foo(a, a, a) x_guard_mgr = guard_manager.getattr_manager("x", "", a, default_mgr_enum) y_guard_mgr = guard_manager.getattr_manager("y", "", a, default_mgr_enum) z_guard_mgr = guard_manager.getattr_manager("z", "", a, default_mgr_enum) install_no_tensor_aliasing_guard( [x_guard_mgr, y_guard_mgr, z_guard_mgr], ["x", "y", "z"], ["no_aliasing(x, y, z)"], ) # Check structure x_guards = x_guard_mgr.get_leaf_guards() y_guards = y_guard_mgr.get_leaf_guards() z_guards = z_guard_mgr.get_leaf_guards() self.assertEqual(len(x_guards), 1) self.assertEqual(len(y_guards), 1) self.assertEqual(len(z_guards), 1) self.assertTrue(isinstance(x_guards[0], NO_TENSOR_ALIASING)) self.assertTrue(isinstance(y_guards[0], NO_TENSOR_ALIASING)) self.assertTrue(isinstance(z_guards[0], NO_TENSOR_ALIASING)) # Check that the two guards are the same object self.assertTrue(x_guards[0] is y_guards[0] is z_guards[0]) self.assertFalse(guard_manager.check(f_locals)) self.assertFalse(guard_manager.check_verbose(f_locals).result) f_locals_unaliased = Foo( torch.randn(3, 4), torch.randn(3, 4), torch.randn(3, 4), ) self.assertTrue(guard_manager.check(f_locals_unaliased)) self.assertTrue(guard_manager.check_verbose(f_locals_unaliased).result) # Check that hash map is cleared. self.assertTrue(guard_manager.check(f_locals_unaliased)) f_locals_unaliased = Foo( a, torch.randn(3, 4), a, ) self.assertFalse(guard_manager.check(f_locals_unaliased)) self.assertFalse(guard_manager.check_verbose(f_locals_unaliased).result) def test_weakref_alive_guard(self): x = torch.rand(3, 4) weakref_x = weakref.ref(x) guard = guards.NOT_NONE(["weakref_x is not None"]) self.assertTrue(guard(weakref_x())) del x self.assertFalse(guard(weakref_x())) def test_guard_manager_leaf_guard(self): guard_manager = RootGuardManager() guard_manager.add_type_match_guard(id_type(5), ["type(x) == int"]) guard_manager.add_lambda_guard( functools.partial(ge_match, expected=5), ge_match_verbose_code_parts(expected=5), ) guard_manager.add_lambda_guard( functools.partial(less_match, expected=10), less_match_verbose_code_parts(expected=10), ) self.assertEqual(len(guard_manager.get_leaf_guards()), 3) self.assertEqual(len(guard_manager.get_accessors()), 0) self.assertTrue(guard_manager.check(6)) self.assertFalse(guard_manager.check(4)) self.assertFalse(guard_manager.check("foo")) def test_attr_guard_manager(self): class Foo: def __init__(self, x, y): self.x = x self.y = y foo = Foo(1, 2) guard_manager = RootGuardManager() guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"]) guard_manager.getattr_manager("x", "x", 1, default_mgr_enum).add_lambda_guard( functools.partial(equals_match, expected=foo.x), equals_match_verbose_code_parts(foo.x), ) guard_manager.getattr_manager("y", "y", 2, default_mgr_enum).add_lambda_guard( functools.partial(equals_match, expected=foo.y), equals_match_verbose_code_parts(foo.y), ) self.assertEqual(len(guard_manager.get_leaf_guards()), 1) # 2 child managers, one for x and one for y self.assertEqual(len(guard_manager.get_accessors()), 2) self.assertTrue( isinstance(guard_manager.get_accessors()[0], GetAttrGuardAccessor) ) self.assertTrue( isinstance(guard_manager.get_accessors()[1], GetAttrGuardAccessor) ) # Check leaf guards on child managers self.assertEqual( len( guard_manager.getattr_manager( attr="x", source="x", example_value=None, guard_manager_enum=default_mgr_enum, ).get_leaf_guards() ), 1, ) self.assertEqual( len( guard_manager.getattr_manager( "y", "y", None, default_mgr_enum ).get_leaf_guards() ), 1, ) self.assertTrue(guard_manager.check(foo)) self.assertFalse(guard_manager.check(Foo(3, 4))) self.assertFalse(guard_manager.check("foo")) def test_item_guard_manager(self): foo = [1, 2] guard_manager = RootGuardManager() guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"]) guard_manager.getitem_manager(0, "", 1, default_mgr_enum).add_lambda_guard( functools.partial(equals_match, expected=foo[0]), equals_match_verbose_code_parts(foo[0]), ) guard_manager.getitem_manager(1, "", 2, default_mgr_enum).add_lambda_guard( functools.partial(equals_match, expected=foo[1]), equals_match_verbose_code_parts(foo[1]), ) self.assertEqual(len(guard_manager.get_leaf_guards()), 1) # 2 child managers, one for x and one for y self.assertEqual(len(guard_manager.get_accessors()), 2) self.assertTrue( isinstance(guard_manager.get_accessors()[0], GetItemGuardAccessor) ) self.assertTrue( isinstance(guard_manager.get_accessors()[1], GetItemGuardAccessor) ) # Check leaf guards on child managers self.assertEqual( len( guard_manager.getitem_manager( 0, "", None, default_mgr_enum ).get_leaf_guards() ), 1, ) self.assertEqual( len( guard_manager.getitem_manager( 1, "", None, default_mgr_enum ).get_leaf_guards() ), 1, ) self.assertTrue(guard_manager.check(foo)) self.assertFalse(guard_manager.check([3, 4])) self.assertFalse(guard_manager.check("foo")) def test_dict_getitem_accessor(self): foo = { "a": 1, "b": 2, } guards_manager = RootGuardManager() guards_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"]) guards_manager.dict_getitem_manager( "a", "", 1, default_mgr_enum ).add_equals_match_guard(1, ["a == 1"]) guards_manager.dict_getitem_manager( "b", "", 2, default_mgr_enum ).add_equals_match_guard(2, ["b == 2"]) self.assertTrue(guards_manager.check(foo)) self.assertFalse(guards_manager.check({"a": 1, "b": 3})) def test_globals(self): global global_pair, Pair guard_manager = RootGuardManager() gpair_mgr = guard_manager.globals_dict_manager( globals(), "", None, default_mgr_enum ).getitem_manager("global_pair", "", global_pair, default_mgr_enum) gpair_mgr.add_lambda_guard( lambda x: isinstance(x, Pair) and isinstance(x.x, torch.Tensor) and isinstance(x.y, int), "global guard fail", ) self.assertTrue(guard_manager.check(global_pair)) global_pair.y = "foo" self.assertFalse(guard_manager.check(global_pair)) def test_type_manager(self): guard_manager = RootGuardManager() class A: a = 4 class B(A): def mul(self, x): super().mul(x) foo = B() f_locals = {"foo": foo} # len(type(foo).__mro__) == 2 foo_mgr = guard_manager.getitem_manager("foo", "", foo, default_mgr_enum) type_manager = foo_mgr.type_manager("", type(foo), default_mgr_enum) self.assertTrue(isinstance(foo_mgr.get_accessors()[0], TypeGuardAccessor)) mro_manager = type_manager.getattr_manager( "__mro__", "", type(foo).__mro__, default_mgr_enum ) self.assertTrue( isinstance(type_manager.get_accessors()[0], GetAttrGuardAccessor) ) mro_manager.add_length_check_guard( 3, "Expected len(type(foo).__mro__) == 3", ) # type(foo).__mro__[0].a = 4 item_manager = mro_manager.getitem_manager( 1, "", type(foo).__mro__[1], default_mgr_enum ) self.assertTrue( isinstance(mro_manager.get_accessors()[0], GetItemGuardAccessor) ) attr_manager = item_manager.getattr_manager( "a", "", type(foo).__mro__[0].a, default_mgr_enum ) self.assertTrue( isinstance(item_manager.get_accessors()[0], GetAttrGuardAccessor) ) attr_manager.add_lambda_guard( lambda x: x == 4, "Expected value 4", ) self.assertTrue(guard_manager.check(f_locals)) def test_tuple_iterator_getitem(self): a = (1, 2, 3, 4, 5, 6) foo = iter(a) next(foo) # foo points at index=1 guard_manager = RootGuardManager() # Check a[3] which is tuple_iterator_getitem(foo, 2) guard_manager.add_tuple_iterator_length_guard( 5, id_type(iter(())), ["len == 5"] ) guard_manager.tuple_iterator_getitem_manager( 2, "", foo, default_mgr_enum ).add_equals_match_guard(a[3], ["x==4"]) # Check that type match works self.assertFalse(guard_manager.check(False)) self.assertTrue(guard_manager.check(foo)) # Check that index error fails gracefully b = (1, 2) b_foo = iter(b) self.assertFalse(guard_manager.check(b_foo)) def test_global_weakref(self): guard_manager = RootGuardManager() globals_manager = guard_manager.globals_dict_manager( globals(), "", None, default_mgr_enum ) weakref_manager = globals_manager.global_weakref_manager( "weakref_x", "", None, default_mgr_enum ) weakref_manager.add_lambda_guard( lambda x: isinstance(x, torch.Tensor), "global weakref fail", ) self.assertTrue(guard_manager.check(None)) global x del x self.assertFalse(guard_manager.check(None)) def test_lambda_manager(self): a = (1, 1, 3, 4, 5, 6) guard_manager = RootGuardManager() # Check that we can use the same accessor foo_mgr = guard_manager.lambda_manager( lambda x: x[2], "", None, default_mgr_enum ) foo_mgr.add_lambda_guard( lambda x: x == 3, "Expected value 3", ) self.assertTrue(guard_manager.check(a)) # test that exception works guard_manager = RootGuardManager() def fn(x): raise AssertionError("Test") return x foo_mgr = guard_manager.lambda_manager(fn, "", None, default_mgr_enum) self.assertFalse(guard_manager.check(None)) debug_info = guard_manager.check_verbose(None) self.assertFalse(debug_info.result) self.assertTrue("Test" in debug_info.verbose_code_parts[0]) def test_dict_contains_guard(self): foo = {"a": 1, "b": 2} guard = guards.DICT_CONTAINS(True, "a", ["has a"]) self.assertTrue(guard(foo)) self.assertTrue(guard({"a": 1, "b": 2})) self.assertFalse(guard({"b": 2, "c": 3})) self.assertFalse(guard({})) guard = guards.DICT_CONTAINS(False, "c", ["not has c"]) self.assertTrue(guard(foo)) self.assertTrue(guard({"a": 1, "b": 2})) self.assertFalse(guard({"b": 2, "c": 3})) self.assertTrue(guard({})) def test_dict_guard_manager(self): root = RootGuardManager() def nothing(): pass f_locals = { "d": {"a": 1, nothing: {"z": 3}, 100: torch.randn(4)}, } # its a getitem_manager just for f_locals. But the child guard manager # should be a DictGuardManager. dict_mgr = root.getitem_manager( "d", "", f_locals["d"], torch._dynamo.guards.GuardManagerType.DICT_GUARD_MANAGER, ) self.assertTrue(isinstance(dict_mgr, DictGuardManager)) self.assertTrue(root.check(f_locals)) # Check that no one can add a leaf guard with self.assertRaises(RuntimeError): dict_mgr.add_id_match_guard(id_type(f_locals), "id match") # Check that no one can add an arbitrary accessor with self.assertRaises(RuntimeError): dict_mgr.getitem_manager("a", "", f_locals["d"]["a"]) # Check that it fails with different length dict f_locals_prime = { "d": {"a": 1, "b": 2}, } self.assertFalse(root.check(f_locals_prime)) # Add key-value manager ("a" : 1) self.assertTrue(root.check(f_locals)) dict_mgr.get_key_manager(0, "", "a", default_mgr_enum).add_equals_match_guard( "a", ["dict.keys()[0] == a"], ) self.assertTrue(root.check(f_locals)) dict_mgr.get_value_manager(0, "", 1, default_mgr_enum).add_equals_match_guard( 1, ["d[0] == 1"] ) self.assertTrue(root.check(f_locals)) # Add key-value manager (nothing : {"z" : 3}) self.assertTrue(root.check(f_locals)) dict_mgr.get_key_manager(1, "", nothing, default_mgr_enum).add_lambda_guard( lambda x: x is nothing, ["x is nothing"] ) self.assertTrue(root.check(f_locals)) value_mgr = dict_mgr.get_value_manager( 1, "", f_locals["d"][nothing], torch._dynamo.guards.GuardManagerType.DICT_GUARD_MANAGER, ) self.assertTrue(isinstance(value_mgr, DictGuardManager)) self.assertTrue(root.check(f_locals)) # Check structure # Check that we are only guarding on two keys. This is common in # LazyVariableTracker. self.assertEqual(len(dict_mgr.get_key_value_managers()), 2) f_locals["d"]["a"] = 2 self.assertFalse(root.check(f_locals)) self.assertFalse(root.check_verbose(f_locals).result) f_locals["d"]["a"] = 1 self.assertTrue(root.check(f_locals)) f_locals["d"].pop(100) # fails because of len check self.assertFalse(root.check(f_locals)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()