xref: /aosp_15_r20/external/pytorch/test/profiler/test_profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: profiler"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport collections
4*da0073e9SAndroid Build Coastguard Workerimport gc
5*da0073e9SAndroid Build Coastguard Workerimport json
6*da0073e9SAndroid Build Coastguard Workerimport mmap
7*da0073e9SAndroid Build Coastguard Workerimport os
8*da0073e9SAndroid Build Coastguard Workerimport pickle
9*da0073e9SAndroid Build Coastguard Workerimport random
10*da0073e9SAndroid Build Coastguard Workerimport re
11*da0073e9SAndroid Build Coastguard Workerimport struct
12*da0073e9SAndroid Build Coastguard Workerimport subprocess
13*da0073e9SAndroid Build Coastguard Workerimport sys
14*da0073e9SAndroid Build Coastguard Workerimport tempfile
15*da0073e9SAndroid Build Coastguard Workerimport threading
16*da0073e9SAndroid Build Coastguard Workerimport time
17*da0073e9SAndroid Build Coastguard Workerimport unittest
18*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass, field
19*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Optional
20*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workerimport expecttest
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerimport torch
25*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
26*da0073e9SAndroid Build Coastguard Workerimport torch.optim
27*da0073e9SAndroid Build Coastguard Workerimport torch.utils.data
28*da0073e9SAndroid Build Coastguard Workerfrom torch._C._profiler import _ExperimentalConfig, _ExtraFields_PyCall
29*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.profiler import KinetoStepTracker, profile as _profile
30*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.profiler_legacy import profile as _profile_legacy
31*da0073e9SAndroid Build Coastguard Workerfrom torch.profiler import (
32*da0073e9SAndroid Build Coastguard Worker    _utils,
33*da0073e9SAndroid Build Coastguard Worker    DeviceType,
34*da0073e9SAndroid Build Coastguard Worker    kineto_available,
35*da0073e9SAndroid Build Coastguard Worker    profile,
36*da0073e9SAndroid Build Coastguard Worker    ProfilerAction,
37*da0073e9SAndroid Build Coastguard Worker    ProfilerActivity,
38*da0073e9SAndroid Build Coastguard Worker    record_function,
39*da0073e9SAndroid Build Coastguard Worker    supported_activities,
40*da0073e9SAndroid Build Coastguard Worker)
41*da0073e9SAndroid Build Coastguard Workerfrom torch.profiler._pattern_matcher import (
42*da0073e9SAndroid Build Coastguard Worker    Conv2dBiasFollowedByBatchNorm2dPattern,
43*da0073e9SAndroid Build Coastguard Worker    ExtraCUDACopyPattern,
44*da0073e9SAndroid Build Coastguard Worker    ForLoopIndexingPattern,
45*da0073e9SAndroid Build Coastguard Worker    FP32MatMulPattern,
46*da0073e9SAndroid Build Coastguard Worker    GradNotSetToNonePattern,
47*da0073e9SAndroid Build Coastguard Worker    MatMulDimInFP16Pattern,
48*da0073e9SAndroid Build Coastguard Worker    NamePattern,
49*da0073e9SAndroid Build Coastguard Worker    OptimizerSingleTensorPattern,
50*da0073e9SAndroid Build Coastguard Worker    Pattern,
51*da0073e9SAndroid Build Coastguard Worker    report_all_anti_patterns,
52*da0073e9SAndroid Build Coastguard Worker    SynchronizedDataLoaderPattern,
53*da0073e9SAndroid Build Coastguard Worker)
54*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_MULTIGPU
55*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import skipCUDAVersionIn
56*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
57*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
58*da0073e9SAndroid Build Coastguard Worker    IS_ARM64,
59*da0073e9SAndroid Build Coastguard Worker    IS_JETSON,
60*da0073e9SAndroid Build Coastguard Worker    IS_LINUX,
61*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
62*da0073e9SAndroid Build Coastguard Worker    parametrize,
63*da0073e9SAndroid Build Coastguard Worker    run_tests,
64*da0073e9SAndroid Build Coastguard Worker    serialTest,
65*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
66*da0073e9SAndroid Build Coastguard Worker    TemporaryDirectoryName,
67*da0073e9SAndroid Build Coastguard Worker    TemporaryFileName,
68*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ASAN,
69*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_CROSSREF,
70*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
71*da0073e9SAndroid Build Coastguard Worker    TestCase,
72*da0073e9SAndroid Build Coastguard Worker)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker# if tqdm is not shutdown properly, it will leave the monitor thread alive.
76*da0073e9SAndroid Build Coastguard Worker# This causes an issue in the multithreading test because we check all events
77*da0073e9SAndroid Build Coastguard Worker# in that test with their tids. The events that correspond to these lingering
78*da0073e9SAndroid Build Coastguard Worker# threads all have TID of (uint64_t)(-1) which is invalid.
79*da0073e9SAndroid Build Coastguard Worker# The work around is turnning off monitoring thread when tqdm is loaded.
80*da0073e9SAndroid Build Coastguard Worker# Since these are unit tests, it is safe to turn off monitor thread.
81*da0073e9SAndroid Build Coastguard Workertry:
82*da0073e9SAndroid Build Coastguard Worker    import tqdm
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    tqdm.tqdm.monitor_interval = 0
85*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
86*da0073e9SAndroid Build Coastguard Worker    pass
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Workertry:
89*da0073e9SAndroid Build Coastguard Worker    import psutil
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    HAS_PSUTIL = True
92*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError:
93*da0073e9SAndroid Build Coastguard Worker    HAS_PSUTIL = False
94*da0073e9SAndroid Build Coastguard Worker    psutil = None
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run")
98*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN")
99*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_WINDOWS, "Test is flaky on Windows")
100*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
101*da0073e9SAndroid Build Coastguard Workerclass TestProfilerCUDA(TestCase):
102*da0073e9SAndroid Build Coastguard Worker    @skipCUDAVersionIn([(11, 5)])  # https://github.com/pytorch/pytorch/issues/69023
103*da0073e9SAndroid Build Coastguard Worker    def test_mem_leak(self):
104*da0073e9SAndroid Build Coastguard Worker        """Checks that there's no memory leak when using profiler with CUDA"""
105*da0073e9SAndroid Build Coastguard Worker        t = torch.rand(1, 1).cuda()
106*da0073e9SAndroid Build Coastguard Worker        p = psutil.Process()
107*da0073e9SAndroid Build Coastguard Worker        last_rss = collections.deque(maxlen=5)
108*da0073e9SAndroid Build Coastguard Worker        for outer_idx in range(10):
109*da0073e9SAndroid Build Coastguard Worker            with _profile(use_cuda=True):
110*da0073e9SAndroid Build Coastguard Worker                for _ in range(1024):
111*da0073e9SAndroid Build Coastguard Worker                    t = torch.mm(t, t)
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker            gc.collect()
114*da0073e9SAndroid Build Coastguard Worker            torch.cuda.empty_cache()
115*da0073e9SAndroid Build Coastguard Worker            last_rss.append(p.memory_info().rss)
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker        # with CUDA events leaking the increase in memory was ~7 MB between
118*da0073e9SAndroid Build Coastguard Worker        # profiler invocations above
119*da0073e9SAndroid Build Coastguard Worker        is_increasing = all(
120*da0073e9SAndroid Build Coastguard Worker            last_rss[idx] > last_rss[idx - 1] for idx in range(1, len(last_rss))
121*da0073e9SAndroid Build Coastguard Worker        )
122*da0073e9SAndroid Build Coastguard Worker        max_diff = -1
123*da0073e9SAndroid Build Coastguard Worker        for idx in range(1, len(last_rss)):
124*da0073e9SAndroid Build Coastguard Worker            max_diff = max(max_diff, last_rss[idx] - last_rss[idx - 1])
125*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
126*da0073e9SAndroid Build Coastguard Worker            not (is_increasing and max_diff > 100 * 1024),
127*da0073e9SAndroid Build Coastguard Worker            msg=f"memory usage is increasing, {str(last_rss)}",
128*da0073e9SAndroid Build Coastguard Worker        )
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker    def test_custom_module_input_op_ids(self):
131*da0073e9SAndroid Build Coastguard Worker        class MyFunc(torch.autograd.Function):
132*da0073e9SAndroid Build Coastguard Worker            @staticmethod
133*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
134*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x)
135*da0073e9SAndroid Build Coastguard Worker                return x
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker            @staticmethod
138*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
139*da0073e9SAndroid Build Coastguard Worker                (x,) = ctx.saved_tensors
140*da0073e9SAndroid Build Coastguard Worker                return x
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        def custom_layer(input_ten):
143*da0073e9SAndroid Build Coastguard Worker            return MyFunc.apply(input_ten)
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker        # Only testing that emit_nvtx runs when
146*da0073e9SAndroid Build Coastguard Worker        # record_shapes option is enabled.
147*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.profiler.emit_nvtx(record_shapes=True) as prof:
148*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(10, 10, requires_grad=True)
149*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(10, 10, requires_grad=True)
150*da0073e9SAndroid Build Coastguard Worker            z = x + y
151*da0073e9SAndroid Build Coastguard Worker            s = custom_layer(z)
152*da0073e9SAndroid Build Coastguard Worker            q = s.sum()
153*da0073e9SAndroid Build Coastguard Worker            q.backward()
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
156*da0073e9SAndroid Build Coastguard Worker    def test_cudagraph_profiling_workaround(self):
157*da0073e9SAndroid Build Coastguard Worker        import subprocess
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker        # repro taken from #75504
160*da0073e9SAndroid Build Coastguard Worker        # Launch in a separate process to catch hanging/illegal memory errors
161*da0073e9SAndroid Build Coastguard Worker        # and to make sure CUPTI isn't already initialized.
162*da0073e9SAndroid Build Coastguard Worker        p = subprocess.check_call(
163*da0073e9SAndroid Build Coastguard Worker            [
164*da0073e9SAndroid Build Coastguard Worker                sys.executable,
165*da0073e9SAndroid Build Coastguard Worker                "-c",
166*da0073e9SAndroid Build Coastguard Worker                """
167*da0073e9SAndroid Build Coastguard Workerimport os
168*da0073e9SAndroid Build Coastguard Workerimport torch
169*da0073e9SAndroid Build Coastguard Workerfrom torch.profiler import ProfilerActivity, profile
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Workerdef add_one(in_: torch.Tensor):
172*da0073e9SAndroid Build Coastguard Worker    return in_ + 1
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Workersample_arg = torch.zeros(10, device="cuda").requires_grad_(True)
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker# add this before cuda graphs are created
177*da0073e9SAndroid Build Coastguard Workertorch.profiler._utils._init_for_cuda_graphs()
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Workeradd_one_graphed = torch.cuda.graphs.make_graphed_callables(add_one, sample_args=(sample_arg,))
180*da0073e9SAndroid Build Coastguard Workerzeros = torch.zeros(10, device="cuda")
181*da0073e9SAndroid Build Coastguard Workerout = add_one_graphed(zeros)
182*da0073e9SAndroid Build Coastguard Workerassert out[0] == 1
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Workerwith profile(activities=[ProfilerActivity.CPU]):
185*da0073e9SAndroid Build Coastguard Worker    add_one_graphed(zeros)
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Workerwith profile(activities=[ProfilerActivity.CUDA]):
188*da0073e9SAndroid Build Coastguard Worker    add_one_graphed(zeros)
189*da0073e9SAndroid Build Coastguard Worker""",
190*da0073e9SAndroid Build Coastguard Worker            ],
191*da0073e9SAndroid Build Coastguard Worker            universal_newlines=True,
192*da0073e9SAndroid Build Coastguard Worker            timeout=60,
193*da0073e9SAndroid Build Coastguard Worker        )
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker        # ^ this will throw an exception if the script fails.
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.profiler.itt.is_available(), "ITT is required")
199*da0073e9SAndroid Build Coastguard Workerclass TestProfilerITT(TestCase):
200*da0073e9SAndroid Build Coastguard Worker    def test_custom_module_input_op_ids(self):
201*da0073e9SAndroid Build Coastguard Worker        class MyFunc(torch.autograd.Function):
202*da0073e9SAndroid Build Coastguard Worker            @staticmethod
203*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
204*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x)
205*da0073e9SAndroid Build Coastguard Worker                return x
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker            @staticmethod
208*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
209*da0073e9SAndroid Build Coastguard Worker                (x,) = ctx.saved_tensors
210*da0073e9SAndroid Build Coastguard Worker                return x
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker        def custom_layer(input_ten):
213*da0073e9SAndroid Build Coastguard Worker            return MyFunc.apply(input_ten)
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker        # Only testing that emit_itt runs when
216*da0073e9SAndroid Build Coastguard Worker        # record_shapes option is enabled.
217*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.profiler.emit_itt(record_shapes=True) as prof:
218*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(10, 10, requires_grad=True)
219*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(10, 10, requires_grad=True)
220*da0073e9SAndroid Build Coastguard Worker            z = x + y
221*da0073e9SAndroid Build Coastguard Worker            s = custom_layer(z)
222*da0073e9SAndroid Build Coastguard Worker            q = s.sum()
223*da0073e9SAndroid Build Coastguard Worker            q.backward()
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker@instantiate_parametrized_tests
227*da0073e9SAndroid Build Coastguard Workerclass TestProfiler(TestCase):
228*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
229*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
230*da0073e9SAndroid Build Coastguard Worker    )
231*da0073e9SAndroid Build Coastguard Worker    def test_source(self):
232*da0073e9SAndroid Build Coastguard Worker        """Checks that source code attribution works for eager, TS and autograd mode"""
233*da0073e9SAndroid Build Coastguard Worker        # avoid automatic inlining
234*da0073e9SAndroid Build Coastguard Worker        prev_opt = torch._C._get_graph_executor_optimize()
235*da0073e9SAndroid Build Coastguard Worker        torch._C._set_graph_executor_optimize(False)
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
238*da0073e9SAndroid Build Coastguard Worker        def ts_method_2(x, y):
239*da0073e9SAndroid Build Coastguard Worker            return torch.matmul(x, y)
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
242*da0073e9SAndroid Build Coastguard Worker        def ts_method_1(x, y, z):
243*da0073e9SAndroid Build Coastguard Worker            a = x + z
244*da0073e9SAndroid Build Coastguard Worker            w = ts_method_2(x, y) + a
245*da0073e9SAndroid Build Coastguard Worker            return w.sum()
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker        class DummyModule(nn.Module):
248*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
249*da0073e9SAndroid Build Coastguard Worker                super().__init__()
250*da0073e9SAndroid Build Coastguard Worker                self.conv = torch.nn.Conv2d(
251*da0073e9SAndroid Build Coastguard Worker                    3, 2, kernel_size=1, stride=2, padding=3, bias=False
252*da0073e9SAndroid Build Coastguard Worker                )
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
255*da0073e9SAndroid Build Coastguard Worker                return self.conv(x)
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker        mod = DummyModule()
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker        def call_module(x):
260*da0073e9SAndroid Build Coastguard Worker            return mod(x)
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        with _profile(
263*da0073e9SAndroid Build Coastguard Worker            with_stack=True,
264*da0073e9SAndroid Build Coastguard Worker            use_kineto=kineto_available(),
265*da0073e9SAndroid Build Coastguard Worker            experimental_config=_ExperimentalConfig(verbose=True),
266*da0073e9SAndroid Build Coastguard Worker        ) as p:
267*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(10, 10, requires_grad=True)
268*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(10, 10, requires_grad=True)
269*da0073e9SAndroid Build Coastguard Worker            z = x + y
270*da0073e9SAndroid Build Coastguard Worker            w = ts_method_1(x, y, z)
271*da0073e9SAndroid Build Coastguard Worker            v = 2 * w
272*da0073e9SAndroid Build Coastguard Worker            v.backward()
273*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(2, 3, 2, 2, requires_grad=True)
274*da0073e9SAndroid Build Coastguard Worker            b = call_module(a)
275*da0073e9SAndroid Build Coastguard Worker            c = b.sum()
276*da0073e9SAndroid Build Coastguard Worker            c.backward()
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        for e in p.function_events:
279*da0073e9SAndroid Build Coastguard Worker            if "aten::add" in e.name or "AddBackward" in e.name:
280*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(any("test_profiler" in entry for entry in e.stack))
281*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
282*da0073e9SAndroid Build Coastguard Worker                    any(
283*da0073e9SAndroid Build Coastguard Worker                        (
284*da0073e9SAndroid Build Coastguard Worker                            "test_source" in entry
285*da0073e9SAndroid Build Coastguard Worker                            or "ts_method_1" in entry
286*da0073e9SAndroid Build Coastguard Worker                            or "ts_method_2" in entry
287*da0073e9SAndroid Build Coastguard Worker                        )
288*da0073e9SAndroid Build Coastguard Worker                        for entry in e.stack
289*da0073e9SAndroid Build Coastguard Worker                    )
290*da0073e9SAndroid Build Coastguard Worker                )
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        # TODO: https://github.com/pytorch/kineto/issues/617
293*da0073e9SAndroid Build Coastguard Worker        if kineto_available() and not IS_WINDOWS:
294*da0073e9SAndroid Build Coastguard Worker            with TemporaryFileName(mode="w+") as fname:
295*da0073e9SAndroid Build Coastguard Worker                p.export_chrome_trace(fname)
296*da0073e9SAndroid Build Coastguard Worker                with open(fname) as f:
297*da0073e9SAndroid Build Coastguard Worker                    events = json.load(f)["traceEvents"]
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker                def extract(pattern: str):
300*da0073e9SAndroid Build Coastguard Worker                    matches = [e for e in events if re.search(pattern, e["name"])]
301*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
302*da0073e9SAndroid Build Coastguard Worker                        len(matches), 1, repr([e["name"] for e in matches])
303*da0073e9SAndroid Build Coastguard Worker                    )
304*da0073e9SAndroid Build Coastguard Worker                    return matches[0]
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker                module_event = extract(r"DummyModule_0")
307*da0073e9SAndroid Build Coastguard Worker                wrapper_event = extract(r"call_module")
308*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
309*da0073e9SAndroid Build Coastguard Worker                    module_event["args"]["Python parent id"],
310*da0073e9SAndroid Build Coastguard Worker                    wrapper_event["args"]["Python id"],
311*da0073e9SAndroid Build Coastguard Worker                )
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker        torch._C._set_graph_executor_optimize(prev_opt)
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker    @parametrize(
316*da0073e9SAndroid Build Coastguard Worker        "name,thread_spec",
317*da0073e9SAndroid Build Coastguard Worker        {
318*da0073e9SAndroid Build Coastguard Worker            "basic": ((False, False),),
319*da0073e9SAndroid Build Coastguard Worker            "multiple_preexisting": ((False, False),) * 2,
320*da0073e9SAndroid Build Coastguard Worker            "open_in_scope": ((True, False),),
321*da0073e9SAndroid Build Coastguard Worker            "close_in_scope": ((False, True),),
322*da0073e9SAndroid Build Coastguard Worker            "complex": (
323*da0073e9SAndroid Build Coastguard Worker                # Large number of background threads
324*da0073e9SAndroid Build Coastguard Worker                (False, False),
325*da0073e9SAndroid Build Coastguard Worker                (False, False),
326*da0073e9SAndroid Build Coastguard Worker                (False, False),
327*da0073e9SAndroid Build Coastguard Worker                (False, False),
328*da0073e9SAndroid Build Coastguard Worker                # some of which finish during profiling
329*da0073e9SAndroid Build Coastguard Worker                (False, True),
330*da0073e9SAndroid Build Coastguard Worker                (False, True),
331*da0073e9SAndroid Build Coastguard Worker                # And the profiled section is also multithreaded
332*da0073e9SAndroid Build Coastguard Worker                (True, False),
333*da0073e9SAndroid Build Coastguard Worker                (True, True),
334*da0073e9SAndroid Build Coastguard Worker            ),
335*da0073e9SAndroid Build Coastguard Worker        }.items(),
336*da0073e9SAndroid Build Coastguard Worker        name_fn=lambda name, thread_spec: name,
337*da0073e9SAndroid Build Coastguard Worker    )
338*da0073e9SAndroid Build Coastguard Worker    @serialTest()
339*da0073e9SAndroid Build Coastguard Worker    @parametrize("work_in_main_thread", [True, False])
340*da0073e9SAndroid Build Coastguard Worker    def test_source_multithreaded(self, name, thread_spec, work_in_main_thread):
341*da0073e9SAndroid Build Coastguard Worker        """Test various threading configurations.
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker        `thread_spec` is a Tuple[Tuple[bool, bool], ...] where each pair is a
344*da0073e9SAndroid Build Coastguard Worker        thread. The first bool indicates if the thread should be started under
345*da0073e9SAndroid Build Coastguard Worker        the profiler context and the second is if it should be joined under the
346*da0073e9SAndroid Build Coastguard Worker        profiler context.
347*da0073e9SAndroid Build Coastguard Worker        """
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker        timeout = 15
350*da0073e9SAndroid Build Coastguard Worker        num_threads = len(thread_spec) + 1  # Main thread
351*da0073e9SAndroid Build Coastguard Worker        start_barrier = threading.Barrier(num_threads, timeout=timeout)
352*da0073e9SAndroid Build Coastguard Worker        end_barrier = threading.Barrier(num_threads, timeout=timeout)
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker        class Task(threading.Thread):
355*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
356*da0073e9SAndroid Build Coastguard Worker                self._end_gate = threading.Event()
357*da0073e9SAndroid Build Coastguard Worker                super().__init__(daemon=True)
358*da0073e9SAndroid Build Coastguard Worker                self.start()
359*da0073e9SAndroid Build Coastguard Worker                self.finished = False
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker            def run(self):
362*da0073e9SAndroid Build Coastguard Worker                self._run(self._end_gate)
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker            def release(self):
365*da0073e9SAndroid Build Coastguard Worker                self._end_gate.set()
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker            @staticmethod
368*da0073e9SAndroid Build Coastguard Worker            def _run(end_gate=None):
369*da0073e9SAndroid Build Coastguard Worker                def known_preexisting_function():
370*da0073e9SAndroid Build Coastguard Worker                    start_barrier.wait()
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker                # Fixed point that we can use to test capture of functions
373*da0073e9SAndroid Build Coastguard Worker                # which are already running when profiling is enabled.
374*da0073e9SAndroid Build Coastguard Worker                known_preexisting_function()
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker                model = torch.nn.Sequential(
377*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(10, 10),
378*da0073e9SAndroid Build Coastguard Worker                    torch.nn.ReLU(),
379*da0073e9SAndroid Build Coastguard Worker                )
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker                def invoked_during_run():
382*da0073e9SAndroid Build Coastguard Worker                    pass
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker                invoked_during_run()
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker                _ = model(torch.rand(4, 10))
387*da0073e9SAndroid Build Coastguard Worker                end_barrier.wait()
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker                if end_gate is not None:
390*da0073e9SAndroid Build Coastguard Worker                    end_gate.wait(timeout=timeout)
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker        threads = {}
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker        def add_threads(context: bool):
395*da0073e9SAndroid Build Coastguard Worker            for idx, (start_under_profiler, _) in enumerate(thread_spec):
396*da0073e9SAndroid Build Coastguard Worker                if start_under_profiler == context:
397*da0073e9SAndroid Build Coastguard Worker                    assert idx not in threads
398*da0073e9SAndroid Build Coastguard Worker                    threads[idx] = Task()
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker        def join_threads(context: bool):
401*da0073e9SAndroid Build Coastguard Worker            for idx, (_, end_under_profiler) in enumerate(thread_spec):
402*da0073e9SAndroid Build Coastguard Worker                if end_under_profiler == context:
403*da0073e9SAndroid Build Coastguard Worker                    threads[idx].release()
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker            for idx, (_, end_under_profiler) in enumerate(thread_spec):
406*da0073e9SAndroid Build Coastguard Worker                t = threads[idx]
407*da0073e9SAndroid Build Coastguard Worker                if end_under_profiler == context:
408*da0073e9SAndroid Build Coastguard Worker                    t.join(timeout=timeout)
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker        try:
411*da0073e9SAndroid Build Coastguard Worker            add_threads(False)
412*da0073e9SAndroid Build Coastguard Worker            with torch.profiler.profile(with_stack=True) as prof:
413*da0073e9SAndroid Build Coastguard Worker                # Threads added while the profiler are running will not be observed
414*da0073e9SAndroid Build Coastguard Worker                # since there is no way to hook into Python's thread start call to
415*da0073e9SAndroid Build Coastguard Worker                # register the observer. These are here purely to verify safety.
416*da0073e9SAndroid Build Coastguard Worker                add_threads(True)
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker                if work_in_main_thread:
419*da0073e9SAndroid Build Coastguard Worker                    Task._run()
420*da0073e9SAndroid Build Coastguard Worker                else:
421*da0073e9SAndroid Build Coastguard Worker                    start_barrier.wait()
422*da0073e9SAndroid Build Coastguard Worker                    end_barrier.wait()
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker                join_threads(True)
425*da0073e9SAndroid Build Coastguard Worker            join_threads(False)
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker        finally:
428*da0073e9SAndroid Build Coastguard Worker            # It is very important that we clean up everything because the
429*da0073e9SAndroid Build Coastguard Worker            # Python tracer will detect ALL active threads. (Even orphans from
430*da0073e9SAndroid Build Coastguard Worker            # prior failed tests.) If we don't clean up properly we can
431*da0073e9SAndroid Build Coastguard Worker            # contaminate subsequent tests.
432*da0073e9SAndroid Build Coastguard Worker            start_barrier.abort()
433*da0073e9SAndroid Build Coastguard Worker            end_barrier.abort()
434*da0073e9SAndroid Build Coastguard Worker            for t in threads.values():
435*da0073e9SAndroid Build Coastguard Worker                t.release()
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker            for t in threads.values():
438*da0073e9SAndroid Build Coastguard Worker                t.join(timeout=timeout)
439*da0073e9SAndroid Build Coastguard Worker
440*da0073e9SAndroid Build Coastguard Worker            for t in threads.values():
441*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(t.is_alive())
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker        roots = prof.profiler.kineto_results.experimental_event_tree()
444*da0073e9SAndroid Build Coastguard Worker        nodes = [
445*da0073e9SAndroid Build Coastguard Worker            node
446*da0073e9SAndroid Build Coastguard Worker            for node in _utils.traverse_dfs(roots)
447*da0073e9SAndroid Build Coastguard Worker            if isinstance(node.extra_fields, _ExtraFields_PyCall)
448*da0073e9SAndroid Build Coastguard Worker        ]
449*da0073e9SAndroid Build Coastguard Worker        tid_counts = collections.Counter([node.start_tid for node in nodes])
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker        prior_threads = sum(
452*da0073e9SAndroid Build Coastguard Worker            not start_under_profiler for start_under_profiler, _ in thread_spec
453*da0073e9SAndroid Build Coastguard Worker        )
454*da0073e9SAndroid Build Coastguard Worker        expected_threads = prior_threads + 1
455*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
456*da0073e9SAndroid Build Coastguard Worker            len(tid_counts), expected_threads, f"{expected_threads}, {tid_counts}"
457*da0073e9SAndroid Build Coastguard Worker        )
458*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(nodes), sum(tid_counts.values()))
459*da0073e9SAndroid Build Coastguard Worker
460*da0073e9SAndroid Build Coastguard Worker        # Profiler uses uint64_t max as a placeholder until TID can be determined.
461*da0073e9SAndroid Build Coastguard Worker        no_tid = 2**64 - 1
462*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(no_tid in tid_counts)
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker        worker_threads = prior_threads + (1 if work_in_main_thread else 0)
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker        observed_preexisting = [
467*da0073e9SAndroid Build Coastguard Worker            node.start_tid
468*da0073e9SAndroid Build Coastguard Worker            for node in nodes
469*da0073e9SAndroid Build Coastguard Worker            if "known_preexisting_function" in node.name
470*da0073e9SAndroid Build Coastguard Worker        ]
471*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(observed_preexisting), worker_threads)
472*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(observed_preexisting), len(set(observed_preexisting)))
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker        observed_during_run = [
475*da0073e9SAndroid Build Coastguard Worker            node.start_tid for node in nodes if "invoked_during_run" in node.name
476*da0073e9SAndroid Build Coastguard Worker        ]
477*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(observed_during_run), worker_threads)
478*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(observed_during_run), len(set(observed_during_run)))
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    def payload(self, use_cuda=False):
481*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
482*da0073e9SAndroid Build Coastguard Worker        if use_cuda:
483*da0073e9SAndroid Build Coastguard Worker            x = x.cuda()
484*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10, 10)
485*da0073e9SAndroid Build Coastguard Worker        if use_cuda:
486*da0073e9SAndroid Build Coastguard Worker            y = y.cuda()
487*da0073e9SAndroid Build Coastguard Worker        z = torch.mm(x, y)
488*da0073e9SAndroid Build Coastguard Worker        z = z + y
489*da0073e9SAndroid Build Coastguard Worker        if use_cuda:
490*da0073e9SAndroid Build Coastguard Worker            z = z.cpu()
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker    def _check_stats(self, profiler_stats):
493*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(profiler_stats.profiling_window_duration_sec, 0)
494*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(profiler_stats.number_of_events, 0)
495*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(profiler_stats.profiler_prepare_call_duration_us, 0)
496*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(profiler_stats.profiler_enable_call_duration_us, 0)
497*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(profiler_stats.profiler_disable_call_duration_us, 0)
498*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(profiler_stats.parse_kineto_call_duration_us, 0)
499*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(
500*da0073e9SAndroid Build Coastguard Worker            profiler_stats.function_events_build_tree_call_duration_us, 0
501*da0073e9SAndroid Build Coastguard Worker        )
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not kineto_available(), "Kineto is required")
504*da0073e9SAndroid Build Coastguard Worker    def test_kineto(self):
505*da0073e9SAndroid Build Coastguard Worker        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
506*da0073e9SAndroid Build Coastguard Worker        with _profile(use_cuda=use_cuda, use_kineto=True):
507*da0073e9SAndroid Build Coastguard Worker            self.payload(use_cuda=use_cuda)
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker        # rerun to avoid initial start overhead
510*da0073e9SAndroid Build Coastguard Worker        with _profile(use_cuda=use_cuda, use_kineto=True) as p:
511*da0073e9SAndroid Build Coastguard Worker            self.payload(use_cuda=use_cuda)
512*da0073e9SAndroid Build Coastguard Worker
513*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("aten::mm" in str(p))
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker        output = p.key_averages().table(
516*da0073e9SAndroid Build Coastguard Worker            sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total",
517*da0073e9SAndroid Build Coastguard Worker            row_limit=-1,
518*da0073e9SAndroid Build Coastguard Worker        )
519*da0073e9SAndroid Build Coastguard Worker        # print(output)
520*da0073e9SAndroid Build Coastguard Worker        found_gemm = False
521*da0073e9SAndroid Build Coastguard Worker        found_memcpy = False
522*da0073e9SAndroid Build Coastguard Worker        found_mm = False
523*da0073e9SAndroid Build Coastguard Worker        for e in p.function_events:
524*da0073e9SAndroid Build Coastguard Worker            if "aten::mm" in e.name:
525*da0073e9SAndroid Build Coastguard Worker                found_mm = True
526*da0073e9SAndroid Build Coastguard Worker            if "gemm" in e.name.lower() or "Cijk" in e.name:
527*da0073e9SAndroid Build Coastguard Worker                found_gemm = True
528*da0073e9SAndroid Build Coastguard Worker            if "memcpy" in e.name.lower():
529*da0073e9SAndroid Build Coastguard Worker                found_memcpy = True
530*da0073e9SAndroid Build Coastguard Worker        if use_cuda:
531*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(found_gemm)
532*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(found_memcpy)
533*da0073e9SAndroid Build Coastguard Worker        else:
534*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(found_mm)
535*da0073e9SAndroid Build Coastguard Worker        self._check_stats(p._stats)
536*da0073e9SAndroid Build Coastguard Worker        # p.export_chrome_trace("/tmp/test_trace.json")
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not kineto_available(), "Kineto is required")
539*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "Multiple GPUs needed")
540*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
541*da0073e9SAndroid Build Coastguard Worker    def test_kineto_multigpu(self):
542*da0073e9SAndroid Build Coastguard Worker        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
543*da0073e9SAndroid Build Coastguard Worker            for gpu_id in [0, 1]:
544*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(10, 10).cuda(gpu_id)
545*da0073e9SAndroid Build Coastguard Worker                y = torch.randn(10, 10).cuda(gpu_id)
546*da0073e9SAndroid Build Coastguard Worker                z = x.matmul(y)
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker        found_gemm_0 = False
549*da0073e9SAndroid Build Coastguard Worker        found_gemm_1 = False
550*da0073e9SAndroid Build Coastguard Worker        found_cuda = False
551*da0073e9SAndroid Build Coastguard Worker        for evt in prof.events():
552*da0073e9SAndroid Build Coastguard Worker            if "gemm" in evt.name.lower() and evt.device_type == DeviceType.CUDA:
553*da0073e9SAndroid Build Coastguard Worker                if evt.device_index == 0:
554*da0073e9SAndroid Build Coastguard Worker                    found_gemm_0 = True
555*da0073e9SAndroid Build Coastguard Worker                elif evt.device_index == 1:
556*da0073e9SAndroid Build Coastguard Worker                    found_gemm_1 = True
557*da0073e9SAndroid Build Coastguard Worker            if "cuda" in evt.name.lower() and evt.device_type == DeviceType.CPU:
558*da0073e9SAndroid Build Coastguard Worker                found_cuda = True
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found_gemm_0)
561*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found_gemm_1)
562*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found_cuda)
563*da0073e9SAndroid Build Coastguard Worker        self._check_stats(prof._stats())
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker    def test_memory_profiler(self):
566*da0073e9SAndroid Build Coastguard Worker        def run_profiler(tensor_creation_fn):
567*da0073e9SAndroid Build Coastguard Worker            # collecting allocs / deallocs
568*da0073e9SAndroid Build Coastguard Worker            with _profile(
569*da0073e9SAndroid Build Coastguard Worker                profile_memory=True,
570*da0073e9SAndroid Build Coastguard Worker                record_shapes=True,
571*da0073e9SAndroid Build Coastguard Worker                use_kineto=kineto_available(),
572*da0073e9SAndroid Build Coastguard Worker            ) as prof:
573*da0073e9SAndroid Build Coastguard Worker                x = None
574*da0073e9SAndroid Build Coastguard Worker                with record_function("test_user_scope_alloc"):
575*da0073e9SAndroid Build Coastguard Worker                    x = tensor_creation_fn()
576*da0073e9SAndroid Build Coastguard Worker                with record_function("test_user_scope_dealloc"):
577*da0073e9SAndroid Build Coastguard Worker                    del x
578*da0073e9SAndroid Build Coastguard Worker            return prof.key_averages(group_by_input_shape=True)
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker        def check_metrics(stats, metric, allocs=None, deallocs=None):
581*da0073e9SAndroid Build Coastguard Worker            stat_metrics = {}
582*da0073e9SAndroid Build Coastguard Worker            # print(stats)
583*da0073e9SAndroid Build Coastguard Worker            for stat in stats:
584*da0073e9SAndroid Build Coastguard Worker                stat_metrics[stat.key] = getattr(stat, metric)
585*da0073e9SAndroid Build Coastguard Worker            # print(stat_metrics)
586*da0073e9SAndroid Build Coastguard Worker            if allocs is not None:
587*da0073e9SAndroid Build Coastguard Worker                for alloc_fn in allocs:
588*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(alloc_fn in stat_metrics)
589*da0073e9SAndroid Build Coastguard Worker                    self.assertGreater(
590*da0073e9SAndroid Build Coastguard Worker                        stat_metrics[alloc_fn], 0, f"alloc_fn = {alloc_fn}"
591*da0073e9SAndroid Build Coastguard Worker                    )
592*da0073e9SAndroid Build Coastguard Worker            if deallocs is not None:
593*da0073e9SAndroid Build Coastguard Worker                for dealloc_fn in deallocs:
594*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(dealloc_fn in stat_metrics)
595*da0073e9SAndroid Build Coastguard Worker                    self.assertLess(
596*da0073e9SAndroid Build Coastguard Worker                        stat_metrics[dealloc_fn], 0, f"alloc_fn = {dealloc_fn}"
597*da0073e9SAndroid Build Coastguard Worker                    )
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker        def create_cpu_tensor():
600*da0073e9SAndroid Build Coastguard Worker            return torch.rand(10, 10)
601*da0073e9SAndroid Build Coastguard Worker
602*da0073e9SAndroid Build Coastguard Worker        def create_cuda_tensor():
603*da0073e9SAndroid Build Coastguard Worker            return torch.rand(10, 10).cuda()
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker        def create_mkldnn_tensor():
606*da0073e9SAndroid Build Coastguard Worker            return torch.rand(10, 10, dtype=torch.float32).to_mkldnn()
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker        stats = run_profiler(create_cpu_tensor)
609*da0073e9SAndroid Build Coastguard Worker        check_metrics(
610*da0073e9SAndroid Build Coastguard Worker            stats,
611*da0073e9SAndroid Build Coastguard Worker            "cpu_memory_usage",
612*da0073e9SAndroid Build Coastguard Worker            allocs=[
613*da0073e9SAndroid Build Coastguard Worker                "aten::empty",
614*da0073e9SAndroid Build Coastguard Worker                "aten::rand",
615*da0073e9SAndroid Build Coastguard Worker                "test_user_scope_alloc",
616*da0073e9SAndroid Build Coastguard Worker            ],
617*da0073e9SAndroid Build Coastguard Worker            deallocs=[
618*da0073e9SAndroid Build Coastguard Worker                "test_user_scope_dealloc",
619*da0073e9SAndroid Build Coastguard Worker            ],
620*da0073e9SAndroid Build Coastguard Worker        )
621*da0073e9SAndroid Build Coastguard Worker
622*da0073e9SAndroid Build Coastguard Worker        if kineto_available():
623*da0073e9SAndroid Build Coastguard Worker            with TemporaryFileName(mode="w+") as fname:
624*da0073e9SAndroid Build Coastguard Worker                with profile(profile_memory=True) as prof:
625*da0073e9SAndroid Build Coastguard Worker                    x = None
626*da0073e9SAndroid Build Coastguard Worker                    with record_function("test_user_scope_alloc"):
627*da0073e9SAndroid Build Coastguard Worker                        x = create_cpu_tensor()
628*da0073e9SAndroid Build Coastguard Worker                    with record_function("test_user_scope_dealloc"):
629*da0073e9SAndroid Build Coastguard Worker                        del x
630*da0073e9SAndroid Build Coastguard Worker                prof.export_chrome_trace(fname)
631*da0073e9SAndroid Build Coastguard Worker                with open(fname) as f:
632*da0073e9SAndroid Build Coastguard Worker                    trace = json.load(f)
633*da0073e9SAndroid Build Coastguard Worker                    assert "traceEvents" in trace
634*da0073e9SAndroid Build Coastguard Worker                    events = trace["traceEvents"]
635*da0073e9SAndroid Build Coastguard Worker                    found_memory_events = False
636*da0073e9SAndroid Build Coastguard Worker                    for evt in events:
637*da0073e9SAndroid Build Coastguard Worker                        assert "name" in evt
638*da0073e9SAndroid Build Coastguard Worker                        if evt["name"] == "[memory]":
639*da0073e9SAndroid Build Coastguard Worker                            found_memory_events = True
640*da0073e9SAndroid Build Coastguard Worker                            assert "args" in evt
641*da0073e9SAndroid Build Coastguard Worker                            assert "Addr" in evt["args"]
642*da0073e9SAndroid Build Coastguard Worker                            assert "Device Type" in evt["args"]
643*da0073e9SAndroid Build Coastguard Worker                            assert "Device Id" in evt["args"]
644*da0073e9SAndroid Build Coastguard Worker                            assert "Bytes" in evt["args"]
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker                            # Memory should be an instantaneous event.
647*da0073e9SAndroid Build Coastguard Worker                            assert "dur" not in evt["args"]
648*da0073e9SAndroid Build Coastguard Worker                            assert "cat" not in evt["args"]
649*da0073e9SAndroid Build Coastguard Worker                    assert found_memory_events
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
652*da0073e9SAndroid Build Coastguard Worker            create_cuda_tensor()
653*da0073e9SAndroid Build Coastguard Worker            stats = run_profiler(create_cuda_tensor)
654*da0073e9SAndroid Build Coastguard Worker            check_metrics(
655*da0073e9SAndroid Build Coastguard Worker                stats,
656*da0073e9SAndroid Build Coastguard Worker                "device_memory_usage",
657*da0073e9SAndroid Build Coastguard Worker                allocs=[
658*da0073e9SAndroid Build Coastguard Worker                    "test_user_scope_alloc",
659*da0073e9SAndroid Build Coastguard Worker                    "aten::to",
660*da0073e9SAndroid Build Coastguard Worker                    "aten::empty_strided",
661*da0073e9SAndroid Build Coastguard Worker                ],
662*da0073e9SAndroid Build Coastguard Worker                deallocs=[
663*da0073e9SAndroid Build Coastguard Worker                    "test_user_scope_dealloc",
664*da0073e9SAndroid Build Coastguard Worker                ],
665*da0073e9SAndroid Build Coastguard Worker            )
666*da0073e9SAndroid Build Coastguard Worker            check_metrics(
667*da0073e9SAndroid Build Coastguard Worker                stats,
668*da0073e9SAndroid Build Coastguard Worker                "cpu_memory_usage",
669*da0073e9SAndroid Build Coastguard Worker                allocs=[
670*da0073e9SAndroid Build Coastguard Worker                    "aten::rand",
671*da0073e9SAndroid Build Coastguard Worker                    "aten::empty",
672*da0073e9SAndroid Build Coastguard Worker                ],
673*da0073e9SAndroid Build Coastguard Worker            )
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker        if torch.backends.mkldnn.is_available():
676*da0073e9SAndroid Build Coastguard Worker            create_mkldnn_tensor()
677*da0073e9SAndroid Build Coastguard Worker            stats = run_profiler(create_mkldnn_tensor)
678*da0073e9SAndroid Build Coastguard Worker            check_metrics(
679*da0073e9SAndroid Build Coastguard Worker                stats,
680*da0073e9SAndroid Build Coastguard Worker                "cpu_memory_usage",
681*da0073e9SAndroid Build Coastguard Worker                allocs=[
682*da0073e9SAndroid Build Coastguard Worker                    "test_user_scope_alloc",
683*da0073e9SAndroid Build Coastguard Worker                    "aten::rand",
684*da0073e9SAndroid Build Coastguard Worker                    "aten::empty",
685*da0073e9SAndroid Build Coastguard Worker                    "aten::to_mkldnn",
686*da0073e9SAndroid Build Coastguard Worker                ],
687*da0073e9SAndroid Build Coastguard Worker                deallocs=[
688*da0073e9SAndroid Build Coastguard Worker                    "test_user_scope_dealloc",
689*da0073e9SAndroid Build Coastguard Worker                ],
690*da0073e9SAndroid Build Coastguard Worker            )
691*da0073e9SAndroid Build Coastguard Worker
692*da0073e9SAndroid Build Coastguard Worker        # check top-level memory events
693*da0073e9SAndroid Build Coastguard Worker        with _profile(profile_memory=True, use_kineto=kineto_available()) as prof:
694*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(10, 10)
695*da0073e9SAndroid Build Coastguard Worker            del x
696*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.is_available():
697*da0073e9SAndroid Build Coastguard Worker                y = torch.rand(10, 10).cuda()
698*da0073e9SAndroid Build Coastguard Worker                del y
699*da0073e9SAndroid Build Coastguard Worker            gc.collect()
700*da0073e9SAndroid Build Coastguard Worker        stats = prof.key_averages(group_by_input_shape=True)
701*da0073e9SAndroid Build Coastguard Worker        check_metrics(
702*da0073e9SAndroid Build Coastguard Worker            stats,
703*da0073e9SAndroid Build Coastguard Worker            "cpu_memory_usage",
704*da0073e9SAndroid Build Coastguard Worker            allocs=["aten::rand", "aten::empty"],
705*da0073e9SAndroid Build Coastguard Worker            deallocs=["[memory]"],
706*da0073e9SAndroid Build Coastguard Worker        )
707*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
708*da0073e9SAndroid Build Coastguard Worker            check_metrics(stats, "device_memory_usage", deallocs=["[memory]"])
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
711*da0073e9SAndroid Build Coastguard Worker        IS_JETSON, "Jetson has a guard against OOM since host and gpu memory are shared"
712*da0073e9SAndroid Build Coastguard Worker    )
713*da0073e9SAndroid Build Coastguard Worker    def test_oom_tracing(self):
714*da0073e9SAndroid Build Coastguard Worker        def run_profiler(tensor_creation_fn):
715*da0073e9SAndroid Build Coastguard Worker            with _profile(profile_memory=True, record_shapes=True) as prof:
716*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, ".*[tT]ried to allocate.*"):
717*da0073e9SAndroid Build Coastguard Worker                    x = tensor_creation_fn()
718*da0073e9SAndroid Build Coastguard Worker                return prof
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker        def create_cuda_tensor_oom():
721*da0073e9SAndroid Build Coastguard Worker            device = torch.device("cuda:0")
722*da0073e9SAndroid Build Coastguard Worker            return torch.empty(
723*da0073e9SAndroid Build Coastguard Worker                1024, 1024, 1024, 1024, dtype=torch.float32, device=device
724*da0073e9SAndroid Build Coastguard Worker            )
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Worker        def check_trace(fname):
727*da0073e9SAndroid Build Coastguard Worker            prof.export_chrome_trace(fname)
728*da0073e9SAndroid Build Coastguard Worker            with open(fname) as f:
729*da0073e9SAndroid Build Coastguard Worker                trace = json.load(f)
730*da0073e9SAndroid Build Coastguard Worker                self.assertTrue("traceEvents" in trace)
731*da0073e9SAndroid Build Coastguard Worker                events = trace["traceEvents"]
732*da0073e9SAndroid Build Coastguard Worker                found_out_of_memory_events = False
733*da0073e9SAndroid Build Coastguard Worker                for evt in events:
734*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue("name" in evt)
735*da0073e9SAndroid Build Coastguard Worker                    if evt["name"] == "[OutOfMemory]":
736*da0073e9SAndroid Build Coastguard Worker                        found_out_of_memory_events = True
737*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue("args" in evt)
738*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue("Device Type" in evt["args"])
739*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue("Device Id" in evt["args"])
740*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue("Bytes" in evt["args"])
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker                        # Memory should be an instantaneous event.
743*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue("dur" not in evt["args"])
744*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue("cat" not in evt["args"])
745*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(found_out_of_memory_events)
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
748*da0073e9SAndroid Build Coastguard Worker            with TemporaryFileName(mode="w+") as fname:
749*da0073e9SAndroid Build Coastguard Worker                prof = run_profiler(create_cuda_tensor_oom)
750*da0073e9SAndroid Build Coastguard Worker                check_trace(fname)
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not kineto_available(), "Kineto is required")
753*da0073e9SAndroid Build Coastguard Worker    def test_module_hierarchy(self):
754*da0073e9SAndroid Build Coastguard Worker        class A(nn.Module):
755*da0073e9SAndroid Build Coastguard Worker            def my_new_method(self, x):
756*da0073e9SAndroid Build Coastguard Worker                return x * 3
757*da0073e9SAndroid Build Coastguard Worker
758*da0073e9SAndroid Build Coastguard Worker            def forward_impl_(self, x, y):
759*da0073e9SAndroid Build Coastguard Worker                return self.my_new_method(x) + y
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
762*da0073e9SAndroid Build Coastguard Worker                y = y - 2
763*da0073e9SAndroid Build Coastguard Worker                return self.forward_impl_(x, y)
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker        class B(nn.Module):
766*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
767*da0073e9SAndroid Build Coastguard Worker                return x + 2
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Worker        class C(nn.Module):
770*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
771*da0073e9SAndroid Build Coastguard Worker                super().__init__()
772*da0073e9SAndroid Build Coastguard Worker                self.A0 = A()
773*da0073e9SAndroid Build Coastguard Worker                self.B0 = B()
774*da0073e9SAndroid Build Coastguard Worker
775*da0073e9SAndroid Build Coastguard Worker            def call_b(self, x):
776*da0073e9SAndroid Build Coastguard Worker                return self.B0.forward(x)
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
779*da0073e9SAndroid Build Coastguard Worker                return self.A0.forward(x, y) + self.call_b(x)
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker        model = C()
782*da0073e9SAndroid Build Coastguard Worker        model = torch.jit.script(model)
783*da0073e9SAndroid Build Coastguard Worker        input_a = torch.rand(128, 128)
784*da0073e9SAndroid Build Coastguard Worker        input_b = torch.rand(128, 128)
785*da0073e9SAndroid Build Coastguard Worker        op_to_module_hierarchy = {}
786*da0073e9SAndroid Build Coastguard Worker        op_to_module_hierarchy["aten::sub"] = ["TOP(C)::forward.A0(A)::forward."]
787*da0073e9SAndroid Build Coastguard Worker        op_to_module_hierarchy["aten::mul"] = [
788*da0073e9SAndroid Build Coastguard Worker            "TOP(C)::forward.A0(A)::forward.SELF(A)::forward_impl_.SELF(A)::my_new_method."
789*da0073e9SAndroid Build Coastguard Worker        ]
790*da0073e9SAndroid Build Coastguard Worker        op_to_module_hierarchy["aten::add"] = [
791*da0073e9SAndroid Build Coastguard Worker            "TOP(C)::forward.A0(A)::forward.SELF(A)::forward_impl_.",
792*da0073e9SAndroid Build Coastguard Worker            "TOP(C)::forward.SELF(C)::call_b.B0(B)::forward.",
793*da0073e9SAndroid Build Coastguard Worker            "TOP(C)::forward.",
794*da0073e9SAndroid Build Coastguard Worker        ]
795*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName(mode="w+") as fname:
796*da0073e9SAndroid Build Coastguard Worker            with profile(
797*da0073e9SAndroid Build Coastguard Worker                activities=[torch.profiler.ProfilerActivity.CPU],
798*da0073e9SAndroid Build Coastguard Worker                with_modules=True,
799*da0073e9SAndroid Build Coastguard Worker            ) as prof:
800*da0073e9SAndroid Build Coastguard Worker                model(input_a, input_b)
801*da0073e9SAndroid Build Coastguard Worker            prof.export_chrome_trace(fname)
802*da0073e9SAndroid Build Coastguard Worker            with open(fname) as f:
803*da0073e9SAndroid Build Coastguard Worker                trace = json.load(f)
804*da0073e9SAndroid Build Coastguard Worker                assert "traceEvents" in trace
805*da0073e9SAndroid Build Coastguard Worker                events = trace["traceEvents"]
806*da0073e9SAndroid Build Coastguard Worker                found_memory_events = False
807*da0073e9SAndroid Build Coastguard Worker                for evt in events:
808*da0073e9SAndroid Build Coastguard Worker                    assert "name" in evt
809*da0073e9SAndroid Build Coastguard Worker                    if "args" in evt:
810*da0073e9SAndroid Build Coastguard Worker                        op_name = evt["name"]
811*da0073e9SAndroid Build Coastguard Worker                        if "Module Hierarchy" in evt["args"]:
812*da0073e9SAndroid Build Coastguard Worker                            hierarchy = evt["args"]["Module Hierarchy"]
813*da0073e9SAndroid Build Coastguard Worker                            if op_name in op_to_module_hierarchy:
814*da0073e9SAndroid Build Coastguard Worker                                assert hierarchy in op_to_module_hierarchy[op_name]
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Worker    def test_high_level_trace(self):
817*da0073e9SAndroid Build Coastguard Worker        """Checks that python side high level events are recorded."""
818*da0073e9SAndroid Build Coastguard Worker
819*da0073e9SAndroid Build Coastguard Worker        class RepeatedDataset(torch.utils.data.Dataset):
820*da0073e9SAndroid Build Coastguard Worker            def __init__(self, N, D_in, D_out):
821*da0073e9SAndroid Build Coastguard Worker                self.N = N
822*da0073e9SAndroid Build Coastguard Worker                self.x = torch.randn(N, D_in)
823*da0073e9SAndroid Build Coastguard Worker                self.y = torch.randn(N, D_out)
824*da0073e9SAndroid Build Coastguard Worker
825*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
826*da0073e9SAndroid Build Coastguard Worker                return self.N
827*da0073e9SAndroid Build Coastguard Worker
828*da0073e9SAndroid Build Coastguard Worker            def __getitem__(self, idx):
829*da0073e9SAndroid Build Coastguard Worker                return self.x, self.y
830*da0073e9SAndroid Build Coastguard Worker
831*da0073e9SAndroid Build Coastguard Worker        class TwoLayerNet(torch.nn.Module):
832*da0073e9SAndroid Build Coastguard Worker            def __init__(self, D_in, H, D_out):
833*da0073e9SAndroid Build Coastguard Worker                super().__init__()
834*da0073e9SAndroid Build Coastguard Worker                self.linear1 = torch.nn.Linear(D_in, H)
835*da0073e9SAndroid Build Coastguard Worker                self.linear2 = torch.nn.Linear(H, D_out)
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
838*da0073e9SAndroid Build Coastguard Worker                h_relu = self.linear1(x).clamp(min=0)
839*da0073e9SAndroid Build Coastguard Worker                y_pred = self.linear2(h_relu)
840*da0073e9SAndroid Build Coastguard Worker                return y_pred
841*da0073e9SAndroid Build Coastguard Worker
842*da0073e9SAndroid Build Coastguard Worker        class CustomSGD(torch.optim.SGD):
843*da0073e9SAndroid Build Coastguard Worker            def __init__(self, *args, **kwargs):
844*da0073e9SAndroid Build Coastguard Worker                super().__init__(*args, **kwargs)
845*da0073e9SAndroid Build Coastguard Worker
846*da0073e9SAndroid Build Coastguard Worker        def train():
847*da0073e9SAndroid Build Coastguard Worker            for _, data in enumerate(dataloader):
848*da0073e9SAndroid Build Coastguard Worker                x, y = data[0], data[1]
849*da0073e9SAndroid Build Coastguard Worker                y_pred = model(x)
850*da0073e9SAndroid Build Coastguard Worker                loss = criterion(y_pred, y)
851*da0073e9SAndroid Build Coastguard Worker                optimizer.zero_grad()
852*da0073e9SAndroid Build Coastguard Worker                loss.backward()
853*da0073e9SAndroid Build Coastguard Worker                optimizer.step()
854*da0073e9SAndroid Build Coastguard Worker
855*da0073e9SAndroid Build Coastguard Worker        N, D_in, H, D_out = 8, 10, 5, 2
856*da0073e9SAndroid Build Coastguard Worker        model = TwoLayerNet(D_in, H, D_out)
857*da0073e9SAndroid Build Coastguard Worker        criterion = torch.nn.MSELoss(reduction="sum")
858*da0073e9SAndroid Build Coastguard Worker        optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
859*da0073e9SAndroid Build Coastguard Worker        ds = RepeatedDataset(N, D_in, D_out)
860*da0073e9SAndroid Build Coastguard Worker        dataloader = torch.utils.data.DataLoader(ds, batch_size=1)
861*da0073e9SAndroid Build Coastguard Worker
862*da0073e9SAndroid Build Coastguard Worker        try:
863*da0073e9SAndroid Build Coastguard Worker            train()
864*da0073e9SAndroid Build Coastguard Worker        except Exception:
865*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(False, "Expected no exception without profiling.")
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Worker        # Create multiple instances, expect each func is hooked only one time.
868*da0073e9SAndroid Build Coastguard Worker        # Nested wrappers(repeated patching) will make following test fail.
869*da0073e9SAndroid Build Coastguard Worker        optimizer_duplicate = torch.optim.SGD(model.parameters(), lr=1e-4)
870*da0073e9SAndroid Build Coastguard Worker        dataloader_duplicate = torch.utils.data.DataLoader(ds, batch_size=1)
871*da0073e9SAndroid Build Coastguard Worker
872*da0073e9SAndroid Build Coastguard Worker        def judge(expected_event_count, prof):
873*da0073e9SAndroid Build Coastguard Worker            actual_event_count = {}
874*da0073e9SAndroid Build Coastguard Worker            for e in prof.function_events:
875*da0073e9SAndroid Build Coastguard Worker                if "#" in e.name:
876*da0073e9SAndroid Build Coastguard Worker                    key = e.name
877*da0073e9SAndroid Build Coastguard Worker                    if key in expected_event_count.keys():
878*da0073e9SAndroid Build Coastguard Worker                        actual_event_count[key] = (
879*da0073e9SAndroid Build Coastguard Worker                            actual_event_count.setdefault(key, 0) + 1
880*da0073e9SAndroid Build Coastguard Worker                        )
881*da0073e9SAndroid Build Coastguard Worker            for key, count in expected_event_count.items():
882*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
883*da0073e9SAndroid Build Coastguard Worker                    (key in actual_event_count.keys())
884*da0073e9SAndroid Build Coastguard Worker                    and (count == actual_event_count[key])
885*da0073e9SAndroid Build Coastguard Worker                )
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker        with _profile(use_kineto=kineto_available()) as prof:
888*da0073e9SAndroid Build Coastguard Worker            train()
889*da0073e9SAndroid Build Coastguard Worker        expected_event_count = {
890*da0073e9SAndroid Build Coastguard Worker            # "+1" because the final iteration will enter __next__ but skip the loop body.
891*da0073e9SAndroid Build Coastguard Worker            "enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1),
892*da0073e9SAndroid Build Coastguard Worker            "Optimizer.step#SGD.step": N,
893*da0073e9SAndroid Build Coastguard Worker            "Optimizer.zero_grad#SGD.zero_grad": N,
894*da0073e9SAndroid Build Coastguard Worker        }
895*da0073e9SAndroid Build Coastguard Worker        judge(expected_event_count, prof)
896*da0073e9SAndroid Build Coastguard Worker
897*da0073e9SAndroid Build Coastguard Worker        # Test on pickle/unpickle. Expect to work in multi-processing.
898*da0073e9SAndroid Build Coastguard Worker        optimizer = pickle.loads(pickle.dumps(optimizer))
899*da0073e9SAndroid Build Coastguard Worker        with _profile(use_kineto=kineto_available()) as prof:
900*da0073e9SAndroid Build Coastguard Worker            train()
901*da0073e9SAndroid Build Coastguard Worker        judge(expected_event_count, prof)
902*da0073e9SAndroid Build Coastguard Worker
903*da0073e9SAndroid Build Coastguard Worker        # Test on customized optimizer.
904*da0073e9SAndroid Build Coastguard Worker        optimizer = CustomSGD(model.parameters(), lr=1e-4)
905*da0073e9SAndroid Build Coastguard Worker        with _profile(use_kineto=kineto_available()) as prof:
906*da0073e9SAndroid Build Coastguard Worker            train()
907*da0073e9SAndroid Build Coastguard Worker        expected_event_count = {
908*da0073e9SAndroid Build Coastguard Worker            "enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1),
909*da0073e9SAndroid Build Coastguard Worker            "Optimizer.step#CustomSGD.step": N,
910*da0073e9SAndroid Build Coastguard Worker            "Optimizer.zero_grad#CustomSGD.zero_grad": N,
911*da0073e9SAndroid Build Coastguard Worker        }
912*da0073e9SAndroid Build Coastguard Worker        judge(expected_event_count, prof)
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Worker    def test_flops(self):
915*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Sequential(
916*da0073e9SAndroid Build Coastguard Worker            nn.Conv2d(16, 33, 18),
917*da0073e9SAndroid Build Coastguard Worker            nn.ReLU(),
918*da0073e9SAndroid Build Coastguard Worker            nn.Linear(243, 243),
919*da0073e9SAndroid Build Coastguard Worker            nn.ReLU(),
920*da0073e9SAndroid Build Coastguard Worker        )
921*da0073e9SAndroid Build Coastguard Worker        inputs = torch.randn(40, 16, 18, 260)
922*da0073e9SAndroid Build Coastguard Worker        nested_tensor = torch.nested.nested_tensor(
923*da0073e9SAndroid Build Coastguard Worker            [torch.randn((2, 5)), torch.randn((3, 5))], layout=torch.jagged
924*da0073e9SAndroid Build Coastguard Worker        )
925*da0073e9SAndroid Build Coastguard Worker        with _profile(
926*da0073e9SAndroid Build Coastguard Worker            record_shapes=True, with_flops=True, use_kineto=kineto_available()
927*da0073e9SAndroid Build Coastguard Worker        ) as prof:
928*da0073e9SAndroid Build Coastguard Worker            model(inputs)
929*da0073e9SAndroid Build Coastguard Worker            # test that nested tensor won't cause exception during flop compute
930*da0073e9SAndroid Build Coastguard Worker            nested_tensor = nested_tensor + nested_tensor
931*da0073e9SAndroid Build Coastguard Worker        profiler_output = prof.key_averages(group_by_input_shape=True).table(
932*da0073e9SAndroid Build Coastguard Worker            sort_by="cpu_time_total", row_limit=10
933*da0073e9SAndroid Build Coastguard Worker        )
934*da0073e9SAndroid Build Coastguard Worker        self.assertIn("Total MFLOPs", profiler_output)
935*da0073e9SAndroid Build Coastguard Worker        if not (kineto_available() and torch.cuda.is_available()):
936*da0073e9SAndroid Build Coastguard Worker            return
937*da0073e9SAndroid Build Coastguard Worker
938*da0073e9SAndroid Build Coastguard Worker        with profile(
939*da0073e9SAndroid Build Coastguard Worker            activities=[
940*da0073e9SAndroid Build Coastguard Worker                torch.profiler.ProfilerActivity.CPU,
941*da0073e9SAndroid Build Coastguard Worker                torch.profiler.ProfilerActivity.CUDA,
942*da0073e9SAndroid Build Coastguard Worker            ],
943*da0073e9SAndroid Build Coastguard Worker            record_shapes=True,
944*da0073e9SAndroid Build Coastguard Worker            with_flops=True,
945*da0073e9SAndroid Build Coastguard Worker        ) as kineto_profiler:
946*da0073e9SAndroid Build Coastguard Worker            model(inputs)
947*da0073e9SAndroid Build Coastguard Worker        profiler_output = kineto_profiler.key_averages().table(
948*da0073e9SAndroid Build Coastguard Worker            sort_by="self_cuda_time_total", row_limit=-1
949*da0073e9SAndroid Build Coastguard Worker        )
950*da0073e9SAndroid Build Coastguard Worker        self.assertIn("Total MFLOPs", profiler_output)
951*da0073e9SAndroid Build Coastguard Worker
952*da0073e9SAndroid Build Coastguard Worker    def test_kineto_profiler_api(self):
953*da0073e9SAndroid Build Coastguard Worker        called_num = [0]
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
956*da0073e9SAndroid Build Coastguard Worker        with profile(activities=supported_activities()):
957*da0073e9SAndroid Build Coastguard Worker            self.payload(use_cuda=use_cuda)
958*da0073e9SAndroid Build Coastguard Worker
959*da0073e9SAndroid Build Coastguard Worker        def trace_handler(p):
960*da0073e9SAndroid Build Coastguard Worker            output = p.key_averages().table(
961*da0073e9SAndroid Build Coastguard Worker                sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total",
962*da0073e9SAndroid Build Coastguard Worker                row_limit=-1,
963*da0073e9SAndroid Build Coastguard Worker            )
964*da0073e9SAndroid Build Coastguard Worker            # print(output)
965*da0073e9SAndroid Build Coastguard Worker            # p.export_chrome_trace("/tmp/test_trace_" + str(called_num[0]) + ".json")
966*da0073e9SAndroid Build Coastguard Worker            called_num[0] += 1
967*da0073e9SAndroid Build Coastguard Worker
968*da0073e9SAndroid Build Coastguard Worker        initial_step = KinetoStepTracker.current_step()
969*da0073e9SAndroid Build Coastguard Worker
970*da0073e9SAndroid Build Coastguard Worker        with profile(
971*da0073e9SAndroid Build Coastguard Worker            activities=supported_activities(),
972*da0073e9SAndroid Build Coastguard Worker            schedule=torch.profiler.schedule(wait=1, warmup=1, active=2),
973*da0073e9SAndroid Build Coastguard Worker            on_trace_ready=trace_handler,
974*da0073e9SAndroid Build Coastguard Worker        ) as p:
975*da0073e9SAndroid Build Coastguard Worker            for idx in range(8):
976*da0073e9SAndroid Build Coastguard Worker                self.payload(use_cuda=use_cuda)
977*da0073e9SAndroid Build Coastguard Worker                p.step()
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called_num[0], 2)
980*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(KinetoStepTracker.current_step(), initial_step + 8)
981*da0073e9SAndroid Build Coastguard Worker
982*da0073e9SAndroid Build Coastguard Worker        # case without schedule
983*da0073e9SAndroid Build Coastguard Worker        with profile(activities=supported_activities()) as p:
984*da0073e9SAndroid Build Coastguard Worker            self.payload(use_cuda=use_cuda)
985*da0073e9SAndroid Build Coastguard Worker            self.payload(use_cuda=use_cuda)
986*da0073e9SAndroid Build Coastguard Worker        output = p.key_averages().table(
987*da0073e9SAndroid Build Coastguard Worker            sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total",
988*da0073e9SAndroid Build Coastguard Worker            row_limit=-1,
989*da0073e9SAndroid Build Coastguard Worker        )
990*da0073e9SAndroid Build Coastguard Worker        # print(output)
991*da0073e9SAndroid Build Coastguard Worker
992*da0073e9SAndroid Build Coastguard Worker        test_schedule = torch.profiler.schedule(
993*da0073e9SAndroid Build Coastguard Worker            skip_first=2, wait=1, warmup=1, active=2, repeat=2
994*da0073e9SAndroid Build Coastguard Worker        )
995*da0073e9SAndroid Build Coastguard Worker        test_schedule_expected_outputs = [
996*da0073e9SAndroid Build Coastguard Worker            ProfilerAction.NONE,
997*da0073e9SAndroid Build Coastguard Worker            ProfilerAction.NONE,
998*da0073e9SAndroid Build Coastguard Worker            ProfilerAction.NONE,
999*da0073e9SAndroid Build Coastguard Worker            ProfilerAction.WARMUP,
1000*da0073e9SAndroid Build Coastguard Worker            ProfilerAction.RECORD,
1001*da0073e9SAndroid Build Coastguard Worker            ProfilerAction.RECORD_AND_SAVE,
1002*da0073e9SAndroid Build Coastguard Worker            ProfilerAction.NONE,
1003*da0073e9SAndroid Build Coastguard Worker            ProfilerAction.WARMUP,
1004*da0073e9SAndroid Build Coastguard Worker            ProfilerAction.RECORD,
1005*da0073e9SAndroid Build Coastguard Worker            ProfilerAction.RECORD_AND_SAVE,
1006*da0073e9SAndroid Build Coastguard Worker            ProfilerAction.NONE,
1007*da0073e9SAndroid Build Coastguard Worker            ProfilerAction.NONE,
1008*da0073e9SAndroid Build Coastguard Worker            ProfilerAction.NONE,
1009*da0073e9SAndroid Build Coastguard Worker            ProfilerAction.NONE,
1010*da0073e9SAndroid Build Coastguard Worker        ]
1011*da0073e9SAndroid Build Coastguard Worker        for step in range(len(test_schedule_expected_outputs)):
1012*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(test_schedule(step), test_schedule_expected_outputs[step])
1013*da0073e9SAndroid Build Coastguard Worker
1014*da0073e9SAndroid Build Coastguard Worker    def test_kineto_profiler_multiple_steppers(self):
1015*da0073e9SAndroid Build Coastguard Worker        niters = 8
1016*da0073e9SAndroid Build Coastguard Worker        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
1017*da0073e9SAndroid Build Coastguard Worker        net = SimpleNet()
1018*da0073e9SAndroid Build Coastguard Worker        opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
1019*da0073e9SAndroid Build Coastguard Worker        opt.zero_grad()
1020*da0073e9SAndroid Build Coastguard Worker        inputs = torch.rand(10)
1021*da0073e9SAndroid Build Coastguard Worker
1022*da0073e9SAndroid Build Coastguard Worker        with profile(activities=supported_activities()):
1023*da0073e9SAndroid Build Coastguard Worker            self.payload(use_cuda=use_cuda)
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard Worker        def optimizer_step():
1026*da0073e9SAndroid Build Coastguard Worker            """This simulates a step() hook in the optimizer"""
1027*da0073e9SAndroid Build Coastguard Worker            KinetoStepTracker.increment_step("yet_another_step")
1028*da0073e9SAndroid Build Coastguard Worker
1029*da0073e9SAndroid Build Coastguard Worker        initial_step = KinetoStepTracker.current_step()
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker        def run_batch():
1032*da0073e9SAndroid Build Coastguard Worker            out = net(inputs)
1033*da0073e9SAndroid Build Coastguard Worker            loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
1034*da0073e9SAndroid Build Coastguard Worker            loss.backward()
1035*da0073e9SAndroid Build Coastguard Worker            opt.step()
1036*da0073e9SAndroid Build Coastguard Worker            # Manually call the hook. TODO: Remove this once we add the
1037*da0073e9SAndroid Build Coastguard Worker            # profiler step hooks in the Optimizer class that will get triggered above.
1038*da0073e9SAndroid Build Coastguard Worker            # See https://github.com/pytorch/pytorch/issues/88446
1039*da0073e9SAndroid Build Coastguard Worker            optimizer_step()
1040*da0073e9SAndroid Build Coastguard Worker
1041*da0073e9SAndroid Build Coastguard Worker        for idx in range(niters):
1042*da0073e9SAndroid Build Coastguard Worker            run_batch()
1043*da0073e9SAndroid Build Coastguard Worker
1044*da0073e9SAndroid Build Coastguard Worker        with profile(
1045*da0073e9SAndroid Build Coastguard Worker            activities=supported_activities(),
1046*da0073e9SAndroid Build Coastguard Worker            schedule=torch.profiler.schedule(wait=1, warmup=1, active=2),
1047*da0073e9SAndroid Build Coastguard Worker        ) as p:
1048*da0073e9SAndroid Build Coastguard Worker            for idx in range(niters):
1049*da0073e9SAndroid Build Coastguard Worker                run_batch()
1050*da0073e9SAndroid Build Coastguard Worker                p.step()
1051*da0073e9SAndroid Build Coastguard Worker
1052*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(KinetoStepTracker.current_step(), initial_step + 2 * niters)
1053*da0073e9SAndroid Build Coastguard Worker
1054*da0073e9SAndroid Build Coastguard Worker    def test_export_stacks(self):
1055*da0073e9SAndroid Build Coastguard Worker        with _profile(
1056*da0073e9SAndroid Build Coastguard Worker            with_stack=True,
1057*da0073e9SAndroid Build Coastguard Worker            use_kineto=kineto_available(),
1058*da0073e9SAndroid Build Coastguard Worker            experimental_config=_ExperimentalConfig(verbose=True),
1059*da0073e9SAndroid Build Coastguard Worker        ) as p:
1060*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(10, 10)
1061*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(10, 10)
1062*da0073e9SAndroid Build Coastguard Worker            z = torch.mm(x, y)
1063*da0073e9SAndroid Build Coastguard Worker            z = z + y
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName(mode="w+") as fname:
1066*da0073e9SAndroid Build Coastguard Worker            p.export_stacks(fname)
1067*da0073e9SAndroid Build Coastguard Worker            with open(fname) as f:
1068*da0073e9SAndroid Build Coastguard Worker                lines = f.readlines()
1069*da0073e9SAndroid Build Coastguard Worker            assert len(lines) > 0, "Empty stacks file"
1070*da0073e9SAndroid Build Coastguard Worker            for line in lines:
1071*da0073e9SAndroid Build Coastguard Worker                is_int = False
1072*da0073e9SAndroid Build Coastguard Worker                try:
1073*da0073e9SAndroid Build Coastguard Worker                    assert int(line.split(" ")[-1]) > 0, "Invalid stacks record"
1074*da0073e9SAndroid Build Coastguard Worker                    is_int = True
1075*da0073e9SAndroid Build Coastguard Worker                except ValueError:
1076*da0073e9SAndroid Build Coastguard Worker                    pass
1077*da0073e9SAndroid Build Coastguard Worker                assert is_int, "Invalid stacks record"
1078*da0073e9SAndroid Build Coastguard Worker
1079*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not kineto_available(), "Kineto is required")
1080*da0073e9SAndroid Build Coastguard Worker    def test_tensorboard_trace_handler(self):
1081*da0073e9SAndroid Build Coastguard Worker        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
1082*da0073e9SAndroid Build Coastguard Worker        with _profile(use_cuda=use_cuda, use_kineto=True):
1083*da0073e9SAndroid Build Coastguard Worker            self.payload(use_cuda=use_cuda)
1084*da0073e9SAndroid Build Coastguard Worker
1085*da0073e9SAndroid Build Coastguard Worker        with TemporaryDirectoryName() as dname:
1086*da0073e9SAndroid Build Coastguard Worker            with profile(
1087*da0073e9SAndroid Build Coastguard Worker                activities=[torch.profiler.ProfilerActivity.CPU]
1088*da0073e9SAndroid Build Coastguard Worker                + ([torch.profiler.ProfilerActivity.CUDA] if use_cuda else []),
1089*da0073e9SAndroid Build Coastguard Worker                schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=3),
1090*da0073e9SAndroid Build Coastguard Worker                on_trace_ready=torch.profiler.tensorboard_trace_handler(dname),
1091*da0073e9SAndroid Build Coastguard Worker            ) as p:
1092*da0073e9SAndroid Build Coastguard Worker                for _ in range(18):
1093*da0073e9SAndroid Build Coastguard Worker                    self.payload(use_cuda=use_cuda)
1094*da0073e9SAndroid Build Coastguard Worker                    p.step()
1095*da0073e9SAndroid Build Coastguard Worker
1096*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(os.path.exists(dname))
1097*da0073e9SAndroid Build Coastguard Worker            file_num = 0
1098*da0073e9SAndroid Build Coastguard Worker            for file_name in os.listdir(dname):
1099*da0073e9SAndroid Build Coastguard Worker                parts = file_name.split(".")
1100*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(len(parts) > 4)
1101*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
1102*da0073e9SAndroid Build Coastguard Worker                    parts[-4].isdigit() and int(parts[-4]) > 0,
1103*da0073e9SAndroid Build Coastguard Worker                    "Wrong tracing file name pattern",
1104*da0073e9SAndroid Build Coastguard Worker                )
1105*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(parts[-3:], ["pt", "trace", "json"])
1106*da0073e9SAndroid Build Coastguard Worker                file_num += 1
1107*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(file_num, 3)
1108*da0073e9SAndroid Build Coastguard Worker
1109*da0073e9SAndroid Build Coastguard Worker        # test case for gzip file format
1110*da0073e9SAndroid Build Coastguard Worker        with TemporaryDirectoryName() as dname:
1111*da0073e9SAndroid Build Coastguard Worker            p = profile(
1112*da0073e9SAndroid Build Coastguard Worker                activities=[torch.profiler.ProfilerActivity.CPU]
1113*da0073e9SAndroid Build Coastguard Worker                + ([torch.profiler.ProfilerActivity.CUDA] if use_cuda else []),
1114*da0073e9SAndroid Build Coastguard Worker                schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=3),
1115*da0073e9SAndroid Build Coastguard Worker                on_trace_ready=torch.profiler.tensorboard_trace_handler(
1116*da0073e9SAndroid Build Coastguard Worker                    dname, use_gzip=True
1117*da0073e9SAndroid Build Coastguard Worker                ),
1118*da0073e9SAndroid Build Coastguard Worker            )
1119*da0073e9SAndroid Build Coastguard Worker            p.start()
1120*da0073e9SAndroid Build Coastguard Worker            for _ in range(18):
1121*da0073e9SAndroid Build Coastguard Worker                self.payload(use_cuda=use_cuda)
1122*da0073e9SAndroid Build Coastguard Worker                p.step()
1123*da0073e9SAndroid Build Coastguard Worker            p.stop()
1124*da0073e9SAndroid Build Coastguard Worker
1125*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(os.path.exists(dname))
1126*da0073e9SAndroid Build Coastguard Worker            file_num = 0
1127*da0073e9SAndroid Build Coastguard Worker            for file_name in os.listdir(dname):
1128*da0073e9SAndroid Build Coastguard Worker                parts = file_name.split(".")
1129*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(len(parts) > 4)
1130*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
1131*da0073e9SAndroid Build Coastguard Worker                    parts[-5].isdigit() and int(parts[-5]) > 0,
1132*da0073e9SAndroid Build Coastguard Worker                    "Wrong tracing file name pattern",
1133*da0073e9SAndroid Build Coastguard Worker                )
1134*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(parts[-4:], ["pt", "trace", "json", "gz"])
1135*da0073e9SAndroid Build Coastguard Worker                file_num += 1
1136*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(file_num, 3)
1137*da0073e9SAndroid Build Coastguard Worker
1138*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not kineto_available(), "Kineto is required")
1139*da0073e9SAndroid Build Coastguard Worker    def test_profiler_metadata(self):
1140*da0073e9SAndroid Build Coastguard Worker        t1, t2 = torch.ones(1), torch.ones(1)
1141*da0073e9SAndroid Build Coastguard Worker        with profile() as prof:
1142*da0073e9SAndroid Build Coastguard Worker            torch.add(t1, t2)
1143*da0073e9SAndroid Build Coastguard Worker            prof.add_metadata("test_key1", "test_value1")
1144*da0073e9SAndroid Build Coastguard Worker            prof.add_metadata_json("test_key2", "[1,2,3]")
1145*da0073e9SAndroid Build Coastguard Worker
1146*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName(mode="w+") as fname:
1147*da0073e9SAndroid Build Coastguard Worker            prof.export_chrome_trace(fname)
1148*da0073e9SAndroid Build Coastguard Worker            with open(fname) as f:
1149*da0073e9SAndroid Build Coastguard Worker                trace = json.load(f)
1150*da0073e9SAndroid Build Coastguard Worker                assert "test_key1" in trace
1151*da0073e9SAndroid Build Coastguard Worker                assert trace["test_key1"] == "test_value1"
1152*da0073e9SAndroid Build Coastguard Worker                assert "test_key2" in trace
1153*da0073e9SAndroid Build Coastguard Worker                assert trace["test_key2"] == [1, 2, 3]
1154*da0073e9SAndroid Build Coastguard Worker
1155*da0073e9SAndroid Build Coastguard Worker    def _test_profiler_tracing(self, use_kineto):
1156*da0073e9SAndroid Build Coastguard Worker        with _profile(use_kineto=use_kineto) as prof:
1157*da0073e9SAndroid Build Coastguard Worker            t1, t2 = torch.ones(1), torch.ones(1)
1158*da0073e9SAndroid Build Coastguard Worker            torch.add(t1, t2)
1159*da0073e9SAndroid Build Coastguard Worker
1160*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName(mode="w+") as fname:
1161*da0073e9SAndroid Build Coastguard Worker            prof.export_chrome_trace(fname)
1162*da0073e9SAndroid Build Coastguard Worker            # read the trace and expect valid json
1163*da0073e9SAndroid Build Coastguard Worker            # if the JSON generated by export_chrome_trace is not valid, this will throw and fail the test.
1164*da0073e9SAndroid Build Coastguard Worker            with open(fname) as f:
1165*da0073e9SAndroid Build Coastguard Worker                json.load(f)
1166*da0073e9SAndroid Build Coastguard Worker
1167*da0073e9SAndroid Build Coastguard Worker        # test empty trace
1168*da0073e9SAndroid Build Coastguard Worker        with _profile(use_kineto=use_kineto) as prof:
1169*da0073e9SAndroid Build Coastguard Worker            pass
1170*da0073e9SAndroid Build Coastguard Worker        # saving an empty trace
1171*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName(mode="w+") as fname:
1172*da0073e9SAndroid Build Coastguard Worker            prof.export_chrome_trace(fname)
1173*da0073e9SAndroid Build Coastguard Worker            if use_kineto:
1174*da0073e9SAndroid Build Coastguard Worker                with open(fname) as f:
1175*da0073e9SAndroid Build Coastguard Worker                    contents = json.load(f)
1176*da0073e9SAndroid Build Coastguard Worker                    # Some builds may not have logger observer
1177*da0073e9SAndroid Build Coastguard Worker                    # so skip if not
1178*da0073e9SAndroid Build Coastguard Worker                    if "WARNING" in contents:
1179*da0073e9SAndroid Build Coastguard Worker                        found_empty_warning = False
1180*da0073e9SAndroid Build Coastguard Worker                        for warning in contents["WARNING"]:
1181*da0073e9SAndroid Build Coastguard Worker                            if "No Valid Trace Events" in warning:
1182*da0073e9SAndroid Build Coastguard Worker                                found_empty_warning = True
1183*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(found_empty_warning)
1184*da0073e9SAndroid Build Coastguard Worker
1185*da0073e9SAndroid Build Coastguard Worker        # Same test but for cuda.
1186*da0073e9SAndroid Build Coastguard Worker        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
1187*da0073e9SAndroid Build Coastguard Worker        if not use_cuda:
1188*da0073e9SAndroid Build Coastguard Worker            return
1189*da0073e9SAndroid Build Coastguard Worker
1190*da0073e9SAndroid Build Coastguard Worker        device = torch.device("cuda:0")
1191*da0073e9SAndroid Build Coastguard Worker        with _profile(use_cuda=True, use_kineto=use_kineto) as prof:
1192*da0073e9SAndroid Build Coastguard Worker            t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device)
1193*da0073e9SAndroid Build Coastguard Worker            torch.add(t1, t2)
1194*da0073e9SAndroid Build Coastguard Worker
1195*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName(mode="w+") as fname:
1196*da0073e9SAndroid Build Coastguard Worker            prof.export_chrome_trace(fname)
1197*da0073e9SAndroid Build Coastguard Worker            # Now validate the json
1198*da0073e9SAndroid Build Coastguard Worker            with open(fname) as f:
1199*da0073e9SAndroid Build Coastguard Worker                json.load(f)
1200*da0073e9SAndroid Build Coastguard Worker
1201*da0073e9SAndroid Build Coastguard Worker    def test_profiler_tracing(self):
1202*da0073e9SAndroid Build Coastguard Worker        self._test_profiler_tracing(False)
1203*da0073e9SAndroid Build Coastguard Worker        if kineto_available():
1204*da0073e9SAndroid Build Coastguard Worker            self._test_profiler_tracing(True)
1205*da0073e9SAndroid Build Coastguard Worker
1206*da0073e9SAndroid Build Coastguard Worker    def test_profiler_op_event_args(self):
1207*da0073e9SAndroid Build Coastguard Worker        torch._C._profiler._set_record_concrete_inputs_enabled_val(True)
1208*da0073e9SAndroid Build Coastguard Worker        with _profile(record_shapes=True) as prof:
1209*da0073e9SAndroid Build Coastguard Worker            a = torch.ones((64, 32), dtype=torch.float32)
1210*da0073e9SAndroid Build Coastguard Worker            c = torch.cat([a, a]).sin()
1211*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName(mode="w+") as fname:
1212*da0073e9SAndroid Build Coastguard Worker            prof.export_chrome_trace(fname)
1213*da0073e9SAndroid Build Coastguard Worker            with open(fname) as f:
1214*da0073e9SAndroid Build Coastguard Worker                j = json.load(f)
1215*da0073e9SAndroid Build Coastguard Worker                op_events = [
1216*da0073e9SAndroid Build Coastguard Worker                    e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op"
1217*da0073e9SAndroid Build Coastguard Worker                ]
1218*da0073e9SAndroid Build Coastguard Worker                for e in op_events:
1219*da0073e9SAndroid Build Coastguard Worker                    args = e["args"]
1220*da0073e9SAndroid Build Coastguard Worker                    if e["name"] == "aten::ones":
1221*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(
1222*da0073e9SAndroid Build Coastguard Worker                            args["Input type"],
1223*da0073e9SAndroid Build Coastguard Worker                            ["ScalarList", "Scalar", "", "", "Scalar"],
1224*da0073e9SAndroid Build Coastguard Worker                        )
1225*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(
1226*da0073e9SAndroid Build Coastguard Worker                            args["Concrete Inputs"], ["[64, 32]", "6", "", "", "False"]
1227*da0073e9SAndroid Build Coastguard Worker                        )
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Worker                    if e["name"] == "aten::cat":
1230*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(args["Input Dims"], [[[64, 32], [64, 32]], []])
1231*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(args["Input type"], ["TensorList", "Scalar"])
1232*da0073e9SAndroid Build Coastguard Worker
1233*da0073e9SAndroid Build Coastguard Worker                    # check that each op has record function id
1234*da0073e9SAndroid Build Coastguard Worker                    self.assertGreaterEqual(
1235*da0073e9SAndroid Build Coastguard Worker                        args.get("Record function id", -1),
1236*da0073e9SAndroid Build Coastguard Worker                        0,
1237*da0073e9SAndroid Build Coastguard Worker                        f"Failed finding record funciont for op = {e}",
1238*da0073e9SAndroid Build Coastguard Worker                    )
1239*da0073e9SAndroid Build Coastguard Worker
1240*da0073e9SAndroid Build Coastguard Worker    def test_profiler_strides(self):
1241*da0073e9SAndroid Build Coastguard Worker        torch._C._profiler._set_record_concrete_inputs_enabled_val(True)
1242*da0073e9SAndroid Build Coastguard Worker        base_tensor = torch.randn(1024, dtype=torch.float32)
1243*da0073e9SAndroid Build Coastguard Worker        a = base_tensor.as_strided((16, 16), (17, 1), 0)
1244*da0073e9SAndroid Build Coastguard Worker        b = base_tensor.as_strided((16, 16), (25, 2), 272)
1245*da0073e9SAndroid Build Coastguard Worker        with _profile(record_shapes=True) as prof:
1246*da0073e9SAndroid Build Coastguard Worker            c = torch.add(a, b)
1247*da0073e9SAndroid Build Coastguard Worker
1248*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName(mode="w+") as fname:
1249*da0073e9SAndroid Build Coastguard Worker            prof.export_chrome_trace(fname)
1250*da0073e9SAndroid Build Coastguard Worker            with open(fname) as f:
1251*da0073e9SAndroid Build Coastguard Worker                j = json.load(f)
1252*da0073e9SAndroid Build Coastguard Worker                op_events = [
1253*da0073e9SAndroid Build Coastguard Worker                    e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op"
1254*da0073e9SAndroid Build Coastguard Worker                ]
1255*da0073e9SAndroid Build Coastguard Worker                for e in op_events:
1256*da0073e9SAndroid Build Coastguard Worker                    args = e["args"]
1257*da0073e9SAndroid Build Coastguard Worker                    if e["name"] == "aten::add":
1258*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(args["Input Strides"], [[17, 1], [25, 2], []])
1259*da0073e9SAndroid Build Coastguard Worker
1260*da0073e9SAndroid Build Coastguard Worker    def test_profiler_fwd_bwd_link(self):
1261*da0073e9SAndroid Build Coastguard Worker        with _profile(use_kineto=True) as prof:
1262*da0073e9SAndroid Build Coastguard Worker            t1, t2 = torch.ones(1, requires_grad=True), torch.ones(
1263*da0073e9SAndroid Build Coastguard Worker                1, requires_grad=True
1264*da0073e9SAndroid Build Coastguard Worker            )
1265*da0073e9SAndroid Build Coastguard Worker            z = torch.add(t1, t2)
1266*da0073e9SAndroid Build Coastguard Worker            y = torch.ones(1)
1267*da0073e9SAndroid Build Coastguard Worker            loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
1268*da0073e9SAndroid Build Coastguard Worker            loss.backward()
1269*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName(mode="w+") as fname:
1270*da0073e9SAndroid Build Coastguard Worker            prof.export_chrome_trace(fname)
1271*da0073e9SAndroid Build Coastguard Worker            with open(fname) as f:
1272*da0073e9SAndroid Build Coastguard Worker                j = json.load(f)
1273*da0073e9SAndroid Build Coastguard Worker                events = j["traceEvents"]
1274*da0073e9SAndroid Build Coastguard Worker                ts_to_name = {}
1275*da0073e9SAndroid Build Coastguard Worker                flow_s_to_ts = {}
1276*da0073e9SAndroid Build Coastguard Worker                flow_f_to_ts = {}
1277*da0073e9SAndroid Build Coastguard Worker                for e in events:
1278*da0073e9SAndroid Build Coastguard Worker                    if e["ph"] == "X":
1279*da0073e9SAndroid Build Coastguard Worker                        ts_to_name[e["ts"]] = e["name"]
1280*da0073e9SAndroid Build Coastguard Worker                    if (
1281*da0073e9SAndroid Build Coastguard Worker                        "cat" in e
1282*da0073e9SAndroid Build Coastguard Worker                        and "name" in e
1283*da0073e9SAndroid Build Coastguard Worker                        and e["cat"] == "fwdbwd"
1284*da0073e9SAndroid Build Coastguard Worker                        and e["name"] == "fwdbwd"
1285*da0073e9SAndroid Build Coastguard Worker                    ):
1286*da0073e9SAndroid Build Coastguard Worker                        if e["ph"] == "s":
1287*da0073e9SAndroid Build Coastguard Worker                            flow_s_to_ts[e["id"]] = e["ts"]
1288*da0073e9SAndroid Build Coastguard Worker                        elif e["ph"] == "f":
1289*da0073e9SAndroid Build Coastguard Worker                            flow_f_to_ts[e["id"]] = e["ts"]
1290*da0073e9SAndroid Build Coastguard Worker
1291*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(flow_s_to_ts), 2)
1292*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(flow_f_to_ts), 2)
1293*da0073e9SAndroid Build Coastguard Worker                self.assertIn(1, flow_s_to_ts)
1294*da0073e9SAndroid Build Coastguard Worker                self.assertIn(1, flow_f_to_ts)
1295*da0073e9SAndroid Build Coastguard Worker                self.assertIn(2, flow_s_to_ts)
1296*da0073e9SAndroid Build Coastguard Worker                self.assertIn(2, flow_f_to_ts)
1297*da0073e9SAndroid Build Coastguard Worker                s_ts_1 = flow_s_to_ts[1]
1298*da0073e9SAndroid Build Coastguard Worker                f_ts_1 = flow_f_to_ts[1]
1299*da0073e9SAndroid Build Coastguard Worker                s_ts_2 = flow_s_to_ts[2]
1300*da0073e9SAndroid Build Coastguard Worker                f_ts_2 = flow_f_to_ts[2]
1301*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
1302*da0073e9SAndroid Build Coastguard Worker                    all(
1303*da0073e9SAndroid Build Coastguard Worker                        ts in ts_to_name.keys()
1304*da0073e9SAndroid Build Coastguard Worker                        for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2]
1305*da0073e9SAndroid Build Coastguard Worker                    )
1306*da0073e9SAndroid Build Coastguard Worker                )
1307*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
1308*da0073e9SAndroid Build Coastguard Worker                    ts_to_name[s_ts_1] == "aten::binary_cross_entropy_with_logits"
1309*da0073e9SAndroid Build Coastguard Worker                )
1310*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(ts_to_name[s_ts_2] == "aten::add")
1311*da0073e9SAndroid Build Coastguard Worker
1312*da0073e9SAndroid Build Coastguard Worker    def test_profiler_disable_fwd_bwd_link(self):
1313*da0073e9SAndroid Build Coastguard Worker        try:
1314*da0073e9SAndroid Build Coastguard Worker            torch._C._profiler._set_fwd_bwd_enabled_val(False)
1315*da0073e9SAndroid Build Coastguard Worker
1316*da0073e9SAndroid Build Coastguard Worker            with _profile(use_kineto=True) as prof:
1317*da0073e9SAndroid Build Coastguard Worker                t1, t2 = torch.ones(1, requires_grad=True), torch.ones(
1318*da0073e9SAndroid Build Coastguard Worker                    1, requires_grad=True
1319*da0073e9SAndroid Build Coastguard Worker                )
1320*da0073e9SAndroid Build Coastguard Worker                z = torch.add(t1, t2)
1321*da0073e9SAndroid Build Coastguard Worker                y = torch.ones(1)
1322*da0073e9SAndroid Build Coastguard Worker                loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
1323*da0073e9SAndroid Build Coastguard Worker                loss.backward()
1324*da0073e9SAndroid Build Coastguard Worker
1325*da0073e9SAndroid Build Coastguard Worker            with TemporaryFileName(mode="w+") as fname:
1326*da0073e9SAndroid Build Coastguard Worker                prof.export_chrome_trace(fname)
1327*da0073e9SAndroid Build Coastguard Worker                with open(fname) as f:
1328*da0073e9SAndroid Build Coastguard Worker                    j = json.load(f)
1329*da0073e9SAndroid Build Coastguard Worker                    events = j["traceEvents"]
1330*da0073e9SAndroid Build Coastguard Worker
1331*da0073e9SAndroid Build Coastguard Worker                    for e in events:
1332*da0073e9SAndroid Build Coastguard Worker                        self.assertNotEqual(e.get("cat", None), "fwdbwd")
1333*da0073e9SAndroid Build Coastguard Worker        finally:
1334*da0073e9SAndroid Build Coastguard Worker            torch._C._profiler._set_fwd_bwd_enabled_val(True)
1335*da0073e9SAndroid Build Coastguard Worker
1336*da0073e9SAndroid Build Coastguard Worker    # This test is broken on Windows, the likely reason is that kineto/CUPTI
1337*da0073e9SAndroid Build Coastguard Worker    # is not supported that particular environment. Once the CI stabilizes
1338*da0073e9SAndroid Build Coastguard Worker    # we can narrow the condition so Windows is checked as well (TODO)
1339*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not kineto_available(), "Kineto is required")
1340*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Test does not work on Windows")
1341*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
1342*da0073e9SAndroid Build Coastguard Worker    def test_profiler_cuda_sync_events(self):
1343*da0073e9SAndroid Build Coastguard Worker        device = torch.device("cuda:0")
1344*da0073e9SAndroid Build Coastguard Worker        t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device)
1345*da0073e9SAndroid Build Coastguard Worker
1346*da0073e9SAndroid Build Coastguard Worker        def workload() -> None:
1347*da0073e9SAndroid Build Coastguard Worker            torch.add(t1, t2)
1348*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
1349*da0073e9SAndroid Build Coastguard Worker            torch.add(t1, t2)
1350*da0073e9SAndroid Build Coastguard Worker
1351*da0073e9SAndroid Build Coastguard Worker        def trace_and_check(exp_config: Optional[_ExperimentalConfig]) -> None:
1352*da0073e9SAndroid Build Coastguard Worker            with _profile(
1353*da0073e9SAndroid Build Coastguard Worker                use_kineto=True,
1354*da0073e9SAndroid Build Coastguard Worker                use_cuda=True,
1355*da0073e9SAndroid Build Coastguard Worker                experimental_config=exp_config,
1356*da0073e9SAndroid Build Coastguard Worker            ) as prof:
1357*da0073e9SAndroid Build Coastguard Worker                workload()
1358*da0073e9SAndroid Build Coastguard Worker
1359*da0073e9SAndroid Build Coastguard Worker            with TemporaryFileName(mode="w+") as fname:
1360*da0073e9SAndroid Build Coastguard Worker                # fname = "/tmp/kineto_out.json"
1361*da0073e9SAndroid Build Coastguard Worker                prof.export_chrome_trace(fname)
1362*da0073e9SAndroid Build Coastguard Worker                with open(fname) as f:
1363*da0073e9SAndroid Build Coastguard Worker                    j = json.load(f)
1364*da0073e9SAndroid Build Coastguard Worker                    cats = {e.get("cat", None) for e in j["traceEvents"]}
1365*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
1366*da0073e9SAndroid Build Coastguard Worker                "cuda_sync" in cats,
1367*da0073e9SAndroid Build Coastguard Worker                "Expected to find cuda_sync event" f" found = {cats}",
1368*da0073e9SAndroid Build Coastguard Worker            )
1369*da0073e9SAndroid Build Coastguard Worker
1370*da0073e9SAndroid Build Coastguard Worker        print("Testing enable_cuda_sync_events in _ExperimentalConfig")
1371*da0073e9SAndroid Build Coastguard Worker        trace_and_check(exp_config=_ExperimentalConfig(enable_cuda_sync_events=True))
1372*da0073e9SAndroid Build Coastguard Worker
1373*da0073e9SAndroid Build Coastguard Worker        print("Testing _profiler._set_cuda_sync_enabled_val()")
1374*da0073e9SAndroid Build Coastguard Worker        try:
1375*da0073e9SAndroid Build Coastguard Worker            torch._C._profiler._set_cuda_sync_enabled_val(True)
1376*da0073e9SAndroid Build Coastguard Worker            trace_and_check(exp_config=None)
1377*da0073e9SAndroid Build Coastguard Worker        finally:
1378*da0073e9SAndroid Build Coastguard Worker            torch._C._profiler._set_cuda_sync_enabled_val(False)
1379*da0073e9SAndroid Build Coastguard Worker
1380*da0073e9SAndroid Build Coastguard Worker    def test_profiler_type(self):
1381*da0073e9SAndroid Build Coastguard Worker        profiler_type = torch._C._autograd._profiler_type
1382*da0073e9SAndroid Build Coastguard Worker        ActiveProfilerType = torch._C._profiler.ActiveProfilerType
1383*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(profiler_type(), ActiveProfilerType.NONE)
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker        # Autograd profiler
1386*da0073e9SAndroid Build Coastguard Worker        with _profile_legacy():
1387*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(profiler_type(), ActiveProfilerType.LEGACY)
1388*da0073e9SAndroid Build Coastguard Worker
1389*da0073e9SAndroid Build Coastguard Worker        # Kineto profiler
1390*da0073e9SAndroid Build Coastguard Worker        with profile():
1391*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(profiler_type(), ActiveProfilerType.KINETO)
1392*da0073e9SAndroid Build Coastguard Worker
1393*da0073e9SAndroid Build Coastguard Worker    def test_profiler_correlation_id(self):
1394*da0073e9SAndroid Build Coastguard Worker        """
1395*da0073e9SAndroid Build Coastguard Worker        We expect the correlation_id to be unique across multiple invokation of the profiler,
1396*da0073e9SAndroid Build Coastguard Worker        So we will reuse id_uniqueness_set.
1397*da0073e9SAndroid Build Coastguard Worker        """
1398*da0073e9SAndroid Build Coastguard Worker        id_uniqueness_set = set()
1399*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Sequential(
1400*da0073e9SAndroid Build Coastguard Worker            nn.Conv2d(16, 33, 18),
1401*da0073e9SAndroid Build Coastguard Worker            nn.ReLU(),
1402*da0073e9SAndroid Build Coastguard Worker            nn.Linear(243, 243),
1403*da0073e9SAndroid Build Coastguard Worker            nn.ReLU(),
1404*da0073e9SAndroid Build Coastguard Worker        )
1405*da0073e9SAndroid Build Coastguard Worker        inputs = torch.randn(40, 16, 18, 260)
1406*da0073e9SAndroid Build Coastguard Worker        uint32_max = 2**32 - 1
1407*da0073e9SAndroid Build Coastguard Worker        for i in range(5):
1408*da0073e9SAndroid Build Coastguard Worker            with profile() as prof:
1409*da0073e9SAndroid Build Coastguard Worker                model(inputs)
1410*da0073e9SAndroid Build Coastguard Worker            for event in prof.profiler.kineto_results.events():
1411*da0073e9SAndroid Build Coastguard Worker                corr_id = event.correlation_id()
1412*da0073e9SAndroid Build Coastguard Worker                if (corr_id) and event.device_type() == DeviceType.CPU:
1413*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(corr_id not in id_uniqueness_set)
1414*da0073e9SAndroid Build Coastguard Worker                    id_uniqueness_set.add(corr_id)
1415*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(corr_id < uint32_max)
1416*da0073e9SAndroid Build Coastguard Worker
1417*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_with_shapes(self):
1418*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4, 4)
1419*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(4, 4)
1420*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(4, 4)
1421*da0073e9SAndroid Build Coastguard Worker        inp = torch.nested.nested_tensor([a, b])
1422*da0073e9SAndroid Build Coastguard Worker        with torch.profiler.profile(record_shapes=True) as prof:
1423*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.linear(inp, c, None)
1424*da0073e9SAndroid Build Coastguard Worker        for e in prof.events():
1425*da0073e9SAndroid Build Coastguard Worker            if e.name in ("aten::mm", "aten::addmm"):
1426*da0073e9SAndroid Build Coastguard Worker                # intentionally vague tests to protect against possible future changes
1427*da0073e9SAndroid Build Coastguard Worker                # of mm to addmm or other impl, or changing internal order of args
1428*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(len(e.input_shapes) > 0)
1429*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(len(e.input_shapes[0]) > 0)
1430*da0073e9SAndroid Build Coastguard Worker
1431*da0073e9SAndroid Build Coastguard Worker    @patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"})
1432*da0073e9SAndroid Build Coastguard Worker    @patch.dict(os.environ, {"KINETO_DAEMON_INIT_DELAY_S": "1"})
1433*da0073e9SAndroid Build Coastguard Worker    def test_kineto_profiler_with_environment_variable(self):
1434*da0073e9SAndroid Build Coastguard Worker        script = """
1435*da0073e9SAndroid Build Coastguard Workerimport torch
1436*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
1437*da0073e9SAndroid Build Coastguard Workerfrom torch.profiler import supported_activities, profile
1438*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.profiler import KinetoStepTracker
1439*da0073e9SAndroid Build Coastguard Worker
1440*da0073e9SAndroid Build Coastguard Workerclass SimpleNet(nn.Module):
1441*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
1442*da0073e9SAndroid Build Coastguard Worker        super().__init__()
1443*da0073e9SAndroid Build Coastguard Worker        self.fc1 = nn.Linear(10, 5)
1444*da0073e9SAndroid Build Coastguard Worker        self.fc2 = nn.Linear(5, 2)
1445*da0073e9SAndroid Build Coastguard Worker
1446*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
1447*da0073e9SAndroid Build Coastguard Worker        return self.fc2(self.fc1(x))
1448*da0073e9SAndroid Build Coastguard Worker
1449*da0073e9SAndroid Build Coastguard Worker
1450*da0073e9SAndroid Build Coastguard Workerdef payload(use_cuda=False):
1451*da0073e9SAndroid Build Coastguard Worker    x = torch.randn(10, 10)
1452*da0073e9SAndroid Build Coastguard Worker    if use_cuda:
1453*da0073e9SAndroid Build Coastguard Worker        x = x.cuda()
1454*da0073e9SAndroid Build Coastguard Worker    y = torch.randn(10, 10)
1455*da0073e9SAndroid Build Coastguard Worker    if use_cuda:
1456*da0073e9SAndroid Build Coastguard Worker        y = y.cuda()
1457*da0073e9SAndroid Build Coastguard Worker    z = torch.mm(x, y)
1458*da0073e9SAndroid Build Coastguard Worker    z = z + y
1459*da0073e9SAndroid Build Coastguard Worker    if use_cuda:
1460*da0073e9SAndroid Build Coastguard Worker        z = z.cpu()
1461*da0073e9SAndroid Build Coastguard Worker
1462*da0073e9SAndroid Build Coastguard Workerniters = 8
1463*da0073e9SAndroid Build Coastguard Workeruse_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
1464*da0073e9SAndroid Build Coastguard Workernet = SimpleNet()
1465*da0073e9SAndroid Build Coastguard Workeropt = torch.optim.SGD(net.parameters(), lr=0.01)
1466*da0073e9SAndroid Build Coastguard Workeropt.zero_grad()
1467*da0073e9SAndroid Build Coastguard Workerinputs = torch.rand(10)
1468*da0073e9SAndroid Build Coastguard Worker
1469*da0073e9SAndroid Build Coastguard Workerwith profile(activities=supported_activities()):
1470*da0073e9SAndroid Build Coastguard Worker    payload(use_cuda=use_cuda)
1471*da0073e9SAndroid Build Coastguard Worker
1472*da0073e9SAndroid Build Coastguard Workerinitial_step = KinetoStepTracker.current_step()
1473*da0073e9SAndroid Build Coastguard Worker
1474*da0073e9SAndroid Build Coastguard Workerdef run_batch():
1475*da0073e9SAndroid Build Coastguard Worker    out = net(inputs)
1476*da0073e9SAndroid Build Coastguard Worker    loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
1477*da0073e9SAndroid Build Coastguard Worker    loss.backward()
1478*da0073e9SAndroid Build Coastguard Worker    opt.step()
1479*da0073e9SAndroid Build Coastguard Worker
1480*da0073e9SAndroid Build Coastguard Workerfor _ in range(niters):
1481*da0073e9SAndroid Build Coastguard Worker    run_batch()
1482*da0073e9SAndroid Build Coastguard Worker
1483*da0073e9SAndroid Build Coastguard Workerwith profile(
1484*da0073e9SAndroid Build Coastguard Worker    activities=supported_activities(),
1485*da0073e9SAndroid Build Coastguard Worker    schedule=torch.profiler.schedule(
1486*da0073e9SAndroid Build Coastguard Worker        wait=1,
1487*da0073e9SAndroid Build Coastguard Worker        warmup=1,
1488*da0073e9SAndroid Build Coastguard Worker        active=2),
1489*da0073e9SAndroid Build Coastguard Worker) as p:
1490*da0073e9SAndroid Build Coastguard Worker    for _ in range(niters):
1491*da0073e9SAndroid Build Coastguard Worker        run_batch()
1492*da0073e9SAndroid Build Coastguard Worker        p.step()
1493*da0073e9SAndroid Build Coastguard Workerassert KinetoStepTracker.current_step() == initial_step + 2 * niters
1494*da0073e9SAndroid Build Coastguard Worker"""
1495*da0073e9SAndroid Build Coastguard Worker        try:
1496*da0073e9SAndroid Build Coastguard Worker            subprocess.check_output(
1497*da0073e9SAndroid Build Coastguard Worker                [sys.executable, "-W", "always", "-c", script],
1498*da0073e9SAndroid Build Coastguard Worker                cwd=os.path.dirname(os.path.realpath(__file__)),
1499*da0073e9SAndroid Build Coastguard Worker            )
1500*da0073e9SAndroid Build Coastguard Worker        except subprocess.CalledProcessError as e:
1501*da0073e9SAndroid Build Coastguard Worker            if e.returncode != 0:
1502*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
1503*da0073e9SAndroid Build Coastguard Worker                    False,
1504*da0073e9SAndroid Build Coastguard Worker                    "Kineto is not working properly with the Dynolog environment variable",
1505*da0073e9SAndroid Build Coastguard Worker                )
1506*da0073e9SAndroid Build Coastguard Worker
1507*da0073e9SAndroid Build Coastguard Worker    def test_concrete_inputs_profiling(self):
1508*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, 6)
1509*da0073e9SAndroid Build Coastguard Worker        with profile(record_shapes=True) as p:
1510*da0073e9SAndroid Build Coastguard Worker            y = x.as_strided([4, 3], [1, 4])
1511*da0073e9SAndroid Build Coastguard Worker
1512*da0073e9SAndroid Build Coastguard Worker        found = False
1513*da0073e9SAndroid Build Coastguard Worker        for e in p.events():
1514*da0073e9SAndroid Build Coastguard Worker            if e.name in ("aten::as_strided"):
1515*da0073e9SAndroid Build Coastguard Worker                found = True
1516*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(len(e.input_shapes) > 0)
1517*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(len(e.concrete_inputs) > 0)
1518*da0073e9SAndroid Build Coastguard Worker                self.assertEqual([2, 6], e.input_shapes[0])
1519*da0073e9SAndroid Build Coastguard Worker                self.assertEqual([4, 3], e.concrete_inputs[1])
1520*da0073e9SAndroid Build Coastguard Worker                self.assertEqual([1, 4], e.concrete_inputs[2])
1521*da0073e9SAndroid Build Coastguard Worker
1522*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found, "Expected to find aten::as_strided but did not")
1523*da0073e9SAndroid Build Coastguard Worker
1524*da0073e9SAndroid Build Coastguard Worker    def test_concrete_inputs_profiling_toggling(self):
1525*da0073e9SAndroid Build Coastguard Worker        try:
1526*da0073e9SAndroid Build Coastguard Worker            for before, after in [(True, False), (False, True)]:
1527*da0073e9SAndroid Build Coastguard Worker                x = torch.rand(2, 6)
1528*da0073e9SAndroid Build Coastguard Worker                torch._C._profiler._set_record_concrete_inputs_enabled_val(before)
1529*da0073e9SAndroid Build Coastguard Worker                with profile(record_shapes=True) as p:
1530*da0073e9SAndroid Build Coastguard Worker                    y = x.as_strided([4, 3], [1, 4])
1531*da0073e9SAndroid Build Coastguard Worker                    torch._C._profiler._set_record_concrete_inputs_enabled_val(after)
1532*da0073e9SAndroid Build Coastguard Worker
1533*da0073e9SAndroid Build Coastguard Worker                found = False
1534*da0073e9SAndroid Build Coastguard Worker                for e in p.events():
1535*da0073e9SAndroid Build Coastguard Worker                    if e.name in ("aten::as_strided"):
1536*da0073e9SAndroid Build Coastguard Worker                        found = True
1537*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(len(e.input_shapes))
1538*da0073e9SAndroid Build Coastguard Worker
1539*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(found, "Expected to find aten::as_strided but did not")
1540*da0073e9SAndroid Build Coastguard Worker        finally:
1541*da0073e9SAndroid Build Coastguard Worker            torch._C._profiler._set_record_concrete_inputs_enabled_val(True)
1542*da0073e9SAndroid Build Coastguard Worker
1543*da0073e9SAndroid Build Coastguard Worker    def test_record_function_fast(self):
1544*da0073e9SAndroid Build Coastguard Worker        x, y = (torch.rand((4, 4)) for _ in range(2))
1545*da0073e9SAndroid Build Coastguard Worker        with profile(record_shapes=True) as p:
1546*da0073e9SAndroid Build Coastguard Worker            for _ in range(4):
1547*da0073e9SAndroid Build Coastguard Worker                # Test first with no optional args
1548*da0073e9SAndroid Build Coastguard Worker                with torch._C._profiler._RecordFunctionFast("add_test_fast_rf1"):
1549*da0073e9SAndroid Build Coastguard Worker                    x.add(y)
1550*da0073e9SAndroid Build Coastguard Worker
1551*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(
1552*da0073e9SAndroid Build Coastguard Worker            len([e for e in p.events() if e.name == "add_test_fast_rf1"]), 4
1553*da0073e9SAndroid Build Coastguard Worker        )
1554*da0073e9SAndroid Build Coastguard Worker        for e in p.events():
1555*da0073e9SAndroid Build Coastguard Worker            if e.name == "add_test_fast_rf1":
1556*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(e.input_shapes == [])
1557*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(e.kwinputs == {})
1558*da0073e9SAndroid Build Coastguard Worker        with profile(record_shapes=True) as p:
1559*da0073e9SAndroid Build Coastguard Worker            # add optional args
1560*da0073e9SAndroid Build Coastguard Worker            cm = torch._C._profiler._RecordFunctionFast(
1561*da0073e9SAndroid Build Coastguard Worker                "add_test_fast_rf2", [x, y], {"stream": 0, "grid": "lambda x : x + 1"}
1562*da0073e9SAndroid Build Coastguard Worker            )
1563*da0073e9SAndroid Build Coastguard Worker            for _ in range(4):
1564*da0073e9SAndroid Build Coastguard Worker                with cm:
1565*da0073e9SAndroid Build Coastguard Worker                    x.add(y)
1566*da0073e9SAndroid Build Coastguard Worker
1567*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(
1568*da0073e9SAndroid Build Coastguard Worker            len([e for e in p.events() if e.name == "add_test_fast_rf2"]), 4
1569*da0073e9SAndroid Build Coastguard Worker        )
1570*da0073e9SAndroid Build Coastguard Worker
1571*da0073e9SAndroid Build Coastguard Worker        for e in p.events():
1572*da0073e9SAndroid Build Coastguard Worker            if e.name == "add_test_fast_rf2":
1573*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(e.input_shapes == [[4, 4], [4, 4]])
1574*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(e.kwinputs == {"stream": 0, "grid": "lambda x : x + 1"})
1575*da0073e9SAndroid Build Coastguard Worker
1576*da0073e9SAndroid Build Coastguard Worker        with profile(record_shapes=True) as p:
1577*da0073e9SAndroid Build Coastguard Worker            cm = torch._C._profiler._RecordFunctionFast(
1578*da0073e9SAndroid Build Coastguard Worker                "add_test_fast_rf3", input_values=["hi"], keyword_values={"hi": "hello"}
1579*da0073e9SAndroid Build Coastguard Worker            )
1580*da0073e9SAndroid Build Coastguard Worker            for _ in range(4):
1581*da0073e9SAndroid Build Coastguard Worker                try:
1582*da0073e9SAndroid Build Coastguard Worker                    with cm:
1583*da0073e9SAndroid Build Coastguard Worker                        x.add(y)
1584*da0073e9SAndroid Build Coastguard Worker                        raise ValueError
1585*da0073e9SAndroid Build Coastguard Worker                        x.relu()
1586*da0073e9SAndroid Build Coastguard Worker                except ValueError:
1587*da0073e9SAndroid Build Coastguard Worker                    pass
1588*da0073e9SAndroid Build Coastguard Worker
1589*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(
1590*da0073e9SAndroid Build Coastguard Worker            len([e for e in p.events() if e.name == "add_test_fast_rf3"]), 4
1591*da0073e9SAndroid Build Coastguard Worker        )
1592*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(any((e.name and "relu" in e.name) for e in p.events()))
1593*da0073e9SAndroid Build Coastguard Worker
1594*da0073e9SAndroid Build Coastguard Worker        for e in p.events():
1595*da0073e9SAndroid Build Coastguard Worker            if e.name == "add_test_fast_rf3":
1596*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(e.input_shapes == [[]])
1597*da0073e9SAndroid Build Coastguard Worker
1598*da0073e9SAndroid Build Coastguard Worker        with profile() as p:
1599*da0073e9SAndroid Build Coastguard Worker            for _ in range(4):
1600*da0073e9SAndroid Build Coastguard Worker                with torch._C._profiler._RecordFunctionFast(
1601*da0073e9SAndroid Build Coastguard Worker                    "add_test_fast_rf4", [x, y]
1602*da0073e9SAndroid Build Coastguard Worker                ):
1603*da0073e9SAndroid Build Coastguard Worker                    x.add(y)
1604*da0073e9SAndroid Build Coastguard Worker                    with torch._C._profiler._RecordFunctionFast("add_test_fast_rf5"):
1605*da0073e9SAndroid Build Coastguard Worker                        x.relu()
1606*da0073e9SAndroid Build Coastguard Worker
1607*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(
1608*da0073e9SAndroid Build Coastguard Worker            len([e for e in p.events() if e.name == "add_test_fast_rf4"]), 4
1609*da0073e9SAndroid Build Coastguard Worker        )
1610*da0073e9SAndroid Build Coastguard Worker
1611*da0073e9SAndroid Build Coastguard Worker        for e in p.events():
1612*da0073e9SAndroid Build Coastguard Worker            if e.name == "add_test_fast_rf4":
1613*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(e.input_shapes == [])
1614*da0073e9SAndroid Build Coastguard Worker
1615*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(
1616*da0073e9SAndroid Build Coastguard Worker            len([e for e in p.events() if e.name == "add_test_fast_rf5"]), 4
1617*da0073e9SAndroid Build Coastguard Worker        )
1618*da0073e9SAndroid Build Coastguard Worker
1619*da0073e9SAndroid Build Coastguard Worker        with profile(record_shapes=True) as p:
1620*da0073e9SAndroid Build Coastguard Worker            # test optional args with tuple
1621*da0073e9SAndroid Build Coastguard Worker            cm = torch._C._profiler._RecordFunctionFast(
1622*da0073e9SAndroid Build Coastguard Worker                "add_test_fast_rf6",
1623*da0073e9SAndroid Build Coastguard Worker                (
1624*da0073e9SAndroid Build Coastguard Worker                    x,
1625*da0073e9SAndroid Build Coastguard Worker                    y,
1626*da0073e9SAndroid Build Coastguard Worker                ),
1627*da0073e9SAndroid Build Coastguard Worker            )
1628*da0073e9SAndroid Build Coastguard Worker            for _ in range(4):
1629*da0073e9SAndroid Build Coastguard Worker                with cm:
1630*da0073e9SAndroid Build Coastguard Worker                    x.add(y)
1631*da0073e9SAndroid Build Coastguard Worker
1632*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(
1633*da0073e9SAndroid Build Coastguard Worker            len([e for e in p.events() if e.name == "add_test_fast_rf6"]), 4
1634*da0073e9SAndroid Build Coastguard Worker        )
1635*da0073e9SAndroid Build Coastguard Worker
1636*da0073e9SAndroid Build Coastguard Worker        for e in p.events():
1637*da0073e9SAndroid Build Coastguard Worker            if e.name == "add_test_fast_rf6":
1638*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(e.input_shapes == [[4, 4], [4, 4]])
1639*da0073e9SAndroid Build Coastguard Worker
1640*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1641*da0073e9SAndroid Build Coastguard Worker    def test_profiler_op_event_kwargs(self):
1642*da0073e9SAndroid Build Coastguard Worker        x, y = (torch.rand((4, 4)) for _ in range(2))
1643*da0073e9SAndroid Build Coastguard Worker        with profile(record_shapes=True) as p:
1644*da0073e9SAndroid Build Coastguard Worker            cm = torch._C._profiler._RecordFunctionFast(
1645*da0073e9SAndroid Build Coastguard Worker                "add_test_kwinputs",
1646*da0073e9SAndroid Build Coastguard Worker                [x, y],
1647*da0073e9SAndroid Build Coastguard Worker                {"stream": 0, "grid": "lambda x : x + 1", "debug": 'debug"'},
1648*da0073e9SAndroid Build Coastguard Worker            )
1649*da0073e9SAndroid Build Coastguard Worker            for _ in range(4):
1650*da0073e9SAndroid Build Coastguard Worker                with cm:
1651*da0073e9SAndroid Build Coastguard Worker                    x.add(y)
1652*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName(mode="w+") as fname:
1653*da0073e9SAndroid Build Coastguard Worker            p.export_chrome_trace(fname)
1654*da0073e9SAndroid Build Coastguard Worker            with open(fname) as f:
1655*da0073e9SAndroid Build Coastguard Worker                j = json.load(f)
1656*da0073e9SAndroid Build Coastguard Worker                op_events = [
1657*da0073e9SAndroid Build Coastguard Worker                    e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op"
1658*da0073e9SAndroid Build Coastguard Worker                ]
1659*da0073e9SAndroid Build Coastguard Worker                for e in op_events:
1660*da0073e9SAndroid Build Coastguard Worker                    if e["name"] == "add_test_kwinputs":
1661*da0073e9SAndroid Build Coastguard Worker                        args = e["args"]
1662*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue("stream" in args)
1663*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue("grid" in args)
1664*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(args["stream"] == 0)
1665*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(args["grid"] == "lambda x : x + 1")
1666*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(args["debug"] == "None")
1667*da0073e9SAndroid Build Coastguard Worker
1668*da0073e9SAndroid Build Coastguard Worker    def test_is_profiler_enabled(self):
1669*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.autograd.profiler._is_profiler_enabled)
1670*da0073e9SAndroid Build Coastguard Worker
1671*da0073e9SAndroid Build Coastguard Worker        with profile() as p:
1672*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.autograd.profiler._is_profiler_enabled)
1673*da0073e9SAndroid Build Coastguard Worker
1674*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.autograd.profiler._is_profiler_enabled)
1675*da0073e9SAndroid Build Coastguard Worker
1676*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.profiler.profile() as p:
1677*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.autograd.profiler._is_profiler_enabled)
1678*da0073e9SAndroid Build Coastguard Worker
1679*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.autograd.profiler._is_profiler_enabled)
1680*da0073e9SAndroid Build Coastguard Worker
1681*da0073e9SAndroid Build Coastguard Worker    def test_guarded_record_function_fast(self):
1682*da0073e9SAndroid Build Coastguard Worker        x, y = (torch.rand((4, 4)) for _ in range(2))
1683*da0073e9SAndroid Build Coastguard Worker
1684*da0073e9SAndroid Build Coastguard Worker        with profile() as p:
1685*da0073e9SAndroid Build Coastguard Worker            cm = torch._C._profiler._RecordFunctionFast("guarded_rff")
1686*da0073e9SAndroid Build Coastguard Worker            for _ in range(4):
1687*da0073e9SAndroid Build Coastguard Worker                if torch.autograd.profiler._is_profiler_enabled:
1688*da0073e9SAndroid Build Coastguard Worker                    with cm:
1689*da0073e9SAndroid Build Coastguard Worker                        x.add(y)
1690*da0073e9SAndroid Build Coastguard Worker                else:
1691*da0073e9SAndroid Build Coastguard Worker                    x.add(y)
1692*da0073e9SAndroid Build Coastguard Worker
1693*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(
1694*da0073e9SAndroid Build Coastguard Worker            len([e for e in p.events() if e.name == "guarded_rff"]), 4
1695*da0073e9SAndroid Build Coastguard Worker        )
1696*da0073e9SAndroid Build Coastguard Worker
1697*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
1698*da0073e9SAndroid Build Coastguard Worker    def test_event_list(self):
1699*da0073e9SAndroid Build Coastguard Worker        # AFAIK event list is part of legacy profiler and/or used when kineto is not available.
1700*da0073e9SAndroid Build Coastguard Worker        # This test has basic sanity checks to test against obvious regressions.
1701*da0073e9SAndroid Build Coastguard Worker        x, y = (torch.rand((4, 4), requires_grad=True, device="cuda") for _ in range(2))
1702*da0073e9SAndroid Build Coastguard Worker        with profile(with_stack=True) as p:
1703*da0073e9SAndroid Build Coastguard Worker            z = (x @ y).relu().sum()
1704*da0073e9SAndroid Build Coastguard Worker            z.backward()
1705*da0073e9SAndroid Build Coastguard Worker
1706*da0073e9SAndroid Build Coastguard Worker        event_list = torch.autograd.profiler_util.EventList(p.events())
1707*da0073e9SAndroid Build Coastguard Worker        # event_list._build_tree()
1708*da0073e9SAndroid Build Coastguard Worker
1709*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName(mode="w+") as fname:
1710*da0073e9SAndroid Build Coastguard Worker            event_list.export_chrome_trace(fname)
1711*da0073e9SAndroid Build Coastguard Worker            with open(fname) as f:
1712*da0073e9SAndroid Build Coastguard Worker                json.load(f)
1713*da0073e9SAndroid Build Coastguard Worker
1714*da0073e9SAndroid Build Coastguard Worker        event_list.table()
1715*da0073e9SAndroid Build Coastguard Worker
1716*da0073e9SAndroid Build Coastguard Worker    def _check_all_gpu_present(self, gpu_dict, max_gpu_count):
1717*da0073e9SAndroid Build Coastguard Worker        for i in range(0, max_gpu_count):
1718*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(gpu_dict["GPU " + str(i)], 1)
1719*da0073e9SAndroid Build Coastguard Worker
1720*da0073e9SAndroid Build Coastguard Worker    # Do json sanity testing. Checks that all events are between profiler start and end
1721*da0073e9SAndroid Build Coastguard Worker    # also checks to see that GPU values are present in trace if cuda is used
1722*da0073e9SAndroid Build Coastguard Worker    def _validate_basic_json(self, traceEvents, cuda_available=False):
1723*da0073e9SAndroid Build Coastguard Worker        MAX_GPU_COUNT = 8
1724*da0073e9SAndroid Build Coastguard Worker        PROFILER_IDX = -4
1725*da0073e9SAndroid Build Coastguard Worker        RECORD_END = -1
1726*da0073e9SAndroid Build Coastguard Worker        RECORD_START = -2
1727*da0073e9SAndroid Build Coastguard Worker        traceEventProfiler = traceEvents[PROFILER_IDX]
1728*da0073e9SAndroid Build Coastguard Worker
1729*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(traceEventProfiler["name"] == "PyTorch Profiler (0)")
1730*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(traceEvents[RECORD_END]["name"] == "Record Window End")
1731*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
1732*da0073e9SAndroid Build Coastguard Worker            traceEvents[RECORD_START]["name"] == "Iteration Start: PyTorch Profiler"
1733*da0073e9SAndroid Build Coastguard Worker        )
1734*da0073e9SAndroid Build Coastguard Worker        # check that the profiler starts/ends within the record interval
1735*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(
1736*da0073e9SAndroid Build Coastguard Worker            traceEventProfiler["ts"],
1737*da0073e9SAndroid Build Coastguard Worker            traceEvents[RECORD_START]["ts"],
1738*da0073e9SAndroid Build Coastguard Worker            "Profiler starts before record!",
1739*da0073e9SAndroid Build Coastguard Worker        )
1740*da0073e9SAndroid Build Coastguard Worker        self.assertLessEqual(
1741*da0073e9SAndroid Build Coastguard Worker            traceEventProfiler["ts"] + traceEventProfiler["dur"],
1742*da0073e9SAndroid Build Coastguard Worker            traceEvents[RECORD_END]["ts"],
1743*da0073e9SAndroid Build Coastguard Worker            "Profiler ends after record end!",
1744*da0073e9SAndroid Build Coastguard Worker        )
1745*da0073e9SAndroid Build Coastguard Worker
1746*da0073e9SAndroid Build Coastguard Worker        gpu_dict = collections.defaultdict(int)
1747*da0073e9SAndroid Build Coastguard Worker        for i, traceEvent in enumerate(traceEvents):
1748*da0073e9SAndroid Build Coastguard Worker            if (
1749*da0073e9SAndroid Build Coastguard Worker                i == len(traceEvents) + RECORD_END
1750*da0073e9SAndroid Build Coastguard Worker                or i == len(traceEvents) + RECORD_START
1751*da0073e9SAndroid Build Coastguard Worker            ):
1752*da0073e9SAndroid Build Coastguard Worker                continue
1753*da0073e9SAndroid Build Coastguard Worker            # make sure all valid trace events are within the bounds of the profiler
1754*da0073e9SAndroid Build Coastguard Worker            if "ts" in traceEvent:
1755*da0073e9SAndroid Build Coastguard Worker                self.assertGreaterEqual(
1756*da0073e9SAndroid Build Coastguard Worker                    traceEvent["ts"],
1757*da0073e9SAndroid Build Coastguard Worker                    traceEventProfiler["ts"],
1758*da0073e9SAndroid Build Coastguard Worker                    "Trace event is out of bounds",
1759*da0073e9SAndroid Build Coastguard Worker                )
1760*da0073e9SAndroid Build Coastguard Worker            # some python events seem to go a little past record end probably because
1761*da0073e9SAndroid Build Coastguard Worker            # of some clock inaccuracies so just compare events ending to RECORD_END
1762*da0073e9SAndroid Build Coastguard Worker            if "dur" in traceEvent:
1763*da0073e9SAndroid Build Coastguard Worker                self.assertLessEqual(
1764*da0073e9SAndroid Build Coastguard Worker                    traceEvent["ts"] + traceEvent["dur"],
1765*da0073e9SAndroid Build Coastguard Worker                    traceEvents[RECORD_END]["ts"],
1766*da0073e9SAndroid Build Coastguard Worker                    "Trace event ends too late!",
1767*da0073e9SAndroid Build Coastguard Worker                )
1768*da0073e9SAndroid Build Coastguard Worker            gpu_value = traceEvent.get("args", {}).get("labels", None)
1769*da0073e9SAndroid Build Coastguard Worker            if gpu_value and "GPU" in gpu_value:
1770*da0073e9SAndroid Build Coastguard Worker                gpu_dict[gpu_value] += 1
1771*da0073e9SAndroid Build Coastguard Worker                # Max PID offset is 5M, based from pytorch/kineto include header:
1772*da0073e9SAndroid Build Coastguard Worker                # https://github.com/pytorch/kineto/blob/8681ff11e1fa54da39023076c5c43eddd87b7a8a/libkineto/include/output_base.h#L35
1773*da0073e9SAndroid Build Coastguard Worker                kExceedMaxPid = 5000000
1774*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
1775*da0073e9SAndroid Build Coastguard Worker                    traceEvents[i + 1]["args"]["sort_index"]
1776*da0073e9SAndroid Build Coastguard Worker                    == kExceedMaxPid + int(gpu_value.split()[1])
1777*da0073e9SAndroid Build Coastguard Worker                )
1778*da0073e9SAndroid Build Coastguard Worker
1779*da0073e9SAndroid Build Coastguard Worker        # TODO add checking gpu count if cpuOnly_ is true or not
1780*da0073e9SAndroid Build Coastguard Worker
1781*da0073e9SAndroid Build Coastguard Worker    def _test_chrome_trace_basic_helper(self, with_cuda=False):
1782*da0073e9SAndroid Build Coastguard Worker        if with_cuda:
1783*da0073e9SAndroid Build Coastguard Worker            device = "cuda"
1784*da0073e9SAndroid Build Coastguard Worker        else:
1785*da0073e9SAndroid Build Coastguard Worker            device = "cpu"
1786*da0073e9SAndroid Build Coastguard Worker        x, y = (torch.rand(4, 4).to(device) for _ in range(2))
1787*da0073e9SAndroid Build Coastguard Worker
1788*da0073e9SAndroid Build Coastguard Worker        with profile(with_stack=True) as p:
1789*da0073e9SAndroid Build Coastguard Worker            torch.add(x, y)
1790*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName(mode="w+") as fname:
1791*da0073e9SAndroid Build Coastguard Worker            p.export_chrome_trace(fname)
1792*da0073e9SAndroid Build Coastguard Worker            with open(fname) as f:
1793*da0073e9SAndroid Build Coastguard Worker                report = json.load(f)
1794*da0073e9SAndroid Build Coastguard Worker                self._validate_basic_json(report["traceEvents"], with_cuda)
1795*da0073e9SAndroid Build Coastguard Worker
1796*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not kineto_available(), "Kineto is required")
1797*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1798*da0073e9SAndroid Build Coastguard Worker    def test_basic_chrome_trace(self):
1799*da0073e9SAndroid Build Coastguard Worker        self._test_chrome_trace_basic_helper()
1800*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
1801*da0073e9SAndroid Build Coastguard Worker            self._test_chrome_trace_basic_helper(with_cuda=True)
1802*da0073e9SAndroid Build Coastguard Worker
1803*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1804*da0073e9SAndroid Build Coastguard Worker    def test_profiler_time_scale(self):
1805*da0073e9SAndroid Build Coastguard Worker        MARGIN_ERROR = 0.5
1806*da0073e9SAndroid Build Coastguard Worker        SEC_TO_US = 1000 * 1000
1807*da0073e9SAndroid Build Coastguard Worker        WAIT_TIME = 10
1808*da0073e9SAndroid Build Coastguard Worker        with profile() as p:
1809*da0073e9SAndroid Build Coastguard Worker            with torch.profiler.record_function("test_span"):
1810*da0073e9SAndroid Build Coastguard Worker                for i in range(WAIT_TIME):
1811*da0073e9SAndroid Build Coastguard Worker                    torch.rand(4, 4)
1812*da0073e9SAndroid Build Coastguard Worker                    time.sleep(1)
1813*da0073e9SAndroid Build Coastguard Worker        events = p.events()
1814*da0073e9SAndroid Build Coastguard Worker
1815*da0073e9SAndroid Build Coastguard Worker        # make sure function events are scaled appropriately
1816*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(events[0].name == "test_span")
1817*da0073e9SAndroid Build Coastguard Worker        test_span = events[0]
1818*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(
1819*da0073e9SAndroid Build Coastguard Worker            test_span.cpu_time / SEC_TO_US,
1820*da0073e9SAndroid Build Coastguard Worker            WAIT_TIME - MARGIN_ERROR,
1821*da0073e9SAndroid Build Coastguard Worker            "event out of range",
1822*da0073e9SAndroid Build Coastguard Worker        )
1823*da0073e9SAndroid Build Coastguard Worker        self.assertLessEqual(
1824*da0073e9SAndroid Build Coastguard Worker            test_span.cpu_time / SEC_TO_US,
1825*da0073e9SAndroid Build Coastguard Worker            WAIT_TIME + MARGIN_ERROR,
1826*da0073e9SAndroid Build Coastguard Worker            "event out of range",
1827*da0073e9SAndroid Build Coastguard Worker        )
1828*da0073e9SAndroid Build Coastguard Worker
1829*da0073e9SAndroid Build Coastguard Worker        # make sure tracing is scaled appropriately
1830*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName(mode="w+") as fname:
1831*da0073e9SAndroid Build Coastguard Worker            p.export_chrome_trace(fname)
1832*da0073e9SAndroid Build Coastguard Worker            with open(fname) as f:
1833*da0073e9SAndroid Build Coastguard Worker                report = json.load(f)
1834*da0073e9SAndroid Build Coastguard Worker            events = report["traceEvents"]
1835*da0073e9SAndroid Build Coastguard Worker            for event in events:
1836*da0073e9SAndroid Build Coastguard Worker                if event["name"] == "test_span":
1837*da0073e9SAndroid Build Coastguard Worker                    self.assertGreaterEqual(
1838*da0073e9SAndroid Build Coastguard Worker                        event["dur"] / SEC_TO_US,
1839*da0073e9SAndroid Build Coastguard Worker                        WAIT_TIME - MARGIN_ERROR,
1840*da0073e9SAndroid Build Coastguard Worker                        "profiling out of range",
1841*da0073e9SAndroid Build Coastguard Worker                    )
1842*da0073e9SAndroid Build Coastguard Worker                    self.assertLessEqual(
1843*da0073e9SAndroid Build Coastguard Worker                        event["dur"] / SEC_TO_US,
1844*da0073e9SAndroid Build Coastguard Worker                        WAIT_TIME + MARGIN_ERROR,
1845*da0073e9SAndroid Build Coastguard Worker                        "profiling out of range",
1846*da0073e9SAndroid Build Coastguard Worker                    )
1847*da0073e9SAndroid Build Coastguard Worker
1848*da0073e9SAndroid Build Coastguard Worker    def _schedule_helper(self, warmup, active, repeat, acc_events=True):
1849*da0073e9SAndroid Build Coastguard Worker        with profile(
1850*da0073e9SAndroid Build Coastguard Worker            schedule=torch.profiler.schedule(
1851*da0073e9SAndroid Build Coastguard Worker                skip_first=0,
1852*da0073e9SAndroid Build Coastguard Worker                wait=0,
1853*da0073e9SAndroid Build Coastguard Worker                warmup=warmup,
1854*da0073e9SAndroid Build Coastguard Worker                active=active,
1855*da0073e9SAndroid Build Coastguard Worker                repeat=repeat,
1856*da0073e9SAndroid Build Coastguard Worker            ),
1857*da0073e9SAndroid Build Coastguard Worker            acc_events=acc_events,
1858*da0073e9SAndroid Build Coastguard Worker        ) as prof:
1859*da0073e9SAndroid Build Coastguard Worker            for i in range(100):
1860*da0073e9SAndroid Build Coastguard Worker                torch.add(1, 2)
1861*da0073e9SAndroid Build Coastguard Worker                prof.step()
1862*da0073e9SAndroid Build Coastguard Worker        # print(prof.key_averages())
1863*da0073e9SAndroid Build Coastguard Worker        for ev in prof.key_averages():
1864*da0073e9SAndroid Build Coastguard Worker            if ev.key == "aten::add":
1865*da0073e9SAndroid Build Coastguard Worker                return ev.count
1866*da0073e9SAndroid Build Coastguard Worker        return 0
1867*da0073e9SAndroid Build Coastguard Worker
1868*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1869*da0073e9SAndroid Build Coastguard Worker    def test_schedule_function_count(self):
1870*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self._schedule_helper(warmup=0, active=1, repeat=1), 1)
1871*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self._schedule_helper(warmup=0, active=5, repeat=0), 100)
1872*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self._schedule_helper(warmup=0, active=5, repeat=10), 50)
1873*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self._schedule_helper(warmup=1, active=5, repeat=0), 83)
1874*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self._schedule_helper(warmup=10, active=10, repeat=4), 40)
1875*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self._schedule_helper(warmup=50, active=1, repeat=0), 1)
1876*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1877*da0073e9SAndroid Build Coastguard Worker            self._schedule_helper(warmup=0, active=5, repeat=0, acc_events=False), 0
1878*da0073e9SAndroid Build Coastguard Worker        )
1879*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1880*da0073e9SAndroid Build Coastguard Worker            self._schedule_helper(warmup=10, active=10, repeat=4, acc_events=False), 10
1881*da0073e9SAndroid Build Coastguard Worker        )
1882*da0073e9SAndroid Build Coastguard Worker
1883*da0073e9SAndroid Build Coastguard Worker    def _step_helper_func(self, prof):
1884*da0073e9SAndroid Build Coastguard Worker        time.sleep(0.1)
1885*da0073e9SAndroid Build Coastguard Worker        torch.randn(1, 3, 224, 224)
1886*da0073e9SAndroid Build Coastguard Worker        prof.step()
1887*da0073e9SAndroid Build Coastguard Worker
1888*da0073e9SAndroid Build Coastguard Worker    def _partial_overlap(self, prof_step, step_helper_func):
1889*da0073e9SAndroid Build Coastguard Worker        p_start = prof_step["ts"]
1890*da0073e9SAndroid Build Coastguard Worker        p_end = prof_step["ts"] + prof_step["dur"]
1891*da0073e9SAndroid Build Coastguard Worker        h_start = step_helper_func["ts"]
1892*da0073e9SAndroid Build Coastguard Worker        h_end = step_helper_func["ts"] + step_helper_func["dur"]
1893*da0073e9SAndroid Build Coastguard Worker
1894*da0073e9SAndroid Build Coastguard Worker        if p_start < h_start and p_end < h_end and p_end > h_start:
1895*da0073e9SAndroid Build Coastguard Worker            return True
1896*da0073e9SAndroid Build Coastguard Worker        if p_start > h_start and p_start < h_end and p_end > h_end:
1897*da0073e9SAndroid Build Coastguard Worker            return True
1898*da0073e9SAndroid Build Coastguard Worker        return False
1899*da0073e9SAndroid Build Coastguard Worker
1900*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1901*da0073e9SAndroid Build Coastguard Worker    def test_cpu_annotation_overlap(self):
1902*da0073e9SAndroid Build Coastguard Worker        with torch.profiler.profile(
1903*da0073e9SAndroid Build Coastguard Worker            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
1904*da0073e9SAndroid Build Coastguard Worker            record_shapes=True,
1905*da0073e9SAndroid Build Coastguard Worker            with_stack=True,
1906*da0073e9SAndroid Build Coastguard Worker            schedule=torch.profiler.schedule(wait=0, warmup=0, active=5, repeat=1),
1907*da0073e9SAndroid Build Coastguard Worker        ) as prof:
1908*da0073e9SAndroid Build Coastguard Worker            for i in range(5):
1909*da0073e9SAndroid Build Coastguard Worker                self._step_helper_func(prof)
1910*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName(mode="w+") as fname:
1911*da0073e9SAndroid Build Coastguard Worker            prof.export_chrome_trace(fname)
1912*da0073e9SAndroid Build Coastguard Worker            prof_steps = []
1913*da0073e9SAndroid Build Coastguard Worker            step_helper_funcs = []
1914*da0073e9SAndroid Build Coastguard Worker            with open(fname) as f:
1915*da0073e9SAndroid Build Coastguard Worker                report = json.load(f)
1916*da0073e9SAndroid Build Coastguard Worker                for event in report["traceEvents"]:
1917*da0073e9SAndroid Build Coastguard Worker                    if "ProfilerStep" in event["name"]:
1918*da0073e9SAndroid Build Coastguard Worker                        prof_steps.append(event)
1919*da0073e9SAndroid Build Coastguard Worker                    if "step_helper_func" in event["name"]:
1920*da0073e9SAndroid Build Coastguard Worker                        step_helper_funcs.append(event)
1921*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(prof_steps), 5)
1922*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(step_helper_funcs), 5)
1923*da0073e9SAndroid Build Coastguard Worker            for i in range(0, len(step_helper_funcs)):
1924*da0073e9SAndroid Build Coastguard Worker                for j in range(0, len(step_helper_funcs)):
1925*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
1926*da0073e9SAndroid Build Coastguard Worker                        not self._partial_overlap(prof_steps[i], step_helper_funcs[j])
1927*da0073e9SAndroid Build Coastguard Worker                    )
1928*da0073e9SAndroid Build Coastguard Worker
1929*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1930*da0073e9SAndroid Build Coastguard Worker    def test_user_annotation(self):
1931*da0073e9SAndroid Build Coastguard Worker        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
1932*da0073e9SAndroid Build Coastguard Worker        with profile(activities=supported_activities()) as p:
1933*da0073e9SAndroid Build Coastguard Worker            with torch.profiler.record_function("test_user_annotation"):
1934*da0073e9SAndroid Build Coastguard Worker                self.payload(use_cuda=use_cuda)
1935*da0073e9SAndroid Build Coastguard Worker
1936*da0073e9SAndroid Build Coastguard Worker        for evt in p.key_averages():
1937*da0073e9SAndroid Build Coastguard Worker            if evt.key == "test_user_annotation":
1938*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(evt.is_user_annotation)
1939*da0073e9SAndroid Build Coastguard Worker            else:
1940*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(evt.is_user_annotation)
1941*da0073e9SAndroid Build Coastguard Worker
1942*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
1943*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1944*da0073e9SAndroid Build Coastguard Worker    def test_dynamic_toggle(self):
1945*da0073e9SAndroid Build Coastguard Worker        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p:
1946*da0073e9SAndroid Build Coastguard Worker            with torch.profiler.record_function("test_user_annotation"):
1947*da0073e9SAndroid Build Coastguard Worker                x, y = (torch.rand(4, 4).to("cuda") for _ in range(2))
1948*da0073e9SAndroid Build Coastguard Worker                torch.add(x, y)
1949*da0073e9SAndroid Build Coastguard Worker
1950*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(any("aten" in e.name for e in p.events()))
1951*da0073e9SAndroid Build Coastguard Worker
1952*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(any("cuda" in e.name for e in p.events()))
1953*da0073e9SAndroid Build Coastguard Worker
1954*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(any("kernel" in e.name for e in p.events()))
1955*da0073e9SAndroid Build Coastguard Worker
1956*da0073e9SAndroid Build Coastguard Worker        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p1:
1957*da0073e9SAndroid Build Coastguard Worker            p1.toggle_collection_dynamic(False, [ProfilerActivity.CUDA])
1958*da0073e9SAndroid Build Coastguard Worker            with torch.profiler.record_function("test_user_annotation"):
1959*da0073e9SAndroid Build Coastguard Worker                x, y = (torch.rand(4, 4).to("cuda") for _ in range(2))
1960*da0073e9SAndroid Build Coastguard Worker                torch.add(x, y)
1961*da0073e9SAndroid Build Coastguard Worker
1962*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(any("aten" in e.name for e in p1.events()))
1963*da0073e9SAndroid Build Coastguard Worker
1964*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all("cuda" not in e.name for e in p1.events()))
1965*da0073e9SAndroid Build Coastguard Worker
1966*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all("kernel" not in e.name for e in p1.events()))
1967*da0073e9SAndroid Build Coastguard Worker
1968*da0073e9SAndroid Build Coastguard Worker        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p2:
1969*da0073e9SAndroid Build Coastguard Worker            p2.toggle_collection_dynamic(
1970*da0073e9SAndroid Build Coastguard Worker                False, [ProfilerActivity.CUDA, ProfilerActivity.CPU]
1971*da0073e9SAndroid Build Coastguard Worker            )
1972*da0073e9SAndroid Build Coastguard Worker            with torch.profiler.record_function("test_user_annotation"):
1973*da0073e9SAndroid Build Coastguard Worker                x, y = (torch.rand(4, 4).to("cuda") for _ in range(2))
1974*da0073e9SAndroid Build Coastguard Worker                torch.add(x, y)
1975*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(p2.events()) == 0)
1976*da0073e9SAndroid Build Coastguard Worker
1977*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1978*da0073e9SAndroid Build Coastguard Worker    def test_lazy_build_tree(self):
1979*da0073e9SAndroid Build Coastguard Worker        with profile() as p:
1980*da0073e9SAndroid Build Coastguard Worker            self.payload()
1981*da0073e9SAndroid Build Coastguard Worker
1982*da0073e9SAndroid Build Coastguard Worker        stats = p._stats()
1983*da0073e9SAndroid Build Coastguard Worker        # Test that the tree is not built
1984*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(stats.function_events_build_tree_call_duration_us, 0)
1985*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(stats.number_of_events, 0)
1986*da0073e9SAndroid Build Coastguard Worker
1987*da0073e9SAndroid Build Coastguard Worker        # Test that the tree is built on demand
1988*da0073e9SAndroid Build Coastguard Worker        p.events()
1989*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(stats.function_events_build_tree_call_duration_us, 0)
1990*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(stats.number_of_events, 0)
1991*da0073e9SAndroid Build Coastguard Worker
1992*da0073e9SAndroid Build Coastguard Worker
1993*da0073e9SAndroid Build Coastguard Workerclass SimpleNet(nn.Module):
1994*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
1995*da0073e9SAndroid Build Coastguard Worker        super().__init__()
1996*da0073e9SAndroid Build Coastguard Worker        self.fc1 = nn.Linear(10, 5)
1997*da0073e9SAndroid Build Coastguard Worker        self.fc2 = nn.Linear(5, 2)
1998*da0073e9SAndroid Build Coastguard Worker
1999*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
2000*da0073e9SAndroid Build Coastguard Worker        return self.fc2(self.fc1(x))
2001*da0073e9SAndroid Build Coastguard Worker
2002*da0073e9SAndroid Build Coastguard Worker
2003*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
2004*da0073e9SAndroid Build Coastguard Workerclass MockKinetoEvent:
2005*da0073e9SAndroid Build Coastguard Worker    _name: str
2006*da0073e9SAndroid Build Coastguard Worker    _start_us: int
2007*da0073e9SAndroid Build Coastguard Worker    _duration_us: int
2008*da0073e9SAndroid Build Coastguard Worker    _linked_correlation_id: int
2009*da0073e9SAndroid Build Coastguard Worker    _device_type: int
2010*da0073e9SAndroid Build Coastguard Worker
2011*da0073e9SAndroid Build Coastguard Worker    @property
2012*da0073e9SAndroid Build Coastguard Worker    def name(self) -> str:
2013*da0073e9SAndroid Build Coastguard Worker        return self._name
2014*da0073e9SAndroid Build Coastguard Worker
2015*da0073e9SAndroid Build Coastguard Worker    def start_ns(self) -> int:
2016*da0073e9SAndroid Build Coastguard Worker        return self._start_us * 1000
2017*da0073e9SAndroid Build Coastguard Worker
2018*da0073e9SAndroid Build Coastguard Worker    def duration_ns(self) -> int:
2019*da0073e9SAndroid Build Coastguard Worker        return self._duration_us * 1000
2020*da0073e9SAndroid Build Coastguard Worker
2021*da0073e9SAndroid Build Coastguard Worker    def linked_correlation_id(self) -> int:
2022*da0073e9SAndroid Build Coastguard Worker        return self._linked_correlation_id
2023*da0073e9SAndroid Build Coastguard Worker
2024*da0073e9SAndroid Build Coastguard Worker    def device_type(self) -> DeviceType:
2025*da0073e9SAndroid Build Coastguard Worker        return DeviceType.CUDA if self._device_type == 1 else DeviceType.CPU
2026*da0073e9SAndroid Build Coastguard Worker
2027*da0073e9SAndroid Build Coastguard Worker
2028*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
2029*da0073e9SAndroid Build Coastguard Workerclass MockProfilerEvent:
2030*da0073e9SAndroid Build Coastguard Worker    _name: str
2031*da0073e9SAndroid Build Coastguard Worker    id: int
2032*da0073e9SAndroid Build Coastguard Worker    start_time_ns: int
2033*da0073e9SAndroid Build Coastguard Worker    duration_time_ns: int
2034*da0073e9SAndroid Build Coastguard Worker    correlation_id: int = 0
2035*da0073e9SAndroid Build Coastguard Worker    children: List["MockProfilerEvent"] = field(default_factory=list)
2036*da0073e9SAndroid Build Coastguard Worker    parent: Optional["MockProfilerEvent"] = None
2037*da0073e9SAndroid Build Coastguard Worker
2038*da0073e9SAndroid Build Coastguard Worker    @property
2039*da0073e9SAndroid Build Coastguard Worker    def end_time_ns(self):
2040*da0073e9SAndroid Build Coastguard Worker        return self.start_time_ns + self.duration_time_ns
2041*da0073e9SAndroid Build Coastguard Worker
2042*da0073e9SAndroid Build Coastguard Worker    @property
2043*da0073e9SAndroid Build Coastguard Worker    def name(self) -> str:
2044*da0073e9SAndroid Build Coastguard Worker        return self._name
2045*da0073e9SAndroid Build Coastguard Worker
2046*da0073e9SAndroid Build Coastguard Worker    def __post__init__(self, parent, children):
2047*da0073e9SAndroid Build Coastguard Worker        object.__setattr__(self, "parent", parent)
2048*da0073e9SAndroid Build Coastguard Worker        object.__setattr__(self, "children", children)
2049*da0073e9SAndroid Build Coastguard Worker
2050*da0073e9SAndroid Build Coastguard Worker
2051*da0073e9SAndroid Build Coastguard Workerclass MockNode:
2052*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name, children) -> None:
2053*da0073e9SAndroid Build Coastguard Worker        self.name = name
2054*da0073e9SAndroid Build Coastguard Worker        self.children = [MockNode(name, i) for name, i in children.items()]
2055*da0073e9SAndroid Build Coastguard Worker
2056*da0073e9SAndroid Build Coastguard Worker
2057*da0073e9SAndroid Build Coastguard Workerclass TestExperimentalUtils(TestCase):
2058*da0073e9SAndroid Build Coastguard Worker    def make_tree(self) -> List[MockNode]:
2059*da0073e9SAndroid Build Coastguard Worker        tree = {
2060*da0073e9SAndroid Build Coastguard Worker            "root_0": {
2061*da0073e9SAndroid Build Coastguard Worker                "1": {"2": {}},
2062*da0073e9SAndroid Build Coastguard Worker                "3": {
2063*da0073e9SAndroid Build Coastguard Worker                    "4": {},
2064*da0073e9SAndroid Build Coastguard Worker                    "5": {},
2065*da0073e9SAndroid Build Coastguard Worker                },
2066*da0073e9SAndroid Build Coastguard Worker            },
2067*da0073e9SAndroid Build Coastguard Worker            "root_1": {
2068*da0073e9SAndroid Build Coastguard Worker                "6": {},
2069*da0073e9SAndroid Build Coastguard Worker                "7": {},
2070*da0073e9SAndroid Build Coastguard Worker                "8": {
2071*da0073e9SAndroid Build Coastguard Worker                    "9": {"10": {}},
2072*da0073e9SAndroid Build Coastguard Worker                },
2073*da0073e9SAndroid Build Coastguard Worker            },
2074*da0073e9SAndroid Build Coastguard Worker        }
2075*da0073e9SAndroid Build Coastguard Worker        return [MockNode(name, i) for name, i in tree.items()]
2076*da0073e9SAndroid Build Coastguard Worker
2077*da0073e9SAndroid Build Coastguard Worker    def test_dfs(self) -> None:
2078*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2079*da0073e9SAndroid Build Coastguard Worker            " ".join(i.name for i in _utils.traverse_dfs(self.make_tree())),
2080*da0073e9SAndroid Build Coastguard Worker            "root_0 1 2 3 4 5 root_1 6 7 8 9 10",
2081*da0073e9SAndroid Build Coastguard Worker        )
2082*da0073e9SAndroid Build Coastguard Worker
2083*da0073e9SAndroid Build Coastguard Worker    def test_bfs(self) -> None:
2084*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2085*da0073e9SAndroid Build Coastguard Worker            " ".join(i.name for i in _utils.traverse_bfs(self.make_tree())),
2086*da0073e9SAndroid Build Coastguard Worker            "root_0 root_1 1 3 6 7 8 2 4 5 9 10",
2087*da0073e9SAndroid Build Coastguard Worker        )
2088*da0073e9SAndroid Build Coastguard Worker
2089*da0073e9SAndroid Build Coastguard Worker    @staticmethod
2090*da0073e9SAndroid Build Coastguard Worker    def generate_mock_profile():
2091*da0073e9SAndroid Build Coastguard Worker        cuda_events = [
2092*da0073e9SAndroid Build Coastguard Worker            MockKinetoEvent("cudaLaunchKernel", 400, 100, 1, 0),
2093*da0073e9SAndroid Build Coastguard Worker            MockKinetoEvent("cudaLaunchKernel", 500, 100, 2, 0),
2094*da0073e9SAndroid Build Coastguard Worker            MockKinetoEvent("cudaLaunchKernel", 600, 100, 3, 0),
2095*da0073e9SAndroid Build Coastguard Worker            MockKinetoEvent("cudaLaunchKernel", 700, 100, 4, 0),
2096*da0073e9SAndroid Build Coastguard Worker            MockKinetoEvent("cudaLaunchKernel", 800, 100, 5, 0),
2097*da0073e9SAndroid Build Coastguard Worker            MockKinetoEvent("cudaLaunchKernel", 1500, 100, 6, 0),
2098*da0073e9SAndroid Build Coastguard Worker            MockKinetoEvent("GPU", 900, 100, 1, 1),
2099*da0073e9SAndroid Build Coastguard Worker            MockKinetoEvent("GPU", 1000, 100, 2, 1),
2100*da0073e9SAndroid Build Coastguard Worker            MockKinetoEvent("GPU", 1100, 100, 3, 1),
2101*da0073e9SAndroid Build Coastguard Worker            MockKinetoEvent("GPU", 1200, 100, 4, 1),
2102*da0073e9SAndroid Build Coastguard Worker            MockKinetoEvent("GPU", 1300, 100, 5, 1),
2103*da0073e9SAndroid Build Coastguard Worker            MockKinetoEvent("GPU", 1700, 100, 6, 1),
2104*da0073e9SAndroid Build Coastguard Worker        ]
2105*da0073e9SAndroid Build Coastguard Worker        cpu_events = [
2106*da0073e9SAndroid Build Coastguard Worker            MockProfilerEvent("CPU (Before cudaLaunchKernel)", 1, 0, 100000),
2107*da0073e9SAndroid Build Coastguard Worker            MockProfilerEvent("CPU (Before cudaLaunchKernel)", 2, 100000, 100000),
2108*da0073e9SAndroid Build Coastguard Worker            MockProfilerEvent("CPU (Before cudaLaunchKernel)", 3, 200000, 100000),
2109*da0073e9SAndroid Build Coastguard Worker            MockProfilerEvent("CPU (Before cudaLaunchKernel)", 4, 300000, 100000),
2110*da0073e9SAndroid Build Coastguard Worker            MockProfilerEvent("CPU (After cudaLaunchKernel)", 5, 400000, 100000),
2111*da0073e9SAndroid Build Coastguard Worker            MockProfilerEvent("CPU (After cudaLaunchKernel)", 6, 500000, 100000),
2112*da0073e9SAndroid Build Coastguard Worker            MockProfilerEvent("CPU (After cudaLaunchKernel)", 7, 600000, 100000),
2113*da0073e9SAndroid Build Coastguard Worker            MockProfilerEvent("CPU (After cudaLaunchKernel)", 8, 700000, 100000),
2114*da0073e9SAndroid Build Coastguard Worker            MockProfilerEvent("CPU (After GPU)", 9, 800000, 100000),
2115*da0073e9SAndroid Build Coastguard Worker            MockProfilerEvent("CPU (After GPU)", 10, 900000, 100000),
2116*da0073e9SAndroid Build Coastguard Worker            MockProfilerEvent("CPU (After GPU)", 11, 1100000, 100000),
2117*da0073e9SAndroid Build Coastguard Worker            MockProfilerEvent("CPU (After GPU)", 12, 1200000, 500000),
2118*da0073e9SAndroid Build Coastguard Worker        ]
2119*da0073e9SAndroid Build Coastguard Worker
2120*da0073e9SAndroid Build Coastguard Worker        profiler = unittest.mock.Mock()
2121*da0073e9SAndroid Build Coastguard Worker        profiler.kineto_results = unittest.mock.Mock()
2122*da0073e9SAndroid Build Coastguard Worker        profiler.kineto_results.events = unittest.mock.Mock(return_value=cuda_events)
2123*da0073e9SAndroid Build Coastguard Worker        profiler.kineto_results.experimental_event_tree = unittest.mock.Mock(
2124*da0073e9SAndroid Build Coastguard Worker            return_value=cpu_events
2125*da0073e9SAndroid Build Coastguard Worker        )
2126*da0073e9SAndroid Build Coastguard Worker        return profiler
2127*da0073e9SAndroid Build Coastguard Worker
2128*da0073e9SAndroid Build Coastguard Worker    @staticmethod
2129*da0073e9SAndroid Build Coastguard Worker    def load_mock_profile():
2130*da0073e9SAndroid Build Coastguard Worker        accept = expecttest.ACCEPT
2131*da0073e9SAndroid Build Coastguard Worker        json_file_path = os.path.join(
2132*da0073e9SAndroid Build Coastguard Worker            os.path.dirname(os.path.realpath(__file__)),
2133*da0073e9SAndroid Build Coastguard Worker            "profiler_utils_mock_events.json",
2134*da0073e9SAndroid Build Coastguard Worker        )
2135*da0073e9SAndroid Build Coastguard Worker        if accept and torch.cuda.is_available():
2136*da0073e9SAndroid Build Coastguard Worker
2137*da0073e9SAndroid Build Coastguard Worker            def garbage_code(x):
2138*da0073e9SAndroid Build Coastguard Worker                for i in range(5):
2139*da0073e9SAndroid Build Coastguard Worker                    x[0, i] = i
2140*da0073e9SAndroid Build Coastguard Worker
2141*da0073e9SAndroid Build Coastguard Worker            x = torch.ones((4096, 4096), device="cuda")
2142*da0073e9SAndroid Build Coastguard Worker            x = x @ x
2143*da0073e9SAndroid Build Coastguard Worker            with profile(
2144*da0073e9SAndroid Build Coastguard Worker                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
2145*da0073e9SAndroid Build Coastguard Worker                record_shapes=True,
2146*da0073e9SAndroid Build Coastguard Worker                with_stack=True,
2147*da0073e9SAndroid Build Coastguard Worker            ) as prof:
2148*da0073e9SAndroid Build Coastguard Worker                for _ in range(5):
2149*da0073e9SAndroid Build Coastguard Worker                    x = x @ x
2150*da0073e9SAndroid Build Coastguard Worker                garbage_code(x)
2151*da0073e9SAndroid Build Coastguard Worker                for _ in range(5):
2152*da0073e9SAndroid Build Coastguard Worker                    x = x @ x
2153*da0073e9SAndroid Build Coastguard Worker
2154*da0073e9SAndroid Build Coastguard Worker            kineto_events = [
2155*da0073e9SAndroid Build Coastguard Worker                {
2156*da0073e9SAndroid Build Coastguard Worker                    "_name": e.name,
2157*da0073e9SAndroid Build Coastguard Worker                    "_start_ns": e.start_ns(),
2158*da0073e9SAndroid Build Coastguard Worker                    "_duration_ns": e.duration_ns(),
2159*da0073e9SAndroid Build Coastguard Worker                    "_linked_correlation_id": e.linked_correlation_id(),
2160*da0073e9SAndroid Build Coastguard Worker                    "_device_type": 1 if e.device_type() == DeviceType.CUDA else 0,
2161*da0073e9SAndroid Build Coastguard Worker                }
2162*da0073e9SAndroid Build Coastguard Worker                for e in prof.profiler.kineto_results.events()
2163*da0073e9SAndroid Build Coastguard Worker            ]
2164*da0073e9SAndroid Build Coastguard Worker
2165*da0073e9SAndroid Build Coastguard Worker            def EventTreeDFS(event_tree):
2166*da0073e9SAndroid Build Coastguard Worker                from collections import deque
2167*da0073e9SAndroid Build Coastguard Worker
2168*da0073e9SAndroid Build Coastguard Worker                stack = deque(event_tree)
2169*da0073e9SAndroid Build Coastguard Worker                while stack:
2170*da0073e9SAndroid Build Coastguard Worker                    curr_event = stack.pop()
2171*da0073e9SAndroid Build Coastguard Worker                    yield curr_event
2172*da0073e9SAndroid Build Coastguard Worker                    for child_event in curr_event.children:
2173*da0073e9SAndroid Build Coastguard Worker                        stack.append(child_event)
2174*da0073e9SAndroid Build Coastguard Worker
2175*da0073e9SAndroid Build Coastguard Worker            profiler_events = [
2176*da0073e9SAndroid Build Coastguard Worker                {
2177*da0073e9SAndroid Build Coastguard Worker                    "_name": e.name,
2178*da0073e9SAndroid Build Coastguard Worker                    "id": e.id,
2179*da0073e9SAndroid Build Coastguard Worker                    "start_time_ns": e.start_time_ns,
2180*da0073e9SAndroid Build Coastguard Worker                    "duration_time_ns": e.duration_time_ns,
2181*da0073e9SAndroid Build Coastguard Worker                    "correlation_id": e.correlation_id,
2182*da0073e9SAndroid Build Coastguard Worker                    "children": [child.id for child in e.children],
2183*da0073e9SAndroid Build Coastguard Worker                    "parent": e.parent.id if e.parent else None,
2184*da0073e9SAndroid Build Coastguard Worker                }
2185*da0073e9SAndroid Build Coastguard Worker                for e in EventTreeDFS(
2186*da0073e9SAndroid Build Coastguard Worker                    prof.profiler.kineto_results.experimental_event_tree()
2187*da0073e9SAndroid Build Coastguard Worker                )
2188*da0073e9SAndroid Build Coastguard Worker            ]
2189*da0073e9SAndroid Build Coastguard Worker
2190*da0073e9SAndroid Build Coastguard Worker            with open(json_file_path, "w") as f:
2191*da0073e9SAndroid Build Coastguard Worker                json.dump([kineto_events, profiler_events], f)
2192*da0073e9SAndroid Build Coastguard Worker
2193*da0073e9SAndroid Build Coastguard Worker        assert os.path.exists(json_file_path)
2194*da0073e9SAndroid Build Coastguard Worker        with open(json_file_path) as f:
2195*da0073e9SAndroid Build Coastguard Worker            kineto_events, profiler_events = json.load(f)
2196*da0073e9SAndroid Build Coastguard Worker
2197*da0073e9SAndroid Build Coastguard Worker        cuda_events = [MockKinetoEvent(*event.values()) for event in kineto_events]
2198*da0073e9SAndroid Build Coastguard Worker        cpu_events = []
2199*da0073e9SAndroid Build Coastguard Worker        id_map = {}
2200*da0073e9SAndroid Build Coastguard Worker        for e in profiler_events:
2201*da0073e9SAndroid Build Coastguard Worker            event = MockProfilerEvent(**e)
2202*da0073e9SAndroid Build Coastguard Worker            id_map[event.id] = event
2203*da0073e9SAndroid Build Coastguard Worker            cpu_events.append(event)
2204*da0073e9SAndroid Build Coastguard Worker        for event in cpu_events:
2205*da0073e9SAndroid Build Coastguard Worker            parent = None if event.parent is None else id_map[event.parent]
2206*da0073e9SAndroid Build Coastguard Worker            children = [id_map[child] for child in event.children]
2207*da0073e9SAndroid Build Coastguard Worker            event.__post__init__(parent, children)
2208*da0073e9SAndroid Build Coastguard Worker        cpu_events = [event for event in cpu_events if event.parent is None]
2209*da0073e9SAndroid Build Coastguard Worker        profiler = unittest.mock.Mock()
2210*da0073e9SAndroid Build Coastguard Worker        profiler.kineto_results = unittest.mock.Mock()
2211*da0073e9SAndroid Build Coastguard Worker        profiler.kineto_results.events = unittest.mock.Mock(return_value=cuda_events)
2212*da0073e9SAndroid Build Coastguard Worker        profiler.kineto_results.experimental_event_tree = unittest.mock.Mock(
2213*da0073e9SAndroid Build Coastguard Worker            return_value=cpu_events
2214*da0073e9SAndroid Build Coastguard Worker        )
2215*da0073e9SAndroid Build Coastguard Worker        return profiler
2216*da0073e9SAndroid Build Coastguard Worker
2217*da0073e9SAndroid Build Coastguard Worker    def test_utils_compute_self_time(self):
2218*da0073e9SAndroid Build Coastguard Worker        with profile() as prof:
2219*da0073e9SAndroid Build Coastguard Worker            t1, t2 = torch.ones(1, requires_grad=True), torch.ones(
2220*da0073e9SAndroid Build Coastguard Worker                1, requires_grad=True
2221*da0073e9SAndroid Build Coastguard Worker            )
2222*da0073e9SAndroid Build Coastguard Worker            z = torch.add(t1, t2)
2223*da0073e9SAndroid Build Coastguard Worker            y = torch.ones(1)
2224*da0073e9SAndroid Build Coastguard Worker            loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
2225*da0073e9SAndroid Build Coastguard Worker            loss.backward()
2226*da0073e9SAndroid Build Coastguard Worker        basic_eval = _utils.BasicEvaluation(prof.profiler)
2227*da0073e9SAndroid Build Coastguard Worker        metrics = basic_eval.metrics
2228*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(metrics) > 0)
2229*da0073e9SAndroid Build Coastguard Worker        for event_key, event_metrics in metrics.items():
2230*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
2231*da0073e9SAndroid Build Coastguard Worker                event_metrics.self_time_ns,
2232*da0073e9SAndroid Build Coastguard Worker                event_key.event.duration_time_ns
2233*da0073e9SAndroid Build Coastguard Worker                - sum(child.duration_time_ns for child in event_key.event.children),
2234*da0073e9SAndroid Build Coastguard Worker            )
2235*da0073e9SAndroid Build Coastguard Worker
2236*da0073e9SAndroid Build Coastguard Worker    def test_utils_intervals_overlap(self):
2237*da0073e9SAndroid Build Coastguard Worker        event = _utils.EventKey(MockProfilerEvent("Event 1", 1, 5, 5))
2238*da0073e9SAndroid Build Coastguard Worker        intervals = [
2239*da0073e9SAndroid Build Coastguard Worker            _utils.Interval(0, 9),
2240*da0073e9SAndroid Build Coastguard Worker            _utils.Interval(1, 2),
2241*da0073e9SAndroid Build Coastguard Worker            _utils.Interval(2, 3),
2242*da0073e9SAndroid Build Coastguard Worker            _utils.Interval(3, 4),
2243*da0073e9SAndroid Build Coastguard Worker            _utils.Interval(4, 5),
2244*da0073e9SAndroid Build Coastguard Worker            _utils.Interval(8, 12),
2245*da0073e9SAndroid Build Coastguard Worker        ]
2246*da0073e9SAndroid Build Coastguard Worker        print(event.intervals_overlap(intervals))
2247*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(event.intervals_overlap(intervals), 5)
2248*da0073e9SAndroid Build Coastguard Worker
2249*da0073e9SAndroid Build Coastguard Worker    def test_utils_compute_queue_depth(self):
2250*da0073e9SAndroid Build Coastguard Worker        def format_queue_depth(queue_depth_list, events):
2251*da0073e9SAndroid Build Coastguard Worker            res = ""
2252*da0073e9SAndroid Build Coastguard Worker            for data, event in zip(queue_depth_list, events):
2253*da0073e9SAndroid Build Coastguard Worker                res += f"{data.queue_depth} [{event.name}]\n"
2254*da0073e9SAndroid Build Coastguard Worker            return res
2255*da0073e9SAndroid Build Coastguard Worker
2256*da0073e9SAndroid Build Coastguard Worker        # We have to use Mock because time series data is too flaky to test
2257*da0073e9SAndroid Build Coastguard Worker        profiler = self.generate_mock_profile()
2258*da0073e9SAndroid Build Coastguard Worker        basic_evaluation = _utils.BasicEvaluation(profiler)
2259*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
2260*da0073e9SAndroid Build Coastguard Worker            format_queue_depth(
2261*da0073e9SAndroid Build Coastguard Worker                basic_evaluation.queue_depth_list, basic_evaluation.cuda_events
2262*da0073e9SAndroid Build Coastguard Worker            ),
2263*da0073e9SAndroid Build Coastguard Worker            """\
2264*da0073e9SAndroid Build Coastguard Worker1 [cudaLaunchKernel]
2265*da0073e9SAndroid Build Coastguard Worker2 [cudaLaunchKernel]
2266*da0073e9SAndroid Build Coastguard Worker3 [cudaLaunchKernel]
2267*da0073e9SAndroid Build Coastguard Worker4 [cudaLaunchKernel]
2268*da0073e9SAndroid Build Coastguard Worker5 [cudaLaunchKernel]
2269*da0073e9SAndroid Build Coastguard Worker4 [GPU]
2270*da0073e9SAndroid Build Coastguard Worker3 [GPU]
2271*da0073e9SAndroid Build Coastguard Worker2 [GPU]
2272*da0073e9SAndroid Build Coastguard Worker1 [GPU]
2273*da0073e9SAndroid Build Coastguard Worker0 [GPU]
2274*da0073e9SAndroid Build Coastguard Worker1 [cudaLaunchKernel]
2275*da0073e9SAndroid Build Coastguard Worker0 [GPU]
2276*da0073e9SAndroid Build Coastguard Worker""",
2277*da0073e9SAndroid Build Coastguard Worker        )
2278*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
2279*da0073e9SAndroid Build Coastguard Worker            format_queue_depth(
2280*da0073e9SAndroid Build Coastguard Worker                [basic_evaluation.metrics[k] for k in basic_evaluation.event_keys],
2281*da0073e9SAndroid Build Coastguard Worker                basic_evaluation.events,
2282*da0073e9SAndroid Build Coastguard Worker            ),
2283*da0073e9SAndroid Build Coastguard Worker            """\
2284*da0073e9SAndroid Build Coastguard Worker0 [CPU (Before cudaLaunchKernel)]
2285*da0073e9SAndroid Build Coastguard Worker0 [CPU (Before cudaLaunchKernel)]
2286*da0073e9SAndroid Build Coastguard Worker0 [CPU (Before cudaLaunchKernel)]
2287*da0073e9SAndroid Build Coastguard Worker0 [CPU (Before cudaLaunchKernel)]
2288*da0073e9SAndroid Build Coastguard Worker1 [CPU (After cudaLaunchKernel)]
2289*da0073e9SAndroid Build Coastguard Worker2 [CPU (After cudaLaunchKernel)]
2290*da0073e9SAndroid Build Coastguard Worker3 [CPU (After cudaLaunchKernel)]
2291*da0073e9SAndroid Build Coastguard Worker4 [CPU (After cudaLaunchKernel)]
2292*da0073e9SAndroid Build Coastguard Worker5 [CPU (After GPU)]
2293*da0073e9SAndroid Build Coastguard Worker4 [CPU (After GPU)]
2294*da0073e9SAndroid Build Coastguard Worker2 [CPU (After GPU)]
2295*da0073e9SAndroid Build Coastguard Worker1 [CPU (After GPU)]
2296*da0073e9SAndroid Build Coastguard Worker""",
2297*da0073e9SAndroid Build Coastguard Worker        )
2298*da0073e9SAndroid Build Coastguard Worker
2299*da0073e9SAndroid Build Coastguard Worker    def test_utils_compute_queue_depth_when_no_cuda_events(self):
2300*da0073e9SAndroid Build Coastguard Worker        # For traces with only cpu events, we expect empty queue depth list
2301*da0073e9SAndroid Build Coastguard Worker        x = torch.ones((1024, 1024))
2302*da0073e9SAndroid Build Coastguard Worker        with profile() as prof:
2303*da0073e9SAndroid Build Coastguard Worker            for _ in range(5):
2304*da0073e9SAndroid Build Coastguard Worker                x = x @ x
2305*da0073e9SAndroid Build Coastguard Worker        basic_evaluation = _utils.BasicEvaluation(prof.profiler)
2306*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(basic_evaluation.compute_queue_depth())
2307*da0073e9SAndroid Build Coastguard Worker
2308*da0073e9SAndroid Build Coastguard Worker    def test_utils_compute_idle_time(self):
2309*da0073e9SAndroid Build Coastguard Worker        profiler = self.generate_mock_profile()
2310*da0073e9SAndroid Build Coastguard Worker        basic_evaluation = _utils.BasicEvaluation(profiler)
2311*da0073e9SAndroid Build Coastguard Worker        expected_output = "\n".join(
2312*da0073e9SAndroid Build Coastguard Worker            [
2313*da0073e9SAndroid Build Coastguard Worker                f"{basic_evaluation.metrics[event_key].idle_time_ns} [{event_key.event.name}]"
2314*da0073e9SAndroid Build Coastguard Worker                for event_key in basic_evaluation.event_keys
2315*da0073e9SAndroid Build Coastguard Worker            ]
2316*da0073e9SAndroid Build Coastguard Worker        )
2317*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
2318*da0073e9SAndroid Build Coastguard Worker            expected_output,
2319*da0073e9SAndroid Build Coastguard Worker            """\
2320*da0073e9SAndroid Build Coastguard Worker100000 [CPU (Before cudaLaunchKernel)]
2321*da0073e9SAndroid Build Coastguard Worker100000 [CPU (Before cudaLaunchKernel)]
2322*da0073e9SAndroid Build Coastguard Worker100000 [CPU (Before cudaLaunchKernel)]
2323*da0073e9SAndroid Build Coastguard Worker100000 [CPU (Before cudaLaunchKernel)]
2324*da0073e9SAndroid Build Coastguard Worker0 [CPU (After cudaLaunchKernel)]
2325*da0073e9SAndroid Build Coastguard Worker0 [CPU (After cudaLaunchKernel)]
2326*da0073e9SAndroid Build Coastguard Worker0 [CPU (After cudaLaunchKernel)]
2327*da0073e9SAndroid Build Coastguard Worker0 [CPU (After cudaLaunchKernel)]
2328*da0073e9SAndroid Build Coastguard Worker0 [CPU (After GPU)]
2329*da0073e9SAndroid Build Coastguard Worker0 [CPU (After GPU)]
2330*da0073e9SAndroid Build Coastguard Worker0 [CPU (After GPU)]
2331*da0073e9SAndroid Build Coastguard Worker100000 [CPU (After GPU)]""",
2332*da0073e9SAndroid Build Coastguard Worker        )
2333*da0073e9SAndroid Build Coastguard Worker
2334*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_JETSON, "JSON not behaving as expected on Jetson")
2335*da0073e9SAndroid Build Coastguard Worker    def test_utils_get_optimizable_events(self):
2336*da0073e9SAndroid Build Coastguard Worker        basic_evaluation = _utils.BasicEvaluation(self.load_mock_profile())
2337*da0073e9SAndroid Build Coastguard Worker        optimizable_events = basic_evaluation.get_optimizable_events(
2338*da0073e9SAndroid Build Coastguard Worker            2, print_enable=False
2339*da0073e9SAndroid Build Coastguard Worker        )
2340*da0073e9SAndroid Build Coastguard Worker        expected_output = "\n".join(
2341*da0073e9SAndroid Build Coastguard Worker            [f"{event_key.event.name}" for event_key in optimizable_events]
2342*da0073e9SAndroid Build Coastguard Worker        )
2343*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
2344*da0073e9SAndroid Build Coastguard Worker            expected_output,
2345*da0073e9SAndroid Build Coastguard Worker            """\
2346*da0073e9SAndroid Build Coastguard Worker<built-in function _cuda_synchronize>
2347*da0073e9SAndroid Build Coastguard Workeraten::copy_""",
2348*da0073e9SAndroid Build Coastguard Worker        )
2349*da0073e9SAndroid Build Coastguard Worker
2350*da0073e9SAndroid Build Coastguard Worker    def test_profiler_name_pattern(self):
2351*da0073e9SAndroid Build Coastguard Worker        x = torch.ones((4096, 4096))
2352*da0073e9SAndroid Build Coastguard Worker        with profile() as prof:
2353*da0073e9SAndroid Build Coastguard Worker            for _ in range(5):
2354*da0073e9SAndroid Build Coastguard Worker                x = x @ x
2355*da0073e9SAndroid Build Coastguard Worker                x = x + x
2356*da0073e9SAndroid Build Coastguard Worker        matched_events = NamePattern(prof, "aten::mm").matched_events()
2357*da0073e9SAndroid Build Coastguard Worker        output = "\n".join([f"{event.name}" for event in matched_events])
2358*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
2359*da0073e9SAndroid Build Coastguard Worker            output,
2360*da0073e9SAndroid Build Coastguard Worker            """\
2361*da0073e9SAndroid Build Coastguard Workeraten::mm
2362*da0073e9SAndroid Build Coastguard Workeraten::mm
2363*da0073e9SAndroid Build Coastguard Workeraten::mm
2364*da0073e9SAndroid Build Coastguard Workeraten::mm
2365*da0073e9SAndroid Build Coastguard Workeraten::mm""",
2366*da0073e9SAndroid Build Coastguard Worker        )
2367*da0073e9SAndroid Build Coastguard Worker
2368*da0073e9SAndroid Build Coastguard Worker    # TODO: Add logic for CUDA version of test
2369*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
2370*da0073e9SAndroid Build Coastguard Worker    def test_profiler_pattern_match_helper(self):
2371*da0073e9SAndroid Build Coastguard Worker        x = torch.ones((100, 100))
2372*da0073e9SAndroid Build Coastguard Worker        with profile() as prof:
2373*da0073e9SAndroid Build Coastguard Worker            for _ in range(5):
2374*da0073e9SAndroid Build Coastguard Worker                x = x @ x
2375*da0073e9SAndroid Build Coastguard Worker                x = x + x
2376*da0073e9SAndroid Build Coastguard Worker        event_tree = prof.profiler.kineto_results.experimental_event_tree()
2377*da0073e9SAndroid Build Coastguard Worker        pattern = Pattern(prof)
2378*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([], pattern.siblings_of(event_tree[0])[0])
2379*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(event_tree[1:], pattern.siblings_of(event_tree[0])[1])
2380*da0073e9SAndroid Build Coastguard Worker        child_nodes = event_tree[0].children
2381*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([], pattern.siblings_of(child_nodes[0])[0])
2382*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(child_nodes[1:], pattern.siblings_of(child_nodes[0])[1])
2383*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2384*da0073e9SAndroid Build Coastguard Worker            event_tree[0], pattern.root_of(event_tree[0].children[0].children[0])
2385*da0073e9SAndroid Build Coastguard Worker        )
2386*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(None, pattern.next_of(event_tree[-1]))
2387*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(event_tree[1], pattern.next_of(event_tree[0]))
2388*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(event_tree[0], pattern.prev_of(event_tree[1]))
2389*da0073e9SAndroid Build Coastguard Worker
2390*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2391*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
2392*da0073e9SAndroid Build Coastguard Worker    )
2393*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2394*da0073e9SAndroid Build Coastguard Worker    def test_profiler_extra_cuda_copy_pattern(self):
2395*da0073e9SAndroid Build Coastguard Worker        cases = (
2396*da0073e9SAndroid Build Coastguard Worker            (0, lambda: torch.ones((100, 100), device="cuda")),
2397*da0073e9SAndroid Build Coastguard Worker            (1, lambda: torch.ones((100, 100)).to("cuda")),
2398*da0073e9SAndroid Build Coastguard Worker            (1, lambda: torch.zeros((100, 100)).to("cuda")),
2399*da0073e9SAndroid Build Coastguard Worker            (1, lambda: torch.empty((100, 100)).fill_(5).to("cuda")),
2400*da0073e9SAndroid Build Coastguard Worker            (1, lambda: torch.ones((100, 100)).cuda()),
2401*da0073e9SAndroid Build Coastguard Worker            (1, lambda: torch.zeros((100, 100)).cuda()),
2402*da0073e9SAndroid Build Coastguard Worker            (1, lambda: torch.empty((100, 100)).fill_(5).cuda()),
2403*da0073e9SAndroid Build Coastguard Worker            (1, lambda: torch.rand((100, 100)).cuda()),
2404*da0073e9SAndroid Build Coastguard Worker            (1, lambda: torch.randn((100, 100)).cuda()),
2405*da0073e9SAndroid Build Coastguard Worker            (1, lambda: torch.full((100, 100), 10).cuda()),
2406*da0073e9SAndroid Build Coastguard Worker            (0, lambda: torch.rand((100, 100)).to(dtype=torch.float16)),
2407*da0073e9SAndroid Build Coastguard Worker            (0, lambda: torch.rand((100, 100)).half()),
2408*da0073e9SAndroid Build Coastguard Worker            (0, lambda: torch.rand((100, 100), device="cuda").half()),
2409*da0073e9SAndroid Build Coastguard Worker        )
2410*da0073e9SAndroid Build Coastguard Worker        num_matched = []
2411*da0073e9SAndroid Build Coastguard Worker        for _, fn in cases:
2412*da0073e9SAndroid Build Coastguard Worker            with profile(with_stack=True, record_shapes=True) as prof:
2413*da0073e9SAndroid Build Coastguard Worker                fn()
2414*da0073e9SAndroid Build Coastguard Worker            pattern = ExtraCUDACopyPattern(prof)
2415*da0073e9SAndroid Build Coastguard Worker            num_matched.append(len(pattern.matched_events()))
2416*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_matched, [i for i, _ in cases])
2417*da0073e9SAndroid Build Coastguard Worker
2418*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2419*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
2420*da0073e9SAndroid Build Coastguard Worker    )
2421*da0073e9SAndroid Build Coastguard Worker    def test_profiler_for_loop_indexing_pattern(self):
2422*da0073e9SAndroid Build Coastguard Worker        x = torch.ones((100, 100))
2423*da0073e9SAndroid Build Coastguard Worker
2424*da0073e9SAndroid Build Coastguard Worker        def case1():
2425*da0073e9SAndroid Build Coastguard Worker            for i in range(100):
2426*da0073e9SAndroid Build Coastguard Worker                x[i] = i
2427*da0073e9SAndroid Build Coastguard Worker
2428*da0073e9SAndroid Build Coastguard Worker        def case2():
2429*da0073e9SAndroid Build Coastguard Worker            y = 0
2430*da0073e9SAndroid Build Coastguard Worker            for i in range(100):
2431*da0073e9SAndroid Build Coastguard Worker                y += x[i]
2432*da0073e9SAndroid Build Coastguard Worker
2433*da0073e9SAndroid Build Coastguard Worker        def case3():
2434*da0073e9SAndroid Build Coastguard Worker            y = 1
2435*da0073e9SAndroid Build Coastguard Worker            for i in range(100):
2436*da0073e9SAndroid Build Coastguard Worker                y *= x[i]
2437*da0073e9SAndroid Build Coastguard Worker
2438*da0073e9SAndroid Build Coastguard Worker        def case4():
2439*da0073e9SAndroid Build Coastguard Worker            y = x
2440*da0073e9SAndroid Build Coastguard Worker            for _ in range(100):
2441*da0073e9SAndroid Build Coastguard Worker                y = y @ x
2442*da0073e9SAndroid Build Coastguard Worker
2443*da0073e9SAndroid Build Coastguard Worker        def case5():
2444*da0073e9SAndroid Build Coastguard Worker            for i in range(100):
2445*da0073e9SAndroid Build Coastguard Worker                x[i, :] = torch.arange(100) + i
2446*da0073e9SAndroid Build Coastguard Worker
2447*da0073e9SAndroid Build Coastguard Worker        cases = ((1, case1), (1, case2), (1, case3), (0, case4), (1, case5))
2448*da0073e9SAndroid Build Coastguard Worker        num_matched = []
2449*da0073e9SAndroid Build Coastguard Worker        for _, fn in cases:
2450*da0073e9SAndroid Build Coastguard Worker            with profile(with_stack=True) as prof:
2451*da0073e9SAndroid Build Coastguard Worker                fn()
2452*da0073e9SAndroid Build Coastguard Worker            pattern = ForLoopIndexingPattern(prof)
2453*da0073e9SAndroid Build Coastguard Worker            num_matched.append(len(pattern.matched_events()))
2454*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_matched, [i for i, _ in cases])
2455*da0073e9SAndroid Build Coastguard Worker
2456*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2457*da0073e9SAndroid Build Coastguard Worker    def test_profiler_fp32_matmul_pattern(self):
2458*da0073e9SAndroid Build Coastguard Worker        x = torch.ones((100, 100), device="cuda")
2459*da0073e9SAndroid Build Coastguard Worker        with profile(with_stack=True) as prof:
2460*da0073e9SAndroid Build Coastguard Worker            x = x @ x
2461*da0073e9SAndroid Build Coastguard Worker        pattern = FP32MatMulPattern(prof)
2462*da0073e9SAndroid Build Coastguard Worker        has_tf32 = 0 if pattern.skip else 1
2463*da0073e9SAndroid Build Coastguard Worker        num_matched = len(pattern.matched_events())
2464*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_matched, has_tf32)
2465*da0073e9SAndroid Build Coastguard Worker
2466*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2467*da0073e9SAndroid Build Coastguard Worker    def test_profiler_extra_cuda_copy_pattern_benchmark(self):
2468*da0073e9SAndroid Build Coastguard Worker        with profile(with_stack=True, record_shapes=True) as prof:
2469*da0073e9SAndroid Build Coastguard Worker            x = torch.ones((100, 100)).to("cuda")
2470*da0073e9SAndroid Build Coastguard Worker            x = torch.ones((50, 50)).to("cuda")
2471*da0073e9SAndroid Build Coastguard Worker        pattern = ExtraCUDACopyPattern(prof)
2472*da0073e9SAndroid Build Coastguard Worker        shapes_factor_map = pattern.benchmark(pattern.matched_events())
2473*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(shapes_factor_map), 2)
2474*da0073e9SAndroid Build Coastguard Worker
2475*da0073e9SAndroid Build Coastguard Worker    def test_profiler_optimizer_single_tensor_pattern(self):
2476*da0073e9SAndroid Build Coastguard Worker        x = torch.ones((100, 100))
2477*da0073e9SAndroid Build Coastguard Worker        cases = (
2478*da0073e9SAndroid Build Coastguard Worker            (1, lambda: torch.optim.Adam(model.parameters())),
2479*da0073e9SAndroid Build Coastguard Worker            (1, lambda: torch.optim.SGD(model.parameters(), lr=0.01)),
2480*da0073e9SAndroid Build Coastguard Worker            (1, lambda: torch.optim.AdamW(model.parameters())),
2481*da0073e9SAndroid Build Coastguard Worker            (0, lambda: torch.optim.Adam(model.parameters(), foreach=True)),
2482*da0073e9SAndroid Build Coastguard Worker            (0, lambda: torch.optim.SGD(model.parameters(), lr=0.01, foreach=True)),
2483*da0073e9SAndroid Build Coastguard Worker            (0, lambda: torch.optim.AdamW(model.parameters(), foreach=True)),
2484*da0073e9SAndroid Build Coastguard Worker        )
2485*da0073e9SAndroid Build Coastguard Worker        num_matched = []
2486*da0073e9SAndroid Build Coastguard Worker        for _, fn in cases:
2487*da0073e9SAndroid Build Coastguard Worker            with profile(with_stack=True) as prof:
2488*da0073e9SAndroid Build Coastguard Worker                model = nn.Sequential(
2489*da0073e9SAndroid Build Coastguard Worker                    nn.Linear(100, 100),
2490*da0073e9SAndroid Build Coastguard Worker                    nn.ReLU(),
2491*da0073e9SAndroid Build Coastguard Worker                    nn.Linear(100, 10),
2492*da0073e9SAndroid Build Coastguard Worker                )
2493*da0073e9SAndroid Build Coastguard Worker                optimizer = fn()
2494*da0073e9SAndroid Build Coastguard Worker                optimizer.zero_grad()
2495*da0073e9SAndroid Build Coastguard Worker                y_hat = model(x)
2496*da0073e9SAndroid Build Coastguard Worker                loss = torch.nn.functional.cross_entropy(
2497*da0073e9SAndroid Build Coastguard Worker                    y_hat, torch.randint(0, 10, (100,))
2498*da0073e9SAndroid Build Coastguard Worker                )
2499*da0073e9SAndroid Build Coastguard Worker                loss.backward()
2500*da0073e9SAndroid Build Coastguard Worker                optimizer.step()
2501*da0073e9SAndroid Build Coastguard Worker            pattern = OptimizerSingleTensorPattern(prof)
2502*da0073e9SAndroid Build Coastguard Worker            num_matched.append(len(pattern.matched_events()))
2503*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_matched, [i for i, _ in cases])
2504*da0073e9SAndroid Build Coastguard Worker
2505*da0073e9SAndroid Build Coastguard Worker    def test_profiler_synchronized_dataloader_pattern(self):
2506*da0073e9SAndroid Build Coastguard Worker        dataset = torch.rand((100, 100))
2507*da0073e9SAndroid Build Coastguard Worker        sync_dataloader = torch.utils.data.DataLoader(dataset, batch_size=10)
2508*da0073e9SAndroid Build Coastguard Worker        async_dataloader = torch.utils.data.DataLoader(
2509*da0073e9SAndroid Build Coastguard Worker            dataset, batch_size=10, num_workers=4
2510*da0073e9SAndroid Build Coastguard Worker        )
2511*da0073e9SAndroid Build Coastguard Worker        with profile(with_stack=True) as prof:
2512*da0073e9SAndroid Build Coastguard Worker            next(iter(sync_dataloader))
2513*da0073e9SAndroid Build Coastguard Worker            next(iter(async_dataloader))
2514*da0073e9SAndroid Build Coastguard Worker        pattern = SynchronizedDataLoaderPattern(prof)
2515*da0073e9SAndroid Build Coastguard Worker        num_matched = len(pattern.matched_events())
2516*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_matched, 1)
2517*da0073e9SAndroid Build Coastguard Worker
2518*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo(
2519*da0073e9SAndroid Build Coastguard Worker        "pattern checks for aten::_zero op which might not be there with torch.compile'd graph"
2520*da0073e9SAndroid Build Coastguard Worker    )
2521*da0073e9SAndroid Build Coastguard Worker    def test_profiler_grad_not_set_to_none_pattern(self):
2522*da0073e9SAndroid Build Coastguard Worker        x = torch.ones((100, 100))
2523*da0073e9SAndroid Build Coastguard Worker        model = nn.Sequential(
2524*da0073e9SAndroid Build Coastguard Worker            nn.Linear(100, 100),
2525*da0073e9SAndroid Build Coastguard Worker            nn.ReLU(),
2526*da0073e9SAndroid Build Coastguard Worker            nn.Linear(100, 10),
2527*da0073e9SAndroid Build Coastguard Worker        )
2528*da0073e9SAndroid Build Coastguard Worker        optimizer = torch.optim.Adam(model.parameters())
2529*da0073e9SAndroid Build Coastguard Worker        cases = (
2530*da0073e9SAndroid Build Coastguard Worker            (0, lambda: optimizer.zero_grad()),
2531*da0073e9SAndroid Build Coastguard Worker            (0, lambda: model.zero_grad()),
2532*da0073e9SAndroid Build Coastguard Worker            (1, lambda: optimizer.zero_grad(set_to_none=False)),
2533*da0073e9SAndroid Build Coastguard Worker            (1, lambda: model.zero_grad(set_to_none=False)),
2534*da0073e9SAndroid Build Coastguard Worker        )
2535*da0073e9SAndroid Build Coastguard Worker        num_matched = []
2536*da0073e9SAndroid Build Coastguard Worker        for _, fn in cases:
2537*da0073e9SAndroid Build Coastguard Worker            with profile(with_stack=True) as prof:
2538*da0073e9SAndroid Build Coastguard Worker                y_hat = model(x)
2539*da0073e9SAndroid Build Coastguard Worker                loss = torch.nn.functional.cross_entropy(
2540*da0073e9SAndroid Build Coastguard Worker                    y_hat, torch.randint(0, 10, (100,))
2541*da0073e9SAndroid Build Coastguard Worker                )
2542*da0073e9SAndroid Build Coastguard Worker                loss.backward()
2543*da0073e9SAndroid Build Coastguard Worker                optimizer.step()
2544*da0073e9SAndroid Build Coastguard Worker                fn()
2545*da0073e9SAndroid Build Coastguard Worker            pattern = GradNotSetToNonePattern(prof)
2546*da0073e9SAndroid Build Coastguard Worker            num_matched.append(len(pattern.matched_events()))
2547*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_matched, [i for i, _ in cases])
2548*da0073e9SAndroid Build Coastguard Worker
2549*da0073e9SAndroid Build Coastguard Worker    def test_profiler_conv2d_bias_followed_by_batchnorm2d_pattern(self):
2550*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((1, 3, 32, 32))
2551*da0073e9SAndroid Build Coastguard Worker        cases = (
2552*da0073e9SAndroid Build Coastguard Worker            (1, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1), nn.BatchNorm2d(3))),
2553*da0073e9SAndroid Build Coastguard Worker            (0, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1, bias=False), nn.BatchNorm2d(3))),
2554*da0073e9SAndroid Build Coastguard Worker            (0, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1))),
2555*da0073e9SAndroid Build Coastguard Worker        )
2556*da0073e9SAndroid Build Coastguard Worker        num_matched = []
2557*da0073e9SAndroid Build Coastguard Worker        for _, model in cases:
2558*da0073e9SAndroid Build Coastguard Worker            with profile(with_stack=True, record_shapes=True) as prof:
2559*da0073e9SAndroid Build Coastguard Worker                model(x)
2560*da0073e9SAndroid Build Coastguard Worker            pattern = Conv2dBiasFollowedByBatchNorm2dPattern(prof)
2561*da0073e9SAndroid Build Coastguard Worker            num_matched.append(len(pattern.matched_events()))
2562*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_matched, [i for i, _ in cases])
2563*da0073e9SAndroid Build Coastguard Worker
2564*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2565*da0073e9SAndroid Build Coastguard Worker    def test_profiler_matmul_dim_fp16_pattern(self):
2566*da0073e9SAndroid Build Coastguard Worker        cases = (
2567*da0073e9SAndroid Build Coastguard Worker            (1, torch.randn((201, 201), device="cuda", dtype=torch.float16)),
2568*da0073e9SAndroid Build Coastguard Worker            (1, torch.randn((3, 97, 97), device="cuda", dtype=torch.float16)),
2569*da0073e9SAndroid Build Coastguard Worker            (0, torch.randn((200, 200), device="cuda", dtype=torch.float16)),
2570*da0073e9SAndroid Build Coastguard Worker            (0, torch.randn((3, 200, 200), device="cuda", dtype=torch.float16)),
2571*da0073e9SAndroid Build Coastguard Worker        )
2572*da0073e9SAndroid Build Coastguard Worker        num_matched = []
2573*da0073e9SAndroid Build Coastguard Worker        for _, x in cases:
2574*da0073e9SAndroid Build Coastguard Worker            with profile(with_stack=True, record_shapes=True) as prof:
2575*da0073e9SAndroid Build Coastguard Worker                x @ x
2576*da0073e9SAndroid Build Coastguard Worker            pattern = MatMulDimInFP16Pattern(prof)
2577*da0073e9SAndroid Build Coastguard Worker            num_matched.append(len(pattern.matched_events()))
2578*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_matched, [i for i, _ in cases])
2579*da0073e9SAndroid Build Coastguard Worker
2580*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
2581*da0073e9SAndroid Build Coastguard Worker    def test_profiler_pattern_matcher_json_report(self):
2582*da0073e9SAndroid Build Coastguard Worker        x = torch.ones((100, 100))
2583*da0073e9SAndroid Build Coastguard Worker        model = nn.Sequential(
2584*da0073e9SAndroid Build Coastguard Worker            nn.Linear(100, 100),
2585*da0073e9SAndroid Build Coastguard Worker            nn.ReLU(),
2586*da0073e9SAndroid Build Coastguard Worker            nn.Linear(100, 10),
2587*da0073e9SAndroid Build Coastguard Worker        )
2588*da0073e9SAndroid Build Coastguard Worker        optimizer = torch.optim.Adam(model.parameters())
2589*da0073e9SAndroid Build Coastguard Worker        with profile(with_stack=True, record_shapes=True) as prof:
2590*da0073e9SAndroid Build Coastguard Worker            y_hat = model(x)
2591*da0073e9SAndroid Build Coastguard Worker            loss = torch.nn.functional.cross_entropy(
2592*da0073e9SAndroid Build Coastguard Worker                y_hat, torch.randint(0, 10, (100,))
2593*da0073e9SAndroid Build Coastguard Worker            )
2594*da0073e9SAndroid Build Coastguard Worker            loss.backward()
2595*da0073e9SAndroid Build Coastguard Worker            optimizer.step()
2596*da0073e9SAndroid Build Coastguard Worker            optimizer.zero_grad()
2597*da0073e9SAndroid Build Coastguard Worker
2598*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryDirectory() as tmpdir:
2599*da0073e9SAndroid Build Coastguard Worker            report_all_anti_patterns(prof, json_report_dir=tmpdir, print_enable=False)
2600*da0073e9SAndroid Build Coastguard Worker
2601*da0073e9SAndroid Build Coastguard Worker            with open(os.path.join(tmpdir, "torchtidy_report.json")) as f:
2602*da0073e9SAndroid Build Coastguard Worker                report = json.load(f)
2603*da0073e9SAndroid Build Coastguard Worker
2604*da0073e9SAndroid Build Coastguard Worker            # It is platform dependent whether the path will include "profiler/"
2605*da0073e9SAndroid Build Coastguard Worker            keys = [k for k in report.keys() if k.endswith("test_profiler.py")]
2606*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(keys), 1, f"{keys}")
2607*da0073e9SAndroid Build Coastguard Worker            entry = report[keys[0]]
2608*da0073e9SAndroid Build Coastguard Worker
2609*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(len(entry) > 0)
2610*da0073e9SAndroid Build Coastguard Worker            expected_fields = sorted(["line_number", "name", "url", "message"])
2611*da0073e9SAndroid Build Coastguard Worker            for event in entry:
2612*da0073e9SAndroid Build Coastguard Worker                actual_fields = sorted(event.keys())
2613*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected_fields, actual_fields)
2614*da0073e9SAndroid Build Coastguard Worker
2615*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding")
2616*da0073e9SAndroid Build Coastguard Worker    def test_fuzz_symbolize(self):
2617*da0073e9SAndroid Build Coastguard Worker        # generate some random addresses in the text section and make sure the
2618*da0073e9SAndroid Build Coastguard Worker        # symbolizers do not throw exceptions/crash
2619*da0073e9SAndroid Build Coastguard Worker        def get_text_sections():
2620*da0073e9SAndroid Build Coastguard Worker            text_sections = []
2621*da0073e9SAndroid Build Coastguard Worker            seen = set()
2622*da0073e9SAndroid Build Coastguard Worker            for filename in os.listdir("/proc/self/map_files"):
2623*da0073e9SAndroid Build Coastguard Worker                library = os.readlink("/proc/self/map_files/" + filename)
2624*da0073e9SAndroid Build Coastguard Worker                if ".so" not in library or library in seen:
2625*da0073e9SAndroid Build Coastguard Worker                    continue
2626*da0073e9SAndroid Build Coastguard Worker                seen.add(library)
2627*da0073e9SAndroid Build Coastguard Worker                with open(os.path.join("/proc/self/map_files", library), "rb") as f:
2628*da0073e9SAndroid Build Coastguard Worker                    mm = mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ)
2629*da0073e9SAndroid Build Coastguard Worker
2630*da0073e9SAndroid Build Coastguard Worker                    def unpack(fmt, offset):
2631*da0073e9SAndroid Build Coastguard Worker                        return struct.unpack(
2632*da0073e9SAndroid Build Coastguard Worker                            fmt, mm[offset : offset + struct.calcsize(fmt)]
2633*da0073e9SAndroid Build Coastguard Worker                        )
2634*da0073e9SAndroid Build Coastguard Worker
2635*da0073e9SAndroid Build Coastguard Worker                    if mm[:4] != b"\x7fELF":
2636*da0073e9SAndroid Build Coastguard Worker                        continue
2637*da0073e9SAndroid Build Coastguard Worker                    (section_headers_start,) = unpack("Q", 40)
2638*da0073e9SAndroid Build Coastguard Worker                    (section_header_size,) = unpack("H", 58)
2639*da0073e9SAndroid Build Coastguard Worker                    (num_section_headers,) = unpack("H", 60)
2640*da0073e9SAndroid Build Coastguard Worker                    (shstrndx,) = unpack("H", 62)
2641*da0073e9SAndroid Build Coastguard Worker                    (shstrtab_offset,) = unpack(
2642*da0073e9SAndroid Build Coastguard Worker                        "Q", section_headers_start + shstrndx * section_header_size + 24
2643*da0073e9SAndroid Build Coastguard Worker                    )
2644*da0073e9SAndroid Build Coastguard Worker                    for i in range(num_section_headers):
2645*da0073e9SAndroid Build Coastguard Worker                        (section_name_offset,) = unpack(
2646*da0073e9SAndroid Build Coastguard Worker                            "I", section_headers_start + i * section_header_size
2647*da0073e9SAndroid Build Coastguard Worker                        )
2648*da0073e9SAndroid Build Coastguard Worker                        name_start = shstrtab_offset + section_name_offset
2649*da0073e9SAndroid Build Coastguard Worker                        section_name = mm[name_start : name_start + 6]
2650*da0073e9SAndroid Build Coastguard Worker                        if section_name != b".text\0":
2651*da0073e9SAndroid Build Coastguard Worker                            continue
2652*da0073e9SAndroid Build Coastguard Worker                        (section_offset,) = unpack(
2653*da0073e9SAndroid Build Coastguard Worker                            "Q", section_headers_start + i * section_header_size + 24
2654*da0073e9SAndroid Build Coastguard Worker                        )
2655*da0073e9SAndroid Build Coastguard Worker                        (section_size,) = unpack(
2656*da0073e9SAndroid Build Coastguard Worker                            "Q", section_headers_start + i * section_header_size + 32
2657*da0073e9SAndroid Build Coastguard Worker                        )
2658*da0073e9SAndroid Build Coastguard Worker                        start = int(filename.split("-")[0], 16) + section_offset
2659*da0073e9SAndroid Build Coastguard Worker                        text_sections.append((start, section_size))
2660*da0073e9SAndroid Build Coastguard Worker                        break
2661*da0073e9SAndroid Build Coastguard Worker                    mm.close()
2662*da0073e9SAndroid Build Coastguard Worker            return text_sections
2663*da0073e9SAndroid Build Coastguard Worker
2664*da0073e9SAndroid Build Coastguard Worker        r = random.Random()
2665*da0073e9SAndroid Build Coastguard Worker        r.seed(1)
2666*da0073e9SAndroid Build Coastguard Worker        text_sections = get_text_sections()
2667*da0073e9SAndroid Build Coastguard Worker        addrs = []
2668*da0073e9SAndroid Build Coastguard Worker        for i in range(200):
2669*da0073e9SAndroid Build Coastguard Worker            s = r.randrange(0, len(text_sections))
2670*da0073e9SAndroid Build Coastguard Worker            start, size = text_sections[s]
2671*da0073e9SAndroid Build Coastguard Worker            addr = r.randrange(start, start + size)
2672*da0073e9SAndroid Build Coastguard Worker            addrs.append(addr)
2673*da0073e9SAndroid Build Coastguard Worker        fast = torch._C._profiler.symbolize_addresses(addrs, "fast")
2674*da0073e9SAndroid Build Coastguard Worker        dladdr = torch._C._profiler.symbolize_addresses(addrs, "dladdr")
2675*da0073e9SAndroid Build Coastguard Worker        addr2line = torch._C._profiler.symbolize_addresses(addrs, "addr2line")
2676*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(fast), len(addrs))
2677*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(addr2line), len(fast))
2678*da0073e9SAndroid Build Coastguard Worker
2679*da0073e9SAndroid Build Coastguard Worker
2680*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
2681*da0073e9SAndroid Build Coastguard Worker    run_tests()
2682