1# Owner(s): ["module: dynamo"] 2import unittest 3import weakref 4 5import torch 6import torch._dynamo 7import torch._dynamo.config 8import torch._dynamo.test_case 9import torch._dynamo.testing 10import torch._logging 11from torch.testing._internal.logging_utils import kwargs_to_settings, log_settings 12 13 14class RecompileUxTests(torch._dynamo.test_case.TestCase): 15 # TODO(whc) dynamo actually recompiles one more time than the cache limit 16 cache_limit = 1 17 18 @classmethod 19 def setUpClass(cls): 20 super().setUpClass() 21 cls._exit_stack.enter_context( 22 torch._dynamo.config.patch("cache_size_limit", cls.cache_limit) 23 ) 24 25 def test_drop_cache_on_skip(self): 26 def model(x, i): 27 return x + i 28 29 attached = False 30 triggered = False 31 32 def trigger(): 33 nonlocal triggered 34 triggered = True 35 36 def compiler(gm, input): 37 nonlocal attached 38 f = gm.forward 39 assert not attached 40 # NB: making this a weakref.ref causes the cycle to no 41 # longer be promptly GC'ed 42 weakref.finalize(f, trigger) 43 attached = True 44 return f 45 46 x = torch.randn(2) 47 for i in range(2): 48 opt_model = torch._dynamo.optimize(compiler)(model) 49 opt_model(x, i) 50 51 self.assertTrue(triggered) 52 53 def test_loop_torture(self): 54 def loop_torture(input, iters): 55 out = input 56 # randint itself causes one graph break 57 for _ in range(iters): 58 out += input 59 return out 60 61 compile_counter = torch._dynamo.testing.CompileCounter() 62 for _ in range(10): 63 x = torch.randn(3) 64 iters = torch.randint(low=0, high=1000, size=()) 65 opt_loop_torture = torch._dynamo.optimize(compile_counter)(loop_torture) 66 opt_loop_torture(x, iters) 67 68 # Currently, we recompile each time, 69 # We'd probably like to bail out quickly and warn 70 # TODO(whc) these checks fail on py37. Why? 71 # self.assertEqual(counters["frames"]["total"], 2 + self.cache_limit) 72 # self.assertEqual(counters["frames"]["ok"], 1 + self.cache_limit) 73 74 # compile_counter only sees frames that were fed to the backend compiler, 75 # which is a subset of counters["frames"]["ok"] -- probably because 76 # counters["frames"]["ok"] includes frames not containing torch ops? 77 self.assertEqual(compile_counter.frame_count, self.cache_limit) 78 79 @torch._dynamo.config.patch("automatic_dynamic_shapes", False) 80 def test_dynamic_input(self): 81 def model(input): 82 return input + input 83 84 expected_recompiles = 2 85 compile_counter = torch._dynamo.testing.CompileCounter() 86 with torch._dynamo.config.patch("cache_size_limit", expected_recompiles): 87 with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs: 88 for _ in range(10): 89 bsz = torch.randint(low=0, high=1000, size=()) 90 x = torch.randn((bsz, 3, 4)) 91 opt_model = torch._dynamo.optimize(compile_counter)(model) 92 opt_model(x) 93 94 self.assertEqual(compile_counter.frame_count, expected_recompiles) 95 self.assertEqual(len(logs.records), 1) 96 print(logs.records[0]) 97 self.assertTrue( 98 logs.records[0] 99 .getMessage() 100 .startswith("torch._dynamo hit config.cache_size_limit") 101 ) 102 103 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 104 def test_nvfuser_guards(self): 105 # we may want to model dynamo's guards sufficiently after nvfuser's ProfilingExecutor guards 106 # such that we ensure dynamo is in charge of all the recompilations at the top level, 107 # and we could thus simplify the underlying torchscript executor 108 def func(a, b, c): 109 return a + b * c 110 111 a = torch.rand(3, 4, 5, device="cuda") 112 b = torch.rand(3, 4, 5, device="cuda") 113 b_v = torch.rand(3, 5, 4, device="cuda").view(3, 4, 5) 114 b_p = torch.rand(3, 5, 4, device="cuda").permute(0, 2, 1) 115 c = torch.rand(3, 4, 5, device="cuda") 116 compile_counter = torch._dynamo.testing.CompileCounter() 117 118 with torch._dynamo.config.patch("cache_size_limit", 2): 119 opt_func = torch._dynamo.optimize(compile_counter)(func) 120 opt_func(a, b, c) # warmup 121 self.assertEqual(compile_counter.frame_count, 1) 122 123 opt_func(a, b, c) # no guard fail or recompile 124 self.assertEqual(compile_counter.frame_count, 1) 125 126 opt_func(a, b_v, c) # a view should not cause nvfuser recompile 127 self.assertEqual(compile_counter.frame_count, 1) 128 129 opt_func(a, b_p, c) # a permutation should cause recompile 130 self.assertEqual(compile_counter.frame_count, 2) 131 132 def assert_single_log_contains(self, logs, contains_str): 133 self.assertEqual(len(logs.records), 1) 134 self.assertTrue( 135 logs.records[0].getMessage().find(contains_str) > 0, 136 msg=f'Expected to find "{contains_str}" in log "{logs.records[0].getMessage()}"', 137 ) 138 139 def test_verbose_tensor_check(self): 140 def func(a): 141 # Warning: choose a function here whose meta implementation lives 142 # entirely in C++. If you do a Python one, Dynamo will dive into 143 # torch._refs which is OK but it will muddy up the warnings 144 return torch.add(a, 4) 145 146 def cache_fail_test(cached_input, missed_input, expected_failure): 147 # TODO(whc) maybe its hacky to have a 'test within a test' but this seemed convenient 148 torch._dynamo.reset() 149 torch._dynamo.utils.counters.clear() 150 opt_func = torch._dynamo.optimize("eager")(func) 151 # warmup 152 opt_func(cached_input) 153 154 with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs: 155 opt_func = torch._dynamo.optimize("eager")(func) 156 opt_func(missed_input) 157 self.assert_single_log_contains(logs, expected_failure) 158 159 a = torch.rand(3, 4, 5) 160 cache_fail_test( 161 a, 162 a[0:2, :, :], 163 "tensor 'L['a']' size mismatch at index 0. expected 3, actual 2", 164 ) 165 cache_fail_test( 166 a, 167 a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)), 168 "tensor 'L['a']' stride mismatch at index 0. expected 20, actual 1", 169 ) 170 cache_fail_test( 171 a, a[0, :, :], "tensor 'L['a']' rank mismatch. expected 3, actual 2" 172 ) 173 cache_fail_test(a, a.to("meta"), "tensor 'L['a']' dispatch key set mismatch.") 174 cache_fail_test( 175 a, 176 a.to(torch.float16), 177 "tensor 'L['a']' dtype mismatch. expected Float, actual Half", 178 ) 179 a_grad = a.clone() 180 a_grad.requires_grad = True 181 cache_fail_test( 182 a, 183 a_grad, 184 "tensor 'L['a']' requires_grad mismatch. expected requires_grad=0", 185 ) 186 187 def test_mismatched_type(self): 188 a = torch.rand(3, 4, 5) 189 b = torch.rand(3, 4, 5) 190 191 def func(a, b): 192 return a + b 193 194 opt_func = torch._dynamo.optimize("eager")(func) 195 # warmup 196 opt_func(a, b) 197 198 with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs: 199 opt_func = torch._dynamo.optimize("eager")(func) 200 opt_func(a, 1) 201 self.assert_single_log_contains( 202 logs, 203 "expected type of 'L['b']' to be a tensor type, ' but found <class 'int'>", 204 ) 205 206 @torch._dynamo.config.patch("cache_size_limit", 32) 207 def test_multiple_guard_fails(self): 208 failure_reasons = [] 209 210 def guard_fail_fn(failure): 211 failure_reasons.append(failure[0]) 212 213 def f(x): 214 return torch.relu(x) 215 216 opt_f = torch._dynamo.optimize( 217 backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False 218 )(f) 219 220 for i in range(5): 221 failure_reasons.clear() 222 opt_f(torch.randn(8 + i)) 223 224 failure_str = "\n".join(failure_reasons) 225 for line in """\ 226tensor 'L['x']' size mismatch at index 0. expected 11, actual 12 227tensor 'L['x']' size mismatch at index 0. expected 10, actual 12 228tensor 'L['x']' size mismatch at index 0. expected 9, actual 12 229tensor 'L['x']' size mismatch at index 0. expected 8, actual 12""".split( 230 "\n" 231 ): 232 self.assertIn( 233 line, 234 failure_str, 235 ) 236 237 @torch._dynamo.config.patch("cache_size_limit", 32) 238 def test_multiple_guard_fails_report_all(self): 239 with log_settings(kwargs_to_settings(recompiles_verbose=True)): 240 failure_reasons = [] 241 242 def guard_fail_fn(failure): 243 failure_reasons.append(failure[0]) 244 245 def f(x): 246 return torch.ones(len(x), x[-1]) 247 248 opt_f = torch._dynamo.optimize( 249 backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False 250 )(f) 251 252 opt_f([4, 5, 6]) 253 254 def filter_reasons(): 255 return "\n".join( 256 [ 257 line 258 for line in "\n".join(failure_reasons).splitlines() 259 if not line.startswith("___check_type_id") 260 ] 261 ) 262 263 failure_reasons.clear() 264 opt_f([7, 8]) 265 266 for line in """\ 267len(L['x']) == 3""".split( 268 "\n" 269 ): 270 self.assertIn(line, filter_reasons()) 271 272 failure_reasons.clear() 273 opt_f([9]) 274 275 for line in """\ 276len(L['x']) == 2 277len(L['x']) == 3""".split( 278 "\n" 279 ): 280 self.assertIn(line, filter_reasons()) 281 282 283if __name__ == "__main__": 284 from torch._dynamo.test_case import run_tests 285 286 run_tests() 287