xref: /aosp_15_r20/external/pytorch/test/dynamo/test_profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import logging
3from unittest.mock import patch
4
5import torch
6import torch._dynamo.test_case
7import torch._dynamo.testing
8import torch._dynamo.utils
9import torch._logging
10from torch._dynamo.utils import dynamo_timed
11from torch.testing._internal.common_utils import TemporaryFileName
12
13
14class DynamoProfilerTests(torch._dynamo.test_case.TestCase):
15    def test_dynamo_timed_profiling_isolated(self):
16        # dynamo_timed functions should appear in profile traces.
17        def inner_fn(x):
18            with dynamo_timed("inner_fn"):
19                return x.sin()
20
21        def outer_fn(x, y):
22            return inner_fn(x) * y
23
24        x, y = (torch.rand((2, 2)) for _ in range(2))
25
26        with torch.profiler.profile(with_stack=False) as prof:
27            outer_fn(x, y)
28
29        self.assertTrue(
30            any("inner_fn (dynamo_timed)" in evt.name for evt in prof.events())
31        )
32
33    def test_dynamo_timed_profiling_backend_compile(self):
34        # dynamo_timed functions should appear in profile traces.
35        # this checks whether these actually appear in actual dynamo execution.
36        # "backend_compile" is just chosen as an example; if it gets renamed
37        # this test can be replaced or deleted
38
39        fn_name = "call_user_compiler"
40
41        def fn(x, y):
42            return x.sin() * y.cos()
43
44        x, y = (torch.rand((2, 2)) for _ in range(2))
45
46        with torch.profiler.profile(with_stack=False) as prof:
47            torch._dynamo.optimize("aot_eager")(fn)(x, y)
48
49        self.assertTrue(
50            any(f"{fn_name} (dynamo_timed)" in evt.name for evt in prof.events())
51        )
52
53    @patch.object(torch._dynamo.config, "assume_static_by_default", False)
54    def test_profile_dynamic_shapes_runtime(self):
55        def fn(x, y, z):
56            return x @ y + z
57
58        opt_fn = torch._dynamo.optimize("aot_eager", dynamic=True, nopython=True)(fn)
59
60        inputs = [
61            (torch.rand(a, b), torch.rand(b, c), torch.rand(a, c))
62            for (a, b, c) in [(15, 16, 17), (15, 15, 16), (16, 16, 16)]
63        ]
64
65        opt_fn(*inputs[0])
66        opt_fn(*inputs[1])
67
68        with torch.profiler.profile(record_shapes=True):
69            opt_fn(*inputs[2])
70
71    @patch.object(torch._dynamo.config, "assume_static_by_default", False)
72    def test_profile_dynamic_shapes_compilation(self):
73        def fn(x, y, z):
74            return x @ y + z
75
76        opt_fn = torch._dynamo.optimize("aot_eager", dynamic=True, nopython=True)(fn)
77
78        inputs = (torch.rand(15, 16), torch.rand(16, 17), torch.rand(15, 17))
79
80        with torch.profiler.profile(record_shapes=True):
81            opt_fn(*inputs)
82
83    @patch.object(torch._dynamo.config, "assume_static_by_default", False)
84    def test_profile_dynamic_shapes_list_compilation(self):
85        def fn(x, y, z):
86            return torch.cat([x, y], dim=0) + z
87
88        opt_fn = torch._dynamo.optimize("aot_eager", dynamic=True, nopython=True)(fn)
89
90        inputs = (torch.rand(4, 16), torch.rand(12, 16), torch.rand(16, 16))
91
92        with torch.profiler.profile(record_shapes=True):
93            opt_fn(*inputs)
94
95    def test_execution_trace_dynamic_shapes(self):
96        def fn(x, y, z):
97            return x @ y + z
98
99        et = torch.profiler.ExecutionTraceObserver()
100        opt_fn = torch.compile(fn, dynamic=True, backend="aot_eager")
101        inputs = [torch.rand((4, 4)) for _ in range(3)]
102
103        with TemporaryFileName() as fname:
104            et.register_callback(fname)
105            et.start()
106            out = opt_fn(*inputs)
107            et.stop()
108            et.unregister_callback()
109
110    def test_profiler_cache_lookup(self):
111        def fn(x):
112            y = x**2
113            y = y + 2
114            z = y**3
115            return z
116
117        for profiler, get_events in (
118            (torch.autograd.profiler.profile, lambda prof: prof.function_events),
119            (torch.profiler.profiler.profile, lambda prof: prof.events()),
120        ):
121            x = torch.randn((2, 2), requires_grad=True)
122            ref = fn(x)
123            opt_fn = torch.compile(fn, backend="aot_eager")
124
125            # warmup
126            opt_fn(x)
127
128            with profiler() as prof:
129                res = opt_fn(x)
130            events = list(
131                filter(
132                    lambda event: "TorchDynamo Cache Lookup" in event.name,
133                    get_events(prof),
134                )
135            )
136
137            self.assertEqual(ref, res)
138            self.assertTrue(
139                len(events) == 1,
140                "Expected one lookup profiler event for one opt_fn run",
141            )
142
143    def test_profiler_cache_lookup_profiler_step(self):
144        def fn(x, y, z):
145            return torch.add(torch.sub(x, y), z)
146
147        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
148
149        (
150            x,
151            y,
152            z,
153        ) = (torch.rand(4, 4) for _ in range(3))
154
155        prof = torch.profiler.profile(
156            schedule=torch.profiler.schedule(wait=2, warmup=2, active=2, repeat=1)
157        )
158
159        for _ in range(10):
160            opt_fn(x, y, z)
161            prof.step()
162
163        self.assertTrue(
164            any(e.name == "TorchDynamo Cache Lookup" for e in prof.events())
165        )
166
167    def test_profiler_dynamo_compiled_region(self):
168        torch._logging.set_logs(dynamo=logging.INFO)
169
170        def fn(x, y):
171            r = y.sum(dim=1)
172            print(r.shape)
173            return x * r
174
175        fn_c = torch.compile(fn)
176
177        with torch.profiler.profile(record_shapes=True) as prof:
178            fn_c(
179                torch.randn(10),
180                torch.randn(10, 10),
181            )
182
183            fn_c(
184                torch.randn(10),
185                torch.randn(10, 15),
186            )
187
188        for e in prof.events():
189            if e.name == "Torch-Compiled Region":
190                print(e.kwinputs)
191        self.assertTrue(
192            any(
193                e.name == "Torch-Compiled Region" and e.kwinputs["context"] == "0/0_1"
194                for e in prof.events()
195            )
196        )
197        self.assertTrue(
198            any(
199                e.name == "Torch-Compiled Region" and e.kwinputs["context"] == "1/0"
200                for e in prof.events()
201            )
202        )
203
204
205if __name__ == "__main__":
206    from torch._dynamo.test_case import run_tests
207
208    run_tests()
209