xref: /aosp_15_r20/external/pytorch/test/dynamo/test_config.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import torch
4import torch._dynamo.test_case
5import torch._dynamo.testing
6from torch._dynamo.utils import disable_cache_limit
7
8
9# NB: do NOT include this test class in test_dynamic_shapes.py
10
11
12class ConfigTests(torch._dynamo.test_case.TestCase):
13    @disable_cache_limit()
14    def test_no_automatic_dynamic(self):
15        def fn(a, b):
16            return a - b * 10
17
18        torch._dynamo.reset()
19        cnt_static = torch._dynamo.testing.CompileCounter()
20        with torch._dynamo.config.patch(
21            automatic_dynamic_shapes=False, assume_static_by_default=True
22        ):
23            opt_fn = torch._dynamo.optimize(cnt_static)(fn)
24            for i in range(2, 12):
25                opt_fn(torch.randn(i), torch.randn(i))
26        self.assertEqual(cnt_static.frame_count, 10)
27
28    @disable_cache_limit()
29    def test_automatic_dynamic(self):
30        def fn(a, b):
31            return a - b * 10
32
33        torch._dynamo.reset()
34        cnt_dynamic = torch._dynamo.testing.CompileCounter()
35        with torch._dynamo.config.patch(
36            automatic_dynamic_shapes=True, assume_static_by_default=True
37        ):
38            opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
39            # NB: must not do 0, 1 as they specialized
40            for i in range(2, 12):
41                opt_fn(torch.randn(i), torch.randn(i))
42        # two graphs now rather than 10
43        self.assertEqual(cnt_dynamic.frame_count, 2)
44
45    @disable_cache_limit()
46    def test_no_assume_static_by_default(self):
47        def fn(a, b):
48            return a - b * 10
49
50        torch._dynamo.reset()
51        cnt_dynamic = torch._dynamo.testing.CompileCounter()
52        with torch._dynamo.config.patch(
53            automatic_dynamic_shapes=True, assume_static_by_default=False
54        ):
55            opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
56            # NB: must not do 0, 1 as they specialized
57            for i in range(2, 12):
58                opt_fn(torch.randn(i), torch.randn(i))
59        # one graph now, as we didn't wait for recompile
60        self.assertEqual(cnt_dynamic.frame_count, 1)
61
62    def test_config_compile_ignored(self):
63        # Remove from this list if no longer relevant
64        dynamo_guarded_config_ignorelist = {
65            "log_file_name",
66            "verbose",
67            "verify_correctness",  # will not affect model, will raise RuntimeError
68            # (no silent change to compilation behaviour)
69            "cache_size_limit",
70            "accumulated_cache_size_limit",
71            "replay_record_enabled",
72            "cprofile",  # only wraps _compile, not graph
73            "repro_after",
74            "repro_level",
75            "repro_forward_only",
76            "repro_tolerance",
77            "same_two_models_use_fp64",
78            "error_on_recompile",  # safe because: will throw error
79            "report_guard_failures",
80            "base_dir",  # used for minifying / logging
81            "DEBUG_DIR_VAR_NAME",
82            "debug_dir_root",
83        }
84        for k in dynamo_guarded_config_ignorelist:
85            assert k in torch._dynamo.config._compile_ignored_keys, k
86
87    def test_config_hash(self):
88        config = torch._dynamo.config
89        starting_hash = config.get_hash()
90
91        with config.patch({"verbose": not config.verbose}):
92            new_hash = config.get_hash()
93            assert "verbose" in config._compile_ignored_keys
94            assert new_hash == starting_hash
95
96        new_hash = config.get_hash()
97        assert new_hash == starting_hash
98
99        with config.patch({"dead_code_elimination": not config.dead_code_elimination}):
100            changed_hash = config.get_hash()
101            assert "dead_code_elimination" not in config._compile_ignored_keys
102            assert changed_hash != starting_hash
103
104            # Test nested patch
105            with config.patch({"verbose": not config.verbose}):
106                inner_changed_hash = config.get_hash()
107                assert inner_changed_hash == changed_hash
108                assert inner_changed_hash != starting_hash
109
110        newest_hash = config.get_hash()
111        assert changed_hash != newest_hash
112        assert newest_hash == starting_hash
113
114
115if __name__ == "__main__":
116    from torch._dynamo.test_case import run_tests
117
118    run_tests()
119