1# Owner(s): ["module: dynamo"] 2 3import torch 4import torch._dynamo.test_case 5import torch._dynamo.testing 6from torch._dynamo.utils import disable_cache_limit 7 8 9# NB: do NOT include this test class in test_dynamic_shapes.py 10 11 12class ConfigTests(torch._dynamo.test_case.TestCase): 13 @disable_cache_limit() 14 def test_no_automatic_dynamic(self): 15 def fn(a, b): 16 return a - b * 10 17 18 torch._dynamo.reset() 19 cnt_static = torch._dynamo.testing.CompileCounter() 20 with torch._dynamo.config.patch( 21 automatic_dynamic_shapes=False, assume_static_by_default=True 22 ): 23 opt_fn = torch._dynamo.optimize(cnt_static)(fn) 24 for i in range(2, 12): 25 opt_fn(torch.randn(i), torch.randn(i)) 26 self.assertEqual(cnt_static.frame_count, 10) 27 28 @disable_cache_limit() 29 def test_automatic_dynamic(self): 30 def fn(a, b): 31 return a - b * 10 32 33 torch._dynamo.reset() 34 cnt_dynamic = torch._dynamo.testing.CompileCounter() 35 with torch._dynamo.config.patch( 36 automatic_dynamic_shapes=True, assume_static_by_default=True 37 ): 38 opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn) 39 # NB: must not do 0, 1 as they specialized 40 for i in range(2, 12): 41 opt_fn(torch.randn(i), torch.randn(i)) 42 # two graphs now rather than 10 43 self.assertEqual(cnt_dynamic.frame_count, 2) 44 45 @disable_cache_limit() 46 def test_no_assume_static_by_default(self): 47 def fn(a, b): 48 return a - b * 10 49 50 torch._dynamo.reset() 51 cnt_dynamic = torch._dynamo.testing.CompileCounter() 52 with torch._dynamo.config.patch( 53 automatic_dynamic_shapes=True, assume_static_by_default=False 54 ): 55 opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn) 56 # NB: must not do 0, 1 as they specialized 57 for i in range(2, 12): 58 opt_fn(torch.randn(i), torch.randn(i)) 59 # one graph now, as we didn't wait for recompile 60 self.assertEqual(cnt_dynamic.frame_count, 1) 61 62 def test_config_compile_ignored(self): 63 # Remove from this list if no longer relevant 64 dynamo_guarded_config_ignorelist = { 65 "log_file_name", 66 "verbose", 67 "verify_correctness", # will not affect model, will raise RuntimeError 68 # (no silent change to compilation behaviour) 69 "cache_size_limit", 70 "accumulated_cache_size_limit", 71 "replay_record_enabled", 72 "cprofile", # only wraps _compile, not graph 73 "repro_after", 74 "repro_level", 75 "repro_forward_only", 76 "repro_tolerance", 77 "same_two_models_use_fp64", 78 "error_on_recompile", # safe because: will throw error 79 "report_guard_failures", 80 "base_dir", # used for minifying / logging 81 "DEBUG_DIR_VAR_NAME", 82 "debug_dir_root", 83 } 84 for k in dynamo_guarded_config_ignorelist: 85 assert k in torch._dynamo.config._compile_ignored_keys, k 86 87 def test_config_hash(self): 88 config = torch._dynamo.config 89 starting_hash = config.get_hash() 90 91 with config.patch({"verbose": not config.verbose}): 92 new_hash = config.get_hash() 93 assert "verbose" in config._compile_ignored_keys 94 assert new_hash == starting_hash 95 96 new_hash = config.get_hash() 97 assert new_hash == starting_hash 98 99 with config.patch({"dead_code_elimination": not config.dead_code_elimination}): 100 changed_hash = config.get_hash() 101 assert "dead_code_elimination" not in config._compile_ignored_keys 102 assert changed_hash != starting_hash 103 104 # Test nested patch 105 with config.patch({"verbose": not config.verbose}): 106 inner_changed_hash = config.get_hash() 107 assert inner_changed_hash == changed_hash 108 assert inner_changed_hash != starting_hash 109 110 newest_hash = config.get_hash() 111 assert changed_hash != newest_hash 112 assert newest_hash == starting_hash 113 114 115if __name__ == "__main__": 116 from torch._dynamo.test_case import run_tests 117 118 run_tests() 119