# Owner(s): ["module: dynamo"] import unittest import weakref import torch import torch._dynamo import torch._dynamo.config import torch._dynamo.test_case import torch._dynamo.testing import torch._logging from torch.testing._internal.logging_utils import kwargs_to_settings, log_settings class RecompileUxTests(torch._dynamo.test_case.TestCase): # TODO(whc) dynamo actually recompiles one more time than the cache limit cache_limit = 1 @classmethod def setUpClass(cls): super().setUpClass() cls._exit_stack.enter_context( torch._dynamo.config.patch("cache_size_limit", cls.cache_limit) ) def test_drop_cache_on_skip(self): def model(x, i): return x + i attached = False triggered = False def trigger(): nonlocal triggered triggered = True def compiler(gm, input): nonlocal attached f = gm.forward assert not attached # NB: making this a weakref.ref causes the cycle to no # longer be promptly GC'ed weakref.finalize(f, trigger) attached = True return f x = torch.randn(2) for i in range(2): opt_model = torch._dynamo.optimize(compiler)(model) opt_model(x, i) self.assertTrue(triggered) def test_loop_torture(self): def loop_torture(input, iters): out = input # randint itself causes one graph break for _ in range(iters): out += input return out compile_counter = torch._dynamo.testing.CompileCounter() for _ in range(10): x = torch.randn(3) iters = torch.randint(low=0, high=1000, size=()) opt_loop_torture = torch._dynamo.optimize(compile_counter)(loop_torture) opt_loop_torture(x, iters) # Currently, we recompile each time, # We'd probably like to bail out quickly and warn # TODO(whc) these checks fail on py37. Why? # self.assertEqual(counters["frames"]["total"], 2 + self.cache_limit) # self.assertEqual(counters["frames"]["ok"], 1 + self.cache_limit) # compile_counter only sees frames that were fed to the backend compiler, # which is a subset of counters["frames"]["ok"] -- probably because # counters["frames"]["ok"] includes frames not containing torch ops? self.assertEqual(compile_counter.frame_count, self.cache_limit) @torch._dynamo.config.patch("automatic_dynamic_shapes", False) def test_dynamic_input(self): def model(input): return input + input expected_recompiles = 2 compile_counter = torch._dynamo.testing.CompileCounter() with torch._dynamo.config.patch("cache_size_limit", expected_recompiles): with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs: for _ in range(10): bsz = torch.randint(low=0, high=1000, size=()) x = torch.randn((bsz, 3, 4)) opt_model = torch._dynamo.optimize(compile_counter)(model) opt_model(x) self.assertEqual(compile_counter.frame_count, expected_recompiles) self.assertEqual(len(logs.records), 1) print(logs.records[0]) self.assertTrue( logs.records[0] .getMessage() .startswith("torch._dynamo hit config.cache_size_limit") ) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") def test_nvfuser_guards(self): # we may want to model dynamo's guards sufficiently after nvfuser's ProfilingExecutor guards # such that we ensure dynamo is in charge of all the recompilations at the top level, # and we could thus simplify the underlying torchscript executor def func(a, b, c): return a + b * c a = torch.rand(3, 4, 5, device="cuda") b = torch.rand(3, 4, 5, device="cuda") b_v = torch.rand(3, 5, 4, device="cuda").view(3, 4, 5) b_p = torch.rand(3, 5, 4, device="cuda").permute(0, 2, 1) c = torch.rand(3, 4, 5, device="cuda") compile_counter = torch._dynamo.testing.CompileCounter() with torch._dynamo.config.patch("cache_size_limit", 2): opt_func = torch._dynamo.optimize(compile_counter)(func) opt_func(a, b, c) # warmup self.assertEqual(compile_counter.frame_count, 1) opt_func(a, b, c) # no guard fail or recompile self.assertEqual(compile_counter.frame_count, 1) opt_func(a, b_v, c) # a view should not cause nvfuser recompile self.assertEqual(compile_counter.frame_count, 1) opt_func(a, b_p, c) # a permutation should cause recompile self.assertEqual(compile_counter.frame_count, 2) def assert_single_log_contains(self, logs, contains_str): self.assertEqual(len(logs.records), 1) self.assertTrue( logs.records[0].getMessage().find(contains_str) > 0, msg=f'Expected to find "{contains_str}" in log "{logs.records[0].getMessage()}"', ) def test_verbose_tensor_check(self): def func(a): # Warning: choose a function here whose meta implementation lives # entirely in C++. If you do a Python one, Dynamo will dive into # torch._refs which is OK but it will muddy up the warnings return torch.add(a, 4) def cache_fail_test(cached_input, missed_input, expected_failure): # TODO(whc) maybe its hacky to have a 'test within a test' but this seemed convenient torch._dynamo.reset() torch._dynamo.utils.counters.clear() opt_func = torch._dynamo.optimize("eager")(func) # warmup opt_func(cached_input) with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs: opt_func = torch._dynamo.optimize("eager")(func) opt_func(missed_input) self.assert_single_log_contains(logs, expected_failure) a = torch.rand(3, 4, 5) cache_fail_test( a, a[0:2, :, :], "tensor 'L['a']' size mismatch at index 0. expected 3, actual 2", ) cache_fail_test( a, a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)), "tensor 'L['a']' stride mismatch at index 0. expected 20, actual 1", ) cache_fail_test( a, a[0, :, :], "tensor 'L['a']' rank mismatch. expected 3, actual 2" ) cache_fail_test(a, a.to("meta"), "tensor 'L['a']' dispatch key set mismatch.") cache_fail_test( a, a.to(torch.float16), "tensor 'L['a']' dtype mismatch. expected Float, actual Half", ) a_grad = a.clone() a_grad.requires_grad = True cache_fail_test( a, a_grad, "tensor 'L['a']' requires_grad mismatch. expected requires_grad=0", ) def test_mismatched_type(self): a = torch.rand(3, 4, 5) b = torch.rand(3, 4, 5) def func(a, b): return a + b opt_func = torch._dynamo.optimize("eager")(func) # warmup opt_func(a, b) with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs: opt_func = torch._dynamo.optimize("eager")(func) opt_func(a, 1) self.assert_single_log_contains( logs, "expected type of 'L['b']' to be a tensor type, ' but found ", ) @torch._dynamo.config.patch("cache_size_limit", 32) def test_multiple_guard_fails(self): failure_reasons = [] def guard_fail_fn(failure): failure_reasons.append(failure[0]) def f(x): return torch.relu(x) opt_f = torch._dynamo.optimize( backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False )(f) for i in range(5): failure_reasons.clear() opt_f(torch.randn(8 + i)) failure_str = "\n".join(failure_reasons) for line in """\ tensor 'L['x']' size mismatch at index 0. expected 11, actual 12 tensor 'L['x']' size mismatch at index 0. expected 10, actual 12 tensor 'L['x']' size mismatch at index 0. expected 9, actual 12 tensor 'L['x']' size mismatch at index 0. expected 8, actual 12""".split( "\n" ): self.assertIn( line, failure_str, ) @torch._dynamo.config.patch("cache_size_limit", 32) def test_multiple_guard_fails_report_all(self): with log_settings(kwargs_to_settings(recompiles_verbose=True)): failure_reasons = [] def guard_fail_fn(failure): failure_reasons.append(failure[0]) def f(x): return torch.ones(len(x), x[-1]) opt_f = torch._dynamo.optimize( backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False )(f) opt_f([4, 5, 6]) def filter_reasons(): return "\n".join( [ line for line in "\n".join(failure_reasons).splitlines() if not line.startswith("___check_type_id") ] ) failure_reasons.clear() opt_f([7, 8]) for line in """\ len(L['x']) == 3""".split( "\n" ): self.assertIn(line, filter_reasons()) failure_reasons.clear() opt_f([9]) for line in """\ len(L['x']) == 2 len(L['x']) == 3""".split( "\n" ): self.assertIn(line, filter_reasons()) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()