xref: /aosp_15_r20/external/pytorch/test/dynamo/test_reorder_logs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import io
3import warnings
4from unittest.mock import patch
5
6import torch
7import torch._dynamo
8import torch._dynamo.test_case
9import torch._dynamo.testing
10from torch._dynamo.testing import same
11from torch._dynamo.utils import counters
12
13
14class ReorderLogsTests(torch._dynamo.test_case.TestCase):
15    def test_dont_reorder_print(self):
16        def f(x):
17            x = x + x
18            print("moo")
19            x = x * x
20            return x
21
22        counters.clear()
23        x = torch.randn(3, 3)
24        opt_f = torch.compile(backend="eager")(f)
25        with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
26            opt_out = opt_f(x)
27            printed_output = mock_stdout.getvalue().strip()
28            orig_out = f(x)
29
30        self.assertTrue(same(orig_out, opt_out))
31        self.assertEqual(printed_output, "moo")
32        self.assertEqual(len(counters["graph_break"]), 1)
33
34    @torch._dynamo.config.patch(reorderable_logging_functions={print})
35    def test_reorder_print(self):
36        def f(x):
37            print("moo")
38            x1 = x + x
39            print(x1)
40            x2 = x1 * x1
41            print(1, 2, 3)
42            x3 = x2 + x2
43            return (x1, x3)
44
45        x = torch.ones(3, 3)
46        opt_f = torch.compile(backend="eager", fullgraph=True)(f)
47        with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
48            opt_out = opt_f(x)
49            printed_output = mock_stdout.getvalue().strip()
50            orig_out = f(x)
51
52        self.assertEqual(printed_output, f"moo\n{torch.ones(3, 3) * 2}\n1 2 3")
53        self.assertTrue(same(orig_out, opt_out))
54
55    @torch._dynamo.config.patch(reorderable_logging_functions={warnings.warn})
56    def test_reorder_warnings(self):
57        import warnings
58
59        def f(x):
60            x1 = x + x
61            warnings.warn("moo")
62            x2 = x1 * x1
63            warnings.warn(f"{x2}")
64            x3 = x2 + x2
65            return x3
66
67        x = torch.ones(3, 3)
68        opt_f = torch.compile(backend="eager", fullgraph=True)(f)
69        with warnings.catch_warnings(record=True) as w:
70            opt_out = opt_f(x)
71            warning_messages = [str(i.message) for i in w]
72            orig_out = f(x)
73
74        self.assertTrue(same(orig_out, opt_out))
75        self.assertIn("moo", warning_messages)
76
77    @torch._dynamo.config.patch(reorderable_logging_functions={print})
78    def test_reorder_print_graph_break(self):
79        def f(x):
80            x1 = x + x
81            print(f"res: {x1}")
82            x2 = x1 * x1
83            torch._dynamo.graph_break()
84            x3 = x2 + x2
85            print(1, 2, 3)
86            return x3
87
88        x = torch.ones(3, 3)
89        opt_f = torch.compile(backend="eager")(f)
90        with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
91            opt_out = opt_f(x)
92            printed_output = mock_stdout.getvalue().strip()
93            orig_out = f(x)
94
95        self.assertEqual(printed_output, f"res: {torch.ones(3, 3) * 2}\n1 2 3")
96        self.assertTrue(same(orig_out, opt_out))
97
98    def test_reorder_custom_log_fn(self):
99        custom_logs = []
100
101        def custom_log(s: str):
102            torch._dynamo.graph_break()
103            custom_logs.append(s)
104
105        def f(x):
106            custom_log("moo")
107            x1 = x + x
108            custom_log(f"{x1}")
109            return x + x
110
111        x = torch.ones(3, 3)
112        counters.clear()
113        with torch._dynamo.config.patch(reorderable_logging_functions={custom_log}):
114            opt_f = torch.compile(backend="eager")(f)
115            opt_out = opt_f(x)
116
117        self.assertEqual(sum(counters["graph_break"].values()), 1)
118        self.assertEqual(custom_logs[0], "moo")
119        self.assertEqual(custom_logs[1], f"{torch.ones(3, 3) * 2}")
120
121    @torch._dynamo.config.patch(reorderable_logging_functions={print})
122    def test_constant_mutation(self):
123        def f(x):
124            alist = [x]
125            alist.append(x + 1)
126            print(alist[-1])
127            alist[0].sum().item()  # graph break
128            res = alist.pop()
129            print(alist[-1])
130            res.sum().item()  # graph break
131            return res
132
133        inputs = (torch.tensor([1]),)
134        counters.clear()
135        opt_f = torch.compile(backend="eager")(f)
136        with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
137            opt_out = opt_f(*inputs)
138            printed_output = mock_stdout.getvalue().strip()
139            orig_out = f(*inputs)
140
141        self.assertEqual(printed_output, "tensor([2])\ntensor([1])")
142        self.assertTrue(same(orig_out, opt_out))
143
144        graph_break_key = counters["graph_break"].keys()
145        self.assertEqual(len(graph_break_key), 1)
146        self.assertEqual(next(iter(graph_break_key)), "Tensor.item")
147
148
149if __name__ == "__main__":
150    from torch._dynamo.test_case import run_tests
151
152    run_tests()
153