1# Owner(s): ["module: dynamo"] 2import io 3import warnings 4from unittest.mock import patch 5 6import torch 7import torch._dynamo 8import torch._dynamo.test_case 9import torch._dynamo.testing 10from torch._dynamo.testing import same 11from torch._dynamo.utils import counters 12 13 14class ReorderLogsTests(torch._dynamo.test_case.TestCase): 15 def test_dont_reorder_print(self): 16 def f(x): 17 x = x + x 18 print("moo") 19 x = x * x 20 return x 21 22 counters.clear() 23 x = torch.randn(3, 3) 24 opt_f = torch.compile(backend="eager")(f) 25 with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: 26 opt_out = opt_f(x) 27 printed_output = mock_stdout.getvalue().strip() 28 orig_out = f(x) 29 30 self.assertTrue(same(orig_out, opt_out)) 31 self.assertEqual(printed_output, "moo") 32 self.assertEqual(len(counters["graph_break"]), 1) 33 34 @torch._dynamo.config.patch(reorderable_logging_functions={print}) 35 def test_reorder_print(self): 36 def f(x): 37 print("moo") 38 x1 = x + x 39 print(x1) 40 x2 = x1 * x1 41 print(1, 2, 3) 42 x3 = x2 + x2 43 return (x1, x3) 44 45 x = torch.ones(3, 3) 46 opt_f = torch.compile(backend="eager", fullgraph=True)(f) 47 with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: 48 opt_out = opt_f(x) 49 printed_output = mock_stdout.getvalue().strip() 50 orig_out = f(x) 51 52 self.assertEqual(printed_output, f"moo\n{torch.ones(3, 3) * 2}\n1 2 3") 53 self.assertTrue(same(orig_out, opt_out)) 54 55 @torch._dynamo.config.patch(reorderable_logging_functions={warnings.warn}) 56 def test_reorder_warnings(self): 57 import warnings 58 59 def f(x): 60 x1 = x + x 61 warnings.warn("moo") 62 x2 = x1 * x1 63 warnings.warn(f"{x2}") 64 x3 = x2 + x2 65 return x3 66 67 x = torch.ones(3, 3) 68 opt_f = torch.compile(backend="eager", fullgraph=True)(f) 69 with warnings.catch_warnings(record=True) as w: 70 opt_out = opt_f(x) 71 warning_messages = [str(i.message) for i in w] 72 orig_out = f(x) 73 74 self.assertTrue(same(orig_out, opt_out)) 75 self.assertIn("moo", warning_messages) 76 77 @torch._dynamo.config.patch(reorderable_logging_functions={print}) 78 def test_reorder_print_graph_break(self): 79 def f(x): 80 x1 = x + x 81 print(f"res: {x1}") 82 x2 = x1 * x1 83 torch._dynamo.graph_break() 84 x3 = x2 + x2 85 print(1, 2, 3) 86 return x3 87 88 x = torch.ones(3, 3) 89 opt_f = torch.compile(backend="eager")(f) 90 with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: 91 opt_out = opt_f(x) 92 printed_output = mock_stdout.getvalue().strip() 93 orig_out = f(x) 94 95 self.assertEqual(printed_output, f"res: {torch.ones(3, 3) * 2}\n1 2 3") 96 self.assertTrue(same(orig_out, opt_out)) 97 98 def test_reorder_custom_log_fn(self): 99 custom_logs = [] 100 101 def custom_log(s: str): 102 torch._dynamo.graph_break() 103 custom_logs.append(s) 104 105 def f(x): 106 custom_log("moo") 107 x1 = x + x 108 custom_log(f"{x1}") 109 return x + x 110 111 x = torch.ones(3, 3) 112 counters.clear() 113 with torch._dynamo.config.patch(reorderable_logging_functions={custom_log}): 114 opt_f = torch.compile(backend="eager")(f) 115 opt_out = opt_f(x) 116 117 self.assertEqual(sum(counters["graph_break"].values()), 1) 118 self.assertEqual(custom_logs[0], "moo") 119 self.assertEqual(custom_logs[1], f"{torch.ones(3, 3) * 2}") 120 121 @torch._dynamo.config.patch(reorderable_logging_functions={print}) 122 def test_constant_mutation(self): 123 def f(x): 124 alist = [x] 125 alist.append(x + 1) 126 print(alist[-1]) 127 alist[0].sum().item() # graph break 128 res = alist.pop() 129 print(alist[-1]) 130 res.sum().item() # graph break 131 return res 132 133 inputs = (torch.tensor([1]),) 134 counters.clear() 135 opt_f = torch.compile(backend="eager")(f) 136 with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: 137 opt_out = opt_f(*inputs) 138 printed_output = mock_stdout.getvalue().strip() 139 orig_out = f(*inputs) 140 141 self.assertEqual(printed_output, "tensor([2])\ntensor([1])") 142 self.assertTrue(same(orig_out, opt_out)) 143 144 graph_break_key = counters["graph_break"].keys() 145 self.assertEqual(len(graph_break_key), 1) 146 self.assertEqual(next(iter(graph_break_key)), "Tensor.item") 147 148 149if __name__ == "__main__": 150 from torch._dynamo.test_case import run_tests 151 152 run_tests() 153