1# Owner(s): ["module: dynamo"] 2import functools 3import weakref 4 5import torch 6import torch._dynamo 7import torch._dynamo.test_case 8from torch._C._dynamo import guards 9from torch._dynamo.convert_frame import GlobalStateGuard 10from torch.testing._internal.common_utils import set_default_dtype 11 12 13RootGuardManager = guards.RootGuardManager 14DictGuardManager = guards.DictGuardManager 15DictSubclassGuardManager = guards.DictSubclassGuardManager 16GetAttrGuardAccessor = guards.GetAttrGuardAccessor 17GetItemGuardAccessor = guards.GetItemGuardAccessor 18TypeGuardAccessor = guards.TypeGuardAccessor 19OBJECT_ALIASING = guards.OBJECT_ALIASING 20install_object_aliasing_guard = guards.install_object_aliasing_guard 21NO_TENSOR_ALIASING = guards.NO_TENSOR_ALIASING 22install_no_tensor_aliasing_guard = guards.install_no_tensor_aliasing_guard 23 24 25x = torch.tensor(4) 26weakref_x = weakref.ref(x) 27 28default_mgr_enum = torch._dynamo.guards.GuardManagerType.GUARD_MANAGER 29 30 31class Pair: 32 def __init__(self, x, y): 33 self.x = x 34 self.y = y 35 36 37global_pair = Pair(torch.randn(4), 1) 38 39 40def id_type(x): 41 return id(type(x)) 42 43 44def equals_match(x, expected): 45 return x == expected 46 47 48def equals_match_verbose_code_parts(expected): 49 return [f"x == {expected}"] 50 51 52def ge_match(x, expected): 53 return x >= expected 54 55 56def ge_match_verbose_code_parts(expected): 57 return f"expected >= {expected}" 58 59 60def less_match(x, expected): 61 return x < expected 62 63 64def less_match_verbose_code_parts(expected): 65 return [f"expected < {expected}"] 66 67 68class GuardManagerTests(torch._dynamo.test_case.TestCase): 69 def test_global_state_guard(self): 70 guard = guards.GLOBAL_STATE(["global_state_check"]) 71 self.assertTrue(guard(None)) 72 with set_default_dtype(torch.double): 73 self.assertFalse(guard(None)) 74 self.assertExpectedInline( 75 str(guard.check_verbose(None)), 76 """\ 77GuardDebugInfo( 78result=0, 79verbose_code_parts=['GLOBAL_STATE changed: default_dtype '], 80num_guards_executed=0) 81""", 82 ) 83 self.assertTrue(guard(None)) 84 self.assertTrue(guard.check_verbose(None).result) 85 _orig = torch.are_deterministic_algorithms_enabled() 86 try: 87 torch.use_deterministic_algorithms(not _orig) 88 self.assertFalse(guard(None)) 89 self.assertExpectedInline( 90 str(guard.check_verbose(None)), 91 """\ 92GuardDebugInfo( 93result=0, 94verbose_code_parts=['GLOBAL_STATE changed: deterministic_algorithms '], 95num_guards_executed=0) 96""", 97 ) 98 finally: 99 torch.use_deterministic_algorithms(_orig) 100 self.assertTrue(guard(None)) 101 self.assertTrue(guard.check_verbose(None).result) 102 103 def test_global_state_reason(self): 104 with torch.enable_grad(): 105 guards = GlobalStateGuard() 106 with torch.no_grad(): 107 self.assertIs(guards.check(), False) 108 self.assertEqual(guards.reason(), "grad_mode ") 109 110 def test_python_lambda_leaf_guard(self): 111 const_guard = guards.LAMBDA_GUARD( 112 functools.partial(equals_match, expected=5), 113 equals_match_verbose_code_parts(5), 114 ) 115 self.assertTrue(const_guard(5)) 116 self.assertFalse(const_guard(4)) 117 self.assertFalse(const_guard("foo")) 118 119 def test_type_guard(self): 120 foo = 4 121 guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == int"]) 122 123 self.assertTrue(guard(5)) 124 self.assertTrue(guard(4)) 125 self.assertFalse(guard("foo")) 126 127 foo = {"a": 1} 128 guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == dict"]) 129 self.assertTrue(guard(foo)) 130 self.assertTrue(guard({})) 131 self.assertFalse(guard(5)) 132 self.assertFalse(guard("foo")) 133 134 class Foo: 135 def __init__(self, x, y): 136 self.x = x 137 self.y = y 138 139 foo = Foo(1, 2) 140 141 guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == Foo"]) 142 self.assertTrue(guard(foo)) 143 self.assertFalse(guard({})) 144 self.assertFalse(guard(5)) 145 self.assertFalse(guard("foo")) 146 147 def test_id_guard(self): 148 foo = 4 149 guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"]) 150 151 self.assertTrue(guard(foo)) 152 self.assertFalse(guard(5)) 153 self.assertFalse(guard("foo")) 154 155 foo = {"a": 1} 156 guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"]) 157 self.assertTrue(guard(foo)) 158 self.assertFalse(guard({"a": 1})) 159 self.assertFalse(guard({})) 160 self.assertFalse(guard(5)) 161 162 def test_equals_guard(self): 163 foo = 4 164 guard = guards.EQUALS_MATCH(foo, ["x == 4"]) 165 166 self.assertTrue(guard(4)) 167 self.assertFalse(guard(5)) 168 self.assertFalse(guard("foo")) 169 170 # tuple 171 foo = (1, 2, 3) 172 guard = guards.EQUALS_MATCH(foo, ["x == foo"]) 173 self.assertTrue(guard(foo)) 174 self.assertTrue(guard((1, 2, 3))) 175 self.assertFalse(guard((1, 2, 3, 4))) 176 self.assertFalse(guard({})) 177 178 # list 179 foo = [1, 2, 3] 180 guard = guards.EQUALS_MATCH(foo, ["x == foo"]) 181 self.assertTrue(guard(foo)) 182 self.assertTrue(guard([1, 2, 3])) 183 self.assertFalse(guard([1, 2, 3, 4])) 184 185 # type 186 foo = int 187 guard = guards.EQUALS_MATCH(foo, ["x == foo"]) 188 self.assertTrue(guard(foo)) 189 self.assertTrue(guard(int)) 190 self.assertFalse(guard(float)) 191 192 def test_default_device_guard(self): 193 foo = 1 194 guard = guards.DEFAULT_DEVICE(["cpu device"]) 195 self.assertTrue(guard(foo)) 196 197 try: 198 torch.set_default_device("cuda") 199 self.assertFalse(guard(foo)) 200 finally: 201 torch.set_default_device(None) 202 203 def test_data_ptr_match_guard(self): 204 foo = torch.tensor([1, 2, 3]) 205 guard = guards.DATA_PTR_MATCH(foo, ["x.data_ptr() == foo.data_ptr()"]) 206 self.assertTrue(guard(foo)) 207 self.assertFalse(guard(torch.tensor([1, 2, 3]))) 208 209 def test_length_check_guard(self): 210 foo = [1, 2, 3] 211 guard = guards.LENGTH_CHECK(len(foo), ["len(x) == len(foo)"]) 212 self.assertTrue(guard(foo)) 213 self.assertFalse(guard([])) 214 215 def test_no_hasattr_guard(self): 216 class Bar: 217 def __init__(self) -> None: 218 self.bar = 2 219 220 bar = Bar() 221 222 class Foo: 223 def __init__(self) -> None: 224 self.foo = 2 225 226 foo = Foo() 227 228 guard = guards.NO_HASATTR("foo", ["hasattr(x, 'foo') == False"]) 229 self.assertTrue(guard(bar)) 230 self.assertFalse(guard(foo)) 231 232 def test_tensor_aliasing_guard(self): 233 guard_manager = RootGuardManager() 234 235 a = torch.randn(3, 4) 236 237 class Foo: 238 def __init__(self, x, y): 239 self.x = x 240 self.y = y 241 242 f_locals = Foo(a, a) 243 244 x_guard_mgr = guard_manager.getattr_manager("x", "", a, default_mgr_enum) 245 y_guard_mgr = guard_manager.getattr_manager("y", "", a, default_mgr_enum) 246 install_object_aliasing_guard(x_guard_mgr, y_guard_mgr, ["x is y"]) 247 248 # Check structure 249 x_guards = x_guard_mgr.get_leaf_guards() 250 y_guards = y_guard_mgr.get_leaf_guards() 251 self.assertEqual(len(x_guards), 1) 252 self.assertEqual(len(y_guards), 1) 253 self.assertTrue(isinstance(x_guards[0], OBJECT_ALIASING)) 254 self.assertTrue(isinstance(y_guards[0], OBJECT_ALIASING)) 255 # Check that the two guards are the same object 256 self.assertTrue(x_guards[0] is y_guards[0]) 257 258 f_locals_unaliased = Foo(torch.randn(3, 4), torch.randn(3, 4)) 259 self.assertEqual(len(x_guard_mgr.get_leaf_guards()), 1) 260 self.assertEqual(len(y_guard_mgr.get_leaf_guards()), 1) 261 self.assertTrue(guard_manager.check(f_locals)) 262 263 self.assertFalse(guard_manager.check(f_locals_unaliased)) 264 265 def test_dict_version_guard(self): 266 foo = {"a": 1, "b": 2} 267 guard = guards.DICT_VERSION(foo, ["x.version == foo.version"]) 268 269 self.assertTrue(guard(foo)) 270 self.assertFalse(guard(dict(foo))) 271 foo["a"] = 2 272 self.assertFalse(guard(foo)) 273 self.assertFalse(guard({"a": 1, "b": 2})) 274 self.assertFalse(guard({})) 275 276 def test_dynamic_indices_guard(self): 277 guard1 = guards.DYNAMIC_INDICES(set(), ["x.size(0) == y.size(0)"]) 278 guard2 = guards.DYNAMIC_INDICES(set({0, 1}), ["x.size(0) == y.size(0)"]) 279 280 x = torch.randn(4) 281 self.assertTrue(guard1(x)) 282 self.assertTrue(guard2(x)) 283 284 x._dynamo_dynamic_indices = set({0}) 285 self.assertFalse(guard1(x)) 286 self.assertTrue(guard2(x)) 287 288 x._dynamo_dynamic_indices = set({2}) 289 self.assertFalse(guard1(x)) 290 self.assertFalse(guard2(x)) 291 292 def test_tensor_match_guard(self): 293 guard_manager = RootGuardManager() 294 x = torch.randn(4, 4) 295 size = list(x.size()) 296 stride = list(x.stride()) 297 guard_manager.add_tensor_match_guard(x, size, stride, "x", ["check_tensor(x)"]) 298 self.assertTrue(guard_manager.check(x)) 299 self.assertTrue(guard_manager.check_verbose(x).result) 300 self.assertTrue(guard_manager.check(torch.randn(4, 4))) 301 self.assertTrue(guard_manager.check_verbose(torch.randn(4, 4)).result) 302 self.assertFalse(guard_manager.check(x.t_())) 303 304 x = torch.randn(4, 4) 305 x.t_() 306 debug_info = guard_manager.check_verbose(x) 307 print(debug_info.verbose_code_parts[0]) 308 self.assertTrue( 309 "tensor 'x' stride mismatch" in debug_info.verbose_code_parts[0] 310 ) 311 312 def test_no_tensor_aliasing_guard(self): 313 guard_manager = RootGuardManager() 314 315 a = torch.randn(3, 4) 316 317 class Foo: 318 def __init__(self, x, y, z): 319 self.x = x 320 self.y = y 321 self.z = z 322 323 f_locals = Foo(a, a, a) 324 325 x_guard_mgr = guard_manager.getattr_manager("x", "", a, default_mgr_enum) 326 y_guard_mgr = guard_manager.getattr_manager("y", "", a, default_mgr_enum) 327 z_guard_mgr = guard_manager.getattr_manager("z", "", a, default_mgr_enum) 328 install_no_tensor_aliasing_guard( 329 [x_guard_mgr, y_guard_mgr, z_guard_mgr], 330 ["x", "y", "z"], 331 ["no_aliasing(x, y, z)"], 332 ) 333 334 # Check structure 335 x_guards = x_guard_mgr.get_leaf_guards() 336 y_guards = y_guard_mgr.get_leaf_guards() 337 z_guards = z_guard_mgr.get_leaf_guards() 338 self.assertEqual(len(x_guards), 1) 339 self.assertEqual(len(y_guards), 1) 340 self.assertEqual(len(z_guards), 1) 341 self.assertTrue(isinstance(x_guards[0], NO_TENSOR_ALIASING)) 342 self.assertTrue(isinstance(y_guards[0], NO_TENSOR_ALIASING)) 343 self.assertTrue(isinstance(z_guards[0], NO_TENSOR_ALIASING)) 344 # Check that the two guards are the same object 345 self.assertTrue(x_guards[0] is y_guards[0] is z_guards[0]) 346 self.assertFalse(guard_manager.check(f_locals)) 347 self.assertFalse(guard_manager.check_verbose(f_locals).result) 348 349 f_locals_unaliased = Foo( 350 torch.randn(3, 4), 351 torch.randn(3, 4), 352 torch.randn(3, 4), 353 ) 354 self.assertTrue(guard_manager.check(f_locals_unaliased)) 355 self.assertTrue(guard_manager.check_verbose(f_locals_unaliased).result) 356 # Check that hash map is cleared. 357 self.assertTrue(guard_manager.check(f_locals_unaliased)) 358 359 f_locals_unaliased = Foo( 360 a, 361 torch.randn(3, 4), 362 a, 363 ) 364 self.assertFalse(guard_manager.check(f_locals_unaliased)) 365 self.assertFalse(guard_manager.check_verbose(f_locals_unaliased).result) 366 367 def test_weakref_alive_guard(self): 368 x = torch.rand(3, 4) 369 weakref_x = weakref.ref(x) 370 371 guard = guards.NOT_NONE(["weakref_x is not None"]) 372 self.assertTrue(guard(weakref_x())) 373 del x 374 self.assertFalse(guard(weakref_x())) 375 376 def test_guard_manager_leaf_guard(self): 377 guard_manager = RootGuardManager() 378 guard_manager.add_type_match_guard(id_type(5), ["type(x) == int"]) 379 guard_manager.add_lambda_guard( 380 functools.partial(ge_match, expected=5), 381 ge_match_verbose_code_parts(expected=5), 382 ) 383 guard_manager.add_lambda_guard( 384 functools.partial(less_match, expected=10), 385 less_match_verbose_code_parts(expected=10), 386 ) 387 self.assertEqual(len(guard_manager.get_leaf_guards()), 3) 388 self.assertEqual(len(guard_manager.get_accessors()), 0) 389 self.assertTrue(guard_manager.check(6)) 390 self.assertFalse(guard_manager.check(4)) 391 self.assertFalse(guard_manager.check("foo")) 392 393 def test_attr_guard_manager(self): 394 class Foo: 395 def __init__(self, x, y): 396 self.x = x 397 self.y = y 398 399 foo = Foo(1, 2) 400 guard_manager = RootGuardManager() 401 guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"]) 402 guard_manager.getattr_manager("x", "x", 1, default_mgr_enum).add_lambda_guard( 403 functools.partial(equals_match, expected=foo.x), 404 equals_match_verbose_code_parts(foo.x), 405 ) 406 guard_manager.getattr_manager("y", "y", 2, default_mgr_enum).add_lambda_guard( 407 functools.partial(equals_match, expected=foo.y), 408 equals_match_verbose_code_parts(foo.y), 409 ) 410 self.assertEqual(len(guard_manager.get_leaf_guards()), 1) 411 # 2 child managers, one for x and one for y 412 self.assertEqual(len(guard_manager.get_accessors()), 2) 413 self.assertTrue( 414 isinstance(guard_manager.get_accessors()[0], GetAttrGuardAccessor) 415 ) 416 self.assertTrue( 417 isinstance(guard_manager.get_accessors()[1], GetAttrGuardAccessor) 418 ) 419 # Check leaf guards on child managers 420 self.assertEqual( 421 len( 422 guard_manager.getattr_manager( 423 attr="x", 424 source="x", 425 example_value=None, 426 guard_manager_enum=default_mgr_enum, 427 ).get_leaf_guards() 428 ), 429 1, 430 ) 431 self.assertEqual( 432 len( 433 guard_manager.getattr_manager( 434 "y", "y", None, default_mgr_enum 435 ).get_leaf_guards() 436 ), 437 1, 438 ) 439 440 self.assertTrue(guard_manager.check(foo)) 441 self.assertFalse(guard_manager.check(Foo(3, 4))) 442 self.assertFalse(guard_manager.check("foo")) 443 444 def test_item_guard_manager(self): 445 foo = [1, 2] 446 guard_manager = RootGuardManager() 447 guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"]) 448 guard_manager.getitem_manager(0, "", 1, default_mgr_enum).add_lambda_guard( 449 functools.partial(equals_match, expected=foo[0]), 450 equals_match_verbose_code_parts(foo[0]), 451 ) 452 guard_manager.getitem_manager(1, "", 2, default_mgr_enum).add_lambda_guard( 453 functools.partial(equals_match, expected=foo[1]), 454 equals_match_verbose_code_parts(foo[1]), 455 ) 456 self.assertEqual(len(guard_manager.get_leaf_guards()), 1) 457 # 2 child managers, one for x and one for y 458 self.assertEqual(len(guard_manager.get_accessors()), 2) 459 self.assertTrue( 460 isinstance(guard_manager.get_accessors()[0], GetItemGuardAccessor) 461 ) 462 self.assertTrue( 463 isinstance(guard_manager.get_accessors()[1], GetItemGuardAccessor) 464 ) 465 # Check leaf guards on child managers 466 self.assertEqual( 467 len( 468 guard_manager.getitem_manager( 469 0, "", None, default_mgr_enum 470 ).get_leaf_guards() 471 ), 472 1, 473 ) 474 self.assertEqual( 475 len( 476 guard_manager.getitem_manager( 477 1, "", None, default_mgr_enum 478 ).get_leaf_guards() 479 ), 480 1, 481 ) 482 483 self.assertTrue(guard_manager.check(foo)) 484 self.assertFalse(guard_manager.check([3, 4])) 485 self.assertFalse(guard_manager.check("foo")) 486 487 def test_dict_getitem_accessor(self): 488 foo = { 489 "a": 1, 490 "b": 2, 491 } 492 493 guards_manager = RootGuardManager() 494 guards_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"]) 495 guards_manager.dict_getitem_manager( 496 "a", "", 1, default_mgr_enum 497 ).add_equals_match_guard(1, ["a == 1"]) 498 guards_manager.dict_getitem_manager( 499 "b", "", 2, default_mgr_enum 500 ).add_equals_match_guard(2, ["b == 2"]) 501 502 self.assertTrue(guards_manager.check(foo)) 503 self.assertFalse(guards_manager.check({"a": 1, "b": 3})) 504 505 def test_globals(self): 506 global global_pair, Pair 507 guard_manager = RootGuardManager() 508 gpair_mgr = guard_manager.globals_dict_manager( 509 globals(), "", None, default_mgr_enum 510 ).getitem_manager("global_pair", "", global_pair, default_mgr_enum) 511 512 gpair_mgr.add_lambda_guard( 513 lambda x: isinstance(x, Pair) 514 and isinstance(x.x, torch.Tensor) 515 and isinstance(x.y, int), 516 "global guard fail", 517 ) 518 519 self.assertTrue(guard_manager.check(global_pair)) 520 global_pair.y = "foo" 521 self.assertFalse(guard_manager.check(global_pair)) 522 523 def test_type_manager(self): 524 guard_manager = RootGuardManager() 525 526 class A: 527 a = 4 528 529 class B(A): 530 def mul(self, x): 531 super().mul(x) 532 533 foo = B() 534 f_locals = {"foo": foo} 535 536 # len(type(foo).__mro__) == 2 537 foo_mgr = guard_manager.getitem_manager("foo", "", foo, default_mgr_enum) 538 type_manager = foo_mgr.type_manager("", type(foo), default_mgr_enum) 539 self.assertTrue(isinstance(foo_mgr.get_accessors()[0], TypeGuardAccessor)) 540 mro_manager = type_manager.getattr_manager( 541 "__mro__", "", type(foo).__mro__, default_mgr_enum 542 ) 543 self.assertTrue( 544 isinstance(type_manager.get_accessors()[0], GetAttrGuardAccessor) 545 ) 546 mro_manager.add_length_check_guard( 547 3, 548 "Expected len(type(foo).__mro__) == 3", 549 ) 550 551 # type(foo).__mro__[0].a = 4 552 item_manager = mro_manager.getitem_manager( 553 1, "", type(foo).__mro__[1], default_mgr_enum 554 ) 555 self.assertTrue( 556 isinstance(mro_manager.get_accessors()[0], GetItemGuardAccessor) 557 ) 558 attr_manager = item_manager.getattr_manager( 559 "a", "", type(foo).__mro__[0].a, default_mgr_enum 560 ) 561 self.assertTrue( 562 isinstance(item_manager.get_accessors()[0], GetAttrGuardAccessor) 563 ) 564 attr_manager.add_lambda_guard( 565 lambda x: x == 4, 566 "Expected value 4", 567 ) 568 569 self.assertTrue(guard_manager.check(f_locals)) 570 571 def test_tuple_iterator_getitem(self): 572 a = (1, 2, 3, 4, 5, 6) 573 foo = iter(a) 574 next(foo) # foo points at index=1 575 576 guard_manager = RootGuardManager() 577 # Check a[3] which is tuple_iterator_getitem(foo, 2) 578 guard_manager.add_tuple_iterator_length_guard( 579 5, id_type(iter(())), ["len == 5"] 580 ) 581 guard_manager.tuple_iterator_getitem_manager( 582 2, "", foo, default_mgr_enum 583 ).add_equals_match_guard(a[3], ["x==4"]) 584 585 # Check that type match works 586 self.assertFalse(guard_manager.check(False)) 587 588 self.assertTrue(guard_manager.check(foo)) 589 590 # Check that index error fails gracefully 591 b = (1, 2) 592 b_foo = iter(b) 593 self.assertFalse(guard_manager.check(b_foo)) 594 595 def test_global_weakref(self): 596 guard_manager = RootGuardManager() 597 globals_manager = guard_manager.globals_dict_manager( 598 globals(), "", None, default_mgr_enum 599 ) 600 weakref_manager = globals_manager.global_weakref_manager( 601 "weakref_x", "", None, default_mgr_enum 602 ) 603 604 weakref_manager.add_lambda_guard( 605 lambda x: isinstance(x, torch.Tensor), 606 "global weakref fail", 607 ) 608 609 self.assertTrue(guard_manager.check(None)) 610 global x 611 del x 612 self.assertFalse(guard_manager.check(None)) 613 614 def test_lambda_manager(self): 615 a = (1, 1, 3, 4, 5, 6) 616 617 guard_manager = RootGuardManager() 618 619 # Check that we can use the same accessor 620 foo_mgr = guard_manager.lambda_manager( 621 lambda x: x[2], "", None, default_mgr_enum 622 ) 623 foo_mgr.add_lambda_guard( 624 lambda x: x == 3, 625 "Expected value 3", 626 ) 627 self.assertTrue(guard_manager.check(a)) 628 629 # test that exception works 630 guard_manager = RootGuardManager() 631 632 def fn(x): 633 raise AssertionError("Test") 634 return x 635 636 foo_mgr = guard_manager.lambda_manager(fn, "", None, default_mgr_enum) 637 638 self.assertFalse(guard_manager.check(None)) 639 debug_info = guard_manager.check_verbose(None) 640 self.assertFalse(debug_info.result) 641 self.assertTrue("Test" in debug_info.verbose_code_parts[0]) 642 643 def test_dict_contains_guard(self): 644 foo = {"a": 1, "b": 2} 645 guard = guards.DICT_CONTAINS(True, "a", ["has a"]) 646 647 self.assertTrue(guard(foo)) 648 self.assertTrue(guard({"a": 1, "b": 2})) 649 self.assertFalse(guard({"b": 2, "c": 3})) 650 self.assertFalse(guard({})) 651 652 guard = guards.DICT_CONTAINS(False, "c", ["not has c"]) 653 self.assertTrue(guard(foo)) 654 self.assertTrue(guard({"a": 1, "b": 2})) 655 self.assertFalse(guard({"b": 2, "c": 3})) 656 self.assertTrue(guard({})) 657 658 def test_dict_guard_manager(self): 659 root = RootGuardManager() 660 661 def nothing(): 662 pass 663 664 f_locals = { 665 "d": {"a": 1, nothing: {"z": 3}, 100: torch.randn(4)}, 666 } 667 668 # its a getitem_manager just for f_locals. But the child guard manager 669 # should be a DictGuardManager. 670 dict_mgr = root.getitem_manager( 671 "d", 672 "", 673 f_locals["d"], 674 torch._dynamo.guards.GuardManagerType.DICT_GUARD_MANAGER, 675 ) 676 self.assertTrue(isinstance(dict_mgr, DictGuardManager)) 677 678 self.assertTrue(root.check(f_locals)) 679 680 # Check that no one can add a leaf guard 681 with self.assertRaises(RuntimeError): 682 dict_mgr.add_id_match_guard(id_type(f_locals), "id match") 683 684 # Check that no one can add an arbitrary accessor 685 with self.assertRaises(RuntimeError): 686 dict_mgr.getitem_manager("a", "", f_locals["d"]["a"]) 687 688 # Check that it fails with different length dict 689 f_locals_prime = { 690 "d": {"a": 1, "b": 2}, 691 } 692 self.assertFalse(root.check(f_locals_prime)) 693 694 # Add key-value manager ("a" : 1) 695 self.assertTrue(root.check(f_locals)) 696 dict_mgr.get_key_manager(0, "", "a", default_mgr_enum).add_equals_match_guard( 697 "a", 698 ["dict.keys()[0] == a"], 699 ) 700 self.assertTrue(root.check(f_locals)) 701 dict_mgr.get_value_manager(0, "", 1, default_mgr_enum).add_equals_match_guard( 702 1, ["d[0] == 1"] 703 ) 704 self.assertTrue(root.check(f_locals)) 705 706 # Add key-value manager (nothing : {"z" : 3}) 707 self.assertTrue(root.check(f_locals)) 708 dict_mgr.get_key_manager(1, "", nothing, default_mgr_enum).add_lambda_guard( 709 lambda x: x is nothing, ["x is nothing"] 710 ) 711 self.assertTrue(root.check(f_locals)) 712 value_mgr = dict_mgr.get_value_manager( 713 1, 714 "", 715 f_locals["d"][nothing], 716 torch._dynamo.guards.GuardManagerType.DICT_GUARD_MANAGER, 717 ) 718 self.assertTrue(isinstance(value_mgr, DictGuardManager)) 719 self.assertTrue(root.check(f_locals)) 720 721 # Check structure 722 # Check that we are only guarding on two keys. This is common in 723 # LazyVariableTracker. 724 self.assertEqual(len(dict_mgr.get_key_value_managers()), 2) 725 726 f_locals["d"]["a"] = 2 727 self.assertFalse(root.check(f_locals)) 728 self.assertFalse(root.check_verbose(f_locals).result) 729 730 f_locals["d"]["a"] = 1 731 self.assertTrue(root.check(f_locals)) 732 733 f_locals["d"].pop(100) 734 # fails because of len check 735 self.assertFalse(root.check(f_locals)) 736 737 738if __name__ == "__main__": 739 from torch._dynamo.test_case import run_tests 740 741 run_tests() 742