# Owner(s): ["module: dynamo"] from unittest.mock import patch import torch import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.testing import unsupported from torch._dynamo.utils import ifdynstaticdefault globalmod = torch.nn.ReLU() def indirectly_unsupported(a, b): c = a + b return unsupported(a, c) class SubGraphTests(torch._dynamo.test_case.TestCase): def _common(self, fn, frame_count, op_count): torch._dynamo.reset() v1 = torch.ones(10) v2 = torch.ones(10) * -2.0 correct1 = fn(v1, v2) correct2 = fn(v2, v1) cnt = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnt)(fn) r1 = opt_fn(v1, v2) r2 = opt_fn(v2, v1) self.assertTrue(torch._dynamo.testing.same(r1, correct1)) self.assertTrue(torch._dynamo.testing.same(r2, correct2)) self.assertEqual( cnt.frame_count, frame_count, f"actual {cnt.frame_count} != expected {frame_count}", ) self.assertEqual(cnt.op_count, op_count) def test_control_flow1(self): def fn(a, b): c1 = a - b c2 = b - a if c1.sum() > c2.sum(): return c1 else: return c2 self._common(fn, 1, 5) def test_control_flow2(self): def fn(a, b): if a.sum() > b.sum(): return 1 else: return 2 self._common(fn, 1, 3) def test_control_flow3(self): def fn(a, b): c1 = a - b c2 = b - a m = globalmod if c1.sum() > c2.sum(): return m(c1) else: return m(c2) self._common(fn, 3, 7) def test_control_flow4(self): def fn(a, b): tmp1 = a.sum() > b.sum() and a.sum() > 0 if tmp1: return 1 else: return 2 self._common(fn, 3, 5) def test_control_flow5(self): def fn(a, b): tmp1 = a.sum() > b.sum() and a.sum() > 0 tmp2 = a.sum() < b.sum() or b.sum() > 0 if tmp1 and tmp2: return 1, tmp1, tmp2 else: return 2, tmp1, tmp2 self._common(fn, 6, 13) def test_capi_call1(self): def fn(a, b): c1 = a - b c2 = b - a return unsupported(c1, c2) self._common(fn, 1, 2) def test_capi_call2(self): def fn(a, b): c1 = a - b c2 = b - a return a - (b - unsupported(c1, c2)) self._common(fn, 2, 4) def test_capi_call3(self): def fn(a, b): c1 = a - b c2 = b - a return torch._dynamo.testing.unsupported(c1, c2) self._common(fn, 1, 2) def test_indirect_unsupported1(self): def fn(a, b): c1 = a - b c2 = b - a return indirectly_unsupported(c1, c2) self._common(fn, 2, 3) def test_indirect_unsupported2(self): def fn(a, b): local_const1 = 7 local_const2 = 22 c1 = a - b c2 = b - a return local_const1 / (local_const2 - indirectly_unsupported(c1, c2)) self._common(fn, 3, 5) def test_indirect_unsupported3(self): def fn(a, b): args = [a - b, b - a] return indirectly_unsupported(*args) self._common(fn, 2, 3) def test_stack_state1(self): def fn(a, b): t1 = 1.23 * a t2 = 4.56 * a c1 = a - b c2 = b - a return t1 / (t2 - unsupported(c1, c2)) self._common(fn, 2, 6) def test_stack_state2(self): def fn(a, b): t1 = 1.23 * a t2 = 4.56 * a c1 = a - b c2 = b - a return t1 / (t2 - indirectly_unsupported(c1, c2)) self._common(fn, 3, 7) def test_multigraph(self): def fn(a, b): x = a + b x = x / 2.0 if x.sum() < 0: return x * -1.0 return x self._common(fn, 2, 5) def test_extended_args(self): too_many_adds = "+".join(["a", "b"] * 256) source = ( f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)" ) self._common(eval(source), 3, 1026) def test_resume1(self): def fn(a, b): x = a + b x = x / 2.0 x = x + 2.0 x = unsupported(x, a) x = x + 2.0 x = x + 2.0 x = x + 2.0 return x self._common(fn, 2, 6) def test_resume2(self): def fn(a, b): x = a + b x = x / 2.0 x = x + 2.0 x = indirectly_unsupported(x, a) x = x + 2.0 x = x + 2.0 x = x + 2.0 return x self._common(fn, 3, 7) def test_resume3(self): def fn(a, b): x = a + b x = x / 2.0 x = x + 2.0 x = indirectly_unsupported(x, b=a) x = x + 2.0 x = x + 2.0 x = x + 2.0 return x self._common(fn, 3, 7) def test_resume4(self): def fn(a, b): x = a + b x = x / 2.0 x = x + 2.0 x = indirectly_unsupported(a=x, b=a) x = x + 2.0 x = x + 2.0 x = x + 2.0 return x self._common(fn, 3, 7) def test_resume5(self): def fn(a, b): x = a + b x = x / 2.0 x = x + 2.0 print(x) x = x + 2.0 x = x + 2.0 x = x + 2.0 return x self._common(fn, 2, 6) def test_start1(self): def fn(a, b): print(a) x = a + b x = x + 2.0 x = x + 2.0 return x self._common(fn, 1, 3) def test_start2(self): def fn(a, b): x = indirectly_unsupported(a, b) x = x + 2.0 x = x + 2.0 x = x + 2.0 return x self._common(fn, 2, 4) def test_start3(self): def fn(a, b): x = unsupported(a, b) x = x + 2.0 x = x + 2.0 x = x + 2.0 return x self._common(fn, 1, 3) def test_start4(self): def fn(a, b, check): if check: return a + b + 10 else: return a + b - 10 v1 = torch.randn(10) v2 = torch.randn(10) f = torch.zeros(1, dtype=torch.int32) t = torch.ones(1, dtype=torch.int32) correct1 = fn(v1, v2, t) correct2 = fn(v1, v2, f) cnt = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnt)(fn) r1 = opt_fn(v1, v2, t) r2 = opt_fn(v1, v2, f) self.assertTrue(torch._dynamo.testing.same(r1, correct1)) self.assertTrue(torch._dynamo.testing.same(r2, correct2)) self.assertEqual(cnt.frame_count, 3) self.assertEqual(cnt.op_count, 4) def test_resume_freevars(self): c1 = torch.randn(10) c2 = torch.randn(10) def fn(a, b): x = a + b + (c1 - c2) x = unsupported(x, x) return x + (c1 - c2) self._common(fn, 2, 5) def test_restore_state(self): def fn(a, b): len_ = len x = a + b x = torch.add(unsupported(x, x), 1) return a * x + len_(b) self._common(fn, 2, 4) def test_restore_range(self): def fn(a, b): x = a + b rng = range(3, 8, 2) x = unsupported(x, x) for i in rng: x = x + i return x # We don't specialize on range with dynamic shapes, which # means we fail to unroll the loop. # TODO: Consider forcing specialization when we iterate over # the loop self._common(fn, ifdynstaticdefault(2, 1), ifdynstaticdefault(4, 1)) def test_restore_range_iter(self): def fn(a, b): x = a + b rng = iter(range(3, 8, 2)) x = unsupported(x, x) x += next(rng) return x, list(rng) self._common(fn, 2, 2) def test_pop_after_resume(self): def fn(a, b): tmp = [a + 1, b + 2, a + b] x = a x = unsupported(x, x) for i in range(3): x += tmp.pop(-1) return x self._common(fn, 2, 6) @patch("torch._dynamo.config.assume_static_by_default", False) def test_dynamic_getitem(self): def fn(a, b): return a[b.size(0) - 1] cnt = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnt)(fn) for i in range(3, 12): opt_fn(torch.randn(i), torch.randn(i)) # just one graph self.assertEqual(cnt.frame_count, 1) def test_dynamic_kwarg(self): def fn(a, b): return a - b * 10 torch._dynamo.reset() cnt_dynamic = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn) start = 2 end = 12 steps = end - start for i in range(start, end): opt_fn(torch.randn(i), torch.randn(i)) self.assertEqual(cnt_dynamic.frame_count, 1) def test_dynamic_duck_size(self): def fn(a, b): if a.size(0) == b.size(0): return a + b else: return a.sum() + b.sum() torch._dynamo.reset() cnt_dynamic = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn) x = torch.randn(2) y = torch.randn(3) self.assertEqual(opt_fn(x, x), fn(x, x)) self.assertEqual(opt_fn(x, y), fn(x, y)) self.assertEqual(cnt_dynamic.frame_count, 2) def test_dynamic_order_dependence(self): def fn(a, b): return a.sum() + b.sum() torch._dynamo.reset() cnt_dynamic = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn) x = torch.randn(2) y = torch.randn(3) self.assertEqual(opt_fn(x, y), fn(x, y)) self.assertEqual(opt_fn(x, x), fn(x, x)) # NB: This COULD validly be 2, but we don't test disjointness in the # guards for when x and y didn't duck size together, so we end up # with a generic graph that also works when x and y happen to duck # size together. self.assertEqual(cnt_dynamic.frame_count, 2) torch._dynamo.reset() cnt_dynamic.frame_count = 0 self.assertEqual(opt_fn(x, x), fn(x, x)) # this overspecializes! self.assertEqual(opt_fn(x, y), fn(x, y)) self.assertEqual(cnt_dynamic.frame_count, 2) def test_dynamic_zero_inference(self): def fn(a): if a.size(0) != 0: return a * 2 else: return a + 1 torch._dynamo.reset() cnt_dynamic = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn) x = torch.randn(0) y = torch.randn(2) self.assertEqual(opt_fn(y), fn(y)) self.assertEqual(opt_fn(x), fn(x)) self.assertEqual(cnt_dynamic.frame_count, 2) @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) def test_no_graph_break_on_item(self): def fn(a, b): x = a + b - 1.5 x = x.sum() x.item() x = x / (a + b) return x self._common(fn, 1, 5) # item gets DCE'd @patch.object(torch._dynamo.config, "capture_scalar_outputs", False) def test_graph_break_on_item(self): def fn(a, b): x = a + b - 1.5 x = x.sum() x.item() x = x / (a + b) return x self._common(fn, 2, 5) def test_resume_paths_join(self): def fn(x, c1, c2, c3): x = x + 1 if c1: x = x + 2 x = x + 3 if c2: x = x + 4 x = x + 5 if c3: x = x + 6 return x + 7 v1 = torch.randn(10) t = torch.Tensor([True]) f = torch.Tensor([False]) cnt = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnt)(fn) for a in (t, f): for b in (t, f): for c in (t, f): opt_fn(v1, a, b, c) # checking here we don't create 2^n graphs self.assertEqual(cnt.frame_count, 7) self.assertEqual(cnt.op_count, 10) def test_resume_with_no_grad1(self): def fn(a, b): x = a + b with torch.no_grad(): x = x + 1 x.sum().tolist() # graph break x = x + 2 x = x + 3 return x self._common(fn, 2, 9) torch._dynamo.reset() with torch.no_grad(): self._common(fn, 2, 5) def test_resume_with_no_grad2(self): def fn(a, b): x = a + b with torch.no_grad(): x = x + 1 x.sum().tolist() # graph break x = x + 2 x.sum().tolist() # graph break x = x + 3 x = x + 4 return x self._common(fn, 3, 13) def test_resume_with_no_grad3(self): def fn(a, b): x = a + b with torch.no_grad(): with torch.no_grad(): x = x + 1 with torch.enable_grad(): x.sum().tolist() # graph break x = x[0] + 2 x = x + 3 x = x + 4 return x self._common(fn, 2, 11) def test_resume_tuple_iterator(self): def fn(a, b): x = a + b it = iter(tuple(range(10))) x = x + next(it) x = x + next(it) x = x + next(it) x = unsupported(x, x) x = x + next(it) x = x + next(it) x = x + next(it) x = x + next(it) return x self._common(fn, 2, 8) def test_tuple_iterator_return(self): def fn(x): it = iter(tuple(range(10))) x = x + next(it) x = x + next(it) x = unsupported(x, x) x = x + next(it) x = x + next(it) x = unsupported(x, x) x = x + next(it) x = x + next(it) return x, it v1 = torch.randn(10) v2, it2 = fn(v1) cnt = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnt)(fn) v3, it3 = opt_fn(v1) v4, it4 = opt_fn(v1) self.assertEqual(v2.tolist(), v3.tolist()) self.assertEqual(v2.tolist(), v4.tolist()) self.assertEqual(list(it2), list(it3)) self.assertEqual(cnt.frame_count, 3) self.assertEqual(cnt.op_count, 6) def test_tuple_iterator_mutate(self): def fn(x, it): x = x + next(it) x = x + next(it) x = x + next(it) x = x + next(it) return x v1 = torch.randn(10) it1 = iter(tuple(range(10))) cnt = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnt)(fn) self.assertEqual(opt_fn(v1, it1).tolist(), (v1 + 1 + 2 + 3).tolist()) self.assertEqual(list(it1), [4, 5, 6, 7, 8, 9]) def test_enumerate_not_break_graph(self): def fn(a, b): for i, x in enumerate(a.shape): b = b + x for i, x in enumerate(b.shape, 8): b = b + x * i return b self._common(fn, 1, ifdynstaticdefault(2, 3)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()