xref: /aosp_15_r20/external/pytorch/test/dynamo/test_guard_manager.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import functools
3import weakref
4
5import torch
6import torch._dynamo
7import torch._dynamo.test_case
8from torch._C._dynamo import guards
9from torch._dynamo.convert_frame import GlobalStateGuard
10from torch.testing._internal.common_utils import set_default_dtype
11
12
13RootGuardManager = guards.RootGuardManager
14DictGuardManager = guards.DictGuardManager
15DictSubclassGuardManager = guards.DictSubclassGuardManager
16GetAttrGuardAccessor = guards.GetAttrGuardAccessor
17GetItemGuardAccessor = guards.GetItemGuardAccessor
18TypeGuardAccessor = guards.TypeGuardAccessor
19OBJECT_ALIASING = guards.OBJECT_ALIASING
20install_object_aliasing_guard = guards.install_object_aliasing_guard
21NO_TENSOR_ALIASING = guards.NO_TENSOR_ALIASING
22install_no_tensor_aliasing_guard = guards.install_no_tensor_aliasing_guard
23
24
25x = torch.tensor(4)
26weakref_x = weakref.ref(x)
27
28default_mgr_enum = torch._dynamo.guards.GuardManagerType.GUARD_MANAGER
29
30
31class Pair:
32    def __init__(self, x, y):
33        self.x = x
34        self.y = y
35
36
37global_pair = Pair(torch.randn(4), 1)
38
39
40def id_type(x):
41    return id(type(x))
42
43
44def equals_match(x, expected):
45    return x == expected
46
47
48def equals_match_verbose_code_parts(expected):
49    return [f"x == {expected}"]
50
51
52def ge_match(x, expected):
53    return x >= expected
54
55
56def ge_match_verbose_code_parts(expected):
57    return f"expected >= {expected}"
58
59
60def less_match(x, expected):
61    return x < expected
62
63
64def less_match_verbose_code_parts(expected):
65    return [f"expected < {expected}"]
66
67
68class GuardManagerTests(torch._dynamo.test_case.TestCase):
69    def test_global_state_guard(self):
70        guard = guards.GLOBAL_STATE(["global_state_check"])
71        self.assertTrue(guard(None))
72        with set_default_dtype(torch.double):
73            self.assertFalse(guard(None))
74            self.assertExpectedInline(
75                str(guard.check_verbose(None)),
76                """\
77GuardDebugInfo(
78result=0,
79verbose_code_parts=['GLOBAL_STATE changed: default_dtype '],
80num_guards_executed=0)
81""",
82            )
83        self.assertTrue(guard(None))
84        self.assertTrue(guard.check_verbose(None).result)
85        _orig = torch.are_deterministic_algorithms_enabled()
86        try:
87            torch.use_deterministic_algorithms(not _orig)
88            self.assertFalse(guard(None))
89            self.assertExpectedInline(
90                str(guard.check_verbose(None)),
91                """\
92GuardDebugInfo(
93result=0,
94verbose_code_parts=['GLOBAL_STATE changed: deterministic_algorithms '],
95num_guards_executed=0)
96""",
97            )
98        finally:
99            torch.use_deterministic_algorithms(_orig)
100        self.assertTrue(guard(None))
101        self.assertTrue(guard.check_verbose(None).result)
102
103    def test_global_state_reason(self):
104        with torch.enable_grad():
105            guards = GlobalStateGuard()
106        with torch.no_grad():
107            self.assertIs(guards.check(), False)
108            self.assertEqual(guards.reason(), "grad_mode ")
109
110    def test_python_lambda_leaf_guard(self):
111        const_guard = guards.LAMBDA_GUARD(
112            functools.partial(equals_match, expected=5),
113            equals_match_verbose_code_parts(5),
114        )
115        self.assertTrue(const_guard(5))
116        self.assertFalse(const_guard(4))
117        self.assertFalse(const_guard("foo"))
118
119    def test_type_guard(self):
120        foo = 4
121        guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == int"])
122
123        self.assertTrue(guard(5))
124        self.assertTrue(guard(4))
125        self.assertFalse(guard("foo"))
126
127        foo = {"a": 1}
128        guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == dict"])
129        self.assertTrue(guard(foo))
130        self.assertTrue(guard({}))
131        self.assertFalse(guard(5))
132        self.assertFalse(guard("foo"))
133
134        class Foo:
135            def __init__(self, x, y):
136                self.x = x
137                self.y = y
138
139        foo = Foo(1, 2)
140
141        guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == Foo"])
142        self.assertTrue(guard(foo))
143        self.assertFalse(guard({}))
144        self.assertFalse(guard(5))
145        self.assertFalse(guard("foo"))
146
147    def test_id_guard(self):
148        foo = 4
149        guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"])
150
151        self.assertTrue(guard(foo))
152        self.assertFalse(guard(5))
153        self.assertFalse(guard("foo"))
154
155        foo = {"a": 1}
156        guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"])
157        self.assertTrue(guard(foo))
158        self.assertFalse(guard({"a": 1}))
159        self.assertFalse(guard({}))
160        self.assertFalse(guard(5))
161
162    def test_equals_guard(self):
163        foo = 4
164        guard = guards.EQUALS_MATCH(foo, ["x == 4"])
165
166        self.assertTrue(guard(4))
167        self.assertFalse(guard(5))
168        self.assertFalse(guard("foo"))
169
170        # tuple
171        foo = (1, 2, 3)
172        guard = guards.EQUALS_MATCH(foo, ["x == foo"])
173        self.assertTrue(guard(foo))
174        self.assertTrue(guard((1, 2, 3)))
175        self.assertFalse(guard((1, 2, 3, 4)))
176        self.assertFalse(guard({}))
177
178        # list
179        foo = [1, 2, 3]
180        guard = guards.EQUALS_MATCH(foo, ["x == foo"])
181        self.assertTrue(guard(foo))
182        self.assertTrue(guard([1, 2, 3]))
183        self.assertFalse(guard([1, 2, 3, 4]))
184
185        # type
186        foo = int
187        guard = guards.EQUALS_MATCH(foo, ["x == foo"])
188        self.assertTrue(guard(foo))
189        self.assertTrue(guard(int))
190        self.assertFalse(guard(float))
191
192    def test_default_device_guard(self):
193        foo = 1
194        guard = guards.DEFAULT_DEVICE(["cpu device"])
195        self.assertTrue(guard(foo))
196
197        try:
198            torch.set_default_device("cuda")
199            self.assertFalse(guard(foo))
200        finally:
201            torch.set_default_device(None)
202
203    def test_data_ptr_match_guard(self):
204        foo = torch.tensor([1, 2, 3])
205        guard = guards.DATA_PTR_MATCH(foo, ["x.data_ptr() == foo.data_ptr()"])
206        self.assertTrue(guard(foo))
207        self.assertFalse(guard(torch.tensor([1, 2, 3])))
208
209    def test_length_check_guard(self):
210        foo = [1, 2, 3]
211        guard = guards.LENGTH_CHECK(len(foo), ["len(x) == len(foo)"])
212        self.assertTrue(guard(foo))
213        self.assertFalse(guard([]))
214
215    def test_no_hasattr_guard(self):
216        class Bar:
217            def __init__(self) -> None:
218                self.bar = 2
219
220        bar = Bar()
221
222        class Foo:
223            def __init__(self) -> None:
224                self.foo = 2
225
226        foo = Foo()
227
228        guard = guards.NO_HASATTR("foo", ["hasattr(x, 'foo') == False"])
229        self.assertTrue(guard(bar))
230        self.assertFalse(guard(foo))
231
232    def test_tensor_aliasing_guard(self):
233        guard_manager = RootGuardManager()
234
235        a = torch.randn(3, 4)
236
237        class Foo:
238            def __init__(self, x, y):
239                self.x = x
240                self.y = y
241
242        f_locals = Foo(a, a)
243
244        x_guard_mgr = guard_manager.getattr_manager("x", "", a, default_mgr_enum)
245        y_guard_mgr = guard_manager.getattr_manager("y", "", a, default_mgr_enum)
246        install_object_aliasing_guard(x_guard_mgr, y_guard_mgr, ["x is y"])
247
248        # Check structure
249        x_guards = x_guard_mgr.get_leaf_guards()
250        y_guards = y_guard_mgr.get_leaf_guards()
251        self.assertEqual(len(x_guards), 1)
252        self.assertEqual(len(y_guards), 1)
253        self.assertTrue(isinstance(x_guards[0], OBJECT_ALIASING))
254        self.assertTrue(isinstance(y_guards[0], OBJECT_ALIASING))
255        # Check that the two guards are the same object
256        self.assertTrue(x_guards[0] is y_guards[0])
257
258        f_locals_unaliased = Foo(torch.randn(3, 4), torch.randn(3, 4))
259        self.assertEqual(len(x_guard_mgr.get_leaf_guards()), 1)
260        self.assertEqual(len(y_guard_mgr.get_leaf_guards()), 1)
261        self.assertTrue(guard_manager.check(f_locals))
262
263        self.assertFalse(guard_manager.check(f_locals_unaliased))
264
265    def test_dict_version_guard(self):
266        foo = {"a": 1, "b": 2}
267        guard = guards.DICT_VERSION(foo, ["x.version == foo.version"])
268
269        self.assertTrue(guard(foo))
270        self.assertFalse(guard(dict(foo)))
271        foo["a"] = 2
272        self.assertFalse(guard(foo))
273        self.assertFalse(guard({"a": 1, "b": 2}))
274        self.assertFalse(guard({}))
275
276    def test_dynamic_indices_guard(self):
277        guard1 = guards.DYNAMIC_INDICES(set(), ["x.size(0) == y.size(0)"])
278        guard2 = guards.DYNAMIC_INDICES(set({0, 1}), ["x.size(0) == y.size(0)"])
279
280        x = torch.randn(4)
281        self.assertTrue(guard1(x))
282        self.assertTrue(guard2(x))
283
284        x._dynamo_dynamic_indices = set({0})
285        self.assertFalse(guard1(x))
286        self.assertTrue(guard2(x))
287
288        x._dynamo_dynamic_indices = set({2})
289        self.assertFalse(guard1(x))
290        self.assertFalse(guard2(x))
291
292    def test_tensor_match_guard(self):
293        guard_manager = RootGuardManager()
294        x = torch.randn(4, 4)
295        size = list(x.size())
296        stride = list(x.stride())
297        guard_manager.add_tensor_match_guard(x, size, stride, "x", ["check_tensor(x)"])
298        self.assertTrue(guard_manager.check(x))
299        self.assertTrue(guard_manager.check_verbose(x).result)
300        self.assertTrue(guard_manager.check(torch.randn(4, 4)))
301        self.assertTrue(guard_manager.check_verbose(torch.randn(4, 4)).result)
302        self.assertFalse(guard_manager.check(x.t_()))
303
304        x = torch.randn(4, 4)
305        x.t_()
306        debug_info = guard_manager.check_verbose(x)
307        print(debug_info.verbose_code_parts[0])
308        self.assertTrue(
309            "tensor 'x' stride mismatch" in debug_info.verbose_code_parts[0]
310        )
311
312    def test_no_tensor_aliasing_guard(self):
313        guard_manager = RootGuardManager()
314
315        a = torch.randn(3, 4)
316
317        class Foo:
318            def __init__(self, x, y, z):
319                self.x = x
320                self.y = y
321                self.z = z
322
323        f_locals = Foo(a, a, a)
324
325        x_guard_mgr = guard_manager.getattr_manager("x", "", a, default_mgr_enum)
326        y_guard_mgr = guard_manager.getattr_manager("y", "", a, default_mgr_enum)
327        z_guard_mgr = guard_manager.getattr_manager("z", "", a, default_mgr_enum)
328        install_no_tensor_aliasing_guard(
329            [x_guard_mgr, y_guard_mgr, z_guard_mgr],
330            ["x", "y", "z"],
331            ["no_aliasing(x, y, z)"],
332        )
333
334        # Check structure
335        x_guards = x_guard_mgr.get_leaf_guards()
336        y_guards = y_guard_mgr.get_leaf_guards()
337        z_guards = z_guard_mgr.get_leaf_guards()
338        self.assertEqual(len(x_guards), 1)
339        self.assertEqual(len(y_guards), 1)
340        self.assertEqual(len(z_guards), 1)
341        self.assertTrue(isinstance(x_guards[0], NO_TENSOR_ALIASING))
342        self.assertTrue(isinstance(y_guards[0], NO_TENSOR_ALIASING))
343        self.assertTrue(isinstance(z_guards[0], NO_TENSOR_ALIASING))
344        # Check that the two guards are the same object
345        self.assertTrue(x_guards[0] is y_guards[0] is z_guards[0])
346        self.assertFalse(guard_manager.check(f_locals))
347        self.assertFalse(guard_manager.check_verbose(f_locals).result)
348
349        f_locals_unaliased = Foo(
350            torch.randn(3, 4),
351            torch.randn(3, 4),
352            torch.randn(3, 4),
353        )
354        self.assertTrue(guard_manager.check(f_locals_unaliased))
355        self.assertTrue(guard_manager.check_verbose(f_locals_unaliased).result)
356        # Check that hash map is cleared.
357        self.assertTrue(guard_manager.check(f_locals_unaliased))
358
359        f_locals_unaliased = Foo(
360            a,
361            torch.randn(3, 4),
362            a,
363        )
364        self.assertFalse(guard_manager.check(f_locals_unaliased))
365        self.assertFalse(guard_manager.check_verbose(f_locals_unaliased).result)
366
367    def test_weakref_alive_guard(self):
368        x = torch.rand(3, 4)
369        weakref_x = weakref.ref(x)
370
371        guard = guards.NOT_NONE(["weakref_x is not None"])
372        self.assertTrue(guard(weakref_x()))
373        del x
374        self.assertFalse(guard(weakref_x()))
375
376    def test_guard_manager_leaf_guard(self):
377        guard_manager = RootGuardManager()
378        guard_manager.add_type_match_guard(id_type(5), ["type(x) == int"])
379        guard_manager.add_lambda_guard(
380            functools.partial(ge_match, expected=5),
381            ge_match_verbose_code_parts(expected=5),
382        )
383        guard_manager.add_lambda_guard(
384            functools.partial(less_match, expected=10),
385            less_match_verbose_code_parts(expected=10),
386        )
387        self.assertEqual(len(guard_manager.get_leaf_guards()), 3)
388        self.assertEqual(len(guard_manager.get_accessors()), 0)
389        self.assertTrue(guard_manager.check(6))
390        self.assertFalse(guard_manager.check(4))
391        self.assertFalse(guard_manager.check("foo"))
392
393    def test_attr_guard_manager(self):
394        class Foo:
395            def __init__(self, x, y):
396                self.x = x
397                self.y = y
398
399        foo = Foo(1, 2)
400        guard_manager = RootGuardManager()
401        guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
402        guard_manager.getattr_manager("x", "x", 1, default_mgr_enum).add_lambda_guard(
403            functools.partial(equals_match, expected=foo.x),
404            equals_match_verbose_code_parts(foo.x),
405        )
406        guard_manager.getattr_manager("y", "y", 2, default_mgr_enum).add_lambda_guard(
407            functools.partial(equals_match, expected=foo.y),
408            equals_match_verbose_code_parts(foo.y),
409        )
410        self.assertEqual(len(guard_manager.get_leaf_guards()), 1)
411        # 2 child managers, one for x and one for y
412        self.assertEqual(len(guard_manager.get_accessors()), 2)
413        self.assertTrue(
414            isinstance(guard_manager.get_accessors()[0], GetAttrGuardAccessor)
415        )
416        self.assertTrue(
417            isinstance(guard_manager.get_accessors()[1], GetAttrGuardAccessor)
418        )
419        # Check leaf guards on child managers
420        self.assertEqual(
421            len(
422                guard_manager.getattr_manager(
423                    attr="x",
424                    source="x",
425                    example_value=None,
426                    guard_manager_enum=default_mgr_enum,
427                ).get_leaf_guards()
428            ),
429            1,
430        )
431        self.assertEqual(
432            len(
433                guard_manager.getattr_manager(
434                    "y", "y", None, default_mgr_enum
435                ).get_leaf_guards()
436            ),
437            1,
438        )
439
440        self.assertTrue(guard_manager.check(foo))
441        self.assertFalse(guard_manager.check(Foo(3, 4)))
442        self.assertFalse(guard_manager.check("foo"))
443
444    def test_item_guard_manager(self):
445        foo = [1, 2]
446        guard_manager = RootGuardManager()
447        guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
448        guard_manager.getitem_manager(0, "", 1, default_mgr_enum).add_lambda_guard(
449            functools.partial(equals_match, expected=foo[0]),
450            equals_match_verbose_code_parts(foo[0]),
451        )
452        guard_manager.getitem_manager(1, "", 2, default_mgr_enum).add_lambda_guard(
453            functools.partial(equals_match, expected=foo[1]),
454            equals_match_verbose_code_parts(foo[1]),
455        )
456        self.assertEqual(len(guard_manager.get_leaf_guards()), 1)
457        # 2 child managers, one for x and one for y
458        self.assertEqual(len(guard_manager.get_accessors()), 2)
459        self.assertTrue(
460            isinstance(guard_manager.get_accessors()[0], GetItemGuardAccessor)
461        )
462        self.assertTrue(
463            isinstance(guard_manager.get_accessors()[1], GetItemGuardAccessor)
464        )
465        # Check leaf guards on child managers
466        self.assertEqual(
467            len(
468                guard_manager.getitem_manager(
469                    0, "", None, default_mgr_enum
470                ).get_leaf_guards()
471            ),
472            1,
473        )
474        self.assertEqual(
475            len(
476                guard_manager.getitem_manager(
477                    1, "", None, default_mgr_enum
478                ).get_leaf_guards()
479            ),
480            1,
481        )
482
483        self.assertTrue(guard_manager.check(foo))
484        self.assertFalse(guard_manager.check([3, 4]))
485        self.assertFalse(guard_manager.check("foo"))
486
487    def test_dict_getitem_accessor(self):
488        foo = {
489            "a": 1,
490            "b": 2,
491        }
492
493        guards_manager = RootGuardManager()
494        guards_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
495        guards_manager.dict_getitem_manager(
496            "a", "", 1, default_mgr_enum
497        ).add_equals_match_guard(1, ["a == 1"])
498        guards_manager.dict_getitem_manager(
499            "b", "", 2, default_mgr_enum
500        ).add_equals_match_guard(2, ["b == 2"])
501
502        self.assertTrue(guards_manager.check(foo))
503        self.assertFalse(guards_manager.check({"a": 1, "b": 3}))
504
505    def test_globals(self):
506        global global_pair, Pair
507        guard_manager = RootGuardManager()
508        gpair_mgr = guard_manager.globals_dict_manager(
509            globals(), "", None, default_mgr_enum
510        ).getitem_manager("global_pair", "", global_pair, default_mgr_enum)
511
512        gpair_mgr.add_lambda_guard(
513            lambda x: isinstance(x, Pair)
514            and isinstance(x.x, torch.Tensor)
515            and isinstance(x.y, int),
516            "global guard fail",
517        )
518
519        self.assertTrue(guard_manager.check(global_pair))
520        global_pair.y = "foo"
521        self.assertFalse(guard_manager.check(global_pair))
522
523    def test_type_manager(self):
524        guard_manager = RootGuardManager()
525
526        class A:
527            a = 4
528
529        class B(A):
530            def mul(self, x):
531                super().mul(x)
532
533        foo = B()
534        f_locals = {"foo": foo}
535
536        # len(type(foo).__mro__) == 2
537        foo_mgr = guard_manager.getitem_manager("foo", "", foo, default_mgr_enum)
538        type_manager = foo_mgr.type_manager("", type(foo), default_mgr_enum)
539        self.assertTrue(isinstance(foo_mgr.get_accessors()[0], TypeGuardAccessor))
540        mro_manager = type_manager.getattr_manager(
541            "__mro__", "", type(foo).__mro__, default_mgr_enum
542        )
543        self.assertTrue(
544            isinstance(type_manager.get_accessors()[0], GetAttrGuardAccessor)
545        )
546        mro_manager.add_length_check_guard(
547            3,
548            "Expected len(type(foo).__mro__) == 3",
549        )
550
551        # type(foo).__mro__[0].a = 4
552        item_manager = mro_manager.getitem_manager(
553            1, "", type(foo).__mro__[1], default_mgr_enum
554        )
555        self.assertTrue(
556            isinstance(mro_manager.get_accessors()[0], GetItemGuardAccessor)
557        )
558        attr_manager = item_manager.getattr_manager(
559            "a", "", type(foo).__mro__[0].a, default_mgr_enum
560        )
561        self.assertTrue(
562            isinstance(item_manager.get_accessors()[0], GetAttrGuardAccessor)
563        )
564        attr_manager.add_lambda_guard(
565            lambda x: x == 4,
566            "Expected value 4",
567        )
568
569        self.assertTrue(guard_manager.check(f_locals))
570
571    def test_tuple_iterator_getitem(self):
572        a = (1, 2, 3, 4, 5, 6)
573        foo = iter(a)
574        next(foo)  # foo points at index=1
575
576        guard_manager = RootGuardManager()
577        # Check a[3] which is tuple_iterator_getitem(foo, 2)
578        guard_manager.add_tuple_iterator_length_guard(
579            5, id_type(iter(())), ["len == 5"]
580        )
581        guard_manager.tuple_iterator_getitem_manager(
582            2, "", foo, default_mgr_enum
583        ).add_equals_match_guard(a[3], ["x==4"])
584
585        # Check that type match works
586        self.assertFalse(guard_manager.check(False))
587
588        self.assertTrue(guard_manager.check(foo))
589
590        # Check that index error fails gracefully
591        b = (1, 2)
592        b_foo = iter(b)
593        self.assertFalse(guard_manager.check(b_foo))
594
595    def test_global_weakref(self):
596        guard_manager = RootGuardManager()
597        globals_manager = guard_manager.globals_dict_manager(
598            globals(), "", None, default_mgr_enum
599        )
600        weakref_manager = globals_manager.global_weakref_manager(
601            "weakref_x", "", None, default_mgr_enum
602        )
603
604        weakref_manager.add_lambda_guard(
605            lambda x: isinstance(x, torch.Tensor),
606            "global weakref fail",
607        )
608
609        self.assertTrue(guard_manager.check(None))
610        global x
611        del x
612        self.assertFalse(guard_manager.check(None))
613
614    def test_lambda_manager(self):
615        a = (1, 1, 3, 4, 5, 6)
616
617        guard_manager = RootGuardManager()
618
619        # Check that we can use the same accessor
620        foo_mgr = guard_manager.lambda_manager(
621            lambda x: x[2], "", None, default_mgr_enum
622        )
623        foo_mgr.add_lambda_guard(
624            lambda x: x == 3,
625            "Expected value 3",
626        )
627        self.assertTrue(guard_manager.check(a))
628
629        # test that exception works
630        guard_manager = RootGuardManager()
631
632        def fn(x):
633            raise AssertionError("Test")
634            return x
635
636        foo_mgr = guard_manager.lambda_manager(fn, "", None, default_mgr_enum)
637
638        self.assertFalse(guard_manager.check(None))
639        debug_info = guard_manager.check_verbose(None)
640        self.assertFalse(debug_info.result)
641        self.assertTrue("Test" in debug_info.verbose_code_parts[0])
642
643    def test_dict_contains_guard(self):
644        foo = {"a": 1, "b": 2}
645        guard = guards.DICT_CONTAINS(True, "a", ["has a"])
646
647        self.assertTrue(guard(foo))
648        self.assertTrue(guard({"a": 1, "b": 2}))
649        self.assertFalse(guard({"b": 2, "c": 3}))
650        self.assertFalse(guard({}))
651
652        guard = guards.DICT_CONTAINS(False, "c", ["not has c"])
653        self.assertTrue(guard(foo))
654        self.assertTrue(guard({"a": 1, "b": 2}))
655        self.assertFalse(guard({"b": 2, "c": 3}))
656        self.assertTrue(guard({}))
657
658    def test_dict_guard_manager(self):
659        root = RootGuardManager()
660
661        def nothing():
662            pass
663
664        f_locals = {
665            "d": {"a": 1, nothing: {"z": 3}, 100: torch.randn(4)},
666        }
667
668        # its a getitem_manager just for f_locals. But the child guard manager
669        # should be a DictGuardManager.
670        dict_mgr = root.getitem_manager(
671            "d",
672            "",
673            f_locals["d"],
674            torch._dynamo.guards.GuardManagerType.DICT_GUARD_MANAGER,
675        )
676        self.assertTrue(isinstance(dict_mgr, DictGuardManager))
677
678        self.assertTrue(root.check(f_locals))
679
680        # Check that no one can add a leaf guard
681        with self.assertRaises(RuntimeError):
682            dict_mgr.add_id_match_guard(id_type(f_locals), "id match")
683
684        # Check that no one can add an arbitrary accessor
685        with self.assertRaises(RuntimeError):
686            dict_mgr.getitem_manager("a", "", f_locals["d"]["a"])
687
688        # Check that it fails with different length dict
689        f_locals_prime = {
690            "d": {"a": 1, "b": 2},
691        }
692        self.assertFalse(root.check(f_locals_prime))
693
694        # Add key-value manager ("a" : 1)
695        self.assertTrue(root.check(f_locals))
696        dict_mgr.get_key_manager(0, "", "a", default_mgr_enum).add_equals_match_guard(
697            "a",
698            ["dict.keys()[0] == a"],
699        )
700        self.assertTrue(root.check(f_locals))
701        dict_mgr.get_value_manager(0, "", 1, default_mgr_enum).add_equals_match_guard(
702            1, ["d[0] == 1"]
703        )
704        self.assertTrue(root.check(f_locals))
705
706        # Add key-value manager (nothing : {"z" : 3})
707        self.assertTrue(root.check(f_locals))
708        dict_mgr.get_key_manager(1, "", nothing, default_mgr_enum).add_lambda_guard(
709            lambda x: x is nothing, ["x is nothing"]
710        )
711        self.assertTrue(root.check(f_locals))
712        value_mgr = dict_mgr.get_value_manager(
713            1,
714            "",
715            f_locals["d"][nothing],
716            torch._dynamo.guards.GuardManagerType.DICT_GUARD_MANAGER,
717        )
718        self.assertTrue(isinstance(value_mgr, DictGuardManager))
719        self.assertTrue(root.check(f_locals))
720
721        # Check structure
722        # Check that we are only guarding on two keys. This is common in
723        # LazyVariableTracker.
724        self.assertEqual(len(dict_mgr.get_key_value_managers()), 2)
725
726        f_locals["d"]["a"] = 2
727        self.assertFalse(root.check(f_locals))
728        self.assertFalse(root.check_verbose(f_locals).result)
729
730        f_locals["d"]["a"] = 1
731        self.assertTrue(root.check(f_locals))
732
733        f_locals["d"].pop(100)
734        # fails because of len check
735        self.assertFalse(root.check(f_locals))
736
737
738if __name__ == "__main__":
739    from torch._dynamo.test_case import run_tests
740
741    run_tests()
742