xref: /aosp_15_r20/external/pytorch/test/dynamo/test_profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerimport logging
3*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
7*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
8*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.utils
9*da0073e9SAndroid Build Coastguard Workerimport torch._logging
10*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import dynamo_timed
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TemporaryFileName
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerclass DynamoProfilerTests(torch._dynamo.test_case.TestCase):
15*da0073e9SAndroid Build Coastguard Worker    def test_dynamo_timed_profiling_isolated(self):
16*da0073e9SAndroid Build Coastguard Worker        # dynamo_timed functions should appear in profile traces.
17*da0073e9SAndroid Build Coastguard Worker        def inner_fn(x):
18*da0073e9SAndroid Build Coastguard Worker            with dynamo_timed("inner_fn"):
19*da0073e9SAndroid Build Coastguard Worker                return x.sin()
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker        def outer_fn(x, y):
22*da0073e9SAndroid Build Coastguard Worker            return inner_fn(x) * y
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker        x, y = (torch.rand((2, 2)) for _ in range(2))
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker        with torch.profiler.profile(with_stack=False) as prof:
27*da0073e9SAndroid Build Coastguard Worker            outer_fn(x, y)
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
30*da0073e9SAndroid Build Coastguard Worker            any("inner_fn (dynamo_timed)" in evt.name for evt in prof.events())
31*da0073e9SAndroid Build Coastguard Worker        )
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker    def test_dynamo_timed_profiling_backend_compile(self):
34*da0073e9SAndroid Build Coastguard Worker        # dynamo_timed functions should appear in profile traces.
35*da0073e9SAndroid Build Coastguard Worker        # this checks whether these actually appear in actual dynamo execution.
36*da0073e9SAndroid Build Coastguard Worker        # "backend_compile" is just chosen as an example; if it gets renamed
37*da0073e9SAndroid Build Coastguard Worker        # this test can be replaced or deleted
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker        fn_name = "call_user_compiler"
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
42*da0073e9SAndroid Build Coastguard Worker            return x.sin() * y.cos()
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker        x, y = (torch.rand((2, 2)) for _ in range(2))
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker        with torch.profiler.profile(with_stack=False) as prof:
47*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.optimize("aot_eager")(fn)(x, y)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
50*da0073e9SAndroid Build Coastguard Worker            any(f"{fn_name} (dynamo_timed)" in evt.name for evt in prof.events())
51*da0073e9SAndroid Build Coastguard Worker        )
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "assume_static_by_default", False)
54*da0073e9SAndroid Build Coastguard Worker    def test_profile_dynamic_shapes_runtime(self):
55*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
56*da0073e9SAndroid Build Coastguard Worker            return x @ y + z
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("aot_eager", dynamic=True, nopython=True)(fn)
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker        inputs = [
61*da0073e9SAndroid Build Coastguard Worker            (torch.rand(a, b), torch.rand(b, c), torch.rand(a, c))
62*da0073e9SAndroid Build Coastguard Worker            for (a, b, c) in [(15, 16, 17), (15, 15, 16), (16, 16, 16)]
63*da0073e9SAndroid Build Coastguard Worker        ]
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker        opt_fn(*inputs[0])
66*da0073e9SAndroid Build Coastguard Worker        opt_fn(*inputs[1])
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker        with torch.profiler.profile(record_shapes=True):
69*da0073e9SAndroid Build Coastguard Worker            opt_fn(*inputs[2])
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "assume_static_by_default", False)
72*da0073e9SAndroid Build Coastguard Worker    def test_profile_dynamic_shapes_compilation(self):
73*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
74*da0073e9SAndroid Build Coastguard Worker            return x @ y + z
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("aot_eager", dynamic=True, nopython=True)(fn)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        inputs = (torch.rand(15, 16), torch.rand(16, 17), torch.rand(15, 17))
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker        with torch.profiler.profile(record_shapes=True):
81*da0073e9SAndroid Build Coastguard Worker            opt_fn(*inputs)
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "assume_static_by_default", False)
84*da0073e9SAndroid Build Coastguard Worker    def test_profile_dynamic_shapes_list_compilation(self):
85*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
86*da0073e9SAndroid Build Coastguard Worker            return torch.cat([x, y], dim=0) + z
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("aot_eager", dynamic=True, nopython=True)(fn)
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker        inputs = (torch.rand(4, 16), torch.rand(12, 16), torch.rand(16, 16))
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker        with torch.profiler.profile(record_shapes=True):
93*da0073e9SAndroid Build Coastguard Worker            opt_fn(*inputs)
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker    def test_execution_trace_dynamic_shapes(self):
96*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
97*da0073e9SAndroid Build Coastguard Worker            return x @ y + z
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker        et = torch.profiler.ExecutionTraceObserver()
100*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, dynamic=True, backend="aot_eager")
101*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.rand((4, 4)) for _ in range(3)]
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
104*da0073e9SAndroid Build Coastguard Worker            et.register_callback(fname)
105*da0073e9SAndroid Build Coastguard Worker            et.start()
106*da0073e9SAndroid Build Coastguard Worker            out = opt_fn(*inputs)
107*da0073e9SAndroid Build Coastguard Worker            et.stop()
108*da0073e9SAndroid Build Coastguard Worker            et.unregister_callback()
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker    def test_profiler_cache_lookup(self):
111*da0073e9SAndroid Build Coastguard Worker        def fn(x):
112*da0073e9SAndroid Build Coastguard Worker            y = x**2
113*da0073e9SAndroid Build Coastguard Worker            y = y + 2
114*da0073e9SAndroid Build Coastguard Worker            z = y**3
115*da0073e9SAndroid Build Coastguard Worker            return z
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker        for profiler, get_events in (
118*da0073e9SAndroid Build Coastguard Worker            (torch.autograd.profiler.profile, lambda prof: prof.function_events),
119*da0073e9SAndroid Build Coastguard Worker            (torch.profiler.profiler.profile, lambda prof: prof.events()),
120*da0073e9SAndroid Build Coastguard Worker        ):
121*da0073e9SAndroid Build Coastguard Worker            x = torch.randn((2, 2), requires_grad=True)
122*da0073e9SAndroid Build Coastguard Worker            ref = fn(x)
123*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch.compile(fn, backend="aot_eager")
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker            # warmup
126*da0073e9SAndroid Build Coastguard Worker            opt_fn(x)
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker            with profiler() as prof:
129*da0073e9SAndroid Build Coastguard Worker                res = opt_fn(x)
130*da0073e9SAndroid Build Coastguard Worker            events = list(
131*da0073e9SAndroid Build Coastguard Worker                filter(
132*da0073e9SAndroid Build Coastguard Worker                    lambda event: "TorchDynamo Cache Lookup" in event.name,
133*da0073e9SAndroid Build Coastguard Worker                    get_events(prof),
134*da0073e9SAndroid Build Coastguard Worker                )
135*da0073e9SAndroid Build Coastguard Worker            )
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref, res)
138*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
139*da0073e9SAndroid Build Coastguard Worker                len(events) == 1,
140*da0073e9SAndroid Build Coastguard Worker                "Expected one lookup profiler event for one opt_fn run",
141*da0073e9SAndroid Build Coastguard Worker            )
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker    def test_profiler_cache_lookup_profiler_step(self):
144*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
145*da0073e9SAndroid Build Coastguard Worker            return torch.add(torch.sub(x, y), z)
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker        (
150*da0073e9SAndroid Build Coastguard Worker            x,
151*da0073e9SAndroid Build Coastguard Worker            y,
152*da0073e9SAndroid Build Coastguard Worker            z,
153*da0073e9SAndroid Build Coastguard Worker        ) = (torch.rand(4, 4) for _ in range(3))
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        prof = torch.profiler.profile(
156*da0073e9SAndroid Build Coastguard Worker            schedule=torch.profiler.schedule(wait=2, warmup=2, active=2, repeat=1)
157*da0073e9SAndroid Build Coastguard Worker        )
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker        for _ in range(10):
160*da0073e9SAndroid Build Coastguard Worker            opt_fn(x, y, z)
161*da0073e9SAndroid Build Coastguard Worker            prof.step()
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
164*da0073e9SAndroid Build Coastguard Worker            any(e.name == "TorchDynamo Cache Lookup" for e in prof.events())
165*da0073e9SAndroid Build Coastguard Worker        )
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker    def test_profiler_dynamo_compiled_region(self):
168*da0073e9SAndroid Build Coastguard Worker        torch._logging.set_logs(dynamo=logging.INFO)
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
171*da0073e9SAndroid Build Coastguard Worker            r = y.sum(dim=1)
172*da0073e9SAndroid Build Coastguard Worker            print(r.shape)
173*da0073e9SAndroid Build Coastguard Worker            return x * r
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker        fn_c = torch.compile(fn)
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker        with torch.profiler.profile(record_shapes=True) as prof:
178*da0073e9SAndroid Build Coastguard Worker            fn_c(
179*da0073e9SAndroid Build Coastguard Worker                torch.randn(10),
180*da0073e9SAndroid Build Coastguard Worker                torch.randn(10, 10),
181*da0073e9SAndroid Build Coastguard Worker            )
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker            fn_c(
184*da0073e9SAndroid Build Coastguard Worker                torch.randn(10),
185*da0073e9SAndroid Build Coastguard Worker                torch.randn(10, 15),
186*da0073e9SAndroid Build Coastguard Worker            )
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker        for e in prof.events():
189*da0073e9SAndroid Build Coastguard Worker            if e.name == "Torch-Compiled Region":
190*da0073e9SAndroid Build Coastguard Worker                print(e.kwinputs)
191*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
192*da0073e9SAndroid Build Coastguard Worker            any(
193*da0073e9SAndroid Build Coastguard Worker                e.name == "Torch-Compiled Region" and e.kwinputs["context"] == "0/0_1"
194*da0073e9SAndroid Build Coastguard Worker                for e in prof.events()
195*da0073e9SAndroid Build Coastguard Worker            )
196*da0073e9SAndroid Build Coastguard Worker        )
197*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
198*da0073e9SAndroid Build Coastguard Worker            any(
199*da0073e9SAndroid Build Coastguard Worker                e.name == "Torch-Compiled Region" and e.kwinputs["context"] == "1/0"
200*da0073e9SAndroid Build Coastguard Worker                for e in prof.events()
201*da0073e9SAndroid Build Coastguard Worker            )
202*da0073e9SAndroid Build Coastguard Worker        )
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
206*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    run_tests()
209